Skip to content

Commit c546c56

Browse files
SNOW-294266 make connection object exit autocommit aware (#641)
1 parent 2dc8f4a commit c546c56

File tree

5 files changed

+47
-31
lines changed

5 files changed

+47
-31
lines changed

src/snowflake/connector/auth.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from os import getenv, makedirs, mkdir, path, remove, removedirs, rmdir
1616
from os.path import expanduser
1717
from threading import Lock, Thread
18+
from typing import Dict, Union
1819

1920
from .auth_keypair import AuthByKeyPair
2021
from .auth_usrpwdmfa import AuthByUsrPwdMfa
@@ -124,7 +125,7 @@ def authenticate(self, auth_instance, account, user,
124125
warehouse=None, role=None, passcode=None,
125126
passcode_in_password=False,
126127
mfa_callback=None, password_callback=None,
127-
session_parameters=None, timeout=120):
128+
session_parameters=None, timeout=120) -> Dict[str, Union[str, int, bool]]:
128129
logger.debug('authenticate')
129130

130131
if session_parameters is None:
@@ -356,9 +357,10 @@ def post_request_wrapper(self, url, headers, body):
356357
self._rest._connection._schema = session_info.get('schemaName')
357358
self._rest._connection._warehouse = session_info.get('warehouseName')
358359
self._rest._connection._role = session_info.get('roleName')
359-
self._rest._connection._set_parameters(ret, session_parameters)
360-
361-
return session_parameters
360+
if 'parameters' in ret['data']:
361+
session_parameters.update({p['name']: p['value'] for p in ret['data']['parameters']})
362+
self._rest._connection._update_parameters(session_parameters)
363+
return session_parameters
362364

363365
def _read_temporary_credential(self, host, user, cred_type):
364366
cred = None

src/snowflake/connector/connection.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(self, **kwargs):
214214
self._async_sfqids = set()
215215
self._done_async_sfqids = set()
216216
self.telemetry_enabled = False
217+
self._session_parameters: Dict[str, Union[str, int, bool]] = {}
217218
logger.info(
218219
"Snowflake Connector for Python Version: %s, "
219220
"Python Version: %s, Platform: %s",
@@ -889,7 +890,7 @@ def __authenticate(self, auth_instance):
889890
auth_instance, 'consent_cache_id_token', True)
890891

