diff --git a/.gitignore b/.gitignore index 6333025..d45e338 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ checkpoint *.npy model.ckpt-* +snapshots/* +*.sh \ No newline at end of file diff --git a/network.py b/network.py index 1ea5939..7deb67a 100644 --- a/network.py +++ b/network.py @@ -73,10 +73,12 @@ def create_session(self): self.sess.run([global_init, local_init]) def restore(self, data_path, var_list=None): + if var_list is None: + var_list = tf.global_variables() if data_path.endswith('.npy'): - self.load_npy(data_path, self.sess) + self.load_npy(data_path, self.sess, var_list=var_list) else: - loader = tf.train.Saver(var_list=tf.global_variables()) + loader = tf.train.Saver(var_list=var_list) loader.restore(self.sess, data_path) print('Restore from {}'.format(data_path)) @@ -92,13 +94,16 @@ def save(self, saver, save_dir, step): print('The checkpoint has been created, step: {}'.format(step)) ## Restore from .npy - def load_npy(self, data_path, session, ignore_missing=False): + def load_npy(self, data_path, session, ignore_missing=False, var_list=None): '''Load network weights. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. ''' + if var_list is None: + var_list = tf.global_variables() data_dict = np.load(data_path, encoding='latin1').item() + var_names = [v.name for v in var_list] for op_name in data_dict: with tf.variable_scope(op_name, reuse=True): for param_name, data in data_dict[op_name].items(): @@ -107,6 +112,9 @@ def load_npy(self, data_path, session, ignore_missing=False): param_name = BN_param_map[param_name] var = tf.get_variable(param_name) + if var.name not in var_names: + print("Not restored: %s" % var.name) + continue session.run(var.assign(data)) except ValueError: if not ignore_missing: diff --git a/utils/image_reader.py b/utils/image_reader.py index 3d7a65e..ec7d623 100644 --- a/utils/image_reader.py +++ b/utils/image_reader.py @@ -100,6 +100,7 @@ def _random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_l combined_crop = tf.random_crop(combined_pad, [crop_h, crop_w, 4]) img_crop = combined_crop[:, :, :last_image_dim] label_crop = combined_crop[:, :, last_image_dim:] + label_crop = label_crop + ignore_label label_crop = tf.cast(label_crop, dtype=tf.uint8) # Set static shape so that tensorflow knows shape at compile time. @@ -133,10 +134,18 @@ def _infer_preprocess(img, swap_channel=False): return img, o_shape, n_shape -def _eval_preprocess(img, label, shape, dataset): - if dataset == 'cityscapes': +def _eval_preprocess(img, label, shape, dataset, ignore_label=255): + if 'citycapes' in dataset: img = tf.image.pad_to_bounding_box(img, 0, 0, shape[0], shape[1]) img.set_shape([shape[0], shape[1], 3]) + + label = tf.cast(label, dtype=tf.float32) + label = label - ignore_label # Needs to be subtracted and later added due to 0 padding. + label = tf.image.pad_to_bounding_box(label, 0, 0, shape[0], shape[1]) + label = label + ignore_label + label = tf.cast(label, dtype=tf.uint8) + label.set_shape([shape[0], shape[1], 1]) + else: img = tf.image.resize_images(img, shape, align_corners=True) @@ -178,7 +187,7 @@ def create_tf_dataset(self, cfg): else: # Evaluation phase dataset = dataset.map(lambda x, y: - _eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset), + _eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset, cfg.param['ignore_label']), num_parallel_calls=cfg.N_WORKERS) dataset = dataset.batch(1)