diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 321ac1280c..d759fd024b 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -41,6 +41,17 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "Flux Validation Test", "validation", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable", + ], + ["--inference.prompt='A beautiful sunset over the ocean'"], + ], + "Flux Generation script test", + "test_generate", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 0c06a385ef..cdb5e73135 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,11 +25,24 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) - original_prompts = open(config.inference.prompts_path).readlines() - total_prompts = len(original_prompts) - # Distribute prompts across processes using round-robin assignment - prompts = original_prompts[global_rank::world_size] + single_prompt_mode = config.inference.prompt is not None + + # Use single prompt if specified, otherwise read from file + if single_prompt_mode: + original_prompts = [config.inference.prompt] + logger.info(f"Using single prompt: {config.inference.prompt}") + bs = 1 + # If only single prompt, each rank will generate an image with the same prompt + prompts = original_prompts + else: + original_prompts = open(config.inference.prompts_path).readlines() + logger.info(f"Reading prompts from: {config.inference.prompts_path}") + bs = config.inference.local_batch_size + # Distribute prompts across processes using round-robin assignment + prompts = original_prompts[global_rank::world_size] + + total_prompts = len(original_prompts) trainer.checkpointer.load(step=config.checkpoint.load_step) t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config) @@ -39,15 +52,19 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts - bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, config.inference.save_img_folder, ) - # Create mapping from local indices to global prompt indices - global_ids = list(range(global_rank, total_prompts, world_size)) + if single_prompt_mode: + # In single prompt mode, all ranks process the same prompt (index 0) + # But each rank generates a different image (different seed/rank) + global_ids = [0] * len(prompts) + else: + # In multi-prompt mode, use round-robin distribution + global_ids = list(range(global_rank, total_prompts, world_size)) for i in range(0, len(prompts), bs): images = generate_image( diff --git a/torchtitan/models/flux/job_config.py b/torchtitan/models/flux/job_config.py index 60422de2ee..f9bda99760 100644 --- a/torchtitan/models/flux/job_config.py +++ b/torchtitan/models/flux/job_config.py @@ -62,6 +62,8 @@ class Inference: """Path to save the inference results""" prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt" """Path to file with newline separated prompts to generate images for""" + prompt: str = "" + """Single prompt to generate image for. If specified, takes precedence over prompts_path""" local_batch_size: int = 2 """Batch size for inference""" img_size: int = 256 diff --git a/torchtitan/models/flux/run_infer.sh b/torchtitan/models/flux/run_infer.sh index bf1b4aa5a6..67c2690a49 100755 --- a/torchtitan/models/flux/run_infer.sh +++ b/torchtitan/models/flux/run_infer.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/models/flux/run_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"} diff --git a/torchtitan/models/flux/run_train.sh b/torchtitan/models/flux/run_train.sh index 2661e02691..dda0515f19 100755 --- a/torchtitan/models/flux/run_train.sh +++ b/torchtitan/models/flux/run_train.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"}