22using OnnxStack . StableDiffusion . Common ;
33using OnnxStack . StableDiffusion . Config ;
44using OnnxStack . WebUI . Models ;
5- using System ;
65using System . Runtime . CompilerServices ;
6+ using System . Text . Json ;
7+ using System . Text . Json . Serialization ;
78
89namespace OnnxStack . Web . Hubs
910{
1011 public class StableDiffusionHub : Hub
1112 {
1213 private readonly ILogger < StableDiffusionHub > _logger ;
13- private readonly IStableDiffusionService _stableDiffusionService ;
1414 private readonly IWebHostEnvironment _webHostEnvironment ;
15+ private readonly JsonSerializerOptions _serializerOptions ;
16+ private readonly IStableDiffusionService _stableDiffusionService ;
17+
18+
19+ /// <summary>
20+ /// Initializes a new instance of the <see cref="StableDiffusionHub"/> class.
21+ /// </summary>
22+ /// <param name="logger">The logger.</param>
23+ /// <param name="stableDiffusionService">The stable diffusion service.</param>
24+ /// <param name="webHostEnvironment">The web host environment.</param>
1525 public StableDiffusionHub ( ILogger < StableDiffusionHub > logger , IStableDiffusionService stableDiffusionService , IWebHostEnvironment webHostEnvironment )
1626 {
1727 _logger = logger ;
1828 _webHostEnvironment = webHostEnvironment ;
1929 _stableDiffusionService = stableDiffusionService ;
30+ _serializerOptions = new JsonSerializerOptions { WriteIndented = true , Converters = { new JsonStringEnumConverter ( ) } } ;
2031 }
2132
33+
34+ /// <summary>
35+ /// Called when a new connection is established with the hub.
36+ /// </summary>
2237 public override async Task OnConnectedAsync ( )
2338 {
2439 _logger . Log ( LogLevel . Information , "[OnConnectedAsync], Id: {0}" , Context . ConnectionId ) ;
25-
2640 await Clients . Caller . SendAsync ( "OnMessage" , "OnConnectedAsync" ) ;
2741 await base . OnConnectedAsync ( ) ;
2842 }
2943
3044
45+ /// <summary>
46+ /// Called when a connection with the hub is terminated.
47+ /// </summary>
48+ /// <param name="exception"></param>
3149 public override async Task OnDisconnectedAsync ( Exception exception )
3250 {
3351 _logger . Log ( LogLevel . Information , "[OnDisconnectedAsync], Id: {0}" , Context . ConnectionId ) ;
34-
3552 await Clients . Caller . SendAsync ( "OnMessage" , "OnDisconnectedAsync" ) ;
3653 await base . OnDisconnectedAsync ( exception ) ;
3754 }
3855
3956
57+ /// <summary>
58+ /// Execute Text-To-Image Stable Diffusion
59+ /// </summary>
60+ /// <param name="options">The options.</param>
61+ /// <param name="cancellationToken">The cancellation token.</param>
62+ /// <returns></returns>
4063 [ HubMethodName ( "ExecuteTextToImage" ) ]
41- public async IAsyncEnumerable < DiffusionResult > OnExecuteTextToImage ( TextToImageOptions options , [ EnumeratorCancellation ] CancellationToken cancellationToken )
64+ public async IAsyncEnumerable < TextToImageResult > OnExecuteTextToImage ( TextToImageOptions options , [ EnumeratorCancellation ] CancellationToken cancellationToken )
4265 {
43- _logger . Log ( LogLevel . Information , "[OnExecuteTextToImage] - New prompt received, Connection: {0}" , Context . ConnectionId ) ;
44- var linkedCancellationToken = CancellationTokenSource . CreateLinkedTokenSource ( Context . ConnectionAborted , cancellationToken ) ;
66+ _logger . Log ( LogLevel . Information , "[OnExecuteTextToImage] - New request received, Connection: {0}" , Context . ConnectionId ) ;
67+ var cancellationTokenSource = CancellationTokenSource . CreateLinkedTokenSource ( Context . ConnectionAborted , cancellationToken ) ;
68+
69+ // TODO: Add support for multiple results
70+ var result = await GenerateTextToImageResult ( options , cancellationTokenSource . Token ) ;
71+ if ( result is null )
72+ yield break ;
4573
74+ yield return result ;
75+ }
76+
77+
78+ /// <summary>
79+ /// Generates the text to image result.
80+ /// </summary>
81+ /// <param name="options">The options.</param>
82+ /// <param name="cancellationToken">The cancellation token.</param>
83+ /// <returns></returns>
84+ private async Task < TextToImageResult > GenerateTextToImageResult ( TextToImageOptions options , CancellationToken cancellationToken )
85+ {
86+ options . Seed = GenerateSeed ( options . Seed ) ;
4687 var promptOptions = new PromptOptions
4788 {
4889 Prompt = options . Prompt ,
@@ -54,44 +95,103 @@ public async IAsyncEnumerable<DiffusionResult> OnExecuteTextToImage(TextToImageO
5495 {
5596 Width = options . Width ,
5697 Height = options . Height ,
57- Seed = GenerateSeed ( options . Seed ) ,
98+ Seed = options . Seed ,
5899 InferenceSteps = options . InferenceSteps ,
59100 GuidanceScale = options . GuidanceScale ,
60101 Strength = options . Strength ,
61102 InitialNoiseLevel = options . InitialNoiseLevel
62103 } ;
63104
64- // TODO: Add support for multiple results
65- var result = await GenerateTextToImage ( promptOptions , schedulerOptions , cancellationToken ) ;
66- if ( result is null )
67- yield break ;
105+ var fileInfo = CreateFileInfo ( promptOptions , schedulerOptions ) ;
106+ if ( ! await SaveOptionsFile ( fileInfo , options ) )
107+ return null ;
68108
69- yield return result ;
109+ if ( ! await RunStableDiffusion ( promptOptions , schedulerOptions , fileInfo , cancellationToken ) )
110+ return null ;
111+
112+ return new TextToImageResult ( fileInfo . OutputImage , fileInfo . OutputImageUrl , options , fileInfo . OutputOptionsUrl ) ;
70113 }
71114
72- private async Task < DiffusionResult > GenerateTextToImage ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , CancellationToken cancellationToken )
73- {
74- var rand = Path . GetFileNameWithoutExtension ( Path . GetRandomFileName ( ) ) ;
75- var outputImage = $ "{ schedulerOptions . Seed } _{ promptOptions . SchedulerType } _{ rand } .png";
76- var outputImageUrl = CreateOutputImageUrl ( "TextToImage" , outputImage ) ;
77- var outputImageFile = CreateOutputImageFile ( outputImageUrl ) ;
78115
116+ /// <summary>
117+ /// Runs the stable diffusion.
118+ /// </summary>
119+ /// <param name="promptOptions">The prompt options.</param>
120+ /// <param name="schedulerOptions">The scheduler options.</param>
121+ /// <param name="fileInfo">The file information.</param>
122+ /// <param name="cancellationToken">The cancellation token.</param>
123+ /// <returns></returns>
124+ private async Task < bool > RunStableDiffusion ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , FileInfoResult fileInfo , CancellationToken cancellationToken )
125+ {
79126 try
80127 {
81- await _stableDiffusionService . TextToImageFile ( promptOptions , schedulerOptions , outputImageFile , ProgressCallback ( ) , cancellationToken ) ;
82- return new DiffusionResult ( outputImage , outputImageUrl ) ;
128+ await _stableDiffusionService . TextToImageFile ( promptOptions , schedulerOptions , fileInfo . OutputImageFile , ProgressCallback ( ) , cancellationToken ) ;
129+ return true ;
83130 }
84131 catch ( OperationCanceledException tex )
85132 {
86133 await Clients . Caller . SendAsync ( "OnCanceled" , tex . Message ) ;
134+ _logger . Log ( LogLevel . Warning , tex , "[OnExecuteTextToImage] - Operation canceled, Connection: {0}" , Context . ConnectionId ) ;
87135 }
88136 catch ( Exception ex )
89137 {
90138 await Clients . Caller . SendAsync ( "OnError" , ex . Message ) ;
139+ _logger . Log ( LogLevel . Error , ex , "[OnExecuteTextToImage] - Error generating image, Connection: {0}" , Context . ConnectionId ) ;
140+ }
141+ return false ;
142+ }
143+
144+
145+ /// <summary>
146+ /// Saves the options file.
147+ /// </summary>
148+ /// <param name="fileInfo">The file information.</param>
149+ /// <param name="options">The options.</param>
150+ /// <returns></returns>
151+ private async Task < bool > SaveOptionsFile ( FileInfoResult fileInfo , TextToImageOptions options )
152+ {
153+ try
154+ {
155+ using ( var stream = File . Create ( fileInfo . OutputOptionsFile ) )
156+ {
157+ await JsonSerializer . SerializeAsync ( stream , options , _serializerOptions ) ;
158+ return true ;
159+ }
160+ }
161+ catch ( Exception ex )
162+ {
163+ _logger . Log ( LogLevel . Error , ex , "[SaveOptions] - Error saving model card, Connection: {0}" , Context . ConnectionId ) ;
164+ return false ;
91165 }
92- return null ;
93166 }
94167
168+
169+ /// <summary>
170+ /// Creates the file information.
171+ /// </summary>
172+ /// <param name="promptOptions">The prompt options.</param>
173+ /// <param name="schedulerOptions">The scheduler options.</param>
174+ /// <returns></returns>
175+ private FileInfoResult CreateFileInfo ( PromptOptions promptOptions , SchedulerOptions schedulerOptions )
176+ {
177+ var rand = Path . GetFileNameWithoutExtension ( Path . GetRandomFileName ( ) ) ;
178+ var output = $ "{ schedulerOptions . Seed } _{ promptOptions . SchedulerType } _{ rand } ";
179+ var outputImage = $ "{ output } .png";
180+ var outputImageUrl = CreateOutputUrl ( "TextToImage" , outputImage ) ;
181+ var outputImageFile = UrlToPhysicalPath ( outputImageUrl ) ;
182+
183+ var outputJson = $ "{ output } .json";
184+ var outputJsonUrl = CreateOutputUrl ( "TextToImage" , outputJson ) ;
185+ var outputJsonFile = UrlToPhysicalPath ( outputJsonUrl ) ;
186+ return new FileInfoResult ( outputImage , outputImageUrl , outputImageFile , outputJson , outputJsonUrl , outputJsonFile ) ;
187+ }
188+
189+
190+ /// <summary>
191+ /// Generates the seed.
192+ /// </summary>
193+ /// <param name="seed">The seed.</param>
194+ /// <returns></returns>
95195 private int GenerateSeed ( int seed )
96196 {
97197 if ( seed > 0 )
@@ -100,6 +200,11 @@ private int GenerateSeed(int seed)
100200 return Random . Shared . Next ( ) ;
101201 }
102202
203+
204+ /// <summary>
205+ /// Progress callback.
206+ /// </summary>
207+ /// <returns></returns>
103208 private Action < int , int > ProgressCallback ( )
104209 {
105210 return async ( progress , total ) =>
@@ -110,19 +215,32 @@ private Action<int, int> ProgressCallback()
110215 }
111216
112217
113- private string CreateOutputImageFile ( string url )
218+ /// <summary>
219+ /// URL path to physical path.
220+ /// </summary>
221+ /// <param name="url">The URL.</param>
222+ /// <returns></returns>
223+ private string UrlToPhysicalPath ( string url )
114224 {
115225 string webRootPath = _webHostEnvironment . WebRootPath ;
116226 string physicalPath = Path . Combine ( webRootPath , url . TrimStart ( '/' ) . Replace ( '/' , '\\ ' ) ) ;
117227 return physicalPath ;
118228 }
119229
120- private string CreateOutputImageUrl ( string folder , string imageName )
230+
231+ /// <summary>
232+ /// Creates the output URL.
233+ /// </summary>
234+ /// <param name="folder">The folder.</param>
235+ /// <param name="file">The file.</param>
236+ /// <returns></returns>
237+ private string CreateOutputUrl ( string folder , string file )
121238 {
122- return $ "/images/results/{ folder } /{ imageName } ";
239+ return $ "/images/results/{ folder } /{ file } ";
123240 }
124241 }
125242
126243 public record ProgressResult ( int Progress , int Total ) ;
127- public record DiffusionResult ( string OutputImage , string OutputImageUrl ) ;
244+ public record TextToImageResult ( string OutputImage , string OutputImageUrl , TextToImageOptions OutputOptions , string OutputOptionsUrl ) ;
245+ public record FileInfoResult ( string OutputImage , string OutputImageUrl , string OutputImageFile , string OutputOptions , string OutputOptionsUrl , string OutputOptionsFile ) ;
128246}
0 commit comments