diff --git a/src/stack_pr/cli.py b/src/stack_pr/cli.py index d068998..47e6b9a 100755 --- a/src/stack_pr/cli.py +++ b/src/stack_pr/cli.py @@ -631,18 +631,78 @@ def get_next_available_branch_name(branch_name_template: str, name: str) -> str: return generate_branch_name(branch_name_template, int(branch_id or 0) + 1) +def get_commit_branch_name(commit_id: str) -> str | None: + """Get the branch name that a commit is currently on. + + Args: + commit_id: The commit hash to check. + + Returns: + The branch name if the commit is on a branch, None otherwise. + """ + # Get all branches that contain this commit + branches = get_command_output( + ["git", "branch", "--points-at", commit_id] + ).splitlines() + + # Clean up branch names by removing leading spaces and asterisks + branches = [branch.lstrip(" *") for branch in branches] + # Filter out HEAD references + branches = [ + branch + for branch in branches + if branch != "HEAD" and not branch.startswith("(HEAD") + ] + + if not branches: + return None + + # If the commit is on multiple branches, prefer the current branch + current_branch = get_current_branch_name() + if current_branch in branches: + return current_branch + + return branches[0] + + def set_head_branches( - st: list[StackEntry], remote: str, *, verbose: bool, branch_name_template: str + stack: list[StackEntry], + remote: str, + *, + verbose: bool, + branch_name_template: str, ) -> None: - """Set the head ref for each stack entry if it doesn't already have one.""" + """Set the head ref for each stack entry using the branch name the commit is currently on. + + If a commit is not on any branch, it will use the default branch name template. + Args: + stack: List of stack entries to process. + remote: Name of the remote repository. + verbose: Whether to print verbose output. + branch_name_template: Template string for generating branch names. + """ run_shell_command(["git", "fetch", "--prune", remote], quiet=not verbose) available_name = get_available_branch_name(remote, branch_name_template) - for e in filter(lambda e: not e.has_head(), st): - e.head = available_name - available_name = get_next_available_branch_name( - branch_name_template, available_name - ) + + for entry in filter(lambda e: not e.has_head(), stack): + # Try to get the branch name the commit is currently on + branch_name = get_commit_branch_name(entry.commit.commit_id()) + if branch_name: + # If the commit is on a branch, use that branch name + entry.head = branch_name + else: + # If the commit is not on any branch, fall back to the default template + entry.head = available_name + available_name = get_next_available_branch_name( + branch_name_template, available_name + ) + + if verbose: + log( + f"Using branch name '{entry.head}' for commit {entry.commit.commit_id()[:8]}", + level=2, + ) def init_local_branches(