Skip to content

Repo-level and pattern-level configuration in .thetaconfig #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions git_theta/checkpoints/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base class and utilities for different checkpoint format backends."""

import fnmatch
import os
import sys
from abc import ABCMeta, abstractmethod
Expand All @@ -12,7 +13,7 @@
else:
from importlib.metadata import entry_points

from git_theta import utils
from git_theta import config, git_utils, utils


@utils.abstract_classattributes("name")
Expand Down Expand Up @@ -102,49 +103,42 @@ def diff(cls, m1: "Checkpoint", m2: "Checkpoint") -> "Checkpoint":
return added, removed, modified


def get_checkpoint_handler_name(checkpoint_type: Optional[str] = None) -> str:
"""Get the name of the checkpoint handler to use.

Order of precedence is
1. `checkpoint_type` argument
2. `$GIT_THETA_CHECKPOINT_TYPE` environment variable
3. default value (currently pytorch)
def get_checkpoint_handler_name(checkpoint_path: str) -> Optional[str]:
"""Get the name of the checkpoint handler based on entry in .thetaconfig for the current checkpoint path

Parameters
----------
checkpoint_type
Name of the checkpoint handler
checkpoint_path
Path to checkpoint in repo

Returns
-------
str
Name of the checkpoint handler
"""
# TODO(bdlester): Find a better way to include checkpoint type information
# in git clean filters that are run without `git theta add`.
# TODO: Don't default to pytorch once other checkpoint formats are supported.
return checkpoint_type or utils.EnvVarConstants.CHECKPOINT_TYPE
repo = git_utils.get_git_repo()
checkpoint_path = git_utils.get_relative_path_from_root(repo, checkpoint_path)
thetaconfig = config.ThetaConfigFile(repo)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to result in reading the config file a bunch of time yeah? Should we refactor to something like the config file is read once and then the config object is used?

checkpoint_config = thetaconfig.get_config(checkpoint_path)
return checkpoint_config.get("checkpoint_format", None)


def get_checkpoint_handler(checkpoint_type: Optional[str] = None) -> Checkpoint:
"""Get the checkpoint handler either by name or from an environment variable.
def get_checkpoint_handler(checkpoint_path: str) -> Checkpoint:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is doing too much now, based on SoC, this function should really be about translating a checkpoint name into a class. We should probably have the path -> checkpoint type be done externally and then the result is passed in.

"""Get the checkpoint handler for the current checkpoint path

Gets the checkpoint handler either for the `checkpoint_type` argument or
`$GIT_THETA_CHECKPOINT_TYPE` environment variable.

Defaults to pytorch when neither are defined.
Gets the checkpoint handler from an entry in .thetaconfig

Parameters
----------
checkpoint_type
Name of the checkpoint handler
checkpoint_path
Path to the checkpoint

Returns
-------
Checkpoint
The checkpoint handler (usually an instance of `git_theta.checkpoints.Checkpoint`).
Returned handler may be defined in a user installed plugin.
"""
checkpoint_type = get_checkpoint_handler_name(checkpoint_type)
checkpoint_type = get_checkpoint_handler_name(checkpoint_path)
discovered_plugins = entry_points(group="git_theta.plugins.checkpoints")
return discovered_plugins[checkpoint_type].load()
199 changes: 199 additions & 0 deletions git_theta/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import dataclasses
import fnmatch
import json
import os
import re
from collections import OrderedDict
from typing import Dict, List, Tuple

import git

from git_theta import git_utils


@dataclasses.dataclass
class PatternAttributes:
pattern: str
attributes: str

@classmethod
def from_line(cls, line):
# TODO(bdlester): Revisit this regex to see if it when the pattern
# is escaped due to having spaces in it.
match = re.match(
r"^\s*(?P<pattern>[^\s]+)\s+(?P<attributes>.*)$", line.rstrip()
)
if not match:
raise ValueError(f"{line} is an invalid git attribute line")
return cls(pattern=match.group("pattern"), attributes=match.group("attributes"))

def add_attribute(self, attribute):
# TODO(nkandpa2): It does not work to specify multiple filter/diff/merge attributes for one pattern so
# at some point we may have to do better parsing of the .gitattributes lines to avoid that
if attribute not in self.attributes:
self.attributes = f"{self.attributes.rstrip()} {attribute}"

def serialize(self):
return f"{self.pattern} {self.attributes}"


class GitAttributesFile:
def __init__(self, repo: git.Repo):
self.repo = repo
self.file = os.path.join(repo.working_dir, ".gitattributes")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as below wrt to holding onto the file name.

self.data = GitAttributesFile.read(self.file)

@classmethod
def read(cls, file: str):
"""
Read contents of this repo's .gitattributes file

Parameters
----------
file
Path to .gitattributes file

Returns
-------
List[str]
lines in .gitattributes file
"""
if os.path.exists(file):
with open(file, "r") as f:
return [PatternAttributes.from_line(line) for line in f if line.strip()]
else:
return []

