File tree Expand file tree Collapse file tree 7 files changed +91
-6
lines changed
tests/fixtures/deprecated_symbols
visitors/deprecated_symbols Expand file tree Collapse file tree 7 files changed +91
-6
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+
3+ torch .cuda .amp .autocast ()
4+ torch .cuda .amp .custom_fwd ()
5+ torch .cuda .amp .custom_bwd ()
6+
7+ dtype = torch .float32
8+ maybe_autocast = torch .cpu .amp .autocast ()
9+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
10+ maybe_autocast = torch .cpu .amp .autocast (dtype = dtype )
Original file line number Diff line number Diff line change 1+ 3:1 TOR101 Use of deprecated function torch.cuda.amp.autocast
2+ 4:1 TOR101 Use of deprecated function torch.cuda.amp.custom_fwd
3+ 5:1 TOR101 Use of deprecated function torch.cuda.amp.custom_bwd
4+ 8:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
5+ 9:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
6+ 10:18 TOR101 Use of deprecated function torch.cpu.amp.autocast
Original file line number Diff line number Diff line change 1+ import torch
2+
3+ dtype = torch .float32
4+
5+ maybe_autocast = torch .cuda .amp .autocast ()
6+ maybe_autocast = torch .cuda .amp .autocast (dtype = torch .bfloat16 )
7+ maybe_autocast = torch .cuda .amp .autocast (dtype = dtype )
8+
9+ maybe_autocast = torch .cpu .amp .autocast ()
10+ maybe_autocast = torch .cpu .amp .autocast (dtype = torch .bfloat16 )
11+ maybe_autocast = torch .cpu .amp .autocast (dtype = dtype )
Original file line number Diff line number Diff line change 1+ import torch
2+
3+ dtype = torch .float32
4+
5+ maybe_autocast = torch .amp .autocast ("cuda" )
6+ maybe_autocast = torch .amp .autocast ("cuda" , dtype = torch .bfloat16 )
7+ maybe_autocast = torch .amp .autocast ("cuda" , dtype = dtype )
8+
9+ maybe_autocast = torch .amp .autocast ("cpu" )
10+ maybe_autocast = torch .amp .autocast ("cpu" , dtype = torch .bfloat16 )
11+ maybe_autocast = torch .amp .autocast ("cpu" , dtype = dtype )
Original file line number Diff line number Diff line change 8383 remove_pr :
8484 reference : https://github.com/pytorch-labs/torchfix#torchbackendscudasdp_kernel
8585
86+ - name : torch.cuda.amp.autocast
87+ deprecate_pr : TBA
88+ remove_pr :
89+
90+ - name : torch.cuda.amp.custom_fwd
91+ deprecate_pr : TBA
92+ remove_pr :
93+
94+ - name : torch.cuda.amp.custom_bwd
95+ deprecate_pr : TBA
96+ remove_pr :
97+
98+ - name : torch.cpu.amp.autocast
99+ deprecate_pr : TBA
100+ remove_pr :
101+
86102# functorch
87103- name : functorch.vmap
88104 deprecate_pr : TBA
Original file line number Diff line number Diff line change 1- import libcst as cst
21import pkgutil
2+ from typing import List , Optional
3+
4+ import libcst as cst
35import yaml
4- from typing import Optional , List
56
67from ...common import (
7- TorchVisitor ,
8- TorchError ,
98 call_with_name_changes ,
109 check_old_names_in_import_from ,
10+ TorchError ,
11+ TorchVisitor ,
1112)
1213
13- from .range import call_replacement_range
14- from .cholesky import call_replacement_cholesky
14+ from .amp import call_replacement_cpu_amp_autocast , call_replacement_cuda_amp_autocast
1515from .chain_matmul import call_replacement_chain_matmul
16+ from .cholesky import call_replacement_cholesky
1617from .qr import call_replacement_qr
1718
19+ from .range import call_replacement_range
20+
1821
1922class TorchDeprecatedSymbolsVisitor (TorchVisitor ):
2023 ERRORS : List [TorchError ] = [
@@ -49,6 +52,8 @@ def _call_replacement(
4952 "torch.range" : call_replacement_range ,
5053 "torch.chain_matmul" : call_replacement_chain_matmul ,
5154 "torch.qr" : call_replacement_qr ,
55+ "torch.cuda.amp.autocast" : call_replacement_cuda_amp_autocast ,
56+ "torch.cpu.amp.autocast" : call_replacement_cpu_amp_autocast ,
5257 }
5358 replacement = None
5459
Original file line number Diff line number Diff line change 1+ import libcst as cst
2+
3+ from ...common import get_module_name
4+
5+
6+ def call_replacement_cpu_amp_autocast (node : cst .Call ) -> cst .CSTNode :
7+ return _call_replacement_amp (node , "cpu" )
8+
9+
10+ def call_replacement_cuda_amp_autocast (node : cst .Call ) -> cst .CSTNode :
11+ return _call_replacement_amp (node , "cuda" )
12+
13+
14+ def _call_replacement_amp (node : cst .Call , device : str ) -> cst .CSTNode :
15+ """
16+ Replace `torch.cuda.amp.autocast()` with `torch.amp.autocast("cuda")` and
17+ Replace `torch.cpu.amp.autocast()` with `torch.amp.autocast("cpu")`.
18+ """
19+ device_arg = cst .ensure_type (cst .parse_expression (f'f("{ device } ")' ), cst .Call ).args [
20+ 0
21+ ]
22+
23+ module_name = get_module_name (node , "torch" )
24+ replacement = cst .parse_expression (f"{ module_name } .amp.autocast(args)" )
25+ replacement = replacement .with_changes (args = (device_arg , * node .args ))
26+ return replacement
You can’t perform that action at this time.
0 commit comments