Skip to content

Commit 4fcda80

Browse files
committed
RedshiftDataAPI serverless support #1530
1 parent 159827c commit 4fcda80

File tree

7 files changed

+91
-32
lines changed

7 files changed

+91
-32
lines changed

awswrangler/data_api/redshift.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ class RedshiftDataApi(connector.DataApiConnector):
1616
Parameters
1717
----------
1818
cluster_id: str
19-
Id for the target Redshift cluster.
19+
Id for the target Redshift cluster - only required if `workgroup_name` not provided.
20+
workgroup_name: str
21+
Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
2022
database: str
2123
Target database name.
2224
secret_arn: str
@@ -35,8 +37,9 @@ class RedshiftDataApi(connector.DataApiConnector):
3537

3638
def __init__(
3739
self,
38-
cluster_id: str,
39-
database: str,
40+
cluster_id: str = "",
41+
workgroup_name: str = "",
42+
database: str = "",
4043
secret_arn: str = "",
4144
db_user: str = "",
4245
sleep: float = 0.25,
@@ -45,6 +48,7 @@ def __init__(
4548
boto3_session: Optional[boto3.Session] = None,
4649
) -> None:
4750
self.cluster_id = cluster_id
51+
self.workgroup_name = workgroup_name
4852
self.database = database
4953
self.secret_arn = secret_arn
5054
self.db_user = db_user
@@ -53,22 +57,36 @@ def __init__(
5357
logger: logging.Logger = logging.getLogger(__name__)
5458
super().__init__(self.client, logger)
5559

60+
def _validate_redshift_target(self) -> None:
61+
if self.database == "":
62+
raise ValueError("`database` must be set for connection")
63+
if self.cluster_id == "" and self.workgroup_name == "":
64+
raise ValueError("Either `cluster_id` or `workgroup_name`(Redshift Serverless) must be set for connection")
65+
5666
def _validate_auth_method(self) -> None:
57-
if self.secret_arn == "" and self.db_user == "":
67+
if self.workgroup_name == "" and self.secret_arn == "" and self.db_user == "":
5868
raise ValueError("Either `secret_arn` or `db_user` must be set for authentication")
5969

6070
def _execute_statement(self, sql: str, database: Optional[str] = None) -> str:
71+
self._validate_redshift_target()
6172
self._validate_auth_method()
62-
credentials = {"SecretArn": self.secret_arn}
63-
if self.db_user:
73+
credentials = {}
74+
if self.secret_arn:
75+
credentials = {"SecretArn": self.secret_arn}
76+
elif self.db_user:
6477
credentials = {"DbUser": self.db_user}
6578

6679
if database is None:
6780
database = self.database
6881

82+
if self.cluster_id:
83+
redshift_target = {"ClusterIdentifier": self.cluster_id}
84+
elif self.workgroup_name:
85+
redshift_target = {"WorkgroupName": self.workgroup_name}
86+
6987
self.logger.debug("Executing %s", sql)
7088
response: Dict[str, Any] = self.client.execute_statement(
71-
ClusterIdentifier=self.cluster_id,
89+
**redshift_target,
7290
Database=database,
7391
Sql=sql,
7492
**credentials,
@@ -167,8 +185,9 @@ class RedshiftDataApiTimeoutException(Exception):
167185

168186

169187
def connect(
170-
cluster_id: str,
171-
database: str,
188+
cluster_id: str = "",
189+
workgroup_name: str = "",
190+
database: str = "",
172191
secret_arn: str = "",
173192
db_user: str = "",
174193
boto3_session: Optional[boto3.Session] = None,
@@ -179,7 +198,9 @@ def connect(
179198
Parameters
180199
----------
181200
cluster_id: str
182-
Id for the target Redshift cluster.
201+
Id for the target Redshift cluster - only required if `workgroup_name` not provided.
202+
workgroup_name: str
203+
Name for the target serverless Redshift workgroup - only required if `cluster_id` not provided.
183204
database: str
184205
Target database name.
185206
secret_arn: str
@@ -196,7 +217,13 @@ def connect(
196217
A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`.
197218
"""
198219
return RedshiftDataApi(
199-
cluster_id, database, secret_arn=secret_arn, db_user=db_user, boto3_session=boto3_session, **kwargs
220+
cluster_id=cluster_id,
221+
workgroup_name=workgroup_name,
222+
database=database,
223+
secret_arn=secret_arn,
224+
db_user=db_user,
225+
boto3_session=boto3_session,
226+
**kwargs,
200227
)
201228

202229

poetry.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ classifiers = [
2727
[tool.poetry.dependencies]
2828
python = ">=3.7.1, <3.11"
2929

30-
boto3 = "^1.20.17"
31-
botocore = "^1.23.17"
30+
boto3 = "^1.24.11"
31+
botocore = "^1.27.11"
3232
pandas = "^1.2.0"
3333
numpy = "^1.21.0"
3434
pyarrow = ">=2.0.0, <8.1.0"

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def workgroup3(bucket, kms_key):
130130

131131
@pytest.fixture(scope="session")
132132
def databases_parameters(cloudformation_outputs, db_password):
133-
parameters = dict(postgresql={}, mysql={}, redshift={}, sqlserver={}, mysql_serverless={}, oracle={})
133+
parameters = dict(
134+
postgresql={}, mysql={}, redshift={}, sqlserver={}, mysql_serverless={}, oracle={}, redshift_serverless={}
135+
)
134136
parameters["postgresql"]["host"] = cloudformation_outputs.get("PostgresqlAddress")
135137
parameters["postgresql"]["port"] = 3306
136138
parameters["postgresql"]["schema"] = "public"
@@ -160,6 +162,9 @@ def databases_parameters(cloudformation_outputs, db_password):
160162
parameters["oracle"]["port"] = 1521
161163
parameters["oracle"]["schema"] = "TEST"
162164
parameters["oracle"]["database"] = "ORCL"
165+
parameters["redshift_serverless"]["secret_arn"] = cloudformation_outputs.get("RedshiftServerlessSecretArn")
166+
parameters["redshift_serverless"]["workgroup"] = cloudformation_outputs.get("RedshiftServerlessWorkgroup")
167+
parameters["redshift_serverless"]["database"] = "test"
163168
return parameters
164169

165170

tests/test_data_api.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,44 @@ def redshift_connector(databases_parameters):
1212
cluster_id = databases_parameters["redshift"]["identifier"]
1313
database = databases_parameters["redshift"]["database"]
1414
secret_arn = databases_parameters["redshift"]["secret_arn"]
15-
conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=None)
16-
return conn
15+
con = wr.data_api.redshift.connect(
16+
cluster_id=cluster_id, database=database, secret_arn=secret_arn, boto3_session=None
17+
)
18+
return con
1719

1820

1921
def create_rds_connector(rds_type, parameters):
2022
cluster_id = parameters[rds_type]["arn"]
2123
database = parameters[rds_type]["database"]
2224
secret_arn = parameters[rds_type]["secret_arn"]
23-
conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=boto3.DEFAULT_SESSION)
24-
return conn
25+
con = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=boto3.DEFAULT_SESSION)
26+
return con
2527

2628

2729
@pytest.fixture
2830
def mysql_serverless_connector(databases_parameters):
2931
return create_rds_connector("mysql_serverless", databases_parameters)
3032

3133

34+
def test_connect_redshift_serverless_iam_role(databases_parameters):
35+
workgroup_name = databases_parameters["redshift_serverless"]["workgroup"]
36+
database = databases_parameters["redshift_serverless"]["database"]
37+
con = wr.data_api.redshift.connect(workgroup_name=workgroup_name, database=database, boto3_session=None)
38+
df = wr.data_api.redshift.read_sql_query("SELECT 1", con=con)
39+
assert df.shape == (1, 1)
40+
41+
42+
def test_connect_redshift_serverless_secrets_manager(databases_parameters):
43+
workgroup_name = databases_parameters["redshift_serverless"]["workgroup"]
44+
database = databases_parameters["redshift_serverless"]["database"]
45+
secret_arn = databases_parameters["redshift_serverless"]["secret_arn"]
46+
con = wr.data_api.redshift.connect(
47+
workgroup_name=workgroup_name, database=database, secret_arn=secret_arn, boto3_session=None
48+
)
49+
df = wr.data_api.redshift.read_sql_query("SELECT 1", con=con)
50+
assert df.shape == (1, 1)
51+
52+
3253
@pytest.fixture(scope="function")
3354
def mysql_serverless_table(mysql_serverless_connector):
3455
name = f"tbl_{get_time_str_with_random_suffix()}"

tests/test_moto.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,21 +573,21 @@ def mock_data_api_connector(connector, has_result_set=True):
573573

574574
def test_data_api_redshift_create_connection():
575575
cluster_id = "cluster123"
576-
conn = wr.data_api.redshift.connect(cluster_id, "db1", db_user="admin")
577-
assert conn.cluster_id == cluster_id
576+
con = wr.data_api.redshift.connect(cluster_id=cluster_id, database="db1", db_user="admin")
577+
assert con.cluster_id == cluster_id
578578

579579

580580
def test_data_api_redshift_read_sql_results():
581581
cluster_id = "cluster123"
582-
con = wr.data_api.redshift.connect(cluster_id, "db1", db_user="admin")
582+
con = wr.data_api.redshift.connect(cluster_id=cluster_id, database="db1", db_user="admin")
583583
expected_dataframe = mock_data_api_connector(con)
584584
dataframe = wr.data_api.redshift.read_sql_query("SELECT * FROM test", con=con)
585585
pd.testing.assert_frame_equal(dataframe, expected_dataframe)
586586

587587

588588
def test_data_api_redshift_read_sql_no_results():
589589
cluster_id = "cluster123"
590-
con = wr.data_api.redshift.connect(cluster_id, "db1", db_user="admin")
590+
con = wr.data_api.redshift.connect(cluster_id=cluster_id, database="db1", db_user="admin")
591591
mock_data_api_connector(con, has_result_set=False)
592592
dataframe = wr.data_api.redshift.read_sql_query("DROP TABLE test", con=con)
593593
assert dataframe.empty is True

tutorials/030 - Data Api.ipynb

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,14 @@
6666
" secret_arn=\"arn:aws:secretsmanager:us-east-1:111111111111:secret:aws-sdk-pandas/redshift-ewn43d\"\n",
6767
")\n",
6868
"\n",
69+
"con_redshift_serverless = wr.data_api.redshift.connect(\n",
70+
" workgroup_name=\"aws-sdk-pandas\",\n",
71+
" database=\"test_redshift\",\n",
72+
" secret_arn=\"arn:aws:secretsmanager:us-east-1:111111111111:secret:aws-sdk-pandas/redshift-f3en4w\"\n",
73+
")\n",
74+
"\n",
6975
"con_mysql = wr.data_api.rds.connect(\n",
70-
" cluster_id=\"arn:aws:rds:us-east-1:111111111111:cluster:mysql-serverless-cluster-wrangler\",\n",
76+
" resource_arn=\"arn:aws:rds:us-east-1:111111111111:cluster:mysql-serverless-cluster-wrangler\",\n",
7177
" database=\"test_rds\",\n",
7278
" secret_arn=\"arn:aws:secretsmanager:us-east-1:111111111111:secret:aws-sdk-pandas/mysql-23df3\"\n",
7379
")"
@@ -102,4 +108,4 @@
102108
"metadata": {}
103109
}
104110
]
105-
}
111+
}

0 commit comments

Comments
 (0)