Skip to content

Commit a669f3f

Browse files
committed
StableCascadePipeline
1 parent cc432ee commit a669f3f

File tree

8 files changed

+786
-7
lines changed

8 files changed

+786
-7
lines changed

TensorStack.StableDiffusion/Common/GenerateOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public record GenerateOptions : IPipelineOptions, ISchedulerOptions
2020
public string Prompt { get; set; }
2121
public string NegativePrompt { get; set; }
2222
public float GuidanceScale { get; set; }
23+
public float GuidanceScale2 { get; set; }
2324
public SchedulerType Scheduler { get; set; }
2425

2526
public float Strength { get; set; } = 1;

TensorStack.StableDiffusion/Enums/PipelineType.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ public enum PipelineType
77
StableDiffusion = 0,
88
StableDiffusion2 = 1,
99
StableDiffusionXL = 2,
10+
StableCascade = 4,
1011
LatentConsistency = 10,
1112

1213
}

TensorStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ public enum SchedulerType
2727
[Display(Name = "KDPM2-Ancestral")]
2828
KDPM2Ancestral = 6,
2929

30+
[Display(Name = "DDPM-Wuerstchen")]
31+
DDPMWuerstchen = 10,
32+
3033
[Display(Name = "LCM")]
3134
LCM = 20
3235
}

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ protected virtual IScheduler CreateScheduler(GenerateOptions options)
110110
SchedulerType.DDIM => new DDIMScheduler(options),
111111
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
112112
SchedulerType.KDPM2Ancestral => new KDPM2AncestralScheduler(options),
113+
SchedulerType.DDPMWuerstchen => new DDPMWuerstchenScheduler(options),
113114
SchedulerType.LCM => new LCMScheduler(options),
114115
_ => default
115116
};
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
using System.IO;
4+
using TensorStack.Common;
5+
using TensorStack.StableDiffusion.Config;
6+
using TensorStack.StableDiffusion.Enums;
7+
8+
namespace TensorStack.StableDiffusion.Pipelines.StableCascade
9+
{
10+
public record StableCascadeConfig : PipelineConfig
11+
{
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="StableCascadeConfig"/> class.
14+
/// </summary>
15+
public StableCascadeConfig()
16+
{
17+
Tokenizer = new TokenizerConfig();
18+
PriorUnet = new UNetModelConfig();
19+
DecoderUnet = new UNetModelConfig();
20+
TextEncoder = new CLIPModelConfig { HiddenSize = 1280 };
21+
ImageEncoder = new CLIPModelConfig { HiddenSize = 768 };
22+
ImageDecoder = new PaellaVQModelConfig
23+
{
24+
Scale = 4,
25+
ScaleFactor = 0.3764f
26+
};
27+
}
28+
29+
public string Name { get; init; } = "StableCascade";
30+
public override PipelineType Pipeline { get; } = PipelineType.StableCascade;
31+
public TokenizerConfig Tokenizer { get; init; }
32+
public CLIPModelConfig TextEncoder { get; init; }
33+
public UNetModelConfig PriorUnet { get; init; }
34+
public UNetModelConfig DecoderUnet { get; init; }
35+
public PaellaVQModelConfig ImageDecoder { get; init; }
36+
public CLIPModelConfig ImageEncoder { get; init; }
37+
38+
39+
/// <summary>
40+
/// Sets the execution provider for all models.
41+
/// </summary>
42+
/// <param name="executionProvider">The execution provider.</param>
43+
public override void SetProvider(ExecutionProvider executionProvider)
44+
{
45+
Tokenizer.SetProvider(executionProvider);
46+
TextEncoder.SetProvider(executionProvider);
47+
PriorUnet.SetProvider(executionProvider);
48+
DecoderUnet.SetProvider(executionProvider);
49+
ImageEncoder.SetProvider(executionProvider);
50+
ImageDecoder.SetProvider(executionProvider);
51+
}
52+
53+
54+
/// <summary>
55+
/// Saves the configuration to file.
56+
/// </summary>
57+
/// <param name="configFile">The configuration file.</param>
58+
/// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param>
59+
public override void Save(string configFile, bool useRelativePaths = true)
60+
{
61+
ConfigService.Serialize(configFile, this, useRelativePaths);
62+
}
63+
64+
65+
/// <summary>
66+
/// Create StableCascade configuration from default values
67+
/// </summary>
68+
/// <param name="name">The name.</param>
69+
/// <param name="modelType">Type of the model.</param>
70+
/// <param name="executionProvider">The execution provider.</param>
71+
/// <returns>StableCascadeConfig.</returns>
72+
public static StableCascadeConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default)
73+
{
74+
var config = new StableCascadeConfig { Name = name };
75+
config.PriorUnet.ModelType = modelType;
76+
config.DecoderUnet.ModelType = modelType;
77+
config.SetProvider(executionProvider);
78+
return config;
79+
}
80+
81+
82+
/// <summary>
83+
/// Create StableCascade configuration from json file
84+
/// </summary>
85+
/// <param name="configFile">The configuration file.</param>
86+
/// <param name="executionProvider">The execution provider.</param>
87+
/// <returns>StableCascadeConfig.</returns>
88+
public static StableCascadeConfig FromFile(string configFile, ExecutionProvider executionProvider = default)
89+
{
90+
var config = ConfigService.Deserialize<StableCascadeConfig>(configFile);
91+
config.SetProvider(executionProvider);
92+
return config;
93+
}
94+
95+
96+
/// <summary>
97+
/// Create StableCascade configuration from folder structure
98+
/// </summary>
99+
/// <param name="modelFolder">The model folder.</param>
100+
/// <param name="modelType">Type of the model.</param>
101+
/// <param name="executionProvider">The execution provider.</param>
102+
/// <returns>StableCascadeConfig.</returns>
103+
public static StableCascadeConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default)
104+
{
105+
var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider);
106+
config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json");
107+
config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx");
108+
config.PriorUnet.Path = Path.Combine(modelFolder, "prior", "model.onnx");
109+
config.DecoderUnet.Path = Path.Combine(modelFolder, "decoder", "model.onnx");
110+
config.ImageEncoder.Path = Path.Combine(modelFolder, "vae_encoder", "model.onnx");
111+
config.ImageDecoder.Path = Path.Combine(modelFolder, "vae_decoder", "model.onnx");
112+
return config;
113+
}
114+
115+
}
116+
}

0 commit comments

Comments
 (0)