diff --git a/flux_train_network.py b/flux_train_network.py index def441559..96ed4b70f 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -141,6 +141,11 @@ def load_target_model(self, args, weight_dtype, accelerator): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + # Apply partitioned for Diffusion4k + if args.partitioned_vae: + ae.decoder.partitioned = True + ae.decoder.stride = 2 # Diffusion4k stride + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): @@ -359,8 +364,14 @@ def get_noise_pred_and_target( # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] + + if args.partitioned_vae: + img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device) + else: + img_ids = flux_utils.prepare_img_ids(bsz, latent_height // 2, latent_width // 2).to(device=accelerator.device) + + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" # get guidance # ensure guidance_scale in args is float @@ -408,7 +419,11 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + # if args.partitioned_vae: + # model_pred = flux_utils.unpack_partitioned_latents(model_pred, latent_width, latent_height) + # else: + # # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, latents.shape[2] // 2, latents.shape[3] // 2) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481d..bb608e488 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -54,6 +54,8 @@ class AutoEncoderParams: z_channels: int scale_factor: float shift_factor: float + stride: int + partitioned: bool def swish(x: Tensor) -> Tensor: @@ -228,6 +230,8 @@ def __init__( in_channels: int, resolution: int, z_channels: int, + partitioned=False, + stride=1, ): super().__init__() self.ch = ch @@ -236,6 +240,8 @@ def __init__( self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) + self.stride = stride + self.partitioned = partitioned # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] @@ -272,7 +278,7 @@ def __init__( self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z: Tensor) -> Tensor: + def forward(self, z: Tensor, partitioned=None) -> Tensor: # z to block_in h = self.conv_in(z) @@ -291,9 +297,56 @@ def forward(self, z: Tensor) -> Tensor: h = self.up[i_level].upsample(h) # end - h = self.norm_out(h) - h = swish(h) - h = self.conv_out(h) + + # Diffusion4k + partitioned = partitioned if not None else self.partitioned + if self.stride > 1 and partitioned: + h = self.norm_out(h) + h = swish(h) + + overlap_size = 1 # because last conv kernel_size = 3 + res = [] + partitioned_height = h.shape[2] // self.stride + partitioned_width = h.shape[3] // self.stride + + assert self.stride == 2 # only support stride = 2 for now + rows = [] + for i in range(0, h.shape[2], partitioned_height): + row = [] + for j in range(0, h.shape[3], partitioned_width): + partition = h[:,:, max(i - overlap_size, 0) : min(i + partitioned_height + overlap_size, h.shape[2]), max(j - overlap_size, 0) : min(j + partitioned_width + overlap_size, h.shape[3])] + + # for strih + if i==0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 1, 0), "constant", 0) + elif i==0: + partition = torch.nn.functional.pad(partition, (0, 1, 1, 0), "constant", 0) + elif i>0 and j==0: + partition = torch.nn.functional.pad(partition, (1, 0, 0, 1), "constant", 0) + elif i>0 and j>0: + partition = torch.nn.functional.pad(partition, (0, 1, 0, 1), "constant", 0) + + partition = torch.nn.functional.interpolate(partition, scale_factor=self.stride, mode='nearest') + partition = self.conv_out(partition) + partition = partition[:,:,overlap_size:partitioned_height*2+overlap_size,overlap_size:partitioned_width*2+overlap_size] + + row.append(partition) + rows.append(row) + + for row in rows: + res.append(torch.cat(row, dim=3)) + + h = torch.cat(res, dim=2) + # Diffusion4k + elif self.stride > 1: + h = self.norm_out(h) + h = torch.nn.functional.interpolate(h, scale_factor=self.stride, mode='nearest') + h = swish(h) + h = self.conv_out(h) + else: + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) return h @@ -404,6 +457,9 @@ class ModelSpec: z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), "schnell": ModelSpec( @@ -436,6 +492,9 @@ class ModelSpec: z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + # Diffusion4k + stride=1, + partitioned=False, ), ), } diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 5f6867a81..cb9583043 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -232,18 +232,49 @@ def encode_prompt(prpt): # sample image weight_dtype = ae.dtype # TOFO give dtype as argument - packed_latent_height = height // 16 - packed_latent_width = width // 16 - noise = torch.randn( - 1, - packed_latent_height * packed_latent_width, - 16 * 2 * 2, - device=accelerator.device, - dtype=weight_dtype, - generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, - ) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True - img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + + if args.partitioned_vae: + vae_scale_factor = 32 + latent_height = 2 * (int(height) // vae_scale_factor) + latent_width = 2 * (int(width) // vae_scale_factor) + + print("latent height", latent_height) + print("latent width", latent_width) + + noisy_model_input = torch.randn( + 1, # Batch size + 16, # VAE channels + latent_height, + latent_width, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + img_ids = flux_utils.prepare_partitioned_img_ids(1, latent_height, latent_width).to(device=accelerator.device) + + print("img_ids: ", img_ids.shape) + else: + # VAE 8x compression + latent_height = height // 8 + latent_width = width // 8 + noisy_model_input = torch.randn( + 1, # Batch size + 16, # VAE channels + latent_height, + latent_width, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + latent_height, latent_width = noisy_model_input.shape[2], noisy_model_input.shape[3] + img_ids = flux_utils.prepare_img_ids(1, latent_height // 2, latent_width // 2).to(device=accelerator.device) + + assert packed_noisy_model_input.shape[1] == img_ids.shape[1], "Packed latent dimensions are not aligned with img ids" + + timesteps = get_schedule(sample_steps, noisy_model_input.shape[1], shift=True) # FLUX.1 dev -> shift=True t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None if controlnet_image is not None: @@ -255,7 +286,7 @@ def encode_prompt(prpt): with accelerator.autocast(), torch.no_grad(): x = denoise( flux, - noise, + packed_noisy_model_input, img_ids, t5_out, txt_ids, @@ -268,7 +299,11 @@ def encode_prompt(prpt): neg_cond=neg_cond, ) - x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + # unpack latents + if args.partitioned_vae: + x = flux_utils.unpack_latents(x, height // 32, width // 32) + else: + x = flux_utils.unpack_latents(x, latent_height // 2, latent_width // 2) # latent to image clean_memory_on_device(accelerator.device) @@ -680,3 +715,4 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + parser.add_argument("--partitioned_vae", action="store_true", help="Partitioned VAE from Diffusion4k paper") diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63ee..7cc2b2b8f 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -346,23 +346,68 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) return img_ids +def prepare_partitioned_img_ids(batch_size: int, height: int, width: int): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] -def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids + +def unpack_latents(x: torch.FloatTensor, height: int, width: int) -> torch.FloatTensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 """ - x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=height, w=width, ph=2, pw=2) return x - -def pack_latents(x: torch.Tensor) -> torch.Tensor: - """ - x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 - """ - x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) +# def unpack_latents(latents, height, width): +# batch_size, num_patches, channels = latents.shape +# +# # height = height // vae_scale_factor +# # width = width // vae_scale_factor +# +# latents = latents.view(batch_size, height, width, channels // 4, 2, 2) +# latents = latents.permute(0, 3, 1, 4, 2, 5) +# +# latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) +# +# return latents + +def unpack_partitioned_latents(x, height, width): + x = einops.rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=height//2, # Divide by 2 because each patch is 2x2 + w=width//2, # Divide by 2 because each patch is 2x2 + ph=2, + pw=2 + ) return x +# def pack_latents(x: torch.Tensor) -> torch.Tensor: +# """ +# x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 +# """ +# x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) +# return x +def pack_latents(latents): + batch_size, channels, height, width = latents.shape + latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4) + + return latents + + + # region Diffusers NUM_DOUBLE_BLOCKS = 19