Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 3169188

Browse files
committed
Add PipelineBase class
1 parent c7b7f8a commit 3169188

File tree

13 files changed

+422
-185
lines changed

13 files changed

+422
-185
lines changed

OnnxStack.Core/Config/OnnxModelConfig.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ namespace OnnxStack.Core.Config
55
{
66
public record OnnxModelConfig
77
{
8-
public OnnxModelType Type { get; set; }
98
public string OnnxModelPath { get; set; }
109

1110
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]

OnnxStack.Core/Config/OnnxModelSetConfig.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.ML.OnnxRuntime;
2-
using System.Collections.Generic;
32

43
namespace OnnxStack.Core.Config
54
{
@@ -13,6 +12,5 @@ public class OnnxModelSetConfig : IOnnxModelSetConfig
1312
public int IntraOpNumThreads { get; set; }
1413
public ExecutionMode ExecutionMode { get; set; }
1514
public ExecutionProvider ExecutionProvider { get; set; }
16-
public List<OnnxModelConfig> ModelConfigurations { get; set; }
1715
}
1816
}

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 0 additions & 16 deletions
This file was deleted.

OnnxStack.Core/Extensions/Extensions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Numerics;
8-
using System.Runtime.CompilerServices;
98

109
namespace OnnxStack.Core
1110
{

OnnxStack.Core/Model/OnnxInferenceParameters.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ public void AddInputTensor(DenseTensor<float> value)
4545
_inputs.Add(metaData, value.ToOrtValue(metaData));
4646
}
4747

48+
49+
/// <summary>
50+
/// Adds the input tensor.
51+
/// </summary>
52+
/// <param name="value">The value.</param>
4853
public void AddInputTensor(DenseTensor<double> value)
4954
{
5055
var metaData = GetNextInputMetadata();
@@ -118,7 +123,6 @@ public void AddOutputBuffer(int index, ReadOnlySpan<int> bufferDimension)
118123
}
119124

120125

121-
122126
/// <summary>
123127
/// Adds an output parameter with unknown output size.
124128
/// </summary>

