Skip to content

Commit 26bfdc1

Browse files
committed
StableDiffusion3Pipeline
1 parent a669f3f commit 26bfdc1

File tree

15 files changed

+1085
-10
lines changed

15 files changed

+1085
-10
lines changed

TensorStack.Common/Extensions/TensorExtensions.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,40 @@ public static Tensor<float> Join(this IEnumerable<Tensor<float>> tensors, int ax
589589
}
590590

591591

592+
/// <summary>
593+
/// Pads the end dimenison by the specified length.
594+
/// </summary>
595+
/// <param name="tensor1">The tensor.</param>
596+
/// <param name="padLength">Length of the pad.</param>
597+
/// <exception cref="System.ArgumentException">Rank 2 or 3 currently supported</exception>
598+
public static Tensor<float> PadEnd(this Tensor<float> tensor1, int padLength)
599+
{
600+
var dimensions = tensor1.Dimensions.ToArray();
601+
dimensions[^1] += padLength;
602+
var concatenatedTensor = new Tensor<float>(dimensions);
603+
604+
if (tensor1.Dimensions.Length == 2)
605+
{
606+
for (int i = 0; i < tensor1.Dimensions[0]; i++)
607+
for (int j = 0; j < tensor1.Dimensions[1]; j++)
608+
concatenatedTensor[i, j] = tensor1[i, j];
609+
}
610+
else if (tensor1.Dimensions.Length == 3)
611+
{
612+
for (int i = 0; i < tensor1.Dimensions[0]; i++)
613+
for (int j = 0; j < tensor1.Dimensions[1]; j++)
614+
for (int k = 0; k < tensor1.Dimensions[2]; k++)
615+
concatenatedTensor[i, j, k] = tensor1[i, j, k];
616+
}
617+
else
618+
{
619+
throw new ArgumentException("Rank 2 or 3 currently supported");
620+
}
621+
622+
return concatenatedTensor;
623+
}
624+
625+
592626
/// <summary>
593627
/// Generates the next random tensor
594628
/// </summary>

TensorStack.StableDiffusion/Enums/PipelineType.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ public enum PipelineType
77
StableDiffusion = 0,
88
StableDiffusion2 = 1,
99
StableDiffusionXL = 2,
10-
StableCascade = 4,
11-
LatentConsistency = 10,
10+
StableDiffusion3 = 3,
11+
StableCascade = 10,
12+
LatentConsistency = 20,
1213

1314
}
1415
}

TensorStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ public enum SchedulerType
3131
DDPMWuerstchen = 10,
3232

3333
[Display(Name = "LCM")]
34-
LCM = 20
34+
LCM = 20,
35+
36+
[Display(Name = "FlowMatch-EulerDiscrete")]
37+
FlowMatchEulerDiscrete = 30,
38+
39+
[Display(Name = "FlowMatch-EulerDynamic")]
40+
FlowMatchEulerDynamic = 31
3541
}
3642
}

TensorStack.StableDiffusion/Models/T5EncoderModel.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ public override async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInpu
3636
if (IsFixedSequenceLength)
3737
tokenInput = PadOrTruncate(tokenInput);
3838

