Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/labelbox/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Labelbox Python SDK Documentation
search-filters
send-to-annotate-params
slice
step_reasoning_tool
task
task-queue
user
Expand Down
6 changes: 6 additions & 0 deletions docs/labelbox/step_reasoning_tool.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Step Reasoning Tool
===============================================================================================

.. automodule:: labelbox.schema.tool_building.step_reasoning_tool
:members:
:show-inheritance:
5 changes: 3 additions & 2 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
FeatureSchema,
Ontology,
PromptResponseClassification,
Tool,
tool_type_cls_from_type,
)
from labelbox.schema.ontology_kind import (
EditorTaskType,
Expand Down Expand Up @@ -1098,7 +1098,8 @@ def create_ontology_from_feature_schemas(
if "tool" in feature_schema.normalized:
tool = feature_schema.normalized["tool"]
try:
Tool.Type(tool)
tool_type_cls = tool_type_cls_from_type(tool)
tool_type_cls(tool)
tools.append(feature_schema.normalized)
except ValueError:
raise ValueError(
Expand Down
25 changes: 20 additions & 5 deletions libs/labelbox/src/labelbox/schema/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
from labelbox.schema.tool_building.tool_type import ToolType

FeatureSchemaId: Type[str] = Annotated[
str, StringConstraints(min_length=25, max_length=25)
Expand Down Expand Up @@ -187,7 +189,7 @@ def __post_init__(self):
@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
return cls(
class_type=cls.Type(dictionary["type"]),
class_type=Classification.Type(dictionary["type"]),
name=dictionary["name"],
instructions=dictionary["instructions"],
required=dictionary.get("required", False),
Expand Down Expand Up @@ -351,7 +353,7 @@ class Type(Enum):
@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
return cls(
class_type=cls.Type(dictionary["type"]),
class_type=PromptResponseClassification.Type(dictionary["type"]),
name=dictionary["name"],
instructions=dictionary["instructions"],
required=True, # always required
Expand Down Expand Up @@ -458,7 +460,7 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
schema_id=dictionary.get("schemaNodeId", None),
feature_schema_id=dictionary.get("featureSchemaId", None),
required=dictionary.get("required", False),
tool=cls.Type(dictionary["tool"]),
tool=Tool.Type(dictionary["tool"]),
classifications=[
Classification.from_dict(c)
for c in dictionary["classifications"]
Expand Down Expand Up @@ -488,6 +490,18 @@ def add_classification(self, classification: Classification) -> None:
self.classifications.append(classification)


def tool_cls_from_type(tool_type: str):
if tool_type.lower() == ToolType.STEP_REASONING.value:
return StepReasoningTool
return Tool


def tool_type_cls_from_type(tool_type: str):
if tool_type.lower() == ToolType.STEP_REASONING.value:
return ToolType
return Tool.Type


class Ontology(DbObject):
"""An ontology specifies which tools and classifications are available
to a project. This is read only for now.
Expand Down Expand Up @@ -525,7 +539,8 @@ def tools(self) -> List[Tool]:
"""Get list of tools (AKA objects) in an Ontology."""
if self._tools is None:
self._tools = [
Tool.from_dict(tool) for tool in self.normalized["tools"]
tool_cls_from_type(tool["tool"]).from_dict(tool)
for tool in self.normalized["tools"]
]
return self._tools

Expand Down Expand Up @@ -581,7 +596,7 @@ class OntologyBuilder:

"""

tools: List[Tool] = field(default_factory=list)
tools: List[Union[Tool, StepReasoningTool]] = field(default_factory=list)
classifications: List[
Union[Classification, PromptResponseClassification]
] = field(default_factory=list)
Expand Down
2 changes: 2 additions & 0 deletions libs/labelbox/src/labelbox/schema/tool_building/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import labelbox.schema.tool_building.tool_type
import labelbox.schema.tool_building.step_reasoning_tool
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from labelbox.schema.tool_building.tool_type import ToolType


@dataclass
class StepReasoningVariant:
id: int
name: str

def asdict(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.name}


@dataclass
class IncorrectStepReasoningVariant:
id: int
name: str
regenerate_conversations_after_incorrect_step: Optional[bool] = True
rate_alternative_responses: Optional[bool] = False

def asdict(self) -> Dict[str, Any]:
actions = []
if self.regenerate_conversations_after_incorrect_step:
actions.append("regenerateSteps")
if self.rate_alternative_responses:
actions.append("generateAndRateAlternativeSteps")
return {"id": self.id, "name": self.name, "actions": actions}

@classmethod
def from_dict(
cls, dictionary: Dict[str, Any]
) -> "IncorrectStepReasoningVariant":
return cls(
id=dictionary["id"],
name=dictionary["name"],
regenerate_conversations_after_incorrect_step="regenerateSteps"
in dictionary.get("actions", []),
rate_alternative_responses="generateAndRateAlternativeSteps"
in dictionary.get("actions", []),
)


def _create_correct_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.CORRECT_STEP_ID, name="Correct"
)


def _create_neutral_step() -> StepReasoningVariant:
return StepReasoningVariant(
id=StepReasoningVariants.NEUTRAL_STEP_ID, name="Neutral"
)


def _create_incorrect_step() -> IncorrectStepReasoningVariant:
return IncorrectStepReasoningVariant(
id=StepReasoningVariants.INCORRECT_STEP_ID, name="Incorrect"
)


@dataclass
class StepReasoningVariants:
"""
This class is used to define the possible options for evaluating a step
Currently the options are correct, neutral, and incorrect
"""

CORRECT_STEP_ID = 0
NEUTRAL_STEP_ID = 1
INCORRECT_STEP_ID = 2

correct_step: StepReasoningVariant = field(
default_factory=_create_correct_step
)
neutral_step: StepReasoningVariant = field(
default_factory=_create_neutral_step
)
incorrect_step: IncorrectStepReasoningVariant = field(
default_factory=_create_incorrect_step
)

def asdict(self):
return [
self.correct_step.asdict(),
self.neutral_step.asdict(),
self.incorrect_step.asdict(),
]

@classmethod
def from_dict(cls, dictionary: List[Dict[str, Any]]):
correct_step = None
neutral_step = None
incorrect_step = None

for variant in dictionary:
if variant["id"] == cls.CORRECT_STEP_ID:
correct_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.NEUTRAL_STEP_ID:
neutral_step = StepReasoningVariant(**variant)
elif variant["id"] == cls.INCORRECT_STEP_ID:
incorrect_step = IncorrectStepReasoningVariant.from_dict(
variant
)

if not all([correct_step, neutral_step, incorrect_step]):
raise ValueError("Invalid step reasoning variants")

return cls(
correct_step=correct_step, # type: ignore
neutral_step=neutral_step, # type: ignore
incorrect_step=incorrect_step, # type: ignore
)


@dataclass
class StepReasoningDefinition:
variants: StepReasoningVariants = field(
default_factory=StepReasoningVariants
)
version: int = field(default=1)
title: Optional[str] = None
value: Optional[str] = None

def asdict(self) -> Dict[str, Any]:
result = {"variants": self.variants.asdict(), "version": self.version}
if self.title is not None:
result["title"] = self.title
if self.value is not None:
result["value"] = self.value
return result

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningDefinition":
variants = StepReasoningVariants.from_dict(dictionary["variants"])
title = dictionary.get("title", None)
value = dictionary.get("value", None)
return cls(variants=variants, title=title, value=value)


@dataclass
class StepReasoningTool:
"""
Use this class in OntologyBuilder to create a tool for step reasoning
The definition field lists the possible options to evaulate a step
"""

name: str
type: ToolType = field(default=ToolType.STEP_REASONING, init=False)
required: bool = False
schema_id: Optional[str] = None
feature_schema_id: Optional[str] = None
color: Optional[str] = None
definition: StepReasoningDefinition = field(
default_factory=StepReasoningDefinition
)

def reset_regenerate_conversations_after_incorrect_step(self):
"""
For live models, the default acation will invoke the model to generate alternatives if a step is marked as incorrect
This method will reset the action to not regenerate the conversation
"""
self.definition.variants.incorrect_step.regenerate_conversations_after_incorrect_step = False

def set_rate_alternative_responses(self):
"""
For live models, will require labelers to rate the alternatives generated by the model
"""
self.definition.variants.incorrect_step.rate_alternative_responses = (
True
)

def asdict(self) -> Dict[str, Any]:
return {
"tool": self.type.value,
"name": self.name,
"required": self.required,
"schemaNodeId": self.schema_id,
"featureSchemaId": self.feature_schema_id,
"definition": self.definition.asdict(),
}

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "StepReasoningTool":
return cls(
name=dictionary["name"],
schema_id=dictionary.get("schemaNodeId", None),
feature_schema_id=dictionary.get("featureSchemaId", None),
required=dictionary.get("required", False),
definition=StepReasoningDefinition.from_dict(
dictionary["definition"]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from enum import Enum


class ToolType(Enum):
STEP_REASONING = "step-reasoning"
6 changes: 4 additions & 2 deletions libs/labelbox/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
)
from labelbox.schema.data_row import DataRowMetadataField
from labelbox.schema.ontology_kind import OntologyKind
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
from labelbox.schema.tool_building.tool_type import ToolType
from labelbox.schema.user import User


Expand Down Expand Up @@ -562,6 +564,7 @@ def feature_schema(client, point):
@pytest.fixture
def chat_evaluation_ontology(client, rand_gen):
ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}"

ontology_builder = OntologyBuilder(
tools=[
Tool(
Expand All @@ -576,6 +579,7 @@ def chat_evaluation_ontology(client, rand_gen):
tool=Tool.Type.MESSAGE_RANKING,
name="model output multi ranking",
),
StepReasoningTool(name="step reasoning"),
],
classifications=[
Classification(
Expand Down Expand Up @@ -626,14 +630,12 @@ def chat_evaluation_ontology(client, rand_gen):
),
],
)

ontology = client.create_ontology(
ontology_name,
ontology_builder.asdict(),
media_type=MediaType.Conversational,
ontology_kind=OntologyKind.ModelEvaluation,
)

yield ontology

try:
Expand Down
Loading
Loading