22using OnnxStack . StableDiffusion . Common ;
33using OnnxStack . StableDiffusion . Config ;
44using OnnxStack . WebUI . Models ;
5+ using Services ;
56using System . Diagnostics ;
67using System . Runtime . CompilerServices ;
7- using System . Text . Json ;
8- using System . Text . Json . Serialization ;
98
109namespace OnnxStack . Web . Hubs
1110{
1211 public class StableDiffusionHub : Hub < IStableDiffusionClient >
1312 {
13+ private readonly IFileService _fileService ;
1414 private readonly ILogger < StableDiffusionHub > _logger ;
15- private readonly IWebHostEnvironment _webHostEnvironment ;
16- private readonly JsonSerializerOptions _serializerOptions ;
1715 private readonly IStableDiffusionService _stableDiffusionService ;
1816
1917
@@ -23,12 +21,11 @@ public class StableDiffusionHub : Hub<IStableDiffusionClient>
2321 /// <param name="logger">The logger.</param>
2422 /// <param name="stableDiffusionService">The stable diffusion service.</param>
2523 /// <param name="webHostEnvironment">The web host environment.</param>
26- public StableDiffusionHub ( ILogger < StableDiffusionHub > logger , IStableDiffusionService stableDiffusionService , IWebHostEnvironment webHostEnvironment )
24+ public StableDiffusionHub ( ILogger < StableDiffusionHub > logger , IStableDiffusionService stableDiffusionService , IFileService fileService )
2725 {
2826 _logger = logger ;
29- _webHostEnvironment = webHostEnvironment ;
27+ _fileService = fileService ;
3028 _stableDiffusionService = stableDiffusionService ;
31- _serializerOptions = new JsonSerializerOptions { WriteIndented = true , Converters = { new JsonStringEnumConverter ( ) } } ;
3229 }
3330
3431
@@ -62,17 +59,83 @@ public override async Task OnDisconnectedAsync(Exception exception)
6259 /// <param name="cancellationToken">The cancellation token.</param>
6360 /// <returns></returns>
6461 [ HubMethodName ( "ExecuteTextToImage" ) ]
65- public async IAsyncEnumerable < TextToImageResult > OnExecuteTextToImage ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , [ EnumeratorCancellation ] CancellationToken cancellationToken )
62+ public async IAsyncEnumerable < StableDiffusionResult > OnExecuteTextToImage ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , [ EnumeratorCancellation ] CancellationToken cancellationToken )
6663 {
6764 _logger . Log ( LogLevel . Information , "[OnExecuteTextToImage] - New request received, Connection: {0}" , Context . ConnectionId ) ;
6865 var cancellationTokenSource = CancellationTokenSource . CreateLinkedTokenSource ( Context . ConnectionAborted , cancellationToken ) ;
6966
7067 // TODO: Add support for multiple results
7168 var result = await GenerateTextToImageResult ( promptOptions , schedulerOptions , cancellationTokenSource . Token ) ;
72- if ( result is null )
73- yield break ;
69+ if ( ! result . IsError )
70+ yield return result ;
7471
75- yield return result ;
72+ await Clients . Caller . OnError ( result . Error ) ;
73+ }
74+
75+
76+ /// <summary>
77+ /// Execute Image-To-Image Stable Diffusion
78+ /// </summary>
79+ /// <param name="options">The options.</param>
80+ /// <param name="cancellationToken">The cancellation token.</param>
81+ /// <returns></returns>
82+ [ HubMethodName ( "ExecuteImageToImage" ) ]
83+ public async IAsyncEnumerable < StableDiffusionResult > OnExecuteImageToImage ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , [ EnumeratorCancellation ] CancellationToken cancellationToken )
84+ {
85+ _logger . Log ( LogLevel . Information , "[ExecuteImageToImage] - New request received, Connection: {0}" , Context . ConnectionId ) ;
86+ var cancellationTokenSource = CancellationTokenSource . CreateLinkedTokenSource ( Context . ConnectionAborted , cancellationToken ) ;
87+
88+ // TODO: Add support for multiple results
89+ var result = await GenerateImageToImageResult ( promptOptions , schedulerOptions , cancellationTokenSource . Token ) ;
90+ if ( ! result . IsError )
91+ yield return result ;
92+
93+ await Clients . Caller . OnError ( result . Error ) ;
94+ yield break ;
95+ }
96+
97+
98+ /// <summary>
99+ /// Generates the image to image result.
100+ /// </summary>
101+ /// <param name="promptOptions">The prompt options.</param>
102+ /// <param name="schedulerOptions">The scheduler options.</param>
103+ /// <param name="cancellationToken">The cancellation token.</param>
104+ /// <returns></returns>
105+ private async Task < StableDiffusionResult > GenerateImageToImageResult ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , CancellationToken cancellationToken )
106+ {
107+ var timestamp = Stopwatch . GetTimestamp ( ) ;
108+ schedulerOptions . Seed = GenerateSeed ( schedulerOptions . Seed ) ;
109+
110+ //1. Create filenames
111+ var random = await _fileService . CreateRandomName ( ) ;
112+ var output = $ "Output-{ random } ";
113+ var outputImage = $ "{ output } .png";
114+ var outputBlueprint = $ "{ output } .json";
115+ var inputImage = $ "Input-{ random } .png";
116+ var outputImageUrl = await _fileService . CreateOutputUrl ( outputImage ) ;
117+ var outputImageFile = await _fileService . UrlToPhysicalPath ( outputImageUrl ) ;
118+
119+ //2. Copy input image to new file
120+ var inputImageFile = await _fileService . CopyInputImageFile ( promptOptions . InputImage , inputImage ) ;
121+ if ( inputImageFile is null )
122+ return new StableDiffusionResult ( "Failed to copy input image" ) ;
123+
124+ //3. Generate blueprint
125+ var blueprint = new ImageBlueprint ( promptOptions , schedulerOptions ) ;
126+ var bluprintFile = await _fileService . SaveBlueprintFile ( blueprint , outputBlueprint ) ;
127+ if ( bluprintFile is null )
128+ return new StableDiffusionResult ( "Failed to save blueprint" ) ;
129+
130+ //4. Set full path of input image
131+ promptOptions . InputImage = inputImageFile . FilePath ;
132+
133+ //5. Run stable diffusion
134+ if ( ! await RunStableDiffusion ( promptOptions , schedulerOptions , outputImageFile , cancellationToken ) )
135+ return new StableDiffusionResult ( "Failed to run stable diffusion" ) ;
136+
137+ //6. Return result
138+ return new StableDiffusionResult ( outputImage , outputImageUrl , blueprint , bluprintFile . Filename , bluprintFile . FileUrl , GetElapsed ( timestamp ) ) ;
76139 }
77140
78141
@@ -82,21 +145,31 @@ public async IAsyncEnumerable<TextToImageResult> OnExecuteTextToImage(PromptOpti
82145 /// <param name="options">The options.</param>
83146 /// <param name="cancellationToken">The cancellation token.</param>
84147 /// <returns></returns>
85- private async Task < TextToImageResult > GenerateTextToImageResult ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , CancellationToken cancellationToken )
148+ private async Task < StableDiffusionResult > GenerateTextToImageResult ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , CancellationToken cancellationToken )
86149 {
87150 var timestamp = Stopwatch . GetTimestamp ( ) ;
88151 schedulerOptions . Seed = GenerateSeed ( schedulerOptions . Seed ) ;
89152
153+ //1. Create filenames
154+ var random = await _fileService . CreateRandomName ( ) ;
155+ var output = $ "Output-{ random } ";
156+ var outputImage = $ "{ output } .png";
157+ var outputBlueprint = $ "{ output } .json";
158+ var outputImageUrl = await _fileService . CreateOutputUrl ( outputImage ) ;
159+ var outputImageFile = await _fileService . UrlToPhysicalPath ( outputImageUrl ) ;
160+
161+ //2. Generate blueprint
90162 var blueprint = new ImageBlueprint ( promptOptions , schedulerOptions ) ;
91- var fileInfo = CreateFileInfo ( promptOptions , schedulerOptions ) ;
92- if ( ! await SaveBlueprintFile ( fileInfo , blueprint ) )
93- return null ;
163+ var bluprintFile = await _fileService . SaveBlueprintFile ( blueprint , outputBlueprint ) ;
164+ if ( bluprintFile is null )
165+ return new StableDiffusionResult ( "Failed to save blueprint" ) ;
94166
95- if ( ! await RunStableDiffusion ( promptOptions , schedulerOptions , fileInfo , cancellationToken ) )
96- return null ;
167+ //3. Run stable diffusion
168+ if ( ! await RunStableDiffusion ( promptOptions , schedulerOptions , outputImageFile , cancellationToken ) )
169+ return new StableDiffusionResult ( "Failed to run stable diffusion" ) ;
97170
98- var elapsed = ( int ) Stopwatch . GetElapsedTime ( timestamp ) . TotalSeconds ;
99- return new TextToImageResult ( fileInfo . Image , fileInfo . ImageUrl , blueprint , fileInfo . Blueprint , fileInfo . BlueprintUrl , elapsed ) ;
171+ //4. Return result
172+ return new StableDiffusionResult ( outputImage , outputImageUrl , blueprint , bluprintFile . Filename , bluprintFile . FileUrl , GetElapsed ( timestamp ) ) ;
100173 }
101174
102175
@@ -108,69 +181,38 @@ private async Task<TextToImageResult> GenerateTextToImageResult(PromptOptions pr
108181 /// <param name="fileInfo">The file information.</param>
109182 /// <param name="cancellationToken">The cancellation token.</param>
110183 /// <returns></returns>
111- private async Task < bool > RunStableDiffusion ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , FileInfoResult fileInfo , CancellationToken cancellationToken )
184+ private async Task < bool > RunStableDiffusion ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , string outputImage , CancellationToken cancellationToken )
112185 {
113186 try
114187 {
115- await _stableDiffusionService . TextToImageFile ( promptOptions , schedulerOptions , fileInfo . ImageFile , ProgressCallback ( ) , cancellationToken ) ;
188+ await _stableDiffusionService . TextToImageFile ( promptOptions , schedulerOptions , outputImage , ProgressCallback ( ) , cancellationToken ) ;
116189 return true ;
117190 }
118191 catch ( OperationCanceledException tex )
119192 {
120193 await Clients . Caller . OnCanceled ( tex . Message ) ;
121- _logger . Log ( LogLevel . Warning , tex , "[OnExecuteTextToImage ] - Operation canceled, Connection: {0}" , Context . ConnectionId ) ;
194+ _logger . Log ( LogLevel . Warning , tex , "[RunStableDiffusion ] - Operation canceled, Connection: {0}" , Context . ConnectionId ) ;
122195 }
123196 catch ( Exception ex )
124197 {
125198 await Clients . Caller . OnError ( ex . Message ) ;
126- _logger . Log ( LogLevel . Error , ex , "[OnExecuteTextToImage ] - Error generating image, Connection: {0}" , Context . ConnectionId ) ;
199+ _logger . Log ( LogLevel . Error , ex , "[RunStableDiffusion ] - Error generating image, Connection: {0}" , Context . ConnectionId ) ;
127200 }
128201 return false ;
129202 }
130203
131204
132205 /// <summary>
133- /// Saves the options file .
206+ /// Progress callback .
134207 /// </summary>
135- /// <param name="fileInfo">The file information.</param>
136- /// <param name="options">The options.</param>
137208 /// <returns></returns>
138- private async Task < bool > SaveBlueprintFile ( FileInfoResult fileInfo , ImageBlueprint bluprint )
209+ private Action < int , int > ProgressCallback ( )
139210 {
140- try
141- {
142- using ( var stream = File . Create ( fileInfo . BlueprintFile ) )
143- {
144- await JsonSerializer . SerializeAsync ( stream , bluprint , _serializerOptions ) ;
145- return true ;
146- }
147- }
148- catch ( Exception ex )
211+ return async ( progress , total ) =>
149212 {
150- _logger . Log ( LogLevel . Error , ex , "[SaveOptions] - Error saving model card, Connection: {0}" , Context . ConnectionId ) ;
151- return false ;
152- }
153- }
154-
155-
156- /// <summary>
157- /// Creates the file information.
158- /// </summary>
159- /// <param name="promptOptions">The prompt options.</param>
160- /// <param name="schedulerOptions">The scheduler options.</param>
161- /// <returns></returns>
162- private FileInfoResult CreateFileInfo ( PromptOptions promptOptions , SchedulerOptions schedulerOptions )
163- {
164- var rand = Path . GetFileNameWithoutExtension ( Path . GetRandomFileName ( ) ) ;
165- var output = $ "{ schedulerOptions . Seed } -{ rand } ";
166- var outputImage = $ "{ output } .png";
167- var outputImageUrl = CreateOutputUrl ( "TextToImage" , outputImage ) ;
168- var outputImageFile = UrlToPhysicalPath ( outputImageUrl ) ;
169-
170- var outputJson = $ "{ output } .json";
171- var outputJsonUrl = CreateOutputUrl ( "TextToImage" , outputJson ) ;
172- var outputJsonFile = UrlToPhysicalPath ( outputJsonUrl ) ;
173- return new FileInfoResult ( outputImage , outputImageUrl , outputImageFile , outputJson , outputJsonUrl , outputJsonFile ) ;
213+ _logger . Log ( LogLevel . Information , "[ProgressCallback] - Progress: {0}/{1}, Connection: {2}" , progress , total , Context . ConnectionId ) ;
214+ await Clients . Caller . OnProgress ( new ProgressResult ( progress , total ) ) ;
215+ } ;
174216 }
175217
176218
@@ -189,42 +231,13 @@ private int GenerateSeed(int seed)
189231
190232
191233 /// <summary>
192- /// Progress callback.
193- /// </summary>
194- /// <returns></returns>
195- private Action < int , int > ProgressCallback ( )
196- {
197- return async ( progress , total ) =>
198- {
199- _logger . Log ( LogLevel . Information , "[OnExecuteTextToImage] - Progress: {0}/{1}, Connection: {2}" , progress , total , Context . ConnectionId ) ;
200- await Clients . Caller . OnProgress ( new ProgressResult ( progress , total ) ) ;
201- } ;
202- }
203-
204-
205- /// <summary>
206- /// URL path to physical path.
207- /// </summary>
208- /// <param name="url">The URL.</param>
209- /// <returns></returns>
210- private string UrlToPhysicalPath ( string url )
211- {
212- string webRootPath = _webHostEnvironment . WebRootPath ;
213- string physicalPath = Path . Combine ( webRootPath , url . TrimStart ( '/' ) . Replace ( '/' , '\\ ' ) ) ;
214- return physicalPath ;
215- }
216-
217-
218- /// <summary>
219- /// Creates the output URL.
234+ /// Gets the elapsed time is seconds.
220235 /// </summary>
221- /// <param name="folder">The folder.</param>
222- /// <param name="file">The file.</param>
236+ /// <param name="timestamp">The begin timestamp.</param>
223237 /// <returns></returns>
224- private string CreateOutputUrl ( string folder , string file )
238+ private static int GetElapsed ( long timestamp )
225239 {
226- return $ "/images/results/ { folder } / { file } " ;
240+ return ( int ) Stopwatch . GetElapsedTime ( timestamp ) . TotalSeconds ;
227241 }
228242 }
229-
230243}
0 commit comments