Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit e4f86a2

Browse files
committed
PromptService support for dual Encoders, Tokenizers
1 parent 990ea25 commit e4f86a2

File tree

10 files changed

+150
-32
lines changed

10 files changed

+150
-32
lines changed

OnnxStack.StableDiffusion/Common/IPromptService.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ namespace OnnxStack.StableDiffusion.Common
66
{
77
public interface IPromptService
88
{
9-
Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled);
10-
Task<int[]> DecodeTextAsync(IModelOptions model, string inputText);
11-
Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenizedInput);
9+
Task<PromptEmbeddingsResult> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled);
1210
}
1311
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
3+
namespace OnnxStack.StableDiffusion.Common
4+
{
5+
public record PromptEmbeddingsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> PooledPromptEmbeds = default);
6+
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
8888
/// <param name="progressCallback">The progress callback.</param>
8989
/// <param name="cancellationToken">The cancellation token.</param>
9090
/// <returns></returns>
91-
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
91+
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default);
9292

9393

9494
/// <summary>

OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService prom
4545
/// <param name="progressCallback">The progress callback.</param>
4646
/// <param name="cancellationToken">The cancellation token.</param>
4747
/// <returns></returns>
48-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
48+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
4949
{
5050
// Get Scheduler
5151
using (var scheduler = GetScheduler(schedulerOptions))
@@ -81,7 +81,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8181
{
8282
inferenceParameters.AddInputTensor(inputTensor);
8383
inferenceParameters.AddInputTensor(timestepTensor);
84-
inferenceParameters.AddInputTensor(promptEmbeddings);
84+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
8585
inferenceParameters.AddOutputBuffer(outputDimension);
8686

8787
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, ISc
6565
/// <param name="progressCallback">The progress callback.</param>
6666
/// <param name="cancellationToken">The cancellation token.</param>
6767
/// <returns></returns>
68-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
68+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
6969
{
7070
using (var scheduler = GetScheduler(schedulerOptions))
7171
{
@@ -111,7 +111,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
111111
{
112112
inferenceParameters.AddInputTensor(inputTensor);
113113
inferenceParameters.AddInputTensor(timestepTensor);
114-
inferenceParameters.AddInputTensor(promptEmbeddings);
114+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
115115
inferenceParameters.AddInputTensor(guidanceEmbeddings);
116116
inferenceParameters.AddOutputBuffer(outputDimension);
117117

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ protected override bool ShouldPerformGuidance(SchedulerOptions schedulerOptions)
8888
/// <param name="progressCallback">The progress callback.</param>
8989
/// <param name="cancellationToken">The cancellation token.</param>
9090
/// <returns></returns>
91-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
91+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
9292
{
9393
// Get Scheduler
9494
using (var scheduler = GetScheduler(schedulerOptions))
@@ -126,7 +126,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
126126
{
127127
inferenceParameters.AddInputTensor(inputTensor);
128128
inferenceParameters.AddInputTensor(timestepTensor);
129-
inferenceParameters.AddInputTensor(promptEmbeddings);
129+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
130130
inferenceParameters.AddInputTensor(guidanceEmbeddings);
131131
inferenceParameters.AddOutputBuffer(outputDimension);
132132

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public InpaintDiffuser(IOnnxModelService onnxModelService, IPromptService prompt
5353
/// <param name="progressCallback">The progress callback.</param>
5454
/// <param name="cancellationToken">The cancellation token.</param>
5555
/// <returns></returns>
56-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
56+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
5757
{
5858
// Get Scheduler
5959
using (var scheduler = GetScheduler(schedulerOptions))
@@ -93,7 +93,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
9393
{
9494
inferenceParameters.AddInputTensor(inputTensor);
9595
inferenceParameters.AddInputTensor(timestepTensor);
96-
inferenceParameters.AddInputTensor(promptEmbeddings);
96+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
9797
inferenceParameters.AddOutputBuffer(outputDimension);
9898

9999
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService
5050
/// <param name="progressCallback">The progress callback.</param>
5151
/// <param name="cancellationToken">The cancellation token.</param>
5252
/// <returns></returns>
53-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
53+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
5454
{
5555
using (var scheduler = GetScheduler(schedulerOptions))
5656
{
@@ -91,7 +91,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
9191
{
9292
inferenceParameters.AddInputTensor(inputTensor);
9393
inferenceParameters.AddInputTensor(timestepTensor);
94-
inferenceParameters.AddInputTensor(promptEmbeddings);
94+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
9595
inferenceParameters.AddOutputBuffer(outputDimension);
9696

9797
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public StableDiffusionDiffuser(IOnnxModelService onnxModelService, IPromptServic
4848
/// <param name="progressCallback">The progress callback.</param>
4949
/// <param name="cancellationToken">The cancellation token.</param>
5050
/// <returns></returns>
51-
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
51+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
5252
{
5353
// Get Scheduler
5454
using (var scheduler = GetScheduler(schedulerOptions))
@@ -81,7 +81,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
8181
{
8282
inferenceParameters.AddInputTensor(inputTensor);
8383
inferenceParameters.AddInputTensor(timestepTensor);
84-
inferenceParameters.AddInputTensor(promptEmbeddings);
84+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
8585
inferenceParameters.AddOutputBuffer(outputDimension);
8686

8787
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);

0 commit comments

Comments
 (0)