|
13 | 13 | using System.Collections.Generic; |
14 | 14 | using System.Diagnostics; |
15 | 15 | using System.Linq; |
| 16 | +using System.Runtime.CompilerServices; |
16 | 17 | using System.Threading; |
17 | 18 | using System.Threading.Tasks; |
18 | 19 |
|
@@ -86,14 +87,77 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp |
86 | 87 | _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); |
87 | 88 |
|
88 | 89 | // LCM does not support negative prompting |
| 90 | + var performGuidance = false; |
89 | 91 | promptOptions.NegativePrompt = string.Empty; |
90 | 92 |
|
| 93 | + // Process prompts |
| 94 | + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); |
| 95 | + |
| 96 | + // Run Scheduler steps |
| 97 | + var schedulerResult = await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, progressCallback, cancellationToken); |
| 98 | + |
| 99 | + _logger?.LogEnd($"End", diffuseTime); |
| 100 | + |
| 101 | + return schedulerResult; |
| 102 | + } |
| 103 | + |
| 104 | + |
| 105 | + /// <summary> |
| 106 | + /// Runs the stable diffusion batch loop |
| 107 | + /// </summary> |
| 108 | + /// <param name="modelOptions">The model options.</param> |
| 109 | + /// <param name="promptOptions">The prompt options.</param> |
| 110 | + /// <param name="schedulerOptions">The scheduler options.</param> |
| 111 | + /// <param name="batchOptions">The batch options.</param> |
| 112 | + /// <param name="progressCallback">The progress callback.</param> |
| 113 | + /// <param name="cancellationToken">The cancellation token.</param> |
| 114 | + /// <returns></returns> |
| 115 | + /// <exception cref="System.NotImplementedException"></exception> |
| 116 | + public async IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int, int> progressCallback = null, [EnumeratorCancellation]CancellationToken cancellationToken = default) |
| 117 | + { |
| 118 | + var diffuseBatchTime = _logger?.LogBegin("Begin..."); |
| 119 | + _logger?.Log($"Model: {modelOptions.Name}, Pipeline: {modelOptions.PipelineType}, Diffuser: {promptOptions.DiffuserType}, Scheduler: {promptOptions.SchedulerType}"); |
| 120 | + |
| 121 | + // LCM does not support negative prompting |
| 122 | + var performGuidance = false; |
| 123 | + promptOptions.NegativePrompt = string.Empty; |
| 124 | + |
| 125 | + var batchIndex = 1; |
| 126 | + var batchCount = batchOptions.Count; |
| 127 | + var schedulerCallback = (int p, int t) => progressCallback?.Invoke(batchIndex, batchCount, p, t); |
| 128 | + |
| 129 | + // Process prompts |
| 130 | + var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance); |
| 131 | + |
| 132 | + if (batchOptions.BatchType == BatchOptionType.Seed) |
| 133 | + { |
| 134 | + var randomSeeds = Enumerable.Range(0, Math.Max(1, batchOptions.Count)).Select(x => Random.Shared.Next()); |
| 135 | + foreach (var randomSeed in randomSeeds) |
| 136 | + { |
| 137 | + schedulerOptions.Seed = randomSeed; |
| 138 | + yield return await RunSchedulerSteps(modelOptions, promptOptions, schedulerOptions, promptEmbeddings, performGuidance, schedulerCallback, cancellationToken); |
| 139 | + batchIndex++; |
| 140 | + } |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + |
| 145 | + /// <summary> |
| 146 | + /// Runs the scheduler steps. |
| 147 | + /// </summary> |
| 148 | + /// <param name="modelOptions">The model options.</param> |
| 149 | + /// <param name="promptOptions">The prompt options.</param> |
| 150 | + /// <param name="schedulerOptions">The scheduler options.</param> |
| 151 | + /// <param name="promptEmbeddings">The prompt embeddings.</param> |
| 152 | + /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param> |
| 153 | + /// <param name="progressCallback">The progress callback.</param> |
| 154 | + /// <param name="cancellationToken">The cancellation token.</param> |
| 155 | + /// <returns></returns> |
| 156 | + protected virtual async Task<DenseTensor<float>> RunSchedulerSteps(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default) |
| 157 | + { |
91 | 158 | // Get Scheduler |
92 | 159 | using (var scheduler = GetScheduler(promptOptions, schedulerOptions)) |
93 | 160 | { |
94 | | - // Process prompts |
95 | | - var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, false); |
96 | | - |
97 | 161 | // Get timesteps |
98 | 162 | var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler); |
99 | 163 |
|
@@ -137,30 +201,11 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp |
137 | 201 | } |
138 | 202 |
|
139 | 203 | // Decode Latents |
140 | | - var result = await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised); |
141 | | - _logger?.LogEnd($"End", diffuseTime); |
142 | | - return result; |
| 204 | + return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised); |
143 | 205 | } |
144 | 206 | } |
145 | 207 |
|
146 | 208 |
|
147 | | - /// <summary> |
148 | | - /// Runs the stable diffusion batch loop |
149 | | - /// </summary> |
150 | | - /// <param name="modelOptions">The model options.</param> |
151 | | - /// <param name="promptOptions">The prompt options.</param> |
152 | | - /// <param name="schedulerOptions">The scheduler options.</param> |
153 | | - /// <param name="batchOptions">The batch options.</param> |
154 | | - /// <param name="progressCallback">The progress callback.</param> |
155 | | - /// <param name="cancellationToken">The cancellation token.</param> |
156 | | - /// <returns></returns> |
157 | | - /// <exception cref="System.NotImplementedException"></exception> |
158 | | - public IAsyncEnumerable<DenseTensor<float>> DiffuseBatchAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<int, int, int> progressCallback = null, CancellationToken cancellationToken = default) |
159 | | - { |
160 | | - throw new NotImplementedException(); |
161 | | - } |
162 | | - |
163 | | - |
164 | 209 | /// <summary> |
165 | 210 | /// Decodes the latents. |
166 | 211 | /// </summary> |
@@ -279,6 +324,6 @@ protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params Name |
279 | 324 | return parameters.ToList(); |
280 | 325 | } |
281 | 326 |
|
282 | | - |
| 327 | + |
283 | 328 | } |
284 | 329 | } |
0 commit comments