|
| 1 | +// Copyright (c) TensorStack. All rights reserved. |
| 2 | +// Licensed under the Apache 2.0 License. |
| 3 | +using System.IO; |
| 4 | +using TensorStack.Common; |
| 5 | +using TensorStack.StableDiffusion.Config; |
| 6 | +using TensorStack.StableDiffusion.Enums; |
| 7 | + |
| 8 | +namespace TensorStack.StableDiffusion.Pipelines.StableDiffusionXL |
| 9 | +{ |
| 10 | + public record StableDiffusionXLConfig : PipelineConfig |
| 11 | + { |
| 12 | + /// <summary> |
| 13 | + /// Initializes a new instance of the <see cref="StableDiffusionXLConfig"/> class. |
| 14 | + /// </summary> |
| 15 | + public StableDiffusionXLConfig() |
| 16 | + { |
| 17 | + Tokenizer = new TokenizerConfig(); |
| 18 | + Tokenizer2 = new TokenizerConfig(); |
| 19 | + TextEncoder = new CLIPModelConfig { HiddenSize = 768 }; |
| 20 | + TextEncoder2 = new CLIPModelConfig { HiddenSize = 1280 }; |
| 21 | + Unet = new UNetModelConfig { IsOptimizationSupported = true }; |
| 22 | + AutoEncoder = new AutoEncoderModelConfig { ScaleFactor = 0.13025f }; |
| 23 | + } |
| 24 | + |
| 25 | + public string Name { get; init; } = "StableDiffusionXL"; |
| 26 | + public override PipelineType Pipeline { get; } = PipelineType.StableDiffusionXL; |
| 27 | + public TokenizerConfig Tokenizer { get; init; } |
| 28 | + public TokenizerConfig Tokenizer2 { get; init; } |
| 29 | + public CLIPModelConfig TextEncoder { get; init; } |
| 30 | + public CLIPModelConfig TextEncoder2 { get; init; } |
| 31 | + public UNetModelConfig Unet { get; init; } |
| 32 | + public AutoEncoderModelConfig AutoEncoder { get; init; } |
| 33 | + |
| 34 | + |
| 35 | + /// <summary> |
| 36 | + /// Sets the execution provider for all models. |
| 37 | + /// </summary> |
| 38 | + /// <param name="executionProvider">The execution provider.</param> |
| 39 | + public override void SetProvider(ExecutionProvider executionProvider) |
| 40 | + { |
| 41 | + Tokenizer.SetProvider(executionProvider); |
| 42 | + Tokenizer2.SetProvider(executionProvider); |
| 43 | + TextEncoder.SetProvider(executionProvider); |
| 44 | + TextEncoder2.SetProvider(executionProvider); |
| 45 | + Unet.SetProvider(executionProvider); |
| 46 | + AutoEncoder.SetProvider(executionProvider); |
| 47 | + } |
| 48 | + |
| 49 | + |
| 50 | + /// <summary> |
| 51 | + /// Saves the configuration to file. |
| 52 | + /// </summary> |
| 53 | + /// <param name="configFile">The configuration file.</param> |
| 54 | + /// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param> |
| 55 | + public override void Save(string configFile, bool useRelativePaths = true) |
| 56 | + { |
| 57 | + ConfigService.Serialize(configFile, this, useRelativePaths); |
| 58 | + } |
| 59 | + |
| 60 | + |
| 61 | + /// <summary> |
| 62 | + /// Create StableDiffusion configuration from default values |
| 63 | + /// </summary> |
| 64 | + /// <param name="name">The name.</param> |
| 65 | + /// <param name="modelType">Type of the model.</param> |
| 66 | + /// <param name="executionProvider">The execution provider.</param> |
| 67 | + /// <returns>StableDiffusionXLConfig.</returns> |
| 68 | + public static StableDiffusionXLConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) |
| 69 | + { |
| 70 | + var config = new StableDiffusionXLConfig { Name = name }; |
| 71 | + config.Unet.ModelType = modelType; |
| 72 | + config.SetProvider(executionProvider); |
| 73 | + return config; |
| 74 | + } |
| 75 | + |
| 76 | + |
| 77 | + /// <summary> |
| 78 | + /// Create StableDiffusion configuration from json file |
| 79 | + /// </summary> |
| 80 | + /// <param name="configFile">The configuration file.</param> |
| 81 | + /// <param name="executionProvider">The execution provider.</param> |
| 82 | + /// <returns>StableDiffusionXLConfig.</returns> |
| 83 | + public static StableDiffusionXLConfig FromFile(string configFile, ExecutionProvider executionProvider = default) |
| 84 | + { |
| 85 | + var config = ConfigService.Deserialize<StableDiffusionXLConfig>(configFile); |
| 86 | + config.SetProvider(executionProvider); |
| 87 | + return config; |
| 88 | + } |
| 89 | + |
| 90 | + |
| 91 | + /// <summary> |
| 92 | + /// Create StableDiffusion configuration from folder structure |
| 93 | + /// </summary> |
| 94 | + /// <param name="modelFolder">The model folder.</param> |
| 95 | + /// <param name="modelType">Type of the model.</param> |
| 96 | + /// <param name="executionProvider">The execution provider.</param> |
| 97 | + /// <returns>StableDiffusionXLConfig.</returns> |
| 98 | + public static StableDiffusionXLConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) |
| 99 | + { |
| 100 | + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); |
| 101 | + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json"); |
| 102 | + config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "vocab.json"); |
| 103 | + config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); |
| 104 | + config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx"); |
| 105 | + config.Unet.Path = Path.Combine(modelFolder, "unet", "model.onnx"); |
| 106 | + config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); |
| 107 | + config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); |
| 108 | + var controlNetPath = Path.Combine(modelFolder, "unet", "controlnet.onnx"); |
| 109 | + if (File.Exists(controlNetPath)) |
| 110 | + config.Unet.ControlNetPath = controlNetPath; |
| 111 | + return config; |
| 112 | + } |
| 113 | + |
| 114 | + } |
| 115 | +} |
0 commit comments