Skip to content

Commit 6fc886a

Browse files
committed
Remove all prints, change most to logging calls, tweak alignment of batch logs, improve setup.py
1 parent 1d7f2d9 commit 6fc886a

File tree

13 files changed

+126
-109
lines changed

13 files changed

+126
-109
lines changed

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
3030
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
3131
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
3232
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
33-
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
33+
* EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
3434
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
3535
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
3636
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
@@ -187,9 +187,6 @@ To run inference from a checkpoint:
187187

188188
## TODO
189189
A number of additions planned in the future for various projects, incl
190-
* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
191190
* Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
192-
* More training experiments
193-
* Make folder/file layout compat with usage as a module
194191
* Add usage examples to comments, good hyper params for training
195192
* Comments, cleanup and the usual things that get pushed back

inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import os
99
import time
1010
import argparse
11+
import logging
1112
import numpy as np
1213
import torch
1314

1415
from timm.models import create_model, apply_test_time_pool
1516
from timm.data import Dataset, create_loader, resolve_data_config
16-
from timm.utils import AverageMeter
17+
from timm.utils import AverageMeter, setup_default_logging
1718

1819
torch.backends.cudnn.benchmark = True
1920

@@ -38,8 +39,8 @@
3839
help='Image resize interpolation type (overrides model)')
3940
parser.add_argument('--num-classes', type=int, default=1000,
4041
help='Number classes in dataset')
41-
parser.add_argument('--print-freq', '-p', default=10, type=int,
42-
metavar='N', help='print frequency (default: 10)')
42+
parser.add_argument('--log-freq', default=10, type=int,
43+
metavar='N', help='batch logging frequency (default: 10)')
4344
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
4445
help='path to latest checkpoint (default: none)')
4546
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
@@ -53,8 +54,8 @@
5354

5455

5556
def main():
57+
setup_default_logging()
5658
args = parser.parse_args()
57-
5859
# might as well try to do something useful...
5960
args.pretrained = args.pretrained or not args.checkpoint
6061

@@ -66,8 +67,8 @@ def main():
6667
pretrained=args.pretrained,
6768
checkpoint_path=args.checkpoint)
6869

69-
print('Model %s created, param count: %d' %
70-
(args.model, sum([m.numel() for m in model.parameters()])))
70+
logging.info('Model %s created, param count: %d' %
71+
(args.model, sum([m.numel() for m in model.parameters()])))
7172

7273
config = resolve_data_config(model, args)
7374
model, test_time_pool = apply_test_time_pool(model, config, args)
@@ -105,9 +106,8 @@ def main():
105106
batch_time.update(time.time() - end)
106107
end = time.time()
107108

108-
if batch_idx % args.print_freq == 0:
109-
print('Predict: [{0}/{1}]\t'
110-
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
109+
if batch_idx % args.log_freq == 0:
110+
logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
111111
batch_idx, len(loader), batch_time=batch_time))
112112

113113
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()

setup.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,27 @@
1919
url='https://github.com/rwightman/pytorch-image-models',
2020
author='Ross Wightman',
2121
author_email='hello@rwightman.com',
22-
classifiers=[ # Optional
22+
classifiers=[
2323
# How mature is this project? Common values are
2424
# 3 - Alpha
2525
# 4 - Beta
2626
# 5 - Production/Stable
2727
'Development Status :: 3 - Alpha',
28-
'Intended Audience :: Developers',
29-
'Topic :: Software Development :: Build Tools',
30-
'License :: OSI Approved :: Apache License',
28+
'Intended Audience :: Education',
29+
'Intended Audience :: Science/Research',
30+
'License :: OSI Approved :: Apache Software License',
3131
'Programming Language :: Python :: 3.6',
32+
'Programming Language :: Python :: 3.7',
33+
'Topic :: Scientific/Engineering',
34+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
35+
'Topic :: Software Development',
36+
'Topic :: Software Development :: Libraries',
37+
'Topic :: Software Development :: Libraries :: Python Modules',
3238
],
3339

3440
# Note that this is a string of words separated by whitespace, not a list.
3541
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
3642
packages=find_packages(exclude=['convert']),
37-
install_requires=['torch', 'torchvision'],
43+
install_requires=['torch >= 1.0', 'torchvision'],
3844
python_requires='>=3.6',
3945
)

timm/data/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from .constants import *
23

34

@@ -56,9 +57,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
5657
new_config['crop_pct'] = default_cfg['crop_pct']
5758

5859
if verbose:
59-
print('Data processing configuration for current model + dataset:')
60+
logging.info('Data processing configuration for current model + dataset:')
6061
for n, v in new_config.items():
61-
print('\t%s: %s' % (n, str(v)))
62+
logging.info('\t%s: %s' % (n, str(v)))
6263

6364
return new_config
6465

