@@ -39,6 +39,8 @@ class OpenlayerModel(abc.ABC):
3939 Refer to Openlayer's templates for examples of how to implement this class.
4040 """
4141
42+ custom_args : dict = {}
43+
4244 def run_from_cli (self ) -> None :
4345 """Run the model from the command line."""
4446 parser = argparse .ArgumentParser (description = "Run data through a model." )
@@ -51,10 +53,26 @@ def run_from_cli(self) -> None:
5153 required = False ,
5254 help = "Directory to dump the results in" ,
5355 )
56+ parser .add_argument (
57+ "--custom-args" ,
58+ type = str ,
59+ required = False ,
60+ help = "Custom arguments in format 'key1=value1,key2=value2'" ,
61+ )
5462
5563 # Parse the arguments
5664 args = parser .parse_args ()
5765
66+ # Parse custom arguments string
67+ custom_args = {}
68+ if args .custom_args :
69+ pairs = args .custom_args .split ("," )
70+ for pair in pairs :
71+ if "=" in pair :
72+ key , value = pair .split ("=" , 1 )
73+ custom_args [key ] = value
74+ self .custom_args = custom_args
75+
5876 return self .batch (
5977 dataset_path = args .dataset_path ,
6078 output_dir = args .output_dir ,
@@ -69,12 +87,16 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
6987 elif dataset_path .endswith (".json" ):
7088 df = pd .read_json (dataset_path , orient = "records" )
7189 fmt = "json"
90+ else :
91+ raise ValueError (f"Unsupported dataset format: { dataset_path } " )
7292
7393 # Call the model's run_batch method, passing in the DataFrame
74- output_df , config = self .run_batch_from_df (df )
94+ output_df , config = self .run_batch_from_df (df , custom_args = self . custom_args )
7595 self .write_output_to_directory (output_df , config , output_dir , fmt )
7696
77- def run_batch_from_df (self , df : pd .DataFrame ) -> Tuple [pd .DataFrame , dict ]:
97+ def run_batch_from_df (
98+ self , df : pd .DataFrame , custom_args : dict = None
99+ ) -> Tuple [pd .DataFrame , dict ]:
78100 """Function that runs the model and returns the result."""
79101 # Ensure the 'output' column exists
80102 if "output" not in df .columns :
@@ -83,6 +105,10 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
83105 # Get the signature of the 'run' method
84106 run_signature = inspect .signature (self .run )
85107
108+ # If the model has a custom_args attribute, update it
109+ if hasattr (self , "custom_args" ) and custom_args is not None :
110+ self .custom_args .update (custom_args )
111+
86112 for index , row in df .iterrows ():
87113 # Filter row_dict to only include keys that are valid parameters
88114 # for the 'run' method
@@ -112,8 +138,7 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
112138 if "tokens" in processed_trace :
113139 df .at [index , "tokens" ] = processed_trace ["tokens" ]
114140 if "context" in processed_trace :
115- # Convert the context list to a string to avoid pandas issues
116- df .at [index , "context" ] = json .dumps (processed_trace ["context" ])
141+ df .at [index , "context" ] = processed_trace ["context" ]
117142
118143 config = {
119144 "outputColumnName" : "output" ,
@@ -132,6 +157,9 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
132157 if "context" in df .columns :
133158 config ["contextColumnName" ] = "context"
134159
160+ for k , v in self .custom_args .items ():
161+ config ["metadata" ][k ] = v
162+
135163 return df , config
136164
137165 def write_output_to_directory (
0 commit comments