Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c2ae3dc

Browse files
Niki ParmarRyan Sepassi
authored andcommitted
Add hparams for img2img transformer for current best and dilated
PiperOrigin-RevId: 186061724
1 parent d6b80dd commit c2ae3dc

File tree

1 file changed

+94
-180
lines changed

1 file changed

+94
-180
lines changed

tensor2tensor/models/image_transformer_2d.py

Lines changed: 94 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -300,251 +300,183 @@ def img2img_transformer2d_base():
300300
hparams.layer_preprocess_sequence = "n"
301301
hparams.layer_postprocess_sequence = "da"
302302
# This version seems to benefit from a higher learning rate.
303-
hparams.learning_rate = 0.1
303+
hparams.learning_rate = 0.2
304304
hparams.layer_prepostprocess_dropout = 0.1
305+
hparams.learning_rate_warmup_steps = 12000
306+
hparams.filter_size = 2048
307+
hparams.num_encoder_layers = 4
308+
hparams.num_decoder_layers = 8
305309
hparams.dec_attention_type = cia.AttentionType.LOCAL_2D
306310
hparams.block_rastor_scan = True
307311
return hparams
308312

309313

310314
@registry.register_hparams
311-
def img2img_transformer_base():
312-
"""Base params for local1d attention."""
313-
hparams = image_transformer2d_base()
314-
# learning related flags
315-
hparams.layer_preprocess_sequence = "n"
316-
hparams.layer_postprocess_sequence = "da"
317-
# This version seems to benefit from a higher learning rate.
318-
hparams.learning_rate = 0.1
319-
hparams.layer_prepostprocess_dropout = 0.1
320-
hparams.block_length = 256
321-
hparams.block_width = 256
322-
hparams.dec_attention_type = cia.AttentionType.LOCAL_1D
323-
hparams.block_rastor_scan = False
324-
return hparams
325-
326-
327-
@registry.register_hparams
328-
def imagetransformer2d_tiny():
329-
hparams = imagetransformer2d_base()
330-
hparams.num_decoder_layers = 2
331-
hparams.hidden_size = 64
332-
hparams.batch_size = 1
333-
return hparams
334-
335-
336-
@registry.register_hparams
337-
def img2img_transformer2d_n3():
338-
hparams = img2img_transformer2d_base()
339-
hparams.batch_size = 1
340-
hparams.num_encoder_layers = 4
341-
hparams.num_decoder_layers = 12
342-
hparams.query_shape = (16, 32)
343-
hparams.memory_flange = (16, 16)
344-
hparams.layer_prepostprocess_dropout = 0.0
345-
return hparams
346-
347-
348-
@registry.register_hparams
349-
def img2img_transformer2d_n31():
350-
"""Set of hyperparameters."""
351-
hparams = img2img_transformer2d_base()
352-
hparams.batch_size = 1
353-
hparams.num_encoder_layers = 6
354-
hparams.num_decoder_layers = 12
355-
hparams.num_heads = 8
356-
hparams.query_shape = (16, 32)
357-
hparams.memory_flange = (16, 32)
358-
return hparams
359-
360-
361-
@registry.register_hparams
362-
def img2img_transformer2d_n32():
363-
"""Set of hyperparameters."""
364-
hparams = img2img_transformer2d_base()
365-
hparams.batch_size = 1
366-
hparams.num_heads = 16
367-
hparams.num_encoder_layers = 4
368-
hparams.num_decoder_layers = 12
369-
hparams.query_shape = (16, 32)
370-
hparams.memory_flange = (16, 48)
371-
return hparams
372-
373-
374-
@registry.register_hparams
375-
def img2img_transformer2d_n4():
376-
"""Set of hyperparameters."""
315+
def img2img_transformer2d_q1():
377316
hparams = img2img_transformer2d_base()
378-
hparams.batch_size = 1
379-
hparams.num_decoder_layers = 8
317+
hparams.batch_size = 2
318+
hparams.layer_preprocess_sequence = "none"
319+
hparams.layer_postprocess_sequence = "dan"
380320
hparams.query_shape = (16, 16)
381-
hparams.memory_flange = (16, 16)
321+
hparams.memory_flange = (16, 64)
382322
return hparams
383323

384324

385325
@registry.register_hparams
386-
def img2img_transformer2d_n14():
387-
hparams = img2img_transformer2d_base()
388-
hparams.batch_size = 1
389-
hparams.hidden_size = 1024
390-
hparams.filter_size = 2048
391-
hparams.layer_prepostprocess_dropout = 0.2
392-
hparams.num_decoder_layers = 8
326+
def img2img_transformer2d_q2():
327+
hparams = img2img_transformer2d_q1()
328+
hparams.batch_size = 2
329+
hparams.layer_preprocess_sequence = "none"
330+
hparams.layer_postprocess_sequence = "dan"
393331
hparams.query_shape = (16, 16)
394-
hparams.memory_flange = (16, 16)
332+
hparams.memory_flange = (16, 32)
395333
return hparams
396334

397335

