33from __future__ import print_function
44
55import tensorflow as tf
6- import tensorflow . contrib . slim as slim
6+ import tf_slim as slim
77from core .SRGAN .ops import preprocessLR , preprocess , random_flip , conv2 , batchnorm , prelu_tf , pixelShuffler , lrelu , \
88 denselayer , vgg_19
99import collections
@@ -38,7 +38,7 @@ def data_loader(FLAGS):
3838 image_list_LR_tensor = tf .convert_to_tensor (image_list_LR , dtype = tf .string )
3939 image_list_HR_tensor = tf .convert_to_tensor (image_list_HR , dtype = tf .string )
4040
41- with tf .variable_scope ('load_image' ):
41+ with tf .compat . v1 . variable_scope ('load_image' ):
4242 # define the image list queue
4343 # image_list_LR_queue = tf.train.string_input_producer(
4444 # image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
@@ -98,7 +98,7 @@ def data_loader(FLAGS):
9898 inputs = tf .identity (inputs )
9999 targets = tf .identity (targets )
100100
101- with tf .variable_scope ('random_flip' ):
101+ with tf .compat . v1 . variable_scope ('random_flip' ):
102102 # Check for random flip:
103103 if (FLAGS .flip is True ) and (FLAGS .mode == 'train' ):
104104 print ('[Config] Use random flip' )
@@ -228,7 +228,7 @@ def preprocess_test(name):
228228def generator (gen_inputs , gen_output_channels , is_training , num_resblock , reuse = False ):
229229 # The Bx residual blocks
230230 def residual_block (inputs , output_channel , stride , scope ):
231- with tf .variable_scope (scope ):
231+ with tf .compat . v1 . variable_scope (scope ):
232232 net = conv2 (inputs , 3 , output_channel , stride , use_bias = False , scope = 'conv_1' )
233233 net = batchnorm (net , is_training )
234234 net = prelu_tf (net )
@@ -238,9 +238,9 @@ def residual_block(inputs, output_channel, stride, scope):
238238
239239 return net
240240
241- with tf .variable_scope ('generator_unit' , reuse = reuse ):
241+ with tf .compat . v1 . variable_scope ('generator_unit' , reuse = reuse ):
242242 # The input layer
243- with tf .variable_scope ('input_stage' ):
243+ with tf .compat . v1 . variable_scope ('input_stage' ):
244244 net = conv2 (gen_inputs , 9 , 64 , 1 , scope = 'conv' )
245245 net = prelu_tf (net )
246246
@@ -251,23 +251,23 @@ def residual_block(inputs, output_channel, stride, scope):
251251 name_scope = 'resblock_%d' % (i )
252252 net = residual_block (net , 64 , 1 , name_scope )
253253
254- with tf .variable_scope ('resblock_output' ):
254+ with tf .compat . v1 . variable_scope ('resblock_output' ):
255255 net = conv2 (net , 3 , 64 , 1 , use_bias = False , scope = 'conv' )
256256 net = batchnorm (net , is_training )
257257
258258 net = net + stage1_output
259259
260- with tf .variable_scope ('subpixelconv_stage1' ):
260+ with tf .compat . v1 . variable_scope ('subpixelconv_stage1' ):
261261 net = conv2 (net , 3 , 256 , 1 , scope = 'conv' )
262262 net = pixelShuffler (net , scale = 2 )
263263 net = prelu_tf (net )
264264
265- with tf .variable_scope ('subpixelconv_stage2' ):
265+ with tf .compat . v1 . variable_scope ('subpixelconv_stage2' ):
266266 net = conv2 (net , 3 , 256 , 1 , scope = 'conv' )
267267 net = pixelShuffler (net , scale = 2 )
268268 net = prelu_tf (net )
269269
270- with tf .variable_scope ('output_stage' ):
270+ with tf .compat . v1 . variable_scope ('output_stage' ):
271271 net = conv2 (net , 9 , gen_output_channels , 1 , scope = 'conv' )
272272
273273 return net
@@ -280,17 +280,17 @@ def discriminator(dis_inputs, FLAGS=None):
280280
281281 # Define the discriminator block
282282 def discriminator_block (inputs , output_channel , kernel_size , stride , scope ):
283- with tf .variable_scope (scope ):
283+ with tf .compat . v1 . variable_scope (scope ):
284284 net = conv2 (inputs , kernel_size , output_channel , stride , use_bias = False , scope = 'conv1' )
285285 net = batchnorm (net , FLAGS .is_training )
286286 net = lrelu (net , 0.2 )
287287
288288 return net
289289
290290 with tf .device ('/gpu:0' ):
291- with tf .variable_scope ('discriminator_unit' ):
291+ with tf .compat . v1 . variable_scope ('discriminator_unit' ):
292292 # The input layer
293- with tf .variable_scope ('input_stage' ):
293+ with tf .compat . v1 . variable_scope ('input_stage' ):
294294 net = conv2 (dis_inputs , 3 , 64 , 1 , scope = 'conv' )
295295 net = lrelu (net , 0.2 )
296296
@@ -317,13 +317,13 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope):
317317 net = discriminator_block (net , 512 , 3 , 2 , 'disblock_7' )
318318
319319 # The dense layer 1
320- with tf .variable_scope ('dense_layer_1' ):
320+ with tf .compat . v1 . variable_scope ('dense_layer_1' ):
321321 net = slim .flatten (net )
322322 net = denselayer (net , 1024 )
323323 net = lrelu (net , 0.2 )
324324
325325 # The dense layer 2
326- with tf .variable_scope ('dense_layer_2' ):
326+ with tf .compat . v1 . variable_scope ('dense_layer_2' ):
327327 net = denselayer (net , 1 )
328328 net = tf .nn .sigmoid (net )
329329
@@ -352,19 +352,19 @@ def SRGAN(inputs, targets, FLAGS):
352352 learning_rate' )
353353
354354 # Build the generator part
355- with tf .variable_scope ('generator' ):
355+ with tf .compat . v1 . variable_scope ('generator' ):
356356 output_channel = targets .get_shape ().as_list ()[- 1 ]
357357 gen_output = generator (inputs , output_channel , reuse = False , FLAGS = FLAGS )
358358 gen_output .set_shape ([FLAGS .batch_size , FLAGS .crop_size * 4 , FLAGS .crop_size * 4 , 3 ])
359359
360360 # Build the fake discriminator
361361 with tf .name_scope ('fake_discriminator' ):
362- with tf .variable_scope ('discriminator' , reuse = False ):
362+ with tf .compat . v1 . variable_scope ('discriminator' , reuse = False ):
363363 discrim_fake_output = discriminator (gen_output , FLAGS = FLAGS )
364364
365365 # Build the real discriminator
366366 with tf .name_scope ('real_discriminator' ):
367- with tf .variable_scope ('discriminator' , reuse = True ):
367+ with tf .compat . v1 . variable_scope ('discriminator' , reuse = True ):
368368 discrim_real_output = discriminator (targets , FLAGS = FLAGS )
369369
370370 # Use the VGG54 feature
@@ -390,47 +390,47 @@ def SRGAN(inputs, targets, FLAGS):
390390 raise NotImplementedError ('Unknown perceptual type!!' )
391391
392392 # Calculating the generator loss
393- with tf .variable_scope ('generator_loss' ):
393+ with tf .compat . v1 . variable_scope ('generator_loss' ):
394394 # Content loss
395- with tf .variable_scope ('content_loss' ):
395+ with tf .compat . v1 . variable_scope ('content_loss' ):
396396 # Compute the euclidean distance between the two features
397397 diff = extracted_feature_gen - extracted_feature_target
398398 if FLAGS .perceptual_mode == 'MSE' :
399399 content_loss = tf .reduce_mean (tf .reduce_sum (tf .square (diff ), axis = [3 ]))
400400 else :
401401 content_loss = FLAGS .vgg_scaling * tf .reduce_mean (tf .reduce_sum (tf .square (diff ), axis = [3 ]))
402402
403- with tf .variable_scope ('adversarial_loss' ):
403+ with tf .compat . v1 . variable_scope ('adversarial_loss' ):
404404 adversarial_loss = tf .reduce_mean (- tf .log (discrim_fake_output + FLAGS .EPS ))
405405
406406 gen_loss = content_loss + (FLAGS .ratio ) * adversarial_loss
407407 print (adversarial_loss .get_shape ())
408408 print (content_loss .get_shape ())
409409
410410 # Calculating the discriminator loss
411- with tf .variable_scope ('discriminator_loss' ):
411+ with tf .compat . v1 . variable_scope ('discriminator_loss' ):
412412 discrim_fake_loss = tf .log (1 - discrim_fake_output + FLAGS .EPS )
413413 discrim_real_loss = tf .log (discrim_real_output + FLAGS .EPS )
414414
415415 discrim_loss = tf .reduce_mean (- (discrim_fake_loss + discrim_real_loss ))
416416
417417 # Define the learning rate and global step
418- with tf .variable_scope ('get_learning_rate_and_global_step' ):
419- global_step = tf .contrib . framework .get_or_create_global_step ()
418+ with tf .compat . v1 . variable_scope ('get_learning_rate_and_global_step' ):
419+ global_step = tf .compat . v1 . train .get_or_create_global_step ()
420420 learning_rate = tf .train .exponential_decay (FLAGS .learning_rate , global_step , FLAGS .decay_step , FLAGS .decay_rate ,
421421 staircase = FLAGS .stair )
422422 incr_global_step = tf .assign (global_step , global_step + 1 )
423423
424- with tf .variable_scope ('dicriminator_train' ):
425- discrim_tvars = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = 'discriminator' )
424+ with tf .compat . v1 . variable_scope ('dicriminator_train' ):
425+ discrim_tvars = tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .TRAINABLE_VARIABLES , scope = 'discriminator' )
426426 discrim_optimizer = tf .train .AdamOptimizer (learning_rate , beta1 = FLAGS .beta )
427427 discrim_grads_and_vars = discrim_optimizer .compute_gradients (discrim_loss , discrim_tvars )
428428 discrim_train = discrim_optimizer .apply_gradients (discrim_grads_and_vars )
429429
430- with tf .variable_scope ('generator_train' ):
430+ with tf .compat . v1 . variable_scope ('generator_train' ):
431431 # Need to wait discriminator to perform train step
432- with tf .control_dependencies ([discrim_train ] + tf .get_collection (tf .GraphKeys .UPDATE_OPS )):
433- gen_tvars = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = 'generator' )
432+ with tf .control_dependencies ([discrim_train ] + tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .UPDATE_OPS )):
433+ gen_tvars = tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .TRAINABLE_VARIABLES , scope = 'generator' )
434434 gen_optimizer = tf .train .AdamOptimizer (learning_rate , beta1 = FLAGS .beta )
435435 gen_grads_and_vars = gen_optimizer .compute_gradients (gen_loss , gen_tvars )
436436 gen_train = gen_optimizer .apply_gradients (gen_grads_and_vars )
@@ -460,7 +460,7 @@ def SRResnet(inputs, targets, FLAGS):
460460 learning_rate' )
461461
462462 # Build the generator part
463- with tf .variable_scope ('generator' ):
463+ with tf .compat . v1 . variable_scope ('generator' ):
464464 output_channel = targets .get_shape ().as_list ()[- 1 ]
465465 gen_output = generator (inputs , output_channel , reuse = False , FLAGS = FLAGS )
466466 gen_output .set_shape ([FLAGS .batch_size , FLAGS .crop_size * 4 , FLAGS .crop_size * 4 , 3 ])
@@ -486,9 +486,9 @@ def SRResnet(inputs, targets, FLAGS):
486486 raise NotImplementedError ('Unknown perceptual type' )
487487
488488 # Calculating the generator loss
489- with tf .variable_scope ('generator_loss' ):
489+ with tf .compat . v1 . variable_scope ('generator_loss' ):
490490 # Content loss
491- with tf .variable_scope ('content_loss' ):
491+ with tf .compat . v1 . variable_scope ('content_loss' ):
492492 # Compute the euclidean distance between the two features
493493 # check=tf.equal(extracted_feature_gen, extracted_feature_target)
494494 diff = extracted_feature_gen - extracted_feature_target
@@ -500,16 +500,16 @@ def SRResnet(inputs, targets, FLAGS):
500500 gen_loss = content_loss
501501
502502 # Define the learning rate and global step
503- with tf .variable_scope ('get_learning_rate_and_global_step' ):
504- global_step = tf .contrib . framework .get_or_create_global_step ()
503+ with tf .compat . v1 . variable_scope ('get_learning_rate_and_global_step' ):
504+ global_step = tf .compat . v1 . train .get_or_create_global_step ()
505505 learning_rate = tf .train .exponential_decay (FLAGS .learning_rate , global_step , FLAGS .decay_step , FLAGS .decay_rate ,
506506 staircase = FLAGS .stair )
507507 incr_global_step = tf .assign (global_step , global_step + 1 )
508508
509- with tf .variable_scope ('generator_train' ):
509+ with tf .compat . v1 . variable_scope ('generator_train' ):
510510 # Need to wait discriminator to perform train step
511- with tf .control_dependencies (tf .get_collection (tf .GraphKeys .UPDATE_OPS )):
512- gen_tvars = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , scope = 'generator' )
511+ with tf .control_dependencies (tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .UPDATE_OPS )):
512+ gen_tvars = tf .compat . v1 . get_collection (tf . compat . v1 .GraphKeys .TRAINABLE_VARIABLES , scope = 'generator' )
513513 gen_optimizer = tf .train .AdamOptimizer (learning_rate , beta1 = FLAGS .beta )
514514 gen_grads_and_vars = gen_optimizer .compute_gradients (gen_loss , gen_tvars )
515515 gen_train = gen_optimizer .apply_gradients (gen_grads_and_vars )
0 commit comments