Skip to content

Commit b1f179c

Browse files
authored
Merge pull request #48 from IBM/tensorflow-2
Upgrade tensorflow to v2 and MAX-Base to v1.5.1
2 parents edc2a99 + 368d511 commit b1f179c

File tree

9 files changed

+79
-75
lines changed

9 files changed

+79
-75
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright 2018-2019 IBM Corp. All Rights Reserved.
2+
# Copyright 2018-2021 IBM Corp. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
FROM quay.io/codait/max-base:v1.4.0
17+
FROM quay.io/codait/max-base:v1.5.1
1818

1919
ARG model_bucket=https://max-cdn.cdn.appdomain.cloud/max-image-resolution-enhancer/1.0.0
2020
ARG model_file=assets.tar.gz

api/predict.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from flask import send_file
1818
from core.model import ModelWrapper
1919
from maxfw.core import MAX_API, PredictAPI
20-
from flask_restplus import abort
20+
from flask_restx import abort
2121
from werkzeug.datastructures import FileStorage
2222

23-
# Set up parser for input data (http://flask-restplus.readthedocs.io/en/stable/parsing.html)
23+
# Set up parser for input data (http://flask-restx.readthedocs.io/en/stable/parsing.html)
2424
input_parser = MAX_API.parser()
2525
input_parser.add_argument('image', type=FileStorage, location='files',
2626
required=True,

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright 2018-2019 IBM Corp. All Rights Reserved.
2+
# Copyright 2018-2021 IBM Corp. All Rights Reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
# Flask settings
1818
DEBUG = False
1919

20-
# Flask-restplus settings
21-
RESTPLUS_MASK_SWAGGER = False
20+
# Flask-restx settings
21+
RESTX_MASK_SWAGGER = False
2222
SWAGGER_UI_DOC_EXPANSION = 'none'
2323

2424
# API metadata

core/SRGAN/model.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import print_function
44

55
import tensorflow as tf
6-
import tensorflow.contrib.slim as slim
6+
import tf_slim as slim
77
from core.SRGAN.ops import preprocessLR, preprocess, random_flip, conv2, batchnorm, prelu_tf, pixelShuffler, lrelu, \
88
denselayer, vgg_19
99
import collections
@@ -38,7 +38,7 @@ def data_loader(FLAGS):
3838
image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string)
3939
image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string)
4040

41-
with tf.variable_scope('load_image'):
41+
with tf.compat.v1.variable_scope('load_image'):
4242
# define the image list queue
4343
# image_list_LR_queue = tf.train.string_input_producer(
4444
# image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
@@ -98,7 +98,7 @@ def data_loader(FLAGS):
9898
inputs = tf.identity(inputs)
9999
targets = tf.identity(targets)
100100

101-
with tf.variable_scope('random_flip'):
101+
with tf.compat.v1.variable_scope('random_flip'):
102102
# Check for random flip:
103103
if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
104104
print('[Config] Use random flip')
@@ -228,7 +228,7 @@ def preprocess_test(name):
228228
def generator(gen_inputs, gen_output_channels, is_training, num_resblock, reuse=False):
229229
# The Bx residual blocks
230230
def residual_block(inputs, output_channel, stride, scope):
231-
with tf.variable_scope(scope):
231+
with tf.compat.v1.variable_scope(scope):
232232
net = conv2(inputs, 3, output_channel, stride, use_bias=False, scope='conv_1')
233233
net = batchnorm(net, is_training)
234234
net = prelu_tf(net)
@@ -238,9 +238,9 @@ def residual_block(inputs, output_channel, stride, scope):
238238

239239
return net
240240

241-
with tf.variable_scope('generator_unit', reuse=reuse):
241+
with tf.compat.v1.variable_scope('generator_unit', reuse=reuse):
242242
# The input layer
243-
with tf.variable_scope('input_stage'):
243+
with tf.compat.v1.variable_scope('input_stage'):
244244
net = conv2(gen_inputs, 9, 64, 1, scope='conv')
245245
net = prelu_tf(net)
246246

