Skip to content

Working on Lumina-Next #519

@lavinal712

Description

@lavinal712

Hello, recently I have been looking at the CFG parallel methods and trying to apply xdit to lumina-nextdit.

import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserLuminaPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
    get_world_group,
    is_dp_last_group,
    get_data_parallel_world_size,
    get_runtime_state,
    get_data_parallel_rank,
)


def main():
    parser = FlexibleArgumentParser(description="xFuser Arguments")
    args = xFuserArgs.add_cli_args(parser).parse_args()
    engine_args = xFuserArgs.from_cli_args(args)
    engine_config, input_config = engine_args.create_config()
    local_rank = get_world_group().local_rank

    pipe = xFuserLuminaPipeline.from_pretrained(
        pretrained_model_name_or_path=engine_config.model_config.model,
        engine_config=engine_config,
        torch_dtype=torch.float16,
    ).to(f"cuda:{local_rank}")
    model_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
    pipe.prepare_run(input_config)

    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()
    output = pipe(
        height=input_config.height,
        width=input_config.width,
        prompt=input_config.prompt,
        num_inference_steps=input_config.num_inference_steps,
        output_type=input_config.output_type,
        use_resolution_binning=input_config.use_resolution_binning,
        guidance_scale=input_config.guidance_scale,
        generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
    )
    end_time = time.time()
    elapsed_time = end_time - start_time
    peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

    parallel_info = (
        f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
        f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
        f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}_tc_{engine_args.use_torch_compile}"
    )
    if input_config.output_type == "pil":
        dp_group_index = get_data_parallel_rank()
        num_dp_groups = get_data_parallel_world_size()
        dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
        if pipe.is_dp_last_group():
            if not os.path.exists("results"):
                os.mkdir("results")
            for i, image in enumerate(output.images):
                image_rank = dp_group_index * dp_batch_size + i
                img_file = (
                    f"./results/lumina_nextdit_result_{parallel_info}_{image_rank}.png"
                )
                image.save(img_file)
                print(img_file)

    if get_world_group().rank == get_world_group().world_size - 1:
        print(
            f"epoch time: {elapsed_time:.2f} sec, model memory: {model_memory/1e9:.2f} GB, overall memory: {peak_memory/1e9:.2f} GB"
        )
    get_runtime_state().destroy_distributed_env()


if __name__ == "__main__":
    main()

After runing this script, I meet this problem:

