@@ -1065,7 +1065,7 @@ def __init__(
10651065 minmax_lr : float = None ,
10661066 disable_quanted_input : bool = False ,
10671067 nsamples : int = 512 ,
1068- iters : int = 200 ,
1068+ iters : int = None ,
10691069 use_ggml : bool = False ,
10701070 use_neural_speed : bool = False ,
10711071 llm_int8_skip_modules = None ,
@@ -1091,7 +1091,6 @@ def __init__(
10911091 self .lr = lr
10921092 self .minmax_lr = minmax_lr
10931093 self .disable_quanted_input = disable_quanted_input
1094- self .iters = iters
10951094 self .llm_int8_skip_modules = (
10961095 llm_int8_skip_modules if llm_int8_skip_modules else []
10971096 )
@@ -1101,7 +1100,14 @@ def __init__(
11011100 self .calib_dataloader = kwargs .get ("calib_dataloader" , None )
11021101 self .calib_len = kwargs .get ("calib_len" , 2048 )
11031102 self .calib_func = kwargs .get ("calib_func" , None )
1104- self .calib_iters = kwargs .get ("calib_iters" , 100 )
1103+ calib_iters = kwargs .get ("calib_iters" , None )
1104+ if iters is not None :
1105+ self .calib_iters = iters
1106+ if calib_iters is not None :
1107+ logger .info ("cannot be set simultaneously for 'iters' and 'calib_iters', "
1108+ "we will use 'iters' as calibration iterations!" )
1109+ else :
1110+ self .calib_iters = 200 if calib_iters is None else calib_iters
11051111 self .scheme = "sym" if self .sym else "asym"
11061112 if isinstance (compute_dtype , torch .dtype ):
11071113 self .compute_dtype = convert_dtype_torch2str (compute_dtype )
0 commit comments