Skip to content

Commit 59f1f8c

Browse files
committed
[rfc][compile] compile method for DiffusionPipeline
1 parent 47ef794 commit 59f1f8c

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def __init__(
286286

287287
self.gradient_checkpointing = False
288288

289+
self.compile_region_classes = (FluxTransformerBlock, FluxSingleTransformerBlock)
290+
289291
@property
290292
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
291293
def attn_processors(self) -> Dict[str, AttentionProcessor]:

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ def __init__(
403403

404404
self.gradient_checkpointing = False
405405

406+
self.compile_region_classes = (WanTransformerBlock,)
407+
406408
def forward(
407409
self,
408410
hidden_states: torch.Tensor,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,6 +2027,39 @@ def _maybe_raise_error_if_group_offload_active(
20272027
return True
20282028
return False
20292029

2030+
def compile(
2031+
self,
2032+
compile_regions_for_transformer: bool = True,
2033+
transformer_module_name: str = "transformer",
2034+
other_modules_names: List[str] = [],
2035+
**compile_kwargs,
2036+
):
2037+
transformer = getattr(self, transformer_module_name, None)
2038+
if transformer is None:
2039+
raise ValueError(
2040+
f"{transformer_module_name} not found in the pipeline. Set `transformer_module_name` to the correct module name."
2041+
)
2042+
2043+
if compile_regions_for_transformer:
2044+
compile_region_classes = getattr(transformer, "compile_region_classes", None)
2045+
if compile_region_classes is None:
2046+
raise ValueError(
2047+
f"{transformer_module_name} does not have `compile_region_classes` attribute. Set `compile_regions_for_transformer` to False."
2048+
)
2049+
2050+
for submod in transformer.modules():
2051+
if isinstance(submod, compile_region_classes):
2052+
submod.compile(**compile_kwargs)
2053+
else:
2054+
transformer.compile(**compile_kwargs)
2055+
2056+
for module_name in other_modules_names:
2057+
module = getattr(self, module_name, None)
2058+
if module is None:
2059+
raise ValueError(
2060+
f"{module_name} not found in the pipeline. Set `other_modules_names` to the correct module names."
2061+
)
2062+
20302063

20312064
class StableDiffusionMixin:
20322065
r"""

0 commit comments

Comments
 (0)