|
105 | 105 | help='use Native AMP for mixed precision training') |
106 | 106 | parser.add_argument('--amp-dtype', default='float16', type=str, |
107 | 107 | help='lower precision AMP dtype (default: float16)') |
| 108 | +parser.add_argument('--model-dtype', default=None, type=str, |
| 109 | + help='Model dtype override (non-AMP) (default: float32)') |
108 | 110 | parser.add_argument('--fuser', default='', type=str, |
109 | 111 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
110 | 112 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) |
@@ -161,9 +163,15 @@ def main(): |
161 | 163 |
|
162 | 164 | device = torch.device(args.device) |
163 | 165 |
|
| 166 | + model_dtype = None |
| 167 | + if args.model_dtype: |
| 168 | + assert args.model_dtype in ('float32', 'float16', 'bfloat16') |
| 169 | + model_dtype = getattr(torch, args.model_dtype) |
| 170 | + |
164 | 171 | # resolve AMP arguments based on PyTorch / Apex availability |
165 | 172 | amp_autocast = suppress |
166 | 173 | if args.amp: |
| 174 | + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' |
167 | 175 | assert args.amp_dtype in ('float16', 'bfloat16') |
168 | 176 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 |
169 | 177 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) |
@@ -201,7 +209,7 @@ def main(): |
201 | 209 | if args.test_pool: |
202 | 210 | model, test_time_pool = apply_test_time_pool(model, data_config) |
203 | 211 |
|
204 | | - model = model.to(device) |
| 212 | + model = model.to(device=device, dtype=model_dtype) |
205 | 213 | model.eval() |
206 | 214 | if args.channels_last: |
207 | 215 | model = model.to(memory_format=torch.channels_last) |
@@ -237,6 +245,7 @@ def main(): |
237 | 245 | use_prefetcher=True, |
238 | 246 | num_workers=workers, |
239 | 247 | device=device, |
| 248 | + img_dtype=model_dtype or torch.float32, |
240 | 249 | **data_config, |
241 | 250 | ) |
242 | 251 |
|
@@ -280,7 +289,7 @@ def main(): |
280 | 289 | np_labels = to_label(np_indices) |
281 | 290 | all_labels.append(np_labels) |
282 | 291 |
|
283 | | - all_outputs.append(output.cpu().numpy()) |
| 292 | + all_outputs.append(output.float().cpu().numpy()) |
284 | 293 |
|
285 | 294 | # measure elapsed time |
286 | 295 | batch_time.update(time.time() - end) |
|
0 commit comments