@@ -59,6 +59,9 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
5959 // Get Model metadata
6060 var metadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Unet ) ;
6161
62+ // Get Time ids
63+ var addTimeIds = GetAddTimeIds ( modelOptions , schedulerOptions , performGuidance ) ;
64+
6265 // Loop though the timesteps
6366 var step = 0 ;
6467 foreach ( var timestep in timesteps )
@@ -71,7 +74,6 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
7174 var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
7275 var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
7376 var timestepTensor = CreateTimestepTensor ( timestep ) ;
74- var addTimeIds = GetAddTimeIds ( schedulerOptions , performGuidance ) ;
7577
7678 var outputChannels = performGuidance ? 2 : 1 ;
7779 var outputDimension = schedulerOptions . GetScaledDimension ( outputChannels ) ;
@@ -113,19 +115,27 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
113115 /// </summary>
114116 /// <param name="schedulerOptions">The scheduler options.</param>
115117 /// <returns></returns>
116- protected DenseTensor < float > GetAddTimeIds ( SchedulerOptions schedulerOptions , bool performGuidance )
118+ protected DenseTensor < float > GetAddTimeIds ( IModelOptions model , SchedulerOptions schedulerOptions , bool performGuidance )
117119 {
118- var addTimeIds = new float [ ]
120+ float [ ] result ;
121+ if ( model . ModelType == ModelType . Refiner )
119122 {
120- schedulerOptions . Height , schedulerOptions . Width , //original_size
121- 0 , 0 , //crops_coords_top_left
122- schedulerOptions . Height , schedulerOptions . Width //negative_target_size
123- } ;
124- var result = TensorHelper . CreateTensor ( addTimeIds , new [ ] { 1 , addTimeIds . Length } ) ;
125- if ( performGuidance )
126- return result . Repeat ( 2 ) ;
123+ //original_size + crops_coords_top_left + aesthetic_score
124+ //original_size + crops_coords_top_left + negative_aesthetic_score
125+ result = ! performGuidance
126+ ? new float [ ] { schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . AestheticScore }
127+ : new float [ ] { schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . AestheticNegativeScore , schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . AestheticScore } ;
128+ }
129+ else
130+ {
131+ //original_size + crops_coords_top_left + target_size
132+ //original_size + crops_coords_top_left + negative_target_size
133+ result = ! performGuidance
134+ ? new float [ ] { schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . Height , schedulerOptions . Width }
135+ : new float [ ] { schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . Height , schedulerOptions . Width , schedulerOptions . Height , schedulerOptions . Width , 0 , 0 , schedulerOptions . Height , schedulerOptions . Width } ;
136+ }
127137
128- return result ;
138+ return TensorHelper . CreateTensor ( result , new [ ] { 1 , result . Length } ) ;
129139 }
130140
131141
0 commit comments