Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 9692289

Browse files
committed
StableDiffusionXL support added to ModelView
1 parent 4837ae4 commit 9692289

File tree

7 files changed

+303
-16
lines changed

7 files changed

+303
-16
lines changed

OnnxStack.UI/App.xaml.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public App()
3333
builder.Services.AddTransient<MessageDialog>();
3434
builder.Services.AddTransient<TextInputDialog>();
3535
builder.Services.AddTransient<CropImageDialog>();
36+
builder.Services.AddTransient<AddModelDialog>();
3637
builder.Services.AddSingleton<IDialogService, DialogService>();
3738
builder.Services.AddSingleton<IModelDownloadService, ModelDownloadService>();
3839

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
<Window x:Class="OnnxStack.UI.Dialogs.AddModelDialog"
2+
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
3+
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
4+
xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
5+
xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006"
6+
mc:Ignorable="d"
7+
Name="UI"
8+
Icon="/Images/Icon.png"
9+
MinWidth="400"
10+
SizeToContent="WidthAndHeight"
11+
WindowStartupLocation="CenterOwner"
12+
SnapsToDevicePixels="True"
13+
UseLayoutRounding="True"
14+
Style="{StaticResource BaseWindow}"
15+
ContentRendered="OnContentRendered">
16+
<DockPanel DataContext="{Binding ElementName=UI}" Margin="15, 15, 15, 10">
17+
<StackPanel DockPanel.Dock="Top">
18+
<TextBlock Text="{Binding ErrorMessage}" FontSize="13" FontWeight="DemiBold" HorizontalAlignment="Center" Foreground="Red" Margin="0,10">
19+
<TextBlock.Style>
20+
<Style TargetType="{x:Type TextBlock}">
21+
<Setter Property="Visibility" Value="Visible" />
22+
<Style.Triggers>
23+
<DataTrigger Binding="{Binding ErrorMessage.Length, ElementName=UI}" Value="0">
24+
<Setter Property="Visibility" Value="Collapsed" />
25+
</DataTrigger>
26+
</Style.Triggers>
27+
</Style>
28+
</TextBlock.Style>
29+
</TextBlock>
30+
<TextBlock Text="Model Pipleine"/>
31+
<ComboBox ItemsSource="{Binding Source={StaticResource DiffuserPipelineType}}" SelectedItem="{Binding PipelineType}" />
32+
<TextBlock Text="Model Name"/>
33+
<TextBox Text="{Binding TextResult, UpdateSourceTrigger=PropertyChanged}" />
34+
</StackPanel>
35+
<StackPanel Orientation="Horizontal" HorizontalAlignment="Right" Margin="0,20,0,0">
36+
<UniformGrid Columns="2" Height="30">
37+
<Button Content="Ok" Command="{Binding SaveCommand}" IsDefault="True"/>
38+
<Button Content="Cancel" Command="{Binding CancelCommand}" Width="100"/>
39+
</UniformGrid>
40+
</StackPanel>
41+
</DockPanel>
42+
</Window>
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
using Microsoft.Extensions.Logging;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Enums;
4+
using OnnxStack.UI.Commands;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.ComponentModel;
8+
using System.Runtime.CompilerServices;
9+
using System.Threading.Tasks;
10+
using System.Windows;
11+
12+
namespace OnnxStack.UI.Dialogs
13+
{
14+
/// <summary>
15+
/// Interaction logic for AddModelDialog.xaml
16+
/// </summary>
17+
public partial class AddModelDialog : Window, INotifyPropertyChanged
18+
{
19+
private readonly ILogger<AddModelDialog> _logger;
20+
21+
private string _textResult;
22+
private string _errorMessage;
23+
private List<string> _invalidOptions;
24+
private DiffuserPipelineType _pipelineType;
25+
26+
public AddModelDialog(ILogger<AddModelDialog> logger)
27+
{
28+
_logger = logger;
29+
WindowCloseCommand = new AsyncRelayCommand(WindowClose);
30+
WindowRestoreCommand = new AsyncRelayCommand(WindowRestore);
31+
WindowMinimizeCommand = new AsyncRelayCommand(WindowMinimize);
32+
WindowMaximizeCommand = new AsyncRelayCommand(WindowMaximize);
33+
SaveCommand = new AsyncRelayCommand(Save, CanExecuteSave);
34+
CancelCommand = new AsyncRelayCommand(Cancel, CanExecuteCancel);
35+
InitializeComponent();
36+
ErrorMessage = string.Empty;
37+
}
38+
public AsyncRelayCommand WindowMinimizeCommand { get; }
39+
public AsyncRelayCommand WindowRestoreCommand { get; }
40+
public AsyncRelayCommand WindowMaximizeCommand { get; }
41+
public AsyncRelayCommand WindowCloseCommand { get; }
42+
public AsyncRelayCommand SaveCommand { get; }
43+
public AsyncRelayCommand CancelCommand { get; }
44+
45+
public DiffuserPipelineType PipelineType
46+
{
47+
get { return _pipelineType; }
48+
set { _pipelineType = value; NotifyPropertyChanged(); }
49+
}
50+
51+
52+
public string TextResult
53+
{
54+
get { return _textResult; }
55+
set { _textResult = value; NotifyPropertyChanged(); ErrorMessage = string.Empty; }
56+
}
57+
58+
public List<string> InvalidOptions
59+
{
60+
get { return _invalidOptions; }
61+
set { _invalidOptions = value; NotifyPropertyChanged(); }
62+
}
63+
64+
65+
public string ErrorMessage
66+
{
67+
get { return _errorMessage; }
68+
set { _errorMessage = value; NotifyPropertyChanged(); }
69+
}
70+
71+
72+
public bool ShowDialog(string title, List<string> invalidOptions = null)
73+
{
74+
Title = title;
75+
InvalidOptions = invalidOptions;
76+
return ShowDialog() ?? false;
77+
}
78+
79+
80+
private Task Save()
81+
{
82+
var result = TextResult.Trim();
83+
if (!InvalidOptions.IsNullOrEmpty() && InvalidOptions.Contains(result))
84+
{
85+
ErrorMessage = $"{result} is an invalid option";
86+
return Task.CompletedTask;
87+
}
88+
89+
_textResult = result;
90+
DialogResult = true;
91+
return Task.CompletedTask;
92+
}
93+
94+
private bool CanExecuteSave()
95+
{
96+
var result = TextResult?.Trim() ?? string.Empty;
97+
return result.Length > 2 && result.Length <= 24;
98+
}
99+
100+
private Task Cancel()
101+
{
102+
DialogResult = false;
103+
return Task.CompletedTask;
104+
}
105+
106+
private bool CanExecuteCancel()
107+
{
108+
return true;
109+
}
110+
111+
#region BaseWindow
112+
113+
private Task WindowClose()
114+
{
115+
Close();
116+
return Task.CompletedTask;
117+
}
118+
119+
private Task WindowRestore()
120+
{
121+
if (WindowState == WindowState.Maximized)
122+
WindowState = WindowState.Normal;
123+
else
124+
WindowState = WindowState.Maximized;
125+
return Task.CompletedTask;
126+
}
127+
128+
private Task WindowMinimize()
129+
{
130+
WindowState = WindowState.Minimized;
131+
return Task.CompletedTask;
132+
}
133+
134+
private Task WindowMaximize()
135+
{
136+
WindowState = WindowState.Maximized;
137+
return Task.CompletedTask;
138+
}
139+
140+
private void OnContentRendered(object sender, EventArgs e)
141+
{
142+
InvalidateVisual();
143+
}
144+
#endregion
145+
146+
#region INotifyPropertyChanged
147+
public event PropertyChangedEventHandler PropertyChanged;
148+
public void NotifyPropertyChanged([CallerMemberName] string property = "")
149+
{
150+
PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(property));
151+
}
152+
#endregion
153+
}
154+
}

