Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions ads/telemetry/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging

from ads import set_auth
import oci

from ads.common import oci_client as oc
from ads.common.auth import default_signer
from ads.common.auth import default_signer, resource_principal
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION


logger = logging.getLogger(__name__)


class TelemetryBase:
"""Base class for Telemetry Client."""

Expand All @@ -25,15 +26,21 @@ def __init__(self, bucket: str, namespace: str = None) -> None:
namespace : str, optional
Namespace of the OCI object storage bucket, by default None.
"""
# Use resource principal as authentication method if available,
# however, do not change the ADS authentication if user configured it by set_auth.
if OCI_RESOURCE_PRINCIPAL_VERSION:
set_auth("resource_principal")
self._auth = default_signer()
self.os_client = oc.OCIClientFactory(**self._auth).object_storage
self._auth = resource_principal()
else:
self._auth = default_signer()
self.os_client: oci.object_storage.ObjectStorageClient = oc.OCIClientFactory(
**self._auth
).object_storage
self.bucket = bucket
self._namespace = namespace
self._service_endpoint = None
logger.debug(f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}")

logger.debug(
f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}"
)

@property
def namespace(self) -> str:
Expand All @@ -58,5 +65,5 @@ def service_endpoint(self):
Tenancy-specific endpoint.
"""
if not self._service_endpoint:
self._service_endpoint = self.os_client.base_client.endpoint
self._service_endpoint = str(self.os_client.base_client.endpoint)
return self._service_endpoint
46 changes: 33 additions & 13 deletions ads/telemetry/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import logging
import threading
import traceback
import urllib.parse
import requests
from requests import Response
from .base import TelemetryBase
from typing import Optional

import oci

from ads.config import DEBUG_TELEMETRY

from .base import TelemetryBase

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +34,7 @@ class TelemetryClient(TelemetryBase):
>>> import traceback
>>> from ads.telemetry.client import TelemetryClient
>>> AQUA_BUCKET = os.environ.get("AQUA_BUCKET", "service-managed-models")
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "ociodscdev")
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "namespace")
>>> telemetry = TelemetryClient(bucket=AQUA_BUCKET, namespace=AQUA_BUCKET_NS)
>>> telemetry.record_event_async(category="aqua/service/model", action="create") # records create action
>>> telemetry.record_event_async(category="aqua/service/model/create", action="shape", detail="VM.GPU.A10.1")
Expand All @@ -45,7 +47,7 @@ def _encode_user_agent(**kwargs):

def record_event(
self, category: str = None, action: str = None, detail: str = None, **kwargs
) -> Response:
) -> Optional[int]:
"""Send a head request to generate an event record.

Parameters
Expand All @@ -62,23 +64,41 @@ def record_event(

Returns
-------
Response
int
The status code for the telemetry request.
200: The the object exists for the telemetry request
404: The the object does not exist for the telemetry request.
Note that for telemetry purpose, the object does not need to be exist.
`None` will be returned if the telemetry request failed.
"""
try:
if not category or not action:
raise ValueError("Please specify the category and the action.")
if detail:
category, action = f"{category}/{action}", detail
# Here `endpoint`` is for debugging purpose
# For some federated/domain users, the `endpoint` may not be a valid URL
endpoint = f"{self.service_endpoint}/n/{self.namespace}/b/{self.bucket}/o/telemetry/{category}/{action}"
headers = {"User-Agent": self._encode_user_agent(**kwargs)}
logger.debug(f"Sending telemetry to endpoint: {endpoint}")
signer = self._auth["signer"]
response = requests.head(endpoint, auth=signer, headers=headers)
logger.debug(f"Telemetry status code: {response.status_code}")
return response

self.os_client.base_client.user_agent = self._encode_user_agent(**kwargs)
try:
response: oci.response.Response = self.os_client.head_object(
namespace_name=self.namespace,
bucket_name=self.bucket,
object_name=f"telemetry/{category}/{action}",
)
logger.debug(f"Telemetry status: {response.status}")
return response.status
except oci.exceptions.ServiceError as ex:
if ex.status == 404:
return ex.status
raise ex
except Exception as e:
if DEBUG_TELEMETRY:
logger.error(f"There is an error recording telemetry: {e}")
traceback.print_exc()
return None

