|
| 1 | +// Copyright (c) TensorStack. All rights reserved. |
| 2 | +// Licensed under the Apache 2.0 License. |
| 3 | +using System.IO; |
| 4 | +using TensorStack.Common; |
| 5 | +using TensorStack.StableDiffusion.Config; |
| 6 | +using TensorStack.StableDiffusion.Enums; |
| 7 | + |
| 8 | +namespace TensorStack.StableDiffusion.Pipelines.Flux |
| 9 | +{ |
| 10 | + public record FluxConfig : PipelineConfig |
| 11 | + { |
| 12 | + /// <summary> |
| 13 | + /// Initializes a new instance of the <see cref="FluxConfig"/> class. |
| 14 | + /// </summary> |
| 15 | + public FluxConfig() |
| 16 | + { |
| 17 | + Tokenizer = new TokenizerConfig(); |
| 18 | + Tokenizer2 = new TokenizerConfig(); |
| 19 | + TextEncoder = new CLIPModelConfig(); |
| 20 | + TextEncoder2 = new CLIPModelConfig |
| 21 | + { |
| 22 | + PadTokenId = 0, |
| 23 | + HiddenSize = 4096, |
| 24 | + IsFixedSequenceLength = false, |
| 25 | + SequenceLength = 512 |
| 26 | + }; |
| 27 | + Transformer = new TransformerModelConfig |
| 28 | + { |
| 29 | + JointAttention = 4096, |
| 30 | + PooledProjection = 768, |
| 31 | + IsOptimizationSupported = true |
| 32 | + }; |
| 33 | + AutoEncoder = new AutoEncoderModelConfig |
| 34 | + { |
| 35 | + LatentChannels = 16, |
| 36 | + ScaleFactor = 0.3611f, |
| 37 | + ShiftFactor = 0.1159f |
| 38 | + }; |
| 39 | + } |
| 40 | + |
| 41 | + public string Name { get; init; } = "Flux"; |
| 42 | + public override PipelineType Pipeline { get; } = PipelineType.Flux; |
| 43 | + public TokenizerConfig Tokenizer { get; init; } |
| 44 | + public TokenizerConfig Tokenizer2 { get; init; } |
| 45 | + public CLIPModelConfig TextEncoder { get; init; } |
| 46 | + public CLIPModelConfig TextEncoder2 { get; init; } |
| 47 | + public TransformerModelConfig Transformer { get; init; } |
| 48 | + public AutoEncoderModelConfig AutoEncoder { get; init; } |
| 49 | + |
| 50 | + |
| 51 | + /// <summary> |
| 52 | + /// Sets the execution provider for all models. |
| 53 | + /// </summary> |
| 54 | + /// <param name="executionProvider">The execution provider.</param> |
| 55 | + public override void SetProvider(ExecutionProvider executionProvider) |
| 56 | + { |
| 57 | + Tokenizer.SetProvider(executionProvider); |
| 58 | + Tokenizer2.SetProvider(executionProvider); |
| 59 | + TextEncoder.SetProvider(executionProvider); |
| 60 | + TextEncoder2.SetProvider(executionProvider); |
| 61 | + Transformer.SetProvider(executionProvider); |
| 62 | + AutoEncoder.SetProvider(executionProvider); |
| 63 | + } |
| 64 | + |
| 65 | + |
| 66 | + /// <summary> |
| 67 | + /// Saves the configuration to file. |
| 68 | + /// </summary> |
| 69 | + /// <param name="configFile">The configuration file.</param> |
| 70 | + /// <param name="useRelativePaths">if set to <c>true</c> use relative paths.</param> |
| 71 | + public override void Save(string configFile, bool useRelativePaths = true) |
| 72 | + { |
| 73 | + ConfigService.Serialize(configFile, this, useRelativePaths); |
| 74 | + } |
| 75 | + |
| 76 | + |
| 77 | + /// <summary> |
| 78 | + /// Create Flux configuration from default values |
| 79 | + /// </summary> |
| 80 | + /// <param name="name">The name.</param> |
| 81 | + /// <param name="modelType">Type of the model.</param> |
| 82 | + /// <param name="executionProvider">The execution provider.</param> |
| 83 | + /// <returns>FluxConfig.</returns> |
| 84 | + public static FluxConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) |
| 85 | + { |
| 86 | + var config = new FluxConfig { Name = name }; |
| 87 | + config.Transformer.ModelType = modelType; |
| 88 | + config.SetProvider(executionProvider); |
| 89 | + return config; |
| 90 | + } |
| 91 | + |
| 92 | + |
| 93 | + /// <summary> |
| 94 | + /// Create StableDiffusionv configuration from json file |
| 95 | + /// </summary> |
| 96 | + /// <param name="configFile">The configuration file.</param> |
| 97 | + /// <param name="executionProvider">The execution provider.</param> |
| 98 | + /// <returns>FluxConfig.</returns> |
| 99 | + public static FluxConfig FromFile(string configFile, ExecutionProvider executionProvider = default) |
| 100 | + { |
| 101 | + var config = ConfigService.Deserialize<FluxConfig>(configFile); |
| 102 | + config.SetProvider(executionProvider); |
| 103 | + return config; |
| 104 | + } |
| 105 | + |
| 106 | + |
| 107 | + /// <summary> |
| 108 | + /// Create Flux configuration from folder structure |
| 109 | + /// </summary> |
| 110 | + /// <param name="modelFolder">The model folder.</param> |
| 111 | + /// <param name="modelType">Type of the model.</param> |
| 112 | + /// <param name="executionProvider">The execution provider.</param> |
| 113 | + /// <returns>FluxConfig.</returns> |
| 114 | + public static FluxConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) |
| 115 | + { |
| 116 | + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); |
| 117 | + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "vocab.json"); |
| 118 | + config.Tokenizer2.Path = Path.Combine(modelFolder, "tokenizer_2", "spiece.model"); |
| 119 | + config.TextEncoder.Path = Path.Combine(modelFolder, "text_encoder", "model.onnx"); |
| 120 | + config.TextEncoder2.Path = Path.Combine(modelFolder, "text_encoder_2", "model.onnx"); |
| 121 | + config.Transformer.Path = Path.Combine(modelFolder, "transformer", "model.onnx"); |
| 122 | + config.AutoEncoder.DecoderModelPath = Path.Combine(modelFolder, "vae_decoder", "model.onnx"); |
| 123 | + config.AutoEncoder.EncoderModelPath = Path.Combine(modelFolder, "vae_encoder", "model.onnx"); |
| 124 | + var controlNetPath = Path.Combine(modelFolder, "transformer", "controlnet.onnx"); |
| 125 | + if (File.Exists(controlNetPath)) |
| 126 | + config.Transformer.ControlNetPath = controlNetPath; |
| 127 | + return config; |
| 128 | + } |
| 129 | + |
| 130 | + } |
| 131 | +} |
0 commit comments