Skip to content

Commit 56df163

Browse files
author
Mark-ZhouWX
committed
support amp O2
1 parent 9db9993 commit 56df163

File tree

8 files changed

+24
-17
lines changed

8 files changed

+24
-17
lines changed

research/segment-anything/configs/coco_box_finetune.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ device: Ascend
55
mode: 0 # 0: graph, 1: pynative
66
work_root: &work_root ./work_dir/
77
log_level: info
8-
8+
amp_level: O2
99

1010
# ---------------------------------------------
1111
# Part2: module setting
@@ -15,7 +15,7 @@ loss_manager:
1515
loss_scaler:
1616
type: dynamic
1717
grad_clip: False
18-
18+
drop_overflow_update: False
1919

2020
optimizer:
2121
type: segment_anything.optim.optimizer.AdamW
@@ -61,9 +61,9 @@ train_loader:
6161

6262
shuffle: True
6363
batch_size: 1
64-
epoch_size: 50
64+
epoch_size: 10
6565
drop_remainder: True
66-
num_workers: 1
66+
num_workers: 2
6767
max_rowsize: 24 # 24M space for dataloader
6868

6969

research/segment-anything/configs/flare_box_finetune.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ device: Ascend
55
mode: 0 # 0: graph, 1: pynative
66
work_root: &work_root ./work_dir/
77
log_level: info
8-
8+
amp_level: O2
99

1010
# ---------------------------------------------
1111
# Part2: module setting
@@ -15,6 +15,7 @@ loss_manager:
1515
loss_scaler:
1616
type: dynamic
1717
grad_clip: False
18+
drop_overflow_update: False
1819

1920

2021
optimizer:
@@ -58,7 +59,7 @@ train_loader:
5859

5960
shuffle: True
6061
batch_size: 1
61-
epoch_size: 50
62+
epoch_size: 20
6263
drop_remainder: True
6364
num_workers: 2
6465
max_rowsize: 64 # 24M space for dataloader
@@ -94,6 +95,7 @@ eval_metric: &eval_metric
9495
callback:
9596
- type: segment_anything.utils.callbacks.TrainStatusLog
9697
loss_item: ['focal_loss', 'dice_loss', 'mse_loss'] # for log
98+
interval: 20
9799
- type: segment_anything.utils.callbacks.SaveCkpt
98100
work_root: *work_root
99101
interval: 1 # in epoch

research/segment-anything/eval.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ def main(args) -> None:
1616
ms.context.set_context(mode=args.mode, device_target=args.device, pynative_synchronize=False)
1717
ms.set_seed(42)
1818

19-
rank_id, rank_size = set_distributed(args.distributed)
19+
rank_id, rank_size, main_device = set_distributed(args.distributed)
2020
update_rank_to_dataloader_config(rank_id, rank_size, args.train_loader, args.eval_loader)
21-
main_device = rank_id == 0
2221

2322
set_directory_and_log(main_device, rank_id, rank_size, args.work_root, args.log_level)
2423
logger.info(args.pretty())

research/segment-anything/segment_anything/modeling/image_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,11 @@ def add_decomposed_rel_pos(
359359

360360
B, _, dim = q.shape
361361
r_q = q.reshape(B, q_h, q_w, dim)
362-
362+
dtype = r_q.dtype
363363
# rel_h = ops.einsum("bhwc,hkc->bhwk", r_q, Rh)
364-
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.broadcast_to(ops.unsqueeze(Rh, 0), (B, -1, -1, -1)))
364+
rel_h = ops.BatchMatMul(transpose_b=True)(r_q, ops.broadcast_to(ops.unsqueeze(Rh, 0).astype(dtype), (B, -1, -1, -1)))
365365
# rel_w = ops.einsum("bhwc,wkc->bhwk", r_q, Rw)
366-
rel_w = ops.mul(ops.unsqueeze(r_q, -2), ops.unsqueeze(ops.unsqueeze(Rw, 0), 0)).sum(axis=-1)
366+
rel_w = ops.mul(ops.unsqueeze(r_q, -2), ops.unsqueeze(ops.unsqueeze(Rw, 0), 0).astype(dtype)).sum(axis=-1)
367367

368368
attn = (
369369
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]

research/segment-anything/segment_anything/modeling/prompt_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def _pe_encoding(self, coords: ms.Tensor) -> ms.Tensor:
196196
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
197197
coords = 2 * coords - 1
198198
# aa = coords @ self.positional_encoding_gaussian_matrix
199-
coords = ops.matmul(coords, self.positional_encoding_gaussian_matrix)
199+
dtype = coords.dtype
200+
coords = ops.matmul(coords, self.positional_encoding_gaussian_matrix.astype(dtype))
200201
coords = 2 * np.pi * coords
201202
# outputs d_1 x ... x d_n x C shape
202203
return ops.cat([ops.sin(coords), ops.cos(coords)], axis=-1)

research/segment-anything/segment_anything/modeling/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def construct(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
233233
attn = ops.softmax(attn, axis=-1)
234234

235235
# Get output
236-
out = attn @ v
236+
dtype = attn.dtype
237+
out = attn @ v.astype(dtype)
237238
out = self._recombine_heads(out)
238239
out = self.out_proj(out)
239240

research/segment-anything/segment_anything/utils/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ def set_distributed(distributed):
8383
context.reset_auto_parallel_context()
8484
context.set_auto_parallel_context(device_num=rank_size, gradients_mean=True,
8585
parallel_mode=ParallelMode.DATA_PARALLEL)
86+
main_device = rank_id == 0
8687

8788
# This is the only palace where global rank_id and rank_size can be modified
8889
global RANK_ID, RANK_SIZE
8990
RANK_ID, RANK_SIZE= rank_id, rank_size
9091

91-
return rank_id, rank_size
92+
print(f'rank {rank_id}/{rank_size}, main_device: {main_device}')
93+
94+
return rank_id, rank_size, main_device
9295

9396

9497
def update_rank_to_dataloader_config(rank_id, rank_size, args_train_loader, args_eval_loader, arg_callback=None):

research/segment-anything/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22

33
import mindspore as ms
4+
from mindspore import amp
45

56
from segment_anything.build_sam import create_model
67
from segment_anything.dataset.dataset import create_dataloader
@@ -19,20 +20,20 @@ def main(args) -> None:
1920
ms.context.set_context(mode=args.mode, device_target=args.device, pynative_synchronize=False)
2021
ms.set_seed(42)
2122

22-
rank_id, rank_size = set_distributed(args.distributed)
23+
rank_id, rank_size, main_device = set_distributed(args.distributed)
2324
update_rank_to_dataloader_config(rank_id, rank_size, args.train_loader, args.eval_loader, args.callback)
24-
main_device = rank_id == 0
2525

2626
set_directory_and_log(main_device, rank_id, rank_size, args.work_root, args.log_level, args.callback)
2727
logger.info(args.pretty())
2828

2929
# Step2: create dataset
3030
train_dataloader = create_dataloader(args.train_loader)
3131

32-
# create model, also freeze layer if specified
32+
# create model, load pretrained ckpt, set amp level, also freeze layer if specified
3333
network = create_model(args.network.model)
3434
loss_fn = create_loss_fn(args.network.loss)
3535
network.set_train()
36+
network = amp.auto_mixed_precision(network, args.get('amp_level', 'O0'))
3637

3738
# Step3: create optimizer, including learning rate scheduler and group parameter settings
3839
optimizer = create_optimizer(params=network.trainable_params(), args=args.optimizer)

0 commit comments

Comments
 (0)