-
Notifications
You must be signed in to change notification settings - Fork 8
Repo-level and pattern-level configuration in .thetaconfig #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
"""Base class and utilities for different checkpoint format backends.""" | ||
|
||
import fnmatch | ||
import os | ||
import sys | ||
from abc import ABCMeta, abstractmethod | ||
|
@@ -12,7 +13,7 @@ | |
else: | ||
from importlib.metadata import entry_points | ||
|
||
from git_theta import utils | ||
from git_theta import config, git_utils, utils | ||
|
||
|
||
@utils.abstract_classattributes("name") | ||
|
@@ -102,49 +103,42 @@ def diff(cls, m1: "Checkpoint", m2: "Checkpoint") -> "Checkpoint": | |
return added, removed, modified | ||
|
||
|
||
def get_checkpoint_handler_name(checkpoint_type: Optional[str] = None) -> str: | ||
"""Get the name of the checkpoint handler to use. | ||
|
||
Order of precedence is | ||
1. `checkpoint_type` argument | ||
2. `$GIT_THETA_CHECKPOINT_TYPE` environment variable | ||
3. default value (currently pytorch) | ||
def get_checkpoint_handler_name(checkpoint_path: str) -> Optional[str]: | ||
"""Get the name of the checkpoint handler based on entry in .thetaconfig for the current checkpoint path | ||
|
||
Parameters | ||
---------- | ||
checkpoint_type | ||
Name of the checkpoint handler | ||
checkpoint_path | ||
Path to checkpoint in repo | ||
|
||
Returns | ||
------- | ||
str | ||
Name of the checkpoint handler | ||
""" | ||
# TODO(bdlester): Find a better way to include checkpoint type information | ||
# in git clean filters that are run without `git theta add`. | ||
# TODO: Don't default to pytorch once other checkpoint formats are supported. | ||
return checkpoint_type or utils.EnvVarConstants.CHECKPOINT_TYPE | ||
repo = git_utils.get_git_repo() | ||
checkpoint_path = git_utils.get_relative_path_from_root(repo, checkpoint_path) | ||
thetaconfig = config.ThetaConfigFile(repo) | ||
checkpoint_config = thetaconfig.get_config(checkpoint_path) | ||
return checkpoint_config.get("checkpoint_format", None) | ||
|
||
|
||
def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint: | ||
"""Get the checkpoint handler either by name or from an environment variable. | ||
def get_checkpoint_handler(checkpoint_path: str) -> Checkpoint: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function is doing too much now, based on SoC, this function should really be about translating a checkpoint name into a class. We should probably have the path -> checkpoint type be done externally and then the result is passed in. |
||
"""Get the checkpoint handler for the current checkpoint path | ||
|
||
Gets the checkpoint handler either for the `checkpoint_type` argument or | ||
`$GIT_THETA_CHECKPOINT_TYPE` environment variable. | ||
|
||
Defaults to pytorch when neither are defined. | ||
Gets the checkpoint handler from an entry in .thetaconfig | ||
|
||
Parameters | ||
---------- | ||
checkpoint_type | ||
Name of the checkpoint handler | ||
checkpoint_path | ||
Path to the checkpoint | ||
|
||
Returns | ||
------- | ||
Checkpoint | ||
The checkpoint handler (usually an instance of `git_theta.checkpoints.Checkpoint`). | ||
Returned handler may be defined in a user installed plugin. | ||
""" | ||
checkpoint_type = get_checkpoint_handler_name(checkpoint_type) | ||
checkpoint_type = get_checkpoint_handler_name(checkpoint_path) | ||
discovered_plugins = entry_points(group="git_theta.plugins.checkpoints") | ||
return discovered_plugins[checkpoint_type].load() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import dataclasses | ||
import fnmatch | ||
import json | ||
import os | ||
import re | ||
from collections import OrderedDict | ||
from typing import Dict, List, Tuple | ||
|
||
import git | ||
|
||
from git_theta import git_utils | ||
|
||
|
||
@dataclasses.dataclass | ||
class PatternAttributes: | ||
pattern: str | ||
attributes: str | ||
|
||
@classmethod | ||
def from_line(cls, line): | ||
# TODO(bdlester): Revisit this regex to see if it when the pattern | ||
# is escaped due to having spaces in it. | ||
match = re.match( | ||
r"^\s*(?P<pattern>[^\s]+)\s+(?P<attributes>.*)$", line.rstrip() | ||
) | ||
if not match: | ||
raise ValueError(f"{line} is an invalid git attribute line") | ||
return cls(pattern=match.group("pattern"), attributes=match.group("attributes")) | ||
|
||
def add_attribute(self, attribute): | ||
# TODO(nkandpa2): It does not work to specify multiple filter/diff/merge attributes for one pattern so | ||
# at some point we may have to do better parsing of the .gitattributes lines to avoid that | ||
if attribute not in self.attributes: | ||
self.attributes = f"{self.attributes.rstrip()} {attribute}" | ||
|
||
def serialize(self): | ||
return f"{self.pattern} {self.attributes}" | ||
|
||
|
||
class GitAttributesFile: | ||
def __init__(self, repo: git.Repo): | ||
self.repo = repo | ||
self.file = os.path.join(repo.working_dir, ".gitattributes") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question as below wrt to holding onto the file name. |
||
self.data = GitAttributesFile.read(self.file) | ||
|
||
@classmethod | ||
def read(cls, file: str): | ||
""" | ||
Read contents of this repo's .gitattributes file | ||
|
||
Parameters | ||
---------- | ||
file | ||
Path to .gitattributes file | ||
|
||
Returns | ||
------- | ||
List[str] | ||
lines in .gitattributes file | ||
""" | ||
if os.path.exists(file): | ||
with open(file, "r") as f: | ||
return [PatternAttributes.from_line(line) for line in f if line.strip()] | ||
else: | ||
return [] | ||
|
||
def write(self): | ||
""" | ||
Write list of attributes to this repo's .gitattributes file | ||
""" | ||
with open(self.file, "w") as f: | ||
f.write(self.serialize()) | ||
|
||
def serialize(self): | ||
return "\n".join([pattern_attrs.serialize() for pattern_attrs in self.data]) | ||
|
||
def add_theta(self, pattern: str) -> List[str]: | ||
"""Set filter, merge, and diff attributes to theta for `pattern`. | ||
|
||
Parameters | ||
---------- | ||
pattern: | ||
The pattern we are adding theta attributes for | ||
""" | ||
pattern = git_utils.get_relative_path_from_root(self.repo, pattern) | ||
pattern_found = False | ||
for pattern_attrs in self.data: | ||
if pattern == pattern_attrs.pattern: | ||
pattern_attrs.add_attribute("filter=theta") | ||
pattern_attrs.add_attribute("merge=theta") | ||
pattern_attrs.add_attribute("diff=theta") | ||
pattern_found = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a comment about how we don't have a |
||
# If we don't find a matching pattern, add a new line that covers this pattern | ||
if not pattern_found: | ||
self.data.append( | ||
PatternAttributes.from_line( | ||
f"{pattern} filter=theta merge=theta diff=theta" | ||
) | ||
) | ||
|
||
def is_theta_tracked(self, path: str) -> bool: | ||
return any( | ||
[ | ||
fnmatch.fnmatchcase(path, pattern_attr.pattern) | ||
and "filter=theta" in pattern_attr.attributes | ||
for pattern_attr in self.data | ||
] | ||
) | ||
|
||
|
||
@dataclasses.dataclass | ||
class Config: | ||
@classmethod | ||
def from_dict(cls, params: Dict) -> "Config": | ||
fields = [field.name for field in dataclasses.fields(cls)] | ||
return cls(**{k: v for k, v in params.items() if k in fields}) | ||
|
||
def serialize(self) -> Dict: | ||
return dataclasses.asdict(self, dict_factory=OrderedDict) | ||
|
||
|
||
@dataclasses.dataclass | ||
class PatternConfig(Config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of the format being {
....
"patterns": [
{
"pattern": str
"checkpoint_format": str
},
...
]
} How about we do something like {
....
"patterns": {
"pattern": {
"checkpoint_format": str
},
...
}
} We are on python 3.7+ so Dicts are ordered (thus we can copy gitattributes shadow behavior) and we can just return the config dataclass instead of having to fiddle with it to turn it into a dict. |
||
pattern: str | ||
checkpoint_format: str | ||
|
||
|
||
@dataclasses.dataclass | ||
class RepoConfig(Config): | ||
parameter_atol: float = 1e-8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to think about nesting this a bit more? Like should we group the LSH params into one object? The downside is it gets more complex to make a config file and work with it |
||
parameter_rtol: float = 1e-5 | ||
lsh_signature_size: int = 16 | ||
lsh_threshold: float = 1e-6 | ||
lsh_pool_size: int = 10_000 | ||
max_concurrency: int = -1 | ||
|
||
|
||
class ThetaConfigFile: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to below, we probably want a name like |
||
def __init__(self, repo): | ||
self.repo = repo | ||
self.file = os.path.join(repo.working_dir, ".thetaconfig") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we expecting this object to write the file multiple times? If not it seems like it would be a better match with things like the Checkpoint API to have |
||
self.repo_config, self.pattern_configs = ThetaConfigFile.read(self.file) | ||
|
||
@classmethod | ||
def read(cls, file) -> Tuple[RepoConfig, List[PatternConfig]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like with this approach we are loosing the ability to tweak values from and EnvVar at runtime? We definitely want to be able to do that for things like max concurrency and I could see wanting to do that with other things too, |
||
""" | ||
Read contents of this repo's .thetaconfig file | ||
|
||
Returns | ||
------- | ||
repo_config: RepoConfig | ||
Repository-level configuration parameters | ||
|
||
pattern_configs: List[PatternConfig] | ||
Pattern-level configuration parameters | ||
""" | ||
if os.path.exists(file): | ||
with open(file, "r") as f: | ||
config = json.load(f) | ||
else: | ||
config = {"repo": {}, "patterns": []} | ||
|
||
repo_config = RepoConfig.from_dict(config["repo"]) | ||
pattern_configs = [PatternConfig.from_dict(d) for d in config["patterns"]] | ||
return repo_config, pattern_configs | ||
|
||
def write(self): | ||
""" | ||
Write a dictionary config to this repo's .thetaconfig file | ||
|
||
Parameters | ||
---------- | ||
config_file: | ||
Path to this repo's .thetaconfig file | ||
|
||
config: | ||
Configuration dictionary to write to .thetaconfig | ||
""" | ||
with open(self.file, "w") as f: | ||
json.dump(self.serialize(), f, indent=4) | ||
|
||
def serialize(self): | ||
repo_config = self.repo_config.serialize() | ||
pattern_configs = [pc.serialize() for pc in self.pattern_configs] | ||
config = {"repo": repo_config, "patterns": pattern_configs} | ||
return config | ||
|
||
def add_pattern(self, config: Dict): | ||
pattern_config = PatternConfig.from_dict(config) | ||
self.pattern_configs.append(pattern_config) | ||
|
||
def get_config(self, path: str) -> Dict: | ||
config = {} | ||
for pc in self.pattern_configs: | ||
if fnmatch.fnmatchcase(path, pc.pattern): | ||
config.update(pc.serialize()) | ||
if "pattern" in config: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
_ = config.pop("pattern", None) |
||
config.pop("pattern") | ||
return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to result in reading the config file a bunch of time yeah? Should we refactor to something like the config file is read once and then the config object is used?