@@ -209,6 +209,12 @@ type Classifier struct {
209209 CategoryMapping * CategoryMapping
210210 PIIMapping * PIIMapping
211211 JailbreakMapping * JailbreakMapping
212+
213+ // Category name mapping layer to support generic categories in config
214+ // Maps MMLU-Pro category names -> generic category names (as defined in config.Categories)
215+ MMLUToGeneric map [string ]string
216+ // Maps generic category names -> MMLU-Pro category names
217+ GenericToMMLU map [string ][]string
212218}
213219
214220type option func (* Classifier )
@@ -272,6 +278,9 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
272278 option (classifier )
273279 }
274280
281+ // Build category name mappings to support generic categories in config
282+ classifier .buildCategoryNameMappings ()
283+
275284 return initModels (classifier )
276285}
277286
@@ -331,18 +340,21 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
331340 return "" , float64 (result .Confidence ), nil
332341 }
333342
334- // Convert class index to category name
343+ // Convert class index to category name (MMLU-Pro)
335344 categoryName , ok := c .CategoryMapping .GetCategoryFromIndex (result .Class )
336345 if ! ok {
337346 observability .Warnf ("Class index %d not found in category mapping" , result .Class )
338347 return "" , float64 (result .Confidence ), nil
339348 }
340349
341- // Record the category classification metric
342- metrics . RecordCategoryClassification (categoryName )
350+ // Translate to generic category if mapping is configured
351+ genericCategory := c . translateMMLUToGeneric (categoryName )
343352
344- observability .Infof ("Classified as category: %s" , categoryName )
345- return categoryName , float64 (result .Confidence ), nil
353+ // Record the category classification metric using generic name when available
354+ metrics .RecordCategoryClassification (genericCategory )
355+
356+ observability .Infof ("Classified as category: %s (mmlu=%s)" , genericCategory , categoryName )
357+ return genericCategory , float64 (result .Confidence ), nil
346358}
347359
348360// IsJailbreakEnabled checks if jailbreak detection is enabled and properly configured
@@ -485,11 +497,11 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
485497 observability .Infof ("Classification result: class=%d, confidence=%.4f, entropy_available=%t" ,
486498 result .Class , result .Confidence , len (result .Probabilities ) > 0 )
487499
488- // Get category names for all classes
500+ // Get category names for all classes and translate to generic names when configured
489501 categoryNames := make ([]string , len (result .Probabilities ))
490502 for i := range result .Probabilities {
491503 if name , ok := c .CategoryMapping .GetCategoryFromIndex (i ); ok {
492- categoryNames [i ] = name
504+ categoryNames [i ] = c . translateMMLUToGeneric ( name )
493505 } else {
494506 categoryNames [i ] = fmt .Sprintf ("unknown_%d" , i )
495507 }
@@ -580,20 +592,21 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
580592 return "" , float64 (result .Confidence ), reasoningDecision , nil
581593 }
582594
583- // Convert class index to category name
595+ // Convert class index to category name and translate to generic
584596 categoryName , ok := c .CategoryMapping .GetCategoryFromIndex (result .Class )
585597 if ! ok {
586598 observability .Warnf ("Class index %d not found in category mapping" , result .Class )
587599 return "" , float64 (result .Confidence ), reasoningDecision , nil
588600 }
601+ genericCategory := c .translateMMLUToGeneric (categoryName )
589602
590603 // Record the category classification metric
591- metrics .RecordCategoryClassification (categoryName )
604+ metrics .RecordCategoryClassification (genericCategory )
592605
593- observability .Infof ("Classified as category: %s, reasoning_decision: use=%t, confidence=%.3f, reason=%s" ,
594- categoryName , reasoningDecision .UseReasoning , reasoningDecision .Confidence , reasoningDecision .DecisionReason )
606+ observability .Infof ("Classified as category: %s (mmlu=%s) , reasoning_decision: use=%t, confidence=%.3f, reason=%s" ,
607+ genericCategory , categoryName , reasoningDecision .UseReasoning , reasoningDecision .Confidence , reasoningDecision .DecisionReason )
595608
596- return categoryName , float64 (result .Confidence ), reasoningDecision , nil
609+ return genericCategory , float64 (result .Confidence ), reasoningDecision , nil
597610}
598611
599612// ClassifyPII performs PII token classification on the given text and returns detected PII types
@@ -772,6 +785,51 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
772785 return nil
773786}
774787
788+ // buildCategoryNameMappings builds translation maps between MMLU-Pro and generic categories
789+ func (c * Classifier ) buildCategoryNameMappings () {
790+ c .MMLUToGeneric = make (map [string ]string )
791+ c .GenericToMMLU = make (map [string ][]string )
792+
793+ // Build set of known MMLU-Pro categories from the model mapping (if available)
794+ knownMMLU := make (map [string ]bool )
795+ if c .CategoryMapping != nil {
796+ for _ , label := range c .CategoryMapping .IdxToCategory {
797+ knownMMLU [strings .ToLower (label )] = true
798+ }
799+ }
800+
801+ for _ , cat := range c .Config .Categories {
802+ if len (cat .MMLUCategories ) > 0 {
803+ for _ , mmlu := range cat .MMLUCategories {
804+ key := strings .ToLower (mmlu )
805+ c .MMLUToGeneric [key ] = cat .Name
806+ c .GenericToMMLU [cat .Name ] = append (c .GenericToMMLU [cat .Name ], mmlu )
807+ }
808+ } else {
809+ // Fallback: identity mapping when the generic name matches an MMLU category
810+ nameLower := strings .ToLower (cat .Name )
811+ if knownMMLU [nameLower ] {
812+ c .MMLUToGeneric [nameLower ] = cat .Name
813+ c .GenericToMMLU [cat .Name ] = append (c .GenericToMMLU [cat .Name ], cat .Name )
814+ }
815+ }
816+ }
817+ }
818+
819+ // translateMMLUToGeneric translates an MMLU-Pro category to a generic category if mapping exists
820+ func (c * Classifier ) translateMMLUToGeneric (mmluCategory string ) string {
821+ if mmluCategory == "" {
822+ return ""
823+ }
824+ if c .MMLUToGeneric == nil {
825+ return mmluCategory
826+ }
827+ if generic , ok := c .MMLUToGeneric [strings .ToLower (mmluCategory )]; ok {
828+ return generic
829+ }
830+ return mmluCategory
831+ }
832+
775833// selectBestModelInternal performs the core model selection logic
776834//
777835// modelFilter is optional - if provided, only models passing the filter will be considered
0 commit comments