Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Commit 388ed20

Browse files
committed
Fix: Minor typo, docstring, and format across the project; compatibility bug related to jpeg compression and RGBA image.
1 parent ab52d16 commit 388ed20

File tree

15 files changed

+195
-119
lines changed

15 files changed

+195
-119
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,4 @@ This repository is organized as follows:
8282
└── unit_tests.py # Main entry for unit tests, regarding `./models/` and `./utils/`.
8383
```
8484

85-
To develop a new project (or a new approach), a `runner` (which defines the training procedure, e.g., the data forwarding flow, the optimizing order of various loss terms, how the model(s) should be updated, etc.), a `loss` (which defines the computation of each loss term), and a `config` (which collects all configurations used in the project) are necessary. You may also need to design your own `dataset` (including data `transformation`), `model` structure, evaluation `metric`, `augmentation` pipeline, and running `controller` if they are not supported yet. **NOTE:** All these modules are almost independent of each other. Hence, once a new feature (e.g., `dataset`, `model`, `metric`, `augmentation`, or `controller`) is developed, it can be shared to others with minor effort. It you are interested in sharing your work, we really appreciate your contribution to these modules.
85+
To develop a new project (or a new approach), a `runner` (which defines the training procedure, e.g., the data forwarding flow, the optimizing order of various loss terms, how the model(s) should be updated, etc.), a `loss` (which defines the computation of each loss term), and a `config` (which collects all configurations used in the project) are necessary. You may also need to design your own `dataset` (including data `transformation`), `model` structure, evaluation `metric`, `augmentation` pipeline, and running `controller` if they are not supported yet. **NOTE:** All these modules are almost independent from each other. Hence, once a new feature (e.g., `dataset`, `model`, `metric`, `augmentation`, or `controller`) is developed, it can be shared to others with minor effort. It you are interested in sharing your work, we really appreciate your contribution to these modules.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ Please find more training demos under `./scripts/training_demos/`.
6868

6969
## Inspect Training Results
7070

71-
Besides using TensorBoard to track the training process, the raw results (e.g., training losses and running time) are saved in JSON format. They can be easily inspected with the following script
71+
Besides using TensorBoard to track the training process, the raw results (e.g., training losses and running time) are saved in [JSON Lines](https://jsonlines.org/) format. They can be easily inspected with the following script
7272

7373
```python
7474
import json
7575
76-
file_name = '<PATH_TO_WORK_DIR>/log.json'
76+
file_name = '<PATH_TO_WORK_DIR>/log.jsonl'
7777
7878
data_entries = []
7979
with open(file_name, 'r') as f:

