@@ -113,7 +113,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
113113 }
114114
115115 // Decode Latents
116- return await DecodeLatents ( modelOptions , schedulerOptions , latents ) ;
116+ return await DecodeLatents ( modelOptions , promptOptions , schedulerOptions , latents ) ;
117117 }
118118 }
119119
@@ -123,26 +123,42 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
123123 /// <param name="options">The options.</param>
124124 /// <param name="latents">The latents.</param>
125125 /// <returns></returns>
126- protected async Task < DenseTensor < float > > DecodeLatents ( IModelOptions model , SchedulerOptions options , DenseTensor < float > latents )
126+ protected async Task < DenseTensor < float > > DecodeLatents ( IModelOptions model , PromptOptions prompt , SchedulerOptions options , DenseTensor < float > latents )
127127 {
128128 // Scale and decode the image latents with vae.
129129 latents = latents . MultiplyBy ( 1.0f / model . ScaleFactor ) ;
130130
131- var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . VaeDecoder ) ;
132- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , latents ) ) ;
133-
134- // Run inference.
135- using ( var inferResult = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . VaeDecoder , inputParameters ) )
131+ var images = prompt . BatchCount > 1
132+ ? latents . Split ( prompt . BatchCount )
133+ : new [ ] { latents } ;
134+ var imageTensors = new List < DenseTensor < float > > ( ) ;
135+ foreach ( var image in images )
136136 {
137- var resultTensor = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
138- if ( await _onnxModelService . IsEnabledAsync ( model , OnnxModelType . SafetyChecker ) )
137+ var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . VaeDecoder ) ;
138+ var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , image ) ) ;
139+
140+ // Run inference.
141+ using ( var inferResult = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . VaeDecoder , inputParameters ) )
139142 {
140- // Check if image contains NSFW content,
141- if ( ! await IsImageSafe ( model , options , resultTensor ) )
142- return resultTensor . CloneEmpty ( ) . ToDenseTensor ( ) ; //TODO: blank image?, exception?, null?
143+ var resultTensor = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
144+ if ( await _onnxModelService . IsEnabledAsync ( model , OnnxModelType . SafetyChecker ) )
145+ {
146+ // Check if image contains NSFW content,
147+ if ( ! await IsImageSafe ( model , options , resultTensor ) )
148+ {
149+ //TODO: blank image?, exception?, null?
150+ imageTensors . Add ( resultTensor . CloneEmpty ( ) . ToDenseTensor ( ) ) ;
151+ continue ;
152+ }
153+ }
154+
155+ if ( prompt . BatchCount == 1 )
156+ return resultTensor . ToDenseTensor ( ) ;
157+
158+ imageTensors . Add ( resultTensor . ToDenseTensor ( ) ) ;
143159 }
144- return resultTensor . ToDenseTensor ( ) ;
145160 }
161+ return imageTensors . Join ( ) ;
146162 }
147163
148164
0 commit comments