timm/models/adaptive_avgmax_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, output_size=1, pool_type='avg'):
8282
self.pool = nn.AdaptiveMaxPool2d(output_size)
8383
else:
8484
if pool_type != 'avg':
85-
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
85+
assert False, 'Invalid pool type: %s' % pool_type
8686
self.pool = nn.AdaptiveAvgPool2d(output_size)
8787

8888
def forward(self, x):

timm/models/densenet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
8686
r"""Densenet-201 model from
8787
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
8888
"""
89-
print(num_classes, in_chans, pretrained)
9089
default_cfg = default_cfgs['densenet161']
9190
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
9291
num_classes=num_classes, in_chans=in_chans, **kwargs)

timm/models/gen_efficientnet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import math
1919
import re
20+
import logging
2021
from copy import deepcopy
2122

2223
import torch
@@ -336,7 +337,7 @@ def _make_block(self, ba):
336337
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
337338
assert ba['act_fn'] is not None
338339
if self.verbose:
339-
print('args:', ba)
340+
logging.info(' Args: {}'.format(str(ba)))
340341
# could replace this if with lambdas or functools binding if variety increases
341342
if bt == 'ir':
342343
ba['drop_connect_rate'] = self.drop_connect_rate
@@ -358,7 +359,7 @@ def _make_stack(self, stack_args):
358359
# each stack (stage) contains a list of block arguments
359360
for block_idx, ba in enumerate(stack_args):
360361
if self.verbose:
361-
print('block', block_idx, end=', ')
362+
logging.info(' Block: {}'.format(block_idx))
362363
if block_idx >= 1:
363364
# only the first block in any stack/stage can have a stride > 1
364365
ba['stride'] = 1
@@ -370,24 +371,22 @@ def __call__(self, in_chs, block_args):
370371
""" Build the blocks
371372
Args:
372373
in_chs: Number of input-channels passed to first block
373-
arch_def: A list of lists, outer list defines stacks (or stages), inner
374+
block_args: A list of lists, outer list defines stages, inner
374375
list contains strings defining block configuration(s)
375376
Return:
376377
List of block stacks (each stack wrapped in nn.Sequential)
377378
"""
378379
if self.verbose:
379-
print('Building model trunk with %d stacks (stages)...' % len(block_args))
380+
logging.info('Building model trunk with %d stages...' % len(block_args))
380381
self.in_chs = in_chs
381382
blocks = []
382383
# outer list of block_args defines the stacks ('stages' by some conventions)
383384
for stack_idx, stack in enumerate(block_args):
384385
if self.verbose:
385-
print('stack', stack_idx)
386+
logging.info('Stack: {}'.format(stack_idx))
386387
assert isinstance(stack, list)
387388
stack = self._make_stack(stack)
388389
blocks.append(stack)
389-
if self.verbose:
390-
print()
391390
return blocks
392391

393392

timm/models/helpers.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.utils.model_zoo as model_zoo
33
import os
4+
import logging
45
from collections import OrderedDict
56

67

@@ -21,9 +22,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
2122
model.load_state_dict(new_state_dict)
2223
else:
2324
model.load_state_dict(checkpoint)
24-
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
25+
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
2526
else:
26-
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
27+
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
2728
raise FileNotFoundError()
2829

2930

@@ -40,27 +41,27 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
4041
if 'optimizer' in checkpoint:
4142
optimizer_state = checkpoint['optimizer']
4243
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
43-
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
44+
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
4445
else:
4546
model.load_state_dict(checkpoint)
4647
start_epoch = 0 if start_epoch is None else start_epoch
47-
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
48+
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
4849
return optimizer_state, start_epoch
4950
else:
50-
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
51+
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
5152
raise FileNotFoundError()
5253

5354

5455
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
5556
if 'url' not in default_cfg or not default_cfg['url']:
56-
print("Warning: pretrained model URL is invalid, using random initialization.")
57+
logging.warning("Pretrained model URL is invalid, using random initialization.")
5758
return
5859

5960
state_dict = model_zoo.load_url(default_cfg['url'])
6061

6162
if in_chans == 1:
6263
conv1_name = default_cfg['first_conv']
63-
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
64+
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
6465
conv1_weight = state_dict[conv1_name + '.weight']
6566
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
6667
elif in_chans != 3:

timm/models/test_time_pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from torch import nn
23
import torch.nn.functional as F
34
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
@@ -31,8 +32,8 @@ def apply_test_time_pool(model, config, args):
3132
if not args.no_test_pool and \
3233
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
3334
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
34-
print('Target input size %s > pretrained default %s, using test time pooling' %
35-
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
35+
logging.info('Target input size %s > pretrained default %s, using test time pooling' %
36+
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
3637
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
3738
test_time_pool = True
3839
return model, test_time_pool

timm/scheduler/tanh_lr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def __init__(self,
5050
self.t_in_epochs = t_in_epochs
5151
if self.warmup_t:
5252
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
53-
print(t_v)
5453
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
5554
super().update_groups(self.warmup_lr_init)
5655
else:

0 commit comments

Comments
 (0)