diff --git a/crates/duckdb/src/appender/mod.rs b/crates/duckdb/src/appender/mod.rs index 72871e76..26b1b1eb 100644 --- a/crates/duckdb/src/appender/mod.rs +++ b/crates/duckdb/src/appender/mod.rs @@ -3,7 +3,7 @@ use std::{ffi::c_void, fmt, os::raw::c_char}; use crate::{ error::result_from_duckdb_appender, - types::{ToSql, ToSqlOutput}, + types::{to_duckdb_decimal, ToSql, ToSqlOutput}, Error, }; @@ -183,6 +183,14 @@ impl Appender<'_> { }, ) }, + ValueRef::Decimal(d) => unsafe { + let decimal = to_duckdb_decimal(d); + let mut value = ffi::duckdb_create_decimal(decimal); + let res = ffi::duckdb_append_value(ptr, value); + // free value + ffi::duckdb_destroy_value(&mut value); + res + }, _ => unreachable!("not supported"), }; if rc != 0 { @@ -224,6 +232,8 @@ impl fmt::Debug for Appender<'_> { #[cfg(test)] mod test { + use rust_decimal::Decimal; + use crate::{params, Connection, Error, Result}; #[test] @@ -424,4 +434,29 @@ mod test { Ok(()) } + + #[test] + fn test_appender_decimal() -> Result<()> { + let d1 = rust_decimal::Decimal::from_i128_with_scale(11344, 4); + let d2 = rust_decimal::Decimal::from_i128_with_scale(12312, 3); + let d3 = rust_decimal::Decimal::from_i128_with_scale(-98765, 5); + + let conn = Connection::open_in_memory()?; + conn.execute_batch("CREATE TABLE decimals (value DECIMAL(20, 10));")?; + + let mut appender = conn.appender("decimals")?; + appender.append_row(params![d1])?; + appender.append_row(params![d2])?; + appender.append_row(params![d3])?; + appender.flush()?; + + let results: Vec = conn + .prepare("SELECT value FROM decimals ORDER BY value ASC")? + .query_map([], |row| row.get(0))? + .collect::>>()?; + + assert_eq!(results, vec![d3, d1, d2]); + + Ok(()) + } } diff --git a/crates/duckdb/src/statement.rs b/crates/duckdb/src/statement.rs index af8dc7b4..a9cbbd1a 100644 --- a/crates/duckdb/src/statement.rs +++ b/crates/duckdb/src/statement.rs @@ -8,7 +8,7 @@ use crate::{arrow2, polars_dataframe::Polars}; use crate::{ arrow_batch::{Arrow, ArrowStream}, error::result_from_duckdb_prepare, - types::{TimeUnit, ToSql, ToSqlOutput}, + types::{to_duckdb_decimal, TimeUnit, ToSql, ToSqlOutput}, }; /// A prepared statement. @@ -608,6 +608,10 @@ impl Statement<'_> { let micros = nanos / 1_000; ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros }) }, + ValueRef::Decimal(d) => unsafe { + let decimal = to_duckdb_decimal(d); + ffi::duckdb_bind_decimal(ptr, col as u64, decimal) + }, _ => unreachable!("not supported: {}", value.data_type()), }; result_from_duckdb_prepare(rc, ptr) @@ -1218,4 +1222,32 @@ mod test { Ok(()) } + + #[test] + fn test_with_decimal() -> Result<()> { + let db = Connection::open_in_memory()?; + db.execute_batch( + "BEGIN; \ + CREATE TABLE foo(x DECIMAL(18, 4)); \ + CREATE TABLE bar(y DECIMAL(18, 2)); \ + COMMIT;", + )?; + + // If duckdb's scale is larger than rust_decimal's scale, value should not be truncated. + let value = rust_decimal::Decimal::from_i128_with_scale(12345, 4); + db.execute("INSERT INTO foo(x) VALUES (?)", [&value])?; + let row: rust_decimal::Decimal = + db.query_row("SELECT x FROM foo", [], |r| r.get::<_, rust_decimal::Decimal>(0))?; + assert_eq!(row, value); + + // If duckdb's scale is smaller than rust_decimal's scale, value should be truncated (1.2345 -> 1.23). + let value = rust_decimal::Decimal::from_i128_with_scale(12345, 4); + db.execute("INSERT INTO bar(y) VALUES (?)", [&value])?; + let row: rust_decimal::Decimal = + db.query_row("SELECT y FROM bar", [], |r| r.get::<_, rust_decimal::Decimal>(0))?; + let value_from_duckdb = rust_decimal::Decimal::from_i128_with_scale(123, 2); + assert_eq!(row, value_from_duckdb); + + Ok(()) + } } diff --git a/crates/duckdb/src/types/decimal.rs b/crates/duckdb/src/types/decimal.rs new file mode 100644 index 00000000..8cb11d8f --- /dev/null +++ b/crates/duckdb/src/types/decimal.rs @@ -0,0 +1,84 @@ +use rust_decimal::prelude::FromPrimitive as _; + +use super::TimeUnit; +use crate::ffi; +use crate::types::{FromSql, FromSqlError, FromSqlResult, Value, ValueRef}; +use crate::Result; +use crate::{types::ToSqlOutput, ToSql}; + +/// Convert a rust_decimal::Decimal to a ffi::duckdb_decimal +pub fn to_duckdb_decimal(d: rust_decimal::Decimal) -> ffi::duckdb_decimal { + // The max size of rust_decimal's scale is 28. + let d_scale = d.scale() as u8; + let d_width = decimal_width(d); + let d_value = { + let mantissa = d.mantissa(); + let lo = mantissa as u64; + let hi = (mantissa >> 64) as i64; + ffi::duckdb_hugeint { lower: lo, upper: hi } + }; + + ffi::duckdb_decimal { + width: d_width, + scale: d_scale, + value: d_value, + } +} + +/// Get the length of the decimal significant digits of a rust_decimal::Decimal +fn decimal_width(d: rust_decimal::Decimal) -> u8 { + let mut num = d.mantissa(); + + if num == 0 { + return 1; + } + + let mut len = 0; + num = num.abs(); + + while num > 0 { + len += 1; + num /= 10; + } + + len +} + +impl ToSql for rust_decimal::Decimal { + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Owned(Value::Decimal(*self))) + } +} + +impl FromSql for rust_decimal::Decimal { + fn column_result(value: ValueRef<'_>) -> FromSqlResult { + match value { + ValueRef::TinyInt(i) => rust_decimal::Decimal::from_i8(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::SmallInt(i) => rust_decimal::Decimal::from_i16(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::Int(i) => rust_decimal::Decimal::from_i32(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::BigInt(i) => rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::HugeInt(i) => rust_decimal::Decimal::from_i128(i).ok_or(FromSqlError::OutOfRange(i)), + ValueRef::UTinyInt(i) => rust_decimal::Decimal::from_u8(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::USmallInt(i) => rust_decimal::Decimal::from_u16(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::UInt(i) => rust_decimal::Decimal::from_u32(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::UBigInt(i) => rust_decimal::Decimal::from_u64(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::Float(f) => rust_decimal::Decimal::from_f32(f).ok_or(FromSqlError::OutOfRange(f as i128)), + ValueRef::Double(d) => rust_decimal::Decimal::from_f64(d).ok_or(FromSqlError::OutOfRange(d as i128)), + ValueRef::Decimal(decimal) => Ok(decimal), + ValueRef::Timestamp(_, i) => rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::Date32(i) => rust_decimal::Decimal::from_i32(i).ok_or(FromSqlError::OutOfRange(i as i128)), + ValueRef::Time64(TimeUnit::Microsecond, i) => { + rust_decimal::Decimal::from_i64(i).ok_or(FromSqlError::OutOfRange(i as i128)) + } + ValueRef::Text(_) => { + let s = value.as_str()?; + s.parse::().or_else(|_| { + s.parse::() + .map_err(|_| FromSqlError::InvalidType) + .and_then(|i| Err(FromSqlError::OutOfRange(i as i128))) + }) + } + _ => Err(FromSqlError::InvalidType), + } + } +} diff --git a/crates/duckdb/src/types/mod.rs b/crates/duckdb/src/types/mod.rs index 87227143..9f34c0a4 100644 --- a/crates/duckdb/src/types/mod.rs +++ b/crates/duckdb/src/types/mod.rs @@ -3,6 +3,7 @@ //! a value was NULL (which gets translated to `None`). pub use self::{ + decimal::to_duckdb_decimal, from_sql::{FromSql, FromSqlError, FromSqlResult}, ordered_map::OrderedMap, string::DuckString, @@ -25,6 +26,7 @@ mod url; mod value; mod value_ref; +mod decimal; mod ordered_map; mod string;