@@ -251,23 +251,23 @@ def residual_block(inputs, output_channel, stride, scope):
251251
name_scope = 'resblock_%d' % (i)
252252
net = residual_block(net, 64, 1, name_scope)
253253

254-
with tf.variable_scope('resblock_output'):
254+
with tf.compat.v1.variable_scope('resblock_output'):
255255
net = conv2(net, 3, 64, 1, use_bias=False, scope='conv')
256256
net = batchnorm(net, is_training)
257257

258258
net = net + stage1_output
259259

260-
with tf.variable_scope('subpixelconv_stage1'):
260+
with tf.compat.v1.variable_scope('subpixelconv_stage1'):
261261
net = conv2(net, 3, 256, 1, scope='conv')
262262
net = pixelShuffler(net, scale=2)
263263
net = prelu_tf(net)
264264

265-
with tf.variable_scope('subpixelconv_stage2'):
265+
with tf.compat.v1.variable_scope('subpixelconv_stage2'):
266266
net = conv2(net, 3, 256, 1, scope='conv')
267267
net = pixelShuffler(net, scale=2)
268268
net = prelu_tf(net)
269269

270-
with tf.variable_scope('output_stage'):
270+
with tf.compat.v1.variable_scope('output_stage'):
271271
net = conv2(net, 9, gen_output_channels, 1, scope='conv')
272272

273273
return net
@@ -280,17 +280,17 @@ def discriminator(dis_inputs, FLAGS=None):
280280

281281
# Define the discriminator block
282282
def discriminator_block(inputs, output_channel, kernel_size, stride, scope):
283-
with tf.variable_scope(scope):
283+
with tf.compat.v1.variable_scope(scope):
284284
net = conv2(inputs, kernel_size, output_channel, stride, use_bias=False, scope='conv1')
285285
net = batchnorm(net, FLAGS.is_training)
286286
net = lrelu(net, 0.2)
287287

288288
return net
289289

290290
with tf.device('/gpu:0'):
291-
with tf.variable_scope('discriminator_unit'):
291+
with tf.compat.v1.variable_scope('discriminator_unit'):
292292
# The input layer
293-
with tf.variable_scope('input_stage'):
293+
with tf.compat.v1.variable_scope('input_stage'):
294294
net = conv2(dis_inputs, 3, 64, 1, scope='conv')
295295
net = lrelu(net, 0.2)
296296

@@ -317,13 +317,13 @@ def discriminator_block(inputs, output_channel, kernel_size, stride, scope):
317317
net = discriminator_block(net, 512, 3, 2, 'disblock_7')
318318

319319
# The dense layer 1
320-
with tf.variable_scope('dense_layer_1'):
320+
with tf.compat.v1.variable_scope('dense_layer_1'):
321321
net = slim.flatten(net)
322322
net = denselayer(net, 1024)
323323
net = lrelu(net, 0.2)
324324

325325
# The dense layer 2
326-
with tf.variable_scope('dense_layer_2'):
326+
with tf.compat.v1.variable_scope('dense_layer_2'):
327327
net = denselayer(net, 1)
328328
net = tf.nn.sigmoid(net)
329329

@@ -352,19 +352,19 @@ def SRGAN(inputs, targets, FLAGS):
352352
learning_rate')
353353

354354
# Build the generator part
355-
with tf.variable_scope('generator'):
355+
with tf.compat.v1.variable_scope('generator'):
356356
output_channel = targets.get_shape().as_list()[-1]
357357
gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
358358
gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3])
359359

360360
# Build the fake discriminator
361361
with tf.name_scope('fake_discriminator'):
362-
with tf.variable_scope('discriminator', reuse=False):
362+
with tf.compat.v1.variable_scope('discriminator', reuse=False):
363363
discrim_fake_output = discriminator(gen_output, FLAGS=FLAGS)
364364

365365
# Build the real discriminator
366366
with tf.name_scope('real_discriminator'):
367-
with tf.variable_scope('discriminator', reuse=True):
367+
with tf.compat.v1.variable_scope('discriminator', reuse=True):
368368
discrim_real_output = discriminator(targets, FLAGS=FLAGS)
369369

