Skip to content

Commit 03fa149

Browse files
committed
Merge branch 'main' into grad_checkpointing
2 parents 0638708 + 03f4f4d commit 03fa149

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+5128
-1766
lines changed

timm/data/loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ def __init__(
123123
def __iter__(self):
124124
first = True
125125
if self.is_cuda:
126-
stream = torch.cuda.Stream()
126+
stream = torch.cuda.Stream(device=self.device)
127127
stream_context = partial(torch.cuda.stream, stream=stream)
128128
elif self.is_npu:
129-
stream = torch.npu.Stream()
129+
stream = torch.npu.Stream(device=self.device)
130130
stream_context = partial(torch.npu.stream, stream=stream)
131131
else:
132132
stream = None
@@ -148,9 +148,9 @@ def __iter__(self):
148148

149149
if stream is not None:
150150
if self.is_cuda:
151-
torch.cuda.current_stream().wait_stream(stream)
151+
torch.cuda.current_stream(device=self.device).wait_stream(stream)
152152
elif self.is_npu:
153-
torch.npu.current_stream().wait_stream(stream)
153+
torch.npu.current_stream(device=self.device).wait_stream(stream)
154154

155155
input = next_input
156156
target = next_target

timm/data/naflex_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
9191
"""
9292
first = True
9393
if self.is_cuda:
94-
stream = torch.cuda.Stream()
94+
stream = torch.cuda.Stream(device=self.device)
9595
stream_context = partial(torch.cuda.stream, stream=stream)
9696
elif self.is_npu:
97-
stream = torch.npu.Stream()
97+
stream = torch.npu.Stream(device=self.device)
9898
stream_context = partial(torch.npu.stream, stream=stream)
9999
else:
100100
stream = None
@@ -152,9 +152,9 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
152152

153153
if stream is not None:
154154
if self.is_cuda:
155-
torch.cuda.current_stream().wait_stream(stream)
155+
torch.cuda.current_stream(device=self.device).wait_stream(stream)
156156
elif self.is_npu:
157-
torch.npu.current_stream().wait_stream(stream)
157+
torch.npu.current_stream(device=self.device).wait_stream(stream)
158158

159159
input_dict = next_input_dict
160160
target = next_target

timm/layers/blur_pool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
Hacked together by Chris Ha and Ross Wightman
77
"""
88
from functools import partial
9+
from math import comb # Python 3.8
910
from typing import Optional, Type
1011

1112
import torch
1213
import torch.nn as nn
1314
import torch.nn.functional as F
14-
import numpy as np
1515

1616
from .padding import get_padding
1717
from .typing import LayerType
@@ -45,7 +45,11 @@ def __init__(
4545
self.pad_mode = pad_mode
4646
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
4747

48-
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
48+
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
49+
coeffs = torch.tensor(
50+
[comb(filt_size - 1, k) for k in range(filt_size)],
51+
dtype=torch.float32,
52+
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
4953
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
5054
if channels is not None:
5155
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)

timm/layers/cond_conv2d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import math
1010
from functools import partial
11-
import numpy as np
1211
import torch
1312
from torch import nn as nn
1413
from torch.nn import functional as F
@@ -21,7 +20,7 @@
2120
def get_condconv_initializer(initializer, num_experts, expert_shape):
2221
def condconv_initializer(weight):
2322
"""CondConv initializer function."""
24-
num_params = np.prod(expert_shape)
23+
num_params = math.prod(expert_shape)
2524
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
2625
weight.shape[1] != num_params):
2726
raise (ValueError(
@@ -75,7 +74,7 @@ def reset_parameters(self):
7574
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
7675
init_weight(self.weight)
7776
if self.bias is not None:
78-
fan_in = np.prod(self.weight_shape[1:])
77+
fan_in = math.prod(self.weight_shape[1:])
7978
bound = 1 / math.sqrt(fan_in)
8079
init_bias = get_condconv_initializer(
8180
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)

timm/models/_builder.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from copy import deepcopy
55
from pathlib import Path
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
77

88
from torch import nn as nn
99
from torch.hub import load_state_dict_from_url
@@ -26,11 +26,21 @@
2626
_CHECK_HASH = False
2727
_USE_OLD_CACHE = int(os.environ.get('TIMM_USE_OLD_CACHE', 0)) > 0
2828

29-
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
30-
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
29+
__all__ = [
30+
'set_pretrained_download_progress',
31+
'set_pretrained_check_hash',
32+
'load_custom_pretrained',
33+
'load_pretrained',
34+
'pretrained_cfg_for_features',
35+
'resolve_pretrained_cfg',
36+
'build_model_with_cfg',
37+
]
3138

3239

33-
def _resolve_pretrained_source(pretrained_cfg):
40+
ModelT = TypeVar("ModelT", bound=nn.Module) # any subclass of nn.Module
41+
42+
43+
def _resolve_pretrained_source(pretrained_cfg: Dict[str, Any]) -> Tuple[str, str]:
3444
cfg_source = pretrained_cfg.get('source', '')
3545
pretrained_url = pretrained_cfg.get('url', None)
3646
pretrained_file = pretrained_cfg.get('file', None)
@@ -78,25 +88,25 @@ def _resolve_pretrained_source(pretrained_cfg):
7888
return load_from, pretrained_loc
7989

8090

81-
def set_pretrained_download_progress(enable=True):
91+
def set_pretrained_download_progress(enable: bool = True) -> None:
8292
""" Set download progress for pretrained weights on/off (globally). """
8393
global _DOWNLOAD_PROGRESS
8494
_DOWNLOAD_PROGRESS = enable
8595

8696

87-
def set_pretrained_check_hash(enable=True):
97+
def set_pretrained_check_hash(enable: bool = True) -> None:
8898
""" Set hash checking for pretrained weights on/off (globally). """
8999
global _CHECK_HASH
90100
_CHECK_HASH = enable
91101

92102

93103
def load_custom_pretrained(
94104
model: nn.Module,
95-
pretrained_cfg: Optional[Dict] = None,
105+
pretrained_cfg: Optional[Dict[str, Any]] = None,
96106
load_fn: Optional[Callable] = None,
97107
cache_dir: Optional[Union[str, Path]] = None,
98-
):
99-
r"""Loads a custom (read non .pth) weight file
108+
) -> None:
109+
"""Loads a custom (read non .pth) weight file
100110
101111
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
102112
a passed in custom load fun, or the `load_pretrained` model member fn.
@@ -141,13 +151,13 @@ def load_custom_pretrained(
141151

142152
def load_pretrained(
143153
model: nn.Module,
144-
pretrained_cfg: Optional[Dict] = None,
154+
pretrained_cfg: Optional[Dict[str, Any]] = None,
145155
num_classes: int = 1000,
146156
in_chans: int = 3,
147157
filter_fn: Optional[Callable] = None,
148158
strict: bool = True,
149159
cache_dir: Optional[Union[str, Path]] = None,
150-
):
160+
) -> None:
151161
""" Load pretrained checkpoint
152162
153163
Args:
@@ -278,7 +288,7 @@ def load_pretrained(
278288
f' This may be expected if model is being adapted.')
279289

280290

281-
def pretrained_cfg_for_features(pretrained_cfg):
291+
def pretrained_cfg_for_features(pretrained_cfg: Dict[str, Any]) -> Dict[str, Any]:
282292
pretrained_cfg = deepcopy(pretrained_cfg)
283293
# remove default pretrained cfg fields that don't have much relevance for feature backbone
284294
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
@@ -287,14 +297,14 @@ def pretrained_cfg_for_features(pretrained_cfg):
287297
return pretrained_cfg
288298

289299

290-
def _filter_kwargs(kwargs, names):
300+
def _filter_kwargs(kwargs: Dict[str, Any], names: List[str]) -> None:
291301
if not kwargs or not names:
292302
return
293303
for n in names:
294304
kwargs.pop(n, None)
295305

296306

297-
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
307+
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) -> None:
298308
""" Update the default_cfg and kwargs before passing to model
299309
300310
Args:
@@ -340,6 +350,7 @@ def resolve_pretrained_cfg(
340350
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
341351
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
342352
) -> PretrainedCfg:
353+
"""Resolve pretrained configuration from various sources."""
343354
model_with_tag = variant
344355
pretrained_tag = None
345356
if pretrained_cfg:
@@ -371,7 +382,7 @@ def resolve_pretrained_cfg(
371382

372383

373384
def build_model_with_cfg(
374-
model_cls: Callable,
385+
model_cls: Union[Type[ModelT], Callable[..., ModelT]],
375386
variant: str,
376387
pretrained: bool,
377388
pretrained_cfg: Optional[Dict] = None,
@@ -383,7 +394,7 @@ def build_model_with_cfg(
383394
cache_dir: Optional[Union[str, Path]] = None,
384395
kwargs_filter: Optional[Tuple[str]] = None,
385396
**kwargs,
386-
):
397+
) -> ModelT:
387398
""" Build model with specified default_cfg and optional model_cfg
388399
389400
This helper fn aids in the construction of a model including:

timm/models/_factory.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any, Dict, Optional, Tuple, Union
44
from urllib.parse import urlsplit
55

6+
from torch import nn
7+
68
from timm.layers import set_layer_config
79
from ._helpers import load_checkpoint
810
from ._hub import load_model_config_from_hf, load_model_config_from_path
@@ -13,7 +15,8 @@
1315
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
1416

1517

16-
def parse_model_name(model_name: str):
18+
def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
19+
"""Parse source and name from potentially prefixed model name."""
1720
if model_name.startswith('hf_hub'):
1821
# NOTE for backwards compat, deprecate hf_hub use
1922
model_name = model_name.replace('hf_hub', 'hf-hub')
@@ -29,9 +32,9 @@ def parse_model_name(model_name: str):
2932
return None, model_name
3033

3134

32-
def safe_model_name(model_name: str, remove_source: bool = True):
33-
# return a filename / path safe model name
34-
def make_safe(name):
35+
def safe_model_name(model_name: str, remove_source: bool = True) -> str:
36+
"""Return a filename / path safe model name."""
37+
def make_safe(name: str) -> str:
3538
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
3639
if remove_source:
3740
model_name = parse_model_name(model_name)[-1]
@@ -42,14 +45,14 @@ def create_model(
4245
model_name: str,
4346
pretrained: bool = False,
4447
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
45-
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
48+
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
4649
checkpoint_path: Optional[Union[str, Path]] = None,
4750
cache_dir: Optional[Union[str, Path]] = None,
4851
scriptable: Optional[bool] = None,
4952
exportable: Optional[bool] = None,
5053
no_jit: Optional[bool] = None,
51-
**kwargs,
52-
):
54+
**kwargs: Any,
55+
) -> nn.Module:
5356
"""Create a model.
5457
5558
Lookup model's entrypoint function and pass relevant args to create a new model.

0 commit comments

Comments
 (0)