Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8fd0c79

Browse files
committed
Multiple model support
1 parent c370dd3 commit 8fd0c79

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+904
-481
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,14 @@ private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions op
6262
{
6363
var timestamp = Stopwatch.GetTimestamp();
6464
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
65-
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
65+
66+
//TODO:
67+
var model = new ModelOptions
68+
{
69+
70+
};
71+
72+
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
6673
if (result is not null)
6774
{
6875
await result.SaveAsPngAsync(outputFilename);

OnnxStack.Console/Examples/StableDiffusionExample.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ public async Task RunAsync()
6060
private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions options)
6161
{
6262
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}.png");
63-
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
63+
//TODO:
64+
var model = new ModelOptions
65+
{
66+
67+
};
68+
69+
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
6470
if (result == null)
6571
return false;
6672

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,14 @@ public async Task RunAsync()
5757

5858
private async Task<bool> GenerateImage(PromptOptions prompt, SchedulerOptions options, string key)
5959
{
60+
//TODO:
61+
var model = new ModelOptions
62+
{
63+
64+
};
65+
6066
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{prompt.SchedulerType}_{key}.png");
61-
var result = await _stableDiffusionService.GenerateAsImageAsync(prompt, options);
67+
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
6268
if (result == null)
6369
return false;
6470

OnnxStack.Console/appsettings.json

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,68 +7,72 @@
77
},
88
"AllowedHosts": "*",
99
"OnnxStackConfig": {
10-
"Name": "StableDiffusion 1.5",
11-
"PadTokenId": 49407,
12-
"BlankTokenId": 49407,
13-
"InputTokenLimit": 512,
14-
"TokenizerLimit": 77,
15-
"EmbeddingsLength": 768,
16-
"ScaleFactor": 0.18215,
17-
"ModelConfigurations": [
10+
"OnnxModelSets": [
1811
{
19-
"Type": "Unet",
20-
"DeviceId": 0,
21-
"InterOpNumThreads": 0,
22-
"IntraOpNumThreads": 0,
23-
"ExecutionMode": "ORT_PARALLEL",
24-
"ExecutionProvider": "DirectML",
25-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx"
26-
},
27-
{
28-
"Type": "Tokenizer",
29-
"DeviceId": 0,
30-
"InterOpNumThreads": 0,
31-
"IntraOpNumThreads": 0,
32-
"ExecutionMode": "ORT_PARALLEL",
33-
"ExecutionProvider": "Cpu",
34-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\cliptokenizer.onnx"
35-
},
36-
{
37-
"Type": "TextEncoder",
38-
"DeviceId": 0,
39-
"InterOpNumThreads": 0,
40-
"IntraOpNumThreads": 0,
41-
"ExecutionMode": "ORT_PARALLEL",
42-
"ExecutionProvider": "Cpu",
43-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\text_encoder\\model.onnx"
44-
},
45-
{
46-
"Type": "VaeEncoder",
47-
"DeviceId": 0,
48-
"InterOpNumThreads": 0,
49-
"IntraOpNumThreads": 0,
50-
"ExecutionMode": "ORT_PARALLEL",
51-
"ExecutionProvider": "Cpu",
52-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_encoder\\model.onnx"
53-
},
54-
{
55-
"Type": "VaeDecoder",
56-
"DeviceId": 0,
57-
"InterOpNumThreads": 0,
58-
"IntraOpNumThreads": 0,
59-
"ExecutionMode": "ORT_PARALLEL",
60-
"ExecutionProvider": "Cpu",
61-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_decoder\\model.onnx"
62-
},
63-
{
64-
"Type": "SafetyChecker",
65-
"IsDisabled": true,
66-
"DeviceId": 0,
67-
"InterOpNumThreads": 0,
68-
"IntraOpNumThreads": 0,
69-
"ExecutionMode": "ORT_PARALLEL",
70-
"ExecutionProvider": "Cpu",
71-
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\safety_checker\\model.onnx"
12+
"Name": "StableDiffusion 1.5",
13+
"PadTokenId": 49407,
14+
"BlankTokenId": 49407,
15+
"InputTokenLimit": 512,
16+
"TokenizerLimit": 77,
17+
"EmbeddingsLength": 768,
18+
"ScaleFactor": 0.18215,
19+
"ModelConfigurations": [
20+
{
21+
"Type": "Unet",
22+
"DeviceId": 0,
23+
"InterOpNumThreads": 0,
24+
"IntraOpNumThreads": 0,
25+
"ExecutionMode": "ORT_PARALLEL",
26+
"ExecutionProvider": "DirectML",
27+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx"
28+
},
29+
{
30+
"Type": "Tokenizer",
31+
"DeviceId": 0,
32+
"InterOpNumThreads": 0,
33+
"IntraOpNumThreads": 0,
34+
"ExecutionMode": "ORT_PARALLEL",
35+
"ExecutionProvider": "DirectML",
36+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\cliptokenizer.onnx"
37+
},
38+
{
39+
"Type": "TextEncoder",
40+
"DeviceId": 0,
41+
"InterOpNumThreads": 0,
42+
"IntraOpNumThreads": 0,
43+
"ExecutionMode": "ORT_PARALLEL",
44+
"ExecutionProvider": "DirectML",
45+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\text_encoder\\model.onnx"
46+
},
47+
{
48+
"Type": "VaeEncoder",
49+
"DeviceId": 0,
50+
"InterOpNumThreads": 0,
51+
"IntraOpNumThreads": 0,
52+
"ExecutionMode": "ORT_PARALLEL",
53+
"ExecutionProvider": "DirectML",
54+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_encoder\\model.onnx"
55+
},
56+
{
57+
"Type": "VaeDecoder",
58+
"DeviceId": 0,
59+
"InterOpNumThreads": 0,
60+
"IntraOpNumThreads": 0,
61+
"ExecutionMode": "ORT_PARALLEL",
62+
"ExecutionProvider": "DirectML",
63+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_decoder\\model.onnx"
64+
},
65+
{
66+
"Type": "SafetyChecker",
67+
"IsDisabled": true,
68+
"DeviceId": 0,
69+
"InterOpNumThreads": 0,
70+
"IntraOpNumThreads": 0,
71+
"ExecutionMode": "ORT_PARALLEL",
72+
"ExecutionProvider": "DirectML",
73+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\safety_checker\\model.onnx"
74+
}
75+
]
7276
}
7377
]
7478
}

OnnxStack.Core/Config/ConfigManager.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ public static OnnxStackConfig LoadConfiguration()
2424
/// </summary>
2525
/// <typeparam name="T">The custom IConfigSection class type, NOTE: json section name MUST match class name</typeparam>
2626
/// <returns>The deserialized custom configuration object</returns>
27-
public static T LoadConfiguration<T>(params JsonConverter[] converters) where T : class, IConfigSection
27+
public static T LoadConfiguration<T>(string sectionName = null, params JsonConverter[] converters) where T : class, IConfigSection
2828
{
29-
return LoadConfigurationSection<T>(converters);
29+
return LoadConfigurationSection<T>(sectionName, converters);
3030
}
3131

3232

@@ -37,13 +37,14 @@ public static T LoadConfiguration<T>(params JsonConverter[] converters) where T
3737
/// <param name="converters">The converters.</param>
3838
/// <returns></returns>
3939
/// <exception cref="System.Exception">Failed to parse json element</exception>
40-
private static T LoadConfigurationSection<T>(params JsonConverter[] converters) where T : class, IConfigSection
40+
private static T LoadConfigurationSection<T>(string sectionName, params JsonConverter[] converters) where T : class, IConfigSection
4141
{
42+
var name = sectionName ?? typeof(T).Name;
4243
var serializerOptions = GetSerializerOptions(converters);
4344
var jsonDocument = GetJsonDocument(serializerOptions);
44-
var configElement = jsonDocument.RootElement.GetProperty(typeof(T).Name);
45+
var configElement = jsonDocument.RootElement.GetProperty(name);
4546
var configuration = configElement.Deserialize<T>(serializerOptions)
46-
?? throw new Exception($"Failed to parse {typeof(T).Name} json element");
47+
?? throw new Exception($"Failed to parse {name} json element");
4748
configuration.Initialize();
4849
return configuration;
4950
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace OnnxStack.Core.Config
2+
{
3+
public interface IOnnxModel
4+
{
5+
string Name { get; set; }
6+
}
7+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using System.Collections.Generic;
2+
3+
namespace OnnxStack.Core.Config
4+
{
5+
public interface IOnnxModelSetConfig : IOnnxModel
6+
{
7+
List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
8+
}
9+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System.Collections.Generic;
2+
3+
namespace OnnxStack.Core.Config
4+
{
5+
public class OnnxModelSetConfig : IOnnxModelSetConfig
6+
{
7+
public string Name { get; set; }
8+
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
9+
}
10+
}
Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,14 @@
11
using OnnxStack.Common.Config;
22
using System.Collections.Generic;
3-
using System.Collections.Immutable;
4-
using System.Linq;
53

64
namespace OnnxStack.Core.Config
75
{
86
public class OnnxStackConfig : IConfigSection
97
{
10-
public string Name { get; set; }
11-
public int PadTokenId { get; set; }
12-
public int BlankTokenId { get; set; }
13-
public int InputTokenLimit { get; set; }
14-
public int TokenizerLimit { get; set; }
15-
public int EmbeddingsLength { get; set; }
16-
public float ScaleFactor { get; set; }
17-
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
18-
public ImmutableArray<int> BlankTokenValueArray { get; set; }
8+
public List<OnnxModelSetConfig> OnnxModelSets { get; set; } = new List<OnnxModelSetConfig>();
199

2010
public void Initialize()
2111
{
22-
BlankTokenValueArray = Enumerable.Repeat(BlankTokenId, InputTokenLimit).ToImmutableArray();
2312
}
2413
}
2514
}

OnnxStack.Core/Model/OnnxModelSet.cs

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ namespace OnnxStack.Core.Model
88
{
99
public class OnnxModelSet : IDisposable
1010
{
11-
private readonly OnnxStackConfig _configuration;
11+
private readonly IOnnxModelSetConfig _configuration;
1212
private readonly PrePackedWeightsContainer _prePackedWeightsContainer;
1313
private readonly ImmutableDictionary<OnnxModelType, OnnxModelSession> _modelSessions;
1414

1515
/// <summary>
1616
/// Initializes a new instance of the <see cref="OnnxModelSet"/> class.
1717
/// </summary>
1818
/// <param name="configuration">The configuration.</param>
19-
public OnnxModelSet(OnnxStackConfig configuration)
19+
public OnnxModelSet(IOnnxModelSetConfig configuration)
2020
{
2121
_configuration = configuration;
2222
_prePackedWeightsContainer = new PrePackedWeightsContainer();
@@ -32,37 +32,7 @@ public OnnxModelSet(OnnxStackConfig configuration)
3232
/// Gets the name.
3333
/// </summary>
3434
public string Name => _configuration.Name;
35-
36-
/// <summary>
37-
/// Gets the pad token identifier.
38-
/// </summary>
39-
public int PadTokenId => _configuration.PadTokenId;
40-
41-
/// <summary>
42-
/// Gets the blank token identifier.
43-
/// </summary>
44-
public int BlankTokenId => _configuration.BlankTokenId;
45-
46-
/// <summary>
47-
/// Gets the input token limit.
48-
/// </summary>
49-
public int InputTokenLimit => _configuration.InputTokenLimit;
50-
51-
/// <summary>
52-
/// Gets the tokenizer limit.
53-
/// </summary>
54-
public int TokenizerLimit => _configuration.TokenizerLimit;
55-
56-
/// <summary>
57-
/// Gets the length of the embeddings.
58-
/// </summary>
59-
public int EmbeddingsLength => _configuration.EmbeddingsLength;
60-
61-
/// <summary>
62-
/// Gets the scale factor.
63-
/// </summary>
64-
public float ScaleFactor => _configuration.ScaleFactor;
65-
35+
6636

6737
/// <summary>
6838
/// Checks the specified model type exists in the set.

0 commit comments

Comments
 (0)