Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9ae9b63

Browse files
authored
Merge pull request #610 from bwilbertz/packed-reverse-copy
fix reverse/copy for packed problems
2 parents 5c0d89e + 8c9292d commit 9ae9b63

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,22 @@ def maybe_reverse_features(self, feature_map):
425425
return
426426
inputs, targets = feature_map["inputs"], feature_map["targets"]
427427
feature_map["inputs"], feature_map["targets"] = targets, inputs
428+
if "inputs_segmentation" in feature_map:
429+
inputs, targets = feature_map["inputs_segmentation"], feature_map["targets_segmentation"]
430+
feature_map["inputs_segmentation"], feature_map["targets_segmentation"] = targets, inputs
431+
if "inputs_position" in feature_map:
432+
inputs, targets = feature_map["inputs_position"], feature_map["targets_position"]
433+
feature_map["inputs_position"], feature_map["targets_position"] = targets, inputs
434+
428435

429436
def maybe_copy_features(self, feature_map):
430437
if not self._was_copy:
431438
return
432439
feature_map["targets"] = feature_map["inputs"]
440+
if "inputs_segmentation" in feature_map:
441+
feature_map["targets_segmentation"] = feature_map["inputs_segmentation"]
442+
if "inputs_position" in feature_map:
443+
feature_map["targets_position"] = feature_map["inputs_position"]
433444

434445
def dataset(self,
435446
mode,

0 commit comments

Comments
 (0)