Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d5414a3

Browse files
committed
Wrap video bytes in class with ControlImage
1 parent 5cc8bc1 commit d5414a3

File tree

11 files changed

+65
-23
lines changed

11 files changed

+65
-23
lines changed

OnnxStack.Core/Services/IVideoService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,6 @@ public interface IVideoService
104104
/// <param name="targetFPS">The target FPS.</param>
105105
/// <param name="cancellationToken">The cancellation token.</param>
106106
/// <returns></returns>
107-
IAsyncEnumerable<byte[]> StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default);
107+
IAsyncEnumerable<VideoFrame> StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default);
108108
}
109109
}

OnnxStack.Core/Services/VideoService.cs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ public async Task<VideoOutput> CreateVideoAsync(VideoFrames videoFrames, Cancell
104104
/// <returns></returns>
105105
public async Task<VideoOutput> CreateVideoAsync(DenseTensor<float> videoTensor, float videoFPS, CancellationToken cancellationToken = default)
106106
{
107-
var videoFrames = await videoTensor.ToVideoFramesAsBytesAsync().ToListAsync(cancellationToken);
107+
var videoFrames = await videoTensor
108+
.ToVideoFramesAsBytesAsync()
109+
.Select(x => new VideoFrame(x))
110+
.ToListAsync(cancellationToken);
108111
return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken);
109112
}
110113