configs/base_config.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
├── ${PROFILE_DIR}/
1818
├── ${RESOURCE_DIR}/
1919
├── ${CONFIG_FILENAME} # in JSON format
20-
├── ${LOG_DATA_FILENAME} # in JSON format
20+
├── ${LOG_DATA_FILENAME} # in JSON Lines format
2121
└── ${LOG_FILENAME} # in plain text
2222
"""
2323

@@ -36,6 +36,15 @@
3636

3737
__all__ = ['BaseConfig']
3838

39+
_PARAM_TYPE_TO_VALUE_TYPE = {
40+
'IntegerParamType': 'int',
41+
'FloatParamType': 'float',
42+
'BooleanParamType': 'bool',
43+
'StringParamType': 'str',
44+
'IndexParamType': 'index-string',
45+
'JsonParamType': 'json-string'
46+
}
47+
3948

4049
class BaseConfig(object):
4150
"""Defines the base configuration class.
@@ -75,27 +84,37 @@ class BaseConfig(object):
7584
The base class provides the following functions to parse configuration from
7685
command line:
7786
87+
- Functions requiring implementation in derived class:
88+
7889
(1) get_options(): Declare all options required by a particular task. The
7990
base class has already pre-declared some options that will be shared
8091
across tasks (e.g., data-related options). To declare more options,
8192
the derived class should override this function by first calling
82-
`options = super().get_options()`. (requires implementation)
93+
`options = super().get_options()`.
94+
(2) parse_options(): Parse the options obtained from the command line (as
95+
well as those options with default values) to `self.config`. This is the
96+
core function of the configuration class, which converts `options` to
97+
`configurations`.
98+
(3) get_recommended_options(): Get a list of options that are recommended
99+
for a particular task. The base class has already pre-declared some
100+
recommended options that will be shared across tasks. To recommend more
101+
options, the derived class should override this function by first
102+
calling `recommended_opts = super().get_recommended_options()`.
103+
104+
- Helper functions shared by all derived classes:
105+
106+
(1) inspect_option(): Inspect argument from a particular `click.option`,
107+
including the argument name, argument type, default value, and help
108+
message.
83109
(2) add_options_to_command(): Add all options for a particular task to the
84110
corresponding command. This function is specially designed to show
85-
user-friendly help message. (no need to care about by derived class)
111+
user-friendly help message.
86112
(3) get_command(): Return a `click.command` to get interactive with users.
87113
This function makes it possible to pass options through command line.
88-
(no need to care about by derived class)
89-
(4) parse_options(): Parse the options obtained from the command line (as
90-
well as those options with default values) to `self.config`. This is the
91-
core function of the configuration class, which converts `options` to
92-
`configurations`. (requires implementation)
93-
(5) update_config(): Update the configuration parsed from options with
114+
(4) update_config(): Update the configuration parsed from options with
94115
key-value pairs. This function makes option parsing more flexible.
95-
(no need to care about by derived class)
96-
(6) get_config(): The main function to get the parsed configuration, which
97-
wraps functions `parse_options()` and `tune_config()`. (no need to care
98-
about by derived class)
116+
(5) get_config(): The main function to get the parsed configuration, which
117+
wraps functions `parse_options()` and `update_config()`.
99118
100119
In summary, to define a configuration class for a new task, the derived
101120
class only need to implement `get_options()` to declare changeable settings
@@ -114,6 +133,30 @@ class only need to implement `get_options()` to declare changeable settings
114133
json_type = JsonParamType()
115134
command_option = cloup.option
116135

