Skip to content

Commit 4ae1515

Browse files
Slightly faster latent2rgb previews.
1 parent f37a471 commit 4ae1515

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

latent_preview.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ def __init__(self, latent_rgb_factors):
3737
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
3838

3939
def decode_latent_to_preview(self, x0):
40-
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
40+
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
41+
latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
4142

4243
latents_ubyte = (((latent_image + 1) / 2)
4344
.clamp(0, 1) # change scale from -1..1 to 0..1
4445
.mul(0xFF) # to 0..255
45-
.byte()).cpu()
46+
).to(device="cpu", dtype=torch.uint8, non_blocking=True)
4647

4748
return Image.fromarray(latents_ubyte.numpy())
4849

0 commit comments

Comments
 (0)