Skip to content

Commit b7366a7

Browse files
committed
[rfc][compile] compile method for DiffusionPipeline
1 parent 47ef794 commit b7366a7

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266266
_keep_in_fp32_modules = None
267267
_skip_layerwise_casting_patterns = None
268268
_supports_group_offloading = True
269+
_regions_for_compile = []
269270

270271
def __init__(self):
271272
super().__init__()
@@ -1402,6 +1403,44 @@ def float(self, *args):
14021403
else:
14031404
return super().float(*args)
14041405

1406+
@wraps(torch.nn.Module.compile)
1407+
def compile(self, use_regional_compile: bool = True, *args, **kwargs):
1408+
""" """
1409+
if use_regional_compile:
1410+
regions_for_compile = getattr(self, "_regions_for_compile", None)
1411+
1412+
if not regions_for_compile:
1413+
logger.warning(
1414+
"_regions_for_compile attribute is empty. Using _no_split_modules to find compile regions."
1415+
)
1416+
1417+
regions_for_compile = getattr(self, "_no_split_modules", None)
1418+
1419+
if not regions_for_compile:
1420+
logger.warning(
1421+
"Both _regions_for_compile and _no_split_modules attribute are empty. "
1422+
"Set _regions_for_compile for the model to benefit from regional compilation. "
1423+
"Falling back to full model compilation, which could have high first iteration "
1424+
"latency."
1425+
)
1426+
super().compile(*args, **kwargs)
1427+
1428+
has_compiled_region = False
1429+
for submod in self.modules():
1430+
if submod.__class__.__name__ in regions_for_compile:
1431+
has_compiled_region = True
1432+
submod.compile(*args, **kwargs)
1433+
1434+
if not has_compiled_region:
1435+
raise ValueError(
1436+
f"Regional compilation failed because {regions_for_compile} classes are not found in the model. "
1437+
"Either set them correctly, or set `use_regional_compile` to False while calling copmile, e.g. "
1438+
"pipe.transformer.compile(use_regional_compile=False) to fallback to full model compilation, "
1439+
"which could have high iteration latency."
1440+
)
1441+
else:
1442+
super().compile(*args, **kwargs)
1443+
14051444
@classmethod
14061445
def _load_pretrained_model(
14071446
cls,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
227227
_supports_gradient_checkpointing = True
228228
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229229
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
230+
_regions_for_compile = _no_split_modules
230231

231232
@register_to_config
232233
def __init__(

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345345
_no_split_modules = ["WanTransformerBlock"]
346346
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347347
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
348+
_regions_for_compile = _no_split_modules
348349

349350
@register_to_config
350351
def __init__(

0 commit comments

Comments
 (0)