diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 40f6c8c6..99e92431 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -423,17 +423,7 @@ def get_run_context(self) -> RunContext: """ if (run := current_run_span.get()) is None: raise RuntimeError("get_run_context() must be called within a run") - - # Capture OpenTelemetry trace context - trace_context: dict[str, str] = {} - propagate.inject(trace_context) - - return { - "run_id": run.run_id, - "run_name": run.name, - "project": run.project, - "trace_context": trace_context, - } + return run.get_context() ``` @@ -501,30 +491,8 @@ def initialize(self) -> None: f"Failed to connect to the Dreadnode server: {e}", ) from e - headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token} - span_processors.append( - BatchSpanProcessor( - RemovePendingSpansExporter( # This will tell Logfire to emit pending spans to us as well - OTLPSpanExporter( - endpoint=urljoin(self.server, "/api/otel/traces"), - headers=headers, - compression=Compression.Gzip, - ), - ), - ), - ) - # TODO(nick): Metrics - # https://linear.app/dreadnode/issue/ENG-1310/sdk-add-metrics-exports - # metric_readers.append( - # PeriodicExportingMetricReader( - # OTLPMetricExporter( - # endpoint=urljoin(self.server, "/v1/metrics"), - # headers=headers, - # compression=Compression.Gzip, - # # preferred_temporality - # ) - # ) - # ) + span_processors.append(RoutingSpanProcessor(self.server, self.token)) + if self._api is not None: api = self._api self._credential_manager = CredentialManager( @@ -1750,6 +1718,7 @@ run( project: str | None = None, autolog: bool = True, name_prefix: str | None = None, + api_key: str | None = None, attributes: AnyDict | None = None, ) -> RunSpan ``` @@ -1799,6 +1768,16 @@ with dreadnode.run("my_run"): `True` ) –Automatically log task inputs, outputs, and execution metrics if otherwise unspecified. +* **`name_prefix`** + (`str | None`, default: + `None` + ) + –A prefix to use when generating a random name for the run. +* **`api_key`** + (`str | None`, default: + `None` + ) + –An optional API key to use for tracing this run instead of the configured one. * **`attributes`** (`AnyDict | None`, default: `None` @@ -1823,6 +1802,7 @@ def run( project: str | None = None, autolog: bool = True, name_prefix: str | None = None, + api_key: str | None = None, attributes: AnyDict | None = None, ) -> RunSpan: """ @@ -1849,6 +1829,8 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Automatically log task inputs, outputs, and execution metrics if otherwise unspecified. + name_prefix: A prefix to use when generating a random name for the run. + api_key: An optional API key to use for tracing this run instead of the configured one. attributes: Additional attributes to attach to the run span. Returns: @@ -1870,6 +1852,7 @@ def run( tags=tags, credential_manager=self._credential_manager, # type: ignore[arg-type] autolog=autolog, + export_auth_token=api_key, ) ``` @@ -2649,6 +2632,63 @@ def task_span( ``` + + +### using\_api\_key + +```python +using_api_key(api_key: str) -> t.Iterator[None] +``` + +Context manager to temporarily override the API key used for exporting spans. + +This is useful for multi-user scenarios where you want to log data +on behalf of another user. + +Example + +```python +with dreadnode.with_api_key("other_user_api_key"): + with dreadnode.run("my_run"): + # do some work here + pass +``` + +**Parameters:** + +* **`api_key`** + (`str`) + –The API key to use for exporting spans within the context. + + +```python +@contextlib.contextmanager +def using_api_key(self, api_key: str) -> t.Iterator[None]: + """ + Context manager to temporarily override the API key used for exporting spans. + + This is useful for multi-user scenarios where you want to log data + on behalf of another user. + + Example: + ~~~ + with dreadnode.with_api_key("other_user_api_key"): + with dreadnode.run("my_run"): + # do some work here + pass + ~~~ + + Args: + api_key: The API key to use for exporting spans within the context. + """ + token_token = current_export_auth_token_context.set(api_key) + try: + yield + finally: + current_export_auth_token_context.reset(token_token) +``` + + DreadnodeConfigWarning diff --git a/dreadnode/__init__.py b/dreadnode/__init__.py index 7221ef47..65f86765 100644 --- a/dreadnode/__init__.py +++ b/dreadnode/__init__.py @@ -50,6 +50,7 @@ task = DEFAULT_INSTANCE.task task_span = DEFAULT_INSTANCE.task_span run = DEFAULT_INSTANCE.run +using_api_key = DEFAULT_INSTANCE.using_api_key task_and_run = DEFAULT_INSTANCE.task_and_run scorer = DEFAULT_INSTANCE.scorer score = DEFAULT_INSTANCE.score diff --git a/dreadnode/main.py b/dreadnode/main.py index 54da44de..920d123a 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from urllib.parse import urljoin, urlparse, urlunparse +from urllib.parse import urlparse, urlunparse import coolname # type: ignore [import-untyped] import logfire @@ -14,10 +14,6 @@ from fsspec.implementations.local import ( # type: ignore [import-untyped] LocalFileSystem, ) -from logfire._internal.exporters.remove_pending import RemovePendingSpansExporter -from opentelemetry import propagate -from opentelemetry.exporter.otlp.proto.http import Compression -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.trace.export import BatchSpanProcessor from dreadnode.api.client import ApiClient @@ -56,11 +52,13 @@ FileMetricReader, FileSpanExporter, ) +from dreadnode.tracing.processors import RoutingSpanProcessor from dreadnode.tracing.span import ( RunContext, RunSpan, Span, TaskSpan, + current_export_auth_token_context, current_run_span, current_task_span, ) @@ -337,30 +335,8 @@ def initialize(self) -> None: f"Failed to connect to the Dreadnode server: {e}", ) from e - headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token} - span_processors.append( - BatchSpanProcessor( - RemovePendingSpansExporter( # This will tell Logfire to emit pending spans to us as well - OTLPSpanExporter( - endpoint=urljoin(self.server, "/api/otel/traces"), - headers=headers, - compression=Compression.Gzip, - ), - ), - ), - ) - # TODO(nick): Metrics - # https://linear.app/dreadnode/issue/ENG-1310/sdk-add-metrics-exports - # metric_readers.append( - # PeriodicExportingMetricReader( - # OTLPMetricExporter( - # endpoint=urljoin(self.server, "/v1/metrics"), - # headers=headers, - # compression=Compression.Gzip, - # # preferred_temporality - # ) - # ) - # ) + span_processors.append(RoutingSpanProcessor(self.server, self.token)) + if self._api is not None: api = self._api self._credential_manager = CredentialManager( @@ -777,6 +753,7 @@ def run( project: str | None = None, autolog: bool = True, name_prefix: str | None = None, + api_key: str | None = None, attributes: AnyDict | None = None, ) -> RunSpan: """ @@ -803,6 +780,8 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Automatically log task inputs, outputs, and execution metrics if otherwise unspecified. + name_prefix: A prefix to use when generating a random name for the run. + api_key: An optional API key to use for tracing this run instead of the configured one. attributes: Additional attributes to attach to the run span. Returns: @@ -824,8 +803,34 @@ def run( tags=tags, credential_manager=self._credential_manager, # type: ignore[arg-type] autolog=autolog, + export_auth_token=api_key, ) + @contextlib.contextmanager + def using_api_key(self, api_key: str) -> t.Iterator[None]: + """ + Context manager to temporarily override the API key used for exporting spans. + + This is useful for multi-user scenarios where you want to log data + on behalf of another user. + + Example: + ``` + with dreadnode.with_api_key("other_user_api_key"): + with dreadnode.run("my_run"): + # do some work here + pass + ``` + + Args: + api_key: The API key to use for exporting spans within the context. + """ + token_token = current_export_auth_token_context.set(api_key) + try: + yield + finally: + current_export_auth_token_context.reset(token_token) + @contextlib.contextmanager def task_and_run( self, @@ -877,17 +882,7 @@ def get_run_context(self) -> RunContext: """ if (run := current_run_span.get()) is None: raise RuntimeError("get_run_context() must be called within a run") - - # Capture OpenTelemetry trace context - trace_context: dict[str, str] = {} - propagate.inject(trace_context) - - return { - "run_id": run.run_id, - "run_name": run.name, - "project": run.project, - "trace_context": trace_context, - } + return run.get_context() def continue_run(self, run_context: RunContext) -> RunSpan: """ diff --git a/dreadnode/tracing/constants.py b/dreadnode/tracing/constants.py index be99b2a2..ae98a432 100644 --- a/dreadnode/tracing/constants.py +++ b/dreadnode/tracing/constants.py @@ -33,3 +33,6 @@ EVENT_ATTRIBUTE_ORIGIN_SPAN_ID = f"{SPAN_NAMESPACE}.origin.span_id" METRIC_ATTRIBUTE_SOURCE_HASH = f"{SPAN_NAMESPACE}.origin.hash" + +# Internal use only - used to support multi-user export flows +SPAN_RESOURCE_ATTRIBUTE_TOKEN = "_dreadnode_token" # noqa: S105 # nosec diff --git a/dreadnode/tracing/processors.py b/dreadnode/tracing/processors.py new file mode 100644 index 00000000..3b00f7f8 --- /dev/null +++ b/dreadnode/tracing/processors.py @@ -0,0 +1,81 @@ +import threading +import typing as t +from urllib.parse import urljoin + +from logfire._internal.exporters.dynamic_batch import DynamicBatchSpanProcessor +from logfire._internal.exporters.remove_pending import RemovePendingSpansExporter +from opentelemetry.exporter.otlp.proto.http import Compression +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor + +from dreadnode.tracing.constants import SPAN_RESOURCE_ATTRIBUTE_TOKEN +from dreadnode.version import VERSION + +if t.TYPE_CHECKING: + from opentelemetry.context import Context + from opentelemetry.trace import Span + + +class RoutingSpanProcessor(SpanProcessor): + """ + A SpanProcessor that routes spans to different BatchSpanProcessors based on + a token attached to the span object. + + This allows a single application to export spans for multiple users/tokens + to the same backend. + """ + + def __init__( + self, + server_url: str, + default_token: str, + *, + token_header_name: str = "X-Api-Key", # noqa: S107 + span_token_attribute_name: str = SPAN_RESOURCE_ATTRIBUTE_TOKEN, + ): + self._server_url = server_url + self._default_token = default_token + self._token_header_name = token_header_name + self._span_token_attribute_name = span_token_attribute_name + self._processors: dict[str, SpanProcessor] = {} + self._lock = threading.Lock() + self._get_or_create_processor(self._default_token) + + def _get_or_create_processor(self, token: str) -> SpanProcessor: + """Lazily creates and caches a BatchSpanProcessor for a given token.""" + with self._lock: + if token not in self._processors: + headers = {"User-Agent": f"dreadnode/{VERSION}", self._token_header_name: token} + self._processors[token] = DynamicBatchSpanProcessor( + RemovePendingSpansExporter( + OTLPSpanExporter( + endpoint=urljoin(self._server_url, "/api/otel/traces"), + headers=headers, + compression=Compression.Gzip, + ) + ) + ) + return self._processors[token] + + def on_start(self, span: "Span", parent_context: "Context | None" = None) -> None: + """No-op. Spans are routed on end.""" + + def on_end(self, span: ReadableSpan) -> None: + """Routes the span to the correct processor based on its token.""" + # We use the resource here to prevent it from being lost during conversions + token = getattr(span.resource, self._span_token_attribute_name, self._default_token) + processor = self._get_or_create_processor(token) + processor.on_end(span) + + def shutdown(self) -> None: + """Shuts down all managed processors.""" + with self._lock: + for processor in self._processors.values(): + processor.shutdown() + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """Flushes all managed processors.""" + with self._lock: + # OTel spec says this should return True only if all flushes succeed. + results = [p.force_flush(timeout_millis) for p in self._processors.values()] + return all(results) diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index d2db90ac..fa17263d 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -64,6 +64,7 @@ SPAN_ATTRIBUTE_TAGS_, SPAN_ATTRIBUTE_TYPE, SPAN_ATTRIBUTE_VERSION, + SPAN_RESOURCE_ATTRIBUTE_TOKEN, SpanType, ) from dreadnode.util import clean_str @@ -87,6 +88,12 @@ default=None, ) +# Used to override the export api-key for multi-user scenarios +current_export_auth_token_context: ContextVar[str | None] = ContextVar( + "_current_token_context", + default=None, +) + def _format_status(status: Status) -> str: """Format the status for display.""" @@ -142,6 +149,10 @@ def __enter__(self) -> te.Self: attributes=prepare_otlp_attributes(self._pre_attributes), ) + if token := current_export_auth_token_context.get(): + # We use the resource here to prevent it from being lost during conversions + setattr(self._span._resource, SPAN_RESOURCE_ATTRIBUTE_TOKEN, token) # type: ignore[attr-defined] # noqa: SLF001 + self._span.__enter__() OPEN_SPANS.add(self._span) # type: ignore [arg-type] @@ -312,13 +323,14 @@ def __str__(self) -> str: return f"{self._span_name} ({self._label})" if self._label else self._span_name -class RunContext(te.TypedDict): +class RunContext(te.TypedDict, total=False): """Context for transferring and continuing runs in other places.""" run_id: str run_name: str project: str trace_context: dict[str, str] + export_auth_token: str | None class RunUpdateSpan(Span): @@ -383,6 +395,7 @@ def __init__( autolog: bool = True, update_frequency: int = 5, run_id: str | ULID | None = None, + export_auth_token: str | None = None, type: SpanType = "run", ) -> None: self.autolog = autolog @@ -395,6 +408,10 @@ def __init__( self._inputs: list[ObjectRef] = [] self._outputs: list[ObjectRef] = [] + # Export auth token for span export router + self._export_auth_token = export_auth_token + self._export_auth_token_token: Token[str | None] | None = None # For managing context + # Credential manager for S3 operations self._credential_manager = credential_manager @@ -444,6 +461,7 @@ def from_context( type="run_fragment", run_id=context["run_id"], credential_manager=credential_manager, + export_auth_token=context["export_auth_token"], ) self._remote_context = context["trace_context"] @@ -453,6 +471,11 @@ def __enter__(self) -> te.Self: if current_run_span.get() is not None: raise RuntimeError("You cannot start a run span within another run") + if self._export_auth_token: + self._export_auth_token_token = current_export_auth_token_context.set( + self._export_auth_token + ) + if self._remote_context is not None: # If the global propagator is a NoExtract instance, we can't continue # a trace, so we'll bypass it and use the W3C propagator directly. @@ -518,6 +541,27 @@ def __exit__( if self._context_token is not None: current_run_span.reset(self._context_token) + if self._export_auth_token_token: + current_export_auth_token_context.reset(self._export_auth_token_token) + + def get_context(self) -> RunContext: + """ + Capture the current run context for transfer to another host, thread, or process. + """ + # Capture OpenTelemetry trace context + trace_context: dict[str, str] = {} + propagate.inject(trace_context) + + context: RunContext = { + "run_id": self.run_id, + "run_name": self.name, + "project": self.project, + "trace_context": trace_context, + "export_auth_token": current_export_auth_token_context.get(), + } + + return context + def push_update(self, *, force: bool = False) -> None: if self._span is None: return