Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d25d482

Browse files
committed
Save options to json with image output
1 parent 0309117 commit d25d482

File tree

2 files changed

+154
-33
lines changed

2 files changed

+154
-33
lines changed

OnnxStack.WebUI/Hubs/StableDiffusionHub.cs

Lines changed: 144 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,88 @@
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
44
using OnnxStack.WebUI.Models;
5-
using System;
65
using System.Runtime.CompilerServices;
6+
using System.Text.Json;
7+
using System.Text.Json.Serialization;
78

89
namespace OnnxStack.Web.Hubs
910
{
1011
public class StableDiffusionHub : Hub
1112
{
1213
private readonly ILogger<StableDiffusionHub> _logger;
13-
private readonly IStableDiffusionService _stableDiffusionService;
1414
private readonly IWebHostEnvironment _webHostEnvironment;
15+
private readonly JsonSerializerOptions _serializerOptions;
16+
private readonly IStableDiffusionService _stableDiffusionService;
17+
18+
19+
/// <summary>
20+
/// Initializes a new instance of the <see cref="StableDiffusionHub"/> class.
21+
/// </summary>
22+
/// <param name="logger">The logger.</param>
23+
/// <param name="stableDiffusionService">The stable diffusion service.</param>
24+
/// <param name="webHostEnvironment">The web host environment.</param>
1525
public StableDiffusionHub(ILogger<StableDiffusionHub> logger, IStableDiffusionService stableDiffusionService, IWebHostEnvironment webHostEnvironment)
1626
{
1727
_logger = logger;
1828
_webHostEnvironment = webHostEnvironment;
1929
_stableDiffusionService = stableDiffusionService;
30+
_serializerOptions = new JsonSerializerOptions { WriteIndented = true, Converters = { new JsonStringEnumConverter() } };
2031
}
2132

33+
34+
/// <summary>
35+
/// Called when a new connection is established with the hub.
36+
/// </summary>
2237
public override async Task OnConnectedAsync()
2338
{
2439
_logger.Log(LogLevel.Information, "[OnConnectedAsync], Id: {0}", Context.ConnectionId);
25-
2640
await Clients.Caller.SendAsync("OnMessage", "OnConnectedAsync");
2741
await base.OnConnectedAsync();
2842
}
2943

3044

45+
/// <summary>
46+
/// Called when a connection with the hub is terminated.
47+
/// </summary>
48+
/// <param name="exception"></param>
3149
public override async Task OnDisconnectedAsync(Exception exception)
3250
{
3351
_logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId);
34-
3552
await Clients.Caller.SendAsync("OnMessage", "OnDisconnectedAsync");
3653
await base.OnDisconnectedAsync(exception);
3754
}
3855

3956

57+
/// <summary>
58+
/// Execute Text-To-Image Stable Diffusion
59+
/// </summary>
60+
/// <param name="options">The options.</param>
61+
/// <param name="cancellationToken">The cancellation token.</param>
62+
/// <returns></returns>
4063
[HubMethodName("ExecuteTextToImage")]
41-
public async IAsyncEnumerable<DiffusionResult> OnExecuteTextToImage(TextToImageOptions options, [EnumeratorCancellation] CancellationToken cancellationToken)
64+
public async IAsyncEnumerable<TextToImageResult> OnExecuteTextToImage(TextToImageOptions options, [EnumeratorCancellation] CancellationToken cancellationToken)
4265
{
43-
_logger.Log(LogLevel.Information, "[OnExecuteTextToImage] - New prompt received, Connection: {0}", Context.ConnectionId);
44-
var linkedCancellationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken);
66+
_logger.Log(LogLevel.Information, "[OnExecuteTextToImage] - New request received, Connection: {0}", Context.ConnectionId);
67+
var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken);
68+
69+
// TODO: Add support for multiple results
70+
var result = await GenerateTextToImageResult(options, cancellationTokenSource.Token);
71+
if (result is null)
72+
yield break;
4573

