Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Commit 3eeeb81

Browse files
committed
Update: Adopt PR NVlabs/stylegan2-ada-pytorch#197 to fix Update: Adopt PR NVlabs/stylegan2-ada-pytorch#197 to fix on PyTorch 1.10 and above.
1 parent 388ed20 commit 3eeeb81

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

third_party/stylegan2_official_ops/conv2d_gradfix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import warnings
3030
import contextlib
3131
import torch
32+
from distutils.version import LooseVersion
3233

3334
enabled = True # Enable the custom op by setting this to true.
3435
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
@@ -61,7 +62,7 @@ def _should_use_custom_op(input):
6162
return False
6263
if input.device.type != 'cuda':
6364
return False
64-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
65+
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
6566
return True
6667
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
6768
return False

third_party/stylegan2_official_ops/grid_sample_gradfix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import warnings
2828
import torch
29+
from distutils.version import LooseVersion
2930

3031
#----------------------------------------------------------------------------
3132

@@ -43,7 +44,7 @@ def grid_sample(input, grid, impl='cuda'):
4344
def _should_use_custom_op():
4445
if not enabled:
4546
return False
46-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
47+
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
4748
return True
4849
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
4950
return False

0 commit comments

Comments
 (0)