def write(self):
"""
Write list of attributes to this repo's .gitattributes file
"""
with open(self.file, "w") as f:
f.write(self.serialize())

def serialize(self):
return "\n".join([pattern_attrs.serialize() for pattern_attrs in self.data])

def add_theta(self, pattern: str) -> List[str]:
"""Set filter, merge, and diff attributes to theta for `pattern`.

Parameters
----------
pattern:
The pattern we are adding theta attributes for
"""
pattern = git_utils.get_relative_path_from_root(self.repo, pattern)
pattern_found = False
for pattern_attrs in self.data:
if pattern == pattern_attrs.pattern:
pattern_attrs.add_attribute("filter=theta")
pattern_attrs.add_attribute("merge=theta")
pattern_attrs.add_attribute("diff=theta")
pattern_found = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment about how we don't have a break here because gitattribute shadow behavior means we want to update every line where the pattern appears?

# If we don't find a matching pattern, add a new line that covers this pattern
if not pattern_found:
self.data.append(
PatternAttributes.from_line(
f"{pattern} filter=theta merge=theta diff=theta"
)
)

def is_theta_tracked(self, path: str) -> bool:
return any(
[
fnmatch.fnmatchcase(path, pattern_attr.pattern)
and "filter=theta" in pattern_attr.attributes
for pattern_attr in self.data
]
)


@dataclasses.dataclass
class Config:
@classmethod
def from_dict(cls, params: Dict) -> "Config":
fields = [field.name for field in dataclasses.fields(cls)]
return cls(**{k: v for k, v in params.items() if k in fields})

def serialize(self) -> Dict:
return dataclasses.asdict(self, dict_factory=OrderedDict)


@dataclasses.dataclass
class PatternConfig(Config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the format being

{
  ....
  "patterns": [
    {
      "pattern": str
      "checkpoint_format": str
    },
    ...
  ]
}

How about we do something like

{
  ....
  "patterns": {
      "pattern": {
        "checkpoint_format": str
      },
      ...
  }
}

We are on python 3.7+ so Dicts are ordered (thus we can copy gitattributes shadow behavior) and we can just return the config dataclass instead of having to fiddle with it to turn it into a dict.

pattern: str
checkpoint_format: str


@dataclasses.dataclass
class RepoConfig(Config):
parameter_atol: float = 1e-8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to think about nesting this a bit more? Like should we group the LSH params into one object? The downside is it gets more complex to make a config file and work with it

parameter_rtol: float = 1e-5
lsh_signature_size: int = 16
lsh_threshold: float = 1e-6
lsh_pool_size: int = 10_000
max_concurrency: int = -1


class ThetaConfigFile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to below, we probably want a name like ThetaConfig without the "File" part as this object is really about the configuration of git-theta, not about the actual file it lives in

def __init__(self, repo):
self.repo = repo
self.file = os.path.join(repo.working_dir, ".thetaconfig")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we expecting this object to write the file multiple times? If not it seems like it would be a better match with things like the Checkpoint API to have write take a file/path and not have to hold onto this location.

self.repo_config, self.pattern_configs = ThetaConfigFile.read(self.file)

@classmethod
def read(cls, file) -> Tuple[RepoConfig, List[PatternConfig]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like with this approach we are loosing the ability to tweak values from and EnvVar at runtime? We definitely want to be able to do that for things like max concurrency and I could see wanting to do that with other things too,

"""
Read contents of this repo's .thetaconfig file

Returns
-------
repo_config: RepoConfig
Repository-level configuration parameters

pattern_configs: List[PatternConfig]
Pattern-level configuration parameters
"""
if os.path.exists(file):
with open(file, "r") as f:
config = json.load(f)
else:
config = {"repo": {}, "patterns": []}

repo_config = RepoConfig.from_dict(config["repo"])
pattern_configs = [PatternConfig.from_dict(d) for d in config["patterns"]]
return repo_config, pattern_configs

def write(self):
"""
Write a dictionary config to this repo's .thetaconfig file

Parameters
----------
config_file:
Path to this repo's .thetaconfig file

config:
Configuration dictionary to write to .thetaconfig
"""
with open(self.file, "w") as f:
json.dump(self.serialize(), f, indent=4)

def serialize(self):
repo_config = self.repo_config.serialize()
pattern_configs = [pc.serialize() for pc in self.pattern_configs]
config = {"repo": repo_config, "patterns": pattern_configs}
return config

def add_pattern(self, config: Dict):
pattern_config = PatternConfig.from_dict(config)
self.pattern_configs.append(pattern_config)

def get_config(self, path: str) -> Dict:
config = {}
for pc in self.pattern_configs:
if fnmatch.fnmatchcase(path, pc.pattern):
config.update(pc.serialize())
if "pattern" in config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.pop can take a default arg (that is returned if the key doesn't match) so you don't have to do this check.

_ = config.pop("pattern", None)

config.pop("pattern")
return config
Loading