Skip to content

Commit c0f89e6

Browse files
authored
Img2vid pipeline (#227)
* updated running instructions * fix logging issue * img2vid initial * img2vid initial commit * deleted images * resolved readme conflict * resolved readme conflict * img2vid pipeline added * compatible with text2vid
1 parent 38283b5 commit c0f89e6

File tree

5 files changed

+401
-40
lines changed

5 files changed

+401
-40
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
2021
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
2122
- **`2025/04/17`**: Flux Finetuning.
2223
- **`2025/02/12`**: Flux LoRA for inference.
@@ -42,7 +43,7 @@ MaxDiffusion supports
4243
* Load Multiple LoRA (SDXL inference).
4344
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
4445
* Dreambooth training support for Stable Diffusion 1.x,2.x.
45-
* LTX-Video text2vid (inference).
46+
* LTX-Video text2vid, img2vid (inference).
4647

4748

4849
# Table of Contents
@@ -183,7 +184,8 @@ To generate images, run the following command:
183184
```bash
184185
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
185186
```
186-
- Other generation parameters can be set in ltx_video.yml file.
187+
- Img2video Generation:
188+
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
187189
## Flux
188190

189191
First make sure you have permissions to access the Flux repos in Huggingface.

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ sampler: "from_checkpoint"
2222

2323
# Generation parameters
2424
pipeline_type: multi-scale
25-
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. "
25+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2626
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
2727
height: 512
2828
width: 512
@@ -35,6 +35,8 @@ stg_mode: "attention_values"
3535
decode_timestep: 0.05
3636
decode_noise_scale: 0.025
3737
seed: 10
38+
conditioning_media_paths: None #["IMAGE_PATH"]
39+
conditioning_start_frames: [0]
3840

3941

4042
first_pass:

src/maxdiffusion/generate_ltx_video.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616

1717
import numpy as np
1818
from absl import app
19-
from typing import Sequence
19+
from typing import Sequence, List, Optional, Union
2020
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
21-
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
21+
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
22+
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
2223
from maxdiffusion import pyconfig, max_logging
24+
import torchvision.transforms.functional as TVF
2325
import imageio
2426
from datetime import datetime
2527
import os
2628
import time
2729
from pathlib import Path
30+
from PIL import Image
31+
import torch
2832

2933

3034
def calculate_padding(
@@ -44,6 +48,79 @@ def calculate_padding(
4448
return padding
4549

4650

51+
def load_image_to_tensor_with_resize_and_crop(
52+
image_input: Union[str, Image.Image],
53+
target_height: int = 512,
54+
target_width: int = 768,
55+
just_crop: bool = False,
56+
) -> torch.Tensor:
57+
"""Load and process an image into a tensor.
58+
59+
Args:
60+
image_input: Either a file path (str) or a PIL Image object
61+
target_height: Desired height of output tensor
62+
target_width: Desired width of output tensor
63+
just_crop: If True, only crop the image to the target size without resizing
64+
"""
65+
if isinstance(image_input, str):
66+
image = Image.open(image_input).convert("RGB")
67+
elif isinstance(image_input, Image.Image):
68+
image = image_input
69+
else:
70+
raise ValueError("image_input must be either a file path or a PIL Image object")
71+
72+
input_width, input_height = image.size
73+
aspect_ratio_target = target_width / target_height
74+
aspect_ratio_frame = input_width / input_height
75+
if aspect_ratio_frame > aspect_ratio_target:
76+
new_width = int(input_height * aspect_ratio_target)
77+
new_height = input_height
78+
x_start = (input_width - new_width) // 2
79+
y_start = 0
80+
else:
81+
new_width = input_width
82+
new_height = int(input_width / aspect_ratio_target)
83+
x_start = 0
84+
y_start = (input_height - new_height) // 2
85+
86+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
87+
if not just_crop:
88+
image = image.resize((target_width, target_height))
89+
90+
frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
91+
frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
92+
frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
93+
frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
94+
frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
95+
frame_tensor = (frame_tensor / 127.5) - 1.0
96+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
97+
return frame_tensor.unsqueeze(0).unsqueeze(2)
98+
99+
100+
def prepare_conditioning(
101+
conditioning_media_paths: List[str],
102+
conditioning_strengths: List[float],
103+
conditioning_start_frames: List[int],
104+
height: int,
105+
width: int,
106+
padding: tuple[int, int, int, int],
107+
) -> Optional[List[ConditioningItem]]:
108+
"""Prepare conditioning items based on input media paths and their parameters."""
109+
conditioning_items = []
110+
for path, strength, start_frame in zip(conditioning_media_paths, conditioning_strengths, conditioning_start_frames):
111+
num_input_frames = 1
112+
media_tensor = load_media_file(
113+
media_path=path,
114+
height=height,
115+
width=width,
116+
max_frames=num_input_frames,
117+
padding=padding,
118+
just_crop=True,
119+
)
120+
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
121+
return conditioning_items
122+
123+
47124
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
48125
# Remove non-letters and convert to lowercase
49126
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
@@ -68,6 +145,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
68145
return "-".join(result)
69146

70147

148+
def load_media_file(
149+
media_path: str,
150+
height: int,
151+
width: int,
152+
max_frames: int,
153+
padding: tuple[int, int, int, int],
154+
just_crop: bool = False,
155+
) -> torch.Tensor:
156+
media_tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width, just_crop=just_crop)
157+
media_tensor = torch.nn.functional.pad(media_tensor, padding)
158+
return media_tensor
159+
160+
71161
def get_unique_filename(
72162
base: str,
73163
ext: str,
@@ -97,6 +187,25 @@ def run(config):
97187
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
98188
if config.pipeline_type == "multi-scale":
99189
pipeline = LTXMultiScalePipeline(pipeline)
190+
conditioning_media_paths = config.conditioning_media_paths if isinstance(config.conditioning_media_paths, List) else None
191+
conditioning_start_frames = config.conditioning_start_frames
192+
conditioning_strengths = None
193+
if conditioning_media_paths:
194+
if not conditioning_strengths:
195+
conditioning_strengths = [1.0] * len(conditioning_media_paths)
196+
conditioning_items = (
197+
prepare_conditioning(
198+
conditioning_media_paths=conditioning_media_paths,
199+
conditioning_strengths=conditioning_strengths,
200+
conditioning_start_frames=conditioning_start_frames,
201+
height=config.height,
202+
width=config.width,
203+
padding=padding,
204+
)
205+
if conditioning_media_paths
206+
else None
207+
)
208+
100209
s0 = time.perf_counter()
101210
images = pipeline(
102211
height=height_padded,
@@ -106,6 +215,7 @@ def run(config):
106215
output_type="pt",
107216
config=config,
108217
enhance_prompt=enhance_prompt,
218+
conditioning_items=conditioning_items,
109219
seed=config.seed,
110220
)
111221
max_logging.log(f"Compile time: {time.perf_counter() - s0:.1f}s.")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Lightricks Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# This implementation is based on the Torch version available at:
16+
# https://github.com/Lightricks/LTX-Video/tree/main
17+
import av
18+
import torch
19+
import io
20+
import numpy as np
21+
22+
23+
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
24+
container = av.open(output_file, "w", format="mp4")
25+
try:
26+
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
27+
stream.height = image_array.shape[0]
28+
stream.width = image_array.shape[1]
29+
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
30+
container.mux(stream.encode(av_frame))
31+
container.mux(stream.encode())
32+
finally:
33+
container.close()
34+
35+
36+
def _decode_single_frame(video_file):
37+
container = av.open(video_file)
38+
try:
39+
stream = next(s for s in container.streams if s.type == "video")
40+
frame = next(container.decode(stream))
41+
finally:
42+
container.close()
43+
return frame.to_ndarray(format="rgb24")
44+
45+
46+
def compress(image: torch.Tensor, crf=29):
47+
if crf == 0:
48+
return image
49+
50+
image_array = (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
51+
with io.BytesIO() as output_file:
52+
_encode_single_frame(output_file, image_array, crf)
53+
video_bytes = output_file.getvalue()
54+
with io.BytesIO(video_bytes) as video_file:
55+
image_array = _decode_single_frame(video_file)
56+
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
57+
return tensor

0 commit comments

Comments
 (0)