def record_event_async(
self, category: str = None, action: str = None, detail: str = None, **kwargs
Expand Down
74 changes: 42 additions & 32 deletions tests/unitary/default_setup/telemetry/test_telemetry_client.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,69 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from unittest.mock import patch


from unittest.mock import patch, PropertyMock
import oci

from ads.telemetry.client import TelemetryClient

class TestTelemetryClient:
"""Contains unittests for TelemetryClient."""
TEST_CONFIG = {
"tenancy": "ocid1.tenancy.oc1..unique_ocid",
"user": "ocid1.user.oc1..unique_ocid",
"fingerprint": "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00",
"key_file": "<path>/<to>/<key_file>",
"region": "test_region",
}

endpoint = "https://objectstorage.us-ashburn-1.oraclecloud.com"
EXPECTED_ENDPOINT = "https://objectstorage.test_region.oraclecloud.com"

def mocked_requests_head(*args, **kwargs):
class MockResponse:
def __init__(self, status_code):
self.status_code = status_code

return MockResponse(200)
class TestTelemetryClient:
"""Contains unittests for TelemetryClient."""

@patch('requests.head', side_effect=mocked_requests_head)
@patch('ads.telemetry.client.TelemetryClient.service_endpoint', new_callable=PropertyMock,
return_value=endpoint)
def test_telemetry_client_record_event(self, mock_endpoint, mock_head):
"""Tests TelemetryClient.record_event() with category/action and path, respectively.
"""
@patch("oci.base_client.BaseClient.request")
@patch("oci.signer.Signer")
def test_telemetry_client_record_event(self, signer, request_call):
"""Tests TelemetryClient.record_event() with category/action and path, respectively."""
data = {
"cmd": "ads aqua model list",
"category": "aqua/service/model",
"action": "list",
"bucket": "test_bucket",
"namespace": "test_namespace",
"value": {
"keyword": "test_service_model_name_or_id"
}
"value": {"keyword": "test_service_model_name_or_id"},
}
category = data["category"]
action = data["action"]
bucket = data["bucket"]
namespace = data["namespace"]
value = data["value"]
expected_endpoint = f"{self.endpoint}/n/{namespace}/b/{bucket}/o/telemetry/{category}/{action}"

telemetry = TelemetryClient(bucket=bucket, namespace=namespace)
with patch("oci.config.from_file", return_value=TEST_CONFIG):
telemetry = TelemetryClient(bucket=bucket, namespace=namespace)
telemetry.record_event(category=category, action=action)
telemetry.record_event(category=category, action=action, **value)

expected_headers = [
{'User-Agent': ''},
{'User-Agent': 'keyword=test_service_model_name_or_id'}
expected_agent_headers = [
"",
"keyword=test_service_model_name_or_id",
]
i = 0
for call_args in mock_head.call_args_list:
args, kwargs = call_args
assert all(endpoint == expected_endpoint for endpoint in args)
assert kwargs['headers'] == expected_headers[i]
i += 1

assert len(request_call.call_args_list) == 2
expected_url = f"{EXPECTED_ENDPOINT}/n/{namespace}/b/{bucket}/o/telemetry/{category}/{action}"

# Event #1, no user-agent
args, _ = request_call.call_args_list[0]
request: oci.request.Request = args[0]
operation = args[2]
assert request.url == expected_url
assert operation == "head_object"
assert request.header_params["user-agent"] == expected_agent_headers[0]

# Event #2, with user-agent
args, _ = request_call.call_args_list[1]
request: oci.request.Request = args[0]
operation = args[2]
assert request.url == expected_url
assert operation == "head_object"
assert request.header_params["user-agent"] == expected_agent_headers[1]
Loading