Skip to content

Commit 6ea305e

Browse files
jizezhang2010YOUY01dependabot[bot]DandandankumarUjjawal
authored
feat: allow custom caching via logical node (#18688)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #17297. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> See #17297. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Added methods in `SessionState` to enable registering a `CacheProducer` that creates a logical node for caching. - Added branching in `DataFrame::cache()` to apply logical node for caching on top of original dataframe plan if cache producer is supplied, otherwise use current implementation as is. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> Change in `SessionState` is user-facing. --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Yongting You <2010youy01@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Daniël Heres <danielheres@gmail.com> Co-authored-by: Kumar Ujjawal <ujjawalpathak6@gmail.com> Co-authored-by: Aryan Bagade <73382554+AryanBagade@users.noreply.github.com> Co-authored-by: Oleks V <comphead@users.noreply.github.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com> Co-authored-by: Vrishabh <psvrishabh@gmail.com> Co-authored-by: Dhanush <dhanushhs51@gmail.com>
1 parent 6856dc4 commit 6ea305e

File tree

4 files changed

+159
-13
lines changed

4 files changed

+159
-13
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,6 +2330,10 @@ impl DataFrame {
23302330

23312331
/// Cache DataFrame as a memory table.
23322332
///
2333+
/// Default behavior could be changed using
2334+
/// a [`crate::execution::session_state::CacheFactory`]
2335+
/// configured via [`SessionState`].
2336+
///
23332337
/// ```
23342338
/// # use datafusion::prelude::*;
23352339
/// # use datafusion::error::Result;
@@ -2344,14 +2348,20 @@ impl DataFrame {
23442348
/// # }
23452349
/// ```
23462350
pub async fn cache(self) -> Result<DataFrame> {
2347-
let context = SessionContext::new_with_state((*self.session_state).clone());
2348-
// The schema is consistent with the output
2349-
let plan = self.clone().create_physical_plan().await?;
2350-
let schema = plan.schema();
2351-
let task_ctx = Arc::new(self.task_ctx());
2352-
let partitions = collect_partitioned(plan, task_ctx).await?;
2353-
let mem_table = MemTable::try_new(schema, partitions)?;
2354-
context.read_table(Arc::new(mem_table))
2351+
if let Some(cache_factory) = self.session_state.cache_factory() {
2352+
let new_plan =
2353+
cache_factory.create(self.plan, self.session_state.as_ref())?;
2354+
Ok(Self::new(*self.session_state, new_plan))
2355+
} else {
2356+
let context = SessionContext::new_with_state((*self.session_state).clone());
2357+
// The schema is consistent with the output
2358+
let plan = self.clone().create_physical_plan().await?;
2359+
let schema = plan.schema();
2360+
let task_ctx = Arc::new(self.task_ctx());
2361+
let partitions = collect_partitioned(plan, task_ctx).await?;
2362+
let mem_table = MemTable::try_new(schema, partitions)?;
2363+
context.read_table(Arc::new(mem_table))
2364+
}
23552365
}
23562366

23572367
/// Apply an alias to the DataFrame.

datafusion/core/src/execution/session_state.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ pub struct SessionState {
185185
/// It will be invoked on `CREATE FUNCTION` statements.
186186
/// thus, changing dialect o PostgreSql is required
187187
function_factory: Option<Arc<dyn FunctionFactory>>,
188+
cache_factory: Option<Arc<dyn CacheFactory>>,
188189
/// Cache logical plans of prepared statements for later execution.
189190
/// Key is the prepared statement name.
190191
prepared_plans: HashMap<String, Arc<PreparedPlan>>,
@@ -206,6 +207,7 @@ impl Debug for SessionState {
206207
.field("table_options", &self.table_options)
207208
.field("table_factories", &self.table_factories)
208209
.field("function_factory", &self.function_factory)
210+
.field("cache_factory", &self.cache_factory)
209211
.field("expr_planners", &self.expr_planners);
210212

211213
#[cfg(feature = "sql")]
@@ -355,6 +357,16 @@ impl SessionState {
355357
self.function_factory.as_ref()
356358
}
357359

360+
/// Register a [`CacheFactory`] for custom caching strategy
361+
pub fn set_cache_factory(&mut self, cache_factory: Arc<dyn CacheFactory>) {
362+
self.cache_factory = Some(cache_factory);
363+
}
364+
365+
/// Get the cache factory
366+
pub fn cache_factory(&self) -> Option<&Arc<dyn CacheFactory>> {
367+
self.cache_factory.as_ref()
368+
}
369+
358370
/// Get the table factories
359371
pub fn table_factories(&self) -> &HashMap<String, Arc<dyn TableProviderFactory>> {
360372
&self.table_factories
@@ -941,6 +953,7 @@ pub struct SessionStateBuilder {
941953
table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
942954
runtime_env: Option<Arc<RuntimeEnv>>,
943955
function_factory: Option<Arc<dyn FunctionFactory>>,
956+
cache_factory: Option<Arc<dyn CacheFactory>>,
944957
// fields to support convenience functions
945958
analyzer_rules: Option<Vec<Arc<dyn AnalyzerRule + Send + Sync>>>,
946959
optimizer_rules: Option<Vec<Arc<dyn OptimizerRule + Send + Sync>>>,
@@ -978,6 +991,7 @@ impl SessionStateBuilder {
978991
table_factories: None,
979992
runtime_env: None,
980993
function_factory: None,
994+
cache_factory: None,
981995
// fields to support convenience functions
982996
analyzer_rules: None,
983997
optimizer_rules: None,
@@ -1030,7 +1044,7 @@ impl SessionStateBuilder {
10301044
table_factories: Some(existing.table_factories),
10311045
runtime_env: Some(existing.runtime_env),
10321046
function_factory: existing.function_factory,
1033-
1047+
cache_factory: existing.cache_factory,
10341048
// fields to support convenience functions
10351049
analyzer_rules: None,
10361050
optimizer_rules: None,
@@ -1319,6 +1333,15 @@ impl SessionStateBuilder {
13191333
self
13201334
}
13211335

1336+
/// Set a [`CacheFactory`] for custom caching strategy
1337+
pub fn with_cache_factory(
1338+
mut self,
1339+
cache_factory: Option<Arc<dyn CacheFactory>>,
1340+
) -> Self {
1341+
self.cache_factory = cache_factory;
1342+
self
1343+
}
1344+
13221345
/// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`]
13231346
/// for more details.
13241347
///
@@ -1382,6 +1405,7 @@ impl SessionStateBuilder {
13821405
table_factories,
13831406
runtime_env,
13841407
function_factory,
1408+
cache_factory,
13851409
analyzer_rules,
13861410
optimizer_rules,
13871411
physical_optimizer_rules,
@@ -1418,6 +1442,7 @@ impl SessionStateBuilder {
14181442
table_factories: table_factories.unwrap_or_default(),
14191443
runtime_env,
14201444
function_factory,
1445+
cache_factory,
14211446
prepared_plans: HashMap::new(),
14221447
};
14231448

@@ -1621,6 +1646,11 @@ impl SessionStateBuilder {
16211646
&mut self.function_factory
16221647
}
16231648

1649+
/// Returns the cache factory
1650+
pub fn cache_factory(&mut self) -> &mut Option<Arc<dyn CacheFactory>> {
1651+
&mut self.cache_factory
1652+
}
1653+
16241654
/// Returns the current analyzer_rules value
16251655
pub fn analyzer_rules(
16261656
&mut self,
@@ -1659,6 +1689,7 @@ impl Debug for SessionStateBuilder {
16591689
.field("table_options", &self.table_options)
16601690
.field("table_factories", &self.table_factories)
16611691
.field("function_factory", &self.function_factory)
1692+
.field("cache_factory", &self.cache_factory)
16621693
.field("expr_planners", &self.expr_planners);
16631694
#[cfg(feature = "sql")]
16641695
let ret = ret.field("type_planner", &self.type_planner);
@@ -2047,6 +2078,19 @@ pub(crate) struct PreparedPlan {
20472078
pub(crate) plan: Arc<LogicalPlan>,
20482079
}
20492080

2081+
/// A [`CacheFactory`] can be registered via [`SessionState`]
2082+
/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`].
2083+
/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`]
2084+
/// may need to be implemented to handle such plans.
2085+
pub trait CacheFactory: Debug + Send + Sync {
2086+
/// Create a logical plan for caching
2087+
fn create(
2088+
&self,
2089+
plan: LogicalPlan,
2090+
session_state: &SessionState,
2091+
) -> datafusion_common::Result<LogicalPlan>;
2092+
}
2093+
20502094
#[cfg(test)]
20512095
mod tests {
20522096
use super::{SessionContextProvider, SessionStateBuilder};

datafusion/core/src/test_util/mod.rs

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub mod csv;
2525
use futures::Stream;
2626
use std::any::Any;
2727
use std::collections::HashMap;
28+
use std::fmt::Formatter;
2829
use std::fs::File;
2930
use std::io::Write;
3031
use std::path::Path;
@@ -36,16 +37,20 @@ use crate::dataframe::DataFrame;
3637
use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable};
3738
use crate::datasource::{empty::EmptyTable, provider_as_source};
3839
use crate::error::Result;
40+
use crate::execution::session_state::CacheFactory;
3941
use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
4042
use crate::physical_plan::ExecutionPlan;
4143
use crate::prelude::{CsvReadOptions, SessionContext};
4244

43-
use crate::execution::SendableRecordBatchStream;
45+
use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder};
4446
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
4547
use arrow::record_batch::RecordBatch;
4648
use datafusion_catalog::Session;
47-
use datafusion_common::TableReference;
48-
use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
49+
use datafusion_common::{DFSchemaRef, TableReference};
50+
use datafusion_expr::{
51+
CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType,
52+
UserDefinedLogicalNodeCore,
53+
};
4954
use std::pin::Pin;
5055

5156
use async_trait::async_trait;
@@ -282,3 +287,67 @@ impl RecordBatchStream for BoundedStream {
282287
self.record_batch.schema()
283288
}
284289
}
290+
291+
#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)]
292+
struct CacheNode {
293+
input: LogicalPlan,
294+
}
295+
296+
impl UserDefinedLogicalNodeCore for CacheNode {
297+
fn name(&self) -> &str {
298+
"CacheNode"
299+
}
300+
301+
fn inputs(&self) -> Vec<&LogicalPlan> {
302+
vec![&self.input]
303+
}
304+
305+
fn schema(&self) -> &DFSchemaRef {
306+
self.input.schema()
307+
}
308+
309+
fn expressions(&self) -> Vec<Expr> {
310+
vec![]
311+
}
312+
313+
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
314+
write!(f, "CacheNode")
315+
}
316+
317+
fn with_exprs_and_inputs(
318+
&self,
319+
_exprs: Vec<Expr>,
320+
inputs: Vec<LogicalPlan>,
321+
) -> Result<Self> {
322+
assert_eq!(inputs.len(), 1, "input size inconsistent");
323+
Ok(Self {
324+
input: inputs[0].clone(),
325+
})
326+
}
327+
}
328+
329+
#[derive(Debug)]
330+
struct TestCacheFactory {}
331+
332+
impl CacheFactory for TestCacheFactory {
333+
fn create(
334+
&self,
335+
plan: LogicalPlan,
336+
_session_state: &SessionState,
337+
) -> Result<LogicalPlan> {
338+
Ok(LogicalPlan::Extension(datafusion_expr::Extension {
339+
node: Arc::new(CacheNode { input: plan }),
340+
}))
341+
}
342+
}
343+
344+
/// Create a test table registered to a session context with an associated cache factory
345+
pub async fn test_table_with_cache_factory() -> Result<DataFrame> {
346+
let session_state = SessionStateBuilder::new()
347+
.with_cache_factory(Some(Arc::new(TestCacheFactory {})))
348+
.build();
349+
let ctx = SessionContext::new_with_state(session_state);
350+
let name = "aggregate_test_100";
351+
register_aggregate_csv(&ctx, name).await?;
352+
ctx.table(name).await
353+
}

datafusion/core/tests/dataframe/mod.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ use datafusion::prelude::{
6161
};
6262
use datafusion::test_util::{
6363
parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table,
64-
test_table_with_name,
64+
test_table_with_cache_factory, test_table_with_name,
6565
};
6666
use datafusion_catalog::TableProvider;
6767
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
@@ -2335,6 +2335,29 @@ async fn cache_test() -> Result<()> {
23352335
Ok(())
23362336
}
23372337

2338+
#[tokio::test]
2339+
async fn cache_producer_test() -> Result<()> {
2340+
let df = test_table_with_cache_factory()
2341+
.await?
2342+
.select_columns(&["c2", "c3"])?
2343+
.limit(0, Some(1))?
2344+
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
2345+
2346+
let cached_df = df.clone().cache().await?;
2347+
2348+
assert_snapshot!(
2349+
cached_df.clone().into_optimized_plan().unwrap(),
2350+
@r###"
2351+
CacheNode
2352+
Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum
2353+
Projection: aggregate_test_100.c2, aggregate_test_100.c3
2354+
Limit: skip=0, fetch=1
2355+
TableScan: aggregate_test_100, fetch=1
2356+
"###
2357+
);
2358+
Ok(())
2359+
}
2360+
23382361
#[tokio::test]
23392362
async fn partition_aware_union() -> Result<()> {
23402363
let left = test_table().await?.select_columns(&["c1", "c2"])?;

0 commit comments

Comments
 (0)