Skip to content

Commit 87e6a3c

Browse files
authored
feat: support Option<T> as sql params (#624)
* z * z * z * z * z * z * z * z * z
1 parent d9dd20e commit 87e6a3c

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

bindings/python/src/utils.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ where
3939
py.allow_threads(|| RUNTIME.block_on(f))
4040
}
4141

42-
// params: Option<Bound<'p, PyAny>>
4342
pub(crate) fn to_sql_params(v: Option<Bound<PyAny>>) -> Params {
4443
match v {
4544
Some(v) => {
@@ -74,6 +73,9 @@ pub(crate) fn to_sql_params(v: Option<Bound<PyAny>>) -> Params {
7473
}
7574

7675
fn to_sql_string(v: Bound<PyAny>) -> PyResult<String> {
76+
if v.is_none() {
77+
return Ok("NULL".to_string());
78+
}
7779
match v.downcast::<PyAny>() {
7880
Ok(v) => {
7981
if let Ok(v) = v.extract::<String>() {

bindings/python/tests/cursor/steps/binding.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,26 @@ def _(context):
5050

5151
@then("Select params binding")
5252
def _(context):
53-
context.cursor.execute("SELECT ?, ?, ?, ?", (3, False, 4, "55"))
53+
context.cursor.execute("SELECT ?, ?, ?, ?, ?", (3, False, 4, "55", None))
5454
row = context.cursor.fetchone()
55-
assert row.values() == (3, False, 4, "55"), f"output: {row.values()}"
55+
assert row.values() == (3, False, 4, "55", None), f"output: {row.values()}"
5656

5757
# Test with named parameters
5858
context.cursor.execute(
59-
"SELECT :a, :b, :c, :d", {"a": 3, "b": False, "c": 4, "d": "55"}
59+
"SELECT :a, :b, :c, :d, :e",
60+
{"a": 3, "b": False, "c": 4, "d": "55", "e": None},
6061
)
6162
row = context.cursor.fetchone()
62-
assert row.values() == (3, False, 4, "55"), f"output: {row.values()}"
63+
assert row.values() == (3, False, 4, "55", None), f"output: {row.values()}"
6364

6465
context.cursor.execute("SELECT ?", 4)
6566
row = context.cursor.fetchone()
6667
assert row.values() == (4,), f"output: {row.values()}"
6768

6869
# Test with positional parameters again
69-
context.cursor.execute("SELECT ?, ?, ?, ?", (3, False, 4, "55"))
70+
context.cursor.execute("SELECT ?, ?, ?, ?, ?", (3, False, 4, "55", None))
7071
row = context.cursor.fetchone()
71-
assert row.values() == (3, False, 4, "55"), f"output: {row.values()}"
72+
assert row.values() == (3, False, 4, "55", None), f"output: {row.values()}"
7273

7374

7475
@then("Select string {input} should be equal to {output}")
@@ -171,7 +172,7 @@ def _(context):
171172
assert ret == expected, f"ret: {ret}"
172173

173174
desc = context.cursor.description
174-
assert desc != None
175+
assert desc is not None
175176

176177
# fetchmany
177178
context.cursor.execute("SELECT * FROM test")

cli/src/display.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,11 +743,7 @@ fn format_table_style(value: &Value, max_col_width: usize, replace_newline: bool
743743
value
744744
};
745745
if value.len() + 3 > max_col_width {
746-
let element_size = if max_col_width >= 6 {
747-
max_col_width - 6
748-
} else {
749-
0
750-
};
746+
let element_size = max_col_width.saturating_sub(6);
751747
value = String::from_utf8(
752748
value
753749
.graphemes(true)

driver/src/params.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ impl Param for &str {
129129
}
130130
}
131131

132+
// Impl Param for None
133+
impl Param for () {
134+
fn as_sql_string(&self) -> String {
135+
"NULL".to_string()
136+
}
137+
}
138+
139+
impl<T> Param for Option<T>
140+
where
141+
T: Param,
142+
{
143+
fn as_sql_string(&self) -> String {
144+
match self {
145+
Some(s) => s.as_sql_string(),
146+
None => "NULL".to_string(),
147+
}
148+
}
149+
}
150+
132151
impl Param for serde_json::Value {
133152
fn as_sql_string(&self) -> String {
134153
match self {
@@ -170,7 +189,7 @@ impl Param for serde_json::Value {
170189
macro_rules! params {
171190
// Handle named parameters
172191
() => {
173-
$crate::Params::default()
192+
$crate::Params::default()
174193
};
175194
($($key:ident => $value:expr),* $(,)?) => {
176195
$crate::Params::NamedParams({
@@ -308,6 +327,15 @@ mod tests {
308327
}
309328
}
310329

330+
// Test Option<T>
331+
{
332+
let params: Params = (Some(1), None::<()>, Some("44"), None::<()>).into();
333+
match params {
334+
Params::QuestionParams(vec) => assert_eq!(vec, vec!["1", "NULL", "'44'", "NULL"]),
335+
_ => panic!("Expected QuestionParams"),
336+
}
337+
}
338+
311339
// Test into params for serde_json
312340
{
313341
let params: Params = serde_json::json!({

0 commit comments

Comments
 (0)