Skip to content

Commit cc31d82

Browse files
mdesmethashhar
authored andcommitted
Refactor role into roles
1 parent ecb53be commit cc31d82

File tree

4 files changed

+60
-39
lines changed

4 files changed

+60
-39
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import trino
2121
from tests.integration.conftest import trino_version
22+
from trino import constants
2223
from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
2324
from trino.transaction import IsolationLevel
2425

@@ -1045,11 +1046,11 @@ def test_set_role_trino_higher_351(run_trino):
10451046
cur = trino_connection.cursor()
10461047
cur.execute('SHOW TABLES FROM information_schema')
10471048
cur.fetchall()
1048-
assert cur._request._client_session.role is None
1049+
assert cur._request._client_session.roles == {}
10491050

10501051
cur.execute("SET ROLE ALL")
10511052
cur.fetchall()
1052-
assert cur._request._client_session.role == "system=ALL"
1053+
assert_role_headers(cur, "system=ALL")
10531054

10541055

10551056
@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog")
@@ -1062,11 +1063,15 @@ def test_set_role_trino_351(run_trino):
10621063
cur = trino_connection.cursor()
10631064
cur.execute('SHOW TABLES FROM information_schema')
10641065
cur.fetchall()
1065-
assert cur._request._client_session.role is None
1066+
assert cur._request._client_session.roles == {}
10661067

10671068
cur.execute("SET ROLE ALL")
10681069
cur.fetchall()
1069-
assert cur._request._client_session.role == "tpch=ALL"
1070+
assert_role_headers(cur, "tpch=ALL")
1071+
1072+
1073+
def assert_role_headers(cursor, expected_header):
1074+
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header
10701075

10711076

10721077
def test_prepared_statements(run_trino):

tests/unit/test_client.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import threading
1414
import time
1515
import uuid
16+
from typing import Optional, Dict
1617
from unittest import mock
1718
from urllib.parse import urlparse
1819

@@ -92,10 +93,9 @@ def assert_headers(headers):
9293
assert headers[constants.HEADER_SOURCE] == source
9394
assert headers[constants.HEADER_USER] == user
9495
assert headers[constants.HEADER_SESSION] == ""
95-
assert headers[constants.HEADER_ROLE] is None
9696
assert headers[accept_encoding_header] == accept_encoding_value
9797
assert headers[client_info_header] == client_info_value
98-
assert len(headers.keys()) == 9
98+
assert len(headers.keys()) == 8
9999

100100
req.post("URL")
101101
_, post_kwargs = post.call_args
@@ -988,10 +988,12 @@ def __call__(self):
988988
with_retry(FailerUntil(3).__call__)()
989989

990990

991-
def assert_headers_with_role(headers, role):
991+
def assert_headers_with_roles(headers: Dict[str, str], roles: Optional[str]):
992+
if roles is None:
993+
assert constants.HEADER_ROLE not in headers
994+
else:
995+
assert headers[constants.HEADER_ROLE] == roles
992996
assert headers[constants.HEADER_USER] == "test_user"
993-
assert headers[constants.HEADER_ROLE] == role
994-
assert len(headers.keys()) == 7
995997

996998

997999
def test_request_headers_role_hive_all(mock_get_and_post):
@@ -1001,17 +1003,17 @@ def test_request_headers_role_hive_all(mock_get_and_post):
10011003
port=8080,
10021004
client_session=ClientSession(
10031005
user="test_user",
1004-
role="hive=ALL",
1006+
roles={"hive": "ALL"}
10051007
),
10061008
)
10071009

10081010
req.post("URL")
10091011
_, post_kwargs = post.call_args
1010-
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
1012+
assert_headers_with_roles(post_kwargs["headers"], "hive=ALL")
10111013

10121014
req.get("URL")
10131015
_, get_kwargs = get.call_args
1014-
assert_headers_with_role(post_kwargs["headers"], "hive=ALL")
1016+
assert_headers_with_roles(post_kwargs["headers"], "hive=ALL")
10151017

10161018

10171019
def test_request_headers_role_admin(mock_get_and_post):
@@ -1022,17 +1024,17 @@ def test_request_headers_role_admin(mock_get_and_post):
10221024
port=8080,
10231025
client_session=ClientSession(
10241026
user="test_user",
1025-
role="admin",
1027+
roles={"system": "admin"}
10261028
),
10271029
)
10281030

10291031
req.post("URL")
10301032
_, post_kwargs = post.call_args
1031-
assert_headers_with_role(post_kwargs["headers"], "admin")
1033+
assert_headers_with_roles(post_kwargs["headers"], "system=admin")
10321034

10331035
req.get("URL")
10341036
_, get_kwargs = get.call_args
1035-
assert_headers_with_role(post_kwargs["headers"], "admin")
1037+
assert_headers_with_roles(post_kwargs["headers"], "system=admin")
10361038

10371039

10381040
def test_request_headers_role_empty(mock_get_and_post):
@@ -1043,14 +1045,14 @@ def test_request_headers_role_empty(mock_get_and_post):
10431045
port=8080,
10441046
client_session=ClientSession(
10451047
user="test_user",
1046-
role="",
1048+
roles=None,
10471049
),
10481050
)
10491051

