File tree Expand file tree Collapse file tree 2 files changed +2
-272
lines changed
aiu_fms_testing_utils/utils Expand file tree Collapse file tree 2 files changed +2
-272
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change 1010# Local Packages
1111from aiu_fms_testing_utils .utils .aiu_setup import dprint , rank , world_size
1212from aiu_fms_testing_utils .utils .args_parsing import get_args
13+ from aiu_fms_testing_utils .utils .direct_quantization import run_dq_roberta
1314from aiu_fms_testing_utils .utils .encoders_utils import (
1415 wrap_encoder ,
1516 run_encoder_eval_qa ,
3738# Main model setup
3839default_dtype , device , dist_strat = setup_model (args )
3940
40- model_path = args .model_path
41- if args .int8_direct_quantization :
42- save_path = None
43-
44- # !!! insert DQ for encoders here
45- # pass default_dtype to DQ function
46-
47- # if DQ is used, args.model_path represent FP16 ckpt but we need to load the
48- # newly-created INT8 ckpt. Without DQ, args.model_path is the INT8 ckpt already.
49- model_path = save_path
50-
5141# Retrieve linear configuration (quantized or not) to instantiate FMS model
5242linear_config = get_linear_config (args )
5343
6454model = get_model (
6555 args .architecture ,
6656 args .variant ,
67- model_path = model_path ,
57+ model_path = args . model_path ,
6858 device_type = "cpu" if args .is_aiu_backend else args .device_type ,
6959 data_type = default_dtype ,
7060 source = args .model_source ,
You can’t perform that action at this time.
0 commit comments