@@ -118,7 +121,8 @@ public async Task<VideoOutput> CreateVideoAsync(DenseTensor<float> videoTensor,
118121
/// <returns></returns>
119122
public async Task<VideoOutput> CreateVideoAsync(IEnumerable<byte[]> videoFrames, float videoFPS, CancellationToken cancellationToken = default)
120123
{
121-
return await CreateVideoInternalAsync(videoFrames, videoFPS, cancellationToken);
124+
var frames = videoFrames.Select(x => new VideoFrame(x));
125+
return await CreateVideoInternalAsync(frames, videoFPS, cancellationToken);
122126
}
123127

124128

@@ -190,7 +194,7 @@ public async Task<VideoFrames> CreateFramesAsync(Stream videoStream, float video
190194
/// <param name="targetFPS">The target FPS.</param>
191195
/// <param name="cancellationToken">The cancellation token.</param>
192196
/// <returns></returns>
193-
public IAsyncEnumerable<byte[]> StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default)
197+
public IAsyncEnumerable<VideoFrame> StreamFramesAsync(byte[] videoBytes, float targetFPS, CancellationToken cancellationToken = default)
194198
{
195199
return CreateFramesInternalAsync(videoBytes, targetFPS, cancellationToken);
196200
}
@@ -220,13 +224,13 @@ private async Task<VideoInfo> GetVideoInfoInternalAsync(MemoryStream videoStream
220224
/// <param name="fps">The FPS.</param>
221225
/// <param name="cancellationToken">The cancellation token.</param>
222226
/// <returns></returns>
223-
private async Task<VideoOutput> CreateVideoInternalAsync(IEnumerable<byte[]> imageData, float fps = 15, CancellationToken cancellationToken = default)
227+
private async Task<VideoOutput> CreateVideoInternalAsync(IEnumerable<VideoFrame> imageData, float fps = 15, CancellationToken cancellationToken = default)
224228
{
225229
string tempVideoPath = GetTempFilename();
226230
try
227231
{
228232
// Analyze first fram to get some details
229-
var frameInfo = await GetVideoInfoAsync(imageData.First());
233+
var frameInfo = await GetVideoInfoAsync(imageData.First().Frame);
230234
var aspectRatio = (double)frameInfo.Width / frameInfo.Height;
231235
using (var videoWriter = CreateWriter(tempVideoPath, fps, aspectRatio))
232236
{
@@ -235,7 +239,7 @@ private async Task<VideoOutput> CreateVideoInternalAsync(IEnumerable<byte[]> ima
235239
foreach (var image in imageData)
236240
{
237241
// Write each frame to the input stream of FFMPEG
238-
await videoWriter.StandardInput.BaseStream.WriteAsync(image, cancellationToken);
242+
await videoWriter.StandardInput.BaseStream.WriteAsync(image.Frame, cancellationToken);
239243
}
240244

241245
// Done close stream and wait for app to process
@@ -265,7 +269,7 @@ private async Task<VideoOutput> CreateVideoInternalAsync(IEnumerable<byte[]> ima
265269
/// <param name="cancellationToken">The cancellation token.</param>
266270
/// <returns></returns>
267271
/// <exception cref="Exception">Invalid PNG header</exception>
268-
private async IAsyncEnumerable<byte[]> CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default)
272+
private async IAsyncEnumerable<VideoFrame> CreateFramesInternalAsync(byte[] videoData, float fps = 15, [EnumeratorCancellation] CancellationToken cancellationToken = default)
269273
{
270274
string tempVideoPath = GetTempFilename();
271275
try
@@ -325,7 +329,7 @@ private async IAsyncEnumerable<byte[]> CreateFramesInternalAsync(byte[] videoDat
325329
break;
326330
}
327331

328-
yield return buffer[..currentIndex];
332+
yield return new VideoFrame(buffer[..currentIndex]);
329333
}
330334

331335
if (cancellationToken.IsCancellationRequested)

OnnxStack.Core/Video/VideoFrame.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using OnnxStack.Core.Image;
2+
3+
namespace OnnxStack.Core.Video
4+
{
5+
public record VideoFrame(byte[] Frame, InputImage ControlImage = default);
6+
}

OnnxStack.Core/Video/VideoFrames.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
namespace OnnxStack.Core.Video
44
{
5-
public record VideoFrames(VideoInfo Info, IReadOnlyList<byte[]> Frames);
5+
public record VideoFrames(VideoInfo Info, IReadOnlyList<VideoFrame> Frames);
66
}

OnnxStack.ImageUpscaler/Services/UpscaleService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ public async Task<List<DenseTensor<float>>> GenerateInternalAsync(UpscaleModelSe
265265
var outputTensors = new List<DenseTensor<float>>();
266266
foreach (var frame in videoFrames.Frames)
267267
{
268-
using (var imageFrame = Image.Load<Rgba32>(frame))
268+
using (var imageFrame = Image.Load<Rgba32>(frame.Frame))
269269
{
270270
var input = CreateInputParams(imageFrame, modelSession.SampleSize, modelSession.ScaleFactor);
271271
var outputDimension = new[] { 1, modelSession.Channels, 0, 0 };

OnnxStack.StableDiffusion/Helpers/ModelFactory.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
5050
var textEncoder2Config = default(TextEncoderModelConfig);
5151
var tokenizerConfig = new TokenizerModelConfig
5252
{
53-
TokenizerLength = modelType == ModelType.Turbo ? 1024 : 768,
5453
OnnxModelPath = tokenizerPath
5554
};
5655

@@ -106,6 +105,10 @@ public static StableDiffusionModelSet CreateModelSet(string modelFolder, Diffuse
106105
};
107106
}
108107

108+
// SD-Turbo has TokenizerLength 1024
109+
if (pipeline == DiffuserPipelineType.StableDiffusion && modelType == ModelType.Turbo)
110+
tokenizerConfig.TokenizerLength = 1024;
111+
109112
var configuration = new StableDiffusionModelSet
110113
{
111114
IsEnabled = true,

OnnxStack.StableDiffusion/Pipelines/Base/PipelineBase.cs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,20 @@ protected async Task<DenseTensor<float>> DiffuseVideoAsync(IDiffuser diffuser, P
155155
foreach (var videoFrame in videoFrames)
156156
{
157157
frameIndex++;
158-
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
159-
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
158+
// byte[] videoFrame = videoFrames[i].Frame;
159+
if (promptOptions.DiffuserType == DiffuserType.ControlNet || promptOptions.DiffuserType == DiffuserType.ControlNetImage)
160+
{
161+
// ControlNetImage uses frame as input image
162+
if (promptOptions.DiffuserType == DiffuserType.ControlNetImage)
163+
promptOptions.InputImage = new InputImage(videoFrame.Frame);
164+
165+
promptOptions.InputContolImage = videoFrame.ControlImage;
166+
}
167+
else
168+
{
169+
promptOptions.InputImage = new InputImage(videoFrame.Frame);
170+
}
171+
160172
var frameResultTensor = await diffuser.DiffuseAsync(promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);
161173

162174
// Frame Progress
@@ -223,5 +235,24 @@ protected Action<DiffusionProgress> CreateBatchCallback(Action<DiffusionProgress
223235
BatchValue = batchIndex()
224236
});
225237
}
238+
239+
240+
/// <summary>
241+
/// Creates the pipeline from a ModelSet configuration.
242+
/// </summary>
243+
/// <param name="modelSet">The model set.</param>
244+
/// <param name="logger">The logger.</param>
245+
/// <returns></returns>
246+
public static IPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
247+
{
248+
return modelSet.PipelineType switch
249+
{
250+
DiffuserPipelineType.StableDiffusionXL => StableDiffusionXLPipeline.CreatePipeline(modelSet, logger),
251+
DiffuserPipelineType.LatentConsistency => LatentConsistencyPipeline.CreatePipeline(modelSet, logger),
252+
DiffuserPipelineType.LatentConsistencyXL => LatentConsistencyXLPipeline.CreatePipeline(modelSet, logger),
253+
DiffuserPipelineType.InstaFlow => InstaFlowPipeline.CreatePipeline(modelSet, logger),
254+
_ => StableDiffusionPipeline.CreatePipeline(modelSet, logger)
255+
};
256+
}
226257
}
227258
}

OnnxStack.StableDiffusion/Pipelines/StableDiffusionPipeline.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ public override Task LoadAsync()
123123
/// Unloads the pipeline.
124124
/// </summary>
125125
/// <returns></returns>
126-
public override Task UnloadAsync()
126+
public override async Task UnloadAsync()
127127
{
128+
await Task.Yield();
128129
_unet?.Dispose();
129130
_tokenizer?.Dispose();
130131
_textEncoder?.Dispose();
131132
_vaeDecoder?.Dispose();
132133
_vaeEncoder?.Dispose();
133-
return Task.CompletedTask;
134134
}
135135

136136

@@ -371,7 +371,7 @@ protected IEnumerable<int> PadWithBlankTokens(IEnumerable<int> inputs, int requi
371371
/// <param name="modelSet">The model set.</param>
372372
/// <param name="logger">The logger.</param>
373373
/// <returns></returns>
374-
public static StableDiffusionPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
374+
public static new StableDiffusionPipeline CreatePipeline(StableDiffusionModelSet modelSet, ILogger logger = default)
375375
{
376376
var unet = new UNetConditionModel(modelSet.UnetConfig.ApplyDefaults(modelSet));
377377
var tokenizer = new TokenizerModel(modelSet.TokenizerConfig.ApplyDefaults(modelSet));

OnnxStack.StableDiffusion/Pipelines/StableDiffusionXLPipeline.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public StableDiffusionXLPipeline(string name, TokenizerModel tokenizer, Tokenize
4343
{
4444
SchedulerType.Euler,
4545
SchedulerType.EulerAncestral,
46-
SchedulerType.DDIM,
46+
SchedulerType.DDPM,
4747
SchedulerType.KDPM2
4848
};
4949
_defaultSchedulerOptions = defaultSchedulerOptions ?? new SchedulerOptions

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
using OnnxStack.Core.Services;
44
using OnnxStack.StableDiffusion.Common;
55
using OnnxStack.StableDiffusion.Config;
6-
using OnnxStack.StableDiffusion.Diffusers;
7-
using OnnxStack.StableDiffusion.Pipelines;
86
using OnnxStack.StableDiffusion.Services;
97
using SixLabors.ImageSharp;
108
using SixLabors.ImageSharp.Memory;

0 commit comments

Comments
 (0)