Skip to content

Commit 2d41342

Browse files
author
Mark-ZhouWX
committed
update cloud training dataset config
1 parent 8099334 commit 2d41342

File tree

1 file changed

+10
-6
lines changed
  • official/cv/segment-anything/segment_anything/utils

1 file changed

+10
-6
lines changed

official/cv/segment-anything/segment_anything/utils/config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ def parse_args(parser_config):
3838
# output
3939
args.work_root = args_cmd.train_url
4040
# input
41-
args.train_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'train2017')
42-
args.train_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations", "instances_train2017.json")
43-
44-
args.eval_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'val2017')
45-
args.eval_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations",
46-
"instances_val2017.json")
41+
if args.train_loader.dataset.type.endswith('COCODataset'):
42+
args.train_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'train2017')
43+
args.train_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations", "instances_train2017.json")
44+
45+
args.eval_loader.dataset.data_dir = os.path.join(args_cmd.data_url, 'val2017')
46+
args.eval_loader.dataset.annotation_path = os.path.join(args_cmd.data_url, "annotations", "instances_val2017.json")
47+
elif args.train_loader.dataset.type.endswith('SA1BDataset'):
48+
args.train_loader.dataset.data_dir = args_cmd.data_url
49+
else:
50+
raise NotImplementedError(f'Not supported dataset {args.train_loader.dataset.type}')
4751

4852
return args
4953

0 commit comments

Comments
 (0)