OnnxStack.StableDiffusion/Common/IPipeline.cs

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.StableDiffusion.Common;
3+
using OnnxStack.StableDiffusion.Config;
4+
using OnnxStack.StableDiffusion.Enums;
5+
using OnnxStack.StableDiffusion.Models;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
11+
namespace OnnxStack.StableDiffusion.Pipelines
12+
{
13+
public interface IPipeline
14+
{
15+
/// <summary>
16+
/// Gets the pipelines supported diffusers.
17+
/// </summary>
18+
IReadOnlyList<DiffuserType> SupportedDiffusers { get; }
19+
20+
21+
/// <summary>
22+
/// Gets the pipelines supported schedulers.
23+
/// </summary>
24+
IReadOnlyList<SchedulerType> SupportedSchedulers { get; }
25+
26+
27+
/// <summary>
28+
/// Loads the pipeline.
29+
/// </summary>
30+
/// <returns></returns>
31+
Task LoadAsync();
32+
33+
34+
/// <summary>
35+
/// Unloads the pipeline.
36+
/// </summary>
37+
/// <returns></returns>
38+
Task UnloadAsync();
39+
40+
41+
/// <summary>
42+
/// Validates the inputs.
43+
/// </summary>
44+
/// <param name="promptOptions">The prompt options.</param>
45+
/// <param name="schedulerOptions">The scheduler options.</param>
46+
void ValidateInputs(PromptOptions promptOptions, SchedulerOptions schedulerOptions);
47+
48+
49+
/// <summary>
50+
/// Runs the pipeline.
51+
/// </summary>
52+
/// <param name="promptOptions">The prompt options.</param>
53+
/// <param name="schedulerOptions">The scheduler options.</param>
54+
/// <param name="controlNet">The control net.</param>
55+
/// <param name="progressCallback">The progress callback.</param>
56+
/// <param name="cancellationToken">The cancellation token.</param>
57+
/// <returns></returns>
58+
Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
59+
60+
61+
/// <summary>
62+
/// Runs the pipeline batch.
63+
/// </summary>
64+
/// <param name="promptOptions">The prompt options.</param>
65+
/// <param name="schedulerOptions">The scheduler options.</param>
66+
/// <param name="batchOptions">The batch options.</param>
67+
/// <param name="controlNet">The control net.</param>
68+
/// <param name="progressCallback">The progress callback.</param>
69+
/// <param name="cancellationToken">The cancellation token.</param>
70+
/// <returns></returns>
71+
IAsyncEnumerable<BatchResult> RunBatchAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
72+
}
73+
}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Image;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Diffusers;
8+
using OnnxStack.StableDiffusion.Enums;
9+
using OnnxStack.StableDiffusion.Models;
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Runtime.CompilerServices;
13+
using System.Threading;
14+
using System.Threading.Tasks;
15+
16+
namespace OnnxStack.StableDiffusion.Pipelines
17+
{
18+
public abstract class PipelineBase : IPipeline
19+
{
20+
protected readonly ILogger _logger;
21+
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="PipelineBase"/> class.
24+
/// </summary>
25+
/// <param name="logger">The logger.</param>
26+
protected PipelineBase(ILogger logger)
27+
{
28+
_logger = logger;
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the pipelines friendly name.
34+
/// </summary>
35+
public abstract string Name { get; }
36+
37+
38+
/// <summary>
39+
/// Gets the type of the pipeline.
40+
/// </summary>
41+
public abstract DiffuserPipelineType PipelineType { get; }
42+
43+
44+
/// <summary>
45+
/// Gets the pipelines supported diffusers.
46+
/// </summary>
47+
public abstract IReadOnlyList<DiffuserType> SupportedDiffusers { get; }
48+
49+
50+
/// <summary>
51+
/// Gets the pipelines supported schedulers.
52+
/// </summary>
53+
public abstract IReadOnlyList<SchedulerType> SupportedSchedulers { get; }
54+
55+
56+
/// <summary>
57+
/// Loads the pipeline.
58+
/// </summary>
59+
/// <returns></returns>
60+
public abstract Task LoadAsync();
61+
62+
63+
/// <summary>
64+
/// Unloads the pipeline.
65+
/// </summary>
66+
/// <returns></returns>
67+
public abstract Task UnloadAsync();
68+
69+
70+
/// <summary>
71+
/// Validates the inputs.
72+
/// </summary>
73+
/// <param name="promptOptions">The prompt options.</param>
74+
/// <param name="schedulerOptions">The scheduler options.</param>
75+
public abstract void ValidateInputs(PromptOptions promptOptions, SchedulerOptions schedulerOptions);
76+
77+
78+
/// <summary>
79+
/// Runs the pipeline.
80+
/// </summary>
81+
/// <param name="promptOptions">The prompt options.</param>
82+
/// <param name="schedulerOptions">The scheduler options.</param>
83+
/// <param name="controlNet">The control net.</param>
84+
/// <param name="progressCallback">The progress callback.</param>
85+
/// <param name="cancellationToken">The cancellation token.</param>
86+
/// <returns></returns>
87+
public abstract Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
88+
89+
90+
/// <summary>
91+
/// Runs the pipeline batch.
92+
/// </summary>
93+
/// <param name="promptOptions">The prompt options.</param>
94+
/// <param name="schedulerOptions">The scheduler options.</param>
95+
/// <param name="batchOptions">The batch options.</param>
96+
/// <param name="controlNet">The control net.</param>
97+
/// <param name="progressCallback">The progress callback.</param>
98+
/// <param name="cancellationToken">The cancellation token.</param>
99+
/// <returns></returns>
100+
public abstract IAsyncEnumerable<BatchResult> RunBatchAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, ControlNetModel controlNet = default, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default);
101+
102+
103+
/// <summary>
104+
/// Creates the diffuser.
105+
/// </summary>
106+
/// <param name="diffuserType">Type of the diffuser.</param>
107+
/// <param name="controlNetModel">The control net model.</param>
108+
/// <returns></returns>
109+
protected abstract IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNetModel controlNetModel);
110+
111+
112+
/// <summary>
113+
/// Runs the Diffusion process and returns an image.
114+
/// </summary>
115+
/// <param name="promptOptions">The prompt options.</param>
116+
/// <param name="schedulerOptions">The scheduler options.</param>
117+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
118+
/// <param name="performGuidance">if set to <c>true</c> perform guidance (CFG).</param>
119+
/// <param name="progressCallback">The progress callback.</param>
120+
/// <param name="cancellationToken">The cancellation token.</param>
121+
/// <returns></returns>
122+
protected async Task<DenseTensor<float>> DiffuseImageAsync(IDiffuser diffuser, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
123+
{
124+
var diffuseTime = _logger?.LogBegin("Image Diffuser starting...");
125+
var schedulerResult = await diffuser.DiffuseAsync(promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken);
126+
_logger?.LogEnd($"Image Diffuser complete", diffuseTime);
127+
return schedulerResult;
128+
}
129+
130+
131+
/// <summary>
132+
/// Runs the Diffusion process and returns a video.
133+
/// </summary>
134+
/// <param name="promptOptions">The prompt options.</param>
135+
/// <param name="schedulerOptions">The scheduler options.</param>
136+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
137+
/// <param name="performGuidance">if set to <c>true</c> perform guidance (CFG).</param>
138+
/// <param name="progressCallback">The progress callback.</param>
139+
/// <param name="cancellationToken">The cancellation token.</param>
140+
/// <returns></returns>
141+
protected async Task<DenseTensor<float>> DiffuseVideoAsync(IDiffuser diffuser, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
142+
{
143+
var diffuseTime = _logger?.LogBegin("Video Diffuser starting...");
144+
145+
var frameIndex = 0;
146+
DenseTensor<float> videoTensor = null;
147+
var videoFrames = promptOptions.InputVideo.VideoFrames.Frames;
148+
var schedulerFrameCallback = CreateBatchCallback(progressCallback, videoFrames.Count, () => frameIndex);
149+
foreach (var videoFrame in videoFrames)
150+
{
151+
frameIndex++;
152+
promptOptions.InputImage = promptOptions.DiffuserType == DiffuserType.ControlNet ? default : new InputImage(videoFrame);
153+
promptOptions.InputContolImage = promptOptions.DiffuserType == DiffuserType.ImageToImage ? default : new InputImage(videoFrame);
154+
var frameResultTensor = await diffuser.DiffuseAsync(promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerFrameCallback, cancellationToken);
155+
156+
// Frame Progress
157+
ReportBatchProgress(progressCallback, frameIndex, videoFrames.Count, frameResultTensor);
158+
159+
// Concatenate frame
160+
videoTensor = videoTensor.Concatenate(frameResultTensor);
161+
}
162+
163+
_logger?.LogEnd($"Video Diffuser complete", diffuseTime);
164+
return videoTensor;
165+
}
166+
167+
168+
/// <summary>
169+
/// Check if we should run guidance.
170+
/// </summary>
171+
/// <param name="schedulerOptions">The scheduler options.</param>
172+
/// <returns></returns>
173+
protected virtual bool ShouldPerformGuidance(SchedulerOptions schedulerOptions)
174+
{
175+
return schedulerOptions.GuidanceScale > 1f;
176+
}
177+
178+
179+
/// <summary>
180+
/// Reports the progress.
181+
/// </summary>
182+
/// <param name="progressCallback">The progress callback.</param>
183+
/// <param name="progress">The progress.</param>
184+
/// <param name="progressMax">The progress maximum.</param>
185+
/// <param name="subProgress">The sub progress.</param>
186+
/// <param name="subProgressMax">The sub progress maximum.</param>
187+
/// <param name="output">The output.</param>
188+
protected void ReportBatchProgress(Action<DiffusionProgress> progressCallback, int progress, int progressMax, DenseTensor<float> progressTensor)
189+
{
190+
progressCallback?.Invoke(new DiffusionProgress
191+
{
192+
BatchMax = progressMax,
193+
BatchValue = progress,
194+
BatchTensor = progressTensor
195+
});
196+
}
197+
198+
199+
/// <summary>
200+
/// Creates the batch callback.
201+
/// </summary>
202+
/// <param name="progressCallback">The progress callback.</param>
203+
/// <param name="batchCount">The batch count.</param>
204+
/// <param name="batchIndex">Index of the batch.</param>
205+
/// <returns></returns>
206+
protected Action<DiffusionProgress> CreateBatchCallback(Action<DiffusionProgress> progressCallback, int batchCount, Func<int> batchIndex)
207+
{
208+
if (progressCallback == null)
209+
return progressCallback;
210+
211+
return (DiffusionProgress progress) => progressCallback?.Invoke(new DiffusionProgress
212+
{
213+
StepMax = progress.StepMax,
214+
StepValue = progress.StepValue,
215+
StepTensor = progress.StepTensor,
216+
BatchMax = batchCount,
217+
BatchValue = batchIndex()
218+
});
219+
}
220+
}
221+
}

0 commit comments

Comments
 (0)