INFO 05-27 09:41:23 [base_pipeline.py:377] Transformer backbone found, paralleling transformer...
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/./examples/lumina_example.py", line 76, in <module>
[rank0]:     main()
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/./examples/lumina_example.py", line 24, in main
[rank0]:     pipe = xFuserLuminaPipeline.from_pretrained(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/xfuser/model_executor/pipelines/pipeline_lumina.py", line 48, in from_pretrained
[rank0]:     return cls(pipeline, engine_config)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 162, in __init__
[rank0]:     pipeline.transformer = self._convert_transformer_backbone(
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 378, in _convert_transformer_backbone
[rank0]:     wrapper = xFuserTransformerWrappersRegister.get_wrapper(transformer)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/azureuser/v-yuqianhong/xDiT/xfuser/model_executor/models/transformers/register.py", line 55, in get_wrapper
[rank0]:     raise ValueError(
[rank0]: ValueError: Transformer class LuminaNextDiT2DModel is not supported by xFuser

How to deal with it? I'm sure the model is added to regester. Here is the model and pipeline code:

Model:

from typing import Any, Dict, Optional
import torch
import torch.distributed
import torch.nn as nn

from diffusers import LuminaNextDiT2DModel
from diffusers.models.embeddings import PatchEmbed, LuminaPatchEmbed
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
from diffusers.utils import is_torch_version

from xfuser.logger import init_logger
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.core.distributed import is_pipeline_first_stage, is_pipeline_last_stage
from .register import xFuserTransformerWrappersRegister
from .base_transformer import xFuserTransformerBaseWrapper

logger = init_logger(__name__)


@xFuserTransformerWrappersRegister.register(LuminaNextDiT2DModel)
class xFuserLuminaNextDiT2DWrapper(xFuserTransformerBaseWrapper):
    def __init__(
        self,
        transformer: LuminaNextDiT2DModel,
    ):
        super().__init__(
            transformer=transformer,
            submodule_classes_to_wrap=[nn.Conv2d, LuminaPatchEmbed],
            submodule_name_to_wrap=["attn1"],
        )
    
    @xFuserBaseWrapper.forward_check_condition
    def forward(
        self,
        hidden_states: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_mask: torch.Tensor,
        image_rotary_emb: torch.Tensor,
        cross_attention_kwargs: Dict[str, Any] = None,
        return_dict=True,
    ) -> torch.Tensor:
        """
        Forward pass of LuminaNextDiT.

        Parameters:
            hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
            timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
            encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
            encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
        """
        #! ---------------------------------------- MODIFIED BELOW ----------------------------------------
        if is_pipeline_first_stage():
            hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
            image_rotary_emb = image_rotary_emb.to(hidden_states.device)

        #! ORIGIN
        # hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
        # image_rotary_emb = image_rotary_emb.to(hidden_states.device)
        #! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

        temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)

        encoder_mask = encoder_mask.bool()
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                mask,
                image_rotary_emb,
                encoder_hidden_states,
                encoder_mask,
                temb=temb,
                cross_attention_kwargs=cross_attention_kwargs,
            )

        #! ---------------------------------------- ADD BELOW ----------------------------------------
        if is_pipeline_last_stage():
            #! ---------------------------------------- ADD ABOVE ----------------------------------------
            hidden_states = self.norm_out(hidden_states, temb)

            # unpatchify
            height_tokens = width_tokens = self.patch_size
            height, width = img_size[0]
            batch_size = hidden_states.size(0)
            sequence_length = (height // height_tokens) * (width // width_tokens)
            hidden_states = hidden_states[:, :sequence_length].view(
                batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
            )
            output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
        #! ---------------------------------------- ADD BELOW ----------------------------------------
        else:
            output = hidden_states
        #! ---------------------------------------- ADD ABOVE ----------------------------------------

        if not return_dict:
            return (output,)

        return Transformer2DModelOutput(sample=output)

Pipeline

import math
import os
from typing import Dict, List, Tuple, Callable, Optional, Union

import torch
import torch.distributed
from diffusers import LuminaPipeline
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina
from diffusers.pipelines.lumina.pipeline_lumina import retrieve_timesteps
from diffusers.utils import deprecate
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput

from xfuser.config import EngineConfig
from xfuser.core.distributed import (
    is_dp_last_group,
    get_classifier_free_guidance_world_size,
    get_pipeline_parallel_world_size,
    get_runtime_state,
    get_cfg_group,
    get_pp_group,
    get_sequence_parallel_world_size,
    get_sp_group,
    is_pipeline_first_stage,
    is_pipeline_last_stage,
    get_world_group
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister


@xFuserPipelineWrapperRegister.register(LuminaPipeline)
class xFuserLuminaPipeline(xFuserPipelineBaseWrapper):

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        engine_config: EngineConfig,
        return_org_pipeline: bool = False,
        **kwargs,
    ):
        pipeline = LuminaPipeline.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )
        if return_org_pipeline:
            return pipeline
        return cls(pipeline, engine_config)

    @torch.no_grad()
    @xFuserPipelineBaseWrapper.enable_fast_attn
    @xFuserPipelineBaseWrapper.enable_data_parallel
    @xFuserPipelineBaseWrapper.check_to_use_naive_forward
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        width: Optional[int] = None,
        height: Optional[int] = None,
        num_inference_steps: int = 30,
        guidance_scale: float = 4.0,
        negative_prompt: Union[str, List[str]] = None,
        sigmas: List[float] = None,
        num_images_per_prompt: Optional[int] = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        prompt_attention_mask: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        clean_caption: bool = True,
        max_sequence_length: int = 256,
        scaling_watershed: Optional[float] = 1.0,
        proportional_attn: Optional[bool] = True,
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    ) -> Union[ImagePipelineOutput, Tuple]:
        """
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_inference_steps (`int`, *optional*, defaults to 30):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`List[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to 4.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            height (`int`, *optional*, defaults to self.unet.config.sample_size):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size):
                The width in pixels of the generated image.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
                applies to [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
            clean_caption (`bool`, *optional*, defaults to `True`):
                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
                be installed. If the dependencies are not installed, the embeddings will be created from the raw
                prompt.
            max_sequence_length (`int` defaults to 120):
                Maximum sequence length to use with the `prompt`.
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            height,
            width,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            prompt_attention_mask=prompt_attention_mask,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale

        cross_attention_kwargs = {}

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if proportional_attn:
            cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2

        scaling_factor = math.sqrt(width * height / self.default_image_size**2)

        device = self._execution_device

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        (
            prompt_embeds,
            prompt_attention_mask,
            negative_prompt_embeds,
            negative_prompt_attention_mask,
        ) = self.encode_prompt(
            prompt,
            do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            device=device,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            prompt_attention_mask=prompt_attention_mask,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
            clean_caption=clean_caption,
            max_sequence_length=max_sequence_length,
        )
        #! ---------------------------------------- MODIFIED BELOW ----------------------------------------
        if do_classifier_free_guidance:
            (
                prompt_embeds,
                prompt_attention_mask,
            ) = self._process_cfg_split_batch(
                negative_prompt_embeds,
                prompt_embeds,
                negative_prompt_attention_mask,
                prompt_attention_mask,
            )

        #! ORIGIN
        # if do_classifier_free_guidance:
        #     prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
        #     prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0)
        #! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

        # 4. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)

        # 5. Prepare latents.
        latent_channels = self.transformer.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            latent_channels,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        self._num_timesteps = len(timesteps)

        # 6. Denoising loop
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                current_timestep = t
                if not torch.is_tensor(current_timestep):
                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                    # This would be a good case for the `match` statement (Python 3.10+)
                    is_mps = latent_model_input.device.type == "mps"
                    is_npu = latent_model_input.device.type == "npu"
                    if isinstance(current_timestep, float):
                        dtype = torch.float32 if (is_mps or is_npu) else torch.float64
                    else:
                        dtype = torch.int32 if (is_mps or is_npu) else torch.int64
                    current_timestep = torch.tensor(
                        [current_timestep],
                        dtype=dtype,
                        device=latent_model_input.device,
                    )
                elif len(current_timestep.shape) == 0:
                    current_timestep = current_timestep[None].to(latent_model_input.device)
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                current_timestep = current_timestep.expand(latent_model_input.shape[0])

                # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
                current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps

                # prepare image_rotary_emb for positional encoding
                # dynamic scaling_factor for different resolution.
                # NOTE: For `Time-aware` denosing mechanism from Lumina-Next
                # https://huggingface.co/papers/2406.18583, Sec 2.3
                # NOTE: We should compute different image_rotary_emb with different timestep.
                if current_timestep[0] < scaling_watershed:
                    linear_factor = scaling_factor
                    ntk_factor = 1.0
                else:
                    linear_factor = 1.0
                    ntk_factor = scaling_factor
                image_rotary_emb = get_2d_rotary_pos_embed_lumina(
                    self.transformer.head_dim,
                    384,
                    384,
                    linear_factor=linear_factor,
                    ntk_factor=ntk_factor,
                )

                noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=current_timestep,
                    encoder_hidden_states=prompt_embeds,
                    encoder_mask=prompt_attention_mask,
                    image_rotary_emb=image_rotary_emb,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]
                noise_pred = noise_pred.chunk(2, dim=1)[0]

                # perform guidance scale
                # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only
                # three channels by default. The standard approach to cfg applies it to all channels.
                # This can be done by uncommenting the following line and commenting-out the line following that.
                # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
                #! ---------------------------------------- MODIFIED BELOW ----------------------------------------
                if do_classifier_free_guidance:
                    noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:]
                    if get_classifier_free_guidance_world_size() == 1:
                        noise_pred_cond_eps, noise_pred_uncond_eps = torch.split(
                            noise_pred_eps, len(noise_pred_eps) // 2, dim=0
                        )
                    elif get_classifier_free_guidance_world_size() == 2:
                        noise_pred_cond_eps, noise_pred_uncond_eps = get_cfg_group().all_gather(
                            noise_pred_eps, separate_tensors=True
                        )
                    noise_pred_half = noise_pred_uncond_eps + guidance_scale * (
                        noise_pred_cond_eps - noise_pred_uncond_eps
                    )
                    noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0)

                    noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1)
                    noise_pred, _ = noise_pred.chunk(2, dim=0)
                #! ---------------------------------------- MODIFIED ABOVE ----------------------------------------

                # compute the previous noisy sample x_t -> x_t-1
                latents_dtype = latents.dtype
                noise_pred = -noise_pred
                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                if latents.dtype != latents_dtype:
                    if torch.backends.mps.is_available():
                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
                        latents = latents.to(latents_dtype)

                progress_bar.update()

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)

        if not output_type == "latent":
            latents = latents / self.vae.config.scaling_factor
            image = self.vae.decode(latents, return_dict=False)[0]
            image = self.image_processor.postprocess(image, output_type=output_type)
        else:
            image = latents

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions