@@ -42,9 +42,7 @@ class OpenlayerModel(abc.ABC):
4242 def run_from_cli (self ) -> None :
4343 """Run the model from the command line."""
4444 parser = argparse .ArgumentParser (description = "Run data through a model." )
45- parser .add_argument (
46- "--dataset-path" , type = str , required = True , help = "Path to the dataset"
47- )
45+ parser .add_argument ("--dataset-path" , type = str , required = True , help = "Path to the dataset" )
4846 parser .add_argument (
4947 "--output-dir" ,
5048 type = str ,
@@ -87,9 +85,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
8785 # Filter row_dict to only include keys that are valid parameters
8886 # for the 'run' method
8987 row_dict = row .to_dict ()
90- filtered_kwargs = {
91- k : v for k , v in row_dict .items () if k in run_signature .parameters
92- }
88+ filtered_kwargs = {k : v for k , v in row_dict .items () if k in run_signature .parameters }
9389
9490 # Call the run method with filtered kwargs
9591 output = self .run (** filtered_kwargs )
@@ -111,6 +107,8 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
111107 df .at [index , "cost" ] = processed_trace ["cost" ]
112108 if "tokens" in processed_trace :
113109 df .at [index , "tokens" ] = processed_trace ["tokens" ]
110+ if "context" in processed_trace :
111+ df .at [index , "context" ] = processed_trace ["context" ]
114112
115113 config = {
116114 "outputColumnName" : "output" ,
@@ -126,6 +124,8 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
126124 config ["costColumnName" ] = "cost"
127125 if "tokens" in df .columns :
128126 config ["numOfTokenColumnName" ] = "tokens"
127+ if "context" in df .columns :
128+ config ["contextColumnName" ] = "context"
129129
130130 return df , config
131131
0 commit comments