Skip to content

Commit b7ff349

Browse files
committed
feat(param-style): workaround to support param-style in statements
1 parent c36efcf commit b7ff349

File tree

3 files changed

+167
-42
lines changed

3 files changed

+167
-42
lines changed

src/sqlitecloud/dbapi2.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ def rowcount(self) -> int:
430430
"""
431431
The number of rows that the last .execute*() returned for DQL statements like SELECT or
432432
the number rows affected by DML statements like UPDATE, INSERT and DELETE.
433+
434+
For the executemany() it returns the number of changes only for the last operation.
433435
"""
434436
if self._is_result_rowset():
435437
return self._resultset.nrows
@@ -490,7 +492,9 @@ def execute(
490492

491493
parameters = self._adapt_parameters(parameters)
492494

493-
# TODO: convert parameters from :name to `?` style
495+
if isinstance(parameters, dict):
496+
parameters = self._named_to_question_mark_parameters(sql, parameters)
497+
494498
result = self._driver.execute_statement(
495499
sql, parameters, self.connection.sqlcloud_connection
496500
)
@@ -529,6 +533,9 @@ def executemany(
529533
commands = ""
530534
params = []
531535
for parameters in seq_of_parameters:
536+
if isinstance(parameters, dict):
537+
parameters = self._named_to_question_mark_parameters(sql, parameters)
538+
532539
params += list(parameters)
533540

534541
if not sql.endswith(";"):
@@ -726,6 +733,25 @@ def _apply_text_factory(self, value: Any) -> Any:
726733

727734
return value
728735

736+
def _named_to_question_mark_parameters(
737+
self, sql: str, params: Dict[str, Any]
738+
) -> Tuple[Any]:
739+
"""
740+
Convert named placeholders parameters from a dictionary to a list of
741+
parameters for question mark style.
742+
743+
SCSP protocol does not support named placeholders yet.
744+
"""
745+
pattern = r":(\w+)"
746+
matches = re.findall(pattern, sql)
747+
748+
params_list = ()
749+
for match in matches:
750+
if match in params:
751+
params_list += (params[match],)
752+
753+
return params_list
754+
729755
def _get_value(self, row: int, col: int) -> Optional[Any]:
730756
if not self._is_result_rowset():
731757
return None

src/tests/integration/test_dbapi2.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ def test_execute_with_named_placeholder(self, sqlitecloud_dbapi2_connection):
143143
assert cursor.rowcount == 1
144144
assert cursor.fetchone() == (1, "For Those About To Rock We Salute You", 1)
145145

146+
def test_execute_with_named_placeholder_and_a_fake_one_which_is_not_given(
147+
self, sqlitecloud_dbapi2_connection
148+
):
149+
""" "Expect the converter from name to qmark placeholder to not be fooled by the
150+
fake name with the colon in it."""
151+
connection = sqlitecloud_dbapi2_connection
152+
153+
cursor = connection.execute(
154+
"SELECT * FROM albums WHERE AlbumId = :id and Title != 'special:name'",
155+
{"id": 1},
156+
)
157+
158+
assert cursor.rowcount == 1
159+
assert cursor.fetchone() == (1, "For Those About To Rock We Salute You", 1)
160+
146161
def test_execute_with_qmarks(self, sqlitecloud_dbapi2_connection):
147162
connection = sqlitecloud_dbapi2_connection
148163

@@ -408,7 +423,7 @@ def test_last_rowid_and_rowcount_with_executemany_deletes(
408423
new_name1 = "Jazz" + str(uuid.uuid4())
409424
new_name2 = "Jazz" + str(uuid.uuid4())
410425

411-
cursor_select = connection.executemany(
426+
cursor_insert = connection.executemany(
412427
"INSERT INTO genres (Name) VALUES (?)",
413428
[(new_name1,), (new_name2,)],
414429
)
@@ -418,32 +433,5 @@ def test_last_rowid_and_rowcount_with_executemany_deletes(
418433
)
419434

420435
assert cursor.fetchone() is None
421-
assert cursor.lastrowid == cursor_select.lastrowid
436+
assert cursor.lastrowid == cursor_insert.lastrowid
422437
assert cursor.rowcount == 1
423-
424-
def test_connection_total_changes(self, sqlitecloud_dbapi2_connection):
425-
connection = sqlitecloud_dbapi2_connection
426-
427-
new_name1 = "Jazz" + str(uuid.uuid4())
428-
new_name2 = "Jazz" + str(uuid.uuid4())
429-
new_name3 = "Jazz" + str(uuid.uuid4())
430-
431-
connection.executemany(
432-
"INSERT INTO genres (Name) VALUES (?)",
433-
[(new_name1,), (new_name2,)],
434-
)
435-
assert connection.total_changes == 2
436-
437-
connection.execute("SELECT * FROM genres")
438-
assert connection.total_changes == 2
439-
440-
connection.execute(
441-
"UPDATE genres SET Name = ? WHERE Name = ?", (new_name3, new_name1)
442-
)
443-
assert connection.total_changes == 3
444-
445-
connection.execute(
446-
"DELETE FROM genres WHERE Name in (?, ?, ?)",
447-
(new_name1, new_name2, new_name3),
448-
)
449-
assert connection.total_changes == 5

src/tests/integration/test_sqlite3_parity.py

Lines changed: 123 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import random
22
import sqlite3
3+
import string
34
import sys
45
import time
6+
import uuid
57
from datetime import date, datetime
68

79
import pytest
@@ -69,6 +71,62 @@ def test_create_table_and_insert_many(
6971

7072
assert sqlitecloud_results == sqlite3_results
7173

74+
@pytest.mark.parametrize(
75+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
76+
)
77+
def test_executemany_with_a_iterator(self, connection, request):
78+
connection = request.getfixturevalue(connection)
79+
80+
class IterChars:
81+
def __init__(self):
82+
self.count = ord("a")
83+
84+
def __iter__(self):
85+
return self
86+
87+
def __next__(self):
88+
if self.count > ord("z"):
89+
raise StopIteration
90+
self.count += 1
91+
return (chr(self.count - 1),)
92+
93+
try:
94+
connection.execute("DROP TABLE IF EXISTS characters")
95+
cursor = connection.execute("CREATE TABLE IF NOT EXISTS characters(c)")
96+
97+
theIter = IterChars()
98+
cursor.executemany("INSERT INTO characters(c) VALUES (?)", theIter)
99+
100+
cursor.execute("SELECT c FROM characters")
101+
102+
results = cursor.fetchall()
103+
assert len(results) == 26
104+
finally:
105+
connection.execute("DROP TABLE IF EXISTS characters")
106+
107+
@pytest.mark.parametrize(
108+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
109+
)
110+
def test_executemany_with_yield_generator(self, connection, request):
111+
connection = request.getfixturevalue(connection)
112+
113+
def char_generator():
114+
for c in string.ascii_lowercase:
115+
yield (c,)
116+
117+
try:
118+
connection.execute("DROP TABLE IF EXISTS characters")
119+
cursor = connection.execute("CREATE TABLE IF NOT EXISTS characters(c)")
120+
121+
cursor.executemany("INSERT INTO characters(c) VALUES (?)", char_generator())
122+
123+
cursor.execute("SELECT c FROM characters")
124+
125+
results = cursor.fetchall()
126+
assert len(results) == 26
127+
finally:
128+
connection.execute("DROP TABLE IF EXISTS characters")
129+
72130
def test_execute_with_question_mark_style(
73131
self, sqlitecloud_dbapi2_connection, sqlite3_connection
74132
):
@@ -84,20 +142,37 @@ def test_execute_with_question_mark_style(
84142

85143
assert sqlitecloud_results == sqlite3_results
86144

87-
def test_execute_with_named_param_style(
88-
self, sqlitecloud_dbapi2_connection, sqlite3_connection
89-
):
90-
sqlitecloud_connection = sqlitecloud_dbapi2_connection
145+
@pytest.mark.parametrize(
146+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
147+
)
148+
def test_execute_with_named_param_style(self, connection, request):
149+
connection = request.getfixturevalue(connection)
91150

92-
select_query = "SELECT * FROM albums WHERE AlbumId = :id"
93-
params = {"id": 1}
94-
sqlitecloud_cursor = sqlitecloud_connection.execute(select_query, params)
95-
sqlite3_cursor = sqlite3_connection.execute(select_query, params)
151+
select_query = "SELECT * FROM albums WHERE AlbumId = :id and Title = :title and AlbumId = :id"
152+
params = {"id": 1, "title": "For Those About To Rock We Salute You"}
96153

97-
sqlitecloud_results = sqlitecloud_cursor.fetchall()
98-
sqlite3_results = sqlite3_cursor.fetchall()
154+
cursor = connection.execute(select_query, params)
99155

100-
assert sqlitecloud_results == sqlite3_results
156+
results = cursor.fetchall()
157+
158+
assert len(results) == 1
159+
assert results[0][0] == 1
160+
161+
@pytest.mark.parametrize(
162+
"connection", ["sqlitecloud_dbapi2_connection", "sqlite3_connection"]
163+
)
164+
def test_executemany_with_named_param_style(self, connection, request):
165+
connection = request.getfixturevalue(connection)
166+
167+
select_query = "INSERT INTO customers (FirstName, Email, LastName) VALUES (:name, :email, :name)"
168+
params = [
169+
{"name": "pippo", "email": "pippo@disney.com"},
170+
{"name": "pluto", "email": "pluto@disney.com"},
171+
]
172+
173+
connection.executemany(select_query, params)
174+
175+
assert connection.total_changes == 2
101176

102177
@pytest.mark.skip(
103178
reason="Rowcount does not contain the number of inserted rows yet"
@@ -1151,11 +1226,47 @@ def test_transaction_context_manager_on_failure(self, connection, request):
11511226
"INSERT INTO albums (Title, ArtistId) VALUES ('Test Album 1', 1)"
11521227
)
11531228
id1 = cursor.lastrowid
1154-
connection.execute("INVALID COMMAND")
1229+
connection.execute("insert into pippodd (p) values (1)")
11551230
except Exception:
11561231
assert True
11571232

11581233
cursor = connection.execute("SELECT * FROM albums WHERE AlbumId = ?", (id1,))
11591234
result = cursor.fetchone()
11601235

11611236
assert result is None
1237+
1238+
@pytest.mark.parametrize(
1239+
"connection",
1240+
[
1241+
"sqlitecloud_dbapi2_connection",
1242+
"sqlite3_connection",
1243+
],
1244+
)
1245+
def test_connection_total_changes(self, connection, request):
1246+
connection = request.getfixturevalue(connection)
1247+
1248+
new_name1 = "Jazz" + str(uuid.uuid4())
1249+
new_name2 = "Jazz" + str(uuid.uuid4())
1250+
new_name3 = "Jazz" + str(uuid.uuid4())
1251+
1252+
connection.executemany(
1253+
"INSERT INTO genres (Name) VALUES (?)",
1254+
[(new_name1,), (new_name2,)],
1255+
)
1256+
assert connection.total_changes == 2
1257+
1258+
connection.execute(
1259+
"SELECT * FROM genres WHERE Name IN (?, ?)", (new_name1, new_name2)
1260+
)
1261+
assert connection.total_changes == 2
1262+
1263+
connection.execute(
1264+
"UPDATE genres SET Name = ? WHERE Name = ?", (new_name3, new_name1)
1265+
)
1266+
assert connection.total_changes == 3
1267+
1268+
connection.execute(
1269+
"DELETE FROM genres WHERE Name in (?, ?, ?)",
1270+
(new_name1, new_name2, new_name3),
1271+
)
1272+
assert connection.total_changes == 5

0 commit comments

Comments
 (0)