Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 97 additions & 33 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from networks import oft_flux

init_ipex()
Expand Down Expand Up @@ -325,7 +326,8 @@ def encode(prpt: str):

# generate image
logger.info("Generating image...")
model = model.to(device)
if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0):
model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50

Expand Down Expand Up @@ -411,12 +413,16 @@ def encode(prpt: str):
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage"
)
parser.add_argument(
"--lora_weights",
type=str,
Expand All @@ -442,6 +448,8 @@ def is_fp8(dt):
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
if args.fp8_scaled and flux_dtype.itemsize == 1:
raise ValueError("fp8_scaled is not supported for fp8 flux_dtype")

logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")

Expand Down Expand Up @@ -470,13 +478,68 @@ def is_fp8(dt):
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)

# check LoRA and OFT weights can be mergeable
mergeable_lora_weights = None
mergeable_lora_multipliers = None
if args.fp8_scaled and args.lora_weights:
assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled"

mergeable_lora_weights = []
mergeable_lora_multipliers = []

for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0

with MemoryEfficientSafeOpen(weights_file) as f:
keys = list(f.keys())

is_lora = is_oft = False
includes_text_encoder = False
for key in keys:
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if key.startswith("lora_te") or key.startswith("oft_te"):
includes_text_encoder = True
if (is_lora or is_oft) and includes_text_encoder:
break

if includes_text_encoder or is_oft:
raise ValueError(
f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled"
)

mergeable_lora_weights.append(weights_file)
mergeable_lora_multipliers.append(multiplier)

# DiT
loading_dtype = None if args.fp8_scaled else flux_dtype
loading_device = "cpu" if args.blocks_to_swap or args.offload else device

is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
device,
args.ckpt_path,
loading_dtype,
loading_device,
args.model_type,
args.fp8_scaled,
lora_weights_list=mergeable_lora_weights,
lora_multipliers=mergeable_lora_multipliers,
)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype

if args.blocks_to_swap is not None and args.blocks_to_swap > 0:
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False)
model.move_to_device_except_swap_blocks(device)
model.prepare_block_swap_before_forward()

# logger.info(f"Casting model to {flux_dtype}")
# model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
Expand All @@ -494,36 +557,37 @@ def is_fp8(dt):

# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0

weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break

module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)

if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0

weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break

module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)

if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)

lora_models.append(lora_model)
lora_models.append(lora_model)

if not args.interactive:
generate_image(
Expand Down
4 changes: 2 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def train(args):

# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
)

if args.gradient_checkpointing:
Expand Down Expand Up @@ -302,7 +302,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)

if not cache_latents:
# load VAE here if not cached
Expand Down
4 changes: 2 additions & 2 deletions flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def train(args):

# load FLUX
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
)
flux.requires_grad_(False)

Expand Down Expand Up @@ -304,7 +304,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
Expand Down
24 changes: 14 additions & 10 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def assert_extra_args(
self.use_clip_l = True
else:
self.use_clip_l = False # Chroma does not use CLIP-L
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
assert (
args.apply_t5_attn_mask
), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"

if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
Expand Down Expand Up @@ -100,17 +102,15 @@ def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models

# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device

# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
# load with quantization if needed
_, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled
)
if args.fp8_base:

if args.fp8_base and not args.fp8_scaled:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
Expand All @@ -130,7 +130,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)

if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
Expand Down Expand Up @@ -309,6 +309,9 @@ def encode_images_to_latents(self, args, vae, images):
def shift_scale_latents(self, args, latents):
return latents

def cast_unet(self, args):
return not args.fp8_scaled

def get_noise_pred_and_target(
self,
args,
Expand Down Expand Up @@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--split_mode",
action="store_true",
Expand Down
Loading