@@ -16,7 +16,9 @@ class RedshiftDataApi(connector.DataApiConnector):
1616 Parameters
1717 ----------
1818 cluster_id: str
19- Id for the target Redshift cluster.
19+ Id for the target Redshift cluster - only required if `workgroup_name` not provided.
20+ workgroup_name: str
21+ Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
2022 database: str
2123 Target database name.
2224 secret_arn: str
@@ -35,8 +37,9 @@ class RedshiftDataApi(connector.DataApiConnector):
3537
3638 def __init__ (
3739 self ,
38- cluster_id : str ,
39- database : str ,
40+ cluster_id : str = "" ,
41+ workgroup_name : str = "" ,
42+ database : str = "" ,
4043 secret_arn : str = "" ,
4144 db_user : str = "" ,
4245 sleep : float = 0.25 ,
@@ -45,6 +48,7 @@ def __init__(
4548 boto3_session : Optional [boto3 .Session ] = None ,
4649 ) -> None :
4750 self .cluster_id = cluster_id
51+ self .workgroup_name = workgroup_name
4852 self .database = database
4953 self .secret_arn = secret_arn
5054 self .db_user = db_user
@@ -53,22 +57,36 @@ def __init__(
5357 logger : logging .Logger = logging .getLogger (__name__ )
5458 super ().__init__ (self .client , logger )
5559
60+ def _validate_redshift_target (self ) -> None :
61+ if self .database == "" :
62+ raise ValueError ("`database` must be set for connection" )
63+ if self .cluster_id == "" and self .workgroup_name == "" :
64+ raise ValueError ("Either `cluster_id` or `workgroup_name`(Redshift Serverless) must be set for connection" )
65+
5666 def _validate_auth_method (self ) -> None :
57- if self .secret_arn == "" and self .db_user == "" :
67+ if self .workgroup_name == "" and self . secret_arn == "" and self .db_user == "" :
5868 raise ValueError ("Either `secret_arn` or `db_user` must be set for authentication" )
5969
6070 def _execute_statement (self , sql : str , database : Optional [str ] = None ) -> str :
71+ self ._validate_redshift_target ()
6172 self ._validate_auth_method ()
62- credentials = {"SecretArn" : self .secret_arn }
63- if self .db_user :
73+ credentials = {}
74+ if self .secret_arn :
75+ credentials = {"SecretArn" : self .secret_arn }
76+ elif self .db_user :
6477 credentials = {"DbUser" : self .db_user }
6578
6679 if database is None :
6780 database = self .database
6881
82+ if self .cluster_id :
83+ redshift_target = {"ClusterIdentifier" : self .cluster_id }
84+ elif self .workgroup_name :
85+ redshift_target = {"WorkgroupName" : self .workgroup_name }
86+
6987 self .logger .debug ("Executing %s" , sql )
7088 response : Dict [str , Any ] = self .client .execute_statement (
71- ClusterIdentifier = self . cluster_id ,
89+ ** redshift_target ,
7290 Database = database ,
7391 Sql = sql ,
7492 ** credentials ,
@@ -167,8 +185,9 @@ class RedshiftDataApiTimeoutException(Exception):
167185
168186
169187def connect (
170- cluster_id : str ,
171- database : str ,
188+ cluster_id : str = "" ,
189+ workgroup_name : str = "" ,
190+ database : str = "" ,
172191 secret_arn : str = "" ,
173192 db_user : str = "" ,
174193 boto3_session : Optional [boto3 .Session ] = None ,
@@ -179,7 +198,9 @@ def connect(
179198 Parameters
180199 ----------
181200 cluster_id: str
182- Id for the target Redshift cluster.
201+ Id for the target Redshift cluster - only required if `workgroup_name` not provided.
202+ workgroup_name: str
203+ Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
183204 database: str
184205 Target database name.
185206 secret_arn: str
@@ -196,7 +217,13 @@ def connect(
196217 A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`.
197218 """
198219 return RedshiftDataApi (
199- cluster_id , database , secret_arn = secret_arn , db_user = db_user , boto3_session = boto3_session , ** kwargs
220+ cluster_id = cluster_id ,
221+ workgroup_name = workgroup_name ,
222+ database = database ,
223+ secret_arn = secret_arn ,
224+ db_user = db_user ,
225+ boto3_session = boto3_session ,
226+ ** kwargs ,
200227 )
201228
202229
0 commit comments