diff --git a/WGAN_GP.py b/WGAN_GP.py index 347004c5..f9d4e324 100644 --- a/WGAN_GP.py +++ b/WGAN_GP.py @@ -115,7 +115,7 @@ def build_model(self): interpolates = self.inputs + (alpha * differences) _,D_inter,_=self.discriminator(interpolates, is_training=True, reuse=True) gradients = tf.gradients(D_inter, [interpolates])[0] - slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) + slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) self.d_loss += self.lambd * gradient_penalty @@ -264,4 +264,4 @@ def load(self, checkpoint_dir): return True, counter else: print(" [*] Failed to find a checkpoint") - return False, 0 \ No newline at end of file + return False, 0