136+
@staticmethod
137+
def inspect_option(option):
138+
"""Inspects argument from a particular option.
139+
140+
Args:
141+
option: The input `click.option` to inspect.
142+
143+
Returns:
144+
An `EasyDict` indicating the `name`, `type`, `default` (default
145+
value), and `help` (help message) of the argument.
146+
"""
147+
148+
@option
149+
def func():
150+
"""A dummy function used to parse decorator."""
151+
152+
arg = func.__click_params__[0]
153+
return EasyDict(
154+
name=arg.name,
155+
type=_PARAM_TYPE_TO_VALUE_TYPE[arg.type.__class__.__name__],
156+
default=arg.default,
157+
help=arg.help
158+
)
159+
117160
@staticmethod
118161
def add_options_to_command(options):
119162
"""Adds task options to a command.
@@ -190,7 +233,7 @@ def get_options(cls):
190233
'--config_path', type=str, default='config.json',
191234
help='To which to save full configuration.'),
192235
cls.command_option(
193-
'--log_data_path', type=str, default='log.json',
236+
'--log_data_path', type=str, default='log.jsonl',
194237
help='To which to save raw log data, e.g. losses.'),
195238
cls.command_option(
196239
'--log_path', type=str, default='log.txt',

datasets/transformations/jpeg_compress.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
fn = None
1212

1313
from utils.formatting_utils import format_range
14-
from utils.formatting_utils import format_image
1514
from .base_transformation import BaseTransformation
1615

1716
__all__ = ['JpegCompress']
@@ -51,8 +50,9 @@ def _CPU_forward(self, data):
5150
outputs = []
5251
for image in data:
5352
_, encoded_image = cv2.imencode('.jpg', image, encode_param)
54-
decoded_image = format_image(
55-
cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED))
53+
decoded_image = cv2.imdecode(encoded_image, cv2.IMREAD_UNCHANGED)
54+
if decoded_image.ndim == 2:
55+
decoded_image = decoded_image[:, :, np.newaxis]
5656
outputs.append(decoded_image)
5757
return outputs
5858

dump_command_args.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
"type": "object",
1515
"properties": {
1616
"arg_1": {
17+
"is_recommended": # true / false
1718
"type": # int / float / bool / str / json-string /
1819
# index-string
19-
"is_recommended": # true / false
2020
"default":
2121
"description":
2222
},
2323
"arg_2": {
24-
"type":
2524
"is_recommended":
25+
"type":
2626
"default":
2727
"description":
2828
}
@@ -32,14 +32,14 @@
3232
"type": "object",
3333
"properties": {
3434
"arg_3": {
35-
"type":
3635
"is_recommended":
36+
"type":
3737
"default":
3838
"description":
3939
},
4040
"arg_4": {
41-
"type":
4241
"is_recommended":
42+
"type":
4343
"default":
4444
"description":
4545
}
@@ -54,8 +54,8 @@
5454
"type": "object",
5555
"properties": {
5656
"arg_1": {
57-
"type":
5857
"is_recommended":
58+
"type":
5959
"default":
6060
"description":
6161
}
@@ -71,15 +71,6 @@
7171

7272
from configs import CONFIG_POOL
7373

74-
PARAM_TYPE_TO_VALUE_TYPE = {
75-
'IntegerParamType': 'int',
76-
'FloatParamType': 'float',
77-
'BooleanParamType': 'bool',
78-
'StringParamType': 'str',
79-
'IndexParamType': 'index-string',
80-
'JsonParamType': 'json-string'
81-
}
82-
8374

8475
def parse_args_from_config(config):
8576
"""Parses available arguments from a configuration class.
@@ -89,21 +80,21 @@ def parse_args_from_config(config):
8980
defined in `configs/`. This class is supposed to derive from
9081
`BaseConfig` defined in `configs/base_config.py`.
9182
"""
92-
def _dummy_func():
93-
"""A dummy function used to parse decorator."""
94-
95-
args = dict()
96-
func = config.add_options_to_command(config.get_options())(_dummy_func)
9783
recommended_opts = config.get_recommended_options()
98-
for opt in reversed(func.__click_params__):
99-
if opt.group.title not in args:
100-
args[opt.group.title] = dict(type='object', properties=dict())
101-
args[opt.group.title]['properties'][opt.name] = dict(
102-
type=PARAM_TYPE_TO_VALUE_TYPE[opt.type.__class__.__name__],
103-
default=opt.default,
104-
is_recommended=opt.name in recommended_opts,
105-
description=opt.help
84+
args = dict()
85+
for opt_group, opts in config.get_options().items():
86+
args[opt_group] = dict(
87+
type='object',
88+
properties=dict()
10689
)
90+
for opt in opts:
91+
arg = config.inspect_option(opt)
92+
args[opt_group]['properties'][arg.name] = dict(
93+
is_recommended=arg.name in recommended_opts,
94+
type=arg.type,
95+
default=arg.default,
96+
description=arg.help
97+
)
10798
return args
10899

109100

metrics/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def compute_gan_precision_recall(fake_features,
216216
top_k=3):
217217
"""Computes precision and recall for GAN evaluation.
218218
219-
KID metric is introduced in https://arxiv.org/pdf/1904.06991.pdf, with
220-
official code
219+
GAN precision and recall are introduced in
220+
221+
https://arxiv.org/pdf/1904.06991.pdf, with official code
221222
222223
https://github.com/kynkaat/improved-precision-and-recall-metric.
223224

