From 083e198dc722e550acd0ce54a755e82c8662deaa Mon Sep 17 00:00:00 2001 From: tjkemp Date: Wed, 30 Jul 2025 14:57:03 +0300 Subject: [PATCH 1/2] Fix attention mask handling in batch generation --- examples/hunyuan_video_usp_example.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index 03856f11..a9041ac7 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -92,14 +92,8 @@ def new_forward( get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] - encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) - encoder_hidden_states_indices = torch.arange( - encoder_hidden_states.shape[1], - device=encoder_hidden_states.device) - encoder_hidden_states_indices = encoder_hidden_states_indices[ - encoder_attention_mask] - encoder_hidden_states = encoder_hidden_states[ - ..., encoder_hidden_states_indices, :] + encoder_attention_mask = encoder_attention_mask.to(torch.bool).any(dim=0) + encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask, :] if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size( ) != 0: get_runtime_state().split_text_embed_in_sp = False @@ -297,7 +291,7 @@ def main(): guidance_scale=input_config.guidance_scale, generator=torch.Generator(device="cuda").manual_seed( input_config.seed), - ).frames[0] + ) end_time = time.time() elapsed_time = end_time - start_time @@ -311,9 +305,10 @@ def main(): ) if is_dp_last_group(): resolution = f"{input_config.width}x{input_config.height}" - output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4" - export_to_video(output, output_filename, fps=15) - print(f"output saved to {output_filename}") + for idx, frames in enumerate(output.frames, start=1): + output_filename = f"results/hunyuan_video_{idx:02d}_{parallel_info}_{resolution}.mp4" + export_to_video(frames, output_filename, fps=15) + print(f"output saved to {output_filename}") if get_world_group().rank == get_world_group().world_size - 1: print( From a44477327e18571bf179712df9fe006996837370 Mon Sep 17 00:00:00 2001 From: tjkemp Date: Tue, 2 Sep 2025 16:05:23 +0300 Subject: [PATCH 2/2] Fix hard coded batch_size --- examples/hunyuan_video_usp_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index a9041ac7..cc354aa3 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -228,7 +228,7 @@ def main(): height=input_config.height, width=input_config.width, num_frames=input_config.num_frames, - batch_size=1, + batch_size=input_config.batch_size, num_inference_steps=input_config.num_inference_steps, split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, )