Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 19 additions & 103 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,42 +936,12 @@ class runopt:
Represents the metadata about the specific run option
"""

class AutoAlias(IntEnum):
snake_case = 0x1
SNAKE_CASE = 0x2
camelCase = 0x4

@staticmethod
def convert_to_camel_case(alias: str) -> str:
words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias)
words = [w for w in words if w] # Remove empty strings
if not words:
return ""
return words[0].lower() + "".join(w.capitalize() for w in words[1:])

@staticmethod
def convert_to_snake_case(alias: str) -> str:
alias = re.sub(r"[-\s]+", "_", alias)
alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias)
alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias)
return alias.lower()

@staticmethod
def convert_to_const_case(alias: str) -> str:
return runopt.AutoAlias.convert_to_snake_case(alias).upper()

class alias(str):
pass

class deprecated(str):
pass

default: CfgVal
opt_type: Type[CfgVal]
is_required: bool
help: str
aliases: set[alias] | None = None
deprecated_aliases: set[deprecated] | None = None
aliases: list[str] | None = None
deprecated_aliases: list[str] | None = None

@property
def is_type_list_of_str(self) -> bool:
Expand Down Expand Up @@ -1257,85 +1227,23 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
cfg[key] = val
return cfg

def _generate_aliases(
self, auto_alias: int, aliases: set[str]
) -> set[runopt.alias]:
generated_aliases = set()
for alias in aliases:
if auto_alias & runopt.AutoAlias.camelCase:
generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias))
if auto_alias & runopt.AutoAlias.snake_case:
generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias))
if auto_alias & runopt.AutoAlias.SNAKE_CASE:
generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias))
return generated_aliases

def _get_primary_key_and_aliases(
self,
cfg_key: list[str | int] | str,
) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]:
"""
Returns the primary key and aliases for the given cfg_key.
"""
if isinstance(cfg_key, str):
return cfg_key, set(), set()

if len(cfg_key) == 0:
raise ValueError("cfg_key must be a non-empty list")

if isinstance(cfg_key[0], runopt.alias) or isinstance(
cfg_key[0], runopt.deprecated
):
warnings.warn(
"The main name of the run option should be the head of the list.",
UserWarning,
stacklevel=2,
)
primary_key = None
auto_alias = 0x0
aliases = set[runopt.alias]()
deprecated_aliases = set[runopt.deprecated]()
for name in cfg_key:
if isinstance(name, runopt.alias):
aliases.add(name)
elif isinstance(name, runopt.deprecated):
deprecated_aliases.add(name)
elif isinstance(name, int):
auto_alias = auto_alias | name
else:
if primary_key is not None:
raise ValueError(
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
)
primary_key = name
if primary_key is None or primary_key == "":
raise ValueError(
"Missing cfg_key. Please provide one other than the aliases."
)
if auto_alias != 0x0:
aliases_to_generate_for = aliases | {primary_key}
additional_aliases = self._generate_aliases(
auto_alias, aliases_to_generate_for
)
aliases.update(additional_aliases)
return primary_key, aliases, deprecated_aliases

def add(
self,
cfg_key: str | list[str | int],
cfg_key: str,
type_: Type[CfgVal],
help: str,
default: CfgVal = None,
required: bool = False,
aliases: Optional[list[str]] = None,
deprecated_aliases: Optional[list[str]] = None,
) -> None:
"""
Adds the ``config`` option with the given help string and ``default``
value (if any). If the ``default`` is not specified then this option
is a required option.
"""
primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases(
cfg_key
)
aliases = aliases or []
deprecated_aliases = deprecated_aliases or []
if required and default is not None:
raise ValueError(
f"Required option: {cfg_key} must not specify default value. Given: {default}"
Expand All @@ -1346,12 +1254,20 @@ def add(
f"Option: {cfg_key}, must be of type: {type_}."
f" Given: {default} ({type(default).__name__})"
)
opt = runopt(default, type_, required, help, aliases, deprecated_aliases)

opt = runopt(
default,
type_,
required,
help,
list(set(aliases)),
list(set(deprecated_aliases)),
)
for alias in aliases:
self._alias_to_key[alias] = primary_key
self._alias_to_key[alias] = cfg_key
for deprecated_alias in deprecated_aliases:
self._alias_to_key[deprecated_alias] = primary_key
self._opts[primary_key] = opt
self._alias_to_key[deprecated_alias] = cfg_key
self._opts[cfg_key] = opt

def update(self, other: "runopts") -> None:
self._opts.update(other._opts)
Expand Down
64 changes: 20 additions & 44 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,8 @@ def test_runopts_add(self) -> None:
def test_runopts_add_with_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
"job_priority",
aliases=["jobPriority"],
type_=str,
help="priority for the job",
)
Expand All @@ -616,7 +617,8 @@ def test_runopts_add_with_aliases(self) -> None:
def test_runopts_resolve_with_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
"job_priority",
aliases=["jobPriority"],
type_=str,
help="priority for the job",
)
Expand All @@ -628,71 +630,45 @@ def test_runopts_resolve_with_aliases(self) -> None:
def test_runopts_resolve_with_none_valued_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
"job_priority",
aliases=["jobPriority"],
type_=str,
help="priority for the job",
)
opts.add(
["modelTypeName", runopt.alias("model_type_name")],
"model_type_name",
aliases=["modelTypeName"],
type_=Union[str, None],
help="ML Hub Model Type to attribute resource utilization for job",
)
resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"})
self.assertEqual(resolved_opts.get("model_type_name"), None)
resolved_opts = opts.resolve({"modelTypeName": None, "jobPriority": "low"})
self.assertEqual(resolved_opts.get("modelTypeName"), None)
self.assertEqual(resolved_opts.get("jobPriority"), "low")
self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"})
self.assertEqual(resolved_opts, {"modelTypeName": None, "jobPriority": "low"})

with self.assertRaises(InvalidRunConfigException):
opts.resolve({"model_type_name": None, "modelTypeName": "low"})
opts.resolve({"modelTypeName": None, "model_type_name": "low"})

def test_runopts_add_with_deprecated_aliases(self) -> None:
opts = runopts()
with warnings.catch_warnings(record=True) as w:
opts.add(
[runopt.deprecated("jobPriority"), "job_priority"],
type_=str,
help="run as user",
)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"The main name of the run option should be the head of the list.",
)
opts.add(
"job_priority",
deprecated_aliases=["priority"],
type_=str,
help="run as user",
)

opts.resolve({"job_priority": "high"})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
opts.resolve({"jobPriority": "high"})
opts.resolve({"priority": "high"})
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"Run option `jobPriority` is deprecated, use `job_priority` instead",
"Run option `priority` is deprecated, use `job_priority` instead",
)

def test_runopt_auto_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.AutoAlias.camelCase],
type_=str,
help="run as user",
)
opts.add(
[
"model_type_name",
runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE,
],
type_=str,
help="run as user",
)
self.assertEqual(2, len(opts._opts))
self.assertIsNotNone(opts.get("job_priority"))
self.assertIsNotNone(opts.get("jobPriority"))
self.assertIsNotNone(opts.get("model_type_name"))
self.assertIsNotNone(opts.get("modelTypeName"))
self.assertIsNotNone(opts.get("MODEL_TYPE_NAME"))

def get_runopts(self) -> runopts:
opts = runopts()
opts.add("run_as", type_=str, help="run as user", required=True)
Expand Down
Loading