Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 4ea907c

Browse files
committed
Add ControlNet pipeline scaffold
1 parent 3c12247 commit 4ea907c

File tree

5 files changed

+216
-0
lines changed

5 files changed

+216
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Services;
4+
using OnnxStack.StableDiffusion.Common;
5+
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
7+
using OnnxStack.StableDiffusion.Models;
8+
using OnnxStack.StableDiffusion.Schedulers.StableDiffusion;
9+
using System;
10+
using System.Threading;
11+
using System.Threading.Tasks;
12+
13+
namespace OnnxStack.StableDiffusion.Diffusers.ControlNet
14+
{
15+
public abstract class ControlNetDiffuser : DiffuserBase
16+
{
17+
/// <summary>
18+
/// Initializes a new instance of the <see cref="ControlNetDiffuser"/> class.
19+
/// </summary>
20+
/// <param name="configuration">The configuration.</param>
21+
/// <param name="onnxModelService">The onnx model service.</param>
22+
public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<ControlNetDiffuser> logger)
23+
: base(onnxModelService, promptService, logger) { }
24+
25+
26+
/// <summary>
27+
/// Gets the type of the pipeline.
28+
/// </summary>
29+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.ControlNet;
30+
31+
32+
/// <summary>
33+
/// Called on each Scheduler step.
34+
/// </summary>
35+
/// <param name="modelOptions">The model options.</param>
36+
/// <param name="promptOptions">The prompt options.</param>
37+
/// <param name="schedulerOptions">The scheduler options.</param>
38+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
39+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
40+
/// <param name="progressCallback">The progress callback.</param>
41+
/// <param name="cancellationToken">The cancellation token.</param>
42+
/// <returns></returns>
43+
/// <exception cref="System.NotImplementedException"></exception>
44+
protected override Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
45+
{
46+
throw new NotImplementedException();
47+
}
48+
49+
50+
/// <summary>
51+
/// Gets the scheduler.
52+
/// </summary>
53+
/// <param name="options">The options.</param>
54+
/// <param name="schedulerConfig">The scheduler configuration.</param>
55+
/// <returns></returns>
56+
protected override IScheduler GetScheduler(SchedulerOptions options)
57+
{
58+
return options.SchedulerType switch
59+
{
60+
SchedulerType.LMS => new LMSScheduler(options),
61+
SchedulerType.Euler => new EulerScheduler(options),
62+
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
63+
SchedulerType.DDPM => new DDPMScheduler(options),
64+
SchedulerType.DDIM => new DDIMScheduler(options),
65+
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
66+
_ => default
67+
};
68+
}
69+
}
70+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime;
3+
using Microsoft.ML.OnnxRuntime.Tensors;
4+
using OnnxStack.Core;
5+
using OnnxStack.Core.Config;
6+
using OnnxStack.Core.Model;
7+
using OnnxStack.Core.Services;
8+
using OnnxStack.StableDiffusion.Common;
9+
using OnnxStack.StableDiffusion.Config;
10+
using OnnxStack.StableDiffusion.Enums;
11+
using OnnxStack.StableDiffusion.Helpers;
12+
using SixLabors.ImageSharp;
13+
using System;
14+
using System.Collections.Generic;
15+
using System.Linq;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Diffusers.ControlNet
19+
{
20+
public sealed class ImageDiffuser : ControlNetDiffuser
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<ImageDiffuser> logger)
28+
: base(onnxModelService, promptService, logger)
29+
{
30+
}
31+
32+
33+
/// <summary>
34+
/// Gets the type of the diffuser.
35+
/// </summary>
36+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
37+
38+
39+
/// <summary>
40+
/// Gets the timesteps.
41+
/// </summary>
42+
/// <param name="prompt">The prompt.</param>
43+
/// <param name="options">The options.</param>
44+
/// <param name="scheduler">The scheduler.</param>
45+
/// <returns></returns>
46+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
47+
{
48+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
49+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
50+
return scheduler.Timesteps.Skip(start).ToList();
51+
}
52+
53+
54+
/// <summary>
55+
/// Prepares the latents for inference.
56+
/// </summary>
57+
/// <param name="prompt">The prompt.</param>
58+
/// <param name="options">The options.</param>
59+
/// <param name="scheduler">The scheduler.</param>
60+
/// <returns></returns>
61+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
62+
{
63+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
64+
65+
//TODO: Model Config, Channels
66+
var outputDimension = options.GetScaledDimension();
67+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
68+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
69+
{
70+
inferenceParameters.AddInputTensor(imageTensor);
71+
inferenceParameters.AddOutputBuffer(outputDimension);
72+
73+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
74+
using (var result = results.First())
75+
{
76+
var outputResult = result.ToDenseTensor();
77+
var scaledSample = outputResult.MultiplyBy(model.ScaleFactor);
78+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
79+
}
80+
}
81+
}
82+
83+
}
84+
}

