Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit a889bd1

Browse files
committed
Support splitting models across multiple GPU/Devices
1 parent 656d452 commit a889bd1

File tree

14 files changed

+485
-170
lines changed

14 files changed

+485
-170
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,75 @@
11
{
2-
"Logging": {
3-
"LogLevel": {
4-
"Default": "Information",
5-
"Microsoft.AspNetCore": "Warning"
6-
}
7-
},
8-
"AllowedHosts": "*",
9-
"OnnxStackConfig": {
10-
"DeviceId": 0,
11-
"InterOpNumThreads": 0,
12-
"IntraOpNumThreads": 0,
13-
"ExecutionMode": "ORT_PARALLEL",
14-
"IsSafetyModelEnabled": false,
15-
"ExecutionProviderTarget": "DirectML",
16-
"OnnxUnetPath": "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx",
17-
"OnnxTokenizerPath": "D:\\Repositories\\stable-diffusion-v1-5\\cliptokenizer.onnx",
18-
"OnnxVaeDecoderPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_decoder\\model.onnx",
19-
"OnnxTextEncoderPath": "D:\\Repositories\\stable-diffusion-v1-5\\text_encoder\\model.onnx",
20-
"OnnxSafetyModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\safety_checker\\model.onnx",
21-
"OnnxVaeEncoderPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_encoder\\model.onnx"
22-
}
2+
"Logging": {
3+
"LogLevel": {
4+
"Default": "Information",
5+
"Microsoft.AspNetCore": "Warning"
6+
}
7+
},
8+
"AllowedHosts": "*",
9+
"OnnxStackConfig": {
10+
"Name": "StableDiffusion 1.5",
11+
"PadTokenId": 49407,
12+
"BlankTokenId": 49407,
13+
"InputTokenLimit": 512,
14+
"TokenizerLimit": 77,
15+
"EmbeddingsLength": 768,
16+
"ScaleFactor": 0.18215,
17+
"ModelConfigurations": [
18+
{
19+
"Type": "Unet",
20+
"DeviceId": 0,
21+
"InterOpNumThreads": 0,
22+
"IntraOpNumThreads": 0,
23+
"ExecutionMode": "ORT_PARALLEL",
24+
"ExecutionProvider": "DirectML",
25+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\unet\\model.onnx"
26+
},
27+
{
28+
"Type": "Tokenizer",
29+
"DeviceId": 0,
30+
"InterOpNumThreads": 0,
31+
"IntraOpNumThreads": 0,
32+
"ExecutionMode": "ORT_PARALLEL",
33+
"ExecutionProvider": "Cpu",
34+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\cliptokenizer.onnx"
35+
},
36+
{
37+
"Type": "TextEncoder",
38+
"DeviceId": 0,
39+
"InterOpNumThreads": 0,
40+
"IntraOpNumThreads": 0,
41+
"ExecutionMode": "ORT_PARALLEL",
42+
"ExecutionProvider": "Cpu",
43+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\text_encoder\\model.onnx"
44+
},
45+
{
46+
"Type": "VaeEncoder",
47+
"DeviceId": 0,
48+
"InterOpNumThreads": 0,
49+
"IntraOpNumThreads": 0,
50+
"ExecutionMode": "ORT_PARALLEL",
51+
"ExecutionProvider": "Cpu",
52+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_encoder\\model.onnx"
53+
},
54+
{
55+
"Type": "VaeDecoder",
56+
"DeviceId": 0,
57+
"InterOpNumThreads": 0,
58+
"IntraOpNumThreads": 0,
59+
"ExecutionMode": "ORT_PARALLEL",
60+
"ExecutionProvider": "Cpu",
61+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\vae_decoder\\model.onnx"
62+
},
63+
{
64+
"Type": "SafetyChecker",
65+
"IsDisabled": true,
66+
"DeviceId": 0,
67+
"InterOpNumThreads": 0,
68+
"IntraOpNumThreads": 0,
69+
"ExecutionMode": "ORT_PARALLEL",
70+
"ExecutionProvider": "Cpu",
71+
"OnnxModelPath": "D:\\Repositories\\stable-diffusion-v1-5\\safety_checker\\model.onnx"
72+
}
73+
]
74+
}
2375
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
3+
namespace OnnxStack.Core.Config
4+
{
5+
public class OnnxModelSessionConfig
6+
{
7+
public OnnxModelType Type { get; set; }
8+
public bool IsDisabled { get; set; }
9+
public int DeviceId { get; set; }
10+
public string OnnxModelPath { get; set; }
11+
public int InterOpNumThreads { get; set; }
12+
public int IntraOpNumThreads { get; set; }
13+
public ExecutionMode ExecutionMode { get; set; }
14+
public ExecutionProvider ExecutionProvider { get; set; }
15+
}
16+
}

OnnxStack.Core/Config/OnnxModelType.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
public enum OnnxModelType
44
{
55
Unet = 0,
6-
Tokenizer = 1,
7-
VaeDecoder = 2,
8-
TextEncoder = 3,
9-
SafetyModel = 4,
10-
VaeEncoder = 5,
6+
Tokenizer = 10,
7+
TextEncoder = 20,
8+
VaeEncoder = 30,
9+
VaeDecoder = 40,
10+
SafetyChecker = 100,
1111
}
1212
}
Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,25 @@
1-
using Microsoft.ML.OnnxRuntime;
2-
using OnnxStack.Common.Config;
3-
using System;
4-
using System.Text.Json.Serialization;
1+
using OnnxStack.Common.Config;
2+
using System.Collections.Generic;
3+
using System.Collections.Immutable;
4+
using System.Linq;
55