74+
yield return result;
75+
}
76+
77+
78+
/// <summary>
79+
/// Generates the text to image result.
80+
/// </summary>
81+
/// <param name="options">The options.</param>
82+
/// <param name="cancellationToken">The cancellation token.</param>
83+
/// <returns></returns>
84+
private async Task<TextToImageResult> GenerateTextToImageResult(TextToImageOptions options, CancellationToken cancellationToken)
85+
{
86+
options.Seed = GenerateSeed(options.Seed);
4687
var promptOptions = new PromptOptions
4788
{
4889
Prompt = options.Prompt,
@@ -54,44 +95,103 @@ public async IAsyncEnumerable<DiffusionResult> OnExecuteTextToImage(TextToImageO
5495
{
5596
Width = options.Width,
5697
Height = options.Height,
57-
Seed = GenerateSeed(options.Seed),
98+
Seed = options.Seed,
5899
InferenceSteps = options.InferenceSteps,
59100
GuidanceScale = options.GuidanceScale,
60101
Strength = options.Strength,
61102
InitialNoiseLevel = options.InitialNoiseLevel
62103
};
63104

64-
// TODO: Add support for multiple results
65-
var result = await GenerateTextToImage(promptOptions, schedulerOptions, cancellationToken);
66-
if(result is null)
67-
yield break;
105+
var fileInfo = CreateFileInfo(promptOptions, schedulerOptions);
106+
if (!await SaveOptionsFile(fileInfo, options))
107+
return null;
68108

69-
yield return result;
109+
if (!await RunStableDiffusion(promptOptions, schedulerOptions, fileInfo, cancellationToken))
110+
return null;
111+
112+
return new TextToImageResult(fileInfo.OutputImage, fileInfo.OutputImageUrl, options, fileInfo.OutputOptionsUrl);
70113
}
71114

72-
private async Task<DiffusionResult> GenerateTextToImage(PromptOptions promptOptions, SchedulerOptions schedulerOptions, CancellationToken cancellationToken)
73-
{
74-
var rand = Path.GetFileNameWithoutExtension(Path.GetRandomFileName());
75-
var outputImage = $"{schedulerOptions.Seed}_{promptOptions.SchedulerType}_{rand}.png";
76-
var outputImageUrl = CreateOutputImageUrl("TextToImage", outputImage);
77-
var outputImageFile = CreateOutputImageFile(outputImageUrl);
78115

116+
/// <summary>
117+
/// Runs the stable diffusion.
118+
/// </summary>
119+
/// <param name="promptOptions">The prompt options.</param>
120+
/// <param name="schedulerOptions">The scheduler options.</param>
121+
/// <param name="fileInfo">The file information.</param>
122+
/// <param name="cancellationToken">The cancellation token.</param>
123+
/// <returns></returns>
124+
private async Task<bool> RunStableDiffusion(PromptOptions promptOptions, SchedulerOptions schedulerOptions, FileInfoResult fileInfo, CancellationToken cancellationToken)
125+
{
79126
try
80127
{
81-
await _stableDiffusionService.TextToImageFile(promptOptions, schedulerOptions, outputImageFile, ProgressCallback(), cancellationToken);
82-
return new DiffusionResult(outputImage, outputImageUrl);
128+
await _stableDiffusionService.TextToImageFile(promptOptions, schedulerOptions, fileInfo.OutputImageFile, ProgressCallback(), cancellationToken);
129+
return true;
83130
}
84131
catch (OperationCanceledException tex)
85132
{
86133
await Clients.Caller.SendAsync("OnCanceled", tex.Message);
134+
_logger.Log(LogLevel.Warning, tex, "[OnExecuteTextToImage] - Operation canceled, Connection: {0}", Context.ConnectionId);
87135
}
88136
catch (Exception ex)
89137
{
90138
await Clients.Caller.SendAsync("OnError", ex.Message);
139+
_logger.Log(LogLevel.Error, ex, "[OnExecuteTextToImage] - Error generating image, Connection: {0}", Context.ConnectionId);
140+
}
141+
return false;
142+
}
143+
144+
145+
/// <summary>
146+
/// Saves the options file.
147+
/// </summary>
148+
/// <param name="fileInfo">The file information.</param>
149+
/// <param name="options">The options.</param>
150+
/// <returns></returns>
151+
private async Task<bool> SaveOptionsFile(FileInfoResult fileInfo, TextToImageOptions options)
152+
{
153+
try
154+
{
155+
using (var stream = File.Create(fileInfo.OutputOptionsFile))
156+
{
157+
await JsonSerializer.SerializeAsync(stream, options, _serializerOptions);
158+
return true;
159+
}
160+
}
161+
catch (Exception ex)
162+
{
163+
_logger.Log(LogLevel.Error, ex, "[SaveOptions] - Error saving model card, Connection: {0}", Context.ConnectionId);
164+
return false;
91165
}
92-
return null;
93166
}
94167

168+
169+
/// <summary>
170+
/// Creates the file information.
171+
/// </summary>
172+
/// <param name="promptOptions">The prompt options.</param>
173+
/// <param name="schedulerOptions">The scheduler options.</param>
174+
/// <returns></returns>
175+
private FileInfoResult CreateFileInfo(PromptOptions promptOptions, SchedulerOptions schedulerOptions)
176+
{
177+
var rand = Path.GetFileNameWithoutExtension(Path.GetRandomFileName());
178+
var output = $"{schedulerOptions.Seed}_{promptOptions.SchedulerType}_{rand}";
179+
var outputImage = $"{output}.png";
180+
var outputImageUrl = CreateOutputUrl("TextToImage", outputImage);
181+
var outputImageFile = UrlToPhysicalPath(outputImageUrl);
182+
183+
var outputJson = $"{output}.json";
184+
var outputJsonUrl = CreateOutputUrl("TextToImage", outputJson);
185+
var outputJsonFile = UrlToPhysicalPath(outputJsonUrl);
186+
return new FileInfoResult(outputImage, outputImageUrl, outputImageFile, outputJson, outputJsonUrl, outputJsonFile);
187+
}
188+
189+
190+
/// <summary>
191+
/// Generates the seed.
192+
/// </summary>
193+
/// <param name="seed">The seed.</param>
194+
/// <returns></returns>
95195
private int GenerateSeed(int seed)
96196
{
97197
if (seed > 0)
@@ -100,6 +200,11 @@ private int GenerateSeed(int seed)
100200
return Random.Shared.Next();
101201
}
102202

203+
204+
/// <summary>
205+
/// Progress callback.
206+
/// </summary>
207+
/// <returns></returns>
103208
private Action<int, int> ProgressCallback()
104209
{
105210
return async (progress, total) =>
@@ -110,19 +215,32 @@ private Action<int, int> ProgressCallback()
110215
}
111216

112217

113-
private string CreateOutputImageFile(string url)
218+
/// <summary>
219+
/// URL path to physical path.
220+
/// </summary>
221+
/// <param name="url">The URL.</param>
222+
/// <returns></returns>
223+
private string UrlToPhysicalPath(string url)
114224
{
115225
string webRootPath = _webHostEnvironment.WebRootPath;
116226
string physicalPath = Path.Combine(webRootPath, url.TrimStart('/').Replace('/', '\\'));
117227
return physicalPath;
118228
}
119229

120-
private string CreateOutputImageUrl(string folder, string imageName)
230+
231+
/// <summary>
232+
/// Creates the output URL.
233+
/// </summary>
234+
/// <param name="folder">The folder.</param>
235+
/// <param name="file">The file.</param>
236+
/// <returns></returns>
237+
private string CreateOutputUrl(string folder, string file)
121238
{
122-
return $"/images/results/{folder}/{imageName}";
239+
return $"/images/results/{folder}/{file}";
123240
}
124241
}
125242

126243
public record ProgressResult(int Progress, int Total);
127-
public record DiffusionResult(string OutputImage, string OutputImageUrl);
244+
public record TextToImageResult(string OutputImage, string OutputImageUrl, TextToImageOptions OutputOptions, string OutputOptionsUrl);
245+
public record FileInfoResult(string OutputImage, string OutputImageUrl, string OutputImageFile, string OutputOptions, string OutputOptionsUrl, string OutputOptionsFile);
128246
}

OnnxStack.WebUI/Pages/StableDiffusion/TextToImage.cshtml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@
144144
let diffusionProcess;
145145
146146
const onResponse = (response) => {
147-
updateResultImage(response)
147+
if (!response)
148+
return;
149+
150+
console.log(response);
151+
updateResultImage(response);
148152
processEnd();
149153
}
150154
@@ -157,25 +161,23 @@
157161
}
158162
159163
const onProgress = (response) => {
160-
console.log(response)
164+
console.log(response);
161165
updateProgress(response);
162166
}
163167
164168
const onCanceled = (response) => {
165-
console.log(response)
169+
console.log(response);
166170
outputContainer.html('');
167171
processEnd();
168172
}
169173
170174
const executeTextToImage = async () => {
171175
172-
// Validate form
173176
const diffusionParams = serializeFormToJson(optionsForm);
174177
if (!validateForm())
175178
return;
176179
177180
processBegin();
178-
179181
outputContainer.html(Mustache.render(progressResultTemplate));
180182
diffusionProcess = await connection
181183
.stream("ExecuteTextToImage", diffusionParams)
@@ -197,8 +199,9 @@
197199
const updateProgress = (response) => {
198200
const increment = Math.max(100 / response.total, 1);
199201
const progressPercent = Math.round(Math.min(increment * response.progress, 100), 0);
200-
$("#progress-result").css("width", progressPercent + "%");
201-
$("#progress-result").text(progressPercent + "%");
202+
const progressBar = $("#progress-result");
203+
progressBar.css("width", progressPercent + "%");
204+
progressBar.text(progressPercent + "%");
202205
}
203206
204207
const processBegin = () => {

0 commit comments

Comments
 (0)