Coverage for src/pylint_sort_functions/test_file_manager.py: 100%
61 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-12 16:06 +0200
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-12 16:06 +0200
1"""Test file discovery and reference detection for privacy fixing.
3This module provides functionality to find test files and analyze them for
4function references that need to be updated when functions are privatized.
5It handles both AST-based and string-based analysis of test files.
7Part of the refactoring described in GitHub Issue #32.
8"""
10import re
11from pathlib import Path
12from typing import List
14import astroid # type: ignore[import-untyped]
15from astroid import nodes
17from pylint_sort_functions import utils
19# Import types that will be referenced
20from pylint_sort_functions.privacy_types import FunctionTestReference
23class TestFileManager:
24 """Test file discovery and reference detection.
26 Handles finding test files and analyzing them for function references
27 that need to be updated when functions are privatized.
28 """
30 # Public methods
32 def find_test_files(self, project_root: Path) -> List[Path]:
33 """Find all test files in the project.
35 Uses the existing test detection logic to identify files that should
36 be updated when functions are privatized.
38 :param project_root: Root directory of the project
39 :returns: List of paths to test files
40 """
41 # Get all Python files in the project
42 all_python_files = utils.find_python_files(project_root)
43 test_files = []
45 for file_path in all_python_files:
46 try:
47 # Convert to module name for test detection
48 relative_path = file_path.relative_to(project_root)
49 module_name = str(relative_path.with_suffix("")).replace("/", ".")
51 if utils.is_unittest_file(module_name):
52 test_files.append(file_path)
53 except ValueError:
54 # Skip files that can't be made relative to project root
55 continue
57 return test_files
59 def find_test_references(
60 self, function_name: str, test_files: List[Path]
61 ) -> List[FunctionTestReference]:
62 """Find all references to a function in test files.
64 Scans test files for various types of function references:
65 - Import statements: from module import func
66 - Mock patches: @patch('module.func'), mocker.patch('module.func')
67 - Direct calls: module.func(), func()
69 :param function_name: Name of the function to find references for
70 :param test_files: List of test files to scan
71 :returns: List of test file references
72 """
73 test_references = []
75 for test_file in test_files:
76 try:
77 with open(test_file, "r", encoding="utf-8") as f:
78 content = f.read()
80 # Try to parse as AST for import detection
81 try:
82 module = astroid.parse(content, module_name=str(test_file))
83 file_refs = self._find_references_in_test_file(
84 function_name, test_file, module, content
85 )
86 test_references.extend(file_refs)
87 except Exception: # pylint: disable=broad-exception-caught
88 # If AST parsing fails, try string-based detection
89 file_refs = self._find_string_references_in_test_file(
90 function_name, test_file, content
91 )
92 test_references.extend(file_refs)
94 except Exception: # pylint: disable=broad-exception-caught
95 # Skip files that can't be read
96 continue
98 return test_references
100 # Private methods
102 def _find_references_in_test_file(
103 self,
104 function_name: str,
105 test_file: Path,
106 module: nodes.Module,
107 content: str,
108 ) -> List[FunctionTestReference]:
109 """Find function references in a test file using AST analysis.
111 :param function_name: Name of the function to find
112 :param test_file: Path to the test file being analyzed
113 :param module: Parsed AST module
114 :param content: File content for line-based analysis
115 :returns: List of test references found
116 """
117 references = []
119 # Find import statements
120 for node in module.nodes_of_class((nodes.ImportFrom, nodes.Import)):
121 if isinstance(node, nodes.ImportFrom):
122 # Handle: from module import func1, func2
123 if node.names:
124 for name, alias in node.names:
125 if name == function_name:
126 # Use alias if present, otherwise use original name
127 import_name = alias if alias else name
128 references.append(
129 FunctionTestReference(
130 file_path=test_file,
131 line=node.lineno,
132 col=node.col_offset,
133 context="import",
134 reference_text=(
135 f"from {node.module} import {name}"
136 f"{' as ' + import_name if alias else ''}"
137 ),
138 )
139 )
141 # Find string-based mock patches in the content
142 string_refs = self._find_string_references_in_test_file(
143 function_name, test_file, content
144 )
145 references.extend(string_refs)
147 return references
149 def _find_string_references_in_test_file(
150 self, function_name: str, test_file: Path, content: str
151 ) -> List[FunctionTestReference]:
152 """Find function references in test file using string-based analysis.
154 This handles cases where AST parsing fails or for string literals
155 like mock patches that contain function names.
157 :param function_name: Name of the function to find
158 :param test_file: Path to the test file being analyzed
159 :param content: File content to search
160 :returns: List of test references found
161 """
162 references = []
163 lines = content.split("\n")
165 # Pattern for mock patches: @patch('module.function_name')
166 patch_pattern = rf"@patch\(['\"]([^'\"]*\.{re.escape(function_name)})['\"]"
168 # Pattern for mocker.patch calls: mocker.patch('module.function_name')
169 mocker_pattern = rf"\.patch\(['\"]([^'\"]*\.{re.escape(function_name)})['\"]"
171 for line_num, line in enumerate(lines, 1):
172 # Check for patch decorators
173 match = re.search(patch_pattern, line)
174 if match:
175 references.append(
176 FunctionTestReference(
177 file_path=test_file,
178 line=line_num,
179 col=match.start(),
180 context="mock_patch",
181 reference_text=match.group(1),
182 )
183 )
185 # Check for mocker.patch calls
186 match = re.search(mocker_pattern, line)
187 if match:
188 references.append(
189 FunctionTestReference(
190 file_path=test_file,
191 line=line_num,
192 col=match.start(),
193 context="mock_patch",
194 reference_text=match.group(1),
195 )
196 )
198 return references