diff --git a/projector.py b/projector.py index 36041a086..fe39afef3 100755 --- a/projector.py +++ b/projector.py @@ -25,6 +25,8 @@ def project( G, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution + label, + label_dim, *, num_steps = 1000, w_avg_samples = 10000, @@ -35,7 +37,7 @@ def project( noise_ramp_length = 0.75, regularize_noise_weight = 1e5, verbose = False, - device: torch.device + device: torch.device, ): assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) @@ -47,8 +49,15 @@ def logprint(*args): # Compute w stats. logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') + z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) - w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] + + if label is not None and label_dim is not None: + onehot = np.zeros((w_avg_samples, label_dim), dtype=np.float32) + onehot[:, label] = 1 + label = torch.Tensor(onehot).to(device) + + w_samples = G.mapping(torch.from_numpy(z_samples).to(device), label) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 @@ -134,6 +143,8 @@ def logprint(*args): @click.command() @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) +@click.option('--label', help='Class Label', type=int, default=None, required=False) +@click.option('--label-dim', help='Label Dimension', type=int, default=None, required=False) @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE') @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True) @click.option('--seed', help='Random seed', type=int, default=303, show_default=True) @@ -141,6 +152,8 @@ def logprint(*args): @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR') def run_projection( network_pkl: str, + label: int, + label_dim: int, target_fname: str, outdir: str, save_video: bool, @@ -177,6 +190,8 @@ def run_projection( projected_w_steps = project( G, target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable + label=label, + label_dim=label_dim, num_steps=num_steps, device=device, verbose=True