models/stylegan2_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ def forward(self,
301301
For layers in range [0, truncation_layers), the truncated w-code is
302302
computed as
303303
304-
w_new = w_avg + (w - w_avg) * truncation_psi
304+
w_new = w_avg + (w - w_avg) * trunc_psi
305305
306306
To disable truncation, please set
307307
308-
(1) truncation_psi = 1.0 (None) OR
309-
(2) truncation_layers = 0 (None)
308+
(1) trunc_psi = 1.0 (None) OR
309+
(2) trunc_layers = 0 (None)
310310
"""
311311

312312
mapping_results = self.mapping(z, label, impl=impl)

models/stylegan3_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,12 @@ def forward(self,
277277
For layers in range [0, truncation_layers), the truncated w-code is
278278
computed as
279279
280-
w_new = w_avg + (w - w_avg) * truncation_psi
280+
w_new = w_avg + (w - w_avg) * trunc_psi
281281
282282
To disable truncation, please set
283283
284-
(1) truncation_psi = 1.0 (None) OR
285-
(2) truncation_layers = 0 (None)
284+
(1) trunc_psi = 1.0 (None) OR
285+
(2) trunc_layers = 0 (None)
286286
"""
287287

288288
mapping_results = self.mapping(z, label, impl=impl)

runners/base_runner.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -507,25 +507,30 @@ def log_model_info(self):
507507
for param_name, param in model.named_parameters():
508508
param_shapes[param_name] = f'{list(param.shape)}'
509509
param_numels[param_name] = param.numel()
510-
param_name_max_len = max(map(len, param_shapes.keys()))
511-
param_name_max_len = max(param_name_max_len, len(name_header))
512-
param_shape_max_len = max(map(len, param_shapes.values()))
513-
param_shape_max_len = max(param_shape_max_len, len(shape_header))
514-
param_numel_max_len = int(np.ceil(
515-
np.log10(max(param_numels.values()))
516-
))
517-
param_numel_max_len = max(param_numel_max_len, len(numel_header))
518-
model_info += f'{name_header:<{param_name_max_len + 2}}'
519-
model_info += f'{shape_header:<{param_shape_max_len + 2}}'
520-
model_info += f'{numel_header:>{param_numel_max_len + 2}}\n'
521-
model_info += f'{name_separator:<{param_name_max_len + 2}}'
522-
model_info += f'{shape_separator:<{param_shape_max_len + 2}}'
523-
model_info += f'{numel_separator:>{param_numel_max_len + 2}}\n'
524-
for param_name, param_shape in param_shapes.items():
525-
param_numel = param_numels[param_name]
526-
model_info += f'{param_name:<{param_name_max_len + 2}}'
527-
model_info += f'{param_shape:<{param_shape_max_len + 2}}'
528-
model_info += f'{param_numel:{param_numel_max_len + 2}d}\n'
510+
if len(param_shapes) == 0: # no parameters
511+
model_info += 'The model contains no parameter.\n'
512+
else:
513+
param_name_max_len = max(map(len, param_shapes.keys()))
514+
param_name_max_len = max(param_name_max_len, len(name_header))
515+
param_shape_max_len = max(map(len, param_shapes.values()))
516+
param_shape_max_len = max(param_shape_max_len,
517+
len(shape_header))
518+
param_numel_max_len = int(np.ceil(
519+
np.log10(max(param_numels.values()))
520+
))
521+
param_numel_max_len = max(param_numel_max_len,
522+
len(numel_header))
523+
model_info += f'{name_header:<{param_name_max_len + 2}}'
524+
model_info += f'{shape_header:<{param_shape_max_len + 2}}'
525+
model_info += f'{numel_header:>{param_numel_max_len + 2}}\n'
526+
model_info += f'{name_separator:<{param_name_max_len + 2}}'
527+
model_info += f'{shape_separator:<{param_shape_max_len + 2}}'
528+
model_info += f'{numel_separator:>{param_numel_max_len + 2}}\n'
529+
for param_name, param_shape in param_shapes.items():
530+
param_numel = param_numels[param_name]
531+
model_info += f'{param_name:<{param_name_max_len + 2}}'
532+
model_info += f'{param_shape:<{param_shape_max_len + 2}}'
533+
model_info += f'{param_numel:{param_numel_max_len + 2}d}\n'
529534
model_info += ('-' * 50 + '\n')
530535

531536
model_info += 'Buffers:\n\n'
@@ -534,25 +539,31 @@ def log_model_info(self):
534539
for buffer_name, buffer in model.named_buffers():
535540
buffer_shapes[buffer_name] = f'{list(buffer.shape)}'
536541
buffer_numels[buffer_name] = buffer.numel()
537-
buffer_name_max_len = max(map(len, buffer_shapes.keys()))
538-
buffer_name_max_len = max(buffer_name_max_len, len(name_header))
539-
buffer_shape_max_len = max(map(len, buffer_shapes.values()))
540-
buffer_shape_max_len = max(buffer_shape_max_len, len(shape_header))
541-
buffer_numel_max_len = int(np.ceil(
542-
np.log10(max(buffer_numels.values()))
543-
))
544-
buffer_numel_max_len = max(buffer_numel_max_len, len(numel_header))
545-
model_info += f'{name_header:<{buffer_name_max_len + 2}}'
546-
model_info += f'{shape_header:<{buffer_shape_max_len + 2}}'
547-
model_info += f'{numel_header:>{buffer_numel_max_len + 2}}\n'
548-
model_info += f'{name_separator:<{buffer_name_max_len + 2}}'
549-
model_info += f'{shape_separator:<{buffer_shape_max_len + 2}}'
550-
model_info += f'{numel_separator:>{buffer_numel_max_len + 2}}\n'
551-
for buffer_name, buffer_shape in buffer_shapes.items():
552-
buffer_numel = buffer_numels[buffer_name]
553-
model_info += f'{buffer_name:<{buffer_name_max_len + 2}}'
554-
model_info += f'{buffer_shape:<{buffer_shape_max_len + 2}}'
555-
model_info += f'{buffer_numel:{buffer_numel_max_len + 2}d}\n'
542+
if len(buffer_shapes) == 0: # no buffers
543+
model_info += 'The model contains no buffer.\n'
544+
else:
545+
buffer_name_max_len = max(map(len, buffer_shapes.keys()))
546+
buffer_name_max_len = max(buffer_name_max_len, len(name_header))
547+
buffer_shape_max_len = max(map(len, buffer_shapes.values()))
548+
buffer_shape_max_len = max(buffer_shape_max_len,
549+
len(shape_header))
550+
buffer_numel_max_len = int(np.ceil(
551+
np.log10(max(buffer_numels.values()))
552+
))
553+
buffer_numel_max_len = max(buffer_numel_max_len,
554+
len(numel_header))
555+
model_info += f'{name_header:<{buffer_name_max_len + 2}}'
556+
model_info += f'{shape_header:<{buffer_shape_max_len + 2}}'
557+
model_info += f'{numel_header:>{buffer_numel_max_len + 2}}\n'
558+
model_info += f'{name_separator:<{buffer_name_max_len + 2}}'
559+
model_info += f'{shape_separator:<{buffer_shape_max_len + 2}}'
560+
model_info += f'{numel_separator:>{buffer_numel_max_len + 2}}\n'
561+
for buffer_name, buffer_shape in buffer_shapes.items():
562+
buffer_numel = buffer_numels[buffer_name]
563+
model_info += f'{buffer_name:<{buffer_name_max_len + 2}}'
564+
model_info += f'{buffer_shape:<{buffer_shape_max_len + 2}}'
565+
model_info += f'{buffer_numel:{buffer_numel_max_len + 2}d}'
566+
model_info += '\n'
556567
model_info += ('-' * 50 + '\n')
557568

558569
model_info += 'Size (using `Float32` for size computation):\n\n'

0 commit comments

Comments
 (0)