10501052
req.post("URL")
10511053
_, post_kwargs = post.call_args
1052-
assert_headers_with_role(post_kwargs["headers"], "")
1054+
assert_headers_with_roles(post_kwargs["headers"], None)
10531055

10541056
req.get("URL")
10551057
_, get_kwargs = get.call_args
1056-
assert_headers_with_role(post_kwargs["headers"], "")
1058+
assert_headers_with_roles(post_kwargs["headers"], None)

tests/unit/test_dbapi.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,9 @@ def run(self) -> None:
248248

249249
@patch("trino.dbapi.trino.client")
250250
def test_tags_are_set_when_specified(mock_client):
251-
# WHEN
252251
client_tags = ["TAG1", "TAG2"]
253252
with connect("sample_trino_cluster:443", client_tags=client_tags) as conn:
254253
conn.cursor().execute("SOME FAKE QUERY")
255254

256-
# THEN
257255
_, passed_client_tags = mock_client.ClientSession.call_args
258256
assert passed_client_tags["client_tags"] == client_tags

trino/client.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class ClientSession(object):
9898
:param extra_credential: extra credentials. as list of ``(key, value)``
9999
tuples.
100100
:param client_tags: Client tags as list of strings.
101-
:param role: role for the current session. Some connectors do not
101+
:param roles: roles for the current session. Some connectors do not
102102
support role management. See connector documentation for more details.
103103
"""
104104

@@ -113,7 +113,7 @@ def __init__(
113113
transaction_id: str = None,
114114
extra_credential: List[Tuple[str, str]] = None,
115115
client_tags: List[str] = None,
116-
role: str = None,
116+
roles: Dict[str, str] = None,
117117
):
118118
self._user = user
119119
self._catalog = catalog
@@ -124,7 +124,7 @@ def __init__(
124124
self._transaction_id = transaction_id
125125
self._extra_credential = extra_credential
126126
self._client_tags = client_tags
127-
self._role = role
127+
self._roles = roles or {}
128128
self._prepared_statements: Dict[str, str] = {}
129129
self._object_lock = threading.Lock()
130130

@@ -188,24 +188,15 @@ def extra_credential(self):
188188
def client_tags(self):
189189
return self._client_tags
190190

191-
def __getstate__(self):
192-
state = self.__dict__.copy()
193-
del state["_object_lock"]
194-
return state
195-
196-
def __setstate__(self, state):
197-
self.__dict__.update(state)
198-
self._object_lock = threading.Lock()
199-
200191
@property
201-
def role(self):
192+
def roles(self):
202193
with self._object_lock:
203-
return self._role
194+
return self._roles
204195

205-
@role.setter
206-
def role(self, role):
196+
@roles.setter
197+
def roles(self, roles):
207198
with self._object_lock:
208-
self._role = role
199+
self._roles = roles
209200

210201
@property
211202
def prepared_statements(self):
@@ -216,6 +207,15 @@ def prepared_statements(self, prepared_statements):
216207
with self._object_lock:
217208
self._prepared_statements = prepared_statements
218209

210+
def __getstate__(self):
211+
state = self.__dict__.copy()
212+
del state["_object_lock"]
213+
return state
214+
215+
def __setstate__(self, state):
216+
self.__dict__.update(state)
217+
self._object_lock = threading.Lock()
218+
219219

220220
def get_header_values(headers, header):
221221
return [val.strip() for val in headers[header].split(",")]
@@ -237,6 +237,14 @@ def get_prepared_statement_values(headers, header):
237237
]
238238

239239

240+
def get_roles_values(headers, header):
241+
kvs = get_header_values(headers, header)
242+
return [
243+
(k.strip(), urllib.parse.unquote_plus(v.strip()))
244+
for k, v in (kv.split("=", 1) for kv in kvs)
245+
]
246+
247+
240248
class TrinoStatus(object):
241249
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
242250
self.id = id
@@ -400,7 +408,12 @@ def http_headers(self) -> Dict[str, str]:
400408
headers[constants.HEADER_SCHEMA] = self._client_session.schema
401409
headers[constants.HEADER_SOURCE] = self._client_session.source
402410
headers[constants.HEADER_USER] = self._client_session.user
403-
headers[constants.HEADER_ROLE] = self._client_session.role
411+
if len(self._client_session.roles.values()):
412+
headers[constants.HEADER_ROLE] = ",".join(
413+
# ``name`` must not contain ``=``
414+
"{}={}".format(catalog, urllib.parse.quote(str(role)))
415+
for catalog, role in self._client_session.roles.items()
416+
)
404417
if self._client_session.client_tags is not None and len(self._client_session.client_tags) > 0:
405418
headers[constants.HEADER_CLIENT_TAGS] = ",".join(self._client_session.client_tags)
406419

@@ -579,7 +592,10 @@ def process(self, http_response) -> TrinoStatus:
579592
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]
580593

581594
if constants.HEADER_SET_ROLE in http_response.headers:
582-
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]
595+
for key, value in get_roles_values(
596+
http_response.headers, constants.HEADER_SET_ROLE
597+
):
598+
self._client_session.roles[key] = value
583599

584600
if constants.HEADER_ADDED_PREPARE in http_response.headers:
585601
for name, statement in get_prepared_statement_values(

0 commit comments

Comments
 (0)