66
namespace OnnxStack.Core.Config
77
{
88
public class OnnxStackConfig : IConfigSection
99
{
10-
/// <summary>
11-
/// Gets or sets the device identifier.
12-
/// </summary>
13-
/// <value>
14-
/// The device identifier used by DirectML and CUDA.
15-
/// </value>
16-
public int DeviceId { get; set; }
17-
18-
/// <summary>
19-
/// Gets or sets the execution provider target.
20-
/// </summary>
21-
public ExecutionProvider ExecutionProviderTarget { get; set; } = ExecutionProvider.DirectML;
22-
23-
public string OnnxTokenizerPath { get; set; }
24-
public string OnnxUnetPath { get; set; }
25-
public string OnnxVaeDecoderPath { get; set; }
26-
public string OnnxVaeEncoderPath { get; set; }
27-
public string OnnxTextEncoderPath { get; set; }
28-
public string OnnxSafetyModelPath { get; set; }
29-
public bool IsSafetyModelEnabled { get; set; }
30-
public ExecutionMode ExecutionMode { get; set; }
31-
public int InterOpNumThreads { get; set; }
32-
public int IntraOpNumThreads { get; set; }
10+
public string Name { get; set; }
11+
public int PadTokenId { get; set; }
12+
public int BlankTokenId { get; set; }
13+
public int InputTokenLimit { get; set; }
14+
public int TokenizerLimit { get; set; }
15+
public int EmbeddingsLength { get; set; }
16+
public float ScaleFactor { get; set; }
17+
public List<OnnxModelSessionConfig> ModelConfigurations { get; set; }
18+
public ImmutableArray<int> BlankTokenValueArray { get; set; }
3319

3420
public void Initialize()
3521
{
22+
BlankTokenValueArray = Enumerable.Repeat(BlankTokenId, InputTokenLimit).ToImmutableArray();
3623
}
3724
}
3825
}

OnnxStack.Core/Constants.cs

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,9 @@
11
using System.Collections.Generic;
2-
using System.Collections.Immutable;
3-
using System.Linq;
42

