From f6137a71757da87ce5624502955ba955b76f7b0b Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 25 Sep 2025 22:10:11 +0900 Subject: [PATCH] feat: add fp8 optimization for FLUX --- flux_minimal_inference.py | 130 ++++++++++++++++++++++++++++---------- flux_train.py | 4 +- flux_train_control_net.py | 4 +- flux_train_network.py | 24 ++++--- library/flux_models.py | 46 ++++++++++---- library/flux_utils.py | 103 ++++++++++++++++++------------ train_network.py | 20 +++--- 7 files changed, 223 insertions(+), 108 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 0664b3c78..d8a4ac1ae 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -18,6 +18,7 @@ from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from library.safetensors_utils import MemoryEfficientSafeOpen from networks import oft_flux init_ipex() @@ -325,7 +326,8 @@ def encode(prpt: str): # generate image logger.info("Generating image...") - model = model.to(device) + if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0): + model = model.to(device) if steps is None: steps = 4 if is_schnell else 50 @@ -411,12 +413,16 @@ def encode(prpt: str): parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") + parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") + parser.add_argument( + "--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage" + ) parser.add_argument( "--lora_weights", type=str, @@ -442,6 +448,8 @@ def is_fp8(dt): t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) ae_dtype = str_to_dtype(args.ae_dtype, dtype) flux_dtype = str_to_dtype(args.flux_dtype, dtype) + if args.fp8_scaled and flux_dtype.itemsize == 1: + raise ValueError("fp8_scaled is not supported for fp8 flux_dtype") logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") @@ -470,13 +478,68 @@ def is_fp8(dt): # if is_fp8(t5xxl_dtype): # t5xxl = accelerator.prepare(t5xxl) + # check LoRA and OFT weights can be mergeable + mergeable_lora_weights = None + mergeable_lora_multipliers = None + if args.fp8_scaled and args.lora_weights: + assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled" + + mergeable_lora_weights = [] + mergeable_lora_multipliers = [] + + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + with MemoryEfficientSafeOpen(weights_file) as f: + keys = list(f.keys()) + + is_lora = is_oft = False + includes_text_encoder = False + for key in keys: + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if key.startswith("lora_te") or key.startswith("oft_te"): + includes_text_encoder = True + if (is_lora or is_oft) and includes_text_encoder: + break + + if includes_text_encoder or is_oft: + raise ValueError( + f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled" + ) + + mergeable_lora_weights.append(weights_file) + mergeable_lora_multipliers.append(multiplier) + # DiT + loading_dtype = None if args.fp8_scaled else flux_dtype + loading_device = "cpu" if args.blocks_to_swap or args.offload else device + is_schnell, model = flux_utils.load_flow_model( - args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type + device, + args.ckpt_path, + loading_dtype, + loading_device, + args.model_type, + args.fp8_scaled, + lora_weights_list=mergeable_lora_weights, + lora_multipliers=mergeable_lora_multipliers, ) model.eval() - logger.info(f"Casting model to {flux_dtype}") - model.to(flux_dtype) # make sure model is dtype + + if args.blocks_to_swap is not None and args.blocks_to_swap > 0: + model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + + # logger.info(f"Casting model to {flux_dtype}") + # model.to(flux_dtype) # make sure model is dtype # if is_fp8(flux_dtype): # model = accelerator.prepare(model) # if args.offload: @@ -494,36 +557,37 @@ def is_fp8(dt): # LoRA lora_models: List[lora_flux.LoRANetwork] = [] - for weights_file in args.lora_weights: - if ";" in weights_file: - weights_file, multiplier = weights_file.split(";") - multiplier = float(multiplier) - else: - multiplier = 1.0 - - weights_sd = load_file(weights_file) - is_lora = is_oft = False - for key in weights_sd.keys(): - if key.startswith("lora"): - is_lora = True - if key.startswith("oft"): - is_oft = True - if is_lora or is_oft: - break - - module = lora_flux if is_lora else oft_flux - lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) - - if args.merge_lora_weights: - lora_model.merge_to([clip_l, t5xxl], model, weights_sd) - else: - lora_model.apply_to([clip_l, t5xxl], model) - info = lora_model.load_state_dict(weights_sd, strict=True) - logger.info(f"Loaded LoRA weights from {weights_file}: {info}") - lora_model.eval() - lora_model.to(device) + if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) - lora_models.append(lora_model) + lora_models.append(lora_model) if not args.interactive: generate_image( diff --git a/flux_train.py b/flux_train.py index 4aa67220f..44e40c23b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -271,7 +271,7 @@ def train(args): # load FLUX _, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" + accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux" ) if args.gradient_checkpointing: @@ -302,7 +302,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True) if not cache_latents: # load VAE here if not cached diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 019914058..1b56c9f86 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load FLUX is_schnell, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" + accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux" ) flux.requires_grad_(False) @@ -304,7 +304,7 @@ def train(args): # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True) flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage # ControlNet only has two blocks, so we can keep it on GPU # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..db61f15d9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -51,7 +51,9 @@ def assert_extra_args( self.use_clip_l = True else: self.use_clip_l = False # Chroma does not use CLIP-L - assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります" + assert ( + args.apply_t5_attn_mask + ), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります" if args.fp8_base_unet: args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 @@ -100,17 +102,15 @@ def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) - loading_dtype = None if args.fp8_base else weight_dtype + loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype + loading_device = "cpu" if self.is_swapping_blocks else accelerator.device - # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + # load with quantization if needed _, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, - loading_dtype, - "cpu", - disable_mmap=args.disable_mmap_load_safetensors, - model_type=self.model_type, + accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled ) - if args.fp8_base: + + if args.fp8_base and not args.fp8_scaled: # check dtype of model if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") @@ -130,7 +130,7 @@ def load_target_model(self, args, weight_dtype, accelerator): if self.is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") - model.enable_block_swap(args.blocks_to_swap, accelerator.device) + model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True) if self.use_clip_l: clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) @@ -309,6 +309,9 @@ def encode_images_to_latents(self, args, vae, images): def shift_scale_latents(self, args, latents): return latents + def cast_unet(self, args): + return not args.fp8_scaled + def get_noise_pred_and_target( self, args, @@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_dit_training_arguments(parser) flux_train_utils.add_flux_train_arguments(parser) + parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") parser.add_argument( "--split_mode", action="store_true", diff --git a/library/flux_models.py b/library/flux_models.py index d2d7e06c7..034543f07 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -691,25 +691,32 @@ def _forward( ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) + del vec # prepare image for attention - img_modulated = self.img_norm1(img) + img_modulated = self.img_norm1(img.to(torch.float32)).to(img.dtype) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention - txt_modulated = self.txt_norm1(txt) + txt_modulated = self.txt_norm1(txt.to(torch.float32)).to(txt.dtype) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v # make attention mask if not None attn_mask = None @@ -725,14 +732,24 @@ def _forward( attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + del q, k, v, attn # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + del img_mod1, img_attn + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img.to(torch.float32)).to(img.dtype) + img_mod2.shift + ) + del img_mod2 # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + del txt_mod1, txt_attn + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt.to(torch.float32)).to(txt.dtype) + txt_mod2.shift + ) + del txt_mod2 + return img, txt def forward( @@ -805,10 +822,14 @@ def disable_gradient_checkpointing(self): def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: mod, _ = self.modulation(vec) - x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + del vec + x_mod = (1 + mod.scale) * self.pre_norm(x.to(torch.float32)) + mod.shift + x_mod = x_mod.to(x.dtype) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del qkv q, k = self.norm(q, k, v) # make attention mask if not None @@ -831,9 +852,12 @@ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optio # compute attention attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + del q, k, v # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + del attn, mlp + return x + mod.gate * output def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: @@ -969,7 +993,7 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int, device: torch.device): + def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False): self.blocks_to_swap = num_blocks double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 @@ -980,10 +1004,10 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1215,7 +1239,7 @@ def disable_gradient_checkpointing(self): print("FLUX: Gradient checkpointing disabled.") - def enable_block_swap(self, num_blocks: int, device: torch.device): + def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False): self.blocks_to_swap = num_blocks double_blocks_to_swap = num_blocks // 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 @@ -1226,10 +1250,10 @@ def enable_block_swap(self, num_blocks: int, device: torch.device): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." diff --git a/library/flux_utils.py b/library/flux_utils.py index 410b34ce2..c4bc6712f 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,7 +1,7 @@ import json import os from dataclasses import replace -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import einops import torch @@ -10,6 +10,8 @@ from safetensors.torch import load_file from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel +from library.fp8_optimization_utils import apply_fp8_monkey_patch +from library.lora_utils import load_safetensors_with_lora_and_fp8 from library.utils import setup_logging setup_logging() @@ -25,6 +27,9 @@ MODEL_NAME_SCHNELL = "schnell" MODEL_VERSION_CHROMA = "chroma" +FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"] + def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -93,17 +98,23 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int def load_flow_model( - ckpt_path: str, - dtype: Optional[torch.dtype], device: Union[str, torch.device], - disable_mmap: bool = False, + ckpt_path: str, + dit_weight_dtype: Optional[torch.dtype], + loading_device: Union[str, torch.device], model_type: str = "flux", + fp8_scaled: bool = False, + lora_weights_list: Optional[Dict[str, torch.Tensor]] = None, + lora_multipliers: Optional[list[float]] = None, ) -> Tuple[bool, flux_models.Flux]: + device = torch.device(device) # device for calculation, typically "cuda" + loading_device = torch.device(loading_device) + + # build model if model_type == "flux": is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - # build model logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") with torch.device("meta"): params = flux_models.configs[name].params @@ -117,61 +128,69 @@ def load_flow_model( params = replace(params, depth_single_blocks=num_single_blocks) model = flux_models.Flux(params) - if dtype is not None: - model = model.to(dtype) + if dit_weight_dtype is not None: + model = model.to(dit_weight_dtype) - # load_sft doesn't support torch.device - logger.info(f"Loading state dict from {ckpt_path}") - sd = {} - for ckpt_path in ckpt_paths: - sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)) + elif model_type == "chroma": + from . import chroma_models + # build model + logger.info("Building Chroma model") + with torch.device("meta"): + model = chroma_models.Chroma(chroma_models.chroma_params) + if dit_weight_dtype is not None: + model = model.to(dit_weight_dtype) + + ckpt_paths = [ckpt_path] + + # load state dict + logger.info(f"Loading DiT model from {ckpt_paths}, device={loading_device}") + + sd = load_safetensors_with_lora_and_fp8( + model_files=ckpt_paths, + lora_weights_list=lora_weights_list, + lora_multipliers=lora_multipliers, + fp8_optimization=fp8_scaled, + calc_device=device, + move_to_device=(loading_device == device), + dit_weight_dtype=dit_weight_dtype, + target_keys=FP8_OPTIMIZATION_TARGET_KEYS, + exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS, + ) + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + if fp8_scaled: + apply_fp8_monkey_patch(model, sd, use_scaled_mm=False) + + if loading_device.type != "cpu": # in case of no block swapping + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + if model_type == "flux": # convert Diffusers to BFL if is_diffusers: logger.info("Converting Diffusers to BFL") sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") - # if the key has annoying prefix, remove it - for key in list(sd.keys()): - new_key = key.replace("model.diffusion_model.", "") - if new_key == key: - break # the model doesn't have annoying prefix - sd[new_key] = sd.pop(key) - info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return is_schnell, model elif model_type == "chroma": - from . import chroma_models - - # build model - logger.info("Building Chroma model") - with torch.device("meta"): - model = chroma_models.Chroma(chroma_models.chroma_params) - if dtype is not None: - model = model.to(dtype) - - # load_sft doesn't support torch.device - logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) - - # if the key has annoying prefix, remove it - for key in list(sd.keys()): - new_key = key.replace("model.diffusion_model.", "") - if new_key == key: - break # the model doesn't have annoying prefix - sd[new_key] = sd.pop(key) - info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Chroma: {info}") is_schnell = False # Chroma is not schnell return is_schnell, model - else: - raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") - def load_ae( ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False diff --git a/train_network.py b/train_network.py index 6cebf5fc7..316d03691 100644 --- a/train_network.py +++ b/train_network.py @@ -824,27 +824,31 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) - unet_weight_dtype = te_weight_dtype = weight_dtype + unet_weight_dtype = weight_dtype + te_weight_dtype = weight_dtype if self.cast_text_encoder(args) else None # Experimental Feature: Put base model into fp8 to save vram if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training for U-Net.") - unet_weight_dtype = torch.float8_e4m3fn + if self.cast_unet(args): + accelerator.print("enable fp8 training for U-Net.") + unet_weight_dtype = torch.float8_e4m3fn - if not args.fp8_base_unet: - accelerator.print("enable fp8 training for Text Encoder.") - te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn + if self.cast_text_encoder(args): + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above - logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") - unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator + if self.cast_unet(args): + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.requires_grad_(False) if self.cast_unet(args):