diff --git a/generate.py b/generate.py index f7f961931..9d5aa7dd2 100755 --- a/generate.py +++ b/generate.py @@ -19,6 +19,8 @@ import torch import legacy +import functools +import random #---------------------------------------------------------------------------- @@ -43,6 +45,8 @@ def num_range(s: str) -> List[int]: @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') +@click.option('--gpu', help='Use GPU to generate image (using CPU if not specified)', default='False', type=str, show_default=True) +@click.option('--seed-size', 'seed_size', help='The range at which the seed can be chosen from', default=500000, type=int, show_default=True) def generate_images( ctx: click.Context, network_pkl: str, @@ -51,7 +55,9 @@ def generate_images( noise_mode: str, outdir: str, class_idx: Optional[int], - projected_w: Optional[str] + projected_w: Optional[str], + gpu: Optional[str], + seed_size: Optional[str] ): """Generate images using pretrained network pickle. @@ -79,10 +85,19 @@ def generate_images( """ print('Loading networks from "%s"...' % network_pkl) - device = torch.device('cuda') + + if gpu == 'True': + device = torch.device('cuda') + else: + device = torch.device('cpu') + with dnnlib.util.open_url(network_pkl) as f: G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore + + if gpu != 'True': + G.forward = functools.partial(G.forward, force_fp32=True) + os.makedirs(outdir, exist_ok=True) # Synthesize the result of a W projection. @@ -100,7 +115,9 @@ def generate_images( return if seeds is None: - ctx.fail('--seeds option is required when not using --projected-w') + seeds = [] + for i in range(0, 5): + seeds.append(random.randint(0, seed_size)) # Labels. label = torch.zeros([1, G.c_dim], device=device) @@ -126,4 +143,4 @@ def generate_images( if __name__ == "__main__": generate_images() # pylint: disable=no-value-for-parameter -#---------------------------------------------------------------------------- +#---------------------------------------------------------------------------- \ No newline at end of file