From 985135f0120bfb940839af1a72e68d621e05e2b5 Mon Sep 17 00:00:00 2001 From: Nikhil Kandpal Date: Wed, 3 May 2023 15:19:18 -0400 Subject: [PATCH 1/4] store checkpoint format information in .thetaconfig --- git_theta/checkpoints/base.py | 43 +++++++------ git_theta/git_utils.py | 90 ++++++++++++++++++++++++++- git_theta/merges/average.py | 10 +-- git_theta/scripts/git_theta.py | 18 ++++++ git_theta/scripts/git_theta_diff.py | 2 +- git_theta/scripts/git_theta_filter.py | 10 +-- git_theta/updates/base.py | 11 ++-- git_theta/updates/sparse.py | 9 +-- git_theta/utils.py | 3 +- 9 files changed, 147 insertions(+), 49 deletions(-) diff --git a/git_theta/checkpoints/base.py b/git_theta/checkpoints/base.py index c96ef52..d531470 100644 --- a/git_theta/checkpoints/base.py +++ b/git_theta/checkpoints/base.py @@ -1,5 +1,6 @@ """Base class and utilities for different checkpoint format backends.""" +import fnmatch import os import sys from abc import ABCMeta, abstractmethod @@ -12,7 +13,7 @@ else: from importlib.metadata import entry_points -from git_theta import utils +from git_theta import git_utils, utils @utils.abstract_classattributes("name") @@ -102,42 +103,40 @@ def diff(cls, m1: "Checkpoint", m2: "Checkpoint") -> "Checkpoint": return added, removed, modified -def get_checkpoint_handler_name(checkpoint_type: Optional[str] = None) -> str: - """Get the name of the checkpoint handler to use. - - Order of precedence is - 1. `checkpoint_type` argument - 2. `$GIT_THETA_CHECKPOINT_TYPE` environment variable - 3. default value (currently pytorch) +def get_checkpoint_handler_name(checkpoint_path: str) -> Optional[str]: + """Get the name of the checkpoint handler based on entry in .thetaconfig for the current checkpoint path Parameters ---------- - checkpoint_type - Name of the checkpoint handler + checkpoint_path + Path to checkpoint in repo Returns ------- str Name of the checkpoint handler """ - # TODO(bdlester): Find a better way to include checkpoint type information - # in git clean filters that are run without `git theta add`. - # TODO: Don't default to pytorch once other checkpoint formats are supported. - return checkpoint_type or utils.EnvVarConstants.CHECKPOINT_TYPE + repo = git_utils.get_git_repo() + checkpoint_path = git_utils.get_relative_path_from_root(repo, checkpoint_path) + config_file = git_utils.get_config_file(repo) + config = git_utils.read_config(config_file) + for pattern, config_entry in config.items(): + if fnmatch.fnmatchcase(checkpoint_path, pattern): + return config_entry.get("checkpoint_format", None) -def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint: - """Get the checkpoint handler either by name or from an environment variable. + return None - Gets the checkpoint handler either for the `checkpoint_type` argument or - `$GIT_THETA_CHECKPOINT_TYPE` environment variable. - Defaults to pytorch when neither are defined. +def get_checkpoint_handler(checkpoint_path: str) -> Checkpoint: + """Get the checkpoint handler for the current checkpoint path + + Gets the checkpoint handler from an entry in .thetaconfig Parameters ---------- - checkpoint_type - Name of the checkpoint handler + checkpoint_path + Path to the checkpoint Returns ------- @@ -145,6 +144,6 @@ def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint: The checkpoint handler (usually an instance of `git_theta.checkpoints.Checkpoint`). Returned handler may be defined in a user installed plugin. """ - checkpoint_type = get_checkpoint_handler_name(checkpoint_type) + checkpoint_type = get_checkpoint_handler_name(checkpoint_path) discovered_plugins = entry_points(group="git_theta.plugins.checkpoints") return discovered_plugins[checkpoint_type].load() diff --git a/git_theta/git_utils.py b/git_theta/git_utils.py index af27fba..143e257 100644 --- a/git_theta/git_utils.py +++ b/git_theta/git_utils.py @@ -10,7 +10,7 @@ import shutil import subprocess import sys -from typing import List, Sequence, Union +from typing import Dict, List, Sequence, Union import git @@ -148,13 +148,14 @@ def write_gitattributes( gitattributes_file.write("\n") -def add_theta_to_gitattributes(gitattributes: List[str], path: str) -> str: +def add_theta_to_gitattributes(gitattributes: List[str], path: str) -> List[str]: """Add a filter=theta that covers file_name. Parameters ---------- gitattributes: A list of the lines from the gitattribute files. path: The path to the model we are adding a filter to. + checkpoint_type: The checkpoint format type of the model we are adding a filter to Returns ------- @@ -197,6 +198,91 @@ def get_gitattributes_tracked_patterns(gitattributes_file): return patterns +def get_config_file(repo): + """ + Get path to this repo's .thetaconfig file + + Parameters + ---------- + repo : git.Repo + Repo object for the current git repository + + Returns + ------- + str + path to $git_root/.thetaconfig + """ + return os.path.join(repo.working_dir, ".thetaconfig") + + +def read_config(config_file): + """ + Read contents of this repo's .thetaconfig file. This file is a standard json file. + + Parameters + ---------- + config_file : str + Path to this repo's .thetaconfig file + + Returns + ------- + Dict + contents of .thetaconfig file + """ + if os.path.exists(config_file): + with open(config_file, "r") as f: + return json.load(f) + else: + return {} + + +@file_or_name(config_file="w") +def write_config(config_file: Union[str, io.FileIO], config: Dict): + """ + Write a dictionary to this repo's .thetaconfig file + + Parameters + ---------- + config_file: + Path to this repo's .thetaconfig file + + config: + Configuration dictionary to write to .thetaconfig + """ + json.dump(config, config_file, indent=4) + + +def add_path_to_config(config: Dict, path: str, config_params: Dict) -> Dict: + """Add `config_params` to the config entry for `path`. If no entry exists for `path`, add an entry. + + Parameters + ---------- + config: + Configuration dictionary + path: + The path to add configuration params for + config_params: + The configuration params to add + + Returns + ------- + Dict + New configuration dictionary + """ + pattern_found = False + # Check each entry to see if any match `path` + for pattern, config_entry in config.items(): + if fnmatch.fnmatchcase(path, pattern): + pattern_found = True + config_entry.update(config_params) + + # If we don't find an existing entry that matches `path`, add a new entry + if not pattern_found: + config[path] = config_params + + return config + + def add_file(f, repo): """ Add file to git staging area diff --git a/git_theta/merges/average.py b/git_theta/merges/average.py index c9dd9d5..5596ea3 100644 --- a/git_theta/merges/average.py +++ b/git_theta/merges/average.py @@ -43,16 +43,16 @@ def read_parameter( self, param: metadata.ParamMetadata, param_name: ParamName, path: str ) -> Parameter: update_handler = updates.get_update_handler(param.theta_metadata.update_type)( - params.get_update_serializer() + path, params.get_update_serializer() ) return async_utils.run( update_handler.apply(param, param_name, git_utils.get_git_repo(), path) ) - def write_merged(self, averaged: Parameter, param_name: ParamName): + def write_merged(self, averaged: Parameter, param_name: ParamName, path: str): tensor_metadata = metadata.TensorMetadata.from_tensor(averaged) update_handler = updates.get_update_handler("dense")( - params.get_update_serializer() + path, params.get_update_serializer() ) theta_metadata = metadata.ThetaMetadata("dense", None) # Dense only needs these two... @@ -83,7 +83,7 @@ def merge( # Load the other parameter paramB = self.read_parameter(paramB, param_name, path) result = self.average(alpha * paramA, (1 - alpha) * paramB) - return self.write_merged(result, param_name) + return self.write_merged(result, path, param_name) @classmethod def merge_arguments(self) -> List[MergeArgument]: @@ -256,7 +256,7 @@ def merge( # Load the original parameter paramO = self.read_parameter(paramO, param_name, path) result = self.average(alpha * paramB, (1 - alpha) * paramO) - return self.write_merged(result, param_name) + return self.write_merged(result, param_name, path) @classmethod def merge_arguments(self) -> List[MergeArgument]: diff --git a/git_theta/scripts/git_theta.py b/git_theta/scripts/git_theta.py index fb379aa..0d50701 100755 --- a/git_theta/scripts/git_theta.py +++ b/git_theta/scripts/git_theta.py @@ -56,6 +56,12 @@ def parse_args(): track_parser.add_argument( "file", help="model checkpoint file or file pattern to track" ) + track_parser.add_argument( + "--checkpoint_format", + choices=[e.name for e in entry_points(group="git_theta.plugins.checkpoints")], + default="pytorch", + help="Checkpoint format", + ) track_parser.set_defaults(func=track) add_parser = subparsers.add_parser("add", help="add command used to stage files.") @@ -146,6 +152,7 @@ def track(args): repo = git_utils.get_git_repo() model_path = git_utils.get_relative_path_from_root(repo, args.file) + # Read .gitattributes and add/update entry for the tracked file gitattributes_file = git_utils.get_gitattributes_file(repo) gitattributes = git_utils.read_gitattributes(gitattributes_file) @@ -154,6 +161,17 @@ def track(args): git_utils.write_gitattributes(gitattributes_file, new_gitattributes) git_utils.add_file(gitattributes_file, repo) + # Read .thetaconfig and add/update entry for the tracked file + config_file = git_utils.get_config_file(repo) + config = git_utils.read_config(config_file) + + new_config = git_utils.add_path_to_config( + config, model_path, {"checkpoint_format": args.checkpoint_format} + ) + + git_utils.write_config(config_file, new_config) + git_utils.add_file(config_file, repo) + def add(args, unparsed_args): repo = git_utils.get_git_repo() diff --git a/git_theta/scripts/git_theta_diff.py b/git_theta/scripts/git_theta_diff.py index 039712a..67cd178 100644 --- a/git_theta/scripts/git_theta_diff.py +++ b/git_theta/scripts/git_theta_diff.py @@ -84,7 +84,7 @@ def print_modified_params_summary(modified, indent=0, color=None): def main(): args = parse_args() - checkpoint_handler = checkpoints.get_checkpoint_handler() + checkpoint_handler = checkpoints.get_checkpoint_handler(args.path) old_checkpoint = checkpoint_handler.from_file(args.old_checkpoint) new_checkpoint = checkpoint_handler.from_file(args.new_checkpoint) added, removed, modified = checkpoint_handler.diff(new_checkpoint, old_checkpoint) diff --git a/git_theta/scripts/git_theta_filter.py b/git_theta/scripts/git_theta_filter.py index 6d8a392..183bbf7 100755 --- a/git_theta/scripts/git_theta_filter.py +++ b/git_theta/scripts/git_theta_filter.py @@ -55,7 +55,7 @@ def clean( update_serializer = params.get_update_serializer() # Create an update handler based on user input. update_handler = updates.get_update_handler()( - update_serializer, EnvVarConstants.UPDATE_DATA_PATH + path, update_serializer, EnvVarConstants.UPDATE_DATA_PATH ) prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD").flatten() @@ -94,7 +94,7 @@ async def _clean(param_keys, new_param): # for that parameter. param_update_handler = updates.get_update_handler( param_metadata.theta_metadata.update_type - )(update_serializer) + )(path, update_serializer) param = await param_update_handler.apply( param_metadata, param_keys, repo=repo, path=path ) @@ -149,7 +149,7 @@ def run_clean(args): """ logging.debug(f"Running clean filter on {args.file}") repo = git_utils.get_git_repo() - checkpoint_handler = checkpoints.get_checkpoint_handler() + checkpoint_handler = checkpoints.get_checkpoint_handler(args.file) model_checkpoint = checkpoint_handler.from_file(sys.stdin.buffer) new_metadata = clean(model_checkpoint, repo, args.file) new_metadata.write(sys.stdout) @@ -171,7 +171,7 @@ async def _smudge(param_keys, param_metadata): logging.debug(f"Smudging {'/'.join(param_keys)}") update_handler = updates.get_update_handler( param_metadata.theta_metadata.update_type - )(params.get_update_serializer()) + )(path, params.get_update_serializer()) param_value = await update_handler.apply( param_metadata, param_keys, repo=repo, path=path ) @@ -183,7 +183,7 @@ async def _smudge(param_keys, param_metadata): ) ) - checkpoint_handler = checkpoints.get_checkpoint_handler() + checkpoint_handler = checkpoints.get_checkpoint_handler(path) return checkpoint_handler(model_dict).unflatten() diff --git a/git_theta/updates/base.py b/git_theta/updates/base.py index a04917d..d369b33 100644 --- a/git_theta/updates/base.py +++ b/git_theta/updates/base.py @@ -26,7 +26,8 @@ class Update(metaclass=ABCMeta): name: str = NotImplemented # The name used to lookup the plug-in. - def __init__(self, serializer: params.Serializer, *args, **kwargs): + def __init__(self, path: str, serializer: params.Serializer, *args, **kwargs): + self.path = path self.serializer = serializer async def read(self, param_metadata: metadata.ParamMetadata) -> Parameter: @@ -59,15 +60,15 @@ class IncrementalUpdate(Update): required_keys: FrozenSet[str] = NotImplemented # Names for side-loaded information. - def __init__(self, serializer: params.Serializer, update_data: str = ""): - super().__init__(serializer) + def __init__(self, path: str, serializer: params.Serializer, update_data: str = ""): + super().__init__(path, serializer) self.update_information: Dict[str, np.ndarray] = None self.update_names: utils.Trie = None # Flatten the side-loaded information into a of string keys to arrays. if update_data: self.update_information = { "/".join(k): v - for k, v in checkpoints.get_checkpoint_handler() + for k, v in checkpoints.get_checkpoint_handler(path) .from_file(update_data) .flatten() .items() @@ -121,7 +122,7 @@ async def get_previous_value( # getters return classes to be instantiated. prev_serializer = params.get_update_serializer() prev_update = get_update_handler(param_metadata.theta_metadata.update_type)( - prev_serializer + path, prev_serializer ) return await prev_update.apply(param_metadata, param_keys, repo=repo, path=path) diff --git a/git_theta/updates/sparse.py b/git_theta/updates/sparse.py index 4a84516..837ddfe 100644 --- a/git_theta/updates/sparse.py +++ b/git_theta/updates/sparse.py @@ -18,14 +18,9 @@ class SparseUpdate(IncrementalUpdate): name: str = "sparse" required_keys: FrozenSet[str] = frozenset(("data", "indices", "indptr", "shape")) - def __init__( - self, - serializer: params.Serializer, - update_data: str = "", - threshold: float = 1e-12, - ): + def __init__(self, *args, threshold: float = 1e-12, **kwargs): # TODO: Make threshold configurable - super().__init__(serializer, update_data) + super().__init__(*args, **kwargs) self.threshold = threshold @classmethod diff --git a/git_theta/utils.py b/git_theta/utils.py index 052ee13..67d0d8f 100644 --- a/git_theta/utils.py +++ b/git_theta/utils.py @@ -71,9 +71,8 @@ def __get__(self, obj, objtype=None): class EnvVarConstants: - CHECKPOINT_TYPE = EnvVar(name="GIT_THETA_CHECKPOINT_TYPE", default="pytorch") UPDATE_TYPE = EnvVar(name="GIT_THETA_UPDATE_TYPE", default="dense") - UPDATE_DATA_PATH = EnvVar(name="GIT_THETA_UPDATE_DATA_PATH", default="update.pt") + UPDATE_DATA_PATH = EnvVar(name="GIT_THETA_UPDATE_DATA_PATH", default="") PARAMETER_ATOL = EnvVar(name="GIT_THETA_PARAMETER_ATOL", default=1e-8) PARAMETER_RTOL = EnvVar(name="GIT_THETA_PARAMETER_RTOL", default=1e-5) LSH_SIGNATURE_SIZE = EnvVar(name="GIT_THETA_LSH_SIGNATURE_SIZE", default=16) From c66d2863a36910cbe19211cd6120ba39d9a8185c Mon Sep 17 00:00:00 2001 From: Nikhil Kandpal Date: Thu, 4 May 2023 11:56:37 -0400 Subject: [PATCH 2/4] Factor .gitattributes and .thetaconfig reading/writing/parsing into classes --- git_theta/checkpoints/base.py | 13 +- git_theta/config.py | 198 ++++++++++++++++++++++++++ git_theta/git_utils.py | 193 ------------------------- git_theta/lsh/euclidean_lsh.py | 9 +- git_theta/lsh/pool.py | 6 +- git_theta/scripts/git_theta.py | 39 ++--- git_theta/scripts/git_theta_filter.py | 21 +-- git_theta/utils.py | 6 - 8 files changed, 236 insertions(+), 249 deletions(-) create mode 100644 git_theta/config.py diff --git a/git_theta/checkpoints/base.py b/git_theta/checkpoints/base.py index d531470..9bc9e15 100644 --- a/git_theta/checkpoints/base.py +++ b/git_theta/checkpoints/base.py @@ -13,7 +13,7 @@ else: from importlib.metadata import entry_points -from git_theta import git_utils, utils +from git_theta import config, git_utils, utils @utils.abstract_classattributes("name") @@ -118,14 +118,9 @@ def get_checkpoint_handler_name(checkpoint_path: str) -> Optional[str]: """ repo = git_utils.get_git_repo() checkpoint_path = git_utils.get_relative_path_from_root(repo, checkpoint_path) - config_file = git_utils.get_config_file(repo) - config = git_utils.read_config(config_file) - - for pattern, config_entry in config.items(): - if fnmatch.fnmatchcase(checkpoint_path, pattern): - return config_entry.get("checkpoint_format", None) - - return None + thetaconfig = config.ThetaConfigFile(repo) + checkpoint_config = thetaconfig.get_config(checkpoint_path) + return checkpoint_config.get("checkpoint_format", None) def get_checkpoint_handler(checkpoint_path: str) -> Checkpoint: diff --git a/git_theta/config.py b/git_theta/config.py new file mode 100644 index 0000000..b526099 --- /dev/null +++ b/git_theta/config.py @@ -0,0 +1,198 @@ +import dataclasses +import fnmatch +import json +import os +import re +from collections import OrderedDict +from typing import Dict, List, Tuple + +import git + +from git_theta import git_utils + + +@dataclasses.dataclass +class PatternAttributes: + pattern: str + attributes: str + + @classmethod + def from_line(cls, line): + # TODO(bdlester): Revisit this regex to see if it when the pattern + # is escaped due to having spaces in it. + match = re.match( + r"^\s*(?P[^\s]+)\s+(?P.*)$", line.rstrip() + ) + if not match: + raise ValueError(f"{line} is an invalid git attribute line") + return cls(pattern=match.group("pattern"), attributes=match.group("attributes")) + + def add_attribute(self, attribute): + if attribute not in self.attributes: + self.attributes = f"{self.attribute.rstrip()} {attribute}" + + def serialize(self): + return f"{self.pattern} {self.attributes}" + + +class GitAttributesFile: + def __init__(self, repo: git.Repo): + self.repo = repo + self.file = os.path.join(repo.working_dir, ".gitattributes") + self.data = GitAttributesFile.read(self.file) + + @classmethod + def read(cls, file: str): + """ + Read contents of this repo's .gitattributes file + + Parameters + ---------- + file + Path to .gitattributes file + + Returns + ------- + List[str] + lines in .gitattributes file + """ + if os.path.exists(file): + with open(file, "r") as f: + return [PatternAttributes.from_line(line) for line in f] + else: + return [] + + def write(self): + """ + Write list of attributes to this repo's .gitattributes file + """ + with open(self.file, "w") as f: + f.write( + "\n".join([pattern_attrs.serialize() for pattern_attrs in self.data]) + ) + # End file with newline. + f.write("\n") + + def add_theta(self, pattern: str) -> List[str]: + """Set filter, merge, and diff attributes to theta for `pattern`. + + Parameters + ---------- + pattern: + The pattern we are adding theta attributes for + """ + pattern = git_utils.get_relative_path_from_root(self.repo, pattern) + pattern_found = False + for pattern_attrs in self.data: + if pattern == pattern_attrs.pattern: + pattern_attrs.add("filter=theta") + pattern_attrs.add("merge=theta") + pattern_attrs.add("diff=theta") + pattern_found = True + # If we don't find a matching pattern, add a new line that covers this pattern + if not pattern_found: + self.data.append( + PatternAttributes.from_line( + f"{pattern} filter=theta merge=theta diff=theta" + ) + ) + + def is_theta_tracked(self, path: str) -> bool: + return any( + [ + fnmatch.fnmatchcase(path, pattern_attr.pattern) + and "filter=theta" in pattern_attr.attributes + for pattern_attr in self.data + ] + ) + + +@dataclasses.dataclass +class Config: + @classmethod + def from_dict(cls, params: Dict) -> "Config": + fields = [field.name for field in dataclasses.fields(cls)] + return cls(**{k: v for k, v in params.items() if k in fields}) + + def serialize(self) -> Dict: + return dataclasses.asdict(self, dict_factory=OrderedDict) + + +@dataclasses.dataclass +class PatternConfig(Config): + pattern: str + checkpoint_format: str + + +@dataclasses.dataclass +class RepoConfig(Config): + parameter_atol: float = 1e-8 + parameter_rtol: float = 1e-5 + lsh_signature_size: int = 16 + lsh_threshold: float = 1e-6 + lsh_pool_size: int = 10_000 + max_concurrency: int = -1 + + +class ThetaConfigFile: + def __init__(self, repo): + self.repo = repo + self.file = os.path.join(repo.working_dir, ".thetaconfig") + self.repo_config, self.pattern_configs = ThetaConfigFile.read(self.file) + + @classmethod + def read(cls, file) -> Tuple[RepoConfig, List[PatternConfig]]: + """ + Read contents of this repo's .thetaconfig file + + Returns + ------- + repo_config: RepoConfig + Repository-level configuration parameters + + pattern_configs: List[PatternConfig] + Pattern-level configuration parameters + """ + if os.path.exists(file): + with open(file, "r") as f: + config = json.load(f) + else: + config = {"repo": {}, "patterns": []} + + repo_config = RepoConfig.from_dict(config["repo"]) + pattern_configs = [PatternConfig.from_dict(d) for d in config["patterns"]] + return repo_config, pattern_configs + + def write(self): + """ + Write a dictionary config to this repo's .thetaconfig file + + Parameters + ---------- + config_file: + Path to this repo's .thetaconfig file + + config: + Configuration dictionary to write to .thetaconfig + """ + with open(self.file, "w") as f: + json.dump(self.serialize(), f, indent=4) + + def serialize(self): + repo_config = self.repo_config.serialize() + pattern_configs = [pc.serialize() for pc in self.pattern_configs] + config = {"repo": repo_config, "patterns": pattern_configs} + return config + + def add_pattern(self, config: Dict): + pattern_config = PatternConfig.from_dict(config) + self.pattern_configs.append(pattern_config) + + def get_config(self, path: str) -> Dict: + config = {} + for pc in self.pattern_configs: + if fnmatch.fnmatchcase(path, pc.pattern): + config.update(pc.serialize()) + if "pattern" in config: + config.pop("pattern") + return config diff --git a/git_theta/git_utils.py b/git_theta/git_utils.py index 143e257..dc6d3c9 100644 --- a/git_theta/git_utils.py +++ b/git_theta/git_utils.py @@ -90,199 +90,6 @@ def get_absolute_path(repo: git.Repo, relative_path: str) -> str: return os.path.abspath(os.path.join(repo.working_dir, relative_path)) -def get_gitattributes_file(repo): - """ - Get path to this repo's .gitattributes file - - Parameters - ---------- - repo : git.Repo - Repo object for the current git repository - - Returns - ------- - str - path to $git_root/.gitattributes - """ - return os.path.join(repo.working_dir, ".gitattributes") - - -def read_gitattributes(gitattributes_file): - """ - Read contents of this repo's .gitattributes file - - Parameters - ---------- - gitattributes_file : str - Path to this repo's .gitattributes file - - Returns - ------- - List[str] - lines in .gitattributes file - """ - if os.path.exists(gitattributes_file): - with open(gitattributes_file, "r") as f: - return [line.rstrip("\n") for line in f] - else: - return [] - - -@file_or_name(gitattributes_file="w") -def write_gitattributes( - gitattributes_file: Union[str, io.FileIO], attributes: List[str] -): - """ - Write list of attributes to this repo's .gitattributes file - - Parameters - ---------- - gitattributes_file: - Path to this repo's .gitattributes file - - attributes: - Attributes to write to .gitattributes - """ - gitattributes_file.write("\n".join(attributes)) - # End file with newline. - gitattributes_file.write("\n") - - -def add_theta_to_gitattributes(gitattributes: List[str], path: str) -> List[str]: - """Add a filter=theta that covers file_name. - - Parameters - ---------- - gitattributes: A list of the lines from the gitattribute files. - path: The path to the model we are adding a filter to. - checkpoint_type: The checkpoint format type of the model we are adding a filter to - - Returns - ------- - List[str] - The lines to write to the new gitattribute file with a (possibly) new - filter=theta added that covers the given file. - """ - pattern_found = False - new_gitattributes = [] - for line in gitattributes: - # TODO(bdlester): Revisit this regex to see if it when the pattern - # is escaped due to having spaces in it. - match = re.match(r"^\s*(?P[^\s]+)\s+(?P.*)$", line) - if match: - # If there is already a pattern that covers the file, add the filter - # to that. - if fnmatch.fnmatchcase(path, match.group("pattern")): - pattern_found = True - if not "filter=theta" in match.group("attributes"): - line = f"{line.rstrip()} filter=theta" - if not "merge=theta" in match.group("attributes"): - line = f"{line.rstrip()} merge=theta" - if not "diff=theta" in match.group("attributes"): - line = f"{line.rstrip()} diff=theta" - new_gitattributes.append(line) - # If we don't find a matching pattern, add a new line that covers just this - # specific file. - if not pattern_found: - new_gitattributes.append(f"{path} filter=theta merge=theta diff=theta") - return new_gitattributes - - -def get_gitattributes_tracked_patterns(gitattributes_file): - gitattributes = read_gitattributes(gitattributes_file) - theta_attributes = [ - attribute for attribute in gitattributes if "filter=theta" in attribute - ] - # TODO: Correctly handle patterns with escaped spaces in them - patterns = [attribute.split(" ")[0] for attribute in theta_attributes] - return patterns - - -def get_config_file(repo): - """ - Get path to this repo's .thetaconfig file - - Parameters - ---------- - repo : git.Repo - Repo object for the current git repository - - Returns - ------- - str - path to $git_root/.thetaconfig - """ - return os.path.join(repo.working_dir, ".thetaconfig") - - -def read_config(config_file): - """ - Read contents of this repo's .thetaconfig file. This file is a standard json file. - - Parameters - ---------- - config_file : str - Path to this repo's .thetaconfig file - - Returns - ------- - Dict - contents of .thetaconfig file - """ - if os.path.exists(config_file): - with open(config_file, "r") as f: - return json.load(f) - else: - return {} - - -@file_or_name(config_file="w") -def write_config(config_file: Union[str, io.FileIO], config: Dict): - """ - Write a dictionary to this repo's .thetaconfig file - - Parameters - ---------- - config_file: - Path to this repo's .thetaconfig file - - config: - Configuration dictionary to write to .thetaconfig - """ - json.dump(config, config_file, indent=4) - - -def add_path_to_config(config: Dict, path: str, config_params: Dict) -> Dict: - """Add `config_params` to the config entry for `path`. If no entry exists for `path`, add an entry. - - Parameters - ---------- - config: - Configuration dictionary - path: - The path to add configuration params for - config_params: - The configuration params to add - - Returns - ------- - Dict - New configuration dictionary - """ - pattern_found = False - # Check each entry to see if any match `path` - for pattern, config_entry in config.items(): - if fnmatch.fnmatchcase(path, pattern): - pattern_found = True - config_entry.update(config_params) - - # If we don't find an existing entry that matches `path`, add a new entry - if not pattern_found: - config[path] = config_params - - return config - - def add_file(f, repo): """ Add file to git staging area diff --git a/git_theta/lsh/euclidean_lsh.py b/git_theta/lsh/euclidean_lsh.py index 2c4dc08..49a3fc3 100644 --- a/git_theta/lsh/euclidean_lsh.py +++ b/git_theta/lsh/euclidean_lsh.py @@ -5,10 +5,10 @@ import numba as nb import numpy as np +from git_theta import config, git_utils from git_theta.lsh import HashFamily from git_theta.lsh.pool import RandomnessPool from git_theta.lsh.types import Parameter, Signature -from git_theta.utils import EnvVarConstants class EuclideanLSH(HashFamily): @@ -68,8 +68,9 @@ def nb_hash( def get_lsh(): - # TODO we need a better way of keeping track of configuration at the repository level - # For LSH configuration, once it is set for a repository, changing it should be handled with care + repo = git_utils.get_git_repo() + thetaconfig = config.ThetaConfigFile(repo) return FastEuclideanLSH( - EnvVarConstants.LSH_SIGNATURE_SIZE, EnvVarConstants.PARAMETER_ATOL + thetaconfig.repo_config.lsh_signature_size, + thetaconfig.repo_config.parameter_atol, ) diff --git a/git_theta/lsh/pool.py b/git_theta/lsh/pool.py index 50dd72d..e5f1095 100644 --- a/git_theta/lsh/pool.py +++ b/git_theta/lsh/pool.py @@ -6,7 +6,7 @@ import numpy as np from numpy.random import MT19937, Generator -from git_theta.utils import EnvVarConstants +from git_theta import config, git_utils spec = [("pool", nb.float64[:]), ("signature_offsets", nb.int64[:])] @@ -15,9 +15,11 @@ class RandomnessPool: def __init__(self, signature_size): with nb.objmode(pool="float64[:]", signature_offsets="int64[:]"): + repo = git_utils.get_git_repo() + thetaconfig = config.ThetaConfigFile(repo) # N.b. we use a fixed seed so that every instance of RandomPool has the same set of random numbers rng = Generator(MT19937(seed=42)) - pool = rng.normal(size=EnvVarConstants.LSH_POOL_SIZE) + pool = rng.normal(size=thetaconfig.repo_config.lsh_pool_size) int64_range = np.iinfo(np.int64) signature_offsets = rng.integers( int64_range.min, int64_range.max, size=signature_size, dtype=np.int64 diff --git a/git_theta/scripts/git_theta.py b/git_theta/scripts/git_theta.py index 0d50701..58fce61 100755 --- a/git_theta/scripts/git_theta.py +++ b/git_theta/scripts/git_theta.py @@ -1,7 +1,6 @@ """Installation and .git manipulation scripts.""" import argparse -import fnmatch import logging import re import sys @@ -13,7 +12,7 @@ else: from importlib.metadata import entry_points -from git_theta import async_utils, git_utils, metadata, theta, utils +from git_theta import async_utils, config, git_utils, metadata, theta, utils logging.basicConfig( level=logging.DEBUG, @@ -53,9 +52,7 @@ def parse_args(): "track", help="track command used to identify model checkpoint for git-theta to track", ) - track_parser.add_argument( - "file", help="model checkpoint file or file pattern to track" - ) + track_parser.add_argument("pattern", help="model checkpoint file pattern to track") track_parser.add_argument( "--checkpoint_format", choices=[e.name for e in entry_points(group="git_theta.plugins.checkpoints")], @@ -85,13 +82,12 @@ def post_commit(args): repo = git_utils.get_git_repo() theta_commits = theta.ThetaCommits(repo) - gitattributes_file = git_utils.get_gitattributes_file(repo) - patterns = git_utils.get_gitattributes_tracked_patterns(gitattributes_file) + gitattributes = config.GitAttributesFile(repo) oids = set() commit = repo.commit("HEAD") for path in commit.stats.files.keys(): - if any([fnmatch.fnmatchcase(path, pattern) for pattern in patterns]): + if gitattributes.is_theta_tracked(path): curr_metadata = metadata.Metadata.from_file(commit.tree[path].data_stream) prev_metadata = metadata.Metadata.from_commit(repo, path, "HEAD~1") @@ -150,27 +146,16 @@ def track(args): Track a particular model checkpoint file with git-theta """ repo = git_utils.get_git_repo() - model_path = git_utils.get_relative_path_from_root(repo, args.file) - - # Read .gitattributes and add/update entry for the tracked file - gitattributes_file = git_utils.get_gitattributes_file(repo) - gitattributes = git_utils.read_gitattributes(gitattributes_file) - - new_gitattributes = git_utils.add_theta_to_gitattributes(gitattributes, model_path) - git_utils.write_gitattributes(gitattributes_file, new_gitattributes) - git_utils.add_file(gitattributes_file, repo) - - # Read .thetaconfig and add/update entry for the tracked file - config_file = git_utils.get_config_file(repo) - config = git_utils.read_config(config_file) - - new_config = git_utils.add_path_to_config( - config, model_path, {"checkpoint_format": args.checkpoint_format} - ) + gitattributes = config.GitAttributesFile(repo) + gitattributes.add_theta(args.pattern) + gitattributes.write() + git_utils.add_file(gitattributes.file, repo) - git_utils.write_config(config_file, new_config) - git_utils.add_file(config_file, repo) + thetaconfig = config.ThetaConfigFile(repo) + thetaconfig.add_pattern(vars(args)) + thetaconfig.write() + git_utils.add_file(thetaconfig.file, repo) def add(args, unparsed_args): diff --git a/git_theta/scripts/git_theta_filter.py b/git_theta/scripts/git_theta_filter.py index 183bbf7..960dfcd 100755 --- a/git_theta/scripts/git_theta_filter.py +++ b/git_theta/scripts/git_theta_filter.py @@ -12,6 +12,7 @@ from git_theta import ( async_utils, checkpoints, + config, git_utils, lsh, metadata, @@ -50,6 +51,7 @@ def clean( checkpoint: checkpoints.Checkpoint, repo: git.Repo, path: str ) -> metadata.Metadata: """Convert a `Checkpoint` to cleaned `Metadata`.""" + thetaconfig = config.ThetaConfigFile(repo) # Note: If the update serializer is configurable per-parameter, it will # need to be created inside _clean update_serializer = params.get_update_serializer() @@ -83,13 +85,13 @@ async def _clean(param_keys, new_param): hash_distance = hasher.distance( param_metadata.tensor_metadata.hash, new_tensor_metadata.hash ) - # If hash_distance < PARAMETER_ATOL, assume the tensors pass + # If hash_distance < parameter_atol, assume the tensors pass # np.allclose and parameter hasn't changed - if hash_distance < EnvVarConstants.PARAMETER_ATOL: + if hash_distance < thetaconfig.repo_config.parameter_atol: return param_keys, param_metadata - # If PARAMETER_ATOL < hash_distance < LSH_THRESHOLD, load parameters + # If parameter_atol < hash_distance < lsh_threshold, load parameters # and check if parameter has changed with np.allclose - elif hash_distance < EnvVarConstants.LSH_THRESHOLD: + elif hash_distance < thetaconfig.repo_config.lsh_threshold: # Load the previous parameter using the specific update handler # for that parameter. param_update_handler = updates.get_update_handler( @@ -101,8 +103,8 @@ async def _clean(param_keys, new_param): if np.allclose( param, new_param, - rtol=EnvVarConstants.PARAMETER_RTOL, - atol=EnvVarConstants.PARAMETER_ATOL, + rtol=thetaconfig.repo_config.parameter_rtol, + atol=thetaconfig.repo_config.parameter_atol, ): return param_keys, param_metadata @@ -137,7 +139,7 @@ async def _clean(param_keys, new_param): async_utils.run_map( sorted_checkpoint, _clean, - max_concurrency=EnvVarConstants.MAX_CONCURRENCY, + max_concurrency=thetaconfig.repo_config.max_concurrency, ) ) ).unflatten() @@ -165,6 +167,7 @@ def smudge( cleaned_metadata: metadata.Metadata, repo: git.Repo, path: str ) -> checkpoints.Checkpoint: """Convert cleaned `Metadata` to a `Checkpoint`.""" + thetaconfig = config.ThetaConfigFile(repo) curr_metadata = cleaned_metadata.flatten() async def _smudge(param_keys, param_metadata): @@ -179,7 +182,9 @@ async def _smudge(param_keys, param_metadata): model_dict = async_utils.run( async_utils.run_map( - curr_metadata, _smudge, max_concurrency=EnvVarConstants.MAX_CONCURRENCY + curr_metadata, + _smudge, + max_concurrency=thetaconfig.repo_config.max_concurrency, ) ) diff --git a/git_theta/utils.py b/git_theta/utils.py index 67d0d8f..bcccee2 100644 --- a/git_theta/utils.py +++ b/git_theta/utils.py @@ -73,12 +73,6 @@ def __get__(self, obj, objtype=None): class EnvVarConstants: UPDATE_TYPE = EnvVar(name="GIT_THETA_UPDATE_TYPE", default="dense") UPDATE_DATA_PATH = EnvVar(name="GIT_THETA_UPDATE_DATA_PATH", default="") - PARAMETER_ATOL = EnvVar(name="GIT_THETA_PARAMETER_ATOL", default=1e-8) - PARAMETER_RTOL = EnvVar(name="GIT_THETA_PARAMETER_RTOL", default=1e-5) - LSH_SIGNATURE_SIZE = EnvVar(name="GIT_THETA_LSH_SIGNATURE_SIZE", default=16) - LSH_THRESHOLD = EnvVar(name="GIT_THETA_LSH_THRESHOLD", default=1e-6) - LSH_POOL_SIZE = EnvVar(name="GIT_THETA_LSH_POOL_SIZE", default=10_000) - MAX_CONCURRENCY = EnvVar(name="GIT_THETA_MAX_CONCURRENCY", default=-1) MANUAL_MERGE = EnvVar(name="GIT_THETA_MANUAL_MERGE", default=False) From 9e703f8fa619e58d47c41ff3974261cc0894017e Mon Sep 17 00:00:00 2001 From: Nikhil Kandpal Date: Thu, 4 May 2023 23:21:16 -0400 Subject: [PATCH 3/4] Add tests --- git_theta/config.py | 21 ++--- tests/config_test.py | 189 ++++++++++++++++++++++++++++++++++++++++ tests/git_utils_test.py | 165 ----------------------------------- 3 files changed, 200 insertions(+), 175 deletions(-) create mode 100644 tests/config_test.py delete mode 100644 tests/git_utils_test.py diff --git a/git_theta/config.py b/git_theta/config.py index b526099..9594027 100644 --- a/git_theta/config.py +++ b/git_theta/config.py @@ -28,8 +28,10 @@ def from_line(cls, line): return cls(pattern=match.group("pattern"), attributes=match.group("attributes")) def add_attribute(self, attribute): + # TODO(nkandpa2): It does not work to specify multiple filter/diff/merge attributes for one pattern so + # at some point we may have to do better parsing of the .gitattributes lines to avoid that if attribute not in self.attributes: - self.attributes = f"{self.attribute.rstrip()} {attribute}" + self.attributes = f"{self.attributes.rstrip()} {attribute}" def serialize(self): return f"{self.pattern} {self.attributes}" @@ -58,7 +60,7 @@ def read(cls, file: str): """ if os.path.exists(file): with open(file, "r") as f: - return [PatternAttributes.from_line(line) for line in f] + return [PatternAttributes.from_line(line) for line in f if line.strip()] else: return [] @@ -67,11 +69,10 @@ def write(self): Write list of attributes to this repo's .gitattributes file """ with open(self.file, "w") as f: - f.write( - "\n".join([pattern_attrs.serialize() for pattern_attrs in self.data]) - ) - # End file with newline. - f.write("\n") + f.write(self.serialize()) + + def serialize(self): + return "\n".join([pattern_attrs.serialize() for pattern_attrs in self.data]) def add_theta(self, pattern: str) -> List[str]: """Set filter, merge, and diff attributes to theta for `pattern`. @@ -85,9 +86,9 @@ def add_theta(self, pattern: str) -> List[str]: pattern_found = False for pattern_attrs in self.data: if pattern == pattern_attrs.pattern: - pattern_attrs.add("filter=theta") - pattern_attrs.add("merge=theta") - pattern_attrs.add("diff=theta") + pattern_attrs.add_attribute("filter=theta") + pattern_attrs.add_attribute("merge=theta") + pattern_attrs.add_attribute("diff=theta") pattern_found = True # If we don't find a matching pattern, add a new line that covers this pattern if not pattern_found: diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 0000000..5f511c8 --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,189 @@ +"""Tests for config.py""" + +import os + +import git +import pytest + +from git_theta import config + + +@pytest.fixture +def git_repo(): + cwd = os.getcwd() + repo_dir = os.path.abspath(".delete-me") + os.mkdir(repo_dir) + try: + os.chdir(repo_dir) + repo = git.Repo.init(repo_dir) + yield repo + finally: + os.chdir(cwd) + repo.close() + git.rmtree(repo_dir) + + +@pytest.fixture +def gitattributes(): + return [ + "*.pt filter=theta merge=theta diff=theta", + "*.png filter=lfs", + "really-big-file filter=lfs", + "something else, who knows how cool it could be", + ] + + +def test_add_theta_gitattributes_empty_file(git_repo): + gitattributes = config.GitAttributesFile(git_repo) + model_path = "example" + gitattributes.add_theta(model_path) + assert ( + gitattributes.serialize() == f"{model_path} filter=theta merge=theta diff=theta" + ) + + +def test_add_theta_gitattributes_no_match(git_repo): + with open(os.path.join(git_repo.working_dir, ".gitattributes"), "w") as f: + f.write( + "\n".join( + [ + "Some-other-path filter=lfs", + "*-cool-models.pt filter=theta merge=theta diff=theta", + ] + ) + ) + gitattributes = config.GitAttributesFile(git_repo) + + model_path = "path/to/my/model.pt" + gitattributes.add_theta(model_path) + assert ( + gitattributes.data[-1].serialize() + == f"{model_path} filter=theta merge=theta diff=theta" + ) + + +def test_add_theta_gitattributes_exact_match(git_repo): + model_path = "really/cool/model/yall.ckpt" + + with open(os.path.join(git_repo.working_dir, ".gitattributes"), "w") as f: + f.write(f"{model_path} ident") + gitattributes = config.GitAttributesFile(git_repo) + + gitattributes.add_theta(model_path) + + assert ( + len(gitattributes.data) == 1 + and gitattributes.data[0].serialize() + == f"{model_path} ident filter=theta merge=theta diff=theta" + ) + + +def test_add_theta_gitattributes_rest_unchanged(git_repo): + model_path = "model-v3.pt" + + atts = [ + "some-other-path filter=theta merge=theta diff=theta", + "really-reaaaally-big-files filter=lfs", + "another filter=theta merge=theta diff=theta", + ] + + with open(os.path.join(git_repo.working_dir, ".gitattributes"), "w") as f: + f.write("\n".join(atts)) + gitattributes = config.GitAttributesFile(git_repo) + + gitattributes.add_theta(model_path) + + for i, att in enumerate(atts): + assert att == gitattributes.data[i].serialize() + + +def test_read_gitattributes(gitattributes, tmp_path): + file = tmp_path / ".gitattributes" + with open(file, "w") as f: + f.write("\n".join(gitattributes)) + + attrs = config.GitAttributesFile.read(file) + for true_attr, attr in zip(gitattributes, attrs): + assert attr.serialize() == true_attr + + +def test_read_gitattributes_missing_file(tmp_path): + """Test that gitattributes file missing returns an empty list.""" + missing_file = tmp_path / ".gitattributes" + assert not os.path.exists(missing_file) + read_attributes = config.GitAttributesFile.read(missing_file) + assert read_attributes == [] + + +def test_read_gitattributes_empty_file(tmp_path): + """Test that gitattributes file being empty returns an empty list.""" + empty_file = tmp_path / ".gitattributes" + empty_file.touch() + assert os.path.exists(empty_file) + read_attributes = config.GitAttributesFile.read(empty_file) + assert read_attributes == [] + + +def test_read_gitattributes_empty_lines(gitattributes, tmp_path): + """Test that a gitattributes file with empty lines mixed in is read correctly""" + file = tmp_path / ".gitattributes" + with open(file, "w") as f: + f.write("\n\n".join(gitattributes)) + + attrs = config.GitAttributesFile.read(file) + for true_attr, attr in zip(gitattributes, attrs): + assert attr.serialize() == true_attr + + +def test_write_gitattributes(git_repo, gitattributes): + """Test that attributes are written to file unchanged""" + with open(".gitattributes", "w") as f: + f.write("\n".join(gitattributes)) + + ga = config.GitAttributesFile(git_repo) + os.remove(".gitattributes") + ga.write() + + with open(".gitattributes", "r") as f: + written_gitattributes = f.readlines() + + for attr, written_attr in zip(gitattributes, written_gitattributes): + assert attr == written_attr.rstrip() + + +def test_write_gitattributes_creates_file(git_repo): + """Make sure writing the git attributes can create the missing file before writing.""" + gitattributes_path = ".gitattributes" + assert not os.path.exists(gitattributes_path) + ga = config.GitAttributesFile(git_repo) + ga.add_theta("my_model") + ga.write() + assert os.path.exists(gitattributes_path) + + +def test_read_write_gitattributes_write_read_round_trip(git_repo, gitattributes): + """Test that we can write attributes, then read them back and they will match.""" + ga = config.GitAttributesFile(git_repo) + for line in gitattributes: + ga.data.append(config.PatternAttributes.from_line(line)) + ga.write() + + ga = config.GitAttributesFile(git_repo) + assert ga.serialize() == "\n".join(gitattributes) + + +def test_read_write_gitattributes_read_write_round_trip(git_repo, gitattributes): + """Test reading attrs from file, writing to new file and verify file contents match.""" + with open(".gitattributes", "w") as f: + f.write("\n".join(gitattributes)) + + ga = config.GitAttributesFile(git_repo) + os.remove(".gitattributes") + assert not os.path.exists(".gitattributes") + ga.write() + + with open(".gitattributes", "r") as f: + written_gitattributes = f.read().split("\n") + + for attr, written_attr in zip(gitattributes, written_gitattributes): + assert attr == written_attr diff --git a/tests/git_utils_test.py b/tests/git_utils_test.py deleted file mode 100644 index ad8e704..0000000 --- a/tests/git_utils_test.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Tests for git_utils.py""" - -import os - -import pytest - -from git_theta import git_utils - - -def test_add_theta_gitattributes_empty_file(): - assert git_utils.add_theta_to_gitattributes([], "example") == [ - "example filter=theta merge=theta diff=theta" - ] - - -def test_add_theta_gitattributes_no_match(): - atts = [ - "Some-other-path filter=lfs", - "*-cool-models.pt filter=theta merge=theta diff=theta", - ] - model_path = "path/to/my/model.pt" - assert ( - git_utils.add_theta_to_gitattributes(atts, model_path)[-1] - == f"{model_path} filter=theta merge=theta diff=theta" - ) - - -def test_add_theta_gitattributes_exact_match(): - model_path = "really/cool/model/yall.ckpt" - atts = [f"{model_path} filter=lfs"] - assert ( - git_utils.add_theta_to_gitattributes(atts, model_path)[-1] - == f"{model_path} filter=lfs filter=theta merge=theta diff=theta" - ) - - -def test_add_theta_gitattributes_pattern_match(): - model_path = "literal-the-best-checkpoint.pt" - atts = ["*.pt thing"] - assert ( - git_utils.add_theta_to_gitattributes(atts, model_path)[-1] - == f"*.pt thing filter=theta merge=theta diff=theta" - ) - - -def test_add_theta_gitattributes_multiple_matches(): - model_path = "100-on-mnist.npy" - atts = ["*.npy other-filter", f"{model_path} other-filter"] - assert git_utils.add_theta_to_gitattributes(atts, model_path) == [ - f"{attr} filter=theta merge=theta diff=theta" for attr in atts - ] - - -def test_add_theta_gitattributes_match_with_theta_already(): - model_path = "my-bad-model.chkp" - atts = ["my-*-model.chkp filter=theta merge=theta diff=theta"] - assert git_utils.add_theta_to_gitattributes(atts, model_path) == atts - - -def test_add_theta_gitattributes_rest_unchanged(): - model_path = "model-v3.pt" - atts = [ - "some-other-path filter=theta merge=theta diff=theta", - "really-reaaaally-big-files filter=lfs", - r"model-v\d.pt filter", - "another filter=theta merge=theta diff=theta", - ] - results = git_utils.add_theta_to_gitattributes(atts, model_path) - for i, (a, r) in enumerate(zip(atts, results)): - if i == 2: - continue - assert a == r - - -@pytest.fixture -def gitattributes(): - return [ - "*.pt filter=theta merge=theta diff=theta", - "*.png filter=lfs", - "really-big-file filter=lfs", - "something else, who knows how cool it could be", - ] - - -def test_read_gitattributes(gitattributes, tmp_path): - """Test that reading gitattributes removes newlines.""" - gitattributes_file = tmp_path / ".gitattributes" - with open(gitattributes_file, "w") as wf: - wf.write("\n".join(gitattributes)) - read_attributes = git_utils.read_gitattributes(gitattributes_file) - for attr in read_attributes: - assert not attr.endswith("\n") - - -def test_read_gitattributes_missing_file(tmp_path): - """Test that gitattributes file missing returns an empty list.""" - missing_file = tmp_path / ".gitattributes" - assert not os.path.exists(missing_file) - read_attributes = git_utils.read_gitattributes(missing_file) - assert read_attributes == [] - - -def test_read_gitattributes_empty_file(tmp_path): - """Test that gitattributes file being empty returns an empty list.""" - empty_file = tmp_path / ".gitattributes" - empty_file.touch() - assert os.path.exists(empty_file) - read_attributes = git_utils.read_gitattributes(empty_file) - assert read_attributes == [] - - -def test_write_gitattributes(gitattributes, tmp_path): - """Test that attributes are written to file unchanged and include newlines.""" - attr_file = tmp_path / ".gitattributes" - for attr in gitattributes: - assert not attr.endswith("\n") - git_utils.write_gitattributes(attr_file, gitattributes) - with open(attr_file) as wf: - written_attrs = wf.readlines() - # Check for the newlines which I purposely left on with my reading code. - for written_attr, attr in zip(written_attrs, gitattributes): - assert written_attr == f"{attr}\n" - - -def test_write_gitattributes_ends_in_newline(gitattributes, tmp_path): - """Make sure we have a final newline when writing out file.""" - attr_file = tmp_path / ".gitattributes" - git_utils.write_gitattributes(attr_file, gitattributes) - with open(attr_file) as f: - attrs = f.read() - assert attrs[-1] == "\n" - - -def test_write_gitattributes_creates_file(gitattributes, tmp_path): - """Make sure writing the git attributes can create the missing file before writing.""" - attr_file = tmp_path / ".gitattributes" - assert not os.path.exists(attr_file) - git_utils.write_gitattributes(attr_file, gitattributes) - assert os.path.exists(attr_file) - - -def test_read_write_gitattributes_write_read_round_trip(gitattributes, tmp_path): - """Test that we can write attributes, then read them back and they will match.""" - attr_file = tmp_path / ".gitattributes" - git_utils.write_gitattributes(attr_file, gitattributes) - read_attrs = git_utils.read_gitattributes(attr_file) - assert read_attrs == gitattributes - - -def test_read_write_gitattributes_read_write_round_trip(gitattributes, tmp_path): - """Test reading attrs from file, writing to new file and verify file contents match.""" - attr_file = tmp_path / ".gitattributes" - with open(attr_file, "w") as wf: - wf.writelines([f"{attr}\n" for attr in gitattributes]) - - new_attr_file = tmp_path / ".gitattributes-2" - read_attrs = git_utils.read_gitattributes(attr_file) - git_utils.write_gitattributes(new_attr_file, read_attrs) - - with open(attr_file) as old_f: - old_atts = old_f.read() - with open(new_attr_file) as new_f: - new_atts = new_f.read() - - assert old_atts == new_atts From 6b2193064185ee1add6473380d82347462b73441 Mon Sep 17 00:00:00 2001 From: Nikhil Kandpal Date: Thu, 4 May 2023 23:45:34 -0400 Subject: [PATCH 4/4] Fix checkpoint handler tests --- tests/checkpoints_test.py | 74 ++++------------------------- tests/config_test.py | 15 ------ tests/conftest.py | 15 ++++++ tests/tensorflow_checkpoint_test.py | 9 +++- 4 files changed, 32 insertions(+), 81 deletions(-) diff --git a/tests/checkpoints_test.py b/tests/checkpoints_test.py index a6be24a..43eb2c2 100644 --- a/tests/checkpoints_test.py +++ b/tests/checkpoints_test.py @@ -1,79 +1,25 @@ """Tests for checkpoints.py""" import os +import subprocess import pytest from git_theta import checkpoints -ENV_CHECKPOINT_TYPE = "GIT_THETA_CHECKPOINT_TYPE" - -pytest.importorskip("pytorch") - - -@pytest.fixture -def env_var(): - current_env = dict(os.environ) - os.environ[ENV_CHECKPOINT_TYPE] = "env_variable_handler" - - yield - os.environ.clear() - os.environ.update(current_env) - - -@pytest.fixture -def no_env_var(): - current_env = dict(os.environ) - os.environ.pop(ENV_CHECKPOINT_TYPE, None) - - yield - os.environ.clear() - os.environ.update(current_env) +torch = pytest.importorskip("torch") @pytest.fixture -def empty_env_var(): - current_env = dict(os.environ) - os.environ[ENV_CHECKPOINT_TYPE] = "" - - yield - os.environ.clear() - os.environ.update(current_env) - - -def test_get_checkpoint_handler_name_user_input(env_var): - """Check that function prefers user input to environment variable""" +def fake_model(): + return torch.nn.Sequential( + torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) + ) - user_input = "user_input_handler" - name = checkpoints.get_checkpoint_handler_name(user_input) - assert name == user_input - -def test_get_checkpoint_handler_name_env_variable(env_var): - """Check that function uses environment variable no user input specified""" - - name = checkpoints.get_checkpoint_handler_name() - assert name == "env_variable_handler" - - -def test_get_checkpoint_handler_name_default1(no_env_var): - """Check that function has correct default behavior with no user input and environment variable""" - - name = checkpoints.get_checkpoint_handler_name() - assert name == "pytorch" - - -def test_get_checkpoint_handler_name_default2(empty_env_var): - """Check that function has correct default behavior with no user input and environment variable is empty string""" - - name = checkpoints.get_checkpoint_handler_name() - assert name == "pytorch" - - -# TODO: Move this (and other pytorch checkpoint tests) to new file. Remove the -# importorskip too. -def test_get_checkpoint_handler_pytorch(no_env_var): +def test_get_checkpoint_handler_pytorch(git_repo, fake_model): """Check that checkpoint_handler type is correct for when checkpoint_handler name resolves to pytorch""" - - out = checkpoints.get_checkpoint_handler("pytorch") + torch.save(fake_model, "model.bin") + subprocess.run("git theta track model.bin --checkpoint_format pytorch".split(" ")) + out = checkpoints.get_checkpoint_handler("model.bin") assert out == checkpoints.pickled_dict_checkpoint.PickledDictCheckpoint diff --git a/tests/config_test.py b/tests/config_test.py index 5f511c8..24dce0b 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -8,21 +8,6 @@ from git_theta import config -@pytest.fixture -def git_repo(): - cwd = os.getcwd() - repo_dir = os.path.abspath(".delete-me") - os.mkdir(repo_dir) - try: - os.chdir(repo_dir) - repo = git.Repo.init(repo_dir) - yield repo - finally: - os.chdir(cwd) - repo.close() - git.rmtree(repo_dir) - - @pytest.fixture def gitattributes(): return [ diff --git a/tests/conftest.py b/tests/conftest.py index 0a510a6..982a166 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,6 +100,21 @@ def data_generator(): return DataGenerator +@pytest.fixture +def git_repo(): + cwd = os.getcwd() + repo_dir = os.path.abspath(".delete-me") + os.mkdir(repo_dir) + try: + os.chdir(repo_dir) + repo = git.Repo.init(repo_dir) + yield repo + finally: + os.chdir(cwd) + repo.close() + git.rmtree(repo_dir) + + @pytest.fixture def git_repo_with_commits(): commit_infos = [ diff --git a/tests/tensorflow_checkpoint_test.py b/tests/tensorflow_checkpoint_test.py index 3cd1737..6cbb4d9 100644 --- a/tests/tensorflow_checkpoint_test.py +++ b/tests/tensorflow_checkpoint_test.py @@ -1,6 +1,7 @@ """Tensorflow checkpoint tests.""" import os +import subprocess import tempfile from unittest import mock @@ -93,6 +94,10 @@ def test_round_trip_with_modifications(fake_model): np.testing.assert_allclose(og.numpy(), new.numpy()) -def test_get_checkpoint_handler_tensorflow(): - out = checkpoints.get_checkpoint_handler("tensorflow-checkpoint") +def test_get_checkpoint_handler_tensorflow(git_repo, fake_model): + fake_model.save_weights("model.bin") + subprocess.run( + "git theta track model.bin --checkpoint_format tensorflow".split(" ") + ) + out = checkpoints.get_checkpoint_handler("model.bin") assert out == tensorflow_checkpoint.TensorFlowCheckpoint