Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion crates/duckdb/src/appender/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{ffi::c_void, fmt, os::raw::c_char};

use crate::{
error::result_from_duckdb_appender,
types::{ToSql, ToSqlOutput},
types::{to_duckdb_decimal, ToSql, ToSqlOutput},
Error,
};

Expand Down Expand Up @@ -183,6 +183,14 @@ impl Appender<'_> {
},
)
},
ValueRef::Decimal(d) => unsafe {
let decimal = to_duckdb_decimal(d);
let mut value = ffi::duckdb_create_decimal(decimal);
let res = ffi::duckdb_append_value(ptr, value);
// free value
ffi::duckdb_destroy_value(&mut value);
res
},
_ => unreachable!("not supported"),
};
if rc != 0 {
Expand Down Expand Up @@ -224,6 +232,8 @@ impl fmt::Debug for Appender<'_> {

#[cfg(test)]
mod test {
use rust_decimal::Decimal;

use crate::{params, Connection, Error, Result};

#[test]
Expand Down Expand Up @@ -424,4 +434,29 @@ mod test {

Ok(())
}

#[test]
fn test_appender_decimal() -> Result<()> {
let d1 = rust_decimal::Decimal::from_i128_with_scale(11344, 4);
let d2 = rust_decimal::Decimal::from_i128_with_scale(12312, 3);
let d3 = rust_decimal::Decimal::from_i128_with_scale(-98765, 5);

let conn = Connection::open_in_memory()?;
conn.execute_batch("CREATE TABLE decimals (value DECIMAL(20, 10));")?;

let mut appender = conn.appender("decimals")?;
appender.append_row(params![d1])?;
appender.append_row(params![d2])?;
appender.append_row(params![d3])?;
appender.flush()?;

let results: Vec<Decimal> = conn
.prepare("SELECT value FROM decimals ORDER BY value ASC")?
.query_map([], |row| row.get(0))?
.collect::<Result<Vec<Decimal>>>()?;

assert_eq!(results, vec![d3, d1, d2]);

Ok(())
}
}
34 changes: 33 additions & 1 deletion crates/duckdb/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{arrow2, polars_dataframe::Polars};
use crate::{
arrow_batch::{Arrow, ArrowStream},
error::result_from_duckdb_prepare,
types::{TimeUnit, ToSql, ToSqlOutput},
types::{to_duckdb_decimal, TimeUnit, ToSql, ToSqlOutput},
};

/// A prepared statement.
Expand Down Expand Up @@ -608,6 +608,10 @@ impl Statement<'_> {
let micros = nanos / 1_000;
ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros })
},
ValueRef::Decimal(d) => unsafe {
let decimal = to_duckdb_decimal(d);
ffi::duckdb_bind_decimal(ptr, col as u64, decimal)
},
_ => unreachable!("not supported: {}", value.data_type()),
};
result_from_duckdb_prepare(rc, ptr)
Expand Down Expand Up @@ -1218,4 +1222,32 @@ mod test {

Ok(())
}

#[test]
fn test_with_decimal() -> Result<()> {
let db = Connection::open_in_memory()?;
db.execute_batch(
"BEGIN; \
CREATE TABLE foo(x DECIMAL(18, 4)); \
CREATE TABLE bar(y DECIMAL(18, 2)); \
COMMIT;",
)?;

// If duckdb's scale is larger than rust_decimal's scale, value should not be truncated.
let value = rust_decimal::Decimal::from_i128_with_scale(12345, 4);
db.execute("INSERT INTO foo(x) VALUES (?)", [&value])?;
let row: rust_decimal::Decimal =
db.query_row("SELECT x FROM foo", [], |r| r.get::<_, rust_decimal::Decimal>(0))?;
assert_eq!(row, value);

// If duckdb's scale is smaller than rust_decimal's scale, value should be truncated (1.2345 -> 1.23).
let value = rust_decimal::Decimal::from_i128_with_scale(12345, 4);
db.execute("INSERT INTO bar(y) VALUES (?)", [&value])?;
let row: rust_decimal::Decimal =
db.query_row("SELECT y FROM bar", [], |r| r.get::<_, rust_decimal::Decimal>(0))?;
let value_from_duckdb = rust_decimal::Decimal::from_i128_with_scale(123, 2);
assert_eq!(row, value_from_duckdb);

Ok(())
}
}
84 changes: 84 additions & 0 deletions crates/duckdb/src/types/decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use rust_decimal::prelude::FromPrimitive as _;

