Skip to content

Commit c76cfaa

Browse files
authored
Merge pull request #1266 from codeflash-ai/alias_support_in_ts
alias support and vitest imports
2 parents f21c12e + 824b2ec commit c76cfaa

File tree

6 files changed

+136
-7
lines changed

6 files changed

+136
-7
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
288288
normalized_ignore_paths = []
289289
for path in args.ignore_paths:
290290
path_obj = Path(path)
291-
assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist"
292-
normalized_ignore_paths.append(path_obj.resolve())
291+
if path_obj.exists():
292+
normalized_ignore_paths.append(path_obj.resolve())
293+
# Silently skip non-existent paths (e.g., .next, dist before build)
293294
args.ignore_paths = normalized_ignore_paths
294295
# Project root path is one level above the specified directory, because that's where the module can be imported from
295296
args.module_root = Path(args.module_root).resolve()

codeflash/languages/javascript/import_resolver.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def _resolve_module_path(self, module_path: str, source_dir: Path) -> Path | Non
126126
if module_path.startswith("/"):
127127
return self._resolve_absolute_import(module_path)
128128

129+
# Handle @/ path alias (common in Next.js/TypeScript projects)
130+
# @/ maps to the project root
131+
if module_path.startswith("@/"):
132+
return self._resolve_path_alias(module_path[2:]) # Strip @/
133+
134+
# Handle ~/ path alias (another common pattern)
135+
if module_path.startswith("~/"):
136+
return self._resolve_path_alias(module_path[2:]) # Strip ~/
137+
129138
# Bare imports (e.g., 'lodash') are external packages
130139
return None
131140

@@ -197,6 +206,38 @@ def _resolve_absolute_import(self, module_path: str) -> Path | None:
197206

198207
return None
199208

