From 13910780216893e7602d26632411c8ae06869485 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 16 Jun 2025 17:18:55 -0400 Subject: [PATCH 01/14] feat: upgrade df48 dependency (#1143) * Upgrade to DF 48 * Update unit test * Resolve clippy warnings * Update wrapper test to look for __repr__ special function * Add __repr__ where missing * Error in return of __repr__ * Remove patch now that DF48 is released * Expose lit_with_metadata and add unit test --- Cargo.lock | 184 ++++++++++++++------------ Cargo.toml | 10 +- python/datafusion/__init__.py | 19 +++ python/datafusion/catalog.py | 12 ++ python/datafusion/context.py | 4 + python/datafusion/expr.py | 18 +++ python/datafusion/user_defined.py | 12 ++ python/tests/test_expr.py | 60 ++++++++- python/tests/test_wrapper_coverage.py | 7 +- src/context.rs | 47 +++++-- src/expr.rs | 34 +++-- src/expr/literal.rs | 16 ++- src/expr/window.rs | 29 ++-- src/functions.rs | 6 +- src/pyarrow_filter_expression.rs | 4 +- src/udwf.rs | 8 +- 16 files changed, 325 insertions(+), 145 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 39489ed94..112167cb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -359,6 +359,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] @@ -859,9 +861,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080" +checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96" dependencies = [ "arrow", "arrow-ipc", @@ -887,7 +889,6 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", - "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -902,7 +903,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "regex", "sqlparser", "tempfile", @@ -915,9 +916,9 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66" +checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee" dependencies = [ "arrow", "async-trait", @@ -941,9 +942,9 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00" +checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038" dependencies = [ "arrow", "async-trait", @@ -964,9 +965,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c" +checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12" dependencies = [ "ahash", "apache-avro", @@ -989,9 +990,9 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b" +checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2" dependencies = [ "futures", "log", @@ -1000,9 +1001,9 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1" +checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b" dependencies = [ "arrow", "async-compression", @@ -1025,7 +1026,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -1036,9 +1037,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4ea5111aab9d3f2a8bff570343cccb03ce4c203875ef5a566b7d6f1eb72559e" +checksum = "1371cb4ef13c2e3a15685d37a07398cf13e3b0a85e705024b769fc4c511f5fef" dependencies = [ "apache-avro", "arrow", @@ -1061,9 +1062,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6" +checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc" dependencies = [ "arrow", "async-trait", @@ -1086,9 +1087,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352" +checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c" dependencies = [ "arrow", "async-trait", @@ -1111,9 +1112,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923" +checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b" dependencies = [ "arrow", "async-trait", @@ -1136,21 +1137,21 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292" +checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea" [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84" +checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a" dependencies = [ "arrow", "dashmap", @@ -1160,16 +1161,16 @@ dependencies = [ "log", "object_store", "parking_lot", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964" +checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b" dependencies = [ "arrow", "chrono", @@ -1188,9 +1189,9 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839" +checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a" dependencies = [ "arrow", "datafusion-common", @@ -1201,9 +1202,9 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c" +checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef" dependencies = [ "abi_stable", "arrow", @@ -1211,7 +1212,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "futures", "log", "prost", @@ -1221,9 +1224,9 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e" +checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357" dependencies = [ "arrow", "arrow-buffer", @@ -1241,7 +1244,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", - "rand 0.8.5", + "rand 0.9.1", "regex", "sha2", "unicode-segmentation", @@ -1250,9 +1253,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba" +checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4" dependencies = [ "ahash", "arrow", @@ -1271,9 +1274,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4" +checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696" dependencies = [ "ahash", "arrow", @@ -1284,9 +1287,9 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537" +checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b" dependencies = [ "arrow", "arrow-ord", @@ -1305,9 +1308,9 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a" +checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b" dependencies = [ "arrow", "async-trait", @@ -1321,10 +1324,11 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193" +checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4" dependencies = [ + "arrow", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -1338,9 +1342,9 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313" +checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1348,9 +1352,9 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1" +checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8" dependencies = [ "datafusion-expr", "quote", @@ -1359,9 +1363,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515" +checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058" dependencies = [ "arrow", "chrono", @@ -1378,9 +1382,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65" +checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6" dependencies = [ "ahash", "arrow", @@ -1395,14 +1399,14 @@ dependencies = [ "itertools 0.14.0", "log", "paste", - "petgraph", + "petgraph 0.8.2", ] [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863" +checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df" dependencies = [ "ahash", "arrow", @@ -1414,9 +1418,9 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2" +checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae" dependencies = [ "arrow", "datafusion-common", @@ -1433,9 +1437,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55" +checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23" dependencies = [ "ahash", "arrow", @@ -1463,9 +1467,9 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83" +checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6" dependencies = [ "arrow", "chrono", @@ -1479,9 +1483,9 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f" +checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e" dependencies = [ "arrow", "datafusion-common", @@ -1513,9 +1517,9 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f" +checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49" dependencies = [ "arrow", "async-trait", @@ -1537,9 +1541,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607" +checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df" dependencies = [ "arrow", "bigdecimal", @@ -1554,9 +1558,9 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061efc0937f0ce3abb37ed0d56cfa01dd0e654b90e408656d05e846c8b7599fe" +checksum = "f2f3973b1a4f6e9ee7fd99a22d58e1c06e6723a28dc911a60df575974c8339aa" dependencies = [ "async-recursion", "async-trait", @@ -2717,6 +2721,18 @@ dependencies = [ "indexmap", ] +[[package]] +name = "petgraph" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.3", + "indexmap", + "serde", +] + [[package]] name = "phf" version = "0.11.3" @@ -2837,7 +2853,7 @@ dependencies = [ "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", "prettyplease", "prost", "prost-types", @@ -3661,9 +3677,9 @@ dependencies = [ [[package]] name = "substrait" -version = "0.55.1" +version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "048fe52a3664881ccdfdc9bdb0f4e8805f3444ee64abf299d365c54f6a2ffabb" +checksum = "13de2e20128f2a018dab1cfa30be83ae069219a65968c6f89df66ad124de2397" dependencies = [ "heck", "pbjson", @@ -4016,9 +4032,9 @@ checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" +checksum = "6c6c647a34e851cf0260ccc14687f17cdcb8302ff1a8a687a24b97ca0f82406f" dependencies = [ "typify-impl", "typify-macro", @@ -4026,9 +4042,9 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" +checksum = "741b7f1e2e1338c0bee5ad5a7d3a9bbd4e24c33765c08b7691810e68d879365d" dependencies = [ "heck", "log", @@ -4046,9 +4062,9 @@ dependencies = [ [[package]] name = "typify-macro" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" +checksum = "7560adf816a1e8dad7c63d8845ef6e31e673e39eab310d225636779230cbedeb" dependencies = [ "proc-macro2", "quote", @@ -4116,9 +4132,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 8107d76d3..4135e64e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,11 +37,11 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} -arrow = { version = "55.0.0", features = ["pyarrow"] } -datafusion = { version = "47.0.0", features = ["avro", "unicode_expressions"] } -datafusion-substrait = { version = "47.0.0", optional = true } -datafusion-proto = { version = "47.0.0" } -datafusion-ffi = { version = "47.0.0" } +arrow = { version = "55.1.0", features = ["pyarrow"] } +datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } +datafusion-substrait = { version = "48.0.0", optional = true } +datafusion-proto = { version = "48.0.0" } +datafusion-ffi = { version = "48.0.0" } prost = "0.13.1" # keep in line with `datafusion-substrait` uuid = { version = "1.16", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index c3468eb4a..4f7700251 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -21,6 +21,10 @@ See https://datafusion.apache.org/python for more information. """ +from __future__ import annotations + +from typing import Any + try: import importlib.metadata as importlib_metadata except ImportError: @@ -130,3 +134,18 @@ def str_lit(value): def lit(value) -> Expr: """Create a literal expression.""" return Expr.literal(value) + + +def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Creates a new expression representing a scalar value with metadata. + + Args: + value: A valid PyArrow scalar value or easily castable to one. + metadata: Metadata to attach to the expression. + """ + return Expr.literal_with_metadata(value, metadata) + + +def lit_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Alias for literal_with_metadata.""" + return literal_with_metadata(value, metadata) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 6c3f188cc..67ab3ead2 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -34,6 +34,10 @@ def __init__(self, catalog: df_internal.Catalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog + def __repr__(self) -> str: + """Print a string representation of the catalog.""" + return self.catalog.__repr__() + def names(self) -> list[str]: """Returns the list of databases in this catalog.""" return self.catalog.names() @@ -50,6 +54,10 @@ def __init__(self, db: df_internal.Database) -> None: """This constructor is not typically called by the end user.""" self.db = db + def __repr__(self) -> str: + """Print a string representation of the database.""" + return self.db.__repr__() + def names(self) -> set[str]: """Returns the list of all tables in this database.""" return self.db.names() @@ -66,6 +74,10 @@ def __init__(self, table: df_internal.Table) -> None: """This constructor is not typically called by the end user.""" self.table = table + def __repr__(self) -> str: + """Print a string representation of the table.""" + return self.table.__repr__() + @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 26c3d2e22..4ed465c99 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -496,6 +496,10 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) + def __repr__(self) -> str: + """Print a string representation of the Session Context.""" + return self.ctx.__repr__() + @classmethod def global_ctx(cls) -> SessionContext: """Retrieve the global context as a `SessionContext` wrapper. diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 9e58873d0..e785cab06 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -435,6 +435,20 @@ def literal(value: Any) -> Expr: value = pa.scalar(value) return Expr(expr_internal.RawExpr.literal(value)) + @staticmethod + def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Creates a new expression representing a scalar value with metadata. + + Args: + value: A valid PyArrow scalar value or easily castable to one. + metadata: Metadata to attach to the expression. + """ + if isinstance(value, str): + value = pa.scalar(value, type=pa.string_view()) + value = value if isinstance(value, pa.Scalar) else pa.scalar(value) + + return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata)) + @staticmethod def string_literal(value: str) -> Expr: """Creates a new expression representing a UTF8 literal value. @@ -1172,6 +1186,10 @@ def __init__( end_bound = end_bound.cast(pa.uint64()) self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) + def __repr__(self) -> str: + """Print a string representation of the window frame.""" + return self.window_frame.__repr__() + def get_frame_units(self) -> str: """Returns the window frame units for the bounds.""" return self.window_frame.get_frame_units() diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 9ec3679a6..dd634c7fb 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -102,6 +102,10 @@ def __init__( name, func, input_types, return_type, str(volatility) ) + def __repr__(self) -> str: + """Print a string representation of the Scalar UDF.""" + return self._udf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDF. @@ -268,6 +272,10 @@ def __init__( str(volatility), ) + def __repr__(self) -> str: + """Print a string representation of the Aggregate UDF.""" + return self._udaf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDAF. @@ -604,6 +612,10 @@ def __init__( name, func, input_types, return_type, str(volatility) ) + def __repr__(self) -> str: + """Print a string representation of the Window UDF.""" + return self._udwf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDWF. diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index adca783b5..40a98dc4d 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -19,7 +19,14 @@ import pyarrow as pa import pytest -from datafusion import SessionContext, col, functions, lit +from datafusion import ( + SessionContext, + col, + functions, + lit, + lit_with_metadata, + literal_with_metadata, +) from datafusion.expr import ( Aggregate, AggregateFunction, @@ -103,7 +110,7 @@ def test_limit(test_ctx): plan = plan.to_variant() assert isinstance(plan, Limit) - assert "Skip: Some(Literal(Int64(5)))" in str(plan) + assert "Skip: Some(Literal(Int64(5), None))" in str(plan) def test_aggregate_query(test_ctx): @@ -824,3 +831,52 @@ def test_expr_functions(ctx, function, expected_result): assert len(result) == 1 assert result[0].column(0).equals(expected_result) + + +def test_literal_metadata(ctx): + result = ( + ctx.from_pydict({"a": [1]}) + .select( + lit(1).alias("no_metadata"), + lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"), + literal_with_metadata(3, {"key2": "value2"}).alias( + "literal_with_metadata_fn" + ), + ) + .collect() + ) + + expected_schema = pa.schema( + [ + pa.field("no_metadata", pa.int64(), nullable=False), + pa.field( + "lit_with_metadata_fn", + pa.int64(), + nullable=False, + metadata={"key1": "value1"}, + ), + pa.field( + "literal_with_metadata_fn", + pa.int64(), + nullable=False, + metadata={"key2": "value2"}, + ), + ] + ) + + expected = pa.RecordBatch.from_pydict( + { + "no_metadata": pa.array([1]), + "lit_with_metadata_fn": pa.array([2]), + "literal_with_metadata_fn": pa.array([3]), + }, + schema=expected_schema, + ) + + assert result[0] == expected + + # Testing result[0].schema == expected_schema does not check each key/value pair + # so we want to explicitly test these + for expected_field in expected_schema: + actual_field = result[0].schema.field(expected_field.name) + assert expected_field.metadata == actual_field.metadata diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index 926a65961..f484cb282 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -28,14 +28,14 @@ from enum import EnumMeta as EnumType -def missing_exports(internal_obj, wrapped_obj) -> None: +def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 """ Identify if any of the rust exposted structs or functions do not have wrappers. Special handling for: - Raw* classes: Internal implementation details that shouldn't be exposed - _global_ctx: Internal implementation detail - - __self__, __class__: Python special attributes + - __self__, __class__, __repr__: Python special attributes """ # Special case enums - EnumType overrides a some of the internal functions, # so check all of the values exist and move on @@ -45,6 +45,9 @@ def missing_exports(internal_obj, wrapped_obj) -> None: assert value in dir(wrapped_obj) return + if "__repr__" in internal_obj.__dict__ and "__repr__" not in wrapped_obj.__dict__: + pytest.fail(f"Missing __repr__: {internal_obj.__name__}") + for internal_attr_name in dir(internal_obj): wrapped_attr_name = internal_attr_name.removeprefix("Raw") assert wrapped_attr_name in dir(wrapped_obj) diff --git a/src/context.rs b/src/context.rs index b0af566e4..55c92a8fa 100644 --- a/src/context.rs +++ b/src/context.rs @@ -61,7 +61,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::{ DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, }; -use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; use datafusion::execution::options::ReadOptions; use datafusion::execution::runtime_env::RuntimeEnvBuilder; @@ -183,22 +183,49 @@ impl PyRuntimeEnvBuilder { } fn with_disk_manager_disabled(&self) -> Self { - let mut builder = self.builder.clone(); - builder = builder.with_disk_manager(DiskManagerConfig::Disabled); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::Disabled); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_disk_manager_os(&self) -> Self { - let builder = self.builder.clone(); - let builder = builder.with_disk_manager(DiskManagerConfig::NewOs); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_disk_manager_specified(&self, paths: Vec) -> Self { - let builder = self.builder.clone(); let paths = paths.iter().map(|s| s.into()).collect(); - let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths)); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths)); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_unbounded_memory_pool(&self) -> Self { diff --git a/src/expr.rs b/src/expr.rs index bc7dbeffd..6b1d01d65 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion::logical_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion::logical_expr::expr::AggregateFunctionParams; use datafusion::logical_expr::utils::exprlist_to_fields; use datafusion::logical_expr::{ - ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition, + lit_with_metadata, ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition, }; use pyo3::IntoPyObjectExt; use pyo3::{basic::CompareOp, prelude::*}; @@ -150,7 +150,7 @@ impl PyExpr { Ok(PyScalarVariable::new(data_type, variables).into_bound_py_any(py)?) } Expr::Like(value) => Ok(PyLike::from(value.clone()).into_bound_py_any(py)?), - Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_bound_py_any(py)?), + Expr::Literal(value, metadata) => Ok(PyLiteral::new_with_metadata(value.clone(), metadata.clone()).into_bound_py_any(py)?), Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_bound_py_any(py)?), Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_bound_py_any(py)?), Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_bound_py_any(py)?), @@ -282,6 +282,14 @@ impl PyExpr { lit(value.0).into() } + #[staticmethod] + pub fn literal_with_metadata( + value: PyScalarValue, + metadata: HashMap, + ) -> PyExpr { + lit_with_metadata(value.0, metadata).into() + } + #[staticmethod] pub fn column(value: &str) -> PyExpr { col(value).into() @@ -377,7 +385,7 @@ impl PyExpr { /// Extracts the Expr value into a PyObject that can be shared with Python pub fn python_value(&self, py: Python) -> PyResult { match &self.expr { - Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py), + Expr::Literal(scalar_value, _) => scalar_to_pyarrow(scalar_value, py), _ => Err(py_type_err(format!( "Non Expr::Literal encountered in types: {:?}", &self.expr @@ -417,11 +425,13 @@ impl PyExpr { params: AggregateFunctionParams { args, .. }, .. }) - | Expr::ScalarFunction(ScalarFunction { args, .. }) - | Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { args, .. }, - .. - }) => Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()), + | Expr::ScalarFunction(ScalarFunction { args, .. }) => { + Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()) + } + Expr::WindowFunction(boxed_window_fn) => { + let args = &boxed_window_fn.params.args; + Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()) + } // Expr(s) that require more specific processing Expr::Case(Case { @@ -600,10 +610,10 @@ impl PyExpr { ) -> PyDataFusionResult { match &self.expr { Expr::AggregateFunction(agg_fn) => { - let window_fn = Expr::WindowFunction(WindowFunction::new( + let window_fn = Expr::WindowFunction(Box::new(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()), agg_fn.params.args.clone(), - )); + ))); add_builder_fns_to_window( window_fn, @@ -743,7 +753,7 @@ impl PyExpr { | Operator::QuestionPipe => Err(py_type_err(format!("Unsupported expr: ${op}"))), }, Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type), - Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value), + Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value), _ => Err(py_type_err(format!( "Non Expr::Literal encountered in types: {:?}", expr diff --git a/src/expr/literal.rs b/src/expr/literal.rs index a660ac914..45303a104 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -18,11 +18,22 @@ use crate::errors::PyDataFusionError; use datafusion::common::ScalarValue; use pyo3::{prelude::*, IntoPyObjectExt}; +use std::collections::BTreeMap; #[pyclass(name = "Literal", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLiteral { pub value: ScalarValue, + pub metadata: Option>, +} + +impl PyLiteral { + pub fn new_with_metadata( + value: ScalarValue, + metadata: Option>, + ) -> PyLiteral { + Self { value, metadata } + } } impl From for ScalarValue { @@ -33,7 +44,10 @@ impl From for ScalarValue { impl From for PyLiteral { fn from(value: ScalarValue) -> PyLiteral { - PyLiteral { value } + PyLiteral { + value, + metadata: None, + } } } diff --git a/src/expr/window.rs b/src/expr/window.rs index c5467bf94..052d9eeb4 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -16,7 +16,6 @@ // under the License. use datafusion::common::{DataFusionError, ScalarValue}; -use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits}; use pyo3::{prelude::*, IntoPyObjectExt}; use std::fmt::{self, Display, Formatter}; @@ -118,10 +117,9 @@ impl PyWindowExpr { /// Returns order by columns in a window function expression pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { order_by, .. }, - .. - }) => py_sort_expr_list(&order_by), + Expr::WindowFunction(boxed_window_fn) => { + py_sort_expr_list(&boxed_window_fn.params.order_by) + } other => Err(not_window_function_err(other)), } } @@ -129,10 +127,9 @@ impl PyWindowExpr { /// Return partition by columns in a window function expression pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => py_expr_list(&partition_by), + Expr::WindowFunction(boxed_window_fn) => { + py_expr_list(&boxed_window_fn.params.partition_by) + } other => Err(not_window_function_err(other)), } } @@ -140,10 +137,7 @@ impl PyWindowExpr { /// Return input args for window function pub fn get_args(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { args, .. }, - .. - }) => py_expr_list(&args), + Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args), other => Err(not_window_function_err(other)), } } @@ -151,7 +145,7 @@ impl PyWindowExpr { /// Return window function name pub fn window_func_name(&self, expr: PyExpr) -> PyResult { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()), + Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()), other => Err(not_window_function_err(other)), } } @@ -159,10 +153,9 @@ impl PyWindowExpr { /// Returns a Pywindow frame for a given window function expression pub fn get_frame(&self, expr: PyExpr) -> Option { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { window_frame, .. }, - .. - }) => Some(window_frame.into()), + Expr::WindowFunction(boxed_window_fn) => { + Some(boxed_window_fn.params.window_frame.into()) + } _ => None, } } diff --git a/src/functions.rs b/src/functions.rs index caa79b8ad..b2bafcb65 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -103,7 +103,7 @@ fn array_cat(exprs: Vec) -> PyExpr { #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { let index = ScalarValue::Int64(index); - let index = Expr::Literal(index); + let index = Expr::Literal(index, None); datafusion::functions_nested::expr_fn::array_position(array.into(), element.into(), index) .into() } @@ -334,7 +334,7 @@ fn window( .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty()))); Ok(PyExpr { - expr: datafusion::logical_expr::Expr::WindowFunction(WindowFunction { + expr: datafusion::logical_expr::Expr::WindowFunction(Box::new(WindowFunction { fun, params: WindowFunctionParams { args: args.into_iter().map(|x| x.expr).collect::>(), @@ -351,7 +351,7 @@ fn window( window_frame, null_treatment: None, }, - }), + })), }) } diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 4b4c86597..7fbb1dc2a 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -61,7 +61,7 @@ fn extract_scalar_list<'py>( .iter() .map(|expr| match expr { // TODO: should we also leverage `ScalarValue::to_pyarrow` here? - Expr::Literal(v) => match v { + Expr::Literal(v, _) => match v { // The unwraps here are for infallible conversions ScalarValue::Boolean(Some(b)) => Ok(b.into_bound_py_any(py)?), ScalarValue::Int8(Some(i)) => Ok(i.into_bound_py_any(py)?), @@ -106,7 +106,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let op_module = Python::import(py, "operator")?; let pc_expr: PyDataFusionResult> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), - Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)), + Expr::Literal(scalar, _) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; diff --git a/src/udwf.rs b/src/udwf.rs index defd9c522..a0c8cc59a 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -300,13 +300,9 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self.signature } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { // TODO: Should nullable always be `true`? - Ok(arrow::datatypes::Field::new( - field_args.name(), - self.return_type.clone(), - true, - )) + Ok(arrow::datatypes::Field::new(field_args.name(), self.return_type.clone(), true).into()) } // TODO: Enable passing partition_evaluator_args to python? From 9b6acec075f49d551a2b90608b0c7114de84d718 Mon Sep 17 00:00:00 2001 From: Michele Gregori Date: Thu, 19 Jun 2025 19:58:22 +0200 Subject: [PATCH 02/14] Support types other than String and Int for partition columns (#1154) * impl impl * fix test * format rust * support for old logic dasdas * also on io * fix formatting --------- Co-authored-by: michele gregori --- python/datafusion/context.py | 66 ++++++++++++++++++++++---- python/datafusion/io.py | 8 ++-- python/tests/test_sql.py | 26 ++++++----- src/context.rs | 89 +++++++++++++++++++++++------------- 4 files changed, 132 insertions(+), 57 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 4ed465c99..5b99b0d26 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,8 +19,11 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Protocol +import pyarrow as pa + try: from warnings import deprecated # Python 3.13+ except ImportError: @@ -42,7 +45,6 @@ import pandas as pd import polars as pl - import pyarrow as pa from datafusion.plan import ExecutionPlan, LogicalPlan @@ -539,7 +541,7 @@ def register_listing_table( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, file_sort_order: list[list[Expr | SortExpr]] | None = None, @@ -560,6 +562,7 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order_raw = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -778,7 +781,7 @@ def register_parquet( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -806,6 +809,7 @@ def register_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_parquet( name, str(path), @@ -869,7 +873,7 @@ def register_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> None: """Register a JSON file as a table. @@ -890,6 +894,7 @@ def register_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_json( name, str(path), @@ -906,7 +911,7 @@ def register_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, file_extension: str = ".avro", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, ) -> None: """Register an Avro file as a table. @@ -922,6 +927,7 @@ def register_avro( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -981,7 +987,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -1001,6 +1007,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) return DataFrame( self.ctx.read_json( str(path), @@ -1020,7 +1027,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -1045,6 +1052,7 @@ def read_csv( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) path = [str(p) for p in path] if isinstance(path, list) else str(path) @@ -1064,7 +1072,7 @@ def read_csv( def read_parquet( self, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -1093,6 +1101,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -1114,7 +1123,7 @@ def read_avro( self, path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. @@ -1130,6 +1139,7 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] + file_partition_cols = self._convert_table_partition_cols(file_partition_cols) return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) @@ -1146,3 +1156,41 @@ def read_table(self, table: Table) -> DataFrame: def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + + @staticmethod + def _convert_table_partition_cols( + table_partition_cols: list[tuple[str, str | pa.DataType]], + ) -> list[tuple[str, pa.DataType]]: + warn = False + converted_table_partition_cols = [] + + for col, data_type in table_partition_cols: + if isinstance(data_type, str): + warn = True + if data_type == "string": + converted_data_type = pa.string() + elif data_type == "int": + converted_data_type = pa.int32() + else: + message = ( + f"Unsupported literal data type '{data_type}' for partition " + "column. Supported types are 'string' and 'int'" + ) + raise ValueError(message) + else: + converted_data_type = data_type + + converted_table_partition_cols.append((col, converted_data_type)) + + if warn: + message = ( + "using literals for table_partition_cols data types is deprecated," + "use pyarrow types instead" + ) + warnings.warn( + message, + category=DeprecationWarning, + stacklevel=2, + ) + + return converted_table_partition_cols diff --git a/python/datafusion/io.py b/python/datafusion/io.py index ef5ebf96f..551e20a6f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -34,7 +34,7 @@ def read_parquet( path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -83,7 +83,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -124,7 +124,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -171,7 +171,7 @@ def read_csv( def read_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index b6348e3a0..41cee4ef3 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} -@pytest.mark.parametrize("path_to_str", [True, False]) -def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): +@pytest.mark.parametrize( + ("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)] +) +def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a").mkdir(exist_ok=False) @@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = str(dir_root) if path_to_str else dir_root + partition_data_type = "string" if legacy_data_type else pa.string() + ctx.register_parquet( "datapp", dir_root, - table_partition_cols=[("grp", "string")], + table_partition_cols=[("grp", partition_data_type)], parquet_pruning=True, file_extension=".parquet", ) @@ -488,9 +492,9 @@ def test_register_listing_table( ): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) - (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True) table = pa.Table.from_arrays( [ @@ -501,13 +505,13 @@ def test_register_listing_table( names=["int", "str", "float"], ) pa.parquet.write_table( - table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet" + table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet" + table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" + table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet" ) dir_root = f"file://{dir_root}/" if path_to_str else dir_root @@ -515,7 +519,7 @@ def test_register_listing_table( ctx.register_listing_table( "my_table", dir_root, - table_partition_cols=[("grp", "string"), ("date_id", "int")], + table_partition_cols=[("grp", pa.string()), ("date", pa.date64())], file_extension=".parquet", schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, @@ -531,7 +535,7 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} result = ctx.sql( - "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501 + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501 ).collect() result = pa.Table.from_batches(result) diff --git a/src/context.rs b/src/context.rs index 55c92a8fa..6ce1f12bc 100644 --- a/src/context.rs +++ b/src/context.rs @@ -380,7 +380,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, schema: Option>, file_sort_order: Option>>, @@ -388,7 +388,12 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { let options = ListingOptions::new(Arc::new(ParquetFormat::new())) .with_file_extension(file_extension) - .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .with_table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .with_file_sort_order( file_sort_order .unwrap_or_default() @@ -656,7 +661,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -665,7 +670,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult<()> { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -745,7 +755,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult<()> { @@ -755,7 +765,12 @@ impl PySessionContext { let mut options = NdJsonReadOptions::default() .file_compression_type(parse_file_compression_type(file_compression_type)?) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -778,15 +793,19 @@ impl PySessionContext { path: PathBuf, schema: Option>, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, py: Python, ) -> PyDataFusionResult<()> { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -887,7 +906,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -895,7 +914,12 @@ impl PySessionContext { .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; let mut options = NdJsonReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; @@ -928,7 +952,7 @@ impl PySessionContext { delimiter: &str, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -944,7 +968,12 @@ impl PySessionContext { .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) .file_extension(file_extension) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema = schema.as_ref().map(|x| &x.0); @@ -974,7 +1003,7 @@ impl PySessionContext { pub fn read_parquet( &self, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -983,7 +1012,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -1005,12 +1039,16 @@ impl PySessionContext { &self, path: &str, schema: Option>, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, py: Python, ) -> PyDataFusionResult { - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; let df = if let Some(schema) = schema { options.schema = Some(&schema.0); @@ -1109,21 +1147,6 @@ impl PySessionContext { } } -pub fn convert_table_partition_cols( - table_partition_cols: Vec<(String, String)>, -) -> PyDataFusionResult> { - table_partition_cols - .into_iter() - .map(|(name, ty)| match ty.as_str() { - "string" => Ok((name, DataType::Utf8)), - "int" => Ok((name, DataType::Int32)), - _ => Err(crate::errors::PyDataFusionError::Common(format!( - "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'" - ))), - }) - .collect::, _>>() -} - pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result { From 98dc06b553b69576d50157542f761b44f456dcf2 Mon Sep 17 00:00:00 2001 From: Nuno Faria Date: Sat, 21 Jun 2025 15:20:21 +0100 Subject: [PATCH 03/14] feat: Support Parquet writer options (#1123) * feat: Support Parquet writer options * Create dedicated write_parquet_options function * Rename write_parquet_options to write_parquet_with_options * Merge remote-tracking branch 'origin/main' into write_parquet_options * Fix ruff errors --- python/datafusion/__init__.py | 4 +- python/datafusion/dataframe.py | 221 ++++++++++++++++++ python/tests/test_dataframe.py | 406 +++++++++++++++++++++++++++++++++ src/dataframe.rs | 126 +++++++++- src/lib.rs | 2 + 5 files changed, 757 insertions(+), 2 deletions(-) diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 4f7700251..16d65f685 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -46,7 +46,7 @@ SessionContext, SQLOptions, ) -from .dataframe import DataFrame +from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions from .expr import ( Expr, WindowFrame, @@ -80,6 +80,8 @@ "ExecutionPlan", "Expr", "LogicalPlan", + "ParquetColumnOptions", + "ParquetWriterOptions", "RecordBatch", "RecordBatchStream", "RuntimeEnvBuilder", diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index a1df7e080..769271c7e 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -38,6 +38,8 @@ from typing_extensions import deprecated # Python 3.12 from datafusion._internal import DataFrame as DataFrameInternal +from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal +from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import Expr, SortExpr, sort_or_default from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream @@ -114,6 +116,173 @@ def get_default_level(self) -> Optional[int]: return None +class ParquetWriterOptions: + """Advanced parquet writer options. + + Allows settings the writer options that apply to the entire file. Some options can + also be set on a column by column basis, with the field `column_specific_options` + (see `ParquetColumnOptions`). + + Attributes: + data_pagesize_limit: Sets best effort maximum size of data page in bytes. + write_batch_size: Sets write_batch_size in bytes. + writer_version: Sets parquet writer version. Valid values are `1.0` and + `2.0`. + skip_arrow_metadata: Skip encoding the embedded arrow metadata in the + KV_meta. + compression: Compression type to use. Default is "zstd(3)". + Available compression types are + - "uncompressed": No compression. + - "snappy": Snappy compression. + - "gzip(n)": Gzip compression with level n. + - "brotli(n)": Brotli compression with level n. + - "lz4": LZ4 compression. + - "lz4_raw": LZ4_RAW compression. + - "zstd(n)": Zstandard compression with level n. + dictionary_enabled: Sets if dictionary encoding is enabled. If None, uses + the default parquet writer setting. + dictionary_page_size_limit: Sets best effort maximum dictionary page size, + in bytes. + statistics_enabled: Sets if statistics are enabled for any column Valid + values are `none`, `chunk`, and `page`. If None, uses the default + parquet writer setting. + max_row_group_size: Target maximum number of rows in each row group + (defaults to 1M rows). Writing larger row groups requires more memory to + write, but can get better compression and be faster to read. + created_by: Sets "created by" property. + column_index_truncate_length: Sets column index truncate length. + statistics_truncate_length: Sets statistics truncate length. If None, uses + the default parquet writer setting. + data_page_row_count_limit: Sets best effort maximum number of rows in a data + page. + encoding: Sets default encoding for any column. Valid values are `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. If None, uses the default parquet writer setting. + bloom_filter_on_write: Write bloom filters for all columns when creating + parquet files. + bloom_filter_fpp: Sets bloom filter false positive probability. If None, + uses the default parquet writer setting + bloom_filter_ndv: Sets bloom filter number of distinct values. If None, uses + the default parquet writer setting. + allow_single_file_parallelism: Controls whether DataFusion will attempt to + speed up writing parquet files by serializing them in parallel. Each + column in each row group in each output file are serialized in parallel + leveraging a maximum possible core count of n_files * n_row_groups * + n_columns. + maximum_parallel_row_group_writers: By default parallel parquet writer is + tuned for minimum memory usage in a streaming execution plan. You may + see a performance benefit when writing large parquet files by increasing + `maximum_parallel_row_group_writers` and + `maximum_buffered_record_batches_per_stream` if your system has idle + cores and can tolerate additional memory usage. Boosting these values is + likely worthwhile when writing out already in-memory data, such as from + a cached data frame. + maximum_buffered_record_batches_per_stream: See + `maximum_parallel_row_group_writers`. + column_specific_options: Overrides options for specific columns. If a column + is not a part of this dictionary, it will use the parameters provided here. + """ + + def __init__( + self, + data_pagesize_limit: int = 1024 * 1024, + write_batch_size: int = 1024, + writer_version: str = "1.0", + skip_arrow_metadata: bool = False, + compression: Optional[str] = "zstd(3)", + dictionary_enabled: Optional[bool] = True, + dictionary_page_size_limit: int = 1024 * 1024, + statistics_enabled: Optional[str] = "page", + max_row_group_size: int = 1024 * 1024, + created_by: str = "datafusion-python", + column_index_truncate_length: Optional[int] = 64, + statistics_truncate_length: Optional[int] = None, + data_page_row_count_limit: int = 20_000, + encoding: Optional[str] = None, + bloom_filter_on_write: bool = False, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + allow_single_file_parallelism: bool = True, + maximum_parallel_row_group_writers: int = 1, + maximum_buffered_record_batches_per_stream: int = 2, + column_specific_options: Optional[dict[str, ParquetColumnOptions]] = None, + ) -> None: + """Initialize the ParquetWriterOptions.""" + self.data_pagesize_limit = data_pagesize_limit + self.write_batch_size = write_batch_size + self.writer_version = writer_version + self.skip_arrow_metadata = skip_arrow_metadata + self.compression = compression + self.dictionary_enabled = dictionary_enabled + self.dictionary_page_size_limit = dictionary_page_size_limit + self.statistics_enabled = statistics_enabled + self.max_row_group_size = max_row_group_size + self.created_by = created_by + self.column_index_truncate_length = column_index_truncate_length + self.statistics_truncate_length = statistics_truncate_length + self.data_page_row_count_limit = data_page_row_count_limit + self.encoding = encoding + self.bloom_filter_on_write = bloom_filter_on_write + self.bloom_filter_fpp = bloom_filter_fpp + self.bloom_filter_ndv = bloom_filter_ndv + self.allow_single_file_parallelism = allow_single_file_parallelism + self.maximum_parallel_row_group_writers = maximum_parallel_row_group_writers + self.maximum_buffered_record_batches_per_stream = ( + maximum_buffered_record_batches_per_stream + ) + self.column_specific_options = column_specific_options + + +class ParquetColumnOptions: + """Parquet options for individual columns. + + Contains the available options that can be applied for an individual Parquet column, + replacing the global options in `ParquetWriterOptions`. + + Attributes: + encoding: Sets encoding for the column path. Valid values are: `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. These values are not case-sensitive. If `None`, uses + the default parquet options + dictionary_enabled: Sets if dictionary encoding is enabled for the column path. + If `None`, uses the default parquet options + compression: Sets default parquet compression codec for the column path. Valid + values are `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, + `lz4`, `zstd(level)`, and `lz4_raw`. These values are not case-sensitive. If + `None`, uses the default parquet options. + statistics_enabled: Sets if statistics are enabled for the column Valid values + are: `none`, `chunk`, and `page` These values are not case sensitive. If + `None`, uses the default parquet options. + bloom_filter_enabled: Sets if bloom filter is enabled for the column path. If + `None`, uses the default parquet options. + bloom_filter_fpp: Sets bloom filter false positive probability for the column + path. If `None`, uses the default parquet options. + bloom_filter_ndv: Sets bloom filter number of distinct values. If `None`, uses + the default parquet options. + """ + + def __init__( + self, + encoding: Optional[str] = None, + dictionary_enabled: Optional[bool] = None, + compression: Optional[str] = None, + statistics_enabled: Optional[str] = None, + bloom_filter_enabled: Optional[bool] = None, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + ) -> None: + """Initialize the ParquetColumnOptions.""" + self.encoding = encoding + self.dictionary_enabled = dictionary_enabled + self.compression = compression + self.statistics_enabled = statistics_enabled + self.bloom_filter_enabled = bloom_filter_enabled + self.bloom_filter_fpp = bloom_filter_fpp + self.bloom_filter_ndv = bloom_filter_ndv + + class DataFrame: """Two dimensional table representation of data. @@ -737,6 +906,58 @@ def write_parquet( self.df.write_parquet(str(path), compression.value, compression_level) + def write_parquet_with_options( + self, path: str | pathlib.Path, options: ParquetWriterOptions + ) -> None: + """Execute the :py:class:`DataFrame` and write the results to a Parquet file. + + Allows advanced writer options to be set with `ParquetWriterOptions`. + + Args: + path: Path of the Parquet file to write. + options: Sets the writer parquet options (see `ParquetWriterOptions`). + """ + options_internal = ParquetWriterOptionsInternal( + options.data_pagesize_limit, + options.write_batch_size, + options.writer_version, + options.skip_arrow_metadata, + options.compression, + options.dictionary_enabled, + options.dictionary_page_size_limit, + options.statistics_enabled, + options.max_row_group_size, + options.created_by, + options.column_index_truncate_length, + options.statistics_truncate_length, + options.data_page_row_count_limit, + options.encoding, + options.bloom_filter_on_write, + options.bloom_filter_fpp, + options.bloom_filter_ndv, + options.allow_single_file_parallelism, + options.maximum_parallel_row_group_writers, + options.maximum_buffered_record_batches_per_stream, + ) + + column_specific_options_internal = {} + for column, opts in (options.column_specific_options or {}).items(): + column_specific_options_internal[column] = ParquetColumnOptionsInternal( + bloom_filter_enabled=opts.bloom_filter_enabled, + encoding=opts.encoding, + dictionary_enabled=opts.dictionary_enabled, + compression=opts.compression, + statistics_enabled=opts.statistics_enabled, + bloom_filter_fpp=opts.bloom_filter_fpp, + bloom_filter_ndv=opts.bloom_filter_ndv, + ) + + self.df.write_parquet_with_options( + str(path), + options_internal, + column_specific_options_internal, + ) + def write_json(self, path: str | pathlib.Path) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 64220ce9c..3c9b97f23 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -27,6 +27,8 @@ import pytest from datafusion import ( DataFrame, + ParquetColumnOptions, + ParquetWriterOptions, SessionContext, WindowFrame, column, @@ -66,6 +68,21 @@ def df(): return ctx.from_arrow(batch) +@pytest.fixture +def large_df(): + ctx = SessionContext() + + rows = 100000 + data = { + "a": list(range(rows)), + "b": [f"s-{i}" for i in range(rows)], + "c": [float(i + 0.1) for i in range(rows)], + } + batch = pa.record_batch(data) + + return ctx.from_arrow(batch) + + @pytest.fixture def struct_df(): ctx = SessionContext() @@ -1632,6 +1649,395 @@ def test_write_compressed_parquet_default_compression_level(df, tmp_path, compre df.write_parquet(str(path), compression=compression) +def test_write_parquet_with_options_default_compression(df, tmp_path): + """Test that the default compression is ZSTD.""" + df.write_parquet(tmp_path) + + for file in tmp_path.rglob("*.parquet"): + metadata = pq.ParquetFile(file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["compression"].lower() == "zstd" + + +@pytest.mark.parametrize( + "compression", + ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"], +) +def test_write_parquet_with_options_compression(df, tmp_path, compression): + import re + + path = tmp_path + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + # test that the actual compression scheme is the one written + for _root, _dirs, files in os.walk(path): + for file in files: + if file.endswith(".parquet"): + metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["compression"].lower() == re.sub( + r"\(\d+\)", "", compression + ) + + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + "compression", + ["gzip(12)", "brotli(15)", "zstd(23)"], +) +def test_write_parquet_with_options_wrong_compression_level(df, tmp_path, compression): + path = tmp_path + + with pytest.raises(Exception, match=r"valid compression range .*? exceeded."): + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + +@pytest.mark.parametrize("compression", ["wrong", "wrong(12)"]) +def test_write_parquet_with_options_invalid_compression(df, tmp_path, compression): + path = tmp_path + + with pytest.raises(Exception, match="Unknown or unsupported parquet compression"): + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + +@pytest.mark.parametrize( + ("writer_version", "format_version"), + [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")], +) +def test_write_parquet_with_options_writer_version( + df, tmp_path, writer_version, format_version +): + """Test the Parquet writer version. Note that writer_version=2.0 results in + format_version=2.6""" + if writer_version is None: + df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) + else: + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(writer_version=writer_version) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["format_version"] == format_version + + +@pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"]) +def test_write_parquet_with_options_wrong_writer_version(df, tmp_path, writer_version): + """Test that invalid writer versions in Parquet throw an exception.""" + with pytest.raises( + Exception, match="Unknown or unsupported parquet writer version" + ): + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(writer_version=writer_version) + ) + + +@pytest.mark.parametrize("dictionary_enabled", [True, False, None]) +def test_write_parquet_with_options_dictionary_enabled( + df, tmp_path, dictionary_enabled +): + """Test enabling/disabling the dictionaries in Parquet.""" + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(dictionary_enabled=dictionary_enabled) + ) + # by default, the dictionary is enabled, so None results in True + result = dictionary_enabled if dictionary_enabled is not None else True + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["has_dictionary_page"] == result + + +@pytest.mark.parametrize( + ("statistics_enabled", "has_statistics"), + [("page", True), ("chunk", True), ("none", False), (None, True)], +) +def test_write_parquet_with_options_statistics_enabled( + df, tmp_path, statistics_enabled, has_statistics +): + """Test configuring the statistics in Parquet. In pyarrow we can only check for + column-level statistics, so "page" and "chunk" are tested in the same way.""" + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(statistics_enabled=statistics_enabled) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + if has_statistics: + assert col["statistics"] is not None + else: + assert col["statistics"] is None + + +@pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000]) +def test_write_parquet_with_options_max_row_group_size( + large_df, tmp_path, max_row_group_size +): + """Test configuring the max number of rows per group in Parquet. These test cases + guarantee that the number of rows for each row group is max_row_group_size, given + the total number of rows is a multiple of max_row_group_size.""" + large_df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(max_row_group_size=max_row_group_size) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + for row_group in metadata["row_groups"]: + assert row_group["num_rows"] == max_row_group_size + + +@pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"]) +def test_write_parquet_with_options_created_by(df, tmp_path, created_by): + """Test configuring the created by metadata in Parquet.""" + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(created_by=created_by)) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["created_by"] == created_by + + +@pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50]) +def test_write_parquet_with_options_statistics_truncate_length( + df, tmp_path, statistics_truncate_length +): + """Test configuring the truncate limit in Parquet's row-group-level statistics.""" + ctx = SessionContext() + data = { + "a": [ + "a_the_quick_brown_fox_jumps_over_the_lazy_dog", + "m_the_quick_brown_fox_jumps_over_the_lazy_dog", + "z_the_quick_brown_fox_jumps_over_the_lazy_dog", + ], + "b": ["a_smaller", "m_smaller", "z_smaller"], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, + ParquetWriterOptions(statistics_truncate_length=statistics_truncate_length), + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + statistics = col["statistics"] + assert len(statistics["min"]) <= statistics_truncate_length + assert len(statistics["max"]) <= statistics_truncate_length + + +def test_write_parquet_with_options_default_encoding(tmp_path): + """Test that, by default, Parquet files are written with dictionary encoding. + Note that dictionary encoding is not used for boolean values, so it is not tested + here.""" + ctx = SessionContext() + data = { + "a": [1, 2, 3], + "b": ["1", "2", "3"], + "c": [1.01, 2.02, 3.03], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY") + + +@pytest.mark.parametrize( + ("encoding", "data_types", "result"), + [ + ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")), + ("rle", ["bool"], ("RLE",)), + ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")), + ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")), + ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")), + ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")), + ], +) +def test_write_parquet_with_options_encoding(tmp_path, encoding, data_types, result): + """Test different encodings in Parquet in their respective support column types.""" + ctx = SessionContext() + + data = {} + for data_type in data_types: + if data_type == "int": + data["int"] = [1, 2, 3] + elif data_type == "float": + data["float"] = [1.01, 2.02, 3.03] + elif data_type == "str": + data["str"] = ["a", "b", "c"] + elif data_type == "bool": + data["bool"] = [True, False, True] + + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(encoding=encoding, dictionary_enabled=False) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == result + + +@pytest.mark.parametrize("encoding", ["bit_packed"]) +def test_write_parquet_with_options_unsupported_encoding(df, tmp_path, encoding): + """Test that unsupported Parquet encodings do not work.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises(BaseException, match="Encoding .*? is not supported"): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +@pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"]) +def test_write_parquet_with_options_invalid_encoding(df, tmp_path, encoding): + """Test that invalid Parquet encodings do not work.""" + with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +@pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"]) +def test_write_parquet_with_options_dictionary_encoding_fallback( + df, tmp_path, encoding +): + """Test that the dictionary encoding cannot be used as fallback in Parquet.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises( + BaseException, match="Dictionary encoding can not be used as fallback encoding" + ): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +def test_write_parquet_with_options_bloom_filter(df, tmp_path): + """Test Parquet files with and without (default) bloom filters. Since pyarrow does + not expose any information about bloom filters, the easiest way to confirm that they + are actually written is to compare the file size.""" + path_no_bloom_filter = tmp_path / "1" + path_bloom_filter = tmp_path / "2" + + df.write_parquet_with_options(path_no_bloom_filter, ParquetWriterOptions()) + df.write_parquet_with_options( + path_bloom_filter, ParquetWriterOptions(bloom_filter_on_write=True) + ) + + size_no_bloom_filter = 0 + for file in path_no_bloom_filter.rglob("*.parquet"): + size_no_bloom_filter += os.path.getsize(file) + + size_bloom_filter = 0 + for file in path_bloom_filter.rglob("*.parquet"): + size_bloom_filter += os.path.getsize(file) + + assert size_no_bloom_filter < size_bloom_filter + + +def test_write_parquet_with_options_column_options(df, tmp_path): + """Test writing Parquet files with different options for each column, which replace + the global configs (when provided).""" + data = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [False, True, False], + "d": [1.01, 2.02, 3.03], + "e": [4, 5, 6], + } + + column_specific_options = { + "a": ParquetColumnOptions(statistics_enabled="none"), + "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False), + "c": ParquetColumnOptions( + compression="snappy", encoding="rle", dictionary_enabled=False + ), + "d": ParquetColumnOptions( + compression="zstd(6)", + encoding="byte_stream_split", + dictionary_enabled=False, + statistics_enabled="none", + ), + # column "e" will use the global configs + } + + results = { + "a": { + "statistics": False, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + "b": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE"), + }, + "c": { + "statistics": True, + "compression": "snappy", + "encodings": ("RLE",), + }, + "d": { + "statistics": False, + "compression": "zstd", + "encodings": ("RLE", "BYTE_STREAM_SPLIT"), + }, + "e": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + } + + ctx = SessionContext() + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, + ParquetWriterOptions( + compression="brotli(8)", column_specific_options=column_specific_options + ), + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + column_name = col["path_in_schema"] + result = results[column_name] + assert (col["statistics"] is not None) == result["statistics"] + assert col["compression"].lower() == result["compression"].lower() + assert col["encodings"] == result["encodings"] + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export diff --git a/src/dataframe.rs b/src/dataframe.rs index 7711a0782..3d68db279 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::ffi::CString; use std::sync::Arc; @@ -27,7 +28,7 @@ use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; -use datafusion::config::{CsvOptions, TableParquetOptions}; +use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; @@ -185,6 +186,101 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< Ok(config) } +/// Python mapping of `ParquetOptions` (includes just the writer-related options). +#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetWriterOptions { + options: ParquetOptions, +} + +#[pymethods] +impl PyParquetWriterOptions { + #[new] + #[allow(clippy::too_many_arguments)] + pub fn new( + data_pagesize_limit: usize, + write_batch_size: usize, + writer_version: String, + skip_arrow_metadata: bool, + compression: Option, + dictionary_enabled: Option, + dictionary_page_size_limit: usize, + statistics_enabled: Option, + max_row_group_size: usize, + created_by: String, + column_index_truncate_length: Option, + statistics_truncate_length: Option, + data_page_row_count_limit: usize, + encoding: Option, + bloom_filter_on_write: bool, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + allow_single_file_parallelism: bool, + maximum_parallel_row_group_writers: usize, + maximum_buffered_record_batches_per_stream: usize, + ) -> Self { + Self { + options: ParquetOptions { + data_pagesize_limit, + write_batch_size, + writer_version, + skip_arrow_metadata, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_row_group_size, + created_by, + column_index_truncate_length, + statistics_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + allow_single_file_parallelism, + maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream, + ..Default::default() + }, + } + } +} + +/// Python mapping of `ParquetColumnOptions`. +#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetColumnOptions { + options: ParquetColumnOptions, +} + +#[pymethods] +impl PyParquetColumnOptions { + #[new] + pub fn new( + bloom_filter_enabled: Option, + encoding: Option, + dictionary_enabled: Option, + compression: Option, + statistics_enabled: Option, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + ) -> Self { + Self { + options: ParquetColumnOptions { + bloom_filter_enabled, + encoding, + dictionary_enabled, + compression, + statistics_enabled, + bloom_filter_fpp, + bloom_filter_ndv, + ..Default::default() + }, + } + } +} + /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. @@ -689,6 +785,34 @@ impl PyDataFrame { Ok(()) } + /// Write a `DataFrame` to a Parquet file, using advanced options. + fn write_parquet_with_options( + &self, + path: &str, + options: PyParquetWriterOptions, + column_specific_options: HashMap, + py: Python, + ) -> PyDataFusionResult<()> { + let table_options = TableParquetOptions { + global: options.options, + column_specific_options: column_specific_options + .into_iter() + .map(|(k, v)| (k, v.options)) + .collect(), + ..Default::default() + }; + + wait_for_future( + py, + self.df.as_ref().clone().write_parquet( + path, + DataFrameWriteOptions::new(), + Option::from(table_options), + ), + )??; + Ok(()) + } + /// Executes a query and writes the results to a partitioned JSON file. fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> { wait_for_future( diff --git a/src/lib.rs b/src/lib.rs index 7dced1fbd..1293eee3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,6 +86,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; From 1812a0d3e88976f51b40ccfd1a02fddb17ddb8bb Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Tue, 24 Jun 2025 10:21:53 -0700 Subject: [PATCH 04/14] Fix signature of `__arrow_c_stream__` (#1168) --- python/datafusion/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 769271c7e..1fd63bdc6 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -1053,7 +1053,7 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram columns = list(columns) return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) - def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any: + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """Export an Arrow PyCapsule Stream. This will execute and collect the DataFrame. We will attempt to respect the From 8b7dac2b12dce56cbd3949dc822fc0a270f2a946 Mon Sep 17 00:00:00 2001 From: renato2099 Date: Sun, 1 Jun 2025 22:06:38 +0200 Subject: [PATCH 05/14] Exposing FFI to python Exposing FFI to python Workin progress on python catalog Flushing out schema and catalog providers Adding implementation of python based catalog and schema providers Small updates after rebase --- Cargo.lock | 19 + Cargo.toml | 2 + examples/datafusion-ffi-example/Cargo.lock | 1 + examples/datafusion-ffi-example/Cargo.toml | 1 + .../python/tests/_test_catalog_provider.py | 60 +++ .../src/catalog_provider.rs | 179 +++++++ examples/datafusion-ffi-example/src/lib.rs | 3 + python/datafusion/__init__.py | 3 +- python/datafusion/catalog.py | 86 +++- python/datafusion/context.py | 15 + python/tests/test_catalog.py | 102 +++- src/catalog.rs | 437 ++++++++++++++++-- src/context.rs | 52 ++- src/functions.rs | 2 +- src/lib.rs | 10 +- 15 files changed, 899 insertions(+), 73 deletions(-) create mode 100644 examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py create mode 100644 examples/datafusion-ffi-example/src/catalog_provider.rs diff --git a/Cargo.lock b/Cargo.lock index 112167cb4..a3e9336cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "zstd", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -1503,6 +1509,7 @@ dependencies = [ "datafusion-proto", "datafusion-substrait", "futures", + "log", "mimalloc", "object_store", "prost", @@ -1510,6 +1517,7 @@ dependencies = [ "pyo3", "pyo3-async-runtimes", "pyo3-build-config", + "pyo3-log", "tokio", "url", "uuid", @@ -2953,6 +2961,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.24.2" diff --git a/Cargo.toml b/Cargo.toml index 4135e64e2..1f7895a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} +pyo3-log = "0.12.4" arrow = { version = "55.1.0", features = ["pyarrow"] } datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "48.0.0", optional = true } @@ -49,6 +50,7 @@ async-trait = "0.1.88" futures = "0.3" object_store = { version = "0.12.1", features = ["aws", "gcp", "azure", "http"] } url = "2" +log = "0.4.27" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a1..e5a1ca8d1 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -1448,6 +1448,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "datafusion", "datafusion-ffi", "pyo3", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b9..319163554 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 000000000..72aadf64c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyCatalogProvider + + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ["units", "price"] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 000000000..54e61cf3e --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner.register_table("my_table".to_string(), table).unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table(&self, name: &str) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass( + name = "MyCatalogProvider", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = + FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b65..3a4cf2247 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::catalog_provider::MyCatalogProvider; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; use pyo3::prelude::*; +pub(crate) mod catalog_provider; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -26,5 +28,6 @@ pub(crate) mod table_provider; fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 16d65f685..8e38741bc 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -32,7 +32,7 @@ from datafusion.col import col, column -from . import functions, object_store, substrait, unparser +from . import catalog, functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -93,6 +93,7 @@ "TableFunction", "WindowFrame", "WindowUDF", + "catalog", "col", "column", "common", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 67ab3ead2..bebd38161 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -26,11 +26,23 @@ if TYPE_CHECKING: import pyarrow as pa +try: + from warnings import deprecated # Python 3.13+ +except ImportError: + from typing_extensions import deprecated # Python 3.12 + + +__all__ = [ + "Catalog", + "Schema", + "Table", +] + class Catalog: """DataFusion data catalog.""" - def __init__(self, catalog: df_internal.Catalog) -> None: + def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog @@ -38,39 +50,79 @@ def __repr__(self) -> str: """Print a string representation of the catalog.""" return self.catalog.__repr__() - def names(self) -> list[str]: - """Returns the list of databases in this catalog.""" - return self.catalog.names() + def names(self) -> set[str]: + """This is an alias for `schema_names`.""" + return self.schema_names() + + def schema_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog.schema_names() + + def schema(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + schema = self.catalog.schema(name) + + return ( + Schema(schema) + if isinstance(schema, df_internal.catalog.RawSchema) + else schema + ) - def database(self, name: str = "public") -> Database: + @deprecated("Use `schema` instead.") + def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" - return Database(self.catalog.database(name)) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: + """Register a schema with this catalog.""" + return self.catalog.register_schema(name, schema) -class Database: - """DataFusion Database.""" + def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: + """Deregister a schema from this catalog.""" + return self.catalog.deregister_schema(name, cascade) - def __init__(self, db: df_internal.Database) -> None: + +class Schema: + """DataFusion Schema.""" + + def __init__(self, schema: df_internal.catalog.RawSchema) -> None: """This constructor is not typically called by the end user.""" - self.db = db + self._raw_schema = schema def __repr__(self) -> str: - """Print a string representation of the database.""" - return self.db.__repr__() + """Print a string representation of the schema.""" + return self._raw_schema.__repr__() def names(self) -> set[str]: - """Returns the list of all tables in this database.""" - return self.db.names() + """This is an alias for `table_names`.""" + return self.table_names() + + def table_names(self) -> set[str]: + """Returns the list of all tables in this schema.""" + return self._raw_schema.table_names def table(self, name: str) -> Table: - """Return the table with the given ``name`` from this database.""" - return Table(self.db.table(name)) + """Return the table with the given ``name`` from this schema.""" + return Table(self._raw_schema.table(name)) + + def register_table(self, name, table) -> None: + """Register a table provider in this schema.""" + return self._raw_schema.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Deregister a table provider from this schema.""" + return self._raw_schema.deregister_table(name) + + +@deprecated("Use `Schema` instead.") +class Database(Schema): + """See `Schema`.""" class Table: """DataFusion table.""" - def __init__(self, table: df_internal.Table) -> None: + def __init__(self, table: df_internal.catalog.RawTable) -> None: """This constructor is not typically called by the end user.""" self.table = table diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5b99b0d26..c080931e9 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -80,6 +80,15 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -749,6 +758,12 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def register_catalog_provider( + self, name: str, provider: CatalogProviderExportable + ) -> None: + """Register a catalog provider.""" + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 23b328458..21b0a3e0a 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. +import datafusion as dfn import pyarrow as pa +import pyarrow.dataset as ds import pytest +from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -27,7 +30,7 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ["public"] + assert default.names() == {"public"} for db in [default.database("public"), default.database()]: assert db.names() == {"csv1", "csv", "csv2"} @@ -41,3 +44,100 @@ def test_basic(ctx, database): pa.field("float", pa.float64(), nullable=True), ] ) + + +class CustomTableProvider: + def __init__(self): + pass + + +def create_dataset() -> pa.dataset.Dataset: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ds.dataset([batch]) + + +class CustomSchemaProvider: + def __init__(self): + self.tables = {"table1": create_dataset()} + + def table_names(self) -> set[str]: + return set(self.tables.keys()) + + def register_table(self, name: str, table: Table): + self.tables[name] = table + + def deregister_table(self, name, cascade: bool = True): + del self.tables[name] + + +class CustomCatalogProvider: + def __init__(self): + self.schemas = {"my_schema": CustomSchemaProvider()} + + def schema_names(self) -> set[str]: + return set(self.schemas.keys()) + + def schema(self, name: str): + return self.schemas[name] + + def register_schema(self, name: str, schema: dfn.catalog.Schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade: bool): + del self.schemas[name] + + +def test_python_catalog_provider(ctx: SessionContext): + ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) + + # Check the default catalog provider + assert ctx.catalog("datafusion").names() == {"public"} + + my_catalog = ctx.catalog("my_catalog") + assert my_catalog.names() == {"my_schema"} + + my_catalog.register_schema("second_schema", CustomSchemaProvider()) + assert my_catalog.schema_names() == {"my_schema", "second_schema"} + + my_catalog.deregister_schema("my_schema") + assert my_catalog.schema_names() == {"second_schema"} + + +def test_python_schema_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.deregister_schema("public") + + catalog.register_schema("test_schema1", CustomSchemaProvider()) + assert catalog.names() == {"test_schema1"} + + catalog.register_schema("test_schema2", CustomSchemaProvider()) + catalog.deregister_schema("test_schema1") + assert catalog.names() == {"test_schema2"} + + +def test_python_table_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.register_schema("custom_schema", CustomSchemaProvider()) + schema = catalog.schema("custom_schema") + + assert schema.table_names() == {"table1"} + + schema.deregister_table("table1") + schema.register_table("table2", create_dataset()) + assert schema.table_names() == {"table2"} + + # Use the default schema instead of our custom schema + + schema = catalog.schema() + + schema.register_table("table3", create_dataset()) + assert schema.table_names() == {"table3"} + + schema.deregister_table("table3") + schema.register_table("table4", create_dataset()) + assert schema.table_names() == {"table4"} diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08cb..ba96ce471 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,44 +15,50 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::utils::wait_for_future; +use crate::dataset::Dataset; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::utils::{validate_pycapsule, wait_for_future}; +use async_trait::async_trait; +use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; +use pyo3::IntoPyObjectExt; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; -#[pyclass(name = "Catalog", module = "datafusion", subclass)] +#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub struct PyDatabase { - pub database: Arc, +#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +pub struct PySchema { + pub schema: Arc, } -#[pyclass(name = "Table", module = "datafusion", subclass)] +#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] pub struct PyTable { pub table: Arc, } -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { +impl From> for PyCatalog { + fn from(catalog: Arc) -> Self { Self { catalog } } } -impl PyDatabase { - pub fn new(database: Arc) -> Self { - Self { database } +impl From> for PySchema { + fn from(schema: Arc) -> Self { + Self { schema } } } @@ -68,36 +74,93 @@ impl PyTable { #[pymethods] impl PyCatalog { - fn names(&self) -> Vec { - self.catalog.schema_names() + #[new] + fn new(catalog: PyObject) -> Self { + let catalog_provider = + Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc; + catalog_provider.into() + } + + fn schema_names(&self) -> HashSet { + self.catalog.schema_names().into_iter().collect() } #[pyo3(signature = (name="public"))] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), - None => Err(PyKeyError::new_err(format!( - "Database with name {name} doesn't exist." - ))), - } + fn schema(&self, name: &str) -> PyResult { + let schema = self + .catalog + .schema(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)), + None => PySchema::from(schema).into_py_any(py), + } + }) + } + + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { + let capsule = schema_provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPySchemaProvider::new(schema_provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self + .catalog + .register_schema(name, provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> { + let _ = self + .catalog + .deregister_schema(name, cascade) + .map_err(py_datafusion_err)?; + + Ok(()) } fn __repr__(&self) -> PyResult { - Ok(format!( - "Catalog(schema_names=[{}])", - self.names().join(";") - )) + let mut names: Vec = self.schema_names().into_iter().collect(); + names.sort(); + Ok(format!("Catalog(schema_names=[{}])", names.join(", "))) } } #[pymethods] -impl PyDatabase { - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() +impl PySchema { + #[new] + fn new(schema_provider: PyObject) -> Self { + let schema_provider = + Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc; + schema_provider.into() + } + + #[getter] + fn table_names(&self) -> HashSet { + self.schema.table_names().into_iter().collect() } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))?? { + if let Some(table) = wait_for_future(py, self.schema.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( @@ -107,14 +170,44 @@ impl PyDatabase { } fn __repr__(&self) -> PyResult { - Ok(format!( - "Database(table_names=[{}])", - Vec::from_iter(self.names()).join(";") - )) + let mut names: Vec = self.table_names().into_iter().collect(); + names.sort(); + Ok(format!("Schema(table_names=[{}])", names.join(";"))) + } + + fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let capsule = table_provider + .getattr("__datafusion_table_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + }; + + let _ = self + .schema + .register_table(name.to_string(), provider) + .map_err(py_datafusion_err)?; + + Ok(()) } - // register_table - // deregister_table + fn deregister_table(&self, name: &str) -> PyResult<()> { + let _ = self + .schema + .deregister_table(name) + .map_err(py_datafusion_err)?; + + Ok(()) + } } #[pymethods] @@ -145,3 +238,265 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[derive(Debug)] +pub(crate) struct RustWrappedPySchemaProvider { + schema_provider: PyObject, + owner_name: Option, +} + +impl RustWrappedPySchemaProvider { + pub fn new(schema_provider: PyObject) -> Self { + let owner_name = Python::with_gil(|py| { + schema_provider + .bind(py) + .getattr("owner_name") + .ok() + .map(|name| name.to_string()) + }); + + Self { + schema_provider, + owner_name, + } + } + + fn table_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let py_table_method = provider.getattr("table")?; + + let py_table = py_table_method.call((name,), None)?; + if py_table.is_none() { + return Ok(None); + } + + if py_table.hasattr("__datafusion_table_provider__")? { + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + + Ok(Some(Arc::new(ds) as Arc)) + } + }) + } +} + +#[async_trait] +impl SchemaProvider for RustWrappedPySchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + + provider + .getattr("table_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get table_names: {err}"); + Vec::default() + }) + }) + } + + async fn table( + &self, + name: &str, + ) -> datafusion::common::Result>, DataFusionError> { + self.table_inner(name).map_err(to_datafusion_err) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion::common::Result>> { + let py_table = PyTable::new(table); + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let _ = provider + .call_method1("register_table", (name, py_table)) + .map_err(to_datafusion_err)?; + // Since the definition of `register_table` says that an error + // will be returned if the table already exists, there is no + // case where we want to return a table provider as output. + Ok(None) + }) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let table = provider + .call_method1("deregister_table", (name,)) + .map_err(to_datafusion_err)?; + if table.is_none() { + return Ok(None); + } + + // If we can turn this table provider into a `Dataset`, return it. + // Otherwise, return None. + let dataset = match Dataset::new(&table, py) { + Ok(dataset) => Some(Arc::new(dataset) as Arc), + Err(_) => None, + }; + + Ok(dataset) + }) + } + + fn table_exist(&self, name: &str) -> bool { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .call_method1("table_exist", (name,)) + .and_then(|pyobj| pyobj.extract()) + .unwrap_or(false) + }) + } +} + +#[derive(Debug)] +pub(crate) struct RustWrappedPyCatalogProvider { + pub(crate) catalog_provider: PyObject, +} + +impl RustWrappedPyCatalogProvider { + pub fn new(catalog_provider: PyObject) -> Self { + Self { catalog_provider } + } + + fn schema_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + + let py_schema = provider.call_method1("schema", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + if py_schema.hasattr("__datafusion_schema_provider__")? { + let capsule = provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + + Ok(Some(Arc::new(py_schema) as Arc)) + } + }) + } +} + +#[async_trait] +impl CatalogProvider for RustWrappedPyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + provider + .getattr("schema_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get schema_names: {err}"); + Vec::default() + }) + }) + } + + fn schema(&self, name: &str) -> Option> { + self.schema_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider schema returned error: {err}"); + None + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion::common::Result>> { + // JRIGHT HERE + // let py_schema: PySchema = schema.into(); + Python::with_gil(|py| { + let py_schema = match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(), + None => &PySchema::from(schema) + .into_py_any(py) + .map_err(to_datafusion_err)?, + }; + + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("register_schema", (name, py_schema)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("deregister_schema", (name, cascade)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/context.rs b/src/context.rs index 6ce1f12bc..cb15c5f0b 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -69,8 +70,10 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -614,6 +617,31 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + let provider = if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider + .getattr("__datafusion_catalog_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPyCatalogProvider::new(provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self.ctx.register_catalog(name, provider); + + Ok(()) + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, @@ -845,14 +873,20 @@ impl PySessionContext { } #[pyo3(signature = (name="datafusion"))] - pub fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name, - ))), - } + pub fn catalog(&self, name: &str) -> PyResult { + let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( + "Catalog with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), + None => PyCatalog::from(catalog).into_py_any(py), + } + }) } pub fn tables(&self) -> HashSet { diff --git a/src/functions.rs b/src/functions.rs index b2bafcb65..b40500b8b 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log))?; + m.add_wrapped(wrap_pyfunction!(self::log))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; m.add_wrapped(wrap_pyfunction!(lower))?; diff --git a/src/lib.rs b/src/lib.rs index 1293eee3c..29d3f41da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,10 +77,10 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// datafusion directory. #[pymodule] fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + // Initialize logging + pyo3_log::init(); + // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -98,6 +98,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + let catalog = PyModule::new(py, "catalog")?; + catalog::init_module(&catalog)?; + m.add_submodule(&catalog)?; + // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?; common::init_module(&common)?; From e75237fe615945d9a7867ab713cfd9214a3f3d8c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 15 Feb 2025 13:37:07 -0500 Subject: [PATCH 06/14] Intermediate work adding ffi scalar udf Add scalar UDF and example Add aggregate udf via ffi Initial commit for window ffi integration Remove patch --- docs/source/contributor-guide/ffi.rst | 2 +- examples/datafusion-ffi-example/Cargo.lock | 217 ++++++++++-------- examples/datafusion-ffi-example/Cargo.toml | 8 +- .../python/tests/_test_aggregate_udf.py | 77 +++++++ .../python/tests/_test_scalar_udf.py | 70 ++++++ .../python/tests/_test_window_udf.py | 89 +++++++ .../src/aggregate_udf.rs | 81 +++++++ examples/datafusion-ffi-example/src/lib.rs | 9 + .../datafusion-ffi-example/src/scalar_udf.rs | 91 ++++++++ .../datafusion-ffi-example/src/window_udf.rs | 81 +++++++ python/datafusion/user_defined.py | 107 ++++++++- src/functions.rs | 2 +- src/udaf.rs | 31 ++- src/udf.rs | 25 +- src/udwf.rs | 27 ++- 15 files changed, 805 insertions(+), 112 deletions(-) create mode 100644 examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py create mode 100644 examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py create mode 100644 examples/datafusion-ffi-example/python/tests/_test_window_udf.py create mode 100644 examples/datafusion-ffi-example/src/aggregate_udf.rs create mode 100644 examples/datafusion-ffi-example/src/scalar_udf.rs create mode 100644 examples/datafusion-ffi-example/src/window_udf.rs diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index c1f9806b3..a40af1234 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -176,7 +176,7 @@ By convention the ``datafusion-python`` library expects a Python object that has ``TableProvider`` PyCapsule to have this capsule accessible by calling a function named ``__datafusion_table_provider__``. You can see a complete working example of how to share a ``TableProvider`` from one python library to DataFusion Python in the -`repository examples folder `_. +`repository examples folder `_. This section has been written using ``TableProvider`` as an example. It is the first extension that has been written using this approach and the most thoroughly implemented. diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index e5a1ca8d1..1b4ca6bee 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -323,6 +323,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] @@ -748,9 +750,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080" +checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96" dependencies = [ "arrow", "arrow-ipc", @@ -775,7 +777,6 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", - "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -790,7 +791,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "regex", "sqlparser", "tempfile", @@ -803,9 +804,9 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66" +checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee" dependencies = [ "arrow", "async-trait", @@ -829,9 +830,9 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00" +checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038" dependencies = [ "arrow", "async-trait", @@ -852,9 +853,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c" +checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12" dependencies = [ "ahash", "arrow", @@ -876,9 +877,9 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b" +checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2" dependencies = [ "futures", "log", @@ -887,9 +888,9 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1" +checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b" dependencies = [ "arrow", "async-compression", @@ -912,7 +913,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -923,9 +924,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6" +checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc" dependencies = [ "arrow", "async-trait", @@ -948,9 +949,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352" +checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c" dependencies = [ "arrow", "async-trait", @@ -973,9 +974,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923" +checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b" dependencies = [ "arrow", "async-trait", @@ -998,21 +999,21 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292" +checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea" [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84" +checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a" dependencies = [ "arrow", "dashmap", @@ -1022,16 +1023,16 @@ dependencies = [ "log", "object_store", "parking_lot", - "rand", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964" +checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b" dependencies = [ "arrow", "chrono", @@ -1050,9 +1051,9 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839" +checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a" dependencies = [ "arrow", "datafusion-common", @@ -1063,9 +1064,9 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c" +checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef" dependencies = [ "abi_stable", "arrow", @@ -1073,7 +1074,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "futures", "log", "prost", @@ -1081,11 +1084,25 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-ffi-example" +version = "0.2.0" +dependencies = [ + "arrow", + "arrow-array", + "arrow-schema", + "async-trait", + "datafusion", + "datafusion-ffi", + "pyo3", + "pyo3-build-config", +] + [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e" +checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357" dependencies = [ "arrow", "arrow-buffer", @@ -1103,7 +1120,7 @@ dependencies = [ "itertools", "log", "md-5", - "rand", + "rand 0.9.1", "regex", "sha2", "unicode-segmentation", @@ -1112,9 +1129,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba" +checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4" dependencies = [ "ahash", "arrow", @@ -1133,9 +1150,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4" +checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696" dependencies = [ "ahash", "arrow", @@ -1146,9 +1163,9 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537" +checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b" dependencies = [ "arrow", "arrow-ord", @@ -1167,9 +1184,9 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a" +checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b" dependencies = [ "arrow", "async-trait", @@ -1183,10 +1200,11 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193" +checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4" dependencies = [ + "arrow", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -1200,9 +1218,9 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313" +checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1210,9 +1228,9 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1" +checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8" dependencies = [ "datafusion-expr", "quote", @@ -1221,9 +1239,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515" +checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058" dependencies = [ "arrow", "chrono", @@ -1240,9 +1258,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65" +checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6" dependencies = [ "ahash", "arrow", @@ -1262,9 +1280,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863" +checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df" dependencies = [ "ahash", "arrow", @@ -1276,9 +1294,9 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2" +checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae" dependencies = [ "arrow", "datafusion-common", @@ -1295,9 +1313,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55" +checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23" dependencies = [ "ahash", "arrow", @@ -1325,9 +1343,9 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83" +checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6" dependencies = [ "arrow", "chrono", @@ -1341,9 +1359,9 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f" +checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e" dependencies = [ "arrow", "datafusion-common", @@ -1352,9 +1370,9 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f" +checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49" dependencies = [ "arrow", "async-trait", @@ -1376,9 +1394,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607" +checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df" dependencies = [ "arrow", "bigdecimal", @@ -1441,20 +1459,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "ffi-table-provider" -version = "0.1.0" -dependencies = [ - "arrow", - "arrow-array", - "arrow-schema", - "async-trait", - "datafusion", - "datafusion-ffi", - "pyo3", - "pyo3-build-config", -] - [[package]] name = "fixedbitset" version = "0.5.7" @@ -1488,6 +1492,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1666,6 +1676,11 @@ name = "hashbrown" version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -2271,12 +2286,14 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset", + "hashbrown 0.15.3", "indexmap", + "serde", ] [[package]] @@ -2305,7 +2322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -2484,19 +2501,27 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.3", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -2504,8 +2529,14 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.3.3", ] [[package]] @@ -3032,9 +3063,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 319163554..b26ab48e3 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -16,13 +16,13 @@ # under the License. [package] -name = "ffi-table-provider" -version = "0.1.0" +name = "datafusion-ffi-example" +version = "0.2.0" edition = "2021" [dependencies] -datafusion = { version = "47.0.0" } -datafusion-ffi = { version = "47.0.0" } +datafusion = { version = "48.0.0" } +datafusion-ffi = { version = "48.0.0" } pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } diff --git a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py new file mode 100644 index 000000000..7ea6b295c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udaf +from datafusion_ffi_example import MySumUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3, None], type=pa.int64()), + pa.array([1, 1, 2, 2], type=pa.int64()), + ], + names=["a", "b"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_aggregate_register(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + ctx.register_udaf(my_udaf) + + result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect() + + assert len(result) == 2 + assert result[0].num_columns == 1 + + result = [r.column(0) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected + + +def test_ffi_aggregate_call_directly(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + + result = ( + ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect() + ) + + assert len(result) == 2 + assert result[0].num_columns == 2 + + result = [r.column(1) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py new file mode 100644 index 000000000..0c949c34a --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udf +from datafusion_ffi_example import IsNullUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, None])], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_scalar_register(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + ctx.register_udf(my_udf) + + result = ctx.sql("select my_custom_is_null(a) from test_table").collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected + + +def test_ffi_scalar_call_directly(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + + result = ctx.table("test_table").select(my_udf(col("a"))).collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_window_udf.py b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py new file mode 100644 index 000000000..7d96994b9 --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udwf +from datafusion_ffi_example import MyRankUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([40, 10, 30, 20], type=pa.int64()), + ], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_window_register(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + ctx.register_udwf(my_udwf) + + result = ctx.sql( + "select a, my_custom_rank() over (order by a) from test_table" + ).collect() + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected + + +def test_ffi_window_call_directly(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + + result = ( + ctx.table("test_table") + .select(col("a"), my_udwf().order_by(col("a")).build()) + .collect() + ) + + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected diff --git a/examples/datafusion-ffi-example/src/aggregate_udf.rs b/examples/datafusion-ffi-example/src/aggregate_udf.rs new file mode 100644 index 000000000..9481fe9c6 --- /dev/null +++ b/examples/datafusion-ffi-example/src/aggregate_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_aggregate::sum::Sum; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; +use datafusion_ffi::udaf::FFI_AggregateUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MySumUDF { + inner: Arc, +} + +#[pymethods] +impl MySumUDF { + #[new] + fn new() -> Self { + Self { + inner: Arc::new(Sum::new()), + } + } + + fn __datafusion_aggregate_udf__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_aggregate_udf".into(); + + let func = Arc::new(AggregateUDF::from(self.clone())); + let provider = FFI_AggregateUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl AggregateUDFImpl for MySumUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_sum" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + self.inner.return_type(arg_types) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> DataFusionResult> { + self.inner.accumulator(acc_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 3a4cf2247..f5f96cd49 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,19 +15,28 @@ // specific language governing permissions and limitations // under the License. +use crate::aggregate_udf::MySumUDF; use crate::catalog_provider::MyCatalogProvider; +use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; +use crate::window_udf::MyRankUDF; use pyo3::prelude::*; +pub(crate) mod aggregate_udf; pub(crate) mod catalog_provider; +pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; +pub(crate) mod window_udf; #[pymodule] fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/scalar_udf.rs b/examples/datafusion-ffi-example/src/scalar_udf.rs new file mode 100644 index 000000000..727666638 --- /dev/null +++ b/examples/datafusion-ffi-example/src/scalar_udf.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::common::ScalarValue; +use datafusion::error::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_ffi::udf::FFI_ScalarUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "IsNullUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct IsNullUDF { + signature: Signature, +} + +#[pymethods] +impl IsNullUDF { + #[new] + fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } + } + + fn __datafusion_scalar_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_scalar_udf".into(); + + let func = Arc::new(ScalarUDF::from(self.clone())); + let provider = FFI_ScalarUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl ScalarUDFImpl for IsNullUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_is_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let input = &args.args[0]; + + Ok(match input { + ColumnarValue::Array(arr) => match arr.is_nullable() { + true => { + let nulls = arr.nulls().unwrap(); + let nulls = BooleanArray::from_iter(nulls.iter().map(|x| Some(!x))); + ColumnarValue::Array(Arc::new(nulls)) + } + false => ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))), + }, + ColumnarValue::Scalar(sv) => { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(sv == &ScalarValue::Null))) + } + }) + } +} diff --git a/examples/datafusion-ffi-example/src/window_udf.rs b/examples/datafusion-ffi-example/src/window_udf.rs new file mode 100644 index 000000000..e0d397956 --- /dev/null +++ b/examples/datafusion-ffi-example/src/window_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, FieldRef}; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_window::rank::rank_udwf; +use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; +use datafusion::logical_expr::{PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl}; +use datafusion_ffi::udwf::FFI_WindowUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MyRankUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MyRankUDF { + inner: Arc, +} + +#[pymethods] +impl MyRankUDF { + #[new] + fn new() -> Self { + Self { inner: rank_udwf() } + } + + fn __datafusion_window_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_window_udf".into(); + + let func = Arc::new(WindowUDF::from(self.clone())); + let provider = FFI_WindowUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl WindowUDFImpl for MyRankUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_rank" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> DataFusionResult> { + self.inner + .inner() + .partition_evaluator(partition_evaluator_args) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> DataFusionResult { + self.inner.inner().field(field_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index dd634c7fb..bd686acbb 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload import pyarrow as pa @@ -77,6 +77,12 @@ def __str__(self) -> str: return self.name.lower() +class ScalarUDFExportable(Protocol): + """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" + + def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). @@ -96,6 +102,9 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if hasattr(func, "__datafusion_scalar_udf__"): + self._udf = df_internal.ScalarUDF.from_pycapsule(func) + return if isinstance(input_types, pa.DataType): input_types = [input_types] self._udf = df_internal.ScalarUDF( @@ -134,6 +143,10 @@ def udf( name: Optional[str] = None, ) -> ScalarUDF: ... + @overload + @staticmethod + def udf(func: ScalarUDFExportable) -> ScalarUDF: ... + @staticmethod def udf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Function (UDF). @@ -147,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Args: func (Callable, optional): Only needed when calling as a function. - Skip this argument when using ``udf`` as a decorator. + Skip this argument when using `udf` as a decorator. If you have a Rust + backed ScalarUDF within a PyCapsule, you can pass this parameter + and ignore the rest. They will be determined directly from the + underlying function. See the online documentation for more information. input_types (list[pa.DataType]): The data types of the arguments to ``func``. This list must be of the same length as the number of arguments. @@ -215,12 +231,31 @@ def wrapper(*args: Any, **kwargs: Any): return decorator + if hasattr(args[0], "__datafusion_scalar_udf__"): + return ScalarUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF: + """Create a Scalar UDF from ScalarUDF PyCapsule object. + + This function will instantiate a Scalar UDF that uses a DataFusion + ScalarUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return ScalarUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -242,6 +277,12 @@ def evaluate(self) -> pa.Scalar: """Return the resultant value.""" +class AggregateUDFExportable(Protocol): + """Type hint for object that has __datafusion_aggregate_udf__ PyCapsule.""" + + def __datafusion_aggregate_udf__(self) -> object: ... # noqa: D105 + + class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -263,6 +304,9 @@ def __init__( See :py:func:`udaf` for a convenience function and argument descriptions. """ + if hasattr(accumulator, "__datafusion_aggregate_udf__"): + self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator) + return self._udaf = df_internal.AggregateUDF( name, accumulator, @@ -307,7 +351,7 @@ def udaf( ) -> AggregateUDF: ... @staticmethod - def udaf(*args: Any, **kwargs: Any): # noqa: D417 + def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Aggregate Function (UDAF). This class allows you to define an aggregate function that can be used in @@ -364,6 +408,10 @@ def udf4() -> Summarize: Args: accum: The accumulator python function. Only needed when calling as a function. Skip this argument when using ``udaf`` as a decorator. + If you have a Rust backed AggregateUDF within a PyCapsule, you can + pass this parameter and ignore the rest. They will be determined + directly from the underlying function. See the online documentation + for more information. input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. @@ -422,12 +470,32 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + if hasattr(args[0], "__datafusion_aggregate_udf__"): + return AggregateUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF: + """Create an Aggregate UDF from AggregateUDF PyCapsule object. + + This function will instantiate a Aggregate UDF that uses a DataFusion + AggregateUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return AggregateUDF( + name=name, + accumulator=func, + input_types=None, + return_type=None, + state_type=None, + volatility=None, + ) + class WindowEvaluator: """Evaluator class for user-defined window functions (UDWF). @@ -588,6 +656,12 @@ def include_rank(self) -> bool: return False +class WindowUDFExportable(Protocol): + """Type hint for object that has __datafusion_window_udf__ PyCapsule.""" + + def __datafusion_window_udf__(self) -> object: ... # noqa: D105 + + class WindowUDF: """Class for performing window user-defined functions (UDF). @@ -608,6 +682,9 @@ def __init__( See :py:func:`udwf` for a convenience function and argument descriptions. """ + if hasattr(func, "__datafusion_window_udf__"): + self._udwf = df_internal.WindowUDF.from_pycapsule(func) + return self._udwf = df_internal.WindowUDF( name, func, input_types, return_type, str(volatility) ) @@ -683,7 +760,10 @@ def biased_numbers() -> BiasedNumbers: Args: func: Only needed when calling as a function. Skip this argument when - using ``udwf`` as a decorator. + using ``udwf`` as a decorator. If you have a Rust backed WindowUDF + within a PyCapsule, you can pass this parameter and ignore the rest. + They will be determined directly from the underlying function. See + the online documentation for more information. input_types: The data types of the arguments. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. @@ -692,6 +772,9 @@ def biased_numbers() -> BiasedNumbers: Returns: A user-defined window function that can be used in window function calls. """ + if hasattr(args[0], "__datafusion_window_udf__"): + return WindowUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return WindowUDF._create_window_udf(*args, **kwargs) @@ -759,6 +842,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + @staticmethod + def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: + """Create a Window UDF from WindowUDF PyCapsule object. + + This function will instantiate a Window UDF that uses a DataFusion + WindowUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return WindowUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class TableFunction: """Class for performing user-defined table functions (UDTF). diff --git a/src/functions.rs b/src/functions.rs index b40500b8b..eeef48385 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -682,7 +682,7 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } -// We handle first_value explicitly because the signature expects an order_by +// We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] #[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))] diff --git a/src/udaf.rs b/src/udaf.rs index 34a9cd51d..78f4e2b0c 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -19,6 +19,10 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; +use crate::common::data_type::PyScalarValue; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::expr::PyExpr; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::array::{Array, ArrayRef}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; @@ -27,11 +31,8 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF, }; - -use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; -use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF}; +use pyo3::types::PyCapsule; #[derive(Debug)] struct RustAccumulator { @@ -183,6 +184,26 @@ impl PyAggregateUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_aggregate_udf__")? { + let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_aggregate_udf")?; + + let udaf = unsafe { capsule.reference::() }; + let udaf: ForeignAggregateUDF = udaf.try_into()?; + + Ok(Self { + function: udaf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udf.rs b/src/udf.rs index 574c9d7b5..de1e3f18c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -17,6 +17,8 @@ use std::sync::Arc; +use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; +use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; @@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::logical_expr::{create_udf, ColumnarValue}; use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; /// Create a Rust callable function from a python function that expects pyarrow arrays fn pyarrow_function_to_rust( @@ -105,6 +108,26 @@ impl PyScalarUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_scalar_udf__")? { + let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_scalar_udf")?; + + let udf = unsafe { capsule.reference::() }; + let udf: ForeignScalarUDF = udf.try_into()?; + + Ok(Self { + function: udf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udwf.rs b/src/udwf.rs index a0c8cc59a..4fb98916b 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -27,16 +27,17 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, }; -use pyo3::types::{PyList, PyTuple}; +use datafusion_ffi::udwf::{FFI_WindowUDF, ForeignWindowUDF}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; #[derive(Debug)] struct RustPartitionEvaluator { @@ -245,6 +246,26 @@ impl PyWindowUDF { Ok(self.function.call(args).into()) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_window_udf__")? { + let capsule = func.getattr("__datafusion_window_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_window_udf")?; + + let udwf = unsafe { capsule.reference::() }; + let udwf: ForeignWindowUDF = udwf.try_into()?; + + Ok(Self { + function: udwf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_window_udf__ does not exist on WindowUDF object.".to_string(), + )) + } + } + fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } From bfd7a2de096a9854a89ad89dca9a7eefeba246b0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Jun 2025 15:27:57 -0400 Subject: [PATCH 07/14] Add default in memory options for adding schema and catalogs --- python/datafusion/catalog.py | 5 +++++ python/datafusion/context.py | 9 +++++++++ src/catalog.rs | 11 +++++++++++ src/context.rs | 13 ++++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index bebd38161..5f1a317f6 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -73,6 +73,11 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) + def new_in_memory_schema(self, name: str) -> Schema: + """Create a new schema in this catalog using an in-memory provider.""" + self.catalog.new_in_memory_schema(name) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: """Register a schema with this catalog.""" return self.catalog.register_schema(name, schema) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index c080931e9..f752272bb 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -758,6 +758,15 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def catalog_names(self) -> set[str]: + """Returns the list of catalogs in this context.""" + return self.ctx.catalog_names() + + def new_in_memory_catalog(self, name: str) -> Catalog: + """Create a new catalog in this context using an in-memory provider.""" + self.ctx.new_in_memory_catalog(name) + return self.catalog(name) + def register_catalog_provider( self, name: str, provider: CatalogProviderExportable ) -> None: diff --git a/src/catalog.rs b/src/catalog.rs index ba96ce471..9a24f2d44 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -19,6 +19,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{validate_pycapsule, wait_for_future}; use async_trait::async_trait; +use datafusion::catalog::MemorySchemaProvider; use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, @@ -105,6 +106,16 @@ impl PyCatalog { }) } + fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> { + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let _ = self + .catalog + .register_schema(name, schema) + .map_err(py_datafusion_err)?; + + Ok(()) + } + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { let capsule = schema_provider diff --git a/src/context.rs b/src/context.rs index cb15c5f0b..c97f2f618 100644 --- a/src/context.rs +++ b/src/context.rs @@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::CatalogProvider; +use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -617,6 +617,13 @@ impl PySessionContext { Ok(()) } + pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> { + let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc; + let _ = self.ctx.register_catalog(name, catalog); + + Ok(()) + } + pub fn register_catalog_provider( &mut self, name: &str, @@ -889,6 +896,10 @@ impl PySessionContext { }) } + pub fn catalog_names(&self) -> HashSet { + self.ctx.catalog_names().into_iter().collect() + } + pub fn tables(&self) -> HashSet { self.ctx .catalog_names() From 226077777e92f1d7e48d70b04220422a1b205a43 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 23 Jun 2025 10:35:30 -0400 Subject: [PATCH 08/14] Only collect one time during display() in jupyter notebooks --- python/datafusion/dataframe.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 1fd63bdc6..20dc4f4bc 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -318,6 +318,16 @@ def __repr__(self) -> str: Returns: String representation of the DataFrame. """ + # Check if we're in IPython/Jupyter. If so, we will only use + # the _repr_html_ output to avoid calling collect() twice. + try: + from IPython import get_ipython + + if get_ipython() is not None: + return "" # Return empty string to effectively disable + except ImportError: + pass + return self.df.__repr__() def _repr_html_(self) -> str: From d19d64965623efb456b367c93f9bef07c87f03e1 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 23 Jun 2025 20:36:50 -0400 Subject: [PATCH 09/14] Check for juypter notebook environment specifically --- python/datafusion/dataframe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 20dc4f4bc..3ebfcfbcf 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -318,14 +318,15 @@ def __repr__(self) -> str: Returns: String representation of the DataFrame. """ - # Check if we're in IPython/Jupyter. If so, we will only use + # Check if we're in a Jupyter notebook. If so, we will only use # the _repr_html_ output to avoid calling collect() twice. try: from IPython import get_ipython - if get_ipython() is not None: + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": return "" # Return empty string to effectively disable - except ImportError: + except (ImportError, NameError): pass return self.df.__repr__() From e48322a9526e3fce6ca436ae22fe99f0e772415e Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 24 Jun 2025 09:57:11 -0400 Subject: [PATCH 10/14] Remove approach of checking environment which could not differentiate between jupyter console and notebook --- python/datafusion/dataframe.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 3ebfcfbcf..1fd63bdc6 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -318,17 +318,6 @@ def __repr__(self) -> str: Returns: String representation of the DataFrame. """ - # Check if we're in a Jupyter notebook. If so, we will only use - # the _repr_html_ output to avoid calling collect() twice. - try: - from IPython import get_ipython - - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return "" # Return empty string to effectively disable - except (ImportError, NameError): - pass - return self.df.__repr__() def _repr_html_(self) -> str: From 0e2545053a691b6200478000421e7e4b0d3c1425 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 24 Jun 2025 11:19:06 -0400 Subject: [PATCH 11/14] Instead of trying to detect notebook vs console, collect one time when we have any kind if ipython environment. --- src/dataframe.rs | 53 ++++++++++++++++++++++++++++++++++++------------ src/utils.rs | 11 ++++++++++ 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 3d68db279..c117fab1d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -50,7 +50,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, }; use crate::{ errors::PyDataFusionResult, @@ -288,12 +288,18 @@ impl PyParquetColumnOptions { #[derive(Clone)] pub struct PyDataFrame { df: Arc, + + // In IPython environment cache batches between __repr__ and _repr_html_ calls. + batches: Option<(Vec, bool)>, } impl PyDataFrame { /// creates a new PyDataFrame pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } + Self { + df: Arc::new(df), + batches: None, + } } } @@ -320,16 +326,22 @@ impl PyDataFrame { } } - fn __repr__(&self, py: Python) -> PyDataFusionResult { + fn __repr__(&mut self, py: Python) -> PyDataFusionResult { // Get the Python formatter config let PythonFormatter { formatter: _, config, } = get_python_formatter_with_config(py)?; - let (batches, has_more) = wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??; + + let should_cache = *is_ipython_env(py) && self.batches.is_none(); + let (batches, has_more) = match self.batches.take() { + Some(b) => b, + None => wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), config), + )??, + }; + if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -343,16 +355,27 @@ impl PyDataFrame { false => "", }; + if should_cache { + self.batches = Some((batches, has_more)); + } + Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } - fn _repr_html_(&self, py: Python) -> PyDataFusionResult { + fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; - let (batches, has_more) = wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??; + + let should_cache = *is_ipython_env(py) && self.batches.is_none(); + + let (batches, has_more) = match self.batches.take() { + Some(b) => b, + None => wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), config), + )??, + }; + if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -362,7 +385,7 @@ impl PyDataFrame { // Convert record batches to PyObject list let py_batches = batches - .into_iter() + .iter() .map(|rb| rb.to_pyarrow(py)) .collect::>>()?; @@ -378,6 +401,10 @@ impl PyDataFrame { let html_result = formatter.call_method("format_html", (), Some(&kwargs))?; let html_str: String = html_result.extract()?; + if should_cache { + self.batches = Some((batches, has_more)); + } + Ok(html_str) } diff --git a/src/utils.rs b/src/utils.rs index 90d654385..f4e121fd5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) } +#[inline] +pub(crate) fn is_ipython_env(py: Python) -> &'static bool { + static IS_IPYTHON_ENV: OnceLock = OnceLock::new(); + IS_IPYTHON_ENV.get_or_init(|| { + py.import("IPython") + .and_then(|ipython| ipython.call_method0("get_ipython")) + .map(|ipython| !ipython.is_none()) + .unwrap_or(false) + }) +} + /// Utility to get the Global Datafussion CTX #[inline] pub(crate) fn get_global_ctx() -> &'static SessionContext { From 0f48cb50993538aa702302eebaa02471139b3f1b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 24 Jun 2025 15:02:48 -0400 Subject: [PATCH 12/14] Add string formatter --- python/datafusion/html_formatter.py | 28 ++++++ src/dataframe.rs | 134 +++++++++++++++------------- 2 files changed, 99 insertions(+), 63 deletions(-) diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py index 12a7e4553..e26537dbf 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/html_formatter.py @@ -26,6 +26,8 @@ runtime_checkable, ) +from datafusion._internal import DataFrame as DataFrameInternal + def _validate_positive_int(value: Any, param_name: str) -> None: """Validate that a parameter is a positive integer. @@ -345,6 +347,32 @@ def format_html( return "\n".join(html) + def format_str( + self, + batches: list, + schema: Any, + has_more: bool = False, + table_uuid: str | None = None, + ) -> str: + """Format record batches as a string. + + This method is used by DataFrame's __repr__ implementation and can be + called directly when string rendering is needed. + + Args: + batches: List of Arrow RecordBatch objects + schema: Arrow Schema object + has_more: Whether there are more batches not shown + table_uuid: Unique ID for the table, used for JavaScript interactions + + Returns: + String representation of the data + + Raises: + TypeError: If schema is invalid and no batches are provided + """ + return DataFrameInternal.default_str_repr(batches, schema, has_more, table_uuid) + def _build_html_header(self) -> list[str]: """Build the HTML header with CSS styles.""" html = [] diff --git a/src/dataframe.rs b/src/dataframe.rs index c117fab1d..69a6ec248 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -24,6 +24,7 @@ use arrow::compute::can_cast_types; use arrow::error::ArrowError; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::pyarrow::FromPyArrow; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; @@ -301,68 +302,8 @@ impl PyDataFrame { batches: None, } } -} - -#[pymethods] -impl PyDataFrame { - /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]` - fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult { - if let Ok(key) = key.extract::() { - // df[col] - self.select_columns(vec![key]) - } else if let Ok(tuple) = key.downcast::() { - // df[col1, col2, col3] - let keys = tuple - .iter() - .map(|item| item.extract::()) - .collect::>>()?; - self.select_columns(keys) - } else if let Ok(keys) = key.extract::>() { - // df[[col1, col2, col3]] - self.select_columns(keys) - } else { - let message = "DataFrame can only be indexed by string index or indices".to_string(); - Err(PyDataFusionError::Common(message)) - } - } - - fn __repr__(&mut self, py: Python) -> PyDataFusionResult { - // Get the Python formatter config - let PythonFormatter { - formatter: _, - config, - } = get_python_formatter_with_config(py)?; - - let should_cache = *is_ipython_env(py) && self.batches.is_none(); - let (batches, has_more) = match self.batches.take() { - Some(b) => b, - None => wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??, - }; - - if batches.is_empty() { - // This should not be reached, but do it for safety since we index into the vector below - return Ok("No data to display".to_string()); - } - - let batches_as_displ = - pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?; - - let additional_str = match has_more { - true => "\nData truncated.", - false => "", - }; - - if should_cache { - self.batches = Some((batches, has_more)); - } - - Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) - } - fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { + fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; @@ -398,15 +339,82 @@ impl PyDataFrame { kwargs.set_item("has_more", has_more)?; kwargs.set_item("table_uuid", table_uuid)?; - let html_result = formatter.call_method("format_html", (), Some(&kwargs))?; - let html_str: String = html_result.extract()?; + let method_name = match as_html { + true => "format_html", + false => "format_str", + }; + let html_result = formatter.call_method(method_name, (), Some(&kwargs))?; + let html_str: String = html_result.extract()?; if should_cache { self.batches = Some((batches, has_more)); } Ok(html_str) } +} + +#[pymethods] +impl PyDataFrame { + /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]` + fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult { + if let Ok(key) = key.extract::() { + // df[col] + self.select_columns(vec![key]) + } else if let Ok(tuple) = key.downcast::() { + // df[col1, col2, col3] + let keys = tuple + .iter() + .map(|item| item.extract::()) + .collect::>>()?; + self.select_columns(keys) + } else if let Ok(keys) = key.extract::>() { + // df[[col1, col2, col3]] + self.select_columns(keys) + } else { + let message = "DataFrame can only be indexed by string index or indices".to_string(); + Err(PyDataFusionError::Common(message)) + } + } + + fn __repr__(&mut self, py: Python) -> PyDataFusionResult { + self.prepare_repr_string(py, false) + } + + fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { + self.prepare_repr_string(py, true) + } + + #[staticmethod] + #[expect(unused_variables)] + fn default_str_repr<'py>( + batches: Vec>, + schema: &Bound<'py, PyAny>, + has_more: bool, + table_uuid: &str, + ) -> PyResult { + let batches = batches + .into_iter() + .map(|batch| RecordBatch::from_pyarrow_bound(&batch)) + .collect::>>()? + .into_iter() + .filter(|batch| batch.num_rows() > 0) + .collect::>(); + + if batches.is_empty() { + return Ok("No data to display".to_owned()); + } + + let batches_as_displ = + pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?; + + let additional_str = match has_more { + true => "\nData truncated.", + false => "", + }; + + Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) + } /// Calculate summary statistics for a DataFrame fn describe(&self, py: Python) -> PyDataFusionResult { From 4d8da3d5bcd21c17fa76c428418fdc9cfbe0e0aa Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 24 Jun 2025 15:12:52 -0400 Subject: [PATCH 13/14] Rename html_formatter to dataframe_formatter --- docs/source/api/dataframe.rst | 12 ++++++------ python/datafusion/__init__.py | 2 +- python/datafusion/dataframe.py | 15 ++++++++++++++- .../{html_formatter.py => dataframe_formatter.py} | 2 +- python/tests/test_dataframe.py | 4 ++-- src/dataframe.rs | 4 ++-- 6 files changed, 26 insertions(+), 13 deletions(-) rename python/datafusion/{html_formatter.py => dataframe_formatter.py} (99%) diff --git a/docs/source/api/dataframe.rst b/docs/source/api/dataframe.rst index a9e9e47c8..0efa2c6ed 100644 --- a/docs/source/api/dataframe.rst +++ b/docs/source/api/dataframe.rst @@ -174,7 +174,7 @@ HTML Rendering Customization ---------------------------- DataFusion provides extensive customization options for HTML table rendering through the -``datafusion.html_formatter`` module. +``datafusion.dataframe_formatter`` module. Configuring the HTML Formatter ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -183,7 +183,7 @@ You can customize how DataFrames are rendered by configuring the formatter: .. code-block:: python - from datafusion.html_formatter import configure_formatter + from datafusion.dataframe_formatter import configure_formatter configure_formatter( max_cell_length=30, # Maximum length of cell content before truncation @@ -206,7 +206,7 @@ For advanced styling needs, you can create a custom style provider class: .. code-block:: python - from datafusion.html_formatter import configure_formatter + from datafusion.dataframe_formatter import configure_formatter class CustomStyleProvider: def get_cell_style(self) -> str: @@ -225,7 +225,7 @@ You can register custom formatters for specific data types: .. code-block:: python - from datafusion.html_formatter import get_formatter + from datafusion.dataframe_formatter import get_formatter formatter = get_formatter() @@ -285,7 +285,7 @@ The HTML formatter maintains global state that can be managed: .. code-block:: python - from datafusion.html_formatter import reset_formatter, reset_styles_loaded_state, get_formatter + from datafusion.dataframe_formatter import reset_formatter, reset_styles_loaded_state, get_formatter # Reset the formatter to default settings reset_formatter() @@ -303,7 +303,7 @@ This example shows how to create a dashboard-like styling for your DataFrames: .. code-block:: python - from datafusion.html_formatter import configure_formatter, get_formatter + from datafusion.dataframe_formatter import configure_formatter, get_formatter # Define custom CSS custom_css = """ diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 8e38741bc..77fed2a94 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -47,11 +47,11 @@ SQLOptions, ) from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions +from .dataframe_formatter import configure_formatter from .expr import ( Expr, WindowFrame, ) -from .html_formatter import configure_formatter from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 1fd63bdc6..c747c24d5 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -52,7 +52,6 @@ import polars as pl import pyarrow as pa - from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import expr as expr_internal from enum import Enum @@ -1112,3 +1111,17 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame: - For columns not in subset, the original column is kept unchanged """ return DataFrame(self.df.fill_null(value, subset)) + + @staticmethod + def default_str_repr( + batches: list[pa.RecordBatch], + schema: pa.Schema, + has_more: bool, + table_uuid: str | None = None, + ) -> str: + """Return the default string representation of a DataFrame. + + This method is used by the default formatter and implemented in Rust for + performance reasons. + """ + return DataFrameInternal.default_str_repr(batches, schema, has_more, table_uuid) diff --git a/python/datafusion/html_formatter.py b/python/datafusion/dataframe_formatter.py similarity index 99% rename from python/datafusion/html_formatter.py rename to python/datafusion/dataframe_formatter.py index e26537dbf..27f00f9c3 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/dataframe_formatter.py @@ -271,7 +271,7 @@ def is_styles_loaded(cls) -> bool: True if styles have been loaded, False otherwise Example: - >>> from datafusion.html_formatter import DataFrameHtmlFormatter + >>> from datafusion.dataframe_formatter import DataFrameHtmlFormatter >>> DataFrameHtmlFormatter.is_styles_loaded() False """ diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 3c9b97f23..3b816bc85 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -37,14 +37,14 @@ from datafusion import ( functions as f, ) -from datafusion.expr import Window -from datafusion.html_formatter import ( +from datafusion.dataframe_formatter import ( DataFrameHtmlFormatter, configure_formatter, get_formatter, reset_formatter, reset_styles_loaded_state, ) +from datafusion.expr import Window from pyarrow.csv import write_csv MB = 1024 * 1024 diff --git a/src/dataframe.rs b/src/dataframe.rs index 69a6ec248..f554f340e 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -151,9 +151,9 @@ fn get_python_formatter_with_config(py: Python) -> PyResult { Ok(PythonFormatter { formatter, config }) } -/// Get the Python formatter from the datafusion.html_formatter module +/// Get the Python formatter from the datafusion.dataframe_formatter module fn import_python_formatter(py: Python) -> PyResult> { - let formatter_module = py.import("datafusion.html_formatter")?; + let formatter_module = py.import("datafusion.dataframe_formatter")?; let get_formatter = formatter_module.getattr("get_formatter")?; get_formatter.call0() } From 06ed0ffc27a049de89b884a7c24fe1f5fa6bc7cf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 24 Jun 2025 15:55:56 -0400 Subject: [PATCH 14/14] Add deprecation warning --- python/datafusion/html_formatter.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 python/datafusion/html_formatter.py diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py new file mode 100644 index 000000000..37558b913 --- /dev/null +++ b/python/datafusion/html_formatter.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Deprecated module for dataframe formatting.""" + +import warnings + +from datafusion.dataframe_formatter import * # noqa: F403 + +warnings.warn( + "The module 'html_formatter' is deprecated and will be removed in the next release." + "Please use 'dataframe_formatter' instead.", + DeprecationWarning, + stacklevel=2, +)