Skip to content

Commit 0e25450

Browse files
committed
Instead of trying to detect notebook vs console, collect one time when we have any kind if ipython environment.
1 parent e48322a commit 0e25450

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

src/dataframe.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use crate::physical_plan::PyExecutionPlan;
5050
use crate::record_batch::PyRecordBatchStream;
5151
use crate::sql::logical::PyLogicalPlan;
5252
use crate::utils::{
53-
get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
53+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
5454
};
5555
use crate::{
5656
errors::PyDataFusionResult,
@@ -288,12 +288,18 @@ impl PyParquetColumnOptions {
288288
#[derive(Clone)]
289289
pub struct PyDataFrame {
290290
df: Arc<DataFrame>,
291+
292+
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
293+
batches: Option<(Vec<RecordBatch>, bool)>,
291294
}
292295

293296
impl PyDataFrame {
294297
/// creates a new PyDataFrame
295298
pub fn new(df: DataFrame) -> Self {
296-
Self { df: Arc::new(df) }
299+
Self {
300+
df: Arc::new(df),
301+
batches: None,
302+
}
297303
}
298304
}
299305

@@ -320,16 +326,22 @@ impl PyDataFrame {
320326
}
321327
}
322328

323-
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
329+
fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
324330
// Get the Python formatter config
325331
let PythonFormatter {
326332
formatter: _,
327333
config,
328334
} = get_python_formatter_with_config(py)?;
329-
let (batches, has_more) = wait_for_future(
330-
py,
331-
collect_record_batches_to_display(self.df.as_ref().clone(), config),
332-
)??;
335+
336+
let should_cache = *is_ipython_env(py) && self.batches.is_none();
337+
let (batches, has_more) = match self.batches.take() {
338+
Some(b) => b,
339+
None => wait_for_future(
340+
py,
341+
collect_record_batches_to_display(self.df.as_ref().clone(), config),
342+
)??,
343+
};
344+
333345
if batches.is_empty() {
334346
// This should not be reached, but do it for safety since we index into the vector below
335347
return Ok("No data to display".to_string());
@@ -343,16 +355,27 @@ impl PyDataFrame {
343355
false => "",
344356
};
345357

358+
if should_cache {
359+
self.batches = Some((batches, has_more));
360+
}
361+
346362
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
347363
}
348364

349-
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
365+
fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
350366
// Get the Python formatter and config
351367
let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
352-
let (batches, has_more) = wait_for_future(
353-
py,
354-
collect_record_batches_to_display(self.df.as_ref().clone(), config),
355-
)??;
368+
369+
let should_cache = *is_ipython_env(py) && self.batches.is_none();
370+
371+
let (batches, has_more) = match self.batches.take() {
372+
Some(b) => b,
373+
None => wait_for_future(
374+
py,
375+
collect_record_batches_to_display(self.df.as_ref().clone(), config),
376+
)??,
377+
};
378+
356379
if batches.is_empty() {
357380
// This should not be reached, but do it for safety since we index into the vector below
358381
return Ok("No data to display".to_string());
@@ -362,7 +385,7 @@ impl PyDataFrame {
362385

363386
// Convert record batches to PyObject list
364387
let py_batches = batches
365-
.into_iter()
388+
.iter()
366389
.map(|rb| rb.to_pyarrow(py))
367390
.collect::<PyResult<Vec<PyObject>>>()?;
368391

@@ -378,6 +401,10 @@ impl PyDataFrame {
378401
let html_result = formatter.call_method("format_html", (), Some(&kwargs))?;
379402
let html_str: String = html_result.extract()?;
380403

404+
if should_cache {
405+
self.batches = Some((batches, has_more));
406+
}
407+
381408
Ok(html_str)
382409
}
383410

src/utils.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
3939
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
4040
}
4141

42+
#[inline]
43+
pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
44+
static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
45+
IS_IPYTHON_ENV.get_or_init(|| {
46+
py.import("IPython")
47+
.and_then(|ipython| ipython.call_method0("get_ipython"))
48+
.map(|ipython| !ipython.is_none())
49+
.unwrap_or(false)
50+
})
51+
}
52+
4253
/// Utility to get the Global Datafussion CTX
4354
#[inline]
4455
pub(crate) fn get_global_ctx() -> &'static SessionContext {

0 commit comments

Comments
 (0)