2929import tensorflow as tf
3030
3131
32- def reconstruct_loss (x , gt , hparams , reuse = None ):
33- pred = tf .layers .dense (x , hparams .vocab_size , name = "softmax" , reuse = reuse )
34- xent , w = common_layers .padded_cross_entropy (pred , gt , 0.0 )
35- return xent / w
36-
3732
3833def discriminator (x , compress , hparams , name , reuse = None ):
3934 with tf .variable_scope (name , reuse = reuse ):
40- x = tf .stop_gradient (2 * x ) - x # Reverse gradient.
35+ x = tf .stop_gradient (2 * x ) - x # Reverse gradient.
4136 if compress :
42- x = transformer_vae .compress (x , None , hparams , "compress" )
37+ x = transformer_vae .compress (x , None , False , hparams , "compress" )
4338 else :
44- x = transformer_vae .residual_conv (x , 1 , hparams , "compress_rc" )
39+ x = transformer_vae .residual_conv (x , 1 , 3 , hparams , "compress_rc" )
4540 y = tf .reduce_mean (x , axis = 1 )
4641 return tf .tanh (tf .layers .dense (y , 1 , name = "reduce" ))
4742
43+ def generator (x , hparams , name , reuse = False ):
44+ with tf .variable_scope (name , reuse = reuse ):
45+ return transformer_vae .residual_conv (x , 1 , 3 , hparams ,"generator" )
4846
49- def discriminate_loss (x , y , compress , hparams , name ):
47+
48+ def loss (real_input , fake_input , compress , hparams , lsgan , name ):
49+ eps = 1e-12
5050 with tf .variable_scope (name ):
51- d1 = discriminator (x , compress , hparams , "discriminator" )
52- d2 = discriminator (y , compress , hparams , "discriminator" , reuse = True )
53- dloss = tf .reduce_mean (tf .abs (d1 - d2 ))
54- return - dloss
55-
51+ d1 = discriminator (real_input , compress , hparams , "discriminator" )
52+ d2 = discriminator (fake_input , compress , hparams , "discriminator" , reuse = True )
53+ if lsgan :
54+ dloss = tf .reduce_mean (tf .squared_difference (d1 , 0.9 )) + tf .reduce_mean (tf .square (d2 ))
55+ gloss = tf .reduce_mean (tf .squared_difference (d2 , 0.9 ))
56+ loss = (dloss + gloss )/ 2
57+ else : #cross_entropy
58+ dloss = - tf .reduce_mean (tf .log (d1 + eps )) - tf .reduce_mean (tf .log (1 - d2 + eps ))
59+ gloss = - tf .reduce_mean (tf .log (d2 + eps ))
60+ loss = (dloss + gloss )/ 2
61+ return loss
62+
63+
5664
5765def split_on_batch (x ):
5866 batch_size = tf .shape (x )[0 ]
@@ -70,49 +78,39 @@ def cycle_gan_internal(inputs, targets, _, hparams):
7078 targets = common_layers .embedding (
7179 targets_orig , hparams .vocab_size , hparams .hidden_size ,
7280 "embed" , reuse = True )
73-
74- # Split the batch into input-input and target-target parts.
75- inputs1 , _ = split_on_batch (inputs )
76- _ , targets2 = split_on_batch (targets )
77-
78- # Define F and G, called inp2tgt and tgt2inp here.
79- def inp2tgt (x , reuse = False ):
80- return transformer_vae .residual_conv (x , 1 , hparams , "inp2tgt" , reuse )
81- def tgt2inp (x , reuse = False ):
82- return transformer_vae .residual_conv (x , 1 , hparams , "tgt2inp" , reuse )
83-
84- # Input-input part.
85- inp1_tgt = inp2tgt (inputs1 )
86- inp1_back = tgt2inp (inp1_tgt )
87-
88- # Target-target part.
89- tgt2_inp = tgt2inp (targets2 , reuse = True )
90- tgt2_back = inp2tgt (tgt2_inp , reuse = True )
91-
92- # Reconstruction losses.
93- inp1_orig , _ = split_on_batch (inputs_orig )
94- _ , tgt2_orig = split_on_batch (targets_orig )
95- inp1_loss = reconstruct_loss (
96- inp1_back , tf .squeeze (inp1_orig , axis = 3 ), hparams )
97- tgt2_loss = reconstruct_loss (
98- tgt2_back , tf .squeeze (tgt2_orig , axis = 3 ), hparams , reuse = True )
99-
100- # Discriminator losses.
101- dloss1 = discriminate_loss (inputs1 , tgt2_inp , True , hparams , "inp_disc" )
102- dloss2 = discriminate_loss (targets2 , inp1_tgt , True , hparams , "tgt_disc" )
103-
104- # Reconstruct targets from inputs.
105- tgt = inp2tgt (inputs , reuse = True )
106- tgt = tf .layers .dense (tgt , hparams .vocab_size , name = "softmax" , reuse = True )
107-
108- # We use the reconstruction only for tracking progress, no gradients here!
109- tgt = tf .stop_gradient (tf .expand_dims (tgt , axis = 2 ))
110-
111- losses = {"input_input" : hparams .cycle_loss_multiplier * inp1_loss ,
112- "target_target" : hparams .cycle_loss_multiplier * tgt2_loss ,
113- "input_disc" : dloss1 ,
114- "target_disc" : dloss2 }
115- return tgt , losses
81+
82+ X , _ = split_on_batch (inputs )
83+ _ , Y = split_on_batch (targets )
84+
85+ X_unembeded , _ = split_on_batch (inputs_orig )
86+ _ , Y_unembeded = split_on_batch (targets_orig )
87+
88+
89+ # Y --> X
90+ Y_fake = generator (Y , hparams , 'Fy' , reuse = False )
91+ YtoXloss = loss (X , Y_fake , True , hparams , True , "YtoX" )
92+
93+ # X --> Y
94+ X_fake = generator (X , hparams , 'Gx' , reuse = False )
95+ XtoYloss = loss (Y , X_fake , True , hparams , True , "XtoY" )
96+
97+ # Cycle-Consistency
98+ Y_fake_ = generator (Y_fake , hparams , 'Gx' , reuse = True )
99+ X_fake_ = generator (X_fake , hparams , 'Fy' , reuse = True )
100+ XtoXloss = hparams .cycle_loss_multiplier1 * tf .reduce_mean (tf .abs (X_fake_ - X ))
101+ YtoYloss = hparams .cycle_loss_multiplier2 * tf .reduce_mean (tf .abs (Y_fake_ - Y ))
102+ cycloss = XtoXloss + YtoYloss
103+
104+
105+ sample_generated = generator (inputs , hparams , 'Gx' , reuse = True )
106+ sample_generated = tf .layers .dense (sample_generated , hparams .vocab_size , name = "softmax" , reuse = None )
107+ sample_generated = tf .stop_gradient (tf .expand_dims (sample_generated , axis = 2 ))
108+
109+ losses = {"cycloss" : cycloss ,
110+ "YtoXloss" : YtoXloss ,
111+ "XtoYloss" : XtoYloss }
112+
113+ return sample_generated , losses
116114
117115
118116@registry .register_model
@@ -134,7 +132,15 @@ def cycle_gan_small():
134132 hparams .weight_decay = 3.0
135133 hparams .learning_rate = 0.05
136134 hparams .kl_warmup_steps = 5000
135+ #hparams.hidden_size = 8
137136 hparams .learning_rate_warmup_steps = 3000
138- hparams .add_hparam ("vocab_size" , 32 ) # Vocabulary size, need to set here.
139- hparams .add_hparam ("cycle_loss_multiplier" , 2.0 )
137+ hparams .add_hparam ("vocab_size" , 66 ) # Vocabulary size, need to set here.
138+ hparams .add_hparam ("cycle_loss_multiplier1" , 10.0 )
139+ hparams .add_hparam ("cycle_loss_multiplier2" , 10.0 )
140140 return hparams
141+
142+ # line 43 - 80 -82 are changed : residual network config
143+ #line 42 is changed - compress function
144+
145+
146+
0 commit comments