Skip to main content

unittest

unittest 是 Python 内置的测试框架,CPython 自身就使用 unittest 编写测试。想要为 CPython 贡献代码,必须掌握它。虽然第三方 pytest 更流行,但 unittest 仍是标准库和许多大型项目的基础。

unittest

基本测试

import unittest

def add(a, b):
return a + b

class TestAdd(unittest.TestCase):
def test_integers(self):
self.assertEqual(add(1, 2), 3)

def test_floats(self):
self.assertAlmostEqual(add(0.1, 0.2), 0.3, places=7)

def test_strings(self):
self.assertEqual(add("hello ", "world"), "hello world")

def test_negative(self):
self.assertEqual(add(-1, -2), -3)

if __name__ == "__main__":
unittest.main()

常用断言方法

import unittest

class TestAssertions(unittest.TestCase):
def test_equality(self):
self.assertEqual(1 + 1, 2)
self.assertNotEqual(1, 2)

def test_truth(self):
self.assertTrue(3 > 2)
self.assertFalse(2 > 3)

def test_identity(self):
a = [1, 2, 3]
b = a
self.assertIs(a, b)
self.assertIsNot(a, [1, 2, 3])

def test_none(self):
self.assertIsNone(None)
self.assertIsNotNone(42)

def test_membership(self):
self.assertIn(3, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])

def test_type(self):
self.assertIsInstance(42, int)

def test_exception(self):
with self.assertRaises(ZeroDivisionError):
1 / 0

with self.assertRaises(ValueError) as cm:
int("not_a_number")
self.assertIn("invalid literal", str(cm.exception))

def test_warning(self):
import warnings
with self.assertWarns(DeprecationWarning):
warnings.warn("old api", DeprecationWarning)

setUp 和 tearDown

import unittest
import tempfile
import os

class TestFileOperations(unittest.TestCase):
def setUp(self):
"""每个测试方法前运行"""
self.test_dir = tempfile.mkdtemp()
self.test_file = os.path.join(self.test_dir, "test.txt")
with open(self.test_file, "w") as f:
f.write("test content")

def tearDown(self):
"""每个测试方法后运行"""
if os.path.exists(self.test_file):
os.remove(self.test_file)
os.rmdir(self.test_dir)

def test_file_exists(self):
self.assertTrue(os.path.exists(self.test_file))

def test_file_content(self):
with open(self.test_file) as f:
self.assertEqual(f.read(), "test content")

@classmethod
def setUpClass(cls):
"""整个测试类只运行一次"""
print("测试类开始")

@classmethod
def tearDownClass(cls):
"""整个测试类结束时运行一次"""
print("测试类结束")

unittest.mock

unittest.mock 是测试中替换对象行为的工具,可以模拟外部依赖(API 调用、数据库、文件系统等)。

from unittest.mock import Mock, patch, MagicMock

# 基本 Mock
mock_api = Mock()
mock_api.get_data.return_value = {"status": "ok", "data": [1, 2, 3]}

result = mock_api.get_data()
print(result) # {'status': 'ok', 'data': [1, 2, 3]}

mock_api.get_data.assert_called_once()

patch 装饰器

from unittest.mock import patch
import json

def load_config(path):
with open(path) as f:
return json.load(f)

# 用 patch 替换 builtins.open
with patch("builtins.open", create=True) as mock_open:
mock_open.return_value.__enter__ = lambda s: s
mock_open.return_value.__exit__ = Mock(return_value=False)
mock_open.return_value.read.return_value = '{"key": "value"}'

模拟 AI 推理服务

import unittest
from unittest.mock import patch, Mock

class ModelService:
def __init__(self, model_path):
self.model = self._load_model(model_path)

def _load_model(self, path):
raise NotImplementedError("需要实际模型文件")

def predict(self, text):
return self.model.predict(text)

class TestModelService(unittest.TestCase):
@patch.object(ModelService, '_load_model')
def test_predict(self, mock_load):
mock_model = Mock()
mock_model.predict.return_value = {"label": "positive", "score": 0.95}
mock_load.return_value = mock_model

service = ModelService("/fake/model")
result = service.predict("This is great!")

self.assertEqual(result["label"], "positive")
self.assertGreater(result["score"], 0.9)
mock_model.predict.assert_called_once_with("This is great!")

子测试(subTest)

import unittest

class TestMath(unittest.TestCase):
def test_square(self):
test_cases = [
(2, 4), (3, 9), (4, 16), (-1, 1), (0, 0)
]
for x, expected in test_cases:
with self.subTest(x=x):
self.assertEqual(x ** 2, expected)

跳过测试

import unittest
import sys

class TestPlatform(unittest.TestCase):
@unittest.skip("暂时跳过")
def test_skipped(self):
pass

@unittest.skipIf(sys.platform == "win32", "Windows 不支持")
def test_unix_only(self):
pass

@unittest.skipUnless(sys.platform == "linux", "仅 Linux")
def test_linux_only(self):
pass

@unittest.expectedFailure
def test_known_bug(self):
self.assertEqual(1, 2)