use super::TimeUnit;
use crate::ffi;
use crate::types::{FromSql, FromSqlError, FromSqlResult, Value, ValueRef};
use crate::Result;
use crate::{types::ToSqlOutput, ToSql};

/// Convert a rust_decimal::Decimal to a ffi::duckdb_decimal
pub fn to_duckdb_decimal(d: rust_decimal::Decimal) -> ffi::duckdb_decimal {
// The max size of rust_decimal's scale is 28.
let d_scale = d.scale() as u8;
let d_width = decimal_width(d);
let d_value = {
let mantissa = d.mantissa();
let lo = mantissa as u64;
let hi = (mantissa >> 64) as i64;
ffi::duckdb_hugeint { lower: lo, upper: hi }
};

ffi::duckdb_decimal {
width: d_width,
scale: d_scale,
value: d_value,
}
}

/// Get the length of the decimal significant digits of a rust_decimal::Decimal
fn decimal_width(d: rust_decimal::Decimal) -> u8 {
let mut num = d.mantissa();

if num == 0 {
return 1;
}

let mut len = 0;
num = num.abs();

while num > 0 {
len += 1;
num /= 10;
}

len
}

impl ToSql for rust_decimal::Decimal {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
Ok(ToSqlOutput::Owned(Value::Decimal(*self)))
}
}

impl FromSql for rust_decimal::Decimal {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
ValueRef::TinyInt(i) => rust_decimal::Decimal::from_i8(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::SmallInt(i) => rust_decimal::Decimal::from_i16(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::Int(i) => rust_decimal::Decimal::from_i32(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::BigInt(i) => rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::HugeInt(i) => rust_decimal::Decimal::from_i128(i).ok_or(FromSqlError::OutOfRange(i)),
ValueRef::UTinyInt(i) => rust_decimal::Decimal::from_u8(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::USmallInt(i) => rust_decimal::Decimal::from_u16(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::UInt(i) => rust_decimal::Decimal::from_u32(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::UBigInt(i) => rust_decimal::Decimal::from_u64(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::Float(f) => rust_decimal::Decimal::from_f32(f).ok_or(FromSqlError::OutOfRange(f as i128)),
ValueRef::Double(d) => rust_decimal::Decimal::from_f64(d).ok_or(FromSqlError::OutOfRange(d as i128)),
ValueRef::Decimal(decimal) => Ok(decimal),
ValueRef::Timestamp(_, i) => rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::Date32(i) => rust_decimal::Decimal::from_i32(i).ok_or(FromSqlError::OutOfRange(i as i128)),
ValueRef::Time64(TimeUnit::Microsecond, i) => {
rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128))
}
ValueRef::Text(_) => {
let s = value.as_str()?;
s.parse::<rust_decimal::Decimal>().or_else(|_| {
s.parse::<i128>()
.map_err(|_| FromSqlError::InvalidType)
.and_then(|i| Err(FromSqlError::OutOfRange(i as i128)))
})
}
_ => Err(FromSqlError::InvalidType),
}
}
}
2 changes: 2 additions & 0 deletions crates/duckdb/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! a value was NULL (which gets translated to `None`).

pub use self::{
decimal::to_duckdb_decimal,
from_sql::{FromSql, FromSqlError, FromSqlResult},
ordered_map::OrderedMap,
string::DuckString,
Expand All @@ -25,6 +26,7 @@ mod url;
mod value;
mod value_ref;

mod decimal;
mod ordered_map;
mod string;

Expand Down