Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion docs/source/en/api/pipelines/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.

- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.

Expand Down Expand Up @@ -414,6 +414,91 @@ export_to_video(video, "output.mp4", fps=24)

</details>

<details>
<summary>Long image-to-video generation with multi-prompt sliding windows (ComfyUI parity)</summary>

```py
import torch
from diffusers import LTXI2VLongMultiPromptPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
from diffusers.utils import export_to_video
from PIL import Image


# Stage A: long I2V with sliding windows and multi-prompt scheduling
pipe = LTXI2VLongMultiPromptPipeline.from_pretrained(
"LTX-Video-0.9.8-13B-distilled",
torch_dtype=torch.bfloat16
).to("cuda")

schedule = "a chimpanzee walks in the jungle |a chimpanzee stops and eats a snack |a chimpanzee lays on the ground"
cond_image = Image.open("chimpanzee_l.jpg").convert("RGB")

latents = pipe(
prompt=schedule,
negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
width=768,
height=512, # must be divisible by 32
num_frames=361,
temporal_tile_size=120,
temporal_overlap=32,
sigmas=[1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0],
guidance_scale=1.0, # distilled variants typically use 1.0
cond_image=cond_image, # hard-conditions the first frame
adain_factor=0.25, # cross-window normalization
output_type="latent", # return latent-space video for downstream processing
).frames

# Optional: decode with VAE tiling
video_pil = pipe.vae_decode_tiled(latents, decode_timestep=0.05, decode_noise_scale=0.025, output_type="pil")[0]
export_to_video(video_pil, "ltx_i2v_long_base.mp4", fps=24)

# Stage B (optional): spatial latent upsampling + short refinement
upsampler = LTXLatentUpsamplerModel.from_pretrained("LTX-Video-spatial-upscaler-0.9.8/latent_upsampler", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=upsampler).to(torch.bfloat16).to("cuda")

up_latents = pipe_upsample(
latents=latents,
adain_factor=1.0,
tone_map_compression_ratio=0.6,
output_type="latent"
).frames
try:
pipe.load_lora_weights(
"LTX-Video-ICLoRA-detailer-13b-0.9.8/ltxv-098-ic-lora-detailer-diffusers.safetensors",
adapter_name="ic-detailer",
)
pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
print("[Info] IC-LoRA detailer adapter loaded and fused.")
except Exception as e:
print(f"[Warn] Failed to load IC-LoRA: {e}. Skipping the second refinement sampling.")

# Short refinement pass (distilled; low steps)
frames_refined = pipe(
negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
width=768,
height=512,
num_frames=361,
temporal_tile_size=80,
temporal_overlap=24,
seed=1625,
adain_factor=0.0, # disable AdaIN in refinement
latents=up_latents, # start from upscaled latents
guidance_latents=up_latents,
sigmas=[0.99, 0.9094, 0.0], # short sigma schedule
output_type="pil",
).frames[0]

export_to_video(frames_refined, "ltx_i2v_long_refined.mp4", fps=24)
```

Notes:
- Seeding: window-local hard-condition noise uses `seed + w_start` when `seed` is provided; otherwise the passed-in `generator` drives stochasticity.
- Height/width must be divisible by 32; latent shapes follow the pipeline docstrings.
- Use VAE tiled decoding to avoid OOM for high resolutions or long sequences.
- Distilled variants generally prefer `guidance_scale=1.0` and short schedules for refinement.
</details>

- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].

<details>
Expand Down Expand Up @@ -474,6 +559,12 @@ export_to_video(video, "output.mp4", fps=24)

</details>

## LTXI2VLongMultiPromptPipeline

[[autodoc]] LTXI2VLongMultiPromptPipeline
- all
- __call__

## LTXPipeline

[[autodoc]] LTXPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LTXI2VLongMultiPromptPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
Expand Down Expand Up @@ -1191,6 +1192,7 @@
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LTXI2VLongMultiPromptPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
Expand Down Expand Up @@ -691,7 +692,7 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline,LTXI2VLongMultiPromptPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/ltx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
_import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -41,6 +42,7 @@
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline

else:
import sys
Expand Down
Loading