@@ -63,23 +63,18 @@ def forward(
6363 input_parallel = splitted_input [self .tp_rank ].contiguous ()
6464
6565 # Matrix multiply.
66- output_parallel = self .apply (input_parallel )
66+ bias_ = (
67+ None
68+ if (self .tp_rank > 0 or self .base_layer .skip_bias_add )
69+ else self .base_layer .bias
70+ )
71+ output_parallel = self .apply (input_parallel , bias_ )
6772 if self .base_layer .reduce_results and self .tp_size > 1 :
68- output_ = tensor_model_parallel_all_reduce (output_parallel )
69- else :
70- output_ = output_parallel
71-
72- if not self .base_layer .skip_bias_add :
73- output = (
74- output_ + self .base_layer .bias
75- if self .base_layer .bias is not None
76- else output_
77- )
78- output_bias = None
73+ output = tensor_model_parallel_all_reduce (output_parallel )
7974 else :
80- output = output_
81- output_bias = self .base_layer .bias
75+ output = output_parallel
8276
77+ output_bias = self .base_layer .bias if self .base_layer .skip_bias_add else None
8378 if not self .base_layer .return_bias :
8479 return output
8580
@@ -120,7 +115,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
120115 return lora_b
121116
122117 def apply (self , x : torch .Tensor , bias : torch .Tensor | None = None ) -> torch .Tensor :
123- output = self .base_layer .quant_method .apply (self .base_layer , x )
118+ output = self .base_layer .quant_method .apply (self .base_layer , x , bias )
124119
125120 x = x .view (- 1 , x .shape [- 1 ])
126121 output , out_orig_shape = output .view (- 1 , output .shape [- 1 ]), output .shape
0 commit comments