@@ -155,7 +155,14 @@ def wrapper(*func_args, **func_kwargs):
155155 if step_kwargs .get ("name" ) is None :
156156 step_kwargs ["name" ] = func .__name__
157157 with create_step (* step_args , ** step_kwargs ) as step :
158- output = func (* func_args , ** func_kwargs )
158+ output = None
159+ exception = None
160+ try :
161+ output = func (* func_args , ** func_kwargs )
162+ # pylint: disable=broad-except
163+ except Exception as exc :
164+ step .log (metadata = {"Exceptions" : str (exc )})
165+ exception = exc
159166 end_time = time .time ()
160167 latency = (end_time - step .start_time ) * 1000 # in ms
161168
@@ -171,6 +178,9 @@ def wrapper(*func_args, **func_kwargs):
171178 end_time = end_time ,
172179 latency = latency ,
173180 )
181+
182+ if exception is not None :
183+ raise exception
174184 return output
175185
176186 return wrapper
@@ -189,12 +199,14 @@ def process_trace_for_upload(
189199 root_step = trace_obj .steps [0 ]
190200
191201 input_variables = root_step .inputs
192- input_variable_names = list (input_variables .keys ())
202+ if input_variables :
203+ input_variable_names = list (input_variables .keys ())
204+ else :
205+ input_variable_names = []
193206
194207 processed_steps = bubble_up_costs_and_tokens (trace_obj .to_dict ())
195208
196209 trace_data = {
197- ** input_variables ,
198210 "inferenceTimestamp" : root_step .start_time ,
199211 "inferenceId" : str (root_step .id ),
200212 "output" : root_step .output ,
@@ -204,6 +216,8 @@ def process_trace_for_upload(
204216 "tokens" : processed_steps [0 ].get ("tokens" , 0 ),
205217 "steps" : processed_steps ,
206218 }
219+ if input_variables :
220+ trace_data .update (input_variables )
207221
208222 return trace_data , input_variable_names
209223
0 commit comments