11using Microsoft . ML . OnnxRuntime ;
22using Microsoft . ML . OnnxRuntime . Tensors ;
3- using OnnxStack . Core ;
43using OnnxStack . Core . Config ;
54using OnnxStack . Core . Services ;
65using OnnxStack . StableDiffusion . Common ;
76using OnnxStack . StableDiffusion . Config ;
87using OnnxStack . StableDiffusion . Enums ;
98using OnnxStack . StableDiffusion . Helpers ;
109using OnnxStack . StableDiffusion . Schedulers ;
11- using SixLabors . ImageSharp ;
1210using System ;
1311using System . Collections . Generic ;
1412using System . Linq ;
1513using System . Threading ;
1614using System . Threading . Tasks ;
1715
18-
19- namespace OnnxStack . StableDiffusion . Services
16+ namespace OnnxStack . StableDiffusion . Diffusers
2017{
21- public sealed class SchedulerService : ISchedulerService
18+ public abstract class DiffuserBase : IDiffuser
2219 {
23- private readonly IPromptService _promptService ;
24- private readonly OnnxStackConfig _configuration ;
25- private readonly IOnnxModelService _onnxModelService ;
20+ protected readonly IPromptService _promptService ;
21+ protected readonly OnnxStackConfig _configuration ;
22+ protected readonly IOnnxModelService _onnxModelService ;
2623
2724 /// <summary>
28- /// Initializes a new instance of the <see cref="SchedulerService "/> class.
25+ /// Initializes a new instance of the <see cref="DiffuserBase "/> class.
2926 /// </summary>
3027 /// <param name="configuration">The configuration.</param>
3128 /// <param name="onnxModelService">The onnx model service.</param>
32- public SchedulerService ( IOnnxModelService onnxModelService , IPromptService promptService )
29+ public DiffuserBase ( IOnnxModelService onnxModelService , IPromptService promptService )
3330 {
3431 _promptService = promptService ;
3532 _onnxModelService = onnxModelService ;
@@ -38,12 +35,34 @@ public SchedulerService(IOnnxModelService onnxModelService, IPromptService promp
3835
3936
4037 /// <summary>
41- /// Runs the Stable Diffusion inference.
38+ /// Gets the timesteps.
39+ /// </summary>
40+ /// <param name="prompt">The prompt.</param>
41+ /// <param name="options">The options.</param>
42+ /// <param name="scheduler">The scheduler.</param>
43+ /// <returns></returns>
44+ protected abstract IReadOnlyList < int > GetTimesteps ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler ) ;
45+
46+ /// <summary>
47+ /// Prepares the latents.
48+ /// </summary>
49+ /// <param name="prompt">The prompt.</param>
50+ /// <param name="options">The options.</param>
51+ /// <param name="scheduler">The scheduler.</param>
52+ /// <param name="timesteps">The timesteps.</param>
53+ /// <returns></returns>
54+ protected abstract DenseTensor < float > PrepareLatents ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps ) ;
55+
56+
57+ /// <summary>
58+ /// Rund the stable diffusion loop
4259 /// </summary>
43- /// <param name="promptOptions">The options.</param>
44- /// <param name="schedulerOptions">The scheduler configuration.</param>
60+ /// <param name="promptOptions">The prompt options.</param>
61+ /// <param name="schedulerOptions">The scheduler options.</param>
62+ /// <param name="progress">The progress.</param>
63+ /// <param name="cancellationToken">The cancellation token.</param>
4564 /// <returns></returns>
46- public async Task < DenseTensor < float > > RunAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < int , int > progress = null , CancellationToken cancellationToken = default )
65+ public virtual async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , Action < int , int > progress = null , CancellationToken cancellationToken = default )
4766 {
4867 // Create random seed if none was set
4968 schedulerOptions . Seed = schedulerOptions . Seed > 0 ? schedulerOptions . Seed : Random . Shared . Next ( ) ;
@@ -103,53 +122,13 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
103122 }
104123 }
105124
106- private IReadOnlyList < int > GetTimesteps ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler )
107- {
108- if ( ! prompt . HasInputImage )
109- return scheduler . Timesteps ;
110-
111- // Image2Image we narrow step the range by the Strength
112- var inittimestep = Math . Min ( ( int ) ( options . InferenceSteps * options . Strength ) , options . InferenceSteps ) ;
113- var start = Math . Max ( options . InferenceSteps - inittimestep , 0 ) ;
114- return scheduler . Timesteps . Skip ( start ) . ToList ( ) ;
115- }
116-
117- /// <summary>
118- /// Prepares the latents for inference.
119- /// </summary>
120- /// <param name="prompt">The prompt.</param>
121- /// <param name="options">The options.</param>
122- /// <param name="scheduler">The scheduler.</param>
123- /// <returns></returns>
124- private DenseTensor < float > PrepareLatents ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps )
125- {
126- // If we dont have an initial image create random sample
127- if ( ! prompt . HasInputImage )
128- return scheduler . CreateRandomSample ( options . GetScaledDimension ( ) , scheduler . InitNoiseSigma ) ;
129-
130- // Image input, decode, add noise, return as latent 0
131- var imageTensor = prompt . InputImage . ToDenseTensor ( options . Width , options . Height ) ;
132- var inputNames = _onnxModelService . GetInputNames ( OnnxModelType . VaeEncoder ) ;
133- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , imageTensor ) ) ;
134- using ( var inferResult = _onnxModelService . RunInference ( OnnxModelType . VaeEncoder , inputParameters ) )
135- {
136- var sample = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
137- var noisySample = sample
138- . AddTensors ( scheduler . CreateRandomSample ( sample . Dimensions , options . InitialNoiseLevel ) )
139- . MultipleTensorByFloat ( _configuration . ScaleFactor ) ;
140- var noise = scheduler . CreateRandomSample ( sample . Dimensions ) ;
141- return scheduler . AddNoise ( noisySample , noise , timesteps ) ;
142- }
143- }
144-
145-
146125 /// <summary>
147126 /// Decodes the latents.
148127 /// </summary>
149128 /// <param name="options">The options.</param>
150129 /// <param name="latents">The latents.</param>
151130 /// <returns></returns>
152- private async Task < DenseTensor < float > > DecodeLatents ( SchedulerOptions options , DenseTensor < float > latents )
131+ protected async Task < DenseTensor < float > > DecodeLatents ( SchedulerOptions options , DenseTensor < float > latents )
153132 {
154133 // Scale and decode the image latents with vae.
155134 // latents = 1 / 0.18215 * latents
@@ -181,7 +160,7 @@ private async Task<DenseTensor<float>> DecodeLatents(SchedulerOptions options, D
181160 /// <returns>
182161 /// <c>true</c> if the specified result image is safe; otherwise, <c>false</c>.
183162 /// </returns>
184- private async Task < bool > IsImageSafe ( SchedulerOptions options , DenseTensor < float > resultImage )
163+ protected async Task < bool > IsImageSafe ( SchedulerOptions options , DenseTensor < float > resultImage )
185164 {
186165 //clip input
187166 var inputTensor = ClipImageFeatureExtractor ( options , resultImage ) ;
@@ -207,7 +186,7 @@ private async Task<bool> IsImageSafe(SchedulerOptions options, DenseTensor<float
207186 /// </summary>
208187 /// <param name="imageTensor">The image tensor.</param>
209188 /// <returns></returns>
210- private static DenseTensor < float > ClipImageFeatureExtractor ( SchedulerOptions options , DenseTensor < float > imageTensor )
189+ protected static DenseTensor < float > ClipImageFeatureExtractor ( SchedulerOptions options , DenseTensor < float > imageTensor )
211190 {
212191 //convert tensor result to image
213192 using ( var image = imageTensor . ToImage ( ) )
@@ -243,7 +222,7 @@ private static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions opt
243222 /// <param name="options">The options.</param>
244223 /// <param name="schedulerConfig">The scheduler configuration.</param>
245224 /// <returns></returns>
246- private static IScheduler GetScheduler ( PromptOptions prompt , SchedulerOptions options )
225+ protected static IScheduler GetScheduler ( PromptOptions prompt , SchedulerOptions options )
247226 {
248227 return prompt . SchedulerType switch
249228 {
@@ -259,7 +238,7 @@ private static IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions op
259238 /// </summary>
260239 /// <param name="parameters">The parameters.</param>
261240 /// <returns></returns>
262- private static IReadOnlyCollection < NamedOnnxValue > CreateInputParameters ( params NamedOnnxValue [ ] parameters )
241+ protected static IReadOnlyCollection < NamedOnnxValue > CreateInputParameters ( params NamedOnnxValue [ ] parameters )
263242 {
264243 return parameters . ToList ( ) . AsReadOnly ( ) ;
265244 }
0 commit comments