@@ -17,7 +17,7 @@ def __init__(
1717 self .output = output
1818 self .metadata = metadata
1919
20- self .step_type = "user_call"
20+ self .step_type = None
2121 self .start_time = time .time ()
2222 self .end_time = None
2323 self .ground_truth = None
@@ -49,3 +49,61 @@ def to_dict(self) -> Dict[str, Any]:
4949 "startTime" : self .start_time ,
5050 "endTime" : self .end_time ,
5151 }
52+
53+
54+ class UserCallStep (Step ):
55+ def __init__ (
56+ self ,
57+ name : str ,
58+ inputs : Optional [Any ] = None ,
59+ output : Optional [Any ] = None ,
60+ metadata : Dict [str , any ] = {},
61+ ) -> None :
62+ super ().__init__ (name = name , inputs = inputs , output = output , metadata = metadata )
63+ self .step_type = "user_call"
64+
65+
66+ class OpenAIChatCompletionStep (Step ):
67+ def __init__ (
68+ self ,
69+ name : str ,
70+ inputs : Optional [Any ] = None ,
71+ output : Optional [Any ] = None ,
72+ metadata : Dict [str , any ] = {},
73+ ) -> None :
74+ super ().__init__ (name = name , inputs = inputs , output = output , metadata = metadata )
75+
76+ self .step_type = "openai_chat_completion"
77+ self .prompt_tokens : int = None
78+ self .completion_tokens : int = None
79+ self .cost : float = None
80+ self .model : str = None
81+ self .model_parameters : Dict [str , Any ] = None
82+ self .raw_output : str = None
83+
84+ def to_dict (self ) -> Dict [str , Any ]:
85+ """Dictionary representation of the OpenAIChatCompletionStep."""
86+ step_dict = super ().to_dict ()
87+ step_dict .update (
88+ {
89+ "promptTokens" : self .prompt_tokens ,
90+ "completionTokens" : self .completion_tokens ,
91+ "cost" : self .cost ,
92+ "model" : self .model ,
93+ "modelParameters" : self .model_parameters ,
94+ "rawOutput" : self .raw_output ,
95+ }
96+ )
97+ return step_dict
98+
99+
100+ # ----------------------------- Factory function ----------------------------- #
101+ def step_factory (step_type : str , * args , ** kwargs ) -> Step :
102+ """Factory function to create a step based on the step_type."""
103+ if step_type not in ["user_call" , "openai_chat_completion" ]:
104+ raise ValueError (f"Step type { step_type } not recognized." )
105+ step_type_mapping = {
106+ "user_call" : UserCallStep ,
107+ "openai_chat_completion" : OpenAIChatCompletionStep ,
108+ }
109+ return step_type_mapping [step_type ](* args , ** kwargs )
0 commit comments