@@ -518,9 +518,21 @@ def __init__(
518518 "API key must be provided for self-hosted LLMs. "
519519 "Please pass it as the keyword argument 'api_key'"
520520 )
521+ if kwargs .get ("input_key" ) is None :
522+ raise ValueError (
523+ "Input key must be provided for self-hosted LLMs. "
524+ "Please pass it as the keyword argument 'input_key'"
525+ )
526+ if kwargs .get ("output_key" ) is None :
527+ raise ValueError (
528+ "Output key must be provided for self-hosted LLMs. "
529+ "Please pass it as the keyword argument 'output_key'"
530+ )
521531
522532 self .url = kwargs ["url" ]
523533 self .api_key = kwargs ["api_key" ]
534+ self .input_key = kwargs ["input_key" ]
535+ self .output_key = kwargs ["output_key" ]
524536 self ._initialize_llm ()
525537
526538 def _initialize_llm (self ):
@@ -559,8 +571,7 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
559571 "Authorization" : f"Bearer { self .api_key } " ,
560572 "Content-Type" : "application/json" ,
561573 }
562- # TODO: use correct input key
563- data = {"inputs" : llm_input }
574+ data = {self .input_key : llm_input }
564575 response = requests .post (self .url , headers = headers , json = data )
565576 if response .status_code == 200 :
566577 response_data = response .json ()[0 ]
@@ -570,9 +581,15 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
570581
571582 def _get_output (self , response : Dict [str , Any ]) -> str :
572583 """Gets the output from the response."""
573- # TODO: use correct output key
574- return response ["generated_text" ]
584+ return response [self .output_key ]
575585
576586 def _get_cost_estimate (self , response : Dict [str , Any ]) -> float :
577587 """Estimates the cost from the response."""
578588 return 0
589+
590+
591+ class HuggingFaceModelRunner (SelfHostedLLModelRunner ):
592+ """Wraps LLMs hosted in HuggingFace."""
593+
594+ def __init__ (self , url , api_key ):
595+ super ().__init__ (url , api_key , input_key = "inputs" , output_key = "generated_text" )
0 commit comments