Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/integration_tests/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
"Flux Validation Test",
"validation",
),
OverrideDefinitions(
[
[
"--checkpoint.enable",
],
["--inference.prompt='A beautiful sunset over the ocean'"],
],
"Flux Generation script test",
"test_generate",
ngpu=2,
),
]
return integration_tests_flavors

Expand Down
31 changes: 24 additions & 7 deletions torchtitan/models/flux/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,24 @@ def inference(config: JobConfig):
# Distributed processing setup: Each GPU/process handles a subset of prompts
world_size = int(os.environ["WORLD_SIZE"])
global_rank = int(os.environ["RANK"])
original_prompts = open(config.inference.prompts_path).readlines()
total_prompts = len(original_prompts)

# Distribute prompts across processes using round-robin assignment
prompts = original_prompts[global_rank::world_size]
single_prompt_mode = config.inference.prompt is not None

# Use single prompt if specified, otherwise read from file
if single_prompt_mode:
original_prompts = [config.inference.prompt]
logger.info(f"Using single prompt: {config.inference.prompt}")
bs = 1
# If only single prompt, each rank will generate an image with the same prompt
prompts = original_prompts
else:
original_prompts = open(config.inference.prompts_path).readlines()
logger.info(f"Reading prompts from: {config.inference.prompts_path}")
bs = config.inference.local_batch_size
# Distribute prompts across processes using round-robin assignment
prompts = original_prompts[global_rank::world_size]

total_prompts = len(original_prompts)

trainer.checkpointer.load(step=config.checkpoint.load_step)
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config)
Expand All @@ -39,15 +52,19 @@ def inference(config: JobConfig):

if prompts:
# Generate images for this process's assigned prompts
bs = config.inference.local_batch_size

output_dir = os.path.join(
config.job.dump_folder,
config.inference.save_img_folder,
)

# Create mapping from local indices to global prompt indices
global_ids = list(range(global_rank, total_prompts, world_size))
if single_prompt_mode:
# In single prompt mode, all ranks process the same prompt (index 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I somehow feel this PR is doing some overkill stuff. Would it be simpler if we just provide a truncated version of prompts.txt in test assets, say 1/2 prompts on each rank? Our goal is to make CI job lighter. Users should be fine if they always have to specify prompts in a .txt file? I just feel dealing with two paths doesn't seem to be elegant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That would be easier. I think specify prompts in a .txt file is easier, I will remove the single prompt path. However, there are a minor bug to fix: If the number of prompts < number of ranks, some rank will get 0 prompts. During forward path of T5/clips encoder, the program will hang because FSDP is applied on encoder, and some ranks didn't run forward so all_gather will hang.

I will modify the PR to fix the bug and always using prompts.txt file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I think we can just error out early in that case, for simplicity.

# But each rank generates a different image (different seed/rank)
global_ids = [0] * len(prompts)
else:
# In multi-prompt mode, use round-robin distribution
global_ids = list(range(global_rank, total_prompts, world_size))

for i in range(0, len(prompts), bs):
images = generate_image(
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/models/flux/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Inference:
"""Path to save the inference results"""
prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt"
"""Path to file with newline separated prompts to generate images for"""
prompt: str = ""
"""Single prompt to generate image for. If specified, takes precedence over prompts_path"""
local_batch_size: int = 2
"""Batch size for inference"""
img_size: int = 256
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/flux/run_infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set -ex
# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./torchtitan/models/flux/run_train.sh
NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"4"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"}

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/flux/run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set -ex
# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh
NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"4"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"}

Expand Down
Loading