Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 45 additions & 32 deletions codeflash/cli_cmds/init_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,40 +392,37 @@ def _prompt_custom_directory(dir_type: str) -> str:

def _get_git_remote_for_setup() -> str:
"""Get git remote for project setup."""
try:
repo = Repo(Path.cwd(), search_parent_directories=True)
git_remotes = get_git_remotes(repo)
if not git_remotes:
return ""

if len(git_remotes) == 1:
return git_remotes[0]

git_panel = Panel(
Text(
"Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
style="blue",
),
title="Git Remote Setup",
border_style="bright_blue",
)
console.print(git_panel)
console.print()
cwd = Path.cwd().as_posix()
git_remotes = _cached_git_remotes_for_cwd(cwd)
if not git_remotes:
return ""

git_questions = [
inquirer.List(
"git_remote",
message="Which git remote should Codeflash use?",
choices=git_remotes,
default="origin",
carousel=True,
)
]
if len(git_remotes) == 1:
return git_remotes[0]

git_answers = inquirer.prompt(git_questions, theme=_get_theme())
return git_answers["git_remote"] if git_answers else git_remotes[0]
except InvalidGitRepositoryError:
return ""
git_panel = Panel(
Text(
"Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
style="blue",
),
title="Git Remote Setup",
border_style="bright_blue",
)
console.print(git_panel)
console.print()

git_questions = [
inquirer.List(
"git_remote",
message="Which git remote should Codeflash use?",
choices=git_remotes,
default="origin",
carousel=True,
)
]

git_answers = inquirer.prompt(git_questions, theme=_get_theme())
return git_answers["git_remote"] if git_answers else git_remotes[0]


def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[str]:
Expand Down Expand Up @@ -547,6 +544,22 @@ def get_java_test_command(build_tool: JavaBuildTool) -> str:
return "mvn test"


@lru_cache(maxsize=32)
def _cached_repo_for_cwd(cwd: str) -> Repo | None:
try:
return Repo(Path(cwd), search_parent_directories=True)
except InvalidGitRepositoryError:
return None


@lru_cache(maxsize=32)
def _cached_git_remotes_for_cwd(cwd: str) -> list[str]:
repo = _cached_repo_for_cwd(cwd)
if not repo:
return []
return get_git_remotes(repo)


formatter_warning_shown = False

_SPOTLESS_COMMANDS = {
Expand Down
Loading