@@ -24,24 +24,27 @@ public class StableDiffusionTests
2424{
2525 private readonly IStableDiffusionService _stableDiffusion ;
2626 private readonly ILogger < StableDiffusionTests > _logger ;
27+ private const string StableDiffusionModel = "StableDiffusion 1.5" ;
28+ private const string LatentConsistencyModel = "LCM-Dreamshaper-V7" ;
2729
2830 public StableDiffusionTests ( ITestOutputHelper testOutputHelper )
2931 {
3032 var services = new ServiceCollection ( ) ;
31- services . AddLogging ( builder => builder . AddConsole ( ) ) ;
32- services . AddLogging ( builder => builder . AddXunit ( testOutputHelper ) ) ;
33- services . AddOnnxStack ( ) ;
33+ services . AddLogging ( builder => builder . AddConsole ( ) ) ; //necessary for showing logs when running in docker
34+ services . AddLogging ( builder => builder . AddXunit ( testOutputHelper ) ) ; //necessary for showing logs when running in IDE
3435 services . AddOnnxStackStableDiffusion ( ) ;
3536 var provider = services . BuildServiceProvider ( ) ;
3637 _stableDiffusion = provider . GetRequiredService < IStableDiffusionService > ( ) ;
3738 _logger = provider . GetRequiredService < ILogger < StableDiffusionTests > > ( ) ;
3839 }
3940
40- [ Fact ]
41- public async Task GivenStableDiffusion15_WhenLoadModel_ThenModelIsLoaded ( )
41+ [ Theory ]
42+ [ InlineData ( StableDiffusionModel ) ]
43+ [ InlineData ( LatentConsistencyModel ) ]
44+ public async Task GivenAStableDiffusionModel_WhenLoadModel_ThenModelIsLoaded ( string modelName )
4245 {
4346 //arrange
44- var model = _stableDiffusion . Models . Single ( m => m . Name == "StableDiffusion 1.5" ) ;
47+ var model = _stableDiffusion . Models . Single ( m => m . Name == modelName ) ;
4548
4649 //act
4750 _logger . LogInformation ( "Attempting to load model {0}" , model . Name ) ;
@@ -50,15 +53,19 @@ public async Task GivenStableDiffusion15_WhenLoadModel_ThenModelIsLoaded()
5053 //assert
5154 isModelLoaded . Should ( ) . BeTrue ( ) ;
5255 }
53-
54- [ Fact ]
55- public async Task GivenTextToImage_WhenInference_ThenImageGenerated ( )
56+
57+ [ Theory ]
58+ [ InlineData ( StableDiffusionModel , SchedulerType . EulerAncestral , 10 , 7.0f , "E518D0E4F67CBD5E93513574D30F3FD7" ) ]
59+ [ InlineData ( LatentConsistencyModel , SchedulerType . LCM , 4 , 1.0f , "3554E5E1B714D936805F4C9D890B0711" ) ]
60+ public async Task GivenTextToImage_WhenInference_ThenImageGenerated ( string modelName , SchedulerType schedulerType ,
61+ int inferenceSteps , float guidanceScale , string generatedImageMd5Hash )
62+
5663 {
5764 //arrange
58- var model = _stableDiffusion . Models . Single ( m => m . Name == "StableDiffusion 1.5" ) ;
59- _logger . LogInformation ( "Attempting to load model {0}" , model . Name ) ;
65+ var model = _stableDiffusion . Models . Single ( m => m . Name == modelName ) ;
66+ _logger . LogInformation ( "Attempting to load model: {0}" , model . Name ) ;
6067 await _stableDiffusion . LoadModel ( model ) ;
61-
68+
6269 var prompt = new PromptOptions
6370 {
6471 Prompt = "an astronaut riding a horse in space" ,
@@ -71,14 +78,14 @@ public async Task GivenTextToImage_WhenInference_ThenImageGenerated()
7178 {
7279 Width = 512 ,
7380 Height = 512 ,
74- SchedulerType = SchedulerType . EulerAncestral ,
75- InferenceSteps = 10 ,
76- GuidanceScale = 7.0f ,
81+ SchedulerType = schedulerType ,
82+ InferenceSteps = inferenceSteps ,
83+ GuidanceScale = guidanceScale ,
7784 Seed = 1
7885 } ;
7986
8087 var steps = 0 ;
81-
88+
8289 //act
8390 var image = await _stableDiffusion . GenerateAsImageAsync ( model , prompt , scheduler , ( currentStep , totalSteps ) =>
8491 {
@@ -97,26 +104,27 @@ public async Task GivenTextToImage_WhenInference_ThenImageGenerated()
97104 _logger . LogInformation ( $ "Directory { imagesDirectory } already exists") ;
98105 }
99106
100- var fileName = $ "{ imagesDirectory } /{ nameof ( GivenTextToImage_WhenInference_ThenImageGenerated ) } -{ DateTime . Now : yyyyMMddHHmmss} .png";
107+ var fileName =
108+ $ "{ imagesDirectory } /{ nameof ( GivenTextToImage_WhenInference_ThenImageGenerated ) } -{ DateTime . Now : yyyyMMddHHmmss} .png";
101109 _logger . LogInformation ( $ "Saving generated image to { fileName } ") ;
102110 await image . SaveAsPngAsync ( fileName ) ;
103111
104112 //assert
105113 using ( new AssertionScope ( ) )
106114 {
107- steps . Should ( ) . Be ( 10 ) ;
115+ steps . Should ( ) . Be ( inferenceSteps ) ;
108116 image . Should ( ) . NotBeNull ( ) ;
109117 image . Size . IsEmpty . Should ( ) . BeFalse ( ) ;
110118 image . Width . Should ( ) . Be ( 512 ) ;
111119 image . Height . Should ( ) . Be ( 512 ) ;
112-
120+
113121 File . Exists ( fileName ) . Should ( ) . BeTrue ( ) ;
114122 var md5 = MD5 . Create ( ) ;
115123 var hash = md5 . ComputeHash ( File . ReadAllBytes ( fileName ) ) ;
116124 var hashString = string . Join ( "" , hash . Select ( b => b . ToString ( "X2" ) ) ) ;
117125 _logger . LogInformation ( $ "MD5 Hash of generated image: { hashString } ") ;
118-
119- hashString . Should ( ) . Be ( "E518D0E4F67CBD5E93513574D30F3FD7" ) ;
126+
127+ hashString . Should ( ) . Be ( generatedImageMd5Hash ) ;
120128 }
121129 }
122130}
0 commit comments