Skip to content

[rfc][compile] compile method for DiffusionPipeline #11705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Jun 13, 2025

This PR adds a compile method to the DiffusionPipeline to make it really easy for diffusers users to apply torch.compile.

There are a couple of reasons why we want a deeper integration of compile with diffusion models

  • Regional compile (instead of full model compile) is better fit for diffusion models, and its not easy for the user to apply regional compile precisely.
  • A pipeline has many components, and many of them are not suitable for compile. Since we (or a new model author) know the architecture of the models, we can leverage it to apply compile at the most relevant portions of the pipeline.

Option 1 - Pipeline has a compile method

From the user standpoint, they will have to add this line

pipe.compile(**compile_kwargs)

This will by default find the regions in the transformer model (more on how to find regions later) and apply regional compile to those regions. We can use this to provide more control to the user - if they want to apply compile to pre and pos-transformer model.

Pro

  • Easiest for the user

Cons

  • Might be considered a big change, and we might want to phase this out.

Option 2 - ModelMixin overrides nn.Module compile method (this PR does Option 2)

User does something like this

pipe.transformer.compile(use_regional_compile=True/False, **kwargs)

This overrides the nn.Module compile and provides options for regional compile. This is one level deeper integration, so user will have to know that the pipe must have a transformer named attribute. If use_regional_compile is True, we will find the regions (more on that later), and apply torch.compile on them.

Pros

  • ModelMixin is the most time consuming part of the model, so the abstraction seems right.

Cons

  • Users might already be using pipe.transformer.compile to enable full model compilation. I DO want regional compile as the default (supported by numbers later). But when users upgrade to new diffusers, they will see regional compile if we make it a default, and might be confused. Turning off regional compile by defaults solves this problem, but it is sub-par UX because of high compile time w/o any real benefit.
  • Though minor, its less obvious than pipe.compile

Option 3 - Combine option 1 and option 2

This could be the best option. We have both option 1 and 2. In option 2, we keep the default regional compile OFF. But option 1 - pipe.compile has regional compile by default, and propagates the regional compile to pipe.tranformer.compile. This way, the OOTB solution would be pipe.compile and then any more enthusiastic user can play with pipe.transformer.compile if they wish.

How to find regions?

Here, we can make a simple addition to the existing models and provide guidance for the new models. Most of the diffusion models are basically a sequence of Transformer blocks, and these blocks are obvious from the model __init__ method. We can add a new class attribute for each model - _regions_for_compile - which point to the classes of those Transformer blocks. This is very precise and easy. Relatedly, I also find another field - _no_split_modules - which incidentally serves the same purpose. In the absence of the _regions_for_compile attribute, we can rely on _no_split_modules.

Why regional compile as default?

Regional compilation gives as good speedup as full model compilation for diffusion models with 8x-10x smaller compilation time. More data to support regional compilation as default is below

image

One experiment that I am missing is cudagraphs. I am not sure yet if enabling cudagraphs makes full model better than regional. I will rerun the experiments today with cudagraphs.

cc @sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this! LMK what you think of the comments.