209+
def _resolve_path_alias(self, module_path: str) -> Path | None:
210+
"""Resolve path alias imports like @/utils or ~/lib/helper.
211+
212+
Args:
213+
module_path: The import path without the alias prefix.
214+
215+
Returns:
216+
Resolved absolute path, or None if not found.
217+
218+
"""
219+
# Treat as relative to project root
220+
base_path = (self.project_root / module_path).resolve()
221+
222+
# Check if path is within project
223+
try:
224+
base_path.relative_to(self.project_root)
225+
except ValueError:
226+
logger.debug("Path alias resolves outside project root: %s", base_path)
227+
return None
228+
229+
# Try adding extensions
230+
resolved = self._try_extensions(base_path)
231+
if resolved:
232+
return resolved
233+
234+
# Try as directory with index file
235+
resolved = self._try_index_file(base_path)
236+
if resolved:
237+
return resolved
238+
239+
return None
240+
200241
def _try_extensions(self, base_path: Path) -> Path | None:
201242
"""Try adding various extensions to find the actual file.
202243
@@ -267,10 +308,19 @@ def _is_external_package(self, module_path: str) -> bool:
267308
if module_path.startswith("/"):
268309
return False
269310

311+
# @/ is a common path alias in Next.js/TypeScript projects mapping to project root
312+
# These are internal imports, not external npm packages
313+
if module_path.startswith("@/"):
314+
return False
315+
316+
# ~/ is another common path alias pattern
317+
if module_path.startswith("~/"):
318+
return False
319+
270320
# Bare imports without ./ or ../ are external packages
271321
# This includes:
272322
# - 'lodash'
273-
# - '@company/utils'
323+
# - '@company/utils' (scoped npm packages)
274324
# - 'react'
275325
# - 'fs' (Node.js built-ins)
276326
return True

codeflash/languages/javascript/module_system.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,70 @@ def ensure_module_system_compatibility(code: str, target_module_system: str) ->
330330
return convert_esm_to_commonjs(code)
331331

332332
return code
333+
334+
335+
def ensure_vitest_imports(code: str, test_framework: str) -> str:
336+
"""Ensure vitest test globals are imported when using vitest framework.
337+
338+
Vitest by default does not enable globals (describe, test, expect, etc.),
339+
so they must be explicitly imported. This function adds the import if missing.
340+
341+
Args:
342+
code: JavaScript/TypeScript test code.
343+
test_framework: The test framework being used (vitest, jest, mocha).
344+
345+
Returns:
346+
Code with vitest imports added if needed.
347+
348+
"""
349+
if test_framework != "vitest":
350+
return code
351+
352+
# Check if vitest imports already exist
353+
if "from 'vitest'" in code or 'from "vitest"' in code:
354+
return code
355+
356+
# Check if the code uses test functions that need to be imported
357+
test_globals = ["describe", "test", "it", "expect", "vi", "beforeEach", "afterEach", "beforeAll", "afterAll"]
358+
needs_import = any(f"{global_name}(" in code or f"{global_name} (" in code for global_name in test_globals)
359+
360+
if not needs_import:
361+
return code
362+
363+
# Determine which globals are actually used in the code
364+
used_globals = [g for g in test_globals if f"{g}(" in code or f"{g} (" in code]
365+
if not used_globals:
366+
return code
367+
368+
# Build the import statement
369+
import_statement = f"import {{ {', '.join(used_globals)} }} from 'vitest';\n"
370+
371+
# Find the first line that isn't a comment or empty
372+
lines = code.split("\n")
373+
insert_index = 0
374+
for i, line in enumerate(lines):
375+
stripped = line.strip()
376+
if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"):
377+
# Check if this line is an import/require - insert after imports
378+
if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "):
379+
continue
380+
insert_index = i
381+
break
382+
insert_index = i + 1
383+
384+
# Find the last import line to insert after it
385+
last_import_index = -1
386+
for i, line in enumerate(lines):
387+
stripped = line.strip()
388+
if stripped.startswith("import ") and "from " in stripped:
389+
last_import_index = i
390+
391+
if last_import_index >= 0:
392+
# Insert after the last import
393+
lines.insert(last_import_index + 1, import_statement.rstrip())
394+
else:
395+
# Insert at the beginning (after any leading comments)
396+
lines.insert(insert_index, import_statement.rstrip())
397+
398+
logger.debug("Added vitest imports: %s", used_globals)
399+
return "\n".join(lines)

codeflash/languages/javascript/support.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def default_file_extension(self) -> str:
5555
@property
5656
def test_framework(self) -> str:
5757
"""Primary test framework for JavaScript."""
58-
return "jest"
58+
from codeflash.languages.test_framework import get_js_test_framework_or_default
59+
60+
return get_js_test_framework_or_default()
5961

6062
@property
6163
def comment_prefix(self) -> str:

codeflash/verification/verification_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,13 @@ class TestConfig:
8686
def test_framework(self) -> str:
8787
"""Returns the appropriate test framework based on language.
8888
89-
Returns 'jest' for JavaScript/TypeScript, 'pytest' for Python (default).
89+
For JavaScript/TypeScript: uses the configured framework (vitest, jest, or mocha).
90+
For Python: uses pytest as default.
9091
"""
9192
if is_javascript():
92-
return "jest"
93+
from codeflash.languages.test_framework import get_js_test_framework_or_default
94+
95+
return get_js_test_framework_or_default()
9396
return "pytest"
9497

9598
def set_language(self, language: str) -> None:

codeflash/verification/verifier.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def generate_tests(
6969
instrument_generated_js_test,
7070
validate_and_fix_import_style,
7171
)
72-
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
72+
from codeflash.languages.javascript.module_system import (
73+
ensure_module_system_compatibility,
74+
ensure_vitest_imports,
75+
)
7376

7477
source_file = Path(function_to_optimize.file_path)
7578

@@ -81,6 +84,9 @@ def generate_tests(
8184
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
8285
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
8386

87+
# Ensure vitest imports are present when using vitest framework
88+
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
89+
8490
# Instrument for behavior verification (writes to SQLite)
8591
instrumented_behavior_test_source = instrument_generated_js_test(
8692
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR

0 commit comments

Comments
 (0)