diff --git a/bdiff/__init__.py b/bdiff/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bdiff/git_bdiff.py b/bdiff/git_bdiff.py index 77f4cf0..2db51dd 100644 --- a/bdiff/git_bdiff.py +++ b/bdiff/git_bdiff.py @@ -37,14 +37,10 @@ def __init__(self, cmd): ) -class GitBDiff: - """Class which generates a branch diff.""" - - # Name of primary branch - default is main - primary_branch = "main" - - # Match hex commit IDs - _hash_pattern = re.compile(r"^\s*([0-9a-f]{40})\s*$") +class GitBase: + """ + Base class for gitbdiff functionality + """ # Match branch names. This should catch all valid names but may # also some invalid names through. This should matter given that @@ -52,9 +48,10 @@ class GitBDiff: # overview of the naming scheme, see man git check-ref-format _branch_pattern = re.compile(r"^\s*([^\s~\^\:\?\*\[]+[^.])\s*$") - def __init__(self, parent=None, repo=None): - self.parent = parent or self.primary_branch + # Text returned if in a detached head + detached_head_reference = "detched_head_state" + def __init__(self, parent=None, repo=None): if repo is None: self._repo = None else: @@ -62,9 +59,71 @@ def __init__(self, parent=None, repo=None): if not self._repo.is_dir(): raise GitBDiffError(f"{repo} is not a directory") + def get_branch_name(self): + """Get the name of the current branch.""" + result = None + for line in self.run_git(["branch", "--show-current"]): + # Set m to self._branch_pattern result + # Then check m evaluates to True + if m := self._branch_pattern.match(line): + result = m.group(1) + break + else: + # Check for being in a Detached Head state + for line in self.run_git(["branch"]): + if "HEAD detached" in line: + result = self.detached_head_reference + break + else: + raise GitBDiffError("unable to get branch name") + return result + + def run_git(self, args): + """Run a git command and yield the output.""" + + if not isinstance(args, list): + raise TypeError("args must be a list") + cmd = ["git"] + args + + # Run the the command in the repo directory, capture the + # output, and check for errors. The build in error check is + # not used to allow specific git errors to be treated more + # precisely + proc = subprocess.run( + cmd, capture_output=True, check=False, shell=False, cwd=self._repo + ) + + for line in proc.stderr.decode("utf-8").split("\n"): + if line.startswith("fatal: not a git repository"): + raise GitBDiffNotGit(cmd) + if line.startswith("fatal: "): + raise GitBDiffError(line[7:]) + + if proc.returncode != 0: + raise GitBDiffError(f"command returned {proc.returncode}") + + yield from proc.stdout.decode("utf-8").split("\n") + + +class GitBDiff(GitBase): + """Class which generates a branch diff.""" + + # Name of primary branch - default is main + primary_branch = "main" + + # Match hex commit IDs + _hash_pattern = re.compile(r"^\s*([0-9a-f]{40})\s*$") + + def __init__(self, parent=None, repo=None): + self.parent = parent or self.primary_branch + + super().__init__(parent, repo) + self.ancestor = self.get_branch_point() self.current = self.get_latest_commit() self.branch = self.get_branch_name() + if self.branch == self.detached_head_reference: + raise GitBDiffError("Can't get a diff for a repo in detached head state") def get_branch_point(self): """Get the branch point from the parent repo. @@ -96,17 +155,6 @@ def get_latest_commit(self): raise GitBDiffError("current revision not found") return result - def get_branch_name(self): - """Get the name of the current branch.""" - result = None - for line in self.run_git(["branch", "--show-current"]): - if m := self._branch_pattern.match(line): - result = m.group(1) - break - else: - raise GitBDiffError("unable to get branch name") - return result - @property def is_branch(self): """Whether this is a branch or main.""" @@ -126,28 +174,24 @@ def files(self): if line != "": yield line - def run_git(self, args): - """Run a git command and yield the output.""" - if not isinstance(args, list): - raise TypeError("args must be a list") - cmd = ["git"] + args +class GitInfo(GitBase): + """ + Class to contain info of a git repo + """ - # Run the the command in the repo directory, capture the - # output, and check for errors. The build in error check is - # not used to allow specific git errors to be treated more - # precisely - proc = subprocess.run( - cmd, capture_output=True, check=False, shell=False, cwd=self._repo - ) + def __init__(self, repo=None): + super().__init__(repo=repo) - for line in proc.stderr.decode("utf-8").split("\n"): - if line.startswith("fatal: not a git repository"): - raise GitBDiffNotGit(cmd) - if line.startswith("fatal: "): - raise GitBDiffError(line[7:]) + self.branch = self.get_branch_name() - if proc.returncode != 0: - raise GitBDiffError(f"command returned {proc.returncode}") + def is_main(self): + """ + Returns true if branch matches a main-like branch name as defined below + Count detached head as main-like as we cannot get a diff for this + """ - yield from proc.stdout.decode("utf-8").split("\n") + main_like = ("main", "stable", "trunk", "master", self.detached_head_reference) + if self.branch in main_like: + return True + return False diff --git a/bdiff/tests/__init__.py b/bdiff/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bdiff/tests/test_git_bdiff.py b/bdiff/tests/test_git_bdiff.py index 6170a38..27302f8 100644 --- a/bdiff/tests/test_git_bdiff.py +++ b/bdiff/tests/test_git_bdiff.py @@ -12,7 +12,7 @@ import subprocess import pytest -from git_bdiff import GitBDiff, GitBDiffError, GitBDiffNotGit +from ..git_bdiff import GitBDiff, GitBDiffError, GitBDiffNotGit, GitInfo, GitBase # Disable warnings caused by the use of pytest fixtures @@ -58,9 +58,14 @@ def git_repo(tmpdir_factory): subprocess.run(["git", "checkout", "-b", "overwrite"], check=True) add_to_repo(0, 10, "Overwriting", "at") - # Switch back to the main branch ready for testing + # Switch back to the main branch subprocess.run(["git", "checkout", "main"], check=True) + # Add other trunk-like branches, finishing back in main + for branch in ("stable", "master", "trunk"): + subprocess.run(["git", "checkout", "-b", branch], check=True) + subprocess.run(["git", "checkout", "main"], check=True) + return location @@ -214,3 +219,41 @@ def test_git_run(git_repo): # Run a command that should return non-zero list(i for i in bdiff.run_git(["commit", "-m", "''"])) assert "command returned 1" in str(exc.value) + + +def test_is_main(git_repo): + """Test is_trunk function""" + + os.chdir(git_repo) + + for branch in ("stable", "master", "trunk", "main", "mybranch"): + info = GitInfo() + subprocess.run(["git", "checkout", branch], check=True) + if branch == "my_branch": + assert not info.is_main() + else: + assert info.is_main() + + +def find_previous_hash(): + """ + Loop over a git log output and extract a hash that isn't the current head + """ + + result = subprocess.run(["git", "log"], check=True, capture_output=True, text=True) + for line in result.stdout.split("\n"): + if line.startswith("commit") and "HEAD" not in line: + return line.split()[1] + + +def test_detached_head(git_repo): + """Test Detached Head State""" + + os.chdir(git_repo) + subprocess.run(["git", "checkout", "main"], check=True) + + commit_hash = find_previous_hash() + subprocess.run(["git", "checkout", commit_hash], check=True) + + git_base = GitBase() + assert git_base.get_branch_name() == git_base.detached_head_reference