Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8df2118

Browse files
committed
Reduce ApplyMaskedLatents loop
1 parent fa71a4e commit 8df2118

File tree

2 files changed

+16
-45
lines changed

2 files changed

+16
-45
lines changed

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/InpaintLegacyDiffuser.cs

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,11 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
207207
for (int y = 0; y < height; y++)
208208
{
209209
var pixelSpan = img.GetRowSpan(y);
210-
var value = pixelSpan[x].A / 255.0f;
211-
maskTensor[0, 0, y, x] = 1f - value;
212-
maskTensor[0, 1, y, x] = 0f; // Needed for shape only
213-
maskTensor[0, 2, y, x] = 0f; // Needed for shape only
214-
maskTensor[0, 3, y, x] = 0f; // Needed for shape only
210+
var value = 1f - (pixelSpan[x].A / 255.0f);
211+
maskTensor[0, 0, y, x] = value;
212+
maskTensor[0, 1, y, x] = value; // Needed for shape only
213+
maskTensor[0, 2, y, x] = value; // Needed for shape only
214+
maskTensor[0, 3, y, x] = value; // Needed for shape only
215215
}
216216
}
217217
});
@@ -231,24 +231,10 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
231231
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
232232
{
233233
var result = new DenseTensor<float>(latents.Dimensions);
234-
for (int batch = 0; batch < latents.Dimensions[0]; batch++)
234+
for (int i = 0; i < result.Length; i++)
235235
{
236-
for (int channel = 0; channel < latents.Dimensions[1]; channel++)
237-
{
238-
for (int height = 0; height < latents.Dimensions[2]; height++)
239-
{
240-
for (int width = 0; width < latents.Dimensions[3]; width++)
241-
{
242-
float maskValue = mask[batch, 0, height, width];
243-
float latentsValue = latents[batch, channel, height, width];
244-
float initLatentsProperValue = initLatentsProper[batch, channel, height, width];
245-
246-
//Apply the logic to compute the result based on the mask
247-
float newValue = initLatentsProperValue * maskValue + latentsValue * (1f - maskValue);
248-
result[batch, channel, height, width] = newValue;
249-
}
250-
}
251-
}
236+
float maskValue = mask.GetValue(i);
237+
result.SetValue(i, initLatentsProper.GetValue(i) * maskValue + latents.GetValue(i) * (1f - maskValue));
252238
}
253239
return result;
254240
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,14 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
195195
for (int y = 0; y < height; y++)
196196
{
197197
var pixelSpan = img.GetRowSpan(y);
198-
var value = pixelSpan[x].A / 255.0f;
199-
maskTensor[0, 0, y, x] = 1f - value;
200-
maskTensor[0, 1, y, x] = 0f; // Needed for shape only
201-
maskTensor[0, 2, y, x] = 0f; // Needed for shape only
202-
maskTensor[0, 3, y, x] = 0f; // Needed for shape only
198+
var value = 1f - (pixelSpan[x].A / 255.0f);
199+
maskTensor[0, 0, y, x] = value;
200+
maskTensor[0, 1, y, x] = value; // Needed for shape only
201+
maskTensor[0, 2, y, x] = value; // Needed for shape only
202+
maskTensor[0, 3, y, x] = value; // Needed for shape only
203203
}
204204
}
205205
});
206-
207206
return maskTensor;
208207
}
209208
}
@@ -219,24 +218,10 @@ private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions
219218
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
220219
{
221220
var result = new DenseTensor<float>(latents.Dimensions);
222-
for (int batch = 0; batch < latents.Dimensions[0]; batch++)
221+
for (int i = 0; i < result.Length; i++)
223222
{
224-
for (int channel = 0; channel < latents.Dimensions[1]; channel++)
225-
{
226-
for (int height = 0; height < latents.Dimensions[2]; height++)
227-
{
228-
for (int width = 0; width < latents.Dimensions[3]; width++)
229-
{
230-
float maskValue = mask[batch, 0, height, width];
231-
float latentsValue = latents[batch, channel, height, width];
232-
float initLatentsProperValue = initLatentsProper[batch, channel, height, width];
233-
234-
//Apply the logic to compute the result based on the mask
235-
float newValue = initLatentsProperValue * maskValue + latentsValue * (1f - maskValue);
236-
result[batch, channel, height, width] = newValue;
237-
}
238-
}
239-
}
223+
float maskValue = mask.GetValue(i);
224+
result.SetValue(i, initLatentsProper.GetValue(i) * maskValue + latents.GetValue(i) * (1f - maskValue));
240225
}
241226
return result;
242227
}

0 commit comments

Comments
 (0)