Skip to content

Commit 679fd67

Browse files
authored
Bugfix crnn (#689)
* Bugfix for crnn when amp_level is O0 * Bugfix for CRNN
1 parent 1eeba48 commit 679fd67

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

mindocr/models/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ._registry import is_model, list_models, model_entrypoint
99
from .base_model import BaseModel
10-
from .utils import load_model
10+
from .utils import load_model, set_amp_attr
1111

1212
__all__ = ["build_model"]
1313

@@ -74,5 +74,6 @@ def build_model(name_or_config: Union[str, dict], **kwargs):
7474

7575
if "amp_level" in kwargs:
7676
auto_mixed_precision(network, amp_level=kwargs["amp_level"])
77+
set_amp_attr(network, kwargs["amp_level"])
7778

7879
return network

mindocr/models/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .attention_cells import *
2-
from .load_model import load_model
2+
from .load_model import load_model, set_amp_attr
33
from .rnn_cells import GRUCell

mindocr/models/utils/load_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import os
33
from typing import Callable, Dict, Optional
44

5-
from mindspore import load_checkpoint, load_param_into_net
5+
from mindspore import load_checkpoint, load_param_into_net, nn
66

77
from ..backbones.mindcv_models.utils import auto_map, download_pretrained
88

9-
__all__ = ["load_model", "drop_inconsistent_shape_parameters"]
9+
__all__ = ["load_model", "drop_inconsistent_shape_parameters", "set_amp_attr"]
1010
_logger = logging.getLogger(__name__)
1111

1212

@@ -78,3 +78,9 @@ def load_model(
7878
f"Finish loading model checkoint from {load_from}. "
7979
"If no parameter fail-load warning displayed, all checkpoint params have been successfully loaded."
8080
)
81+
82+
83+
def set_amp_attr(network : nn.Cell, amp_level : str):
84+
cells = network.name_cells()
85+
for name in cells:
86+
setattr(network._cells[name], "_amp_level", amp_level)

0 commit comments

Comments
 (0)