OnnxStack.UI/Models/ModelConfigTemplate.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ public class ModelConfigTemplate
1111
public string Repository { get; set; }
1212
public string ImageIcon { get; set; }
1313
public ModelTemplateStatus Status { get; set; }
14+
public int SampleSize { get; set; }
1415
public int PadTokenId { get; set; }
1516
public int BlankTokenId { get; set; }
1617
public int TokenizerLimit { get; set; }
18+
public bool IsDualTokenizer { get; set; }
1719
public int EmbeddingsLength { get; set; }
20+
public int DualEmbeddingsLength { get; set; }
1821
public float ScaleFactor { get; set; }
1922
public DiffuserPipelineType PipelineType { get; set; }
2023
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();

OnnxStack.UI/Models/ModelSetViewModel.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ public class ModelSetViewModel : INotifyPropertyChanged
3636
private double _progressValue;
3737
private bool _isDownloading;
3838
private bool _hasChanged;
39+
private int _dualEmbeddingsLength;
40+
private bool _isDualTokenizer;
41+
private int _sampleSize;
3942

4043
public string Name
4144
{
@@ -59,17 +62,37 @@ public int BlankTokenId
5962
get { return _blankTokenId; }
6063
set { _blankTokenId = value; NotifyPropertyChanged(); }
6164
}
65+
66+
public int SampleSize
67+
{
68+
get { return _sampleSize; }
69+
set { _sampleSize = value; NotifyPropertyChanged(); }
70+
}
71+
6272
public float ScaleFactor
6373
{
6474
get { return _scaleFactor; }
6575
set { _scaleFactor = value; NotifyPropertyChanged(); }
6676
}
77+
6778
public int TokenizerLimit
6879
{
6980
get { return _tokenizerLimit; }
7081
set { _tokenizerLimit = value; NotifyPropertyChanged(); }
7182
}
7283

