@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266
266
_keep_in_fp32_modules = None
267
267
_skip_layerwise_casting_patterns = None
268
268
_supports_group_offloading = True
269
+ _regions_for_compile = []
269
270
270
271
def __init__ (self ):
271
272
super ().__init__ ()
@@ -1402,6 +1403,44 @@ def float(self, *args):
1402
1403
else :
1403
1404
return super ().float (* args )
1404
1405
1406
+ @wraps (torch .nn .Module .compile )
1407
+ def compile (self , use_regional_compile : bool = True , * args , ** kwargs ):
1408
+ """ """
1409
+ if use_regional_compile :
1410
+ regions_for_compile = getattr (self , "_regions_for_compile" , None )
1411
+
1412
+ if not regions_for_compile :
1413
+ logger .warning (
1414
+ "_regions_for_compile attribute is empty. Using _no_split_modules to find compile regions."
1415
+ )
1416
+
1417
+ regions_for_compile = getattr (self , "_no_split_modules" , None )
1418
+
1419
+ if not regions_for_compile :
1420
+ logger .warning (
1421
+ "Both _regions_for_compile and _no_split_modules attribute are empty. "
1422
+ "Set _regions_for_compile for the model to benefit from regional compilation. "
1423
+ "Falling back to full model compilation, which could have high first iteration "
1424
+ "latency."
1425
+ )
1426
+ super ().compile (* args , ** kwargs )
1427
+
1428
+ has_compiled_region = False
1429
+ for submod in self .modules ():
1430
+ if submod .__class__ .__name__ in regions_for_compile :
1431
+ has_compiled_region = True
1432
+ submod .compile (* args , ** kwargs )
1433
+
1434
+ if not has_compiled_region :
1435
+ raise ValueError (
1436
+ f"Regional compilation failed because { regions_for_compile } classes are not found in the model. "
1437
+ "Either set them correctly, or set `use_regional_compile` to False while calling copmile, e.g. "
1438
+ "pipe.transformer.compile(use_regional_compile=False) to fallback to full model compilation, "
1439
+ "which could have high iteration latency."
1440
+ )
1441
+ else :
1442
+ super ().compile (* args , ** kwargs )
1443
+
1405
1444
@classmethod
1406
1445
def _load_pretrained_model (
1407
1446
cls ,
0 commit comments