Skip to content

Commit f266f84

Browse files
authored
Merge pull request #1586 from lorenzbaraldi/eval_loss
Put validation loss under amp_autocast
2 parents 7c4ed4d + 3d6bc42 commit f266f84

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

train.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -970,16 +970,16 @@ def validate(
970970

971971
with amp_autocast():
972972
output = model(input)
973-
if isinstance(output, (tuple, list)):
974-
output = output[0]
973+
if isinstance(output, (tuple, list)):
974+
output = output[0]
975975

976-
# augmentation reduction
977-
reduce_factor = args.tta
978-
if reduce_factor > 1:
979-
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
980-
target = target[0:target.size(0):reduce_factor]
976+
# augmentation reduction
977+
reduce_factor = args.tta
978+
if reduce_factor > 1:
979+
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
980+
target = target[0:target.size(0):reduce_factor]
981981

982-
loss = loss_fn(output, target)
982+
loss = loss_fn(output, target)
983983
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
984984

985985
if args.distributed:

validate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,9 @@ def validate(args):
294294
with amp_autocast():
295295
output = model(input)
296296

297-
if valid_labels is not None:
298-
output = output[:, valid_labels]
299-
loss = criterion(output, target)
297+
if valid_labels is not None:
298+
output = output[:, valid_labels]
299+
loss = criterion(output, target)
300300

301301
if real_labels is not None:
302302
real_labels.add_result(output)

0 commit comments

Comments
 (0)