Skip to content

Commit 15713cb

Browse files
authored
Change deprecated cuda.amp to amp, update requirements.txt (#27)
1 parent 147f3b1 commit 15713cb

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

week03_fast_pipelines/homework/task1/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def train_epoch(
2121
images = images.to(device)
2222
labels = labels.to(device)
2323

24-
with torch.cuda.amp.autocast():
24+
with torch.amp.autocast(device.type, dtype=torch.float16):
2525
outputs = model(images)
2626
loss = criterion(outputs, labels)
2727
# TODO: your code for loss scaling here

week03_fast_pipelines/seminar/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pandas==1.5.3
88
py-spy==0.3.14
99
torch==2.3.0
1010
torchtext
11-
torchvision==0.17.0
11+
torchvision==0.18.0
1212
tqdm==4.64.1
1313
vit_pytorch==0.40.2
1414
matplotlib==3.8.2

0 commit comments

Comments
 (0)