Skip to content

Commit 61c0f4a

Browse files
mdesmethashhar
authored andcommitted
Implement roles support in dbapi.connect()
1 parent cc31d82 commit 61c0f4a

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,20 @@ cur.execute('SELECT * FROM system.runtime.nodes')
328328
rows = cur.fetchall()
329329
```
330330

331+
## Roles
332+
333+
Authorization roles to use for catalogs, specified as a dict with key-value pairs for the catalog and role. For example, `{"catalog1": "roleA", "catalog2": "roleB"}` sets `roleA` for `catalog1` and `roleB` for `catalog2`. See Trino docs.
334+
335+
```python
336+
import trino
337+
conn = trino.dbapi.connect(
338+
host='localhost',
339+
port=443,
340+
user='the-user',
341+
roles={"catalog1": "roleA", "catalog2": "roleB"},
342+
)
343+
```
344+
331345
## SSL
332346

333347
### SSL verification

tests/integration/test_dbapi_integration.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,19 @@ def test_set_role_trino_351(run_trino):
10701070
assert_role_headers(cur, "tpch=ALL")
10711071

10721072

1073+
@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
1074+
def test_set_role_in_connection_trino_higher_351(run_trino):
1075+
_, host, port = run_trino
1076+
1077+
trino_connection = trino.dbapi.Connection(
1078+
host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}
1079+
)
1080+
cur = trino_connection.cursor()
1081+
cur.execute('SHOW TABLES FROM information_schema')
1082+
cur.fetchall()
1083+
assert_role_headers(cur, "system=ALL")
1084+
1085+
10731086
def assert_role_headers(cursor, expected_header):
10741087
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header
10751088

tests/unit/test_dbapi.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,13 @@ def test_tags_are_set_when_specified(mock_client):
254254

255255
_, passed_client_tags = mock_client.ClientSession.call_args
256256
assert passed_client_tags["client_tags"] == client_tags
257+
258+
259+
@patch("trino.dbapi.trino.client")
260+
def test_role_is_set_when_specified(mock_client):
261+
roles = {"system": "finance"}
262+
with connect("sample_trino_cluster:443", roles=roles) as conn:
263+
conn.cursor().execute("SOME FAKE QUERY")
264+
265+
_, passed_role = mock_client.ClientSession.call_args
266+
assert passed_role["roles"] == roles

trino/dbapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
http_session=None,
111111
client_tags=None,
112112
experimental_python_types=False,
113+
roles=None,
113114
):
114115
self.host = host
115116
self.port = port
@@ -127,7 +128,8 @@ def __init__(
127128
headers=http_headers,
128129
transaction_id=NO_TRANSACTION,
129130
extra_credential=extra_credential,
130-
client_tags=client_tags
131+
client_tags=client_tags,
132+
roles=roles,
131133
)
132134
# mypy cannot follow module import
133135
if http_session is None:

0 commit comments

Comments
 (0)