39+
if (tokenInput.InputIds.Length == 0)
40+
return new TextEncoderResult(new Tensor<float>([1, 1, HiddenSize]), default);
41+
3942
var sequenceLength = tokenInput.InputIds.Length;
4043
var supportsAttentionMask = Metadata.Outputs.Count == 2;
4144
var inputTensor = new TensorSpan<long>(tokenInput.InputIds, [1, sequenceLength]);

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ protected virtual IScheduler CreateScheduler(GenerateOptions options)
112112
SchedulerType.KDPM2Ancestral => new KDPM2AncestralScheduler(options),
113113
SchedulerType.DDPMWuerstchen => new DDPMWuerstchenScheduler(options),
114114
SchedulerType.LCM => new LCMScheduler(options),
115+
SchedulerType.FlowMatchEulerDiscrete => new FlowMatchEulerDiscreteScheduler(options),
116+
SchedulerType.FlowMatchEulerDynamic => new FlowMatchEulerDynamicScheduler(options),
115117
_ => default
116118
};
117119
scheduler.Initialize(options.Strength);
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System.IO;
4+
using TensorStack.Common;
5+
using TensorStack.StableDiffusion.Config;
6+
using TensorStack.StableDiffusion.Enums;
7+
8+
namespace TensorStack.StableDiffusion.Pipelines.StableDiffusion3
9+
{
10+
public record StableDiffusion3Config : PipelineConfig
11+
{
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="StableDiffusion3Config"/> class.
14+
/// </summary>
15+
public StableDiffusion3Config()
16+
{
17+
Tokenizer = new TokenizerConfig();
18+
Tokenizer2 = new TokenizerConfig();
19+
Tokenizer3 = new TokenizerConfig();
20+
TextEncoder = new CLIPModelConfig();
21+
TextEncoder2 = new CLIPModelConfig { HiddenSize = 1280 };
22+
TextEncoder3 = new CLIPModelConfig
23+
{
24+
PadTokenId = 0,
25+
HiddenSize = 4096,
26+
IsFixedSequenceLength = false,
27+
SequenceLength = 512
28+
};
29+
Transformer = new TransformerModelConfig
30+
{
31+
JointAttention = 4096,
32+
PooledProjection = 2048,
33+
CaptionProjection = 1536,
34+
IsOptimizationSupported = true
35+
};
36+
AutoEncoder = new AutoEncoderModelConfig
37+
{
38+
LatentChannels = 16,
39+
ScaleFactor = 1.5305f,
40+
ShiftFactor = 0.0609f
41+
};
42+
}
43+
44+
public string Name { get; init; } = "StableDiffusion3";
45+
public override PipelineType Pipeline { get; } = PipelineType.StableDiffusion3;
46+
public TokenizerConfig Tokenizer { get; init; }
47+
public TokenizerConfig Tokenizer2 { get; init; }
48+
public TokenizerConfig Tokenizer3 { get; init; }
49+
public CLIPModelConfig TextEncoder { get; init; }
50+
public CLIPModelConfig TextEncoder2 { get; init; }
51+
public CLIPModelConfig TextEncoder3 { get; init; }
52+
public TransformerModelConfig Transformer { get; init; }
53+
public AutoEncoderModelConfig AutoEncoder { get; init; }
54+
55+
56+
/// <summary>
57+
/// Sets the execution provider for all models.
58+
/// </summary>
59+
/// <param name="executionProvider">The execution provider.</param>
60+
public override void SetProvider(ExecutionProvider executionProvider)
61+
{
62+
Tokenizer.SetProvider(executionProvider);
63+
Tokenizer2.SetProvider(executionProvider);
64+
Tokenizer3.SetProvider(executionProvider);
65+
TextEncoder.SetProvider(executionProvider);
66+
TextEncoder2.SetProvider(executionProvider);
67+
TextEncoder3.SetProvider(executionProvider);
68+
Transformer.SetProvider(executionProvider);
69+
AutoEncoder.SetProvider(executionProvider);
70+
}
71+
72+
73+
/// <summary>
74+
/// Saves the configuration to file.
75+
/// </summary>
76+
/// <param name="configFile">The configuration file.</param>
77+
/// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param>
78+
public override void Save(string configFile, bool useRelativePaths = true)
79+
{
80+
ConfigService.Serialize(configFile, this, useRelativePaths);
81+
}
82+
83+
84+
/// <summary>
85+
/// Create StableDiffusion3 configuration from default values
86+
/// </summary>
87+
/// <param name="name">The name.</param>
88+
/// <param name="modelType">Type of the model.</param>
89+
/// <param name="executionProvider">The execution provider.</param>
90+
/// <returns>StableDiffusion3Config.</returns>
91+
public static StableDiffusion3Config FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
92+
{
93+
var config = new StableDiffusion3Config { Name = name };
94+
config.Transformer.ModelType = modelType;
95+
config.SetProvider(executionProvider);
96+
return config;
97+
}
98+
99+
100+
/// <summary>
101+
/// Create StableDiffusionv configuration from json file
102+
/// </summary>
103+
/// <param name="configFile">The configuration file.</param>
104+
/// <param name="executionProvider">The execution provider.</param>
105+
/// <returns>StableDiffusion3Config.</returns>
106+
public static StableDiffusion3Config FromFile(string configFile, ExecutionProvider executionProvider = default)
107+
{
108+
var config = ConfigService.Deserialize<StableDiffusion3Config>(configFile);
109+
config.SetProvider(executionProvider);
110+
return config;
111+
}
112+
113+
114+
/// <summary>
115+
/// Create StableDiffusion3 configuration from folder structure
116+
/// </summary>
117+
/// <param name="modelFolder">The model folder.</param>
118+
/// <param name="modelType">Type of the model.</param>
119+
/// <param name="executionProvider">The execution provider.</param>
120+
/// <returns>StableDiffusion3Config.</returns>
121+
public static StableDiffusion3Config FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
122+
{
123+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
124+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
125+
config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "vocab.json");
126+
config.Tokenizer3.Path = Path.Combine(modelFolder, "tokenizer_3", "spiece.model");
127+
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
128+
config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx");
129+
config.TextEncoder3.Path = Path.Combine(modelFolder, "text_encoder_3", "model.onnx");
130+
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
131+
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
132+
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
133+
var controlNetPath = Path.Combine(modelFolder, "controlnet", "model.onnx");
134+
if (File.Exists(controlNetPath))
135+
config.Transformer.ControlNetPath = controlNetPath;
136+
return config;
137+
}
138+
139+
}
140+
}

0 commit comments

Comments
 (0)