22from typing import Any , Dict , List , Optional
33
44from labelbox .schema .tool_building .tool_type import ToolType
5-
6-
7- @dataclass
8- class StepReasoningVariant :
9- id : int
10- name : str
11-
12- def asdict (self ) -> Dict [str , Any ]:
13- return {"id" : self .id , "name" : self .name }
14-
15-
16- @dataclass
17- class IncorrectStepReasoningVariant :
18- id : int
19- name : str
20- regenerate_conversations_after_incorrect_step : Optional [bool ] = False
21- rate_alternative_responses : Optional [bool ] = False
22-
23- def asdict (self ) -> Dict [str , Any ]:
24- actions = []
25- if self .regenerate_conversations_after_incorrect_step :
26- actions .append ("regenerateSteps" )
27- if self .rate_alternative_responses :
28- actions .append ("generateAndRateAlternativeSteps" )
29- return {"id" : self .id , "name" : self .name , "actions" : actions }
30-
31- @classmethod
32- def from_dict (
33- cls , dictionary : Dict [str , Any ]
34- ) -> "IncorrectStepReasoningVariant" :
35- return cls (
36- id = dictionary ["id" ],
37- name = dictionary ["name" ],
38- regenerate_conversations_after_incorrect_step = "regenerateSteps"
39- in dictionary .get ("actions" , []),
40- rate_alternative_responses = "generateAndRateAlternativeSteps"
41- in dictionary .get ("actions" , []),
42- )
43-
44-
45- def _create_correct_step () -> StepReasoningVariant :
46- return StepReasoningVariant (
47- id = StepReasoningVariants .CORRECT_STEP_ID , name = "Correct"
48- )
49-
50-
51- def _create_neutral_step () -> StepReasoningVariant :
52- return StepReasoningVariant (
53- id = StepReasoningVariants .NEUTRAL_STEP_ID , name = "Neutral"
54- )
55-
56-
57- def _create_incorrect_step () -> IncorrectStepReasoningVariant :
58- return IncorrectStepReasoningVariant (
59- id = StepReasoningVariants .INCORRECT_STEP_ID , name = "Incorrect"
60- )
5+ from labelbox .schema .tool_building .variant import Variant , VariantWithActions
616
627
638@dataclass
@@ -67,18 +12,23 @@ class StepReasoningVariants:
6712 Currently the options are correct, neutral, and incorrect
6813 """
6914
70- CORRECT_STEP_ID = 0
71- NEUTRAL_STEP_ID = 1
72- INCORRECT_STEP_ID = 2
73-
74- correct_step : StepReasoningVariant = field (
75- default_factory = _create_correct_step
15+ correct_step : Variant = field (
16+ default_factory = lambda : Variant (id = 0 , name = "Correct" )
7617 )
77- neutral_step : StepReasoningVariant = field (
78- default_factory = _create_neutral_step
18+ neutral_step : Variant = field (
19+ default_factory = lambda : Variant ( id = 1 , name = "Neutral" )
7920 )
80- incorrect_step : IncorrectStepReasoningVariant = field (
81- default_factory = _create_incorrect_step
21+
22+ incorrect_step : VariantWithActions = field (
23+ default_factory = lambda : VariantWithActions (
24+ id = 2 ,
25+ name = "Incorrect" ,
26+ _available_actions = {
27+ "regenerateSteps" ,
28+ "generateAndRateAlternativeSteps" ,
29+ },
30+ actions = ["regenerateSteps" ], # regenerateSteps is on by default
31+ )
8232 )
8333
8434 def asdict (self ):
@@ -95,14 +45,12 @@ def from_dict(cls, dictionary: List[Dict[str, Any]]):
9545 incorrect_step = None
9646
9747 for variant in dictionary :
98- if variant ["id" ] == cls .CORRECT_STEP_ID :
99- correct_step = StepReasoningVariant (** variant )
100- elif variant ["id" ] == cls .NEUTRAL_STEP_ID :
101- neutral_step = StepReasoningVariant (** variant )
102- elif variant ["id" ] == cls .INCORRECT_STEP_ID :
103- incorrect_step = IncorrectStepReasoningVariant .from_dict (
104- variant
105- )
48+ if variant ["id" ] == 0 :
49+ correct_step = Variant (** variant )
50+ elif variant ["id" ] == 1 :
51+ neutral_step = Variant (** variant )
52+ elif variant ["id" ] == 2 :
53+ incorrect_step = VariantWithActions (** variant )
10654
10755 if not all ([correct_step , neutral_step , incorrect_step ]):
10856 raise ValueError ("Invalid step reasoning variants" )
0 commit comments