Skip to content

Commit e6f8be8

Browse files
committed
FluxPipeline
1 parent 0d76d49 commit e6f8be8

File tree

4 files changed

+780
-3
lines changed

4 files changed

+780
-3
lines changed

TensorStack.StableDiffusion/Enums/PipelineType.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ public enum PipelineType
1010
StableDiffusion3 = 3,
1111
StableCascade = 10,
1212
LatentConsistency = 20,
13-
13+
Flux = 30,
1414
}
1515
}

TensorStack.StableDiffusion/Models/TransformerFluxModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public TransformerFluxModel(TransformerModelConfig configuration)
3232
/// <param name="txtIds">The text ids.</param>
3333
/// <param name="guidanceTensor">The guidance tensor.</param>
3434
/// <param name="cancellationToken">The cancellation token.</param>
35-
public async Task<Tensor<float>> RunAsync(int timestep, Tensor<float> hiddenStates, Tensor<float> encoderHiddenStates, Tensor<float> pooledProjections, Tensor<float> imgIds, Tensor<float> txtIds, Tensor<float> guidanceTensor, CancellationToken cancellationToken = default)
35+
public async Task<Tensor<float>> RunAsync(int timestep, Tensor<float> hiddenStates, Tensor<float> encoderHiddenStates, Tensor<float> pooledProjections, Tensor<float> imgIds, Tensor<float> txtIds, float guidanceScale, CancellationToken cancellationToken = default)
3636
{
3737
if (!Transformer.IsLoaded())
3838
await Transformer.LoadAsync(cancellationToken: cancellationToken);
@@ -48,7 +48,7 @@ public async Task<Tensor<float>> RunAsync(int timestep, Tensor<float> hiddenStat
4848
transformerParams.AddInput(imgIds.AsTensorSpan());
4949
transformerParams.AddInput(txtIds.AsTensorSpan());
5050
if (supportsGuidance)
51-
transformerParams.AddInput(guidanceTensor.AsTensorSpan());
51+
transformerParams.AddScalarInput(guidanceScale);
5252

5353
// Outputs
5454
transformerParams.AddOutput(hiddenStates.Dimensions);
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System.IO;
4+
using TensorStack.Common;
5+
using TensorStack.StableDiffusion.Config;
6+
using TensorStack.StableDiffusion.Enums;
7+
8+
namespace TensorStack.StableDiffusion.Pipelines.Flux
9+
{
10+
public record FluxConfig : PipelineConfig
11+
{
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="FluxConfig"/> class.
14+
/// </summary>
15+
public FluxConfig()
16+
{
17+
Tokenizer = new TokenizerConfig();
18+
Tokenizer2 = new TokenizerConfig();
19+
TextEncoder = new CLIPModelConfig();
20+
TextEncoder2 = new CLIPModelConfig
21+
{
22+
PadTokenId = 0,
23+
HiddenSize = 4096,
24+
IsFixedSequenceLength = false,
25+
SequenceLength = 512
26+
};
27+
Transformer = new TransformerModelConfig
28+
{
29+
JointAttention = 4096,
30+
PooledProjection = 768,
31+
IsOptimizationSupported = true
32+
};
33+
AutoEncoder = new AutoEncoderModelConfig
34+
{
35+
LatentChannels = 16,
36+
ScaleFactor = 0.3611f,
37+
ShiftFactor = 0.1159f
38+
};
39+
}
40+
41+
public string Name { get; init; } = "Flux";
42+
public override PipelineType Pipeline { get; } = PipelineType.Flux;
43+
public TokenizerConfig Tokenizer { get; init; }
44+
public TokenizerConfig Tokenizer2 { get; init; }
45+
public CLIPModelConfig TextEncoder { get; init; }
46+
public CLIPModelConfig TextEncoder2 { get; init; }
47+
public TransformerModelConfig Transformer { get; init; }
48+
public AutoEncoderModelConfig AutoEncoder { get; init; }
49+
50+
51+
/// <summary>
52+
/// Sets the execution provider for all models.
53+
/// </summary>
54+
/// <param name="executionProvider">The execution provider.</param>
55+
public override void SetProvider(ExecutionProvider executionProvider)
56+
{
57+
Tokenizer.SetProvider(executionProvider);
58+
Tokenizer2.SetProvider(executionProvider);
59+
TextEncoder.SetProvider(executionProvider);
60+
TextEncoder2.SetProvider(executionProvider);
61+
Transformer.SetProvider(executionProvider);
62+
AutoEncoder.SetProvider(executionProvider);
63+
}
64+
65+
66+
/// <summary>
67+
/// Saves the configuration to file.
68+
/// </summary>
69+
/// <param name="configFile">The configuration file.</param>
70+
/// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param>
71+
public override void Save(string configFile, bool useRelativePaths = true)
72+
{
73+
ConfigService.Serialize(configFile, this, useRelativePaths);
74+
}
75+
76+
77+
/// <summary>
78+
/// Create Flux configuration from default values
79+
/// </summary>
80+
/// <param name="name">The name.</param>
81+
/// <param name="modelType">Type of the model.</param>
82+
/// <param name="executionProvider">The execution provider.</param>
83+
/// <returns>FluxConfig.</returns>
84+
public static FluxConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
85+
{
86+
var config = new FluxConfig { Name = name };
87+
config.Transformer.ModelType = modelType;
88+
config.SetProvider(executionProvider);
89+
return config;
90+
}
91+
92+
93+
/// <summary>
94+
/// Create StableDiffusionv configuration from json file
95+
/// </summary>
96+
/// <param name="configFile">The configuration file.</param>
97+
/// <param name="executionProvider">The execution provider.</param>
98+
/// <returns>FluxConfig.</returns>
99+
public static FluxConfig FromFile(string configFile, ExecutionProvider executionProvider = default)
100+
{
101+
var config = ConfigService.Deserialize<FluxConfig>(configFile);
102+
config.SetProvider(executionProvider);
103+
return config;
104+
}
105+
106+
107+
/// <summary>
108+
/// Create Flux configuration from folder structure
109+
/// </summary>
110+
/// <param name="modelFolder">The model folder.</param>
111+
/// <param name="modelType">Type of the model.</param>
112+
/// <param name="executionProvider">The execution provider.</param>
113+
/// <returns>FluxConfig.</returns>
114+
public static FluxConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
115+
{
116+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
117+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
118+
config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "spiece.model");
119+
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
120+
config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx");
121+
config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx");
122+
config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
123+
config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
124+
var controlNetPath = Path.Combine(modelFolder, "transformer", "controlnet.onnx");
125+
if (File.Exists(controlNetPath))
126+
config.Transformer.ControlNetPath = controlNetPath;
127+
return config;
128+
}
129+
130+
}
131+
}

0 commit comments

Comments
 (0)