Skip to content

Commit 862a5b6

Browse files
feat: accept custom inference id
1 parent d65007f commit 862a5b6

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/openlayer/lib/integrations/langchain_callback.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(self, **kwargs: Any) -> None:
5050
self.metadata: Dict[str, Any] = kwargs or {}
5151
self.steps: Dict[UUID, steps.Step] = {}
5252
self.root_steps: set[UUID] = set() # Track which steps are root
53+
# Extract inference_id from kwargs if provided
54+
self._inference_id = kwargs.get("inference_id")
5355

5456
def _start_step(
5557
self,
@@ -105,6 +107,9 @@ def _start_step(
105107
# Track root steps (those without parent_run_id)
106108
if parent_run_id is None:
107109
self.root_steps.add(run_id)
110+
# Override step ID with custom inference_id if provided
111+
if self._inference_id is not None:
112+
step.id = self._inference_id
108113

109114
self.steps[run_id] = step
110115
return step
@@ -748,8 +753,12 @@ def __init__(
748753
ignore_chain=False,
749754
ignore_retriever=False,
750755
ignore_agent=False,
756+
inference_id: Optional[Any] = None,
751757
**kwargs: Any,
752758
) -> None:
759+
# Add inference_id to kwargs so it gets passed to mixin
760+
if inference_id is not None:
761+
kwargs["inference_id"] = inference_id
753762
super().__init__(**kwargs)
754763
# Store the ignore flags as instance variables
755764
self._ignore_llm = ignore_llm
@@ -890,8 +899,12 @@ def __init__(
890899
ignore_chain=False,
891900
ignore_retriever=False,
892901
ignore_agent=False,
902+
inference_id: Optional[Any] = None,
893903
**kwargs: Any,
894904
) -> None:
905+
# Add inference_id to kwargs so it gets passed to mixin
906+
if inference_id is not None:
907+
kwargs["inference_id"] = inference_id
895908
super().__init__(**kwargs)
896909
# Store the ignore flags as instance variables
897910
self._ignore_llm = ignore_llm
@@ -962,6 +975,10 @@ def _start_step(
962975
self._traces_by_root[run_id] = trace
963976
self.root_steps.add(run_id)
964977

978+
# Override step ID with custom inference_id if provided
979+
if self._inference_id is not None:
980+
step.id = self._inference_id
981+
965982
self.steps[run_id] = step
966983
return step
967984

0 commit comments

Comments
 (0)