Skip to content

Commit 357b71b

Browse files
author
Mark-ZhouWX
committed
add model arts config
1 parent b8e855b commit 357b71b

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

official/cv/segment-anything/configs/coco_box_finetune.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ device: Ascend
55
mode: 0 # 0: graph, 1: pynative
66
work_root: &work_root ./work_dir/
77
log_level: info
8-
amp_level: O2
8+
amp_level: O0
99

1010
# ---------------------------------------------
1111
# Part2: module setting

official/cv/segment-anything/segment_anything/utils/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ def parse_args(parser_config):
3232
if args_cmd.override_cfg is not None:
3333
args.merge_with_dotlist(args_cmd.override_cfg)
3434

35+
# model arts
36+
if 'enable_modelarts' in args_cmd and args_cmd.enable_modelarts:
37+
print(f'model arts enabled')
38+
# output
39+
args.work_root = args_cmd.train_url
40+
# input
41+
args.train_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'train2017')
42+
args.train_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations", "instances_train2017.json")
43+
44+
args.eval_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'val2017')
45+
args.eval_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations",
46+
"instances_val2017.json")
47+
3548
return args
3649

3750

official/cv/segment-anything/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import ast
23

34
import mindspore as ms
45
from mindspore import amp
@@ -65,5 +66,11 @@ def main(args) -> None:
6566
"For dict, use key=value format, eg: device=False. "
6667
"For nested dict, use '.' to denote hierarchy, eg: optimizer.weight_decay=1e-3."
6768
"For list, use number to denote position, eg: callback.1.interval=100.")
69+
70+
# model arts
71+
parser_config.add_argument("--enable_modelarts", type=ast.literal_eval, default=False)
72+
parser_config.add_argument("--train_url", type=str, default="", help="obs path to output folder")
73+
parser_config.add_argument("--data_url", type=str, default="", help="obs path to dataset folder")
74+
6875
args = parse_args(parser_config)
6976
main(args)

0 commit comments

Comments
 (0)