1- import warnings
21from dataclasses import dataclass , field
32from enum import Enum
4- from typing import Any , Dict , List , Optional
53
4+ from labelbox .schema .tool_building .base_step_reasoning_tool import (
5+ _BaseStepReasoningTool ,
6+ _Definition ,
7+ _Variant ,
8+ )
69from labelbox .schema .tool_building .tool_type import ToolType
7- from labelbox .schema .tool_building .variant import VariantWithActions
810
911
1012class IncorrectStepActions (Enum ):
@@ -14,156 +16,29 @@ class IncorrectStepActions(Enum):
1416 JUSTIFICATION = "justification"
1517
1618
17- @dataclass
18- class StepReasoningVariants :
19- """
20- This class is used to define the possible options for evaluating a step
21- Currently the options are correct, neutral, and incorrect
22- NOTE: do not change the variant values
23- """
24-
25- correct_step : VariantWithActions = field (
26- default_factory = lambda : VariantWithActions (
27- id = 0 ,
28- name = "Correct" ,
29- actions = [],
30- )
31- )
32- neutral_step : VariantWithActions = field (
33- default_factory = lambda : VariantWithActions (
34- id = 1 ,
35- name = "Neutral" ,
36- actions = [],
37- )
38- )
39-
40- incorrect_step : VariantWithActions = field (
41- default_factory = lambda : VariantWithActions (
42- id = 2 ,
43- name = "Incorrect" ,
44- _available_actions = {
45- action .value for action in IncorrectStepActions
46- },
47- actions = [
48- action .value for action in IncorrectStepActions
49- ], # initialize to all IncorrectStepActions by default
50- )
51- )
52-
53- def asdict (self ):
54- return [
55- self .correct_step .asdict (),
56- self .neutral_step .asdict (),
57- self .incorrect_step .asdict (),
58- ]
59-
60- @classmethod
61- def from_dict (cls , dictionary : List [Dict [str , Any ]]):
62- correct_step = None
63- neutral_step = None
64- incorrect_step = None
65-
66- for variant in dictionary :
67- if variant ["id" ] == 0 :
68- correct_step = VariantWithActions (** variant )
69- elif variant ["id" ] == 1 :
70- neutral_step = VariantWithActions (** variant )
71- elif variant ["id" ] == 2 :
72- incorrect_step = VariantWithActions (** variant )
73-
74- if not all ([correct_step , neutral_step , incorrect_step ]):
75- raise ValueError ("Invalid step reasoning variants" )
76-
77- return cls (
78- correct_step = correct_step , # type: ignore
79- neutral_step = neutral_step , # type: ignore
80- incorrect_step = incorrect_step , # type: ignore
81- )
82-
83-
84- @dataclass
85- class StepReasoningDefinition :
86- variants : StepReasoningVariants = field (
87- default_factory = StepReasoningVariants
19+ def build_step_reasoning_definition ():
20+ correct_step = _Variant (id = 0 , name = "Correct" , actions = [])
21+ neutral_step = _Variant (id = 1 , name = "Neutral" , actions = [])
22+ incorrect_step = _Variant (
23+ id = 2 ,
24+ name = "Incorrect" ,
25+ _available_actions = {action .value for action in IncorrectStepActions },
26+ actions = [action .value for action in IncorrectStepActions ],
8827 )
89- version : int = field (default = 1 )
90- title : Optional [str ] = None
91- value : Optional [str ] = None
92-
93- def __post_init__ (self ):
94- if self .version != 1 :
95- raise ValueError ("Invalid version" )
96-
97- def asdict (self ) -> Dict [str , Any ]:
98- result = {"variants" : self .variants .asdict (), "version" : self .version }
99- if self .title is not None :
100- result ["title" ] = self .title
101- if self .value is not None :
102- result ["value" ] = self .value
103- return result
104-
105- @classmethod
106- def from_dict (cls , dictionary : Dict [str , Any ]) -> "StepReasoningDefinition" :
107- variants = StepReasoningVariants .from_dict (dictionary ["variants" ])
108- title = dictionary .get ("title" , None )
109- value = dictionary .get ("value" , None )
110- return cls (variants = variants , title = title , value = value )
28+ variants = [correct_step , neutral_step , incorrect_step ]
29+ return _Definition (variants = variants )
11130
11231
11332@dataclass
114- class StepReasoningTool :
33+ class StepReasoningTool ( _BaseStepReasoningTool ) :
11534 """
11635 Use this class in OntologyBuilder to create a tool for step reasoning
11736 The definition field lists the possible options to evaulate a step
11837
11938 NOTE: color attribute is for backward compatibility only and should not be set directly
12039 """
12140
122- name : str
12341 type : ToolType = field (default = ToolType .STEP_REASONING , init = False )
124- required : bool = False
125- schema_id : Optional [str ] = None
126- feature_schema_id : Optional [str ] = None
127- color : Optional [str ] = None
128- definition : StepReasoningDefinition = field (
129- default_factory = StepReasoningDefinition
42+ definition : _Definition = field (
43+ default_factory = build_step_reasoning_definition
13044 )
131-
132- def __post_init__ (self ):
133- warnings .warn (
134- "This feature is experimental and subject to change." ,
135- )
136-
137- if not self .name .strip ():
138- raise ValueError ("Name is required" )
139-
140- def set_incorrect_step_actions (self , actions : List [IncorrectStepActions ]):
141- """
142- For live models, will invoke the model to generate alternatives if a step is marked as incorrect
143- NOTE by default all actions are set to True
144- Pass empty list to reset to false
145- """
146- actions_values = [action .value for action in actions ]
147- self .definition .variants .incorrect_step .set_actions (actions_values )
148-
149- def asdict (self ) -> Dict [str , Any ]:
150- return {
151- "tool" : self .type .value ,
152- "name" : self .name ,
153- "required" : self .required ,
154- "schemaNodeId" : self .schema_id ,
155- "featureSchemaId" : self .feature_schema_id ,
156- "definition" : self .definition .asdict (),
157- }
158-
159- @classmethod
160- def from_dict (cls , dictionary : Dict [str , Any ]) -> "StepReasoningTool" :
161- return cls (
162- name = dictionary ["name" ],
163- schema_id = dictionary .get ("schemaNodeId" , None ),
164- feature_schema_id = dictionary .get ("featureSchemaId" , None ),
165- required = dictionary .get ("required" , False ),
166- definition = StepReasoningDefinition .from_dict (
167- dictionary ["definition" ]
168- ),
169- )
0 commit comments