53
namespace OnnxStack.Core
64
{
75
public static class Constants
86
{
9-
/// <summary>
10-
/// The blank token value
11-
/// </summary>
12-
public const int BlankTokenValue = 49407;
13-
14-
/// <summary>
15-
/// The maximum input token count
16-
/// </summary>
17-
public const int MaxInputTokenCount = 2048;
18-
19-
/// <summary>
20-
/// The clip tokenizer input token limit
21-
/// </summary>
22-
public const int ClipTokenizerTokenLimit = 77;
23-
24-
/// <summary>
25-
/// The clip tokenizer embeddings length
26-
/// </summary>
27-
public const int ClipTokenizerEmbeddingsLength = 768;
28-
29-
/// <summary>
30-
/// The model scale factor
31-
/// </summary>
32-
public const float ModelScaleFactor = 0.18215f;
33-
34-
/// <summary>
35-
/// The cached blank token value array
36-
/// </summary>
37-
public static readonly ImmutableArray<int> BlankTokenValueArray;
38-
397
/// <summary>
408
/// The width/height valid sizes
419
/// </summary>
@@ -44,7 +12,6 @@ public static class Constants
4412
static Constants()
4513
{
4614
// Cache an array with enough blank tokens to fill an empty prompt
47-
BlankTokenValueArray = Enumerable.Repeat(BlankTokenValue, MaxInputTokenCount).ToImmutableArray();
4815
ValidSizes = new List<int> { 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024 };
4916
}
5017
}

OnnxStack.Core/Extensions.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,15 @@ namespace OnnxStack.Core
88
{
99
public static class Extensions
1010
{
11-
/// <summary>
12-
/// Gets the Onnx InferenceSession options.
13-
/// </summary>
14-
/// <param name="configuration">The configuration.</param>
15-
/// <returns></returns>
16-
public static SessionOptions GetSessionOptions(this OnnxStackConfig configuration)
11+
public static SessionOptions GetSessionOptions(this OnnxModelSessionConfig configuration)
1712
{
1813
var sessionOptions = new SessionOptions
1914
{
2015
ExecutionMode = configuration.ExecutionMode,
2116
InterOpNumThreads = configuration.InterOpNumThreads,
2217
IntraOpNumThreads = configuration.InterOpNumThreads
2318
};
24-
switch (configuration.ExecutionProviderTarget)
19+
switch (configuration.ExecutionProvider)
2520
{
2621
case ExecutionProvider.DirectML:
2722
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using Microsoft.Extensions.Configuration;
2+
using Microsoft.ML.OnnxRuntime;
3+
using OnnxStack.Core.Config;
4+
using System;
5+
using System.IO;
6+
7+
namespace OnnxStack.Core.Model
8+
{
9+
public class OnnxModelSession : IDisposable
10+
{
11+
private readonly SessionOptions _options;
12+
private readonly InferenceSession _session;
13+
private readonly OnnxModelSessionConfig _configuration;
14+
15+
/// <summary>
16+
/// Initializes a new instance of the <see cref="OnnxModelSession"/> class.
17+
/// </summary>
18+
/// <param name="configuration">The configuration.</param>
19+
/// <param name="container">The container.</param>
20+
/// <exception cref="System.IO.FileNotFoundException">Onnx model file not found</exception>
21+
public OnnxModelSession(OnnxModelSessionConfig configuration, PrePackedWeightsContainer container)
22+
{
23+
if (!File.Exists(configuration.OnnxModelPath))
24+
throw new FileNotFoundException("Onnx model file not found", configuration.OnnxModelPath);
25+
26+
_configuration = configuration;
27+
_options = configuration.GetSessionOptions();
28+
_options.RegisterOrtExtensions();
29+
_session = new InferenceSession(_configuration.OnnxModelPath, _options, container);
30+
}
31+
32+
33+
/// <summary>
34+
/// Gets the SessionOptions.
35+
/// </summary>
36+
public SessionOptions Options => _options;
37+
38+
39+
/// <summary>
40+
/// Gets the InferenceSession.
41+
/// </summary>
42+
public InferenceSession Session => _session;
43+
44+
45+
/// <summary>
46+
/// Gets the configuration.
47+
/// </summary>
48+
public OnnxModelSessionConfig Configuration => _configuration;
49+
50+
51+
/// <summary>
52+
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
53+
/// </summary>
54+
public void Dispose()
55+
{
56+
_options?.Dispose();
57+
_session?.Dispose();
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)