@@ -49,13 +49,38 @@ def validate_minimum_viable_config(self) -> None:
4949 "'system', 'user', or 'assistant'."
5050 )
5151
52- def _run_in_memory (self , input_data_df : pd .DataFrame ) -> pd .DataFrame :
52+ def run (
53+ self , input_data : pd .DataFrame , output_column_name : Optional [str ] = None
54+ ) -> pd .DataFrame :
55+ """Runs the input data through the model."""
56+ if self .in_memory :
57+ return self ._run_in_memory (
58+ input_data_df = input_data , output_column_name = output_column_name
59+ )
60+ else :
61+ return self ._run_in_conda (
62+ input_data_df = input_data , output_column_name = output_column_name
63+ )
64+
65+ def _run_in_memory (
66+ self , input_data_df : pd .DataFrame , output_column_name : Optional [str ] = None
67+ ) -> pd .DataFrame :
5368 """Runs the input data through the model in memory."""
5469 self .logger .info ("Running LLM in memory..." )
5570 model_outputs = []
56-
71+ run_exceptions = set ()
5772 run_cost = 0
73+
5874 for input_data_row in input_data_df .iterrows ():
75+ # Check if output column already has a value to avoid re-running
76+ if (
77+ output_column_name is not None
78+ and output_column_name in input_data_row [1 ]
79+ ):
80+ if input_data_row [1 ][output_column_name ] is not None :
81+ model_outputs .append (input_data_row [1 ][output_column_name ])
82+ continue
83+
5984 input_variables_dict = input_data_row [1 ][
6085 self .model_config ["input_variable_names" ]
6186 ].to_dict ()
@@ -69,12 +94,22 @@ def _run_in_memory(self, input_data_df: pd.DataFrame) -> pd.DataFrame:
6994 model_outputs .append (result ["output" ])
7095 run_cost += result ["cost" ]
7196 except Exception as exc :
72- model_outputs .append (
73- f"[Error] Could not get predictions for row: { exc } "
74- )
97+ model_outputs .append (None )
98+ run_exceptions .add (exc )
7599
76100 self .logger .info ("Successfully ran data through the model!" )
101+
102+ if run_exceptions :
103+ warnings .warn (
104+ f"We couldn't get the outputs for all rows.\n "
105+ "Encountered the following exceptions while running the model: \n "
106+ f"{ run_exceptions } \n "
107+ "After you fix the issues, you can call the `run` method again and provide "
108+ "the `output_column_name` argument to avoid re-running the model on rows "
109+ "that already have an output value."
110+ )
77111 self .cost_estimates .append (run_cost )
112+
78113 return pd .DataFrame ({"predictions" : model_outputs })
79114
80115 def _inject_prompt (self , input_variables_dict : dict ) -> List [Dict [str , str ]]:
@@ -136,7 +171,9 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
136171 """Extracts the cost from the response."""
137172 pass
138173
139- def _run_in_conda (self , input_data : pd .DataFrame ) -> pd .DataFrame :
174+ def _run_in_conda (
175+ self , input_data_df : pd .DataFrame , output_column_name : Optional [str ] = None
176+ ) -> pd .DataFrame :
140177 """Runs LLM prediction job in a conda environment."""
141178 raise NotImplementedError (
142179 "Running LLM in conda environment is not implemented yet. "
0 commit comments