We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f37a471 commit 4ae1515Copy full SHA for 4ae1515
latent_preview.py
@@ -37,12 +37,13 @@ def __init__(self, latent_rgb_factors):
37
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
38
39
def decode_latent_to_preview(self, x0):
40
- latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
41
+ latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
42
43
latents_ubyte = (((latent_image + 1) / 2)
44
.clamp(0, 1) # change scale from -1..1 to 0..1
45
.mul(0xFF) # to 0..255
- .byte()).cpu()
46
+ ).to(device="cpu", dtype=torch.uint8, non_blocking=True)
47
48
return Image.fromarray(latents_ubyte.numpy())
49
0 commit comments