OnnxStack.StableDiffusion/Enums/DiffuserPipelineType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ public enum DiffuserPipelineType
66
StableDiffusionXL = 1,
77
LatentConsistency = 10,
88
LatentConsistencyXL = 11,
9+
ControlNet = 20,
910
InstaFlow = 30,
1011
}
1112
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Microsoft.Extensions.Logging;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Common;
4+
using OnnxStack.StableDiffusion.Diffusers;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace OnnxStack.StableDiffusion.Pipelines
11+
{
12+
public sealed class ControlNetPipeline : IPipeline
13+
{
14+
private readonly DiffuserPipelineType _pipelineType;
15+
private readonly ILogger<ControlNetPipeline> _logger;
16+
private readonly ConcurrentDictionary<DiffuserType, IDiffuser> _diffusers;
17+
18+
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="ControlNetPipeline"/> class.
21+
/// </summary>
22+
/// <param name="onnxModelService">The onnx model service.</param>
23+
/// <param name="promptService">The prompt service.</param>
24+
public ControlNetPipeline(IEnumerable<IDiffuser> diffusers, ILogger<ControlNetPipeline> logger)
25+
{
26+
_logger = logger;
27+
_pipelineType = DiffuserPipelineType.ControlNet;
28+
_diffusers = diffusers
29+
.Where(x => x.PipelineType == _pipelineType)
30+
.ToConcurrentDictionary(k => k.DiffuserType, v => v);
31+
}
32+
33+
34+
/// <summary>
35+
/// Gets the type of the pipeline.
36+
/// </summary>
37+
public DiffuserPipelineType PipelineType => _pipelineType;
38+
39+
40+
/// <summary>
41+
/// Gets the diffusers.
42+
/// </summary>
43+
public ConcurrentDictionary<DiffuserType, IDiffuser> Diffusers => _diffusers;
44+
45+
46+
/// <summary>
47+
/// Gets the diffuser.
48+
/// </summary>
49+
/// <param name="diffuserType">Type of the diffuser.</param>
50+
/// <returns></returns>
51+
public IDiffuser GetDiffuser(DiffuserType diffuserType)
52+
{
53+
_diffusers.TryGetValue(diffuserType, out var diffuser);
54+
return diffuser;
55+
}
56+
}
57+
}

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
5555
serviceCollection.AddSingleton<IPipeline, LatentConsistencyPipeline>();
5656
serviceCollection.AddSingleton<IPipeline, LatentConsistencyXLPipeline>();
5757
serviceCollection.AddSingleton<IPipeline, InstaFlowPipeline>();
58+
serviceCollection.AddSingleton<IPipeline, ControlNetPipeline>();
5859

5960
//StableDiffusion
6061
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusion.TextDiffuser>();
@@ -79,6 +80,9 @@ private static void RegisterServices(this IServiceCollection serviceCollection)
7980

8081
//InstaFlow
8182
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.InstaFlow.TextDiffuser>();
83+
84+
//ControlNet
85+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.ControlNet.ImageDiffuser>();
8286
}
8387

8488

0 commit comments

Comments
 (0)