Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions nemo_automodel/_transformers/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def from_pretrained(
torch_dtype="auto",
attn_implementation: str = "flash_attention_2",
quantization_config=None,
force_hf: bool = False,
**kwargs,
) -> PreTrainedModel:
"""
Expand Down Expand Up @@ -210,6 +211,8 @@ def from_pretrained(
quantization_config (optional): BitsAndBytesConfig configuration object that
specifies all quantization settings. If provided, quantization
will be applied to the model.
force_hf (bool, default=False): If `True`, force the use of HF model implementation.
If `False`, the model will be loaded using the custom model implementation if available.
**kwargs: Additional keyword arguments forwarded verbatim to
`AutoModelForCausalLM.from_pretrained`.

Expand Down Expand Up @@ -254,17 +257,20 @@ def _retry(**override):
name = cls.__name__
if name.startswith("NeMo"):
cls.__name__ = name[4:]
try:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=bool(kwargs.get("trust_remote_code", False))
)
# if we have a custom model implementation available, we prioritize that over HF
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
model = ModelRegistry.model_arch_name_to_cls[config.architectures[0]](config, *model_args, **kwargs)
logger.info(f"Using custom model implementation for {config.architectures[0]}")
return model
except Exception as e:
logger.error(f"Failed to use custom model implementation with error: {e}")
if not force_hf:
try:
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=bool(kwargs.get("trust_remote_code", False))
)
# if we have a custom model implementation available, we prioritize that over HF
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
model = ModelRegistry.model_arch_name_to_cls[config.architectures[0]](
config, *model_args, **kwargs
)
logger.info(f"Using custom model implementation for {config.architectures[0]}")
return model
except Exception as e:
logger.error(f"Failed to use custom model implementation with error: {e}")

if quantization_config is not None:
kwargs["quantization_config"] = quantization_config
Expand Down Expand Up @@ -317,6 +323,7 @@ def from_config(
torch_dtype: Union[str, torch.dtype] = "auto",
attn_implementation: str = "flash_attention_2",
quantization_config=None,
force_hf: bool = False,
**kwargs,
) -> PreTrainedModel:
"""
Expand Down Expand Up @@ -344,6 +351,8 @@ def from_config(
``"flash_attention_2"``, ``"eager"``). Only applied when the
base model supports this kwarg. Defaults to
``"flash_attention_2"``.
force_hf (bool, default=False): If `True`, force the use of HF model implementation.
If `False`, the model will be loaded using the custom model implementation if available.
**kwargs:
Additional keyword arguments forwarded to the superclass
constructor and underlying ``from_config`` logic.
Expand Down Expand Up @@ -382,14 +391,17 @@ def _retry(**override):
name = cls.__name__
if name.startswith("NeMo"):
cls.__name__ = name[4:]
try:
# if we have a custom model implementation available, we prioritize that over HF
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
model = ModelRegistry.model_arch_name_to_cls[config.architectures[0]](config, *model_args, **kwargs)
logger.info(f"Using custom model implementation for {config.architectures[0]}")
return model
except Exception as e:
logger.error(f"Failed to use custom model implementation with error: {e}")
if not force_hf:
try:
# if we have a custom model implementation available, we prioritize that over HF
if config.architectures[0] in ModelRegistry.model_arch_name_to_cls:
model = ModelRegistry.model_arch_name_to_cls[config.architectures[0]](
config, *model_args, **kwargs
)
logger.info(f"Using custom model implementation for {config.architectures[0]}")
return model
except Exception as e:
logger.error(f"Failed to use custom model implementation with error: {e}")

if quantization_config is not None:
kwargs["quantization_config"] = quantization_config
Expand Down
Loading