Skip to content

Conversation

@likholat
Copy link
Contributor

Description

LTX Text2Video

Continuation of #2982

CVS-164653

Checklist:

  • Tests have been updated or added to cover the new code.
  • This patch fully addresses the ticket.
  • I have made corresponding changes to the documentation.

@likholat likholat requested a review from Wovchena November 10, 2025 12:04
@likholat likholat self-assigned this Nov 10, 2025
@likholat likholat mentioned this pull request Nov 10, 2025
3 tasks
@github-actions github-actions bot added category: image generation Image generation pipelines category: cmake / build Cmake scripts category: Python API Python API for GenAI category: CPP API Changes in GenAI C++ public headers no-match-files category: Image generation samples GenAI Image generation samples labels Nov 10, 2025
@likholat likholat requested a review from rkazants November 10, 2025 12:08
@Wovchena Wovchena requested a review from Copilot November 18, 2025 08:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a Text2Video pipeline for LTX-Video model, building upon previous work. The implementation includes video generation capabilities with text-to-video conversion, model wrappers for the transformer and VAE components, configuration management, and sample applications demonstrating usage.

Key Changes:

  • Adds Text2VideoPipeline class and related video generation infrastructure
  • Implements LTX-Video specific models (transformer and VAE decoder)
  • Extends T5EncoderModel to support attention masks and tokenization parameters
  • Adds video similarity utility and sample applications

Reviewed Changes

Copilot reviewed 25 out of 25 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
src/cpp/src/video_generation/text2video_pipeline.cpp Core pipeline implementation with latent packing/unpacking and video post-processing
src/cpp/src/video_generation/ltx_video_transformer_3d_model.cpp Wrapper for LTX video transformer model
src/cpp/src/video_generation/autoencoder_kl_ltx_video.cpp VAE decoder implementation for video
src/cpp/src/image_generation/models/t5_encoder_model.cpp Extended to support attention masks and custom tokenization parameters
src/cpp/include/openvino/genai/video_generation/generation_config.hpp Video-specific generation configuration
samples/cpp/video_generation/text2video.cpp Sample application demonstrating video generation
video_similarity.py Utility for computing video similarity metrics
Comments suppressed due to low confidence (1)

src/cpp/src/video_generation/text2video_pipeline.cpp:1

  • String comparison operands are reversed. Should be if __name__ == \"__main__\":
// Copyright (C) 2025 Intel Corporation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


#include "image_generation/numpy_utils.hpp"
#include "utils.hpp"
#include "debug_utils.hpp"
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Debug utility header included in production code. Consider removing this include if debug_utils.hpp is only intended for debugging purposes and not needed in production builds.

Suggested change
#include "debug_utils.hpp"

Copilot uses AI. Check for mistakes.
tokenization_params
);
},
py::call_guard<py::gil_scoped_release>(),
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant GIL release guard. The lambda already releases the GIL at line 193, making this call_guard unnecessary and potentially incorrect. Remove this line.

Suggested change
py::call_guard<py::gil_scoped_release>(),

Copilot uses AI. Check for mistakes.
@Wovchena Wovchena requested a review from sgonorov November 18, 2025 13:21
@@ -0,0 +1,56 @@
import argparse
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file will be removed in the final version

@@ -0,0 +1,78 @@
# Based on https://huggingface.co/docs/transformers/main/model_doc/xclip#transformers.XCLIPModel.get_video_features
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script will be moved to wwb in the final version

@likholat likholat marked this pull request as ready for review November 18, 2025 17:51
@Wovchena Wovchena requested a review from Copilot November 19, 2025 06:43
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated 14 comments.

Comments suppressed due to low confidence (1)

src/cpp/src/image_generation/models/t5_encoder_model.cpp:1

  • The TODO suggests skipping attention mask filling when not needed by the pipeline. Implement a mechanism to conditionally compute the attention mask only when required to avoid unnecessary computational overhead.