398336
@registry.register_hparams
399-
def img2img_transformer2d_n24():
400-
hparams = img2img_transformer2d_base()
401-
hparams.batch_size = 1
402-
hparams.hidden_size = 1024
403-
hparams.filter_size = 2048
404-
hparams.layer_prepostprocess_dropout = 0.2
405-
hparams.num_decoder_layers = 8
337+
def img2img_transformer2d_q3():
338+
"""Current best hparams for local 2d."""
339+
hparams = img2img_transformer2d_q1()
340+
hparams.batch_size = 2
406341
hparams.query_shape = (8, 16)
407342
hparams.memory_flange = (8, 32)
408343
return hparams
409344

410345

411346
@registry.register_hparams
412-
def img2img_transformer2d_n41():
413-
hparams = img2img_transformer2d_n4()
414-
hparams.num_decoder_layers = 10
415-
hparams.num_encoder_layers = 6
416-
hparams.query_shape = (16, 16)
417-
hparams.memory_flange = (32, 32)
347+
def img2img_transformer_base():
348+
"""Base params for local1d attention."""
349+
hparams = image_transformer2d_base()
350+
# learning related flags
351+
hparams.layer_preprocess_sequence = "n"
352+
hparams.layer_postprocess_sequence = "da"
353+
# This version seems to benefit from a higher learning rate.
354+
hparams.learning_rate = 0.2
355+
hparams.layer_prepostprocess_dropout = 0.1
356+
hparams.learning_rate_warmup_steps = 12000
357+
hparams.filter_size = 2048
358+
hparams.num_encoder_layers = 4
359+
hparams.num_decoder_layers = 8
360+
hparams.block_length = 256
361+
hparams.block_width = 256
362+
hparams.dec_attention_type = cia.AttentionType.LOCAL_1D
363+
hparams.block_rastor_scan = False
418364
return hparams
419365

420366

421367
@registry.register_hparams
422-
def img2img_transformer2d_n42():
423-
hparams = img2img_transformer2d_n4()
424-
hparams.num_decoder_layers = 12
425-
hparams.num_encoder_layers = 6
426-
hparams.query_shape = (16, 16)
427-
hparams.memory_flange = (32, 16)
428-
hparams.layer_prepostprocess_dropout = 0.1
368+
def img2img_transformer_b1():
369+
hparams = img2img_transformer_base()
370+
hparams.batch_size = 2
371+
hparams.layer_preprocess_sequence = "none"
372+
hparams.layer_postprocess_sequence = "dan"
373+
hparams.block_length = 512
429374
return hparams
430375

431376

432377
@registry.register_hparams
433-
def img2img_transformer2d_n43():
434-
hparams = img2img_transformer2d_n4()
435-
hparams.num_decoder_layers = 12
436-
hparams.num_encoder_layers = 6
437-
hparams.query_shape = (8, 16)
438-
hparams.memory_flange = (8, 64)
378+
def img2img_transformer_b2():
379+
hparams = img2img_transformer_base()
380+
hparams.batch_size = 2
381+
hparams.layer_preprocess_sequence = "none"
382+
hparams.layer_postprocess_sequence = "dan"
383+
hparams.block_length = 256
439384
return hparams
440385

441386

442387
@registry.register_hparams
443-
def img2img_transformer2d_n44():
444-
hparams = img2img_transformer2d_base()
445-
hparams.batch_size = 1
446-
hparams.num_decoder_layers = 8
447-
hparams.query_shape = (8, 16)
448-
hparams.memory_flange = (8, 32)
449-
hparams.layer_prepostprocess_dropout = 0.1
388+
def img2img_transformer_b3():
389+
"""Current best hparams for local 1d."""
390+
hparams = img2img_transformer_base()
391+
hparams.batch_size = 2
392+
hparams.layer_preprocess_sequence = "none"
393+
hparams.layer_postprocess_sequence = "dan"
394+
hparams.block_length = 128
395+
hparams.sampling_temp = 0.9
450396
return hparams
451397

452398

453399
@registry.register_hparams
454-
def img2img_transformer2d_n5():
455-
hparams = img2img_transformer2d_base()
456-
hparams.batch_size = 1
400+
def img2img_transformer_dilated():
401+
"""Try dilated."""
402+
hparams = img2img_transformer_base()
403+
hparams.add_hparam("num_memory_blocks", 1)
457404
hparams.num_heads = 8
458-
hparams.num_decoder_layers = 16
405+
hparams.attention_key_channels = hparams.attention_value_channels = 0
406+
hparams.hidden_size = 512
407+
hparams.filter_size = 2048
408+
hparams.num_decoder_layers = 8
409+
hparams.sampling_method = "random"
410+
hparams.gap_sizes = [0, 16, 64, 0, 16, 64, 128, 0]
411+
hparams.dec_attention_type = cia.AttentionType.DILATED
412+
hparams.img_len = 64
413+
hparams.block_length = 128
414+
hparams.block_width = 128
459415
return hparams
460416

461417