891892
auth = Auth(self.rest)
892-
self._session_parameters = auth.authenticate(
893+
auth.authenticate(
893894
auth_instance=auth_instance,
894895
account=self.account,
895896
user=self.user,
@@ -1114,17 +1115,15 @@ def _validate_client_prefetch_threads(self):
11141115
self.client_prefetch_threads)
11151116
return self.client_prefetch_threads
11161117

1117-
def _set_parameters(self, ret, session_parameters):
1118-
"""Sets session parameters."""
1119-
if 'parameters' not in ret['data']:
1120-
return
1121-
parameters = ret['data']['parameters']
1118+
def _update_parameters(
1119+
self,
1120+
parameters: Dict[str, Union[str, int, bool]],
1121+
) -> None:
1122+
"""Update session parameters."""
11221123
with self._lock_converter:
11231124
self.converter.set_parameters(parameters)
1124-
for kv in parameters:
1125-
name = kv['name']
1126-
value = kv['value']
1127-
session_parameters[name] = value
1125+
for name, value in parameters.items():
1126+
self._session_parameters[name] = value
11281127
if PARAMETER_CLIENT_TELEMETRY_ENABLED == name:
11291128
self.telemetry_enabled = value
11301129
elif PARAMETER_CLIENT_TELEMETRY_OOB_ENABLED == name:
@@ -1153,10 +1152,12 @@ def __enter__(self):
11531152

11541153
def __exit__(self, exc_type, exc_val, exc_tb):
11551154
"""Context manager with commit or rollback teardown."""
1156-
if exc_tb is None:
1157-
self.commit()
1158-
else:
1159-
self.rollback()
1155+
if not self._session_parameters.get('AUTOCOMMIT', False):
1156+
# Either AUTOCOMMIT is turned off, or is not set so we default to old behavior
1157+
if exc_tb is None:
1158+
self.commit()
1159+
else:
1160+
self.rollback()
11601161
self.close()
11611162

11621163
def _get_query_status(self, sf_qid: str) -> Tuple[QueryStatus, Dict[str, Any]]:

src/snowflake/connector/converter.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from datetime import time as dt_t
1212
from datetime import timedelta
1313
from logging import getLogger
14-
from typing import Any, Tuple, Union
14+
from typing import Any, Dict, Optional, Tuple, Union
1515

1616
import pytz
1717

@@ -129,24 +129,22 @@ def _generate_tzinfo_from_tzoffset(tzoffset_minutes: int) -> pytz._FixedOffset:
129129

130130
class SnowflakeConverter(object):
131131
def __init__(self, **kwargs):
132-
self._parameters = {}
132+
self._parameters: Dict[str, Union[str, int, bool]] = {}
133133
self._use_numpy = kwargs.get('use_numpy', False) and numpy is not None
134134

135135
logger.debug('use_numpy: %s', self._use_numpy)
136136

137-
def set_parameters(self, parameters):
138-
self._parameters = {}
139-
for kv in parameters:
140-
self._parameters[kv['name']] = kv['value']
137+
def set_parameters(self, new_parameters: Dict) -> None:
138+
self._parameters = new_parameters
141139

142-
def set_parameter(self, param, value):
140+
def set_parameter(self, param: Any, value: Any) -> None:
143141
self._parameters[param] = value
144142

145-
def get_parameters(self):
143+
def get_parameters(self) -> Dict[str, Union[str, int, bool]]:
146144
return self._parameters
147145

148-
def get_parameter(self, param):
149-
return self._parameters[param] if param in self._parameters else None
146+
def get_parameter(self, param: str) -> Optional[Union[str, int, bool]]:
147+
return self._parameters.get(param)
150148

151149
#
152150
# FROM Snowflake to Python Objects

src/snowflake/connector/cursor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def interrupt_handler(*_): # pragma: no cover
415415
logger.debug('cancelled timebomb in finally')
416416

417417
if 'data' in ret and 'parameters' in ret['data']:
418-
for kv in ret['data']['parameters']:
418+
parameters = ret['data']['parameters']
419+
# Set session parameters for cursor object
420+
for kv in parameters:
419421
if 'TIMESTAMP_OUTPUT_FORMAT' in kv['name']:
420422
self._timestamp_output_format = kv['value']
421423
if 'TIMESTAMP_NTZ_OUTPUT_FORMAT' in kv['name']:
@@ -432,8 +434,8 @@ def interrupt_handler(*_): # pragma: no cover
432434
self._timezone = kv['value']
433435
if 'BINARY_OUTPUT_FORMAT' in kv['name']:
434436
self._binary_output_format = kv['value']
435-
self._connection._set_parameters(
436-
ret, self._connection._session_parameters)
437+
# Set session parameters for connection object
438+
self._connection._update_parameters({p['name']: p['value'] for p in parameters})
437439

438440
self._sequence_counter = -1
439441
return ret

test/integ/test_connection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,3 +926,16 @@ def test_process_param_error(conn_cnx):
926926
side_effect=Exception('test')):
927927
conn._process_params(mock.Mock())
928928
assert pe.errno == ER_FAILED_PROCESSING_PYFORMAT
929+
930+
931+
@pytest.mark.parametrize('auto_commit', [pytest.param(True, marks=pytest.mark.skipolddriver), False])
932+
def test_autocommit(conn_cnx, db_parameters, auto_commit):
933+
conn = snowflake.connector.connect(**db_parameters)
934+
with mock.patch.object(conn, 'commit') as mocked_commit:
935+
with conn:
936+
with conn.cursor() as cur:
937+
cur.execute(f"alter session set autocommit = {auto_commit}")
938+
if auto_commit:
939+
assert not mocked_commit.called
940+
else:
941+
assert mocked_commit.called

0 commit comments

Comments
 (0)