Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions ads/telemetry/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging

from ads import set_auth
import oci

from ads.common import oci_client as oc
from ads.common.auth import default_signer
from ads.common.auth import default_signer, resource_principal
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION


logger = logging.getLogger(__name__)


class TelemetryBase:
"""Base class for Telemetry Client."""

Expand All @@ -25,15 +26,21 @@ def __init__(self, bucket: str, namespace: str = None) -> None:
namespace : str, optional
Namespace of the OCI object storage bucket, by default None.
"""
# Use resource principal as authentication method if available,
# however, do not change the ADS authentication if user configured it by set_auth.
if OCI_RESOURCE_PRINCIPAL_VERSION:
set_auth("resource_principal")
self._auth = default_signer()
self.os_client = oc.OCIClientFactory(**self._auth).object_storage
self._auth = resource_principal()
else:
self._auth = default_signer()
self.os_client: oci.object_storage.ObjectStorageClient = oc.OCIClientFactory(
**self._auth
).object_storage
self.bucket = bucket
self._namespace = namespace
self._service_endpoint = None
logger.debug(f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}")

logger.debug(
f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}"
)

@property
def namespace(self) -> str:
Expand All @@ -58,5 +65,5 @@ def service_endpoint(self):
Tenancy-specific endpoint.
"""
if not self._service_endpoint:
self._service_endpoint = self.os_client.base_client.endpoint
self._service_endpoint = str(self.os_client.base_client.endpoint)
return self._service_endpoint
25 changes: 17 additions & 8 deletions ads/telemetry/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import logging
import threading
import traceback
import urllib.parse
import requests

import oci
from requests import Response
from .base import TelemetryBase

from ads.config import DEBUG_TELEMETRY

from .base import TelemetryBase

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +34,7 @@ class TelemetryClient(TelemetryBase):
>>> import traceback
>>> from ads.telemetry.client import TelemetryClient
>>> AQUA_BUCKET = os.environ.get("AQUA_BUCKET", "service-managed-models")
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "ociodscdev")
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "namespace")
>>> telemetry = TelemetryClient(bucket=AQUA_BUCKET, namespace=AQUA_BUCKET_NS)
>>> telemetry.record_event_async(category="aqua/service/model", action="create") # records create action
>>> telemetry.record_event_async(category="aqua/service/model/create", action="shape", detail="VM.GPU.A10.1")
Expand Down Expand Up @@ -69,16 +71,23 @@ def record_event(
raise ValueError("Please specify the category and the action.")
if detail:
category, action = f"{category}/{action}", detail
# Here `endpoint`` is for debugging purpose
# For some federated/domain users, the `endpoint` may not be a valid URL
endpoint = f"{self.service_endpoint}/n/{self.namespace}/b/{self.bucket}/o/telemetry/{category}/{action}"
headers = {"User-Agent": self._encode_user_agent(**kwargs)}
logger.debug(f"Sending telemetry to endpoint: {endpoint}")
signer = self._auth["signer"]
response = requests.head(endpoint, auth=signer, headers=headers)
logger.debug(f"Telemetry status code: {response.status_code}")

self.os_client.base_client.user_agent = self._encode_user_agent(**kwargs)
response: oci.response.Response = self.os_client.head_object(
namespace_name=self.namespace,
bucket_name=self.bucket,
object_name=f"telemetry/{category}/{action}",
)
logger.debug(f"Telemetry status: {response.status}")
return response
except Exception as e:
if DEBUG_TELEMETRY:
logger.error(f"There is an error recording telemetry: {e}")
traceback.print_exc()

def record_event_async(
self, category: str = None, action: str = None, detail: str = None, **kwargs
Expand Down
Loading