Skip to content

Commit b72aa84

Browse files
authored
Merge branch 'main' into fix/dont-extract-object-properties
2 parents d693eeb + acd4028 commit b72aa84

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1193
-1049
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
288288
normalized_ignore_paths = []
289289
for path in args.ignore_paths:
290290
path_obj = Path(path)
291-
assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist"
292-
normalized_ignore_paths.append(path_obj.resolve())
291+
if path_obj.exists():
292+
normalized_ignore_paths.append(path_obj.resolve())
293+
# Silently skip non-existent paths (e.g., .next, dist before build)
293294
args.ignore_paths = normalized_ignore_paths
294295
# Project root path is one level above the specified directory, because that's where the module can be imported from
295296
args.module_root = Path(args.module_root).resolve()
@@ -358,7 +359,7 @@ def _handle_show_config() -> None:
358359
detected = detect_project(project_root)
359360

360361
# Check if config exists or is auto-detected
361-
config_exists = has_existing_config(project_root)
362+
config_exists, _ = has_existing_config(project_root)
362363
status = "Saved config" if config_exists else "Auto-detected (not saved)"
363364

364365
console.print()
@@ -400,7 +401,8 @@ def _handle_reset_config(confirm: bool = True) -> None:
400401

401402
project_root = Path.cwd()
402403

403-
if not has_existing_config(project_root):
404+
config_exists, _ = has_existing_config(project_root)
405+
if not config_exists:
404406
console.print("[yellow]No Codeflash configuration found to remove.[/yellow]")
405407
return
406408

codeflash/cli_cmds/init_javascript.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from git import InvalidGitRepositoryError, Repo
1717
from rich.console import Group
1818
from rich.panel import Panel
19+
from rich.prompt import Confirm
1920
from rich.table import Table
2021
from rich.text import Text
2122

@@ -26,7 +27,6 @@
2627
from codeflash.code_utils.git_utils import get_git_remotes
2728
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
2829
from codeflash.telemetry.posthog_cf import ph
29-
from rich.prompt import Confirm
3030

3131

