33import time
44from typing import Any , Dict , Optional
55
6+ from . import enums
7+
68
79class Step :
810 def __init__ (
@@ -17,7 +19,7 @@ def __init__(
1719 self .output = output
1820 self .metadata = metadata
1921
20- self .step_type = None
22+ self .step_type : enums . StepType = None
2123 self .start_time = time .time ()
2224 self .end_time = None
2325 self .ground_truth = None
@@ -39,7 +41,7 @@ def to_dict(self) -> Dict[str, Any]:
3941 """Dictionary representation of the Step."""
4042 return {
4143 "name" : self .name ,
42- "type" : self .step_type ,
44+ "type" : self .step_type . value ,
4345 "inputs" : self .inputs ,
4446 "output" : self .output ,
4547 "groundTruth" : self .ground_truth ,
@@ -60,10 +62,10 @@ def __init__(
6062 metadata : Dict [str , any ] = {},
6163 ) -> None :
6264 super ().__init__ (name = name , inputs = inputs , output = output , metadata = metadata )
63- self .step_type = "user_call"
65+ self .step_type = enums . StepType . USER_CALL
6466
6567
66- class OpenAIChatCompletionStep (Step ):
68+ class ChatCompletionStep (Step ):
6769 def __init__ (
6870 self ,
6971 name : str ,
@@ -73,7 +75,8 @@ def __init__(
7375 ) -> None :
7476 super ().__init__ (name = name , inputs = inputs , output = output , metadata = metadata )
7577
76- self .step_type = "openai_chat_completion"
78+ self .step_type = enums .StepType .CHAT_COMPLETION
79+ self .provider : str = None
7780 self .prompt_tokens : int = None
7881 self .completion_tokens : int = None
7982 self .tokens : int = None
@@ -83,10 +86,11 @@ def __init__(
8386 self .raw_output : str = None
8487
8588 def to_dict (self ) -> Dict [str , Any ]:
86- """Dictionary representation of the OpenAIChatCompletionStep ."""
89+ """Dictionary representation of the ChatCompletionStep ."""
8790 step_dict = super ().to_dict ()
8891 step_dict .update (
8992 {
93+ "provider" : self .provider ,
9094 "promptTokens" : self .prompt_tokens ,
9195 "completionTokens" : self .completion_tokens ,
9296 "tokens" : self .tokens ,
@@ -100,12 +104,12 @@ def to_dict(self) -> Dict[str, Any]:
100104
101105
102106# ----------------------------- Factory function ----------------------------- #
103- def step_factory (step_type : str , * args , ** kwargs ) -> Step :
107+ def step_factory (step_type : enums . StepType , * args , ** kwargs ) -> Step :
104108 """Factory function to create a step based on the step_type."""
105- if step_type not in ["user_call" , "openai_chat_completion" ]:
106- raise ValueError (f"Step type { step_type } not recognized." )
109+ if step_type . value not in [item . value for item in enums . StepType ]:
110+ raise ValueError (f"Step type { step_type . value } not recognized." )
107111 step_type_mapping = {
108- "user_call" : UserCallStep ,
109- "openai_chat_completion" : OpenAIChatCompletionStep ,
112+ enums . StepType . USER_CALL : UserCallStep ,
113+ enums . StepType . CHAT_COMPLETION : ChatCompletionStep ,
110114 }
111115 return step_type_mapping [step_type ](* args , ** kwargs )
0 commit comments