@@ -286,6 +286,8 @@ def __init__(

self.gradient_checkpointing = False

self.compile_region_classes = (FluxTransformerBlock, FluxSingleTransformerBlock)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a class-level attribute like this?

_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thats better.

@@ -2027,6 +2027,40 @@ def _maybe_raise_error_if_group_offload_active(
return True
return False

def compile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make it a part of ModelMixin rather than DiffusionPipeline I believe:

class ModelMixin(torch.nn.Module, PushToHubMixin):

For most cases, users want to just do pipe.transformer.compile(). So, perhaps easier with this being added to ModelMixin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect any improvements from compile on the pre/post transformer blocks? If not, then yeah, moving it inside makes sense to me.

Copy link
Contributor Author

@anijain2305 anijain2305 Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, do we have to do pipe.transformer.compile()? Hmm, this could be confusing because torch.nn.Module also has a method named compile. If we move in the ModelMixin, we might want a different method name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, to tackle those cases, we could do something like this:

@wraps(torch.nn.Module.cuda)

So, we include all the acceptable args and kwargs in the ModelMixin compile()` method but additionally include the regional/hierarchical compilation related kwargs. Would that work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Let me try this out.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there are some advantages of Pipeline.compile:

  1. It captures the user intent better, e.g., "I just want to accelerate this pipeline" (the word "compile" is also great because it signals "expect some initial delay").
  2. It's more future proof (diffuser devs get to tweak things to accommodate changes from diffusers/pytorch versions, or even future hardware and pipeline architecture/perf-characteristics).
  3. It's more user friendly (small things like .transformer does make a noticeable effect on user experience).

(2) and (3) are effectively consequences of (1). What do you think? @sayakpaul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sayak suggested to phase this out.

Have pipeline.transformer.compile as the first PR

And then pipeline.compile in future which can internally call pipeline.transformer.compile. Its actually better this way because then the compile region classes is also hidden from the pipe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will eventually ship pipe.compile() :)

Comment on lines 2032 to 2033
compile_regions_for_transformer: bool = True,
transformer_module_name: str = "transformer",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two will have to have changed if we proceed to have it inside ModelMixin.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu June 13, 2025 05:37
@yiyixuxu
Copy link
Collaborator

thanks for the PR !
How do we determine default compile_region_classes for each model?

@anijain2305
Copy link
Contributor Author

thanks for the PR ! How do we determine default compile_region_classes for each model?

This is something that we will have to do manually. I deliberately made that choice. It should be be fairly easy to find those classes, basically the transformer classes. In absence of compile_region_classes, we can fallback to the full model compilation with a warning.

I think this is a fair compromise because it gives control to the model author, and with little guidance it should be very easy.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jun 15, 2025

It should be be fairly easy to find those classes, basically the transformer classes

is it because that they are repeated?

can we add a enable_regional_compile() on ModelMixin since this is the main purpose of that?

@sayakpaul
Copy link
Member

can we add a enable_regional_compile() on ModelMixin since this is the main purpose of that?

@yiyixuxu this is a good suggestion. But I wanted to discuss the developer experience a bit.

Our initial (@anijain2305 feel free to correct me) plan was to override compile() at ModelMixin. This overriden compile() method will have all the args and kwargs taken by the original compile() method of torch (note that any nn.Module now comes with compile()). But it will additionally take an argument to enable regional compilation.

So, it would be something like:

model = FluxTransformer2DModel()
`regional_compilation` is just an arg, name can be changed/decided later.
model.compile(regional_compilation=True, ...)

With enable_regional_compile(), would the developer experience look like?

model = FluxTransformer2DModel()
model.enable_regional_compile()
model.compile(...)

@anijain2305
Copy link
Contributor Author

@sayakpaul @yiyixuxu I updated the PR and also the original description with some discussion points (pros and cons). Let me know what you think.

@@ -1402,6 +1403,44 @@ def float(self, *args):
else:
return super().float(*args)

@wraps(torch.nn.Module.compile)
def compile(self, use_regional_compile: bool = True, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with Option 3 - we should turn this to False to keep the existing behavior of pipe.tranformer.compile same as before.

@StrongerXi
Copy link

But when users upgrade to new diffusers, they will see regional compile if we make it a default, and might be confused.

I thought the only thing the users will see here is less compilation time, if they already have pipeline.transformer.compile(...)?

@yiyixuxu
Copy link
Collaborator

i see, maybe enalbe_ is indeed a bit confusing, maybe just model.reginal_compile() then?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jun 16, 2025

one thing I want to understand is this this pattern (regional compile) only beneficial to diffusers models?
should and would this be supported in pytorch compile in the future (or maybe now)?

@anijain2305
Copy link
Contributor Author

anijain2305 commented Jun 16, 2025

one thing I want to understand is this this pattern (regional compile) only beneficial to diffusers models? should and would this be supported in pytorch compile in the future (or maybe now)?

Regional compile is very useful for diffusion models because the transformer blocks are big enough that the compile optimization inside the transformer block is enough to give major speedup, and there is very little to be had with compile optimizations across the transformer blocks.

The regional compile is not a feature per se, its more like saying - "compile responsibly" - there is no specific support required from the pytorch side to enable this feature. For torch.compile, its just another nn.Module (or function) that you are compiling. As a user, we are compiling just small repeated regions to keep the compilation cost low.

For LLMs or other non-diffusion models, it is possible that you do want cross-transformer layer optimizations to see compiler benefits (maybe the presence of static kv cache changes the space of optimizations).

For diffusion models, we can do more testing to convince ourselves. One thing that I do want to test is cudagraphs.

I thought the only thing the users will see here is less compilation time, if they already have pipeline.transformer.compile(...)?

@StrongerXi It will be true in almost all cases, but if a user specifically wants to compile the FULL model for some reason, they will be confused.

@anijain2305
Copy link
Contributor Author

i see, maybe enalbe_ is indeed a bit confusing, maybe just model.reginal_compile() then?

If we go with Option 3 (which is Option 1 + option 2), we might not need model.regional_compile. In this option 3, the ModelMixin will have a compile method which takes the additional regional_compile args. These regional compile will all be defaulted to None by default. But the Pipeline compile method will use regional by default and setup the regional kwargs whlie calling pipeline.transformer.compile. Let me know if you need to see the code.

@yiyixuxu
Copy link
Collaborator

The regional compile is not a feature per se, its more like saying - "compile responsibly" - there is no specific support required from the pytorch side to enable this feature. For torch.compile, its just another nn.Module (or function) that you are compiling. As a user, we are compiling just small repeated regions to keep the compilation cost low.

thanks for the explanation! As I understand, this is something works great in diffusers but not unique to it. i.e. it can potentially benefit many non-diffusers models. I think it would not be a good practice for different libraries to wrap the compile API in different ways so I would like to upstream it if it's possible

if it is not possible or does not make sense to add in pytorch, I would prefer to very quickly support it with a separate method instead of wrapping around the compile API - it is one less layer and less complex, also easier user to discover this and use

@yiyixuxu
Copy link
Collaborator

regards to pipeline.compile(), I agree with @sayakpaul I think we should focus on model level first and eventually come to pipeline.compile() if it is meaningful

@anijain2305
Copy link
Contributor Author

anijain2305 commented Jun 17, 2025

@yiyixuxu I hear you. There is not really a way to upstream this, as this is completely model dependent. In practice, library/model authors are the best folks to determine where to apply torch.compile (in diffusion models, it turns out to be simple).

I am actually ok with anything that makes it easy to give the best ootb torch.compile experience to the end users. With that in mind, I am ok with regional_compile API in ModelMixin, supplemented with good documentation.

Cc @sayakpaul to hear his thoughts.

@sayakpaul
Copy link
Member

@anijain2305 let's go with #11705 (comment)? Thanks for being willing to reiterate. I will help with tests and documentation as well to see this through with you.

I think Yiyi's observations are quite aligned with what we see in the community of our users.

I would also be keen on seeing the results with CUDAgraphs now that @a-r-r-o-w made a nice observation by using it (Slack).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants