Skip to content

Commit dea1aa0

Browse files
committed
[Enhance] Suppress import errors for optional optimizers
- Add try-except blocks around optimizer registration to gracefully handle import failures - Extend ImportError handling to include RuntimeError for CUDA-related issues - Add warnings.warn to notify users of failed optimizer imports without breaking execution
1 parent bce1c24 commit dea1aa0

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

mmengine/optim/optimizer/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ def register_sophia_optimizers() -> List[str]:
125125
_optim = getattr(Sophia, module_name)
126126
if inspect.isclass(_optim) and issubclass(_optim,
127127
torch.optim.Optimizer):
128-
OPTIMIZERS.register_module(module=_optim)
129-
optimizers.append(module_name)
128+
try:
129+
OPTIMIZERS.register_module(module=_optim)
130+
except Exception as e:
131+
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
130132
return optimizers
131133

132134

@@ -146,7 +148,8 @@ def register_bitsandbytes_optimizers() -> List[str]:
146148
dadaptation_optimizers = []
147149
try:
148150
import bitsandbytes as bnb
149-
except ImportError:
151+
# import bnb may trigger cuda related error without nvidia gpu resources
152+
except (ImportError, RuntimeError):
150153
pass
151154
else:
152155
optim_classes = inspect.getmembers(
@@ -155,7 +158,10 @@ def register_bitsandbytes_optimizers() -> List[str]:
155158
for name, optim_cls in optim_classes:
156159
if name in OPTIMIZERS:
157160
name = f'bnb_{name}'
158-
OPTIMIZERS.register_module(module=optim_cls, name=name)
161+
try:
162+
OPTIMIZERS.register_module(module=optim_cls, name=name)
163+
except Exception as e:
164+
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
159165
dadaptation_optimizers.append(name)
160166
return dadaptation_optimizers
161167

@@ -170,7 +176,10 @@ def register_transformers_optimizers():
170176
except ImportError:
171177
pass
172178
else:
173-
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
179+
try:
180+
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
181+
except Exception as e:
182+
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
174183
transformer_optimizers.append('Adafactor')
175184
return transformer_optimizers
176185

0 commit comments

Comments
 (0)