370370
# Use the VGG54 feature
@@ -390,47 +390,47 @@ def SRGAN(inputs, targets, FLAGS):
390390
raise NotImplementedError('Unknown perceptual type!!')
391391

392392
# Calculating the generator loss
393-
with tf.variable_scope('generator_loss'):
393+
with tf.compat.v1.variable_scope('generator_loss'):
394394
# Content loss
395-
with tf.variable_scope('content_loss'):
395+
with tf.compat.v1.variable_scope('content_loss'):
396396
# Compute the euclidean distance between the two features
397397
diff = extracted_feature_gen - extracted_feature_target
398398
if FLAGS.perceptual_mode == 'MSE':
399399
content_loss = tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
400400
else:
401401
content_loss = FLAGS.vgg_scaling * tf.reduce_mean(tf.reduce_sum(tf.square(diff), axis=[3]))
402402

403-
with tf.variable_scope('adversarial_loss'):
403+
with tf.compat.v1.variable_scope('adversarial_loss'):
404404
adversarial_loss = tf.reduce_mean(-tf.log(discrim_fake_output + FLAGS.EPS))
405405

406406
gen_loss = content_loss + (FLAGS.ratio) * adversarial_loss
407407
print(adversarial_loss.get_shape())
408408
print(content_loss.get_shape())
409409

410410
# Calculating the discriminator loss
411-
with tf.variable_scope('discriminator_loss'):
411+
with tf.compat.v1.variable_scope('discriminator_loss'):
412412
discrim_fake_loss = tf.log(1 - discrim_fake_output + FLAGS.EPS)
413413
discrim_real_loss = tf.log(discrim_real_output + FLAGS.EPS)
414414

415415
discrim_loss = tf.reduce_mean(-(discrim_fake_loss + discrim_real_loss))
416416

417417
# Define the learning rate and global step
418-
with tf.variable_scope('get_learning_rate_and_global_step'):
419-
global_step = tf.contrib.framework.get_or_create_global_step()
418+
with tf.compat.v1.variable_scope('get_learning_rate_and_global_step'):
419+
global_step = tf.compat.v1.train.get_or_create_global_step()
420420
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate,
421421
staircase=FLAGS.stair)
422422
incr_global_step = tf.assign(global_step, global_step + 1)
423423

424-
with tf.variable_scope('dicriminator_train'):
425-
discrim_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
424+
with tf.compat.v1.variable_scope('dicriminator_train'):
425+
discrim_tvars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
426426
discrim_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
427427
discrim_grads_and_vars = discrim_optimizer.compute_gradients(discrim_loss, discrim_tvars)
428428
discrim_train = discrim_optimizer.apply_gradients(discrim_grads_and_vars)
429429

430-
with tf.variable_scope('generator_train'):
430+
with tf.compat.v1.variable_scope('generator_train'):
431431
# Need to wait discriminator to perform train step
432-
with tf.control_dependencies([discrim_train] + tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
433-
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
432+
with tf.control_dependencies([discrim_train] + tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)):
433+
gen_tvars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
434434
gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
435435
gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
436436
gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)
@@ -460,7 +460,7 @@ def SRResnet(inputs, targets, FLAGS):
460460
learning_rate')
461461

462462
# Build the generator part
463-
with tf.variable_scope('generator'):
463+
with tf.compat.v1.variable_scope('generator'):
464464
output_channel = targets.get_shape().as_list()[-1]
465465
gen_output = generator(inputs, output_channel, reuse=False, FLAGS=FLAGS)
466466
gen_output.set_shape([FLAGS.batch_size, FLAGS.crop_size * 4, FLAGS.crop_size * 4, 3])
@@ -486,9 +486,9 @@ def SRResnet(inputs, targets, FLAGS):
486486
raise NotImplementedError('Unknown perceptual type')
487487

488488
# Calculating the generator loss
489-
with tf.variable_scope('generator_loss'):
489+
with tf.compat.v1.variable_scope('generator_loss'):
490490
# Content loss
491-
with tf.variable_scope('content_loss'):
491+
with tf.compat.v1.variable_scope('content_loss'):
492492
# Compute the euclidean distance between the two features
493493
# check=tf.equal(extracted_feature_gen, extracted_feature_target)
494494
diff = extracted_feature_gen - extracted_feature_target
@@ -500,16 +500,16 @@ def SRResnet(inputs, targets, FLAGS):
500500
gen_loss = content_loss
501501

