@@ -330,3 +330,70 @@ def ensure_module_system_compatibility(code: str, target_module_system: str) ->
330330 return convert_esm_to_commonjs (code )
331331
332332 return code
333+
334+
335+ def ensure_vitest_imports (code : str , test_framework : str ) -> str :
336+ """Ensure vitest test globals are imported when using vitest framework.
337+
338+ Vitest by default does not enable globals (describe, test, expect, etc.),
339+ so they must be explicitly imported. This function adds the import if missing.
340+
341+ Args:
342+ code: JavaScript/TypeScript test code.
343+ test_framework: The test framework being used (vitest, jest, mocha).
344+
345+ Returns:
346+ Code with vitest imports added if needed.
347+
348+ """
349+ if test_framework != "vitest" :
350+ return code
351+
352+ # Check if vitest imports already exist
353+ if "from 'vitest'" in code or 'from "vitest"' in code :
354+ return code
355+
356+ # Check if the code uses test functions that need to be imported
357+ test_globals = ["describe" , "test" , "it" , "expect" , "vi" , "beforeEach" , "afterEach" , "beforeAll" , "afterAll" ]
358+ needs_import = any (f"{ global_name } (" in code or f"{ global_name } (" in code for global_name in test_globals )
359+
360+ if not needs_import :
361+ return code
362+
363+ # Determine which globals are actually used in the code
364+ used_globals = [g for g in test_globals if f"{ g } (" in code or f"{ g } (" in code ]
365+ if not used_globals :
366+ return code
367+
368+ # Build the import statement
369+ import_statement = f"import {{ { ', ' .join (used_globals )} }} from 'vitest';\n "
370+
371+ # Find the first line that isn't a comment or empty
372+ lines = code .split ("\n " )
373+ insert_index = 0
374+ for i , line in enumerate (lines ):
375+ stripped = line .strip ()
376+ if stripped and not stripped .startswith ("//" ) and not stripped .startswith ("/*" ) and not stripped .startswith ("*" ):
377+ # Check if this line is an import/require - insert after imports
378+ if stripped .startswith ("import " ) or stripped .startswith ("const " ) or stripped .startswith ("let " ):
379+ continue
380+ insert_index = i
381+ break
382+ insert_index = i + 1
383+
384+ # Find the last import line to insert after it
385+ last_import_index = - 1
386+ for i , line in enumerate (lines ):
387+ stripped = line .strip ()
388+ if stripped .startswith ("import " ) and "from " in stripped :
389+ last_import_index = i
390+
391+ if last_import_index >= 0 :
392+ # Insert after the last import
393+ lines .insert (last_import_index + 1 , import_statement .rstrip ())
394+ else :
395+ # Insert at the beginning (after any leading comments)
396+ lines .insert (insert_index , import_statement .rstrip ())
397+
398+ logger .debug ("Added vitest imports: %s" , used_globals )
399+ return "\n " .join (lines )
0 commit comments