3232
class ProjectLanguage(Enum):
@@ -165,22 +165,21 @@ def get_package_install_command(project_root: Path, package: str, dev: bool = Tr
165165
if dev:
166166
cmd.append("--save-dev")
167167
return cmd
168-
elif pkg_manager == JsPackageManager.YARN:
168+
if pkg_manager == JsPackageManager.YARN:
169169
cmd = ["yarn", "add", package]
170170
if dev:
171171
cmd.append("--dev")
172172
return cmd
173-
elif pkg_manager == JsPackageManager.BUN:
173+
if pkg_manager == JsPackageManager.BUN:
174174
cmd = ["bun", "add", package]
175175
if dev:
176176
cmd.append("--dev")
177177
return cmd
178-
else:
179-
# Default to npm
180-
cmd = ["npm", "install", package]
181-
if dev:
182-
cmd.append("--save-dev")
183-
return cmd
178+
# Default to npm
179+
cmd = ["npm", "install", package]
180+
if dev:
181+
cmd.append("--save-dev")
182+
return cmd
184183

185184

186185
def init_js_project(language: ProjectLanguage) -> None:

codeflash/code_utils/code_extractor.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,9 +1577,11 @@ def get_opt_review_metrics(
15771577
15781578
Returns:
15791579
Markdown-formatted string with code blocks showing calling functions.
1580+
15801581
"""
1581-
from codeflash.languages.base import FunctionInfo, ParentInfo, ReferenceInfo
1582+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
15821583
from codeflash.languages.registry import get_language_support
1584+
from codeflash.models.models import FunctionParent
15831585

15841586
start_time = time.perf_counter()
15851587

@@ -1596,19 +1598,19 @@ def get_opt_review_metrics(
15961598
else:
15971599
function_name, class_name = qualified_name_split[1], qualified_name_split[0]
15981600

1599-
# Create a FunctionInfo for the function
1601+
# Create a FunctionToOptimize for the function
16001602
# We don't have full line info here, so we'll use defaults
1601-
parents = ()
1603+
parents: list[FunctionParent] = []
16021604
if class_name:
1603-
parents = (ParentInfo(name=class_name, type="ClassDef"),)
1605+
parents = [FunctionParent(name=class_name, type="ClassDef")]
16041606

1605-
func_info = FunctionInfo(
1606-
name=function_name,
1607+
func_info = FunctionToOptimize(
1608+
function_name=function_name,
16071609
file_path=file_path,
1608-
start_line=1,
1609-
end_line=1,
16101610
parents=parents,
1611-
language=language,
1611+
starting_line=1,
1612+
ending_line=1,
1613+
language=str(language),
16121614
)
16131615

16141616
# Find references using language support
@@ -1618,9 +1620,7 @@ def get_opt_review_metrics(
16181620
return ""
16191621

16201622
# Format references as markdown code blocks
1621-
calling_fns_details = _format_references_as_markdown(
1622-
references, file_path, project_root, language
1623-
)
1623+
calling_fns_details = _format_references_as_markdown(references, file_path, project_root, language)
16241624

16251625
except Exception as e:
16261626
logger.debug(f"Error getting function references: {e}")
@@ -1631,9 +1631,7 @@ def get_opt_review_metrics(
16311631
return calling_fns_details
16321632

16331633

1634-
def _format_references_as_markdown(
1635-
references: list, file_path: Path, project_root: Path, language: Language
1636-
) -> str:
1634+
def _format_references_as_markdown(references: list, file_path: Path, project_root: Path, language: Language) -> str:
16371635
"""Format references as markdown code blocks with calling function code.
16381636
16391637
Args:
@@ -1644,6 +1642,7 @@ def _format_references_as_markdown(
16441642
16451643
Returns:
16461644
Markdown-formatted string.
1645+
16471646
"""
16481647
# Group references by file
16491648
refs_by_file: dict[Path, list] = {}
@@ -1710,7 +1709,7 @@ def _format_references_as_markdown(
17101709
context_len += len(context_code)
17111710

17121711
if caller_contexts:
1713-
fn_call_context += f"```{lang_hint}:{path_relative}\n"
1712+
fn_call_context += f"```{lang_hint}:{path_relative.as_posix()}\n"
17141713
fn_call_context += "\n".join(caller_contexts)
17151714
fn_call_context += "\n```\n"
17161715

@@ -1728,11 +1727,11 @@ def _extract_calling_function(source_code: str, function_name: str, ref_line: in
17281727
17291728
Returns:
17301729
Source code of the function, or None if not found.
1730+
17311731
"""
17321732
if language == Language.PYTHON:
17331733
return _extract_calling_function_python(source_code, function_name, ref_line)
1734-
else:
1735-
return _extract_calling_function_js(source_code, function_name, ref_line)
1734+
return _extract_calling_function_js(source_code, function_name, ref_line)
17361735

17371736

17381737
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
@@ -1766,6 +1765,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line:
17661765
17671766
Returns:
17681767
Source code of the function, or None if not found.
1768+
17691769
"""
17701770
try:
17711771
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage

codeflash/code_utils/code_replacer.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def replace_function_definitions_for_language(
496496
497497
"""
498498
from codeflash.languages import get_language_support
499-
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
499+
from codeflash.languages.base import Language
500500

501501
original_source_code: str = module_abspath.read_text(encoding="utf8")
502502
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@@ -523,25 +523,15 @@ def replace_function_definitions_for_language(
523523
and function_to_optimize.ending_line
524524
and function_to_optimize.file_path == module_abspath
525525
):
526-
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
527-
func_info = FunctionInfo(
528-
name=function_to_optimize.function_name,
529-
file_path=module_abspath,
530-
start_line=function_to_optimize.starting_line,
531-
end_line=function_to_optimize.ending_line,
532-
parents=parents,
533-
is_async=function_to_optimize.is_async,
534-
language=language,
535-
)
536526
# Extract just the target function from the optimized code
537527
optimized_func = _extract_function_from_code(
538528
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
539529
)
540530
if optimized_func:
541-
new_code = lang_support.replace_function(original_source_code, func_info, optimized_func)
531+
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
542532
else:
543533
# Fallback: use the entire optimized code (for simple single-function files)
544-
new_code = lang_support.replace_function(original_source_code, func_info, code_to_apply)
534+
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
545535
else:
546536
# For helper files or when we don't have precise line info:
547537
# Find each function by name in both original and optimized code
@@ -559,15 +549,17 @@ def replace_function_definitions_for_language(
559549
# Find the function in current code
560550
func = None
561551
for f in current_functions:
562-
if func_name in (f.qualified_name, f.name):
552+
if func_name in (f.qualified_name, f.function_name):
563553
func = f
564554
break
565555

566556
if func is None:
567557
continue
568558

569559
# Extract just this function from the optimized code
570-
optimized_func = _extract_function_from_code(lang_support, code_to_apply, func.name, module_abspath)
560+
optimized_func = _extract_function_from_code(
561+
lang_support, code_to_apply, func.function_name, module_abspath
562+
)
571563
if optimized_func:
572564
new_code = lang_support.replace_function(new_code, func, optimized_func)
573565
modified = True
@@ -606,13 +598,13 @@ def _extract_function_from_code(
606598
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
607599
functions = lang_support.discover_functions_from_source(source_code, file_path)
608600
for func in functions:
609-
if func.name == function_name:
601+
if func.function_name == function_name:
610602
# Extract the function's source using line numbers
611603
# Use doc_start_line if available to include JSDoc/docstring
612604
lines = source_code.splitlines(keepends=True)
613-
effective_start = func.doc_start_line or func.start_line
614-
if effective_start and func.end_line and effective_start <= len(lines):
615-
func_lines = lines[effective_start - 1 : func.end_line]
605+
effective_start = func.doc_start_line or func.starting_line
606+
if effective_start and func.ending_line and effective_start <= len(lines):
607+
func_lines = lines[effective_start - 1 : func.ending_line]
616608
return "".join(func_lines)
617609
except Exception as e:
618610
logger.debug(f"Error extracting function {function_name}: {e}")

codeflash/code_utils/env_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from codeflash.code_utils.code_utils import exit_with_message
1414
from codeflash.code_utils.formatter import format_code
1515
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
16-
from codeflash.languages.base import Language
1716
from codeflash.languages.registry import get_language_support_by_common_formatters
1817
from codeflash.lsp.helpers import is_LSP_enabled
1918

@@ -44,9 +43,9 @@ def check_formatter_installed(
4443
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")
4544
return True
4645

47-
if lang_support.language == Language.PYTHON:
46+
if str(lang_support.language) == "python":
4847
tmp_code = """print("hello world")"""
49-
elif lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
48+
elif str(lang_support.language) in ("javascript", "typescript"):
5049
tmp_code = "console.log('hello world');"
5150
else:
5251
return True

codeflash/code_utils/formatter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import isort
1414

1515
from codeflash.cli_cmds.console import console, logger
16-
from codeflash.languages.registry import get_language_support
1716
from codeflash.lsp.helpers import is_LSP_enabled
1817

1918

@@ -43,6 +42,8 @@ def split_lines(text: str) -> list[str]:
4342
def apply_formatter_cmds(
4443
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
4544
) -> tuple[Path, str, bool]:
45+
from codeflash.languages.registry import get_language_support
46+
4647
if not path.exists():
4748
msg = f"File {path} does not exist. Cannot apply formatter commands."
4849
raise FileNotFoundError(msg)
@@ -90,6 +91,8 @@ def is_diff_line(line: str) -> bool:
9091

9192

9293
def format_generated_code(generated_test_source: str, formatter_cmds: list[str], language: str = "python") -> str:
94+
from codeflash.languages.registry import get_language_support
95+
9396
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
9497
if formatter_name == "disabled": # nothing to do if no formatter provided
9598
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
@@ -114,6 +117,8 @@ def format_code(
114117
print_status: bool = True,
115118
exit_on_failure: bool = True,
116119
) -> str:
120+
from codeflash.languages.registry import get_language_support
121+
117122
if is_LSP_enabled():
118123
exit_on_failure = False
119124

codeflash/code_utils/static_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import BaseModel, ConfigDict, field_validator
99

1010
if TYPE_CHECKING:
11-
from codeflash.models.models import FunctionParent
11+
from codeflash.models.function_types import FunctionParent
1212

1313

1414
ObjectDefT = TypeVar("ObjectDefT", ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)

codeflash/context/code_context_extractor.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
2424

2525
# Language support imports for multi-language code context extraction
26-
from codeflash.languages import is_python
27-
from codeflash.languages.base import Language
26+
from codeflash.languages import Language, is_python
2827
from codeflash.models.models import (
2928
CodeContextType,
3029
CodeOptimizationContext,
@@ -234,27 +233,13 @@ def get_code_optimization_context_for_language(
234233
235234
"""
236235
from codeflash.languages import get_language_support
237-
from codeflash.languages.base import FunctionInfo, ParentInfo
238236

239237
# Get language support for this function
240238
language = Language(function_to_optimize.language)
241239
lang_support = get_language_support(language)
242240

243-
# Convert FunctionToOptimize to FunctionInfo for language support
244-
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
245-
func_info = FunctionInfo(
246-
name=function_to_optimize.function_name,
247-
file_path=function_to_optimize.file_path,
248-
start_line=function_to_optimize.starting_line or 1,
249-
end_line=function_to_optimize.ending_line or 1,
250-
parents=parents,
251-
is_async=function_to_optimize.is_async,
252-
is_method=len(function_to_optimize.parents) > 0,
253-
language=language,
254-
)
255-
256241
# Extract code context using language support
257-
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
242+
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)
258243

259244
# Build imports string if available
260245
imports_code = "\n".join(code_context.imports) if code_context.imports else ""

0 commit comments

Comments
 (0)