1+ using Microsoft . Extensions . Logging ;
2+ using Microsoft . ML . OnnxRuntime . Tensors ;
3+ using OnnxStack . Core ;
4+ using OnnxStack . Core . Image ;
5+ using OnnxStack . StableDiffusion . Common ;
6+ using OnnxStack . StableDiffusion . Config ;
7+ using OnnxStack . StableDiffusion . Diffusers ;
8+ using OnnxStack . StableDiffusion . Enums ;
9+ using OnnxStack . StableDiffusion . Models ;
10+ using System ;
11+ using System . Collections . Generic ;
12+ using System . Runtime . CompilerServices ;
13+ using System . Threading ;
14+ using System . Threading . Tasks ;
15+
16+ namespace OnnxStack . StableDiffusion . Pipelines
17+ {
18+ public abstract class PipelineBase : IPipeline
19+ {
20+ protected readonly ILogger _logger ;
21+
22+ /// <summary>
23+ /// Initializes a new instance of the <see cref="PipelineBase"/> class.
24+ /// </summary>
25+ /// <param name="logger">The logger.</param>
26+ protected PipelineBase ( ILogger logger )
27+ {
28+ _logger = logger ;
29+ }
30+
31+
32+ /// <summary>
33+ /// Gets the pipelines friendly name.
34+ /// </summary>
35+ public abstract string Name { get ; }
36+
37+
38+ /// <summary>
39+ /// Gets the type of the pipeline.
40+ /// </summary>
41+ public abstract DiffuserPipelineType PipelineType { get ; }
42+
43+
44+ /// <summary>
45+ /// Gets the pipelines supported diffusers.
46+ /// </summary>
47+ public abstract IReadOnlyList < DiffuserType > SupportedDiffusers { get ; }
48+
49+
50+ /// <summary>
51+ /// Gets the pipelines supported schedulers.
52+ /// </summary>
53+ public abstract IReadOnlyList < SchedulerType > SupportedSchedulers { get ; }
54+
55+
56+ /// <summary>
57+ /// Loads the pipeline.
58+ /// </summary>
59+ /// <returns></returns>
60+ public abstract Task LoadAsync ( ) ;
61+
62+
63+ /// <summary>
64+ /// Unloads the pipeline.
65+ /// </summary>
66+ /// <returns></returns>
67+ public abstract Task UnloadAsync ( ) ;
68+
69+
70+ /// <summary>
71+ /// Validates the inputs.
72+ /// </summary>
73+ /// <param name="promptOptions">The prompt options.</param>
74+ /// <param name="schedulerOptions">The scheduler options.</param>
75+ public abstract void ValidateInputs ( PromptOptions promptOptions , SchedulerOptions schedulerOptions ) ;
76+
77+
78+ /// <summary>
79+ /// Runs the pipeline.
80+ /// </summary>
81+ /// <param name="promptOptions">The prompt options.</param>
82+ /// <param name="schedulerOptions">The scheduler options.</param>
83+ /// <param name="controlNet">The control net.</param>
84+ /// <param name="progressCallback">The progress callback.</param>
85+ /// <param name="cancellationToken">The cancellation token.</param>
86+ /// <returns></returns>
87+ public abstract Task < DenseTensor < float > > RunAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default ) ;
88+
89+
90+ /// <summary>
91+ /// Runs the pipeline batch.
92+ /// </summary>
93+ /// <param name="promptOptions">The prompt options.</param>
94+ /// <param name="schedulerOptions">The scheduler options.</param>
95+ /// <param name="batchOptions">The batch options.</param>
96+ /// <param name="controlNet">The control net.</param>
97+ /// <param name="progressCallback">The progress callback.</param>
98+ /// <param name="cancellationToken">The cancellation token.</param>
99+ /// <returns></returns>
100+ public abstract IAsyncEnumerable < BatchResult > RunBatchAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default ) ;
101+
102+
103+ /// <summary>
104+ /// Creates the diffuser.
105+ /// </summary>
106+ /// <param name="diffuserType">Type of the diffuser.</param>
107+ /// <param name="controlNetModel">The control net model.</param>
108+ /// <returns></returns>
109+ protected abstract IDiffuser CreateDiffuser ( DiffuserType diffuserType , ControlNetModel controlNetModel ) ;
110+
111+
112+ /// <summary>
113+ /// Runs the Diffusion process and returns an image.
114+ /// </summary>
115+ /// <param name="promptOptions">The prompt options.</param>
116+ /// <param name="schedulerOptions">The scheduler options.</param>
117+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
118+ /// <param name="performGuidance">if set to <c>true</c> perform guidance (CFG).</param>
119+ /// <param name="progressCallback">The progress callback.</param>
120+ /// <param name="cancellationToken">The cancellation token.</param>
121+ /// <returns></returns>
122+ protected async Task < DenseTensor < float > > DiffuseImageAsync ( IDiffuser diffuser , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
123+ {
124+ var diffuseTime = _logger ? . LogBegin ( "Image Diffuser starting..." ) ;
125+ var schedulerResult = await diffuser . DiffuseAsync ( promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
126+ _logger ? . LogEnd ( $ "Image Diffuser complete", diffuseTime ) ;
127+ return schedulerResult ;
128+ }
129+
130+
131+ /// <summary>
132+ /// Runs the Diffusion process and returns a video.
133+ /// </summary>
134+ /// <param name="promptOptions">The prompt options.</param>
135+ /// <param name="schedulerOptions">The scheduler options.</param>
136+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
137+ /// <param name="performGuidance">if set to <c>true</c> perform guidance (CFG).</param>
138+ /// <param name="progressCallback">The progress callback.</param>
139+ /// <param name="cancellationToken">The cancellation token.</param>
140+ /// <returns></returns>
141+ protected async Task < DenseTensor < float > > DiffuseVideoAsync ( IDiffuser diffuser , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
142+ {
143+ var diffuseTime = _logger ? . LogBegin ( "Video Diffuser starting..." ) ;
144+
145+ var frameIndex = 0 ;
146+ DenseTensor < float > videoTensor = null ;
147+ var videoFrames = promptOptions . InputVideo . VideoFrames . Frames ;
148+ var schedulerFrameCallback = CreateBatchCallback ( progressCallback , videoFrames . Count , ( ) => frameIndex ) ;
149+ foreach ( var videoFrame in videoFrames )
150+ {
151+ frameIndex ++ ;
152+ promptOptions . InputImage = promptOptions . DiffuserType == DiffuserType . ControlNet ? default : new InputImage ( videoFrame ) ;
153+ promptOptions . InputContolImage = promptOptions . DiffuserType == DiffuserType . ImageToImage ? default : new InputImage ( videoFrame ) ;
154+ var frameResultTensor = await diffuser . DiffuseAsync ( promptOptions , schedulerOptions , promptEmbeddings , performGuidance , schedulerFrameCallback , cancellationToken ) ;
155+
156+ // Frame Progress
157+ ReportBatchProgress ( progressCallback , frameIndex , videoFrames . Count , frameResultTensor ) ;
158+
159+ // Concatenate frame
160+ videoTensor = videoTensor . Concatenate ( frameResultTensor ) ;
161+ }
162+
163+ _logger ? . LogEnd ( $ "Video Diffuser complete", diffuseTime ) ;
164+ return videoTensor ;
165+ }
166+
167+
168+ /// <summary>
169+ /// Check if we should run guidance.
170+ /// </summary>
171+ /// <param name="schedulerOptions">The scheduler options.</param>
172+ /// <returns></returns>
173+ protected virtual bool ShouldPerformGuidance ( SchedulerOptions schedulerOptions )
174+ {
175+ return schedulerOptions . GuidanceScale > 1f ;
176+ }
177+
178+
179+ /// <summary>
180+ /// Reports the progress.
181+ /// </summary>
182+ /// <param name="progressCallback">The progress callback.</param>
183+ /// <param name="progress">The progress.</param>
184+ /// <param name="progressMax">The progress maximum.</param>
185+ /// <param name="subProgress">The sub progress.</param>
186+ /// <param name="subProgressMax">The sub progress maximum.</param>
187+ /// <param name="output">The output.</param>
188+ protected void ReportBatchProgress ( Action < DiffusionProgress > progressCallback , int progress , int progressMax , DenseTensor < float > progressTensor )
189+ {
190+ progressCallback ? . Invoke ( new DiffusionProgress
191+ {
192+ BatchMax = progressMax ,
193+ BatchValue = progress ,
194+ BatchTensor = progressTensor
195+ } ) ;
196+ }
197+
198+
199+ /// <summary>
200+ /// Creates the batch callback.
201+ /// </summary>
202+ /// <param name="progressCallback">The progress callback.</param>
203+ /// <param name="batchCount">The batch count.</param>
204+ /// <param name="batchIndex">Index of the batch.</param>
205+ /// <returns></returns>
206+ protected Action < DiffusionProgress > CreateBatchCallback ( Action < DiffusionProgress > progressCallback , int batchCount , Func < int > batchIndex )
207+ {
208+ if ( progressCallback == null )
209+ return progressCallback ;
210+
211+ return ( DiffusionProgress progress ) => progressCallback ? . Invoke ( new DiffusionProgress
212+ {
213+ StepMax = progress . StepMax ,
214+ StepValue = progress . StepValue ,
215+ StepTensor = progress . StepTensor ,
216+ BatchMax = batchCount ,
217+ BatchValue = batchIndex ( )
218+ } ) ;
219+ }
220+ }
221+ }
0 commit comments