// Copyright (C) 2023-2025 Intel Corporation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 20 to 52
// Compare with https://github.com/Lightricks/LTX-Video
// TODO: Test GPU, NPU, HETERO, MULTI, AUTO, different steps on different devices
// TODO: describe algo to generate a video in docs and docstrings
// TODO: explain in docstrings available perf metrics
// scheduler needs extra dim?
// Present that will update validation tools later
// new classes LTXVideoTransformer3DModel AutoencoderKLLTXVideo
// private copy constructors to implement clone()
// const VideoGenerationConfig& may outlive VideoGenerationConfig?
// Move negative_prompt to Property
// Allow selecting different models to export from optimum-intel, for example ltxv-2b-0.9.8-distilled.safetensors
// LoRA later: https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7, https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7, https://huggingface.co/Lightricks/LTXV-LoRAs Check https://github.com/Lightricks/LTX-Video for updates
// Wasn't need so far so not going to implement:
// OVLTXPipeline allows prompt_embeds and prompt_attention_mask instead of prompt; Same for negative_prompt_embeds and negative_prompt_attention_mask
// OVLTXPipeline allows batched generation with multiple prompts
// Tests:
// Functional
// Sample
// Cover all config members in sample. Use default values explicitly
// Prefer patching optimum-intel to include more stuff into a model instead of implementing it in C++
// Add video-to-video, inpainting
// image to video described in https://huggingface.co/Lightricks/LTX-Video (class LTXConditionPipeline)
// Optimum doesn't have LTXLatentUpsamplePipeline class
// Controlled video from https://github.com/Lightricks/LTX-Video
// TODO: decode, perf metrics, set_scheduler, set/get_generation_config, reshape, compile, clone()
// TODO: Rename image->video everywhere
// TODO: test multiple videos per prompt
// TODO: test with different config values
// TODO: test log prompts to check truncation
// TODO: throw if num_frames isn't devisable by 8 + 1. Similar value for resolution. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). The model works best on resolutions under 720 x 1280 and number of frames below 257.
// OVLTXPipeline()(num_inference_steps=1) fails. 2 passes. Would be nice to avoid that bug in genai.
// Verify tiny resolution like 32x32
const std::string device = "CPU"; // GPU can be used as well
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extensive TODO comments in the sample file should be addressed or moved to a proper issue tracker. Sample files should demonstrate clean, production-ready usage patterns without extensive inline TODO lists.

Suggested change
// Compare with https://github.com/Lightricks/LTX-Video
// TODO: Test GPU, NPU, HETERO, MULTI, AUTO, different steps on different devices
// TODO: describe algo to generate a video in docs and docstrings
// TODO: explain in docstrings available perf metrics
// scheduler needs extra dim?
// Present that will update validation tools later
// new classes LTXVideoTransformer3DModel AutoencoderKLLTXVideo
// private copy constructors to implement clone()
// const VideoGenerationConfig& may outlive VideoGenerationConfig?
// Move negative_prompt to Property
// Allow selecting different models to export from optimum-intel, for example ltxv-2b-0.9.8-distilled.safetensors
// LoRA later: https://huggingface.co/Lightricks/LTX-Video-ICLoRA-depth-13b-0.9.7, https://huggingface.co/Lightricks/LTX-Video-ICLoRA-pose-13b-0.9.7, https://huggingface.co/Lightricks/LTXV-LoRAs Check https://github.com/Lightricks/LTX-Video for updates
// Wasn't need so far so not going to implement:
// OVLTXPipeline allows prompt_embeds and prompt_attention_mask instead of prompt; Same for negative_prompt_embeds and negative_prompt_attention_mask
// OVLTXPipeline allows batched generation with multiple prompts
// Tests:
// Functional
// Sample
// Cover all config members in sample. Use default values explicitly
// Prefer patching optimum-intel to include more stuff into a model instead of implementing it in C++
// Add video-to-video, inpainting
// image to video described in https://huggingface.co/Lightricks/LTX-Video (class LTXConditionPipeline)
// Optimum doesn't have LTXLatentUpsamplePipeline class
// Controlled video from https://github.com/Lightricks/LTX-Video
// TODO: decode, perf metrics, set_scheduler, set/get_generation_config, reshape, compile, clone()
// TODO: Rename image->video everywhere
// TODO: test multiple videos per prompt
// TODO: test with different config values
// TODO: test log prompts to check truncation
// TODO: throw if num_frames isn't devisable by 8 + 1. Similar value for resolution. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). The model works best on resolutions under 720 x 1280 and number of frames below 257.
// OVLTXPipeline()(num_inference_steps=1) fails. 2 passes. Would be nice to avoid that bug in genai.
// Verify tiny resolution like 32x32
const std::string device = "CPU"; // GPU can be used as well
// Set device to CPU; GPU can be used as well
const std::string device = "CPU";

Copilot uses AI. Check for mistakes.
Comment on lines +31 to +33
size_t patch_size = 4; // TODO: read from vae_decoder/config.json
std::vector<bool> spatio_temporal_scaling{true, true, true, false}; // TODO: read from vae_decoder/config.json. I use it only to compute sum over it so far, so it may be removed
size_t patch_size_t = 1; // TODO: read from vae_decoder/config.json
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configuration parameters have TODO comments indicating they should be read from config.json. These should be properly loaded from the configuration file rather than using hardcoded defaults.

