@@ -66,8 +66,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
6666 {
6767 // Create random seed if none was set
6868 schedulerOptions . Seed = schedulerOptions . Seed > 0 ? schedulerOptions . Seed : Random . Shared . Next ( ) ;
69- Console . WriteLine ( $ "Scheduler: { promptOptions . SchedulerType } , Size: { schedulerOptions . Width } x{ schedulerOptions . Height } , Seed: { schedulerOptions . Seed } , Steps: { schedulerOptions . InferenceSteps } , Guidance: { schedulerOptions . GuidanceScale } ") ;
70-
69+
7170 // Get Scheduler
7271 using ( var scheduler = GetScheduler ( promptOptions , schedulerOptions ) )
7372 {
@@ -78,7 +77,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
7877 var timesteps = GetTimesteps ( promptOptions , schedulerOptions , scheduler ) ;
7978
8079 // Create latent sample
81- var latentSample = PrepareLatents ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
80+ var latents = PrepareLatents ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
8281
8382 // Loop though the timesteps
8483 var step = 0 ;
@@ -87,8 +86,9 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
8786 cancellationToken . ThrowIfCancellationRequested ( ) ;
8887
8988 // Create input tensor.
90- var inputTensor = scheduler . ScaleInput ( latentSample . Duplicate ( schedulerOptions . GetScaledDimension ( 2 ) ) , timestep ) ;
89+ var inputTensor = scheduler . ScaleInput ( latents . Duplicate ( schedulerOptions . GetScaledDimension ( 2 ) ) , timestep ) ;
9190
91+ // Create Input Parameters
9292 var inputNames = _onnxModelService . GetInputNames ( OnnxModelType . Unet ) ;
9393 var inputParameters = CreateInputParameters (
9494 NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ,
@@ -98,27 +98,24 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptO
9898 // Run Inference
9999 using ( var inferResult = await _onnxModelService . RunInferenceAsync ( OnnxModelType . Unet , inputParameters ) )
100100 {
101- var resultTensor = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
102-
103- // Split tensors from 2,4,(H/8),(W/8) to 1,4,(H/8),(W/8)
104- var splitTensors = resultTensor . SplitTensor ( schedulerOptions . GetScaledDimension ( ) , schedulerOptions . GetScaledHeight ( ) , schedulerOptions . GetScaledWidth ( ) ) ;
105- var noisePred = splitTensors . Item1 ;
106- var noisePredText = splitTensors . Item2 ;
101+ var noisePred = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
107102
108103 // Perform guidance
109- noisePred = noisePred . PerformGuidance ( noisePredText , schedulerOptions . GuidanceScale ) ;
104+ if ( schedulerOptions . GuidanceScale > 1.0f )
105+ {
106+ var ( noisePredUncond , noisePredText ) = noisePred . SplitTensor ( schedulerOptions . GetScaledDimension ( ) ) ;
107+ noisePred = noisePredUncond . PerformGuidance ( noisePredText , schedulerOptions . GuidanceScale ) ;
108+ }
110109
111- // LMS Scheduler Step
112- latentSample = scheduler . Step ( noisePred , timestep , latentSample ) ;
113- // ImageHelpers.TensorToImageDebug(latentSample, 64, $@"Examples\StableDebug\Latent_{step}.png");
110+ // Scheduler Step
111+ latents = scheduler . Step ( noisePred , timestep , latents ) ;
114112 }
115113
116- Console . WriteLine ( $ "Step: { ++ step } /{ timesteps . Count } ") ;
117- progress ? . Invoke ( step , timesteps . Count ) ;
114+ progress ? . Invoke ( ++ step , timesteps . Count ) ;
118115 }
119116
120117 // Decode Latents
121- return await DecodeLatents ( schedulerOptions , latentSample ) ;
118+ return await DecodeLatents ( schedulerOptions , latents ) ;
122119 }
123120 }
124121
@@ -192,7 +189,7 @@ protected static DenseTensor<float> ClipImageFeatureExtractor(SchedulerOptions o
192189 using ( var image = imageTensor . ToImage ( ) )
193190 {
194191 // Resize image
195- ImageHelpers . Resize ( image , 224 , 224 ) ;
192+ ImageHelpers . Resize ( image , new [ ] { 1 , 3 , 224 , 224 } ) ;
196193
197194 // Preprocess image
198195 var input = new DenseTensor < float > ( new [ ] { 1 , 3 , 224 , 224 } ) ;
0 commit comments