462418
@registry.register_hparams
463-
def img2img_transformer2d_n6():
464-
hparams = img2img_transformer2d_base()
419+
def imagetransformer2d_tiny():
420+
hparams = imagetransformer2d_base()
421+
hparams.num_decoder_layers = 2
422+
hparams.hidden_size = 64
465423
hparams.batch_size = 1
466-
hparams.learning_rate = 0.05
467-
hparams.num_decoder_layers = 12
468-
hparams.num_encoder_layers = 6
469-
hparams.query_shape = (8, 32)
470-
hparams.memory_flange = (8, 32)
471424
return hparams
472425

473426

474427
@registry.register_hparams
475-
def img2img_transformer2d_n7():
428+
def img2img_transformer2d_n3():
476429
hparams = img2img_transformer2d_base()
477430
hparams.batch_size = 1
478431
hparams.num_encoder_layers = 4
479-
hparams.num_decoder_layers = 8
432+
hparams.num_decoder_layers = 12
480433
hparams.query_shape = (16, 32)
481434
hparams.memory_flange = (16, 16)
482435
hparams.layer_prepostprocess_dropout = 0.0
483436
return hparams
484437

485438

486439
@registry.register_hparams
487-
def img2img_transformer2d_n8():
488-
hparams = img2img_transformer2d_base()
489-
hparams.batch_size = 1
490-
hparams.num_decoder_layers = 8
491-
hparams.query_shape = (16, 16)
492-
hparams.memory_flange = (16, 16)
493-
return hparams
494-
495-
496-
@registry.register_hparams
497-
def img2img_transformer2d_n9():
440+
def img2img_transformer2d_n31():
441+
"""Set of hyperparameters."""
498442
hparams = img2img_transformer2d_base()
499443
hparams.batch_size = 1
500-
hparams.num_heads = 8
444+
hparams.num_encoder_layers = 6
501445
hparams.num_decoder_layers = 12
502-
hparams.filter_size = 2048
503-
hparams.learning_rate = 0.05
446+
hparams.num_heads = 8
447+
hparams.query_shape = (16, 32)
448+
hparams.memory_flange = (16, 32)
504449
return hparams
505450

506451

507452
@registry.register_hparams
508-
def img2img_transformer2d_n10():
453+
def img2img_transformer2d_n24():
509454
hparams = img2img_transformer2d_base()
510455
hparams.batch_size = 1
511-
hparams.learning_rate = 0.05
512-
hparams.num_decoder_layers = 12
513-
hparams.num_encoder_layers = 6
514-
hparams.query_shape = (8, 32)
456+
hparams.hidden_size = 1024
457+
hparams.filter_size = 2048
458+
hparams.layer_prepostprocess_dropout = 0.2
459+
hparams.num_decoder_layers = 8
460+
hparams.query_shape = (8, 16)
515461
hparams.memory_flange = (8, 32)
516-
hparams.layer_prepostprocess_dropout = 0.0
517462
return hparams
518463

519464

520465
@registry.register_hparams
521-
def img2img_transformer2d_n101():
522-
hparams = img2img_transformer2d_n10()
466+
def img2img_transformer2d_n44():
467+
hparams = img2img_transformer2d_base()
523468
hparams.batch_size = 1
524-
hparams.num_decoder_layers = 12
525-
hparams.num_encoder_layers = 6
526-
hparams.query_shape = (8, 32)
469+
hparams.num_decoder_layers = 8
470+
hparams.query_shape = (8, 16)
527471
hparams.memory_flange = (8, 32)
528472
hparams.layer_prepostprocess_dropout = 0.1
529473
return hparams
530474

531475

532-
@registry.register_hparams
533-
def img2img_transformer2d_n102():
534-
hparams = img2img_transformer2d_n10()
535-
hparams.batch_size = 1
536-
hparams.num_decoder_layers = 12
537-
hparams.num_encoder_layers = 4
538-
hparams.query_shape = (8, 32)
539-
hparams.memory_flange = (16, 32)
540-
hparams.layer_prepostprocess_dropout = 0.1
541-
return hparams
542-
543-
544476
@registry.register_hparams
545477
def img2img_transformer2d_n103():
546478
"""Best config for img2img."""
547-
hparams = img2img_transformer2d_n10()
479+
hparams = img2img_transformer2d_base()
548480
hparams.batch_size = 1
549481
hparams.num_decoder_layers = 12
550482
hparams.num_encoder_layers = 6
@@ -570,24 +502,6 @@ def img2img_transformer2d_tiny():
570502
return hparams
571503

572504

573-
@registry.register_hparams
574-
def img2img_transformer_n3():
575-
hparams = img2img_transformer_base()
576-
hparams.batch_size = 1
577-
hparams.num_decoder_layers = 12
578-
return hparams
579-
580-
581-
@registry.register_hparams
582-
def img2img_transformer_n4():
583-
hparams = img2img_transformer_base()
584-
hparams.batch_size = 1
585-
hparams.num_decoder_layers = 8
586-
hparams.block_length = 256
587-
hparams.block_width = 128
588-
return hparams
589-
590-
591505
@registry.register_hparams
592506
def img2img_transformer_tiny():
593507
"""Tiny params."""

0 commit comments

Comments
 (0)