From ed0e20766697942f27bfc82efc740cc4187cb5b1 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 15 Jan 2025 13:02:13 -0500 Subject: [PATCH 01/10] Updated pr. --- ads/llm/__init__.py | 19 +- .../langchain/plugins/embeddings/__init__.py | 4 + ..._data_science_model_deployment_endpoint.py | 207 ++++++++++++++++++ .../large_language_model/langchain_models.rst | 20 ++ .../langchain/embeddings/__init__.py | 5 + .../test_oci_model_deployment_endpoint.py | 35 +++ 6 files changed, 282 insertions(+), 8 deletions(-) create mode 100644 ads/llm/langchain/plugins/embeddings/__init__.py create mode 100644 ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py create mode 100644 tests/unitary/with_extras/langchain/embeddings/__init__.py create mode 100644 tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py diff --git a/ads/llm/__init__.py b/ads/llm/__init__.py index b6e9bcab6..4667d6aa9 100644 --- a/ads/llm/__init__.py +++ b/ads/llm/__init__.py @@ -1,21 +1,24 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*-- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ try: import langchain - from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import ( - OCIModelDeploymentVLLM, - OCIModelDeploymentTGI, - ) + + from ads.llm.chat_template import ChatTemplates from ads.llm.langchain.plugins.chat_models.oci_data_science import ( ChatOCIModelDeployment, - ChatOCIModelDeploymentVLLM, ChatOCIModelDeploymentTGI, + ChatOCIModelDeploymentVLLM, + ) + from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentEndpointEmbeddings, + ) + from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentTGI, + OCIModelDeploymentVLLM, ) - from ads.llm.chat_template import ChatTemplates except ImportError as ex: if ex.name == "langchain": raise ImportError( diff --git a/ads/llm/langchain/plugins/embeddings/__init__.py b/ads/llm/langchain/plugins/embeddings/__init__.py new file mode 100644 index 000000000..0fb4f1549 --- /dev/null +++ b/ads/llm/langchain/plugins/embeddings/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py new file mode 100644 index 000000000..cedc242e7 --- /dev/null +++ b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +from typing import Any, Callable, Dict, List, Mapping, Optional + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.utils import get_from_dict_or_env +from pydantic import BaseModel, Field, model_validator + +DEFAULT_HEADER = { + "Content-Type": "application/json", +} + + +class TokenExpiredError(Exception): + pass + + +def _create_retry_decorator(llm) -> Callable[[Any], Any]: + """Creates a retry decorator.""" + errors = [requests.exceptions.ConnectTimeout, TokenExpiredError] + decorator = create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries + ) + return decorator + + +class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): + """Embedding model deployed on OCI Data Science Model Deployment. + + Example: + + .. code-block:: python + + from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings + + embeddings = OCIModelDeploymentEndpointEmbeddings( + endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + ) + """ # noqa: E501 + + auth: dict = Field(default_factory=dict, exclude=True) + """ADS auth dictionary for OCI authentication: + https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html. + This can be generated by calling `ads.common.auth.api_keys()` + or `ads.common.auth.resource_principal()`. If this is not + provided then the `ads.common.default_signer()` will be used.""" + + endpoint: str = "" + """The uri of the endpoint from the deployed Model Deployment model.""" + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + endpoint_kwargs: Optional[Dict] = None + """Optional attributes (except for headers) passed to the request.post + function. + """ + + max_retries: int = 1 + """The maximum number of retries to make when generating.""" + + @model_validator(mode="before") + def validate_environment( # pylint: disable=no-self-argument + cls, values: Dict + ) -> Dict: + """Validate that python package exists in environment.""" + try: + import ads + + except ImportError as ex: + raise ImportError( + "Could not import ads python package. " + "Please install it with `pip install oracle_ads`." + ) from ex + if not values.get("auth"): + values["auth"] = ads.common.auth.default_signer() + values["endpoint"] = get_from_dict_or_env( + values, + "endpoint", + "OCI_LLM_ENDPOINT", + ) + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint": self.endpoint}, + **{"model_kwargs": _model_kwargs}, + } + + def _embed_with_retry(self, **kwargs) -> Any: + """Use tenacity to retry the call.""" + retry_decorator = _create_retry_decorator(self) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + try: + response = requests.post(self.endpoint, **kwargs) + response.raise_for_status() + return response + except requests.exceptions.HTTPError as http_err: + if response.status_code == 401 and self._refresh_signer(): + raise TokenExpiredError() from http_err + else: + raise ValueError( + f"Server error: {str(http_err)}. Message: {response.text}" + ) from http_err + except Exception as e: + raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e + + return _completion_with_retry(**kwargs) + + def _embedding(self, texts: List[str]) -> List[List[float]]: + """Call out to OCI Data Science Model Deployment Endpoint. + + Args: + texts: A list of texts to embed. + + Returns: + A list of list of floats representing the embeddings, or None if an + error occurs. + """ + _model_kwargs = self.model_kwargs or {} + body = self._construct_request_body(texts, _model_kwargs) + request_kwargs = self._construct_request_kwargs(body) + response = self._embed_with_retry(**request_kwargs) + return self._proceses_response(response) + + def _construct_request_kwargs(self, body: Any) -> dict: + """Constructs the request kwargs as a dictionary.""" + from ads.model.common.utils import _is_json_serializable + + _endpoint_kwargs = self.endpoint_kwargs or {} + headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER) + return ( + dict( + headers=headers, + json=body, + auth=self.auth.get("signer"), + **_endpoint_kwargs, + ) + if _is_json_serializable(body) + else dict( + headers=headers, + data=body, + auth=self.auth.get("signer"), + **_endpoint_kwargs, + ) + ) + + def _construct_request_body(self, texts: List[str], params: dict) -> Any: + """Constructs the request body.""" + return {"input": texts} + + def _proceses_response(self, response: requests.Response) -> List[List[float]]: + """Extracts results from requests.Response.""" + try: + res_json = response.json() + embeddings = res_json["data"][0]["embedding"] + except Exception as e: + raise ValueError( + f"Error raised by inference API: {e}.\nResponse: {response.text}" + ) + return embeddings + + def embed_documents( + self, + texts: List[str], + chunk_size: Optional[int] = None, + ) -> List[List[float]]: + """Compute doc embeddings using OCI Data Science Model Deployment Endpoint. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size defines how many input texts will + be grouped together as request. If None, will use the + chunk size specified by the class. + + Returns: + List of embeddings, one for each text. + """ + results = [] + _chunk_size = ( + len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size + ) + for i in range(0, len(texts), _chunk_size): + response = self._embedding(texts[i : i + _chunk_size]) + results.extend(response) + return results + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using OCI Data Science Model Deployment Endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self._embedding([text])[0] diff --git a/docs/source/user_guide/large_language_model/langchain_models.rst b/docs/source/user_guide/large_language_model/langchain_models.rst index a8163b8dc..4273bf7ad 100644 --- a/docs/source/user_guide/large_language_model/langchain_models.rst +++ b/docs/source/user_guide/large_language_model/langchain_models.rst @@ -127,6 +127,26 @@ Chat models takes `chat messages `_. + + +.. code-block:: python3 + + from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings + + # Create an instance of OCI Model Deployment Endpoint + # Replace the endpoint uri with your own + embeddings = OCIModelDeploymentEndpointEmbeddings( + endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + ) + + query = "Hello World!" + embeddings.embed_query(query) + + Tool Calling ============ diff --git a/tests/unitary/with_extras/langchain/embeddings/__init__.py b/tests/unitary/with_extras/langchain/embeddings/__init__.py new file mode 100644 index 000000000..3d8af46df --- /dev/null +++ b/tests/unitary/with_extras/langchain/embeddings/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py new file mode 100644 index 000000000..11fc868df --- /dev/null +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Test OCI Data Science Model Deployment Endpoint.""" + +import responses +from pytest_mock import MockerFixture +from ads.llm import OCIModelDeploymentEndpointEmbeddings + + +@responses.activate +def test_embedding_call(mocker: MockerFixture) -> None: + """Test valid call to oci model deployment endpoint.""" + endpoint = "https://MD_OCID/predict" + documents = ["Hello", "World"] + expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] + responses.add( + responses.POST, + endpoint, + json={ + "data": [{"embedding": expected_output}], + }, + status=200, + ) + mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + + embeddings = OCIModelDeploymentEndpointEmbeddings( + endpoint=endpoint, + ) + + output = embeddings.embed_documents(documents) + assert output == expected_output From cf9ca93f4022eab159ebe8d5492319fd3e6685cd Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 15 Jan 2025 14:13:53 -0500 Subject: [PATCH 02/10] Updated pr. --- .../test_oci_model_deployment_endpoint.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py index 11fc868df..f70aeaa44 100644 --- a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -6,26 +6,23 @@ """Test OCI Data Science Model Deployment Endpoint.""" -import responses -from pytest_mock import MockerFixture +from unittest.mock import MagicMock, patch from ads.llm import OCIModelDeploymentEndpointEmbeddings -@responses.activate -def test_embedding_call(mocker: MockerFixture) -> None: +@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") +def test_embed_documents(mock_embed_with_retry) -> None: """Test valid call to oci model deployment endpoint.""" - endpoint = "https://MD_OCID/predict" - documents = ["Hello", "World"] expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] - responses.add( - responses.POST, - endpoint, - json={ + result = MagicMock() + result.json = MagicMock( + return_value={ "data": [{"embedding": expected_output}], - }, - status=200, + } ) - mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + mock_embed_with_retry.return_value = result + endpoint = "https://MD_OCID/predict" + documents = ["Hello", "World"] embeddings = OCIModelDeploymentEndpointEmbeddings( endpoint=endpoint, @@ -33,3 +30,25 @@ def test_embedding_call(mocker: MockerFixture) -> None: output = embeddings.embed_documents(documents) assert output == expected_output + + +@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") +def test_embed_query(mock_embed_with_retry) -> None: + """Test valid call to oci model deployment endpoint.""" + expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] + result = MagicMock() + result.json = MagicMock( + return_value={ + "data": [{"embedding": expected_output}], + } + ) + mock_embed_with_retry.return_value = result + endpoint = "https://MD_OCID/predict" + query = "Hello world" + + embeddings = OCIModelDeploymentEndpointEmbeddings( + endpoint=endpoint, + ) + + output = embeddings.embed_query(query) + assert output == expected_output[0] From 9d7cce5f7243cdfb13a5dd72bbd381785dc8e76f Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 15 Jan 2025 14:59:08 -0500 Subject: [PATCH 03/10] Updated pr. --- .../embeddings/test_oci_model_deployment_endpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py index f70aeaa44..403e49013 100644 --- a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -7,7 +7,9 @@ """Test OCI Data Science Model Deployment Endpoint.""" from unittest.mock import MagicMock, patch -from ads.llm import OCIModelDeploymentEndpointEmbeddings +from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentEndpointEmbeddings, +) @patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") From 0c6d6ac4adb3d949910d9d2c4f9a3f89907bccf6 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 15 Jan 2025 16:26:12 -0500 Subject: [PATCH 04/10] Updated pr. --- ads/llm/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ads/llm/__init__.py b/ads/llm/__init__.py index 4667d6aa9..7e5ffee6d 100644 --- a/ads/llm/__init__.py +++ b/ads/llm/__init__.py @@ -12,9 +12,6 @@ ChatOCIModelDeploymentTGI, ChatOCIModelDeploymentVLLM, ) - from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import ( - OCIModelDeploymentEndpointEmbeddings, - ) from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import ( OCIModelDeploymentTGI, OCIModelDeploymentVLLM, From fced37d219951100d8030a6fd9fd723ab4689fa6 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Wed, 22 Jan 2025 17:40:46 -0500 Subject: [PATCH 05/10] Updated pr. --- .../oci_data_science_model_deployment_endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py index cedc242e7..1b12932ed 100644 --- a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +++ b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py @@ -8,8 +8,8 @@ import requests from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.utils import get_from_dict_or_env -from pydantic import BaseModel, Field, model_validator DEFAULT_HEADER = { "Content-Type": "application/json", @@ -64,7 +64,7 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): max_retries: int = 1 """The maximum number of retries to make when generating.""" - @model_validator(mode="before") + @root_validator() def validate_environment( # pylint: disable=no-self-argument cls, values: Dict ) -> Dict: @@ -167,7 +167,7 @@ def _proceses_response(self, response: requests.Response) -> List[List[float]]: except Exception as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {response.text}" - ) + ) from e return embeddings def embed_documents( From 5f520850c417797ee6219fc1529f9f380292a4f4 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Thu, 23 Jan 2025 13:49:51 -0500 Subject: [PATCH 06/10] Updated pr. --- ads/llm/__init__.py | 3 ++ ..._data_science_model_deployment_endpoint.py | 31 +++---------------- .../large_language_model/langchain_models.rst | 4 +-- .../test_oci_model_deployment_endpoint.py | 12 +++---- 4 files changed, 14 insertions(+), 36 deletions(-) diff --git a/ads/llm/__init__.py b/ads/llm/__init__.py index 7e5ffee6d..c3c76e97a 100644 --- a/ads/llm/__init__.py +++ b/ads/llm/__init__.py @@ -12,6 +12,9 @@ ChatOCIModelDeploymentTGI, ChatOCIModelDeploymentVLLM, ) + from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import ( + OCIDataScienceEmbedding, + ) from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import ( OCIModelDeploymentTGI, OCIModelDeploymentVLLM, diff --git a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py index 1b12932ed..95414b414 100644 --- a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +++ b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py @@ -8,8 +8,7 @@ import requests from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import BaseModel, Field DEFAULT_HEADER = { "Content-Type": "application/json", @@ -29,16 +28,16 @@ def _create_retry_decorator(llm) -> Callable[[Any], Any]: return decorator -class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): +class OCIDataScienceEmbedding(BaseModel, Embeddings): """Embedding model deployed on OCI Data Science Model Deployment. Example: .. code-block:: python - from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings + from ads.llm import OCIDataScienceEmbedding - embeddings = OCIModelDeploymentEndpointEmbeddings( + embeddings = OCIDataScienceEmbedding( endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) """ # noqa: E501 @@ -64,28 +63,6 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): max_retries: int = 1 """The maximum number of retries to make when generating.""" - @root_validator() - def validate_environment( # pylint: disable=no-self-argument - cls, values: Dict - ) -> Dict: - """Validate that python package exists in environment.""" - try: - import ads - - except ImportError as ex: - raise ImportError( - "Could not import ads python package. " - "Please install it with `pip install oracle_ads`." - ) from ex - if not values.get("auth"): - values["auth"] = ads.common.auth.default_signer() - values["endpoint"] = get_from_dict_or_env( - values, - "endpoint", - "OCI_LLM_ENDPOINT", - ) - return values - @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/docs/source/user_guide/large_language_model/langchain_models.rst b/docs/source/user_guide/large_language_model/langchain_models.rst index 4273bf7ad..2eaed250c 100644 --- a/docs/source/user_guide/large_language_model/langchain_models.rst +++ b/docs/source/user_guide/large_language_model/langchain_models.rst @@ -135,11 +135,11 @@ You can also use embedding model that's hosted on a `OCI Data Science Model Depl .. code-block:: python3 - from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings + from ads.llm import OCIDataScienceEmbedding # Create an instance of OCI Model Deployment Endpoint # Replace the endpoint uri with your own - embeddings = OCIModelDeploymentEndpointEmbeddings( + embeddings = OCIDataScienceEmbedding( endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py index 403e49013..b12ef297b 100644 --- a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -7,12 +7,10 @@ """Test OCI Data Science Model Deployment Endpoint.""" from unittest.mock import MagicMock, patch -from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import ( - OCIModelDeploymentEndpointEmbeddings, -) +from ads.llm import OCIDataScienceEmbedding -@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") +@patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry") def test_embed_documents(mock_embed_with_retry) -> None: """Test valid call to oci model deployment endpoint.""" expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] @@ -26,7 +24,7 @@ def test_embed_documents(mock_embed_with_retry) -> None: endpoint = "https://MD_OCID/predict" documents = ["Hello", "World"] - embeddings = OCIModelDeploymentEndpointEmbeddings( + embeddings = OCIDataScienceEmbedding( endpoint=endpoint, ) @@ -34,7 +32,7 @@ def test_embed_documents(mock_embed_with_retry) -> None: assert output == expected_output -@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") +@patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry") def test_embed_query(mock_embed_with_retry) -> None: """Test valid call to oci model deployment endpoint.""" expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] @@ -48,7 +46,7 @@ def test_embed_query(mock_embed_with_retry) -> None: endpoint = "https://MD_OCID/predict" query = "Hello world" - embeddings = OCIModelDeploymentEndpointEmbeddings( + embeddings = OCIDataScienceEmbedding( endpoint=endpoint, ) From 3ce4dc0befb39efc7b0f73b604be1c47918079c7 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Mon, 27 Jan 2025 11:46:19 -0500 Subject: [PATCH 07/10] Updated pr. --- .../embeddings/oci_data_science_model_deployment_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py index 95414b414..a2bdb4a4a 100644 --- a/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py +++ b/ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py @@ -8,7 +8,7 @@ import requests from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field DEFAULT_HEADER = { "Content-Type": "application/json", From 5fefbd669a4a2f2960807a88f41695f8003afae3 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Mon, 27 Jan 2025 14:54:37 -0500 Subject: [PATCH 08/10] Updated test run yaml. --- .github/workflows/run-unittests-py38-cov-report.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/run-unittests-py38-cov-report.yml b/.github/workflows/run-unittests-py38-cov-report.yml index 9fcfa4cbe..25fcc391c 100644 --- a/.github/workflows/run-unittests-py38-cov-report.yml +++ b/.github/workflows/run-unittests-py38-cov-report.yml @@ -46,7 +46,8 @@ jobs: --ignore tests/unitary/with_extras/feature_store \ --ignore tests/unitary/with_extras/operator/feature-store \ --ignore tests/unitary/with_extras/operator/forecast \ - --ignore tests/unitary/with_extras/hpo + --ignore tests/unitary/with_extras/hpo \ + --ignore tests/unitary/with_extras/langchain - name: "slow_tests" test-path: "tests/unitary/with_extras/model" From be9a2e383e9156b3b51388ede1cc591402282490 Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Mon, 3 Feb 2025 21:06:38 -0500 Subject: [PATCH 09/10] Updated pr. --- .github/workflows/run-unittests-py38-cov-report.yml | 3 +-- .../embeddings/test_oci_model_deployment_endpoint.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-unittests-py38-cov-report.yml b/.github/workflows/run-unittests-py38-cov-report.yml index 25fcc391c..9fcfa4cbe 100644 --- a/.github/workflows/run-unittests-py38-cov-report.yml +++ b/.github/workflows/run-unittests-py38-cov-report.yml @@ -46,8 +46,7 @@ jobs: --ignore tests/unitary/with_extras/feature_store \ --ignore tests/unitary/with_extras/operator/feature-store \ --ignore tests/unitary/with_extras/operator/forecast \ - --ignore tests/unitary/with_extras/hpo \ - --ignore tests/unitary/with_extras/langchain + --ignore tests/unitary/with_extras/hpo - name: "slow_tests" test-path: "tests/unitary/with_extras/model" diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py index b12ef297b..8cdbb9ed7 100644 --- a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -6,9 +6,14 @@ """Test OCI Data Science Model Deployment Endpoint.""" +import pytest +import sys from unittest.mock import MagicMock, patch from ads.llm import OCIDataScienceEmbedding +if sys.version_info < (3, 9): + pytest.skip(allow_module_level=True) + @patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry") def test_embed_documents(mock_embed_with_retry) -> None: From 5d702678242df26eb41fa087b05d3185438a7afb Mon Sep 17 00:00:00 2001 From: Lu Peng Date: Tue, 4 Feb 2025 10:19:41 -0500 Subject: [PATCH 10/10] Updated pr. --- .../langchain/embeddings/test_oci_model_deployment_endpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py index 8cdbb9ed7..1dcbe1581 100644 --- a/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py +++ b/tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py @@ -9,11 +9,12 @@ import pytest import sys from unittest.mock import MagicMock, patch -from ads.llm import OCIDataScienceEmbedding if sys.version_info < (3, 9): pytest.skip(allow_module_level=True) +from ads.llm import OCIDataScienceEmbedding + @patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry") def test_embed_documents(mock_embed_with_retry) -> None: