@@ -81,33 +81,17 @@ def call_observer(
8181 base_name is "weight", then the module's weight tensor will be used
8282 """
8383 with align_module_device (module ):
84- if base_name == "weight" :
85- value = module .weight
86- g_idx = getattr (module , "weight_g_idx" , None )
87- elif value is not None :
88- g_idx = None
89- else :
90- raise ValueError (
91- "Must provide a value to observe if not using weight observer"
92- )
93-
84+ value = module .weight if base_name == "weight" else value
9485 observer : Observer = getattr (module , f"{ base_name } _observer" )
9586
9687 if should_calculate_gparam :
9788 global_scale = observer .get_global_scale (value )
9889 update_offload_parameter (module , f"{ base_name } _global_scale" , global_scale )
99- else :
100- global_scale = getattr (module , f"{ base_name } _global_scale" , None )
10190
10291 if should_calculate_qparams :
103- updated_scale , updated_zero_point = observer (
104- value , g_idx = g_idx , global_scale = global_scale
105- )
106- # register or update scale & zero_point parameters (supports block shapes)
107- scale_name = f"{ base_name } _scale"
108- zp_name = f"{ base_name } _zero_point"
109- update_offload_parameter (module , scale_name , updated_scale )
110- update_offload_parameter (module , zp_name , updated_zero_point )
92+ scale , zero_point = observer (value )
93+ update_offload_parameter (module , f"{ base_name } _scale" , scale )
94+ update_offload_parameter (module , f"{ base_name } _zero_point" , zero_point )
11195
11296
11397def update_weight_global_scale (module : Module ):
0 commit comments