77import time
88import warnings
99from abc import ABC , abstractmethod
10- from typing import Any , Dict , List , Optional , Union
10+ from typing import Any , Dict , Generator , List , Optional , Tuple , Union
1111
1212import anthropic
1313import cohere
1414import openai
1515import pandas as pd
1616import pybars
17+ from tqdm import tqdm
1718
1819from . import base_model_runner
1920
@@ -56,64 +57,94 @@ def run(
5657 """Runs the input data through the model."""
5758 if self .in_memory :
5859 return self ._run_in_memory (
59- input_data_df = input_data , output_column_name = output_column_name
60+ input_data_df = input_data ,
61+ output_column_name = output_column_name ,
6062 )
6163 else :
6264 return self ._run_in_conda (
6365 input_data_df = input_data , output_column_name = output_column_name
6466 )
6567
6668 def _run_in_memory (
67- self , input_data_df : pd .DataFrame , output_column_name : Optional [str ] = None
69+ self ,
70+ input_data_df : pd .DataFrame ,
71+ output_column_name : Optional [str ] = None ,
6872 ) -> pd .DataFrame :
69- """Runs the input data through the model in memory."""
73+ """Runs the input data through the model in memory and returns a pandas
74+ dataframe."""
75+ for output_df , _ in tqdm (
76+ self ._run_in_memory_and_yield_progress (input_data_df , output_column_name ),
77+ total = len (input_data_df ),
78+ colour = "BLUE" ,
79+ ):
80+ pass
81+ return output_df
82+
83+ def _run_in_memory_and_yield_progress (
84+ self ,
85+ input_data_df : pd .DataFrame ,
86+ output_column_name : Optional [str ] = None ,
87+ ) -> Generator [Tuple [pd .DataFrame , float ], None , None ]:
88+ """Runs the input data through the model in memory and yields the results
89+ and the progress."""
7090 self .logger .info ("Running LLM in memory..." )
7191
7292 model_outputs = []
7393 timestamps = []
7494 run_exceptions = set ()
7595 run_cost = 0
76- for input_data_row in input_data_df .iterrows ():
96+ total_rows = len (input_data_df )
97+ current_row = 0
98+
99+ for _ , input_data_row in input_data_df .iterrows ():
77100 # Check if output column already has a value to avoid re-running
78- if (
79- output_column_name is not None
80- and output_column_name in input_data_row [1 ]
81- ):
82- if input_data_row [1 ][output_column_name ] is not None :
83- model_outputs .append (input_data_row [1 ][output_column_name ])
101+ if output_column_name and output_column_name in input_data_row :
102+ output_value = input_data_row [output_column_name ]
103+ if output_value is not None :
104+ model_outputs .append (output_value )
105+ current_row += 1
106+ yield pd .DataFrame (
107+ {"predictions" : model_outputs , "timestamps" : timestamps }
108+ ), current_row / total_rows
84109 continue
85110
86- input_variables_dict = input_data_row [1 ][
87- self .model_config ["input_variable_names" ]
88- ].to_dict ()
89- injected_prompt = self ._inject_prompt (
90- input_variables_dict = input_variables_dict
91- )
92- llm_input = self ._get_llm_input (injected_prompt )
93-
94- try :
95- result = self ._get_llm_output (llm_input )
96- model_outputs .append (result ["output" ])
97- run_cost += result ["cost" ]
98- except Exception as exc :
99- model_outputs .append (None )
100- run_exceptions .add (exc )
111+ output , cost , exceptions = self ._run_single_input (input_data_row )
112+
113+ model_outputs .append (output )
114+ run_cost += cost
115+ run_exceptions .update (exceptions )
101116 timestamps .append (time .time ())
117+ current_row += 1
118+
119+ yield pd .DataFrame (
120+ {"predictions" : model_outputs , "timestamps" : timestamps }
121+ ), current_row / total_rows
102122
103123 self .logger .info ("Successfully ran data through the model!" )
104124
105- if run_exceptions :
106- warnings .warn (
107- f"We couldn't get the outputs for all rows.\n "
108- "Encountered the following exceptions while running the model: \n "
109- f"{ run_exceptions } \n "
110- "After you fix the issues, you can call the `run` method again and provide "
111- "the `output_column_name` argument to avoid re-running the model on rows "
112- "that already have an output value."
113- )
125+ self ._report_exceptions (run_exceptions )
114126 self .cost_estimates .append (run_cost )
115127
116- return pd .DataFrame ({"predictions" : model_outputs , "timestamps" : timestamps })
128+ yield pd .DataFrame (
129+ {"predictions" : model_outputs , "timestamps" : timestamps }
130+ ), 1.0
131+
132+ def _run_single_input (self , input_data_row : pd .Series ) -> Tuple [str , float , set ]:
133+ """Runs the LLM on a single row of input data.
134+
135+ Returns a tuple of the output, cost, and exceptions encountered.
136+ """
137+ input_variables_dict = input_data_row [
138+ self .model_config ["input_variable_names" ]
139+ ].to_dict ()
140+ injected_prompt = self ._inject_prompt (input_variables_dict = input_variables_dict )
141+ llm_input = self ._get_llm_input (injected_prompt )
142+
143+ try :
144+ outputs = self ._get_llm_output (llm_input )
145+ return outputs ["output" ], outputs ["cost" ], set ()
146+ except Exception as exc :
147+ return None , 0 , {exc }
117148
118149 def _inject_prompt (self , input_variables_dict : dict ) -> List [Dict [str , str ]]:
119150 """Injects the input variables into the prompt template.
@@ -174,6 +205,17 @@ def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
174205 """Extracts the cost from the response."""
175206 pass
176207
208+ def _report_exceptions (self , exceptions : set ) -> None :
209+ if exceptions :
210+ warnings .warn (
211+ f"We couldn't get the outputs for all rows.\n "
212+ "Encountered the following exceptions while running the model: \n "
213+ f"{ exceptions } \n "
214+ "After you fix the issues, you can call the `run` method again and provide "
215+ "the `output_column_name` argument to avoid re-running the model on rows "
216+ "that already have an output value."
217+ )
218+
177219 def _run_in_conda (
178220 self , input_data_df : pd .DataFrame , output_column_name : Optional [str ] = None
179221 ) -> pd .DataFrame :
@@ -199,6 +241,21 @@ def get_cost_estimate(self, num_of_runs: Optional[int] = None) -> float:
199241 return sum (self .cost_estimates [- num_of_runs :])
200242 return self .cost_estimates [- 1 ]
201243
244+ def run_and_yield_progress (
245+ self , input_data : pd .DataFrame , output_column_name : Optional [str ] = None
246+ ) -> Generator [Tuple [pd .DataFrame , float ], None , None ]:
247+ """Runs the input data through the model and yields progress."""
248+ if self .in_memory :
249+ yield from self ._run_in_memory_and_yield_progress (
250+ input_data_df = input_data ,
251+ output_column_name = output_column_name ,
252+ )
253+ else :
254+ raise NotImplementedError (
255+ "Running LLM in conda environment is not implemented yet. "
256+ "Please use the in-memory runner."
257+ )
258+
202259
203260# -------------------------- Concrete model runners -------------------------- #
204261class AnthropicModelRunner (LLModelRunner ):
0 commit comments