502502
# Define the learning rate and global step
503-
with tf.variable_scope('get_learning_rate_and_global_step'):
504-
global_step = tf.contrib.framework.get_or_create_global_step()
503+
with tf.compat.v1.variable_scope('get_learning_rate_and_global_step'):
504+
global_step = tf.compat.v1.train.get_or_create_global_step()
505505
learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, FLAGS.decay_step, FLAGS.decay_rate,
506506
staircase=FLAGS.stair)
507507
incr_global_step = tf.assign(global_step, global_step + 1)
508508

509-
with tf.variable_scope('generator_train'):
509+
with tf.compat.v1.variable_scope('generator_train'):
510510
# Need to wait discriminator to perform train step
511-
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
512-
gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
511+
with tf.control_dependencies(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)):
512+
gen_tvars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
513513
gen_optimizer = tf.train.AdamOptimizer(learning_rate, beta1=FLAGS.beta)
514514
gen_grads_and_vars = gen_optimizer.compute_gradients(gen_loss, gen_tvars)
515515
gen_train = gen_optimizer.apply_gradients(gen_grads_and_vars)

core/SRGAN/model_dense.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def denseBlock(block_inputs, num_layers, bottleneck_scale, growth_rate, FLAGS):
3939
# Build each layer consecutively
4040
net = block_inputs
4141
for i in range(num_layers):
42-
with tf.variable_scope('dense_conv_layer%d' % (i + 1)):
42+
with tf.compat.v1.variable_scope('dense_conv_layer%d' % (i + 1)):
4343
net = denseConvlayer(net, bottleneck_scale, growth_rate, FLAGS.is_training)
4444

4545
return net
@@ -52,9 +52,9 @@ def generatorDense(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
5252
raise ValueError('No FLAGS is provided for generator')
5353

5454
# The main netowrk
55-
with tf.variable_scope('generator_unit', reuse=reuse):
55+
with tf.compat.v1.variable_scope('generator_unit', reuse=reuse):
5656
# The input stage
57-
with tf.variable_scope('input_stage'):
57+
with tf.compat.v1.variable_scope('input_stage'):
5858
net = conv2(gen_inputs, 9, 64, 1, scope='conv')
5959
net = prelu_tf(net)
6060

@@ -64,23 +64,23 @@ def generatorDense(gen_inputs, gen_output_channels, reuse=False, FLAGS=None):
6464
bottleneck_scale = 4
6565
growth_rate = 12
6666
transition_output_channel = 128
67-
with tf.variable_scope('denseBlock_1'):
67+
with tf.compat.v1.variable_scope('denseBlock_1'):
6868
net = denseBlock(net, layer_per_block, bottleneck_scale, growth_rate, FLAGS)
6969

70-
with tf.variable_scope('transition_layer_1'):
70+
with tf.compat.v1.variable_scope('transition_layer_1'):
7171
net = transitionLayer(net, transition_output_channel, FLAGS.is_training)
7272

73-
with tf.variable_scope('subpixelconv_stage1'):
73+
with tf.compat.v1.variable_scope('subpixelconv_stage1'):
7474
net = conv2(net, 3, 256, 1, scope='conv')
7575
net = pixelShuffler(net, scale=2)
7676
net = prelu_tf(net)
7777

78-
with tf.variable_scope('subpixelconv_stage2'):
78+
with tf.compat.v1.variable_scope('subpixelconv_stage2'):
7979
net = conv2(net, 3, 256, 1, scope='conv')
8080
net = pixelShuffler(net, scale=2)
8181
net = prelu_tf(net)
8282

83-
with tf.variable_scope('output_stage'):
83+
with tf.compat.v1.variable_scope('output_stage'):
8484
net = conv2(net, 9, gen_output_channels, 1, scope='conv')
8585

8686
return net

0 commit comments

Comments
 (0)