1313from transformers import PretrainedConfig
1414
1515from vllm .config import ModelConfig , PoolerConfig
16+ from vllm .logger import init_logger
1617from vllm .pooling_params import PoolingParams
1718from vllm .sequence import PoolerOutput , PoolingSequenceGroupOutput
1819from vllm .tasks import PoolingTask
1920from vllm .utils import current_stream , resolve_obj_by_qualname
2021from vllm .v1 .pool .metadata import PoolingCursor , PoolingMetadata
2122
23+ logger = init_logger (__name__ )
24+
2225PoolingFn = Callable [
2326 [Union [torch .Tensor , list [torch .Tensor ]], PoolingMetadata ],
2427 Union [torch .Tensor , list [torch .Tensor ]]]
@@ -183,7 +186,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
183186 fn = resolve_obj_by_qualname (function_name )()
184187 return PoolerActivation .wraps (fn )
185188
186- return PoolerScore ()
189+ return PoolerClassify ()
187190
188191
189192def build_output (
@@ -371,22 +374,29 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
371374
372375class PoolerClassify (PoolerActivation ):
373376
374- def forward_chunk (self , pooled_data : torch .Tensor ) -> torch .Tensor :
375- num_labels = pooled_data .shape [- 1 ]
376- if num_labels < 2 :
377- return F .sigmoid (pooled_data .float ()).to (pooled_data .dtype )
378-
379- return F .softmax (pooled_data .float (), dim = - 1 ).to (pooled_data .dtype )
380-
377+ def __init__ (self , * , static_num_labels : bool = True ) -> None :
378+ super ().__init__ ()
381379
382- class PoolerScore (PoolerActivation ):
380+ if static_num_labels :
381+ from vllm .config import get_current_vllm_config
382+ vllm_config = get_current_vllm_config ()
383+ self .num_labels = getattr (vllm_config .model_config .hf_config ,
384+ "num_labels" , 0 )
385+ if self .num_labels == 0 :
386+ logger .warning ("num_labels should be > 0 for classification"
387+ "models, falling back to softmax. "
388+ "Please check if the configuration is correct." )
389+ else :
390+ self .num_labels = None
383391
384392 def forward_chunk (self , pooled_data : torch .Tensor ) -> torch .Tensor :
385- num_labels = pooled_data .shape [- 1 ]
393+ num_labels = (self .num_labels if self .num_labels is not None else
394+ pooled_data .shape [- 1 ])
395+
386396 if num_labels < 2 :
387397 return F .sigmoid (pooled_data .float ()).to (pooled_data .dtype )
388398
389- return pooled_data
399+ return F . softmax ( pooled_data . float (), dim = - 1 ). to ( pooled_data . dtype )
390400
391401
392402class LambdaPoolerActivation (PoolerActivation ):
@@ -428,6 +438,10 @@ def __init__(self) -> None:
428438 def forward (self , pooled_data : Union [list [torch .Tensor ], torch .Tensor ],
429439 pooling_metadata : PoolingMetadata ):
430440
441+ if isinstance (pooled_data , list ):
442+ pooled_data = torch .stack (pooled_data )
443+ # pooled_data shape: [batchsize, hidden_dimension]
444+
431445 # Apply ST projector
432446 if self .projector is not None :
433447 projector = cast (nn .Module , self .projector )
@@ -437,17 +451,11 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
437451 y = projector (x .to (torch .float32 ))
438452 return y .to (orig_dtype )
439453
440- if isinstance (pooled_data , torch .Tensor ):
441- pooled_data = _proj (pooled_data )
442- else :
443- pooled_data = [_proj (t ) for t in pooled_data ]
454+ pooled_data = _proj (pooled_data )
455+ # pooled_data shape: [batchsize, embedding_dimension]
444456
445457 pooling_params = get_pooling_params (pooling_metadata )
446458
447- if isinstance (pooled_data , list ):
448- pooled_data = torch .stack (pooled_data )
449- # pooled_data shape: [batchsize, embedding_dimension]
450-
451459 # for matryoshka representation
452460 dimensions_list = [
453461 pooling_param .dimensions for pooling_param in pooling_params
@@ -477,13 +485,14 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
477485 for vecs , f in zip (pooled_data , flags )
478486 ]
479487
488+ # pooled_data shape: [batchsize, embedding_dimension]
480489 return pooled_data
481490
482491
483492class RewardPoolerHead (PoolerHead ):
484493
485494 def __init__ (self ) -> None :
486- super ().__init__ (activation = PoolerClassify ())
495+ super ().__init__ (activation = PoolerClassify (static_num_labels = False ))
487496
488497 def forward (self , pooled_data : Union [list [torch .Tensor ], torch .Tensor ],
489498 pooling_metadata : PoolingMetadata ):
@@ -637,19 +646,13 @@ def forward(
637646 pooling_metadata : PoolingMetadata ,
638647 ) -> PoolerOutput :
639648 pooled_data = self .pooling (hidden_states , pooling_metadata )
640-
641649 if isinstance (pooled_data , list ):
642650 pooled_data = torch .stack (pooled_data )
643651 # pooled_data shape: [batchsize, hidden_size]
644652
645653 if self .classifier is not None :
646- # apply classifier once on the full batch if possible
647- if isinstance (pooled_data , torch .Tensor ):
648- pooled_data = self .classifier (pooled_data )
649- elif len ({data .shape for data in pooled_data }) <= 1 :
650- pooled_data = self .classifier (torch .stack (pooled_data ))
651- else :
652- pooled_data = [self .classifier (data ) for data in pooled_data ]
654+ pooled_data = self .classifier (pooled_data )
655+ # pooled_data shape: [batchsize, num_labels]
653656
654657 pooling_params = get_pooling_params (pooling_metadata )
655658 flags = [p .activation for p in pooling_params ]
@@ -662,6 +665,7 @@ def forward(
662665 for vecs , f in zip (pooled_data , flags )
663666 ]
664667
668+ # scores shape: [batchsize, num_labels]
665669 return build_output (scores )
666670
667671
0 commit comments