diff --git a/ads/oracledb/oracle_db.py b/ads/oracledb/oracle_db.py index 117d69259..2d8a363bb 100644 --- a/ads/oracledb/oracle_db.py +++ b/ads/oracledb/oracle_db.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*-- -# Copyright (c) 2021, 2023 Oracle and/or its affiliates. +# Copyright (c) 2021, 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """ @@ -17,19 +16,20 @@ Note: We need to account for cx_Oracle though oracledb can operate in thick mode. The end user may be is using one of the old conda packs or an environment where cx_Oracle is the only available driver. """ -from ads.common.utils import ORACLE_DEFAULT_PORT - import logging -import numpy as np import os -import pandas as pd import tempfile -from time import time -from typing import Dict, Optional, List, Union, Iterator import zipfile +from time import time +from typing import Dict, Iterator, List, Optional, Union + +import numpy as np +import pandas as pd + from ads.common.decorator.runtime_dependency import ( OptionalDependency, ) +from ads.common.utils import ORACLE_DEFAULT_PORT logger = logging.getLogger("ads.oracle_connector") CX_ORACLE = "cx_Oracle" @@ -40,17 +40,17 @@ import oracledb as oracle_driver # Both the driver share same signature for the APIs that we are using. PYTHON_DRIVER_NAME = PYTHON_ORACLEDB -except: +except ModuleNotFoundError: logger.info("oracledb package not found. Trying to load cx_Oracle") try: import cx_Oracle as oracle_driver PYTHON_DRIVER_NAME = CX_ORACLE - except ModuleNotFoundError: + except ModuleNotFoundError as err2: raise ModuleNotFoundError( f"Neither `oracledb` nor `cx_Oracle` module was not found. Please run " f"`pip install {OptionalDependency.DATA}`." - ) + ) from err2 class OracleRDBMSConnection(oracle_driver.Connection): @@ -75,7 +75,7 @@ def __init__( logger.info( "Running oracledb driver in thick mode. For mTLS based connection, thick mode is default." ) - except: + except Exception: logger.info( "Could not use thick mode. The driver is running in thin mode. System might prompt for passphrase" ) @@ -154,7 +154,6 @@ def insert( batch_size=100000, encoding="utf-8", ): - if if_exists not in ["fail", "replace", "append"]: raise ValueError( f"Unknown option `if_exists`={if_exists}. Valid options are 'fail', 'replace', 'append'" @@ -173,7 +172,6 @@ def insert( df_orcl.columns = df_orcl.columns.str.replace(r"\W+", "_", regex=True) table_exist = True with self.cursor() as cursor: - if if_exists != "replace": try: cursor.execute(f"SELECT 1 from {table_name} FETCH NEXT 1 ROWS ONLY") @@ -275,7 +273,6 @@ def chunks(lst: List, batch_size: int): yield lst[i : i + batch_size] for batch in chunks(record_data, batch_size=batch_size): - cursor.executemany(sql, batch, batcherrors=True) for error in cursor.getbatcherrors(): @@ -304,7 +301,6 @@ def _fetch_by_batch(self, cursor, chunksize): def query( self, sql: str, bind_variables: Optional[Dict], chunksize=None ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - start_time = time() cursor = self.cursor() @@ -315,10 +311,8 @@ def query( cursor.execute(sql, **bind_variables) columns = [row[0] for row in cursor.description] df = iter( - ( - pd.DataFrame(data=rows, columns=columns) - for rows in self._fetch_by_batch(cursor, chunksize) - ) + pd.DataFrame(data=rows, columns=columns) + for rows in self._fetch_by_batch(cursor, chunksize) ) else: @@ -332,3 +326,21 @@ def query( ) return df + + +def get_adw_connection(vault_secret_id: str) -> "oracledb.Connection": + """Creates ADW connection from the credentials stored in the vault""" + import oracledb + + from ads.secrets.adb import ADBSecretKeeper + + secret = vault_secret_id + + logging.getLogger().debug("A secret id was used to retrieve credentials.") + creds = ADBSecretKeeper.load_secret(secret).to_dict() + user = creds.pop("user_name", None) + password = creds.pop("password", None) + if not user or not password: + raise ValueError(f"The user or password is missing in {secret}") + logging.getLogger().debug("Downloaded secrets successfully.") + return oracledb.connect(user=user, password=password, **creds)