unittest
unittest 是 Python 内置的测试框架,CPython 自身就使用 unittest 编写测试。想要为 CPython 贡献代码,必须掌握它。虽然第三方 pytest 更流行,但 unittest 仍是标准库和许多大型项目的基础。
基本测试
import unittest
def add(a, b):
return a + b
class TestAdd(unittest.TestCase):
def test_integers(self):
self.assertEqual(add(1, 2), 3)
def test_floats(self):
self.assertAlmostEqual(add(0.1, 0.2), 0.3, places=7)
def test_strings(self):
self.assertEqual(add("hello ", "world"), "hello world")
def test_negative(self):
self.assertEqual(add(-1, -2), -3)
if __name__ == "__main__":
unittest.main()
常用断言方法
import unittest
class TestAssertions(unittest.TestCase):
def test_equality(self):
self.assertEqual(1 + 1, 2)
self.assertNotEqual(1, 2)
def test_truth(self):
self.assertTrue(3 > 2)
self.assertFalse(2 > 3)
def test_identity(self):
a = [1, 2, 3]
b = a
self.assertIs(a, b)
self.assertIsNot(a, [1, 2, 3])
def test_none(self):
self.assertIsNone(None)
self.assertIsNotNone(42)
def test_membership(self):
self.assertIn(3, [1, 2, 3])
self.assertNotIn(4, [1, 2, 3])
def test_type(self):
self.assertIsInstance(42, int)
def test_exception(self):
with self.assertRaises(ZeroDivisionError):
1 / 0
with self.assertRaises(ValueError) as cm:
int("not_a_number")
self.assertIn("invalid literal", str(cm.exception))
def test_warning(self):
import warnings
with self.assertWarns(DeprecationWarning):
warnings.warn("old api", DeprecationWarning)
setUp 和 tearDown
import unittest
import tempfile
import os
class TestFileOperations(unittest.TestCase):
def setUp(self):
"""每个测试方法前运行"""
self.test_dir = tempfile.mkdtemp()
self.test_file = os.path.join(self.test_dir, "test.txt")
with open(self.test_file, "w") as f:
f.write("test content")
def tearDown(self):
"""每个测试方法后运行"""
if os.path.exists(self.test_file):
os.remove(self.test_file)
os.rmdir(self.test_dir)
def test_file_exists(self):
self.assertTrue(os.path.exists(self.test_file))
def test_file_content(self):
with open(self.test_file) as f:
self.assertEqual(f.read(), "test content")
@classmethod
def setUpClass(cls):
"""整个测试类只运行一次"""
print("测试类开始")
@classmethod
def tearDownClass(cls):
"""整个测试类结束时运行一次"""
print("测试类结束")
unittest.mock
unittest.mock 是测试中替换对象行为的工具,可以模拟外部依赖(API 调用、数据库、文件系统等)。
from unittest.mock import Mock, patch, MagicMock
# 基本 Mock
mock_api = Mock()
mock_api.get_data.return_value = {"status": "ok", "data": [1, 2, 3]}
result = mock_api.get_data()
print(result) # {'status': 'ok', 'data': [1, 2, 3]}
mock_api.get_data.assert_called_once()
patch 装饰器
from unittest.mock import patch
import json
def load_config(path):
with open(path) as f:
return json.load(f)
# 用 patch 替换 builtins.open
with patch("builtins.open", create=True) as mock_open:
mock_open.return_value.__enter__ = lambda s: s
mock_open.return_value.__exit__ = Mock(return_value=False)
mock_open.return_value.read.return_value = '{"key": "value"}'
模拟 AI 推理服务
import unittest
from unittest.mock import patch, Mock
class ModelService:
def __init__(self, model_path):
self.model = self._load_model(model_path)
def _load_model(self, path):
raise NotImplementedError("需要实际模型文件")
def predict(self, text):
return self.model.predict(text)
class TestModelService(unittest.TestCase):
@patch.object(ModelService, '_load_model')
def test_predict(self, mock_load):
mock_model = Mock()
mock_model.predict.return_value = {"label": "positive", "score": 0.95}
mock_load.return_value = mock_model
service = ModelService("/fake/model")
result = service.predict("This is great!")
self.assertEqual(result["label"], "positive")
self.assertGreater(result["score"], 0.9)
mock_model.predict.assert_called_once_with("This is great!")
子测试(subTest)
import unittest
class TestMath(unittest.TestCase):
def test_square(self):
test_cases = [
(2, 4), (3, 9), (4, 16), (-1, 1), (0, 0)
]
for x, expected in test_cases:
with self.subTest(x=x):
self.assertEqual(x ** 2, expected)
跳过测试
import unittest
import sys
class TestPlatform(unittest.TestCase):
@unittest.skip("暂时跳过")
def test_skipped(self):
pass
@unittest.skipIf(sys.platform == "win32", "Windows 不支持")
def test_unix_only(self):
pass
@unittest.skipUnless(sys.platform == "linux", "仅 Linux")
def test_linux_only(self):
pass
@unittest.expectedFailure
def test_known_bug(self):
self.assertEqual(1, 2)