Skip to content

Commit cc432ee

Browse files
committed
StableDiffusionXLPipeline
1 parent 6403c90 commit cc432ee

File tree

10 files changed

+818
-16
lines changed

10 files changed

+818
-16
lines changed

TensorStack.Common/Extensions/TensorExtensions.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,17 @@ public static Tensor<float> ClipTo(this Tensor<float> tensor, float minValue, fl
328328
}
329329

330330

331+
/// <summary>
332+
/// Split first tensor from batch and return
333+
/// </summary>
334+
/// <param name="tensor">The tensor.</param>
335+
/// <returns></returns>
336+
public static Tensor<T> FirstBatch<T>(this Tensor<T> tensor)
337+
{
338+
return Split(tensor).FirstOrDefault();
339+
}
340+
341+
331342
/// <summary>
332343
/// Reshapes to new tensor.
333344
/// </summary>
@@ -529,7 +540,7 @@ public static Tensor<T> Permute<T>(this Tensor<T> tensor, int[] permutation)
529540
/// <param name="axis">The axis.</param>
530541
/// <returns>IEnumerable&lt;Tensor&lt;System.Single&gt;&gt;.</returns>
531542
/// <exception cref="NotImplementedException">Only axis 0 is supported</exception>
532-
public static IEnumerable<Tensor<float>> Split(this Tensor<float> tensor, int axis = 0)
543+
public static IEnumerable<Tensor<T>> Split<T>(this Tensor<T> tensor, int axis = 0)
533544
{
534545
if (axis != 0)
535546
throw new NotImplementedException("Only axis 0 is supported");
@@ -542,7 +553,7 @@ public static IEnumerable<Tensor<float>> Split(this Tensor<float> tensor, int ax
542553
for (int i = 0; i < count; i++)
543554
{
544555
var start = i * newLength;
545-
yield return new Tensor<float>(tensor.Memory.Slice(start, newLength), dimensions);
556+
yield return new Tensor<T>(tensor.Memory.Slice(start, newLength), dimensions);
546557
}
547558
}
548559

TensorStack.StableDiffusion/Common/GenerateOptions.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
3030
public ImageTensor InputControlImage { get; set; }
3131
public ControlNetModel ControlNet { get; set; }
3232

33+
public int ClipSkip { get; set; }
34+
public float AestheticScore { get; set; } = 6f;
35+
public float AestheticNegativeScore { get; set; } = 2.5f;
36+
3337

3438
public bool IsLowMemoryEnabled { get; set; }
3539
public bool IsLowMemoryComputeEnabled { get; set; }
@@ -67,8 +71,6 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
6771
public float MaximumBeta { get; set; } = 0.999f;
6872
public List<int> Timesteps { get; set; }
6973
public int TrainSteps { get; set; } = 50;
70-
public float AestheticScore { get; set; } = 6f;
71-
public float AestheticNegativeScore { get; set; } = 2.5f;
7274
public float Shift { get; set; } = 1f;
7375

7476
#endregion

TensorStack.StableDiffusion/Common/TextEncoderResult.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ public TextEncoderResult(Tensor<float> hiddenStates, Tensor<float> textEmbeds)
2020
: this([hiddenStates], textEmbeds) { }
2121

2222

23-
public Tensor<float> GetTextEmbeds()
24-
{
25-
return _textEmbeds;
26-
}
23+
public Tensor<float> TextEmbeds => _textEmbeds;
24+
public Tensor<float> HiddenStates => _hiddenStates[0];
25+
2726

28-
public Tensor<float> GetHiddenStates(int index = 0)
27+
public Tensor<float> GetHiddenStates(int index)
2928
{
3029
if (index > 0)
3130
return _hiddenStates[^index];

TensorStack.StableDiffusion/Enums/PipelineType.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ public enum PipelineType
66
{
77
StableDiffusion = 0,
88
StableDiffusion2 = 1,
9-
LatentConsistency = 10
9+
StableDiffusionXL = 2,
10+
LatentConsistency = 10,
11+
1012
}
1113
}

TensorStack.StableDiffusion/Helpers/PromptParser.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public static async Task<TextEncoderResult> EncodePromptAsync(CLIPTextModel text
6060
{
6161
var textEncoderResult = await textEncoder.RunAsync(inputTokens, cancellationToken);
6262
ApplyPromptWeights(inputTokens, textEncoderResult);
63-
return new TextEncoderResult(textEncoderResult.GetHiddenStates(hiddenStateIndex), textEncoderResult.GetTextEmbeds());
63+
return new TextEncoderResult(textEncoderResult.GetHiddenStates(hiddenStateIndex), textEncoderResult.TextEmbeds);
6464
}
6565
else
6666
{
@@ -127,7 +127,7 @@ public static async Task<TextEncoderResult> EncodePromptAsync(CLIPTextModel text
127127
promptEmbeds.AddRange(output);
128128
}
129129

130-
promptPooledEmbeds.AddRange(result.GetTextEmbeds().Span);
130+
promptPooledEmbeds.AddRange(result.TextEmbeds.Span);
131131
}
132132

133133
var hiddenStates = new Tensor<float>(promptEmbeds.ToArray(), [1, promptEmbeds.Count / textEncoder.HiddenSize, textEncoder.HiddenSize]);
@@ -146,7 +146,7 @@ public static async Task<TextEncoderResult> EncodePromptAsync(CLIPTextModel text
146146
/// <param name="encoderOutput">The encoder output.</param>
147147
public static void ApplyPromptWeights(TokenizerResult tokenizerOutput, TextEncoderResult encoderOutput)
148148
{
149-
var hiddenStates = encoderOutput.GetHiddenStates();
149+
var hiddenStates = encoderOutput.HiddenStates;
150150
var numTokens = hiddenStates.Dimensions[1];
151151
var embedDim = hiddenStates.Dimensions[2];
152152
var weights = tokenizerOutput.Weights.Pad(1, numTokens);

TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ public interface IPipelineOptions : IRunOptions
2626
float ControlNetStrength { get; set; }
2727
ImageTensor InputControlImage { get; set; }
2828

29+
int ClipSkip{ get; set; }
30+
float AestheticScore { get; set; }
31+
float AestheticNegativeScore { get; set; }
2932

3033
bool IsLowMemoryEnabled { get; set; }
3134
bool IsLowMemoryComputeEnabled { get; set; }

TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionPipeline.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ private async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, Can
159159
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
160160
await TextEncoder.UnloadAsync();
161161

162-
return new PromptResult(promptEmbeddings.GetHiddenStates(), promptEmbeddings.GetTextEmbeds(), negativePromptEmbeddings.GetHiddenStates(), negativePromptEmbeddings.GetTextEmbeds());
162+
return new PromptResult(promptEmbeddings.HiddenStates, promptEmbeddings.TextEmbeds, negativePromptEmbeddings.HiddenStates, negativePromptEmbeddings.TextEmbeds);
163163
}
164164

165165

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System.IO;
4+
using TensorStack.Common;
5+
using TensorStack.StableDiffusion.Config;
6+
using TensorStack.StableDiffusion.Enums;
7+
8+
namespace TensorStack.StableDiffusion.Pipelines.StableDiffusionXL
9+
{
10+
public record StableDiffusionXLConfig : PipelineConfig
11+
{
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="StableDiffusionXLConfig"/> class.
14+
/// </summary>
15+
public StableDiffusionXLConfig()
16+
{
17+
Tokenizer = new TokenizerConfig();
18+
Tokenizer2 = new TokenizerConfig();
19+
TextEncoder = new CLIPModelConfig { HiddenSize = 768 };
20+
TextEncoder2 = new CLIPModelConfig { HiddenSize = 1280 };
21+
Unet = new UNetModelConfig { IsOptimizationSupported = true };
22+
AutoEncoder = new AutoEncoderModelConfig { ScaleFactor = 0.13025f };
23+
}
24+
25+
public string Name { get; init; } = "StableDiffusionXL";
26+
public override PipelineType Pipeline { get; } = PipelineType.StableDiffusionXL;
27+
public TokenizerConfig Tokenizer { get; init; }
28+
public TokenizerConfig Tokenizer2 { get; init; }
29+
public CLIPModelConfig TextEncoder { get; init; }
30+
public CLIPModelConfig TextEncoder2 { get; init; }
31+
public UNetModelConfig Unet { get; init; }
32+
public AutoEncoderModelConfig AutoEncoder { get; init; }
33+
34+
35+
/// <summary>
36+
/// Sets the execution provider for all models.
37+
/// </summary>
38+
/// <param name="executionProvider">The execution provider.</param>
39+
public override void SetProvider(ExecutionProvider executionProvider)
40+
{
41+
Tokenizer.SetProvider(executionProvider);
42+
Tokenizer2.SetProvider(executionProvider);
43+
TextEncoder.SetProvider(executionProvider);
44+
TextEncoder2.SetProvider(executionProvider);
45+
Unet.SetProvider(executionProvider);
46+
AutoEncoder.SetProvider(executionProvider);
47+
}
48+
49+
50+
/// <summary>
51+
/// Saves the configuration to file.
52+
/// </summary>
53+
/// <param name="configFile">The configuration file.</param>
54+
/// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param>
55+
public override void Save(string configFile, bool useRelativePaths = true)
56+
{
57+
ConfigService.Serialize(configFile, this, useRelativePaths);
58+
}
59+
60+
61+
/// <summary>
62+
/// Create StableDiffusion configuration from default values
63+
/// </summary>
64+
/// <param name="name">The name.</param>
65+
/// <param name="modelType">Type of the model.</param>
66+
/// <param name="executionProvider">The execution provider.</param>
67+
/// <returns>StableDiffusionXLConfig.</returns>
68+
public static StableDiffusionXLConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
69+
{
70+
var config = new StableDiffusionXLConfig { Name = name };
71+
config.Unet.ModelType = modelType;
72+
config.SetProvider(executionProvider);
73+
return config;
74+
}
75+
76+
77+
/// <summary>
78+
/// Create StableDiffusion configuration from json file
79+
/// </summary>
80+
/// <param name="configFile">The configuration file.</param>
81+
/// <param name="executionProvider">The execution provider.</param>
82+
/// <returns>StableDiffusionXLConfig.</returns>
83+
public static StableDiffusionXLConfig FromFile(string configFile, ExecutionProvider executionProvider = default)
84+
{
85+
var config = ConfigService.Deserialize<StableDiffusionXLConfig>(configFile);
86+
config.SetProvider(executionProvider);
87+
return config;
88+
}
89+
90+
91+
/// <summary>
92+
/// Create StableDiffusion configuration from folder structure
93+
/// </summary>
94+
/// <param name="modelFolder">The model folder.</param>
95+
/// <param name="modelType">Type of the model.</param>
96+
/// <param name="executionProvider">The execution provider.</param>
97+
/// <returns>StableDiffusionXLConfig.</returns>
98+
public static StableDiffusionXLConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
99+
{
100+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
101+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
102+
config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "vocab.json");
103+
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
104+
config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx");
105+
config.Unet.Path = Path.Combine(modelFolder, "unet", "model.onnx");
106+
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
107+
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
108+
var controlNetPath = Path.Combine(modelFolder, "unet", "controlnet.onnx");
109+
if (File.Exists(controlNetPath))
110+
config.Unet.ControlNetPath = controlNetPath;
111+
return config;
112+
}
113+
114+
}
115+
}

0 commit comments

Comments
 (0)