Copilot uses AI. Check for mistakes.
Comment on lines +35 to +39
// latents_mean = torch.zeros((latent_channels,), requires_grad=False)
// latents_std = torch.ones((latent_channels,), requires_grad=False)
std::vector<float> latents_mean_data; // TODO: set default value
std::vector<float> latents_std_data; // TODO: set default value

Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configuration parameters latents_mean_data and latents_std_data have TODO comments about default values. These should either have proper defaults set or be marked as required configuration.

Suggested change
// latents_mean = torch.zeros((latent_channels,), requires_grad=False)
// latents_std = torch.ones((latent_channels,), requires_grad=False)
std::vector<float> latents_mean_data; // TODO: set default value
std::vector<float> latents_std_data; // TODO: set default value
std::vector<float> latents_mean_data;
std::vector<float> latents_std_data;
Config() :
latents_mean_data(latent_channels, 0.0f),
latents_std_data(latent_channels, 1.0f)
{}

Copilot uses AI. Check for mistakes.
Comment on lines +551 to +552
//TODO: move to compute_hidden_states
ov::Tensor rope_interpolation_scale(ov::element::f32, {3});
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO suggests moving rope_interpolation_scale computation to compute_hidden_states. This would improve code organization by grouping related hidden state computations together.

Copilot uses AI. Check for mistakes.
@Wovchena
Copy link
Collaborator

@sgonorov, please review

Copy link
Contributor

@sgonorov sgonorov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but better rebase it on the latest master.

auto filtered_properties = extract_adapters_from_properties(properties, &adapters);
OPENVINO_ASSERT(!adapters, "Adapters are not currently supported for Video Generation Pipeline.");
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, *filtered_properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Flux Transformer 2D model");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ov::genai::utils::print_compiled_model_properties(compiled_model, "Flux Transformer 2D model");
ov::genai::utils::print_compiled_model_properties(compiled_model, "LTX Video Transformer 3D model");

ImageGenerationConfig::validate();
}

void VideoGenerationConfig::update_generation_config(const ov::AnyMap& properties) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also load num_videos_per_prompt here.


using namespace ov::genai;

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to put these utilities into a separate file? text2video_pipeline.cpp looks quite huge already.

std::shared_ptr<LTXVideoTransformer3DModel> m_transformer;
std::shared_ptr<AutoencoderKLLTXVideo> m_vae;
VideoGenerationPerfMetrics m_perf_metrics;
double m_latent_timestep = -1.0; // TODO: float?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be removed - i don't see any usages.

const float* noise_pred_text = noise_pred_uncond + noisy_residual_tensor.get_size();

for (size_t i = 0; i < noisy_residual_tensor.get_size(); ++i) {
noisy_residual[i] = noise_pred_uncond[i] +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use std::fma here for clarity.

double m_latent_timestep = -1.0; // TODO: float?
Ms m_load_time;

size_t m_latent_num_frames = -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure want to assign -1 to size_t here? Maybe it's better to replace it with std::optional?

std::function<bool(size_t, size_t, ov::Tensor&)> callback;
auto callback_iter = properties.find(ov::genai::callback.name());
if (callback_iter != properties.end()) {
callback = callback_iter->second.as<std::function<bool(size_t, size_t, ov::Tensor&)>>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be wrong, but i don't see any callback usages down the code. Should we call it somehere?

config.max_sequence_length = LTX_VIDEO_DEFAULT_CONFIG.max_sequence_length;
}
if (std::isnan(config.guidance_rescale)) {
config.guidance_rescale = LTX_VIDEO_DEFAULT_CONFIG.guidance_rescale;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this parameter used? Maybe it's better to remove it for now?

OPENVINO_ASSERT(generation_config.height % 32 == 0, "Height have to be divisible by 32 but got ", generation_config.height);
OPENVINO_ASSERT(generation_config.width > 0, "Width must be positive");
OPENVINO_ASSERT(generation_config.width % 32 == 0, "Width have to be divisible by 32 but got ", generation_config.width);
OPENVINO_ASSERT(1.0f == generation_config.strength, "Strength isn't applicable. Must be set to the default 1.0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be checked using delta instead of == which may fail.
Also this check is repeated on line 89.

// rcFrame (4 * 16-bit) -> left, top, right, bottom
writer.write_u16(0); // left
writer.write_u16(0); // top
writer.write_u16(static_cast<uint16_t>(w)); // right
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add bounds check here just in case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: cmake / build Cmake scripts category: CPP API Changes in GenAI C++ public headers category: Image generation samples GenAI Image generation samples category: image generation Image generation pipelines category: Python API Python API for GenAI no-match-files

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants