|
| 1 | +// Copyright (c) TensorStack. All rights reserved. |
| 2 | +// Licensed under the Apache 2.0 License. |
| 3 | +using System.IO; |
| 4 | +using TensorStack.Common; |
| 5 | +using TensorStack.StableDiffusion.Config; |
| 6 | +using TensorStack.StableDiffusion.Enums; |
| 7 | + |
| 8 | +namespace TensorStack.StableDiffusion.Pipelines.StableCascade |
| 9 | +{ |
| 10 | + public record StableCascadeConfig : PipelineConfig |
| 11 | + { |
| 12 | + /// <summary> |
| 13 | + /// Initializes a new instance of the <see cref="StableCascadeConfig"/> class. |
| 14 | + /// </summary> |
| 15 | + public StableCascadeConfig() |
| 16 | + { |
| 17 | + Tokenizer = new TokenizerConfig(); |
| 18 | + PriorUnet = new UNetModelConfig(); |
| 19 | + DecoderUnet = new UNetModelConfig(); |
| 20 | + TextEncoder = new CLIPModelConfig { HiddenSize = 1280 }; |
| 21 | + ImageEncoder = new CLIPModelConfig { HiddenSize = 768 }; |
| 22 | + ImageDecoder = new PaellaVQModelConfig |
| 23 | + { |
| 24 | + Scale = 4, |
| 25 | + ScaleFactor = 0.3764f |
| 26 | + }; |
| 27 | + } |
| 28 | + |
| 29 | + public string Name { get; init; } = "StableCascade"; |
| 30 | + public override PipelineType Pipeline { get; } = PipelineType.StableCascade; |
| 31 | + public TokenizerConfig Tokenizer { get; init; } |
| 32 | + public CLIPModelConfig TextEncoder { get; init; } |
| 33 | + public UNetModelConfig PriorUnet { get; init; } |
| 34 | + public UNetModelConfig DecoderUnet { get; init; } |
| 35 | + public PaellaVQModelConfig ImageDecoder { get; init; } |
| 36 | + public CLIPModelConfig ImageEncoder { get; init; } |
| 37 | + |
| 38 | + |
| 39 | + /// <summary> |
| 40 | + /// Sets the execution provider for all models. |
| 41 | + /// </summary> |
| 42 | + /// <param name="executionProvider">The execution provider.</param> |
| 43 | + public override void SetProvider(ExecutionProvider executionProvider) |
| 44 | + { |
| 45 | + Tokenizer.SetProvider(executionProvider); |
| 46 | + TextEncoder.SetProvider(executionProvider); |
| 47 | + PriorUnet.SetProvider(executionProvider); |
| 48 | + DecoderUnet.SetProvider(executionProvider); |
| 49 | + ImageEncoder.SetProvider(executionProvider); |
| 50 | + ImageDecoder.SetProvider(executionProvider); |
| 51 | + } |
| 52 | + |
| 53 | + |
| 54 | + /// <summary> |
| 55 | + /// Saves the configuration to file. |
| 56 | + /// </summary> |
| 57 | + /// <param name="configFile">The configuration file.</param> |
| 58 | + /// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param> |
| 59 | + public override void Save(string configFile, bool useRelativePaths = true) |
| 60 | + { |
| 61 | + ConfigService.Serialize(configFile, this, useRelativePaths); |
| 62 | + } |
| 63 | + |
| 64 | + |
| 65 | + /// <summary> |
| 66 | + /// Create StableCascade configuration from default values |
| 67 | + /// </summary> |
| 68 | + /// <param name="name">The name.</param> |
| 69 | + /// <param name="modelType">Type of the model.</param> |
| 70 | + /// <param name="executionProvider">The execution provider.</param> |
| 71 | + /// <returns>StableCascadeConfig.</returns> |
| 72 | + public static StableCascadeConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) |
| 73 | + { |
| 74 | + var config = new StableCascadeConfig { Name = name }; |
| 75 | + config.PriorUnet.ModelType = modelType; |
| 76 | + config.DecoderUnet.ModelType = modelType; |
| 77 | + config.SetProvider(executionProvider); |
| 78 | + return config; |
| 79 | + } |
| 80 | + |
| 81 | + |
| 82 | + /// <summary> |
| 83 | + /// Create StableCascade configuration from json file |
| 84 | + /// </summary> |
| 85 | + /// <param name="configFile">The configuration file.</param> |
| 86 | + /// <param name="executionProvider">The execution provider.</param> |
| 87 | + /// <returns>StableCascadeConfig.</returns> |
| 88 | + public static StableCascadeConfig FromFile(string configFile, ExecutionProvider executionProvider = default) |
| 89 | + { |
| 90 | + var config = ConfigService.Deserialize<StableCascadeConfig>(configFile); |
| 91 | + config.SetProvider(executionProvider); |
| 92 | + return config; |
| 93 | + } |
| 94 | + |
| 95 | + |
| 96 | + /// <summary> |
| 97 | + /// Create StableCascade configuration from folder structure |
| 98 | + /// </summary> |
| 99 | + /// <param name="modelFolder">The model folder.</param> |
| 100 | + /// <param name="modelType">Type of the model.</param> |
| 101 | + /// <param name="executionProvider">The execution provider.</param> |
| 102 | + /// <returns>StableCascadeConfig.</returns> |
| 103 | + public static StableCascadeConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) |
| 104 | + { |
| 105 | + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); |
| 106 | + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json"); |
| 107 | + config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); |
| 108 | + config.PriorUnet.Path = Path.Combine(modelFolder, "prior", "model.onnx"); |
| 109 | + config.DecoderUnet.Path = Path.Combine(modelFolder, "decoder", "model.onnx"); |
| 110 | + config.ImageEncoder.Path = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); |
| 111 | + config.ImageDecoder.Path = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); |
| 112 | + return config; |
| 113 | + } |
| 114 | + |
| 115 | + } |
| 116 | +} |
0 commit comments