1+ using System . Security . Cryptography ;
12using FluentAssertions ;
3+ using FluentAssertions . Execution ;
24using Microsoft . Extensions . DependencyInjection ;
5+ using Microsoft . Extensions . Logging ;
36using OnnxStack . Core ;
47using OnnxStack . StableDiffusion . Common ;
58using OnnxStack . StableDiffusion . Config ;
69using OnnxStack . StableDiffusion . Enums ;
10+ using SixLabors . ImageSharp ;
711using Xunit . Abstractions ;
812
913namespace OnnxStack . IntegrationTests ;
1014
1115public class StableDiffusionTests
1216{
13- private readonly ITestOutputHelper _testOutputHelper ;
1417 private readonly IStableDiffusionService _stableDiffusion ;
18+ private readonly ILogger < StableDiffusionTests > _logger ;
1519
1620 public StableDiffusionTests ( ITestOutputHelper testOutputHelper )
1721 {
18- _testOutputHelper = testOutputHelper ;
19-
2022 var services = new ServiceCollection ( ) ;
21- services . AddLogging ( ) ;
23+ services . AddLogging ( builder => builder . AddConsole ( ) ) ;
24+ services . AddLogging ( builder => builder . AddXunit ( testOutputHelper ) ) ;
2225 services . AddOnnxStack ( ) ;
2326 services . AddOnnxStackStableDiffusion ( ) ;
2427 var provider = services . BuildServiceProvider ( ) ;
2528 _stableDiffusion = provider . GetRequiredService < IStableDiffusionService > ( ) ;
29+ _logger = provider . GetRequiredService < ILogger < StableDiffusionTests > > ( ) ;
2630 }
2731
2832 [ Fact ]
@@ -32,6 +36,7 @@ public async Task GivenStableDiffusion15_WhenLoadModel_ThenModelIsLoaded()
3236 var model = _stableDiffusion . Models . Single ( m => m . Name == "StableDiffusion 1.5" ) ;
3337
3438 //act
39+ _logger . LogInformation ( "Attempting to load model {0}" , model . Name ) ;
3540 var isModelLoaded = await _stableDiffusion . LoadModel ( model ) ;
3641
3742 //assert
@@ -43,6 +48,7 @@ public async Task GivenTextToImage_WhenInference_ThenImageGenerated()
4348 {
4449 //arrange
4550 var model = _stableDiffusion . Models . Single ( m => m . Name == "StableDiffusion 1.5" ) ;
51+ _logger . LogInformation ( "Attempting to load model {0}" , model . Name ) ;
4652 await _stableDiffusion . LoadModel ( model ) ;
4753
4854 var prompt = new PromptOptions
@@ -60,23 +66,49 @@ public async Task GivenTextToImage_WhenInference_ThenImageGenerated()
6066 Height = 512 ,
6167 InferenceSteps = 10 ,
6268 GuidanceScale = 7.0f ,
63- Seed = - 1
69+ Seed = 1
6470 } ;
6571
6672 var steps = 0 ;
6773
6874 //act
6975 var image = await _stableDiffusion . GenerateAsImageAsync ( model , prompt , scheduler , ( currentStep , totalSteps ) =>
7076 {
71- _testOutputHelper . WriteLine ( $ "Step { currentStep } /{ totalSteps } ") ;
77+ _logger . LogInformation ( $ "Step { currentStep } /{ totalSteps } ") ;
7278 steps ++ ;
7379 } ) ;
7480
81+ var imagesDirectory = Path . Combine ( Directory . GetCurrentDirectory ( ) , "images" ) ;
82+ if ( ! Directory . Exists ( imagesDirectory ) )
83+ {
84+ _logger . LogInformation ( $ "Creating directory { imagesDirectory } ") ;
85+ Directory . CreateDirectory ( imagesDirectory ) ;
86+ }
87+ else
88+ {
89+ _logger . LogInformation ( $ "Directory { imagesDirectory } already exists") ;
90+ }
91+
92+ var fileName = $ "{ imagesDirectory } /{ nameof ( GivenTextToImage_WhenInference_ThenImageGenerated ) } -{ DateTime . Now : yyyyMMddHHmmss} .png";
93+ _logger . LogInformation ( $ "Saving generated image to { fileName } ") ;
94+ await image . SaveAsPngAsync ( fileName ) ;
95+
7596 //assert
76- steps . Should ( ) . Be ( 10 ) ;
77- image . Should ( ) . NotBeNull ( ) ;
78- image . Size . IsEmpty . Should ( ) . BeFalse ( ) ;
79- image . Width . Should ( ) . Be ( 512 ) ;
80- image . Height . Should ( ) . Be ( 512 ) ;
97+ using ( new AssertionScope ( ) )
98+ {
99+ steps . Should ( ) . Be ( 10 ) ;
100+ image . Should ( ) . NotBeNull ( ) ;
101+ image . Size . IsEmpty . Should ( ) . BeFalse ( ) ;
102+ image . Width . Should ( ) . Be ( 512 ) ;
103+ image . Height . Should ( ) . Be ( 512 ) ;
104+
105+ File . Exists ( fileName ) . Should ( ) . BeTrue ( ) ;
106+ var md5 = MD5 . Create ( ) ;
107+ var hash = md5 . ComputeHash ( File . ReadAllBytes ( fileName ) ) ;
108+ var hashString = string . Join ( "" , hash . Select ( b => b . ToString ( "X2" ) ) ) ;
109+ _logger . LogInformation ( $ "MD5 Hash of generated image: { hashString } ") ;
110+
111+ hashString . Should ( ) . Be ( "E518D0E4F67CBD5E93513574D30F3FD7" ) ;
112+ }
81113 }
82114}
0 commit comments