|
51 | 51 | logger = logging.getLogger(__name__) |
52 | 52 |
|
53 | 53 |
|
54 | | -class OCIAuth(httpx.Auth): |
| 54 | +class HttpxOCIAuth(httpx.Auth): |
55 | 55 | """ |
56 | 56 | Custom HTTPX authentication class that uses the OCI Signer for request signing. |
57 | 57 |
|
58 | 58 | Attributes: |
59 | 59 | signer (oci.signer.Signer): The OCI signer used to sign requests. |
60 | 60 | """ |
61 | 61 |
|
62 | | - def __init__(self, signer: oci.signer.Signer): |
| 62 | + def __init__(self, signer: Optional[oci.signer.Signer] = None): |
63 | 63 | """ |
64 | | - Initialize the OCIAuth instance. |
| 64 | + Initialize the HttpxOCIAuth instance. |
65 | 65 |
|
66 | 66 | Args: |
67 | 67 | signer (oci.signer.Signer): The OCI signer to use for signing requests. |
68 | 68 | """ |
69 | | - self.signer = signer |
| 69 | + |
| 70 | + self.signer = signer or authutil.default_signer().get("signer") |
70 | 71 |
|
71 | 72 | def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]: |
72 | 73 | """ |
@@ -256,7 +257,7 @@ def __init__( |
256 | 257 | auth = auth or authutil.default_signer() |
257 | 258 | if not callable(auth.get("signer")): |
258 | 259 | raise ValueError("Auth object must have a 'signer' callable attribute.") |
259 | | - self.auth = OCIAuth(auth["signer"]) |
| 260 | + self.auth = HttpxOCIAuth(auth["signer"]) |
260 | 261 |
|
261 | 262 | logger.debug( |
262 | 263 | f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, " |
@@ -352,7 +353,7 @@ def __init__(self, *args, **kwargs) -> None: |
352 | 353 | **kwargs: Keyword arguments forwarded to BaseClient. |
353 | 354 | """ |
354 | 355 | super().__init__(*args, **kwargs) |
355 | | - self._client = httpx.Client(timeout=self.timeout) |
| 356 | + self._client = httpx.Client(timeout=self.timeout, auth=self.auth) |
356 | 357 |
|
357 | 358 | def is_closed(self) -> bool: |
358 | 359 | return self._client.is_closed |
@@ -400,7 +401,6 @@ def _request( |
400 | 401 | response = self._client.post( |
401 | 402 | self.endpoint, |
402 | 403 | headers=self._prepare_headers(stream=False, headers=headers), |
403 | | - auth=self.auth, |
404 | 404 | json=payload, |
405 | 405 | ) |
406 | 406 | logger.debug(f"Received response with status code: {response.status_code}") |
@@ -447,7 +447,6 @@ def _stream( |
447 | 447 | "POST", |
448 | 448 | self.endpoint, |
449 | 449 | headers=self._prepare_headers(stream=True, headers=headers), |
450 | | - auth=self.auth, |
451 | 450 | json={**payload, "stream": True}, |
452 | 451 | ) as response: |
453 | 452 | try: |
@@ -581,7 +580,7 @@ def __init__(self, *args, **kwargs) -> None: |
581 | 580 | **kwargs: Keyword arguments forwarded to BaseClient. |
582 | 581 | """ |
583 | 582 | super().__init__(*args, **kwargs) |
584 | | - self._client = httpx.AsyncClient(timeout=self.timeout) |
| 583 | + self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth) |
585 | 584 |
|
586 | 585 | def is_closed(self) -> bool: |
587 | 586 | return self._client.is_closed |
@@ -637,7 +636,6 @@ async def _request( |
637 | 636 | response = await self._client.post( |
638 | 637 | self.endpoint, |
639 | 638 | headers=self._prepare_headers(stream=False, headers=headers), |
640 | | - auth=self.auth, |
641 | 639 | json=payload, |
642 | 640 | ) |
643 | 641 | logger.debug(f"Received response with status code: {response.status_code}") |
@@ -683,7 +681,6 @@ async def _stream( |
683 | 681 | "POST", |
684 | 682 | self.endpoint, |
685 | 683 | headers=self._prepare_headers(stream=True, headers=headers), |
686 | | - auth=self.auth, |
687 | 684 | json={**payload, "stream": True}, |
688 | 685 | ) as response: |
689 | 686 | try: |
@@ -797,3 +794,43 @@ async def embeddings( |
797 | 794 | logger.debug(f"Generating embeddings with input: {input}, payload: {payload}") |
798 | 795 | payload = {**(payload or {}), "input": input} |
799 | 796 | return await self._request(payload=payload, headers=headers) |
| 797 | + |
| 798 | + |
| 799 | +def get_httpx_client(**kwargs: Any) -> httpx.Client: |
| 800 | + """ |
| 801 | + Creates and returns a synchronous httpx Client configured with OCI authentication signer based |
| 802 | + the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE. |
| 803 | + More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html |
| 804 | +
|
| 805 | + Parameters |
| 806 | + ---------- |
| 807 | + **kwargs : Any |
| 808 | + Keyword arguments supported by httpx.Client |
| 809 | +
|
| 810 | + Returns |
| 811 | + ------- |
| 812 | + Client |
| 813 | + A configured synchronous httpx Client instance. |
| 814 | + """ |
| 815 | + kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth() |
| 816 | + return httpx.Client(**kwargs) |
| 817 | + |
| 818 | + |
| 819 | +def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient: |
| 820 | + """ |
| 821 | + Creates and returns a synchronous httpx Client configured with OCI authentication signer based |
| 822 | + the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE. |
| 823 | + More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html |
| 824 | +
|
| 825 | + Parameters |
| 826 | + ---------- |
| 827 | + **kwargs : Any |
| 828 | + Keyword arguments supported by httpx.Client |
| 829 | +
|
| 830 | + Returns |
| 831 | + ------- |
| 832 | + AsyncClient |
| 833 | + A configured asynchronous httpx AsyncClient instance. |
| 834 | + """ |
| 835 | + kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth() |
| 836 | + return httpx.AsyncClient(**kwargs) |
0 commit comments