@@ -300,251 +300,183 @@ def img2img_transformer2d_base():
300300 hparams .layer_preprocess_sequence = "n"
301301 hparams .layer_postprocess_sequence = "da"
302302 # This version seems to benefit from a higher learning rate.
303- hparams .learning_rate = 0.1
303+ hparams .learning_rate = 0.2
304304 hparams .layer_prepostprocess_dropout = 0.1
305+ hparams .learning_rate_warmup_steps = 12000
306+ hparams .filter_size = 2048
307+ hparams .num_encoder_layers = 4
308+ hparams .num_decoder_layers = 8
305309 hparams .dec_attention_type = cia .AttentionType .LOCAL_2D
306310 hparams .block_rastor_scan = True
307311 return hparams
308312
309313
310314@registry .register_hparams
311- def img2img_transformer_base ():
312- """Base params for local1d attention."""
313- hparams = image_transformer2d_base ()
314- # learning related flags
315- hparams .layer_preprocess_sequence = "n"
316- hparams .layer_postprocess_sequence = "da"
317- # This version seems to benefit from a higher learning rate.
318- hparams .learning_rate = 0.1
319- hparams .layer_prepostprocess_dropout = 0.1
320- hparams .block_length = 256
321- hparams .block_width = 256
322- hparams .dec_attention_type = cia .AttentionType .LOCAL_1D
323- hparams .block_rastor_scan = False
324- return hparams
325-
326-
327- @registry .register_hparams
328- def imagetransformer2d_tiny ():
329- hparams = imagetransformer2d_base ()
330- hparams .num_decoder_layers = 2
331- hparams .hidden_size = 64
332- hparams .batch_size = 1
333- return hparams
334-
335-
336- @registry .register_hparams
337- def img2img_transformer2d_n3 ():
338- hparams = img2img_transformer2d_base ()
339- hparams .batch_size = 1
340- hparams .num_encoder_layers = 4
341- hparams .num_decoder_layers = 12
342- hparams .query_shape = (16 , 32 )
343- hparams .memory_flange = (16 , 16 )
344- hparams .layer_prepostprocess_dropout = 0.0
345- return hparams
346-
347-
348- @registry .register_hparams
349- def img2img_transformer2d_n31 ():
350- """Set of hyperparameters."""
351- hparams = img2img_transformer2d_base ()
352- hparams .batch_size = 1
353- hparams .num_encoder_layers = 6
354- hparams .num_decoder_layers = 12
355- hparams .num_heads = 8
356- hparams .query_shape = (16 , 32 )
357- hparams .memory_flange = (16 , 32 )
358- return hparams
359-
360-
361- @registry .register_hparams
362- def img2img_transformer2d_n32 ():
363- """Set of hyperparameters."""
364- hparams = img2img_transformer2d_base ()
365- hparams .batch_size = 1
366- hparams .num_heads = 16
367- hparams .num_encoder_layers = 4
368- hparams .num_decoder_layers = 12
369- hparams .query_shape = (16 , 32 )
370- hparams .memory_flange = (16 , 48 )
371- return hparams
372-
373-
374- @registry .register_hparams
375- def img2img_transformer2d_n4 ():
376- """Set of hyperparameters."""
315+ def img2img_transformer2d_q1 ():
377316 hparams = img2img_transformer2d_base ()
378- hparams .batch_size = 1
379- hparams .num_decoder_layers = 8
317+ hparams .batch_size = 2
318+ hparams .layer_preprocess_sequence = "none"
319+ hparams .layer_postprocess_sequence = "dan"
380320 hparams .query_shape = (16 , 16 )
381- hparams .memory_flange = (16 , 16 )
321+ hparams .memory_flange = (16 , 64 )
382322 return hparams
383323
384324
385325@registry .register_hparams
386- def img2img_transformer2d_n14 ():
387- hparams = img2img_transformer2d_base ()
388- hparams .batch_size = 1
389- hparams .hidden_size = 1024
390- hparams .filter_size = 2048
391- hparams .layer_prepostprocess_dropout = 0.2
392- hparams .num_decoder_layers = 8
326+ def img2img_transformer2d_q2 ():
327+ hparams = img2img_transformer2d_q1 ()
328+ hparams .batch_size = 2
329+ hparams .layer_preprocess_sequence = "none"
330+ hparams .layer_postprocess_sequence = "dan"
393331 hparams .query_shape = (16 , 16 )
394- hparams .memory_flange = (16 , 16 )
332+ hparams .memory_flange = (16 , 32 )
395333 return hparams
396334
397335
398336@registry .register_hparams
399- def img2img_transformer2d_n24 ():
400- hparams = img2img_transformer2d_base ()
401- hparams .batch_size = 1
402- hparams .hidden_size = 1024
403- hparams .filter_size = 2048
404- hparams .layer_prepostprocess_dropout = 0.2
405- hparams .num_decoder_layers = 8
337+ def img2img_transformer2d_q3 ():
338+ """Current best hparams for local 2d."""
339+ hparams = img2img_transformer2d_q1 ()
340+ hparams .batch_size = 2
406341 hparams .query_shape = (8 , 16 )
407342 hparams .memory_flange = (8 , 32 )
408343 return hparams
409344
410345
411346@registry .register_hparams
412- def img2img_transformer2d_n41 ():
413- hparams = img2img_transformer2d_n4 ()
414- hparams .num_decoder_layers = 10
415- hparams .num_encoder_layers = 6
416- hparams .query_shape = (16 , 16 )
417- hparams .memory_flange = (32 , 32 )
347+ def img2img_transformer_base ():
348+ """Base params for local1d attention."""
349+ hparams = image_transformer2d_base ()
350+ # learning related flags
351+ hparams .layer_preprocess_sequence = "n"
352+ hparams .layer_postprocess_sequence = "da"
353+ # This version seems to benefit from a higher learning rate.
354+ hparams .learning_rate = 0.2
355+ hparams .layer_prepostprocess_dropout = 0.1
356+ hparams .learning_rate_warmup_steps = 12000
357+ hparams .filter_size = 2048
358+ hparams .num_encoder_layers = 4
359+ hparams .num_decoder_layers = 8
360+ hparams .block_length = 256
361+ hparams .block_width = 256
362+ hparams .dec_attention_type = cia .AttentionType .LOCAL_1D
363+ hparams .block_rastor_scan = False
418364 return hparams
419365
420366
421367@registry .register_hparams
422- def img2img_transformer2d_n42 ():
423- hparams = img2img_transformer2d_n4 ()
424- hparams .num_decoder_layers = 12
425- hparams .num_encoder_layers = 6
426- hparams .query_shape = (16 , 16 )
427- hparams .memory_flange = (32 , 16 )
428- hparams .layer_prepostprocess_dropout = 0.1
368+ def img2img_transformer_b1 ():
369+ hparams = img2img_transformer_base ()
370+ hparams .batch_size = 2
371+ hparams .layer_preprocess_sequence = "none"
372+ hparams .layer_postprocess_sequence = "dan"
373+ hparams .block_length = 512
429374 return hparams
430375
431376
432377@registry .register_hparams
433- def img2img_transformer2d_n43 ():
434- hparams = img2img_transformer2d_n4 ()
435- hparams .num_decoder_layers = 12
436- hparams .num_encoder_layers = 6
437- hparams .query_shape = ( 8 , 16 )
438- hparams .memory_flange = ( 8 , 64 )
378+ def img2img_transformer_b2 ():
379+ hparams = img2img_transformer_base ()
380+ hparams .batch_size = 2
381+ hparams .layer_preprocess_sequence = "none"
382+ hparams .layer_postprocess_sequence = "dan"
383+ hparams .block_length = 256
439384 return hparams
440385
441386
442387@registry .register_hparams
443- def img2img_transformer2d_n44 ():
444- hparams = img2img_transformer2d_base ()
445- hparams .batch_size = 1
446- hparams .num_decoder_layers = 8
447- hparams .query_shape = (8 , 16 )
448- hparams .memory_flange = (8 , 32 )
449- hparams .layer_prepostprocess_dropout = 0.1
388+ def img2img_transformer_b3 ():
389+ """Current best hparams for local 1d."""
390+ hparams = img2img_transformer_base ()
391+ hparams .batch_size = 2
392+ hparams .layer_preprocess_sequence = "none"
393+ hparams .layer_postprocess_sequence = "dan"
394+ hparams .block_length = 128
395+ hparams .sampling_temp = 0.9
450396 return hparams
451397
452398
453399@registry .register_hparams
454- def img2img_transformer2d_n5 ():
455- hparams = img2img_transformer2d_base ()
456- hparams .batch_size = 1
400+ def img2img_transformer_dilated ():
401+ """Try dilated."""
402+ hparams = img2img_transformer_base ()
403+ hparams .add_hparam ("num_memory_blocks" , 1 )
457404 hparams .num_heads = 8
458- hparams .num_decoder_layers = 16
405+ hparams .attention_key_channels = hparams .attention_value_channels = 0
406+ hparams .hidden_size = 512
407+ hparams .filter_size = 2048
408+ hparams .num_decoder_layers = 8
409+ hparams .sampling_method = "random"
410+ hparams .gap_sizes = [0 , 16 , 64 , 0 , 16 , 64 , 128 , 0 ]
411+ hparams .dec_attention_type = cia .AttentionType .DILATED
412+ hparams .img_len = 64
413+ hparams .block_length = 128
414+ hparams .block_width = 128
459415 return hparams
460416
461417
462418@registry .register_hparams
463- def img2img_transformer2d_n6 ():
464- hparams = img2img_transformer2d_base ()
419+ def imagetransformer2d_tiny ():
420+ hparams = imagetransformer2d_base ()
421+ hparams .num_decoder_layers = 2
422+ hparams .hidden_size = 64
465423 hparams .batch_size = 1
466- hparams .learning_rate = 0.05
467- hparams .num_decoder_layers = 12
468- hparams .num_encoder_layers = 6
469- hparams .query_shape = (8 , 32 )
470- hparams .memory_flange = (8 , 32 )
471424 return hparams
472425
473426
474427@registry .register_hparams
475- def img2img_transformer2d_n7 ():
428+ def img2img_transformer2d_n3 ():
476429 hparams = img2img_transformer2d_base ()
477430 hparams .batch_size = 1
478431 hparams .num_encoder_layers = 4
479- hparams .num_decoder_layers = 8
432+ hparams .num_decoder_layers = 12
480433 hparams .query_shape = (16 , 32 )
481434 hparams .memory_flange = (16 , 16 )
482435 hparams .layer_prepostprocess_dropout = 0.0
483436 return hparams
484437
485438
486439@registry .register_hparams
487- def img2img_transformer2d_n8 ():
488- hparams = img2img_transformer2d_base ()
489- hparams .batch_size = 1
490- hparams .num_decoder_layers = 8
491- hparams .query_shape = (16 , 16 )
492- hparams .memory_flange = (16 , 16 )
493- return hparams
494-
495-
496- @registry .register_hparams
497- def img2img_transformer2d_n9 ():
440+ def img2img_transformer2d_n31 ():
441+ """Set of hyperparameters."""
498442 hparams = img2img_transformer2d_base ()
499443 hparams .batch_size = 1
500- hparams .num_heads = 8
444+ hparams .num_encoder_layers = 6
501445 hparams .num_decoder_layers = 12
502- hparams .filter_size = 2048
503- hparams .learning_rate = 0.05
446+ hparams .num_heads = 8
447+ hparams .query_shape = (16 , 32 )
448+ hparams .memory_flange = (16 , 32 )
504449 return hparams
505450
506451
507452@registry .register_hparams
508- def img2img_transformer2d_n10 ():
453+ def img2img_transformer2d_n24 ():
509454 hparams = img2img_transformer2d_base ()
510455 hparams .batch_size = 1
511- hparams .learning_rate = 0.05
512- hparams .num_decoder_layers = 12
513- hparams .num_encoder_layers = 6
514- hparams .query_shape = (8 , 32 )
456+ hparams .hidden_size = 1024
457+ hparams .filter_size = 2048
458+ hparams .layer_prepostprocess_dropout = 0.2
459+ hparams .num_decoder_layers = 8
460+ hparams .query_shape = (8 , 16 )
515461 hparams .memory_flange = (8 , 32 )
516- hparams .layer_prepostprocess_dropout = 0.0
517462 return hparams
518463
519464
520465@registry .register_hparams
521- def img2img_transformer2d_n101 ():
522- hparams = img2img_transformer2d_n10 ()
466+ def img2img_transformer2d_n44 ():
467+ hparams = img2img_transformer2d_base ()
523468 hparams .batch_size = 1
524- hparams .num_decoder_layers = 12
525- hparams .num_encoder_layers = 6
526- hparams .query_shape = (8 , 32 )
469+ hparams .num_decoder_layers = 8
470+ hparams .query_shape = (8 , 16 )
527471 hparams .memory_flange = (8 , 32 )
528472 hparams .layer_prepostprocess_dropout = 0.1
529473 return hparams
530474
531475
532- @registry .register_hparams
533- def img2img_transformer2d_n102 ():
534- hparams = img2img_transformer2d_n10 ()
535- hparams .batch_size = 1
536- hparams .num_decoder_layers = 12
537- hparams .num_encoder_layers = 4
538- hparams .query_shape = (8 , 32 )
539- hparams .memory_flange = (16 , 32 )
540- hparams .layer_prepostprocess_dropout = 0.1
541- return hparams
542-
543-
544476@registry .register_hparams
545477def img2img_transformer2d_n103 ():
546478 """Best config for img2img."""
547- hparams = img2img_transformer2d_n10 ()
479+ hparams = img2img_transformer2d_base ()
548480 hparams .batch_size = 1
549481 hparams .num_decoder_layers = 12
550482 hparams .num_encoder_layers = 6
@@ -570,24 +502,6 @@ def img2img_transformer2d_tiny():
570502 return hparams
571503
572504
573- @registry .register_hparams
574- def img2img_transformer_n3 ():
575- hparams = img2img_transformer_base ()
576- hparams .batch_size = 1
577- hparams .num_decoder_layers = 12
578- return hparams
579-
580-
581- @registry .register_hparams
582- def img2img_transformer_n4 ():
583- hparams = img2img_transformer_base ()
584- hparams .batch_size = 1
585- hparams .num_decoder_layers = 8
586- hparams .block_length = 256
587- hparams .block_width = 128
588- return hparams
589-
590-
591505@registry .register_hparams
592506def img2img_transformer_tiny ():
593507 """Tiny params."""
0 commit comments