File tree Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Expand file tree Collapse file tree 3 files changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -286,6 +286,8 @@ def __init__(
286
286
287
287
self .gradient_checkpointing = False
288
288
289
+ self .compile_region_classes = (FluxTransformerBlock , FluxSingleTransformerBlock )
290
+
289
291
@property
290
292
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
291
293
def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
Original file line number Diff line number Diff line change @@ -403,6 +403,8 @@ def __init__(
403
403
404
404
self .gradient_checkpointing = False
405
405
406
+ self .compile_region_classes = (WanTransformerBlock ,)
407
+
406
408
def forward (
407
409
self ,
408
410
hidden_states : torch .Tensor ,
Original file line number Diff line number Diff line change @@ -2027,6 +2027,39 @@ def _maybe_raise_error_if_group_offload_active(
2027
2027
return True
2028
2028
return False
2029
2029
2030
+ def compile (
2031
+ self ,
2032
+ compile_regions_for_transformer : bool = True ,
2033
+ transformer_module_name : str = "transformer" ,
2034
+ other_modules_names : List [str ] = [],
2035
+ ** compile_kwargs ,
2036
+ ):
2037
+ transformer = getattr (self , transformer_module_name , None )
2038
+ if transformer is None :
2039
+ raise ValueError (
2040
+ f"{ transformer_module_name } not found in the pipeline. Set `transformer_module_name` to the correct module name."
2041
+ )
2042
+
2043
+ if compile_regions_for_transformer :
2044
+ compile_region_classes = getattr (transformer , "compile_region_classes" , None )
2045
+ if compile_region_classes is None :
2046
+ raise ValueError (
2047
+ f"{ transformer_module_name } does not have `compile_region_classes` attribute. Set `compile_regions_for_transformer` to False."
2048
+ )
2049
+
2050
+ for submod in transformer .modules ():
2051
+ if isinstance (submod , compile_region_classes ):
2052
+ submod .compile (** compile_kwargs )
2053
+ else :
2054
+ transformer .compile (** compile_kwargs )
2055
+
2056
+ for module_name in other_modules_names :
2057
+ module = getattr (self , module_name , None )
2058
+ if module is None :
2059
+ raise ValueError (
2060
+ f"{ module_name } not found in the pipeline. Set `other_modules_names` to the correct module names."
2061
+ )
2062
+
2030
2063
2031
2064
class StableDiffusionMixin :
2032
2065
r"""
You can’t perform that action at this time.
0 commit comments