|
1 | | -import unittest |
2 | | -import sys |
3 | | -import os |
4 | | -from unittest.mock import MagicMock, patch |
5 | | - |
6 | | -# 添加项目根目录到Python路径 |
7 | | -sys.path.insert(0, os.path.abspath(os.path.join( |
8 | | - os.path.dirname(__file__), '..', '..', '..'))) |
9 | | - |
10 | | -# 模拟主要依赖 |
11 | | -sys.modules['langchain_core.tools'] = MagicMock() |
12 | | -sys.modules['consts'] = MagicMock() |
13 | | -sys.modules['consts.model'] = MagicMock() |
14 | | - |
15 | | -# 模拟logger |
16 | | -logger_mock = MagicMock() |
17 | | - |
18 | | - |
19 | | -class TestLangchainUtils(unittest.TestCase): |
20 | | - """测试langchain_utils模块的函数""" |
21 | | - |
22 | | - def setUp(self): |
23 | | - """每个测试方法前的设置""" |
24 | | - # 导入原始函数 |
25 | | - from backend.utils.langchain_utils import discover_langchain_modules, _is_langchain_tool |
26 | | - self.discover_langchain_modules = discover_langchain_modules |
27 | | - self._is_langchain_tool = _is_langchain_tool |
28 | | - |
29 | | - def test_is_langchain_tool(self): |
30 | | - """测试_is_langchain_tool函数""" |
31 | | - # 创建一个BaseTool实例的模拟 |
32 | | - mock_tool = MagicMock() |
33 | | - |
34 | | - # 模拟isinstance返回值 |
35 | | - with patch('backend.utils.langchain_utils.isinstance', return_value=True): |
36 | | - result = self._is_langchain_tool(mock_tool) |
37 | | - self.assertTrue(result) |
38 | | - |
39 | | - # 测试非BaseTool对象 |
40 | | - with patch('backend.utils.langchain_utils.isinstance', return_value=False): |
41 | | - result = self._is_langchain_tool("not a tool") |
42 | | - self.assertFalse(result) |
43 | | - |
44 | | - def test_discover_langchain_modules_success(self): |
45 | | - """测试成功发现LangChain工具的情况""" |
46 | | - # 创建一个临时目录结构 |
47 | | - with patch('os.path.isdir', return_value=True), \ |
48 | | - patch('os.listdir', return_value=['tool1.py', 'tool2.py', '__init__.py', 'not_a_py_file.txt']), \ |
49 | | - patch('importlib.util.spec_from_file_location') as mock_spec, \ |
50 | | - patch('importlib.util.module_from_spec') as mock_module_from_spec: |
| 1 | +from unittest.mock import MagicMock |
51 | 2 |
|
52 | | - # 创建模拟工具对象 |
53 | | - mock_tool1 = MagicMock(name="tool1") |
54 | | - mock_tool2 = MagicMock(name="tool2") |
| 3 | +from backend.utils.langchain_utils import discover_langchain_modules, _is_langchain_tool |
55 | 4 |
|
56 | | - # 设置模拟module |
57 | | - mock_module_obj1 = MagicMock() |
58 | | - mock_module_obj1.tool_obj1 = mock_tool1 |
59 | 5 |
|
60 | | - mock_module_obj2 = MagicMock() |
61 | | - mock_module_obj2.tool_obj2 = mock_tool2 |
| 6 | +class TestLangchainUtils: |
| 7 | + """Tests for backend.utils.langchain_utils functions""" |
62 | 8 |
|
63 | | - mock_module_from_spec.side_effect = [ |
64 | | - mock_module_obj1, mock_module_obj2] |
| 9 | + def test_is_langchain_tool_with_base_tool(self, mocker): |
| 10 | + """Returns True for objects that are instances of BaseTool""" |
| 11 | + # Mock BaseTool class and create instance |
| 12 | + mock_base_tool_class = MagicMock() |
| 13 | + mock_tool_instance = MagicMock() |
65 | 14 |
|
66 | | - # 设置模拟spec和loader |
67 | | - mock_spec_obj1 = MagicMock() |
68 | | - mock_spec_obj2 = MagicMock() |
69 | | - mock_spec.side_effect = [mock_spec_obj1, mock_spec_obj2] |
| 15 | + mocker.patch('langchain_core.tools.BaseTool', |
| 16 | + mock_base_tool_class) |
| 17 | + mocker.patch('backend.utils.langchain_utils.isinstance', |
| 18 | + return_value=True) |
70 | 19 |
|
71 | | - mock_loader1 = MagicMock() |
72 | | - mock_loader2 = MagicMock() |
73 | | - mock_spec_obj1.loader = mock_loader1 |
74 | | - mock_spec_obj2.loader = mock_loader2 |
| 20 | + result = _is_langchain_tool(mock_tool_instance) |
| 21 | + assert result is True |
75 | 22 |
|
76 | | - # 设置过滤函数始终返回True |
77 | | - def mock_filter(obj): |
78 | | - return obj is mock_tool1 or obj is mock_tool2 |
| 23 | + def test_is_langchain_tool_with_non_base_tool(self, mocker): |
| 24 | + """Returns False for objects that are not instances of BaseTool""" |
| 25 | + mock_base_tool_class = MagicMock() |
79 | 26 |
|
80 | | - # 执行函数 |
81 | | - result = self.discover_langchain_modules(filter_func=mock_filter) |
| 27 | + mocker.patch('langchain_core.tools.BaseTool', |
| 28 | + mock_base_tool_class) |
| 29 | + mocker.patch('backend.utils.langchain_utils.isinstance', |
| 30 | + return_value=False) |
82 | 31 |
|
83 | | - # 验证loader.exec_module被调用 |
84 | | - mock_loader1.exec_module.assert_called_once_with(mock_module_obj1) |
85 | | - mock_loader2.exec_module.assert_called_once_with(mock_module_obj2) |
| 32 | + result = _is_langchain_tool("not a tool") |
| 33 | + assert result is False |
86 | 34 |
|
87 | | - # 验证结果 |
88 | | - self.assertEqual(len(result), 2) |
89 | | - discovered_objs = [obj for (obj, _) in result] |
90 | | - self.assertIn(mock_tool1, discovered_objs) |
91 | | - self.assertIn(mock_tool2, discovered_objs) |
92 | | - |
93 | | - def test_discover_langchain_modules_directory_not_found(self): |
| 35 | + def test_discover_langchain_modules_success(self, mocker): |
| 36 | + """测试成功发现LangChain工具的情况""" |
| 37 | + # 创建一个临时目录结构 |
| 38 | + mocker.patch('os.path.isdir', return_value=True) |
| 39 | + mocker.patch('os.listdir', return_value=[ |
| 40 | + 'tool1.py', 'tool2.py', '__init__.py', 'not_a_py_file.txt']) |
| 41 | + mock_spec = mocker.patch('importlib.util.spec_from_file_location') |
| 42 | + mock_module_from_spec = mocker.patch('importlib.util.module_from_spec') |
| 43 | + |
| 44 | + # 创建模拟工具对象 |
| 45 | + mock_tool1 = MagicMock(name="tool1") |
| 46 | + mock_tool2 = MagicMock(name="tool2") |
| 47 | + |
| 48 | + # 设置模拟module |
| 49 | + mock_module_obj1 = MagicMock() |
| 50 | + mock_module_obj1.tool_obj1 = mock_tool1 |
| 51 | + |
| 52 | + mock_module_obj2 = MagicMock() |
| 53 | + mock_module_obj2.tool_obj2 = mock_tool2 |
| 54 | + |
| 55 | + mock_module_from_spec.side_effect = [ |
| 56 | + mock_module_obj1, mock_module_obj2] |
| 57 | + |
| 58 | + # 设置模拟spec和loader |
| 59 | + mock_spec_obj1 = MagicMock() |
| 60 | + mock_spec_obj2 = MagicMock() |
| 61 | + mock_spec.side_effect = [mock_spec_obj1, mock_spec_obj2] |
| 62 | + |
| 63 | + mock_loader1 = MagicMock() |
| 64 | + mock_loader2 = MagicMock() |
| 65 | + mock_spec_obj1.loader = mock_loader1 |
| 66 | + mock_spec_obj2.loader = mock_loader2 |
| 67 | + |
| 68 | + # 设置过滤函数始终返回True |
| 69 | + def mock_filter(obj): |
| 70 | + return obj is mock_tool1 or obj is mock_tool2 |
| 71 | + |
| 72 | + # 执行函数 |
| 73 | + result = discover_langchain_modules(filter_func=mock_filter) |
| 74 | + |
| 75 | + # 验证loader.exec_module被调用 |
| 76 | + mock_loader1.exec_module.assert_called_once_with(mock_module_obj1) |
| 77 | + mock_loader2.exec_module.assert_called_once_with(mock_module_obj2) |
| 78 | + |
| 79 | + # 验证结果 |
| 80 | + assert len(result) == 2 |
| 81 | + discovered_objs = [obj for (obj, _) in result] |
| 82 | + assert mock_tool1 in discovered_objs |
| 83 | + assert mock_tool2 in discovered_objs |
| 84 | + |
| 85 | + def test_discover_langchain_modules_directory_not_found(self, mocker): |
94 | 86 | """测试目录不存在的情况""" |
95 | | - with patch('os.path.isdir', return_value=False): |
96 | | - result = self.discover_langchain_modules( |
97 | | - directory="non_existent_dir") |
98 | | - self.assertEqual(result, []) |
| 87 | + mocker.patch('os.path.isdir', return_value=False) |
| 88 | + result = discover_langchain_modules(directory="non_existent_dir") |
| 89 | + assert result == [] |
99 | 90 |
|
100 | | - def test_discover_langchain_modules_module_exception(self): |
| 91 | + def test_discover_langchain_modules_module_exception(self, mocker, mock_logger): |
101 | 92 | """测试处理模块异常的情况""" |
102 | | - with patch('os.path.isdir', return_value=True), \ |
103 | | - patch('os.listdir', return_value=['error_module.py']), \ |
104 | | - patch('importlib.util.spec_from_file_location') as mock_spec, \ |
105 | | - patch('backend.utils.langchain_utils.logger', logger_mock): |
106 | | - |
107 | | - # 设置spec_from_file_location抛出异常 |
108 | | - mock_spec.side_effect = Exception("Module error") |
109 | | - |
110 | | - # 执行函数 - 应该捕获异常并继续 |
111 | | - result = self.discover_langchain_modules() |
112 | | - |
113 | | - # 验证结果为空列表 |
114 | | - self.assertEqual(result, []) |
115 | | - # 验证错误被记录 |
116 | | - self.assertTrue(logger_mock.error.called) |
117 | | - # 验证错误消息包含预期内容 |
118 | | - logger_mock.error.assert_called_with( |
119 | | - "Error processing module error_module.py: Module error") |
120 | | - |
121 | | - def test_discover_langchain_modules_spec_loader_none(self): |
| 93 | + mocker.patch('os.path.isdir', return_value=True) |
| 94 | + mocker.patch('os.listdir', return_value=['error_module.py']) |
| 95 | + mock_spec = mocker.patch('importlib.util.spec_from_file_location') |
| 96 | + mocker.patch('backend.utils.langchain_utils.logger', mock_logger) |
| 97 | + |
| 98 | + # 设置spec_from_file_location抛出异常 |
| 99 | + mock_spec.side_effect = Exception("Module error") |
| 100 | + |
| 101 | + # 执行函数 - 应该捕获异常并继续 |
| 102 | + result = discover_langchain_modules() |
| 103 | + |
| 104 | + # 验证结果为空列表 |
| 105 | + assert result == [] |
| 106 | + # 验证错误被记录 |
| 107 | + assert mock_logger.error.called |
| 108 | + # 验证错误消息包含预期内容 |
| 109 | + mock_logger.error.assert_called_with( |
| 110 | + "Error processing module error_module.py: Module error") |
| 111 | + |
| 112 | + def test_discover_langchain_modules_spec_loader_none(self, mocker, mock_logger): |
122 | 113 | """测试spec或loader为None的情况""" |
123 | | - with patch('os.path.isdir', return_value=True), \ |
124 | | - patch('os.listdir', return_value=['invalid_module.py']), \ |
125 | | - patch('importlib.util.spec_from_file_location', return_value=None), \ |
126 | | - patch('backend.utils.langchain_utils.logger', logger_mock): |
127 | | - |
128 | | - # 执行函数 |
129 | | - result = self.discover_langchain_modules() |
130 | | - |
131 | | - # 验证结果为空列表 |
132 | | - self.assertEqual(result, []) |
133 | | - # 验证警告被记录 |
134 | | - self.assertTrue(logger_mock.warning.called) |
135 | | - # 验证警告消息包含预期内容 - 检查是否包含文件名 |
136 | | - actual_call = logger_mock.warning.call_args[0][0] |
137 | | - self.assertIn("Failed to load spec for", actual_call) |
138 | | - self.assertIn("invalid_module.py", actual_call) |
139 | | - |
140 | | - def test_discover_langchain_modules_custom_filter(self): |
| 114 | + mocker.patch('os.path.isdir', return_value=True) |
| 115 | + mocker.patch('os.listdir', return_value=['invalid_module.py']) |
| 116 | + mocker.patch('importlib.util.spec_from_file_location', |
| 117 | + return_value=None) |
| 118 | + mocker.patch('backend.utils.langchain_utils.logger', mock_logger) |
| 119 | + |
| 120 | + # 执行函数 |
| 121 | + result = discover_langchain_modules() |
| 122 | + |
| 123 | + # 验证结果为空列表 |
| 124 | + assert result == [] |
| 125 | + # 验证警告被记录 |
| 126 | + assert mock_logger.warning.called |
| 127 | + # 验证警告消息包含预期内容 - 检查是否包含文件名 |
| 128 | + actual_call = mock_logger.warning.call_args[0][0] |
| 129 | + assert "Failed to load spec for" in actual_call |
| 130 | + assert "invalid_module.py" in actual_call |
| 131 | + |
| 132 | + def test_discover_langchain_modules_custom_filter(self, mocker): |
141 | 133 | """测试使用自定义过滤函数的情况""" |
142 | | - with patch('os.path.isdir', return_value=True), \ |
143 | | - patch('os.listdir', return_value=['tool.py']), \ |
144 | | - patch('importlib.util.spec_from_file_location') as mock_spec, \ |
145 | | - patch('importlib.util.module_from_spec') as mock_module_from_spec: |
146 | | - |
147 | | - # 创建两个对象,一个通过过滤,一个不通过 |
148 | | - obj_pass = MagicMock(name="pass_object") |
149 | | - obj_fail = MagicMock(name="fail_object") |
150 | | - |
151 | | - # 设置模拟module,使其包含我们的两个测试对象 |
152 | | - mock_module_obj = MagicMock() |
153 | | - mock_module_obj.obj_pass = obj_pass |
154 | | - mock_module_obj.obj_fail = obj_fail |
155 | | - mock_module_from_spec.return_value = mock_module_obj |
156 | | - |
157 | | - # 设置模拟spec和loader |
158 | | - mock_spec_obj = MagicMock() |
159 | | - mock_spec.return_value = mock_spec_obj |
160 | | - mock_loader = MagicMock() |
161 | | - mock_spec_obj.loader = mock_loader |
162 | | - |
163 | | - # 自定义过滤函数,只接受obj_pass |
164 | | - def custom_filter(obj): |
165 | | - return obj is obj_pass |
166 | | - |
167 | | - # 执行函数 |
168 | | - result = self.discover_langchain_modules(filter_func=custom_filter) |
169 | | - |
170 | | - # 验证loader.exec_module被调用 |
171 | | - mock_loader.exec_module.assert_called_once_with(mock_module_obj) |
172 | | - |
173 | | - # 验证结果 - 应该只有一个对象通过过滤 |
174 | | - self.assertEqual(len(result), 1) |
175 | | - self.assertEqual(result[0][0], obj_pass) |
176 | | - |
177 | | - |
178 | | -if __name__ == "__main__": |
179 | | - unittest.main() |
| 134 | + mocker.patch('os.path.isdir', return_value=True) |
| 135 | + mocker.patch('os.listdir', return_value=['tool.py']) |
| 136 | + mock_spec = mocker.patch('importlib.util.spec_from_file_location') |
| 137 | + mock_module_from_spec = mocker.patch('importlib.util.module_from_spec') |
| 138 | + |
| 139 | + # 创建两个对象,一个通过过滤,一个不通过 |
| 140 | + obj_pass = MagicMock(name="pass_object") |
| 141 | + obj_fail = MagicMock(name="fail_object") |
| 142 | + |
| 143 | + # 设置模拟module,使其包含我们的两个测试对象 |
| 144 | + mock_module_obj = MagicMock() |
| 145 | + mock_module_obj.obj_pass = obj_pass |
| 146 | + mock_module_obj.obj_fail = obj_fail |
| 147 | + mock_module_from_spec.return_value = mock_module_obj |
| 148 | + |
| 149 | + # 设置模拟spec和loader |
| 150 | + mock_spec_obj = MagicMock() |
| 151 | + mock_spec.return_value = mock_spec_obj |
| 152 | + mock_loader = MagicMock() |
| 153 | + mock_spec_obj.loader = mock_loader |
| 154 | + |
| 155 | + # 自定义过滤函数,只接受obj_pass |
| 156 | + def custom_filter(obj): |
| 157 | + return obj is obj_pass |
| 158 | + |
| 159 | + # 执行函数 |
| 160 | + result = discover_langchain_modules(filter_func=custom_filter) |
| 161 | + |
| 162 | + # 验证loader.exec_module被调用 |
| 163 | + mock_loader.exec_module.assert_called_once_with(mock_module_obj) |
| 164 | + |
| 165 | + # 验证结果 - 应该只有一个对象通过过滤 |
| 166 | + assert len(result) == 1 |
| 167 | + assert result[0][0] == obj_pass |
0 commit comments