Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_available_datasets() -> list[str]:


def get_available_detectors(dataset: str) -> list[str]:
""" Only for PyTorch models.
"""Only for PyTorch models.

Returns:
The detectors available for the dataset.
Expand All @@ -103,7 +103,7 @@ def get_available_detectors(dataset: str) -> list[str]:


def get_available_models(dataset: str) -> list[str]:
""" Only for PyTorch models.
"""Only for PyTorch models.

Returns:
The pose models available for the dataset.
Expand Down Expand Up @@ -139,19 +139,39 @@ def download_huggingface_model(
model_name: str,
target_dir: str = ".",
remove_hf_folder: bool = True,
rename_mapping: dict | None = None,
rename_mapping: str | dict | None = None,
):
"""
Downloads a DeepLabCut Model Zoo Project from Hugging Face.

Args:
model_name (str): Name of the ModelZoo model.
model_name (str):
Name of the ModelZoo model.
For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
after downloading and decompressing the data into DeepLabCut format. Defaults to True.
rename_mapping (dict, optional): A dictionary to rename the downloaded file.
If None, the original filename is used. Defaults to None.
target_dir (str, optional):
Target directory where the model weights will be stored.
Defaults to the current directory.
remove_hf_folder (bool, optional):
Whether to remove the directory structure created by HuggingFace
after downloading and decompressing the data into DeepLabCut format.
Defaults to True.
rename_mapping (dict | str | None, optional):
- If a dictionary, it should map the original Hugging Face filenames
to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}).
- If a string, it is interpreted as the new name for the downloaded file
- If None, the original filename is used.
Defaults to None.

Examples:
>>> # Download without renaming, keep original filename
download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)

>>> # Download and rename by specifying the new name directly
download_huggingface_model(
model_name="superanimal_humanbody_rtmpose_x",
target_dir="/path/to/,y/checkpoints",
rename_mapping="superanimal_humanbody_rtmpose_x.pt"
)
"""
net_urls = _load_model_names()
if model_name not in net_urls:
Expand Down Expand Up @@ -180,6 +200,10 @@ def download_huggingface_model(
path_ = os.path.join(target_dir, hf_folder, "snapshots")
commit = os.listdir(path_)[0]
file_name = os.path.join(path_, commit, targzfn)

if isinstance(rename_mapping, str):
rename_mapping = {targzfn: rename_mapping}

_handle_downloaded_file(file_name, target_dir, rename_mapping)

if remove_hf_folder:
Expand Down