84+
public bool IsDualTokenizer
85+
{
86+
get { return _isDualTokenizer; }
87+
set { _isDualTokenizer = value; NotifyPropertyChanged(); }
88+
}
89+
90+
public int DualEmbeddingsLength
91+
{
92+
get { return _dualEmbeddingsLength; }
93+
set { _dualEmbeddingsLength = value; NotifyPropertyChanged(); }
94+
}
95+
7396
public int EmbeddingsLength
7497
{
7598
get { return _embeddingsLength; }

OnnxStack.UI/Views/ModelView.xaml.cs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -444,18 +444,13 @@ private bool SetModelPaths(ModelSetViewModel modelSet, string modelDirectory)
444444
var textEncoderPath = Path.Combine(modelDirectory, "text_encoder", "model.onnx");
445445
var vaeDecoder = Path.Combine(modelDirectory, "vae_decoder", "model.onnx");
446446
var vaeEncoder = Path.Combine(modelDirectory, "vae_encoder", "model.onnx");
447+
var tokenizer2Path = Path.Combine(modelDirectory, "tokenizer_2", "model.onnx");
448+
var textEncoder2Path = Path.Combine(modelDirectory, "text_encoder_2", "model.onnx");
449+
447450
if (!File.Exists(tokenizerPath))
448451
tokenizerPath = _defaultTokenizerPath;
449-
450-
// Validate Files
451-
foreach (var modelFile in new[] { unetPath, tokenizerPath, textEncoderPath, vaeDecoder, vaeEncoder })
452-
{
453-
if (!File.Exists(modelFile))
454-
{
455-
_logger.LogError($"Model file not found, ModelFile: {modelFile}");
456-
return false;
457-
}
458-
}
452+
if (!File.Exists(tokenizer2Path))
453+
tokenizer2Path = _defaultTokenizerPath;
459454

460455
// Set Model Paths
461456
foreach (var modelConfig in modelSet.ModelFiles)
@@ -467,6 +462,8 @@ private bool SetModelPaths(ModelSetViewModel modelSet, string modelDirectory)
467462
OnnxModelType.TextEncoder => textEncoderPath,
468463
OnnxModelType.VaeDecoder => vaeDecoder,
469464
OnnxModelType.VaeEncoder => vaeEncoder,
465+
OnnxModelType.Tokenizer2 => tokenizer2Path,
466+
OnnxModelType.TextEncoder2 => textEncoder2Path,
470467
_ => default
471468
};
472469
}
@@ -565,10 +562,10 @@ private async Task<bool> SaveModelAsync(ModelSetViewModel modelSet)
565562
private Task Add()
566563
{
567564
var invalidNames = ModelSets.Select(x => x.Name).ToList();
568-
var textInputDialog = _dialogService.GetDialog<TextInputDialog>();
569-
if (textInputDialog.ShowDialog("Add Model Set", "Name", 1, 30, invalidNames))
565+
var textInputDialog = _dialogService.GetDialog<AddModelDialog>();
566+
if (textInputDialog.ShowDialog("Add Model Set", invalidNames))
570567
{
571-
var models = Enum.GetValues<OnnxModelType>().Select(x => new ModelFileViewModel { Type = x });
568+
var pipeline = textInputDialog.PipelineType;
572569
var newModelTemplate = new ModelConfigTemplate
573570
{
574571
Name = textInputDialog.TextResult,
@@ -580,11 +577,13 @@ private Task Add()
580577
Images = Enumerable.Range(0, 6).Select(x => string.Empty).ToList(),
581578

582579
// TODO: Select pipleine in dialog, then setting any required bits
583-
PipelineType = DiffuserPipelineType.StableDiffusion,
584-
ScaleFactor = 0.18215f,
580+
PipelineType = pipeline,
581+
ScaleFactor = pipeline == DiffuserPipelineType.StableDiffusionXL ? 0.13025f : 0.18215f,
585582
TokenizerLimit = 77,
586-
PadTokenId = 49407,
583+
PadTokenId = pipeline == DiffuserPipelineType.StableDiffusionXL ? 1 : 49407,
587584
EmbeddingsLength = 768,
585+
DualEmbeddingsLength = 1280,
586+
IsDualTokenizer = pipeline == DiffuserPipelineType.StableDiffusionXL,
588587
BlankTokenId = 49407,
589588
Diffusers = Enum.GetValues<DiffuserType>().ToList(),
590589
};
@@ -935,6 +934,9 @@ private ModelSetViewModel CreateViewModel(ModelConfigTemplate modelTemplate)
935934
PadTokenId = modelTemplate.PadTokenId,
936935
ScaleFactor = modelTemplate.ScaleFactor,
937936
TokenizerLimit = modelTemplate.TokenizerLimit,
937+
IsDualTokenizer = modelTemplate.IsDualTokenizer,
938+
SampleSize = modelTemplate.SampleSize,
939+
DualEmbeddingsLength = modelTemplate.DualEmbeddingsLength,
938940
PipelineType = modelTemplate.PipelineType,
939941
EnableTextToImage = modelTemplate.Diffusers.Contains(DiffuserType.TextToImage),
940942
EnableImageToImage = modelTemplate.Diffusers.Contains(DiffuserType.ImageToImage),
@@ -957,6 +959,9 @@ private ModelSetViewModel CreateViewModel(ModelConfigTemplate modelTemplate)
957959
Images = modelTemplate.Images,
958960
ModelFiles = modelTemplate.ModelFiles.ToList(),
959961
Repository = modelTemplate.Repository,
962+
IsDualTokenizer = modelTemplate.IsDualTokenizer,
963+
SampleSize = modelTemplate.SampleSize,
964+
DualEmbeddingsLength = modelTemplate.DualEmbeddingsLength,
960965
Status = ModelTemplateStatus.Installed
961966
}
962967
};
@@ -986,6 +991,9 @@ private ModelSetViewModel CreateViewModel(ModelOptions modelOptions)
986991
InterOpNumThreads = modelOptions.InterOpNumThreads,
987992
PadTokenId = modelOptions.PadTokenId,
988993
ScaleFactor = modelOptions.ScaleFactor,
994+
IsDualTokenizer = modelOptions.IsDualTokenizer,
995+
SampleSize = modelOptions.SampleSize,
996+
DualEmbeddingsLength = modelOptions.DualEmbeddingsLength,
989997
TokenizerLimit = modelOptions.TokenizerLimit,
990998
PipelineType = modelOptions.PipelineType,
991999
EnableTextToImage = modelOptions.Diffusers.Contains(DiffuserType.TextToImage),
@@ -1017,6 +1025,9 @@ private ModelSetViewModel CreateViewModel(ModelOptions modelOptions)
10171025
ScaleFactor = modelOptions.ScaleFactor,
10181026
TokenizerLimit = modelOptions.TokenizerLimit,
10191027
PipelineType = modelOptions.PipelineType,
1028+
IsDualTokenizer = modelOptions.IsDualTokenizer,
1029+
SampleSize = modelOptions.SampleSize,
1030+
DualEmbeddingsLength = modelOptions.DualEmbeddingsLength,
10201031
Description = "",
10211032
Diffusers = modelOptions.Diffusers,
10221033
EmbeddingsLength = modelOptions.EmbeddingsLength,
@@ -1053,6 +1064,9 @@ private ModelOptions CreateModelOptions(ModelSetViewModel editModel)
10531064
TokenizerLimit = editModel.TokenizerLimit,
10541065
PipelineType = editModel.PipelineType,
10551066
Diffusers = new List<DiffuserType>(editModel.GetDiffusers()),
1067+
DualEmbeddingsLength = editModel.DualEmbeddingsLength,
1068+
SampleSize = editModel.SampleSize,
1069+
IsDualTokenizer = editModel.IsDualTokenizer,
10561070
ModelConfigurations = new List<OnnxModelSessionConfig>(editModel.ModelFiles.Select(x => new OnnxModelSessionConfig
10571071
{
10581072
Type = x.Type,

0 commit comments

Comments
 (0)