Skip to content

Commit 51e2b23

Browse files
committed
make sure column + token is unique, replace **{} with .alias()
1 parent b855352 commit 51e2b23

File tree

1 file changed

+98
-91
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+98
-91
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 98 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,50 @@ def _is_continuous(df: nw.DataFrame, col_name: str) -> bool:
163163
return df.get_column(col_name).dtype.is_numeric()
164164

165165

166+
def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series:
167+
dtype = s.dtype
168+
if dtype == nw.Date:
169+
return s.dt.timestamp("ms") / 1_000
170+
if dtype == nw.Datetime:
171+
if dtype.time_unit in ("s", "ms"):
172+
return s.dt.timestamp("ms") / 1_000
173+
elif dtype.time_unit == "us":
174+
return s.dt.timestamp("us") / 1_000_000
175+
elif dtype.time_unit == "ns":
176+
return s.dt.timestamp("ns") / 1_000_000_000
177+
else:
178+
msg = "Unexpected dtype, please report a bug"
179+
raise ValueError(msg)
180+
else:
181+
msg = f"Expected Date or Datetime, got {dtype}"
182+
raise TypeError(msg)
183+
184+
185+
def _generate_temporary_column_name(n_bytes: int, columns: list[str]) -> str:
186+
"""Wraps of Narwhals generate_temporary_column_name to generate a token
187+
which is guaranteed to not be in columns, nor in [col + token for col in columns]
188+
"""
189+
counter = 0
190+
while True:
191+
# This is guaranteed to not be in columns by Narwhals
192+
token = nw.generate_temporary_column_name(n_bytes, columns=columns)
193+
194+
# Now check that it is not in the [col + token for col in columns] list
195+
if token not in {f"{c}{token}" for c in columns}:
196+
return token
197+
198+
counter += 1
199+
if counter > 100:
200+
msg = (
201+
"Internal Error: Plotly was not able to generate a column name with "
202+
f"{n_bytes=} and not in {columns}.\n"
203+
"Please report this to "
204+
"https://github.com/plotly/plotly.py/issues/new and we will try to "
205+
"replicate and fix it."
206+
)
207+
raise AssertionError(msg)
208+
209+
166210
def get_decorated_label(args, column, role):
167211
original_label = label = get_label(args, column)
168212
if "histfunc" in args and (
@@ -443,7 +487,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
443487
# dict.fromkeys(customdata_cols) allows to deduplicate column
444488
# names, yet maintaining the original order.
445489
trace_patch["customdata"] = trace_data.select(
446-
[nw.col(c) for c in dict.fromkeys(customdata_cols)]
490+
*[nw.col(c) for c in dict.fromkeys(customdata_cols)]
447491
)
448492
elif attr_name == "color":
449493
if trace_spec.constructor in [
@@ -1693,7 +1737,7 @@ def build_dataframe(args, constructor):
16931737
other_dim = "x" if missing_bar_dim == "y" else "y"
16941738
if not _is_continuous(df_output, args[other_dim]):
16951739
args[missing_bar_dim] = count_name
1696-
df_output = df_output.with_columns(**{count_name: nw.lit(1)})
1740+
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
16971741
else:
16981742
# on the other hand, if the non-missing dimension is continuous, then we
16991743
# can use this information to override the normal auto-orientation code
@@ -1760,7 +1804,7 @@ def build_dataframe(args, constructor):
17601804
else:
17611805
args["x" if orient_v else "y"] = value_name
17621806
args["y" if orient_v else "x"] = count_name
1763-
df_output = df_output.with_columns(**{count_name: nw.lit(1)})
1807+
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
17641808
args["color"] = args["color"] or var_name
17651809
elif constructor in [go.Violin, go.Box]:
17661810
args["x" if orient_v else "y"] = wide_cross_name or var_name
@@ -1773,12 +1817,12 @@ def build_dataframe(args, constructor):
17731817
args["histfunc"] = None
17741818
args["orientation"] = "h"
17751819
args["x"] = count_name
1776-
df_output = df_output.with_columns(**{count_name: nw.lit(1)})
1820+
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
17771821
else:
17781822
args["histfunc"] = None
17791823
args["orientation"] = "v"
17801824
args["y"] = count_name
1781-
df_output = df_output.with_columns(**{count_name: nw.lit(1)})
1825+
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
17821826

17831827
if no_color:
17841828
args["color"] = None
@@ -1789,10 +1833,10 @@ def build_dataframe(args, constructor):
17891833
def _check_dataframe_all_leaves(df: nw.DataFrame) -> None:
17901834
cols = df.columns
17911835
df_sorted = df.sort(by=cols, descending=False, nulls_last=True)
1792-
null_mask = df_sorted.select(*[nw.col(c).is_null() for c in cols])
1793-
df_sorted = df_sorted.with_columns(nw.col(*cols).cast(nw.String()))
1836+
null_mask = df_sorted.select(nw.all().is_null())
1837+
df_sorted = df_sorted.select(nw.all().cast(nw.String()))
17941838
null_indices_mask = null_mask.select(
1795-
null_mask=nw.any_horizontal(nw.col(cols))
1839+
null_mask=nw.any_horizontal(nw.all())
17961840
).get_column("null_mask")
17971841

17981842
for row_idx, row in zip(
@@ -1854,26 +1898,15 @@ def process_dataframe_hierarchy(args):
18541898

18551899
new_path = [col_name + "_path_copy" for col_name in path]
18561900
df = df.with_columns(
1857-
**{
1858-
new_col_name: nw.col(col_name)
1859-
for new_col_name, col_name in zip(new_path, path)
1860-
}
1901+
nw.col(col_name).alias(new_col_name)
1902+
for new_col_name, col_name in zip(new_path, path)
18611903
)
18621904
path = new_path
18631905
# ------------ Define aggregation functions --------------------------------
18641906
agg_f = {}
18651907
if args["values"]:
18661908
try:
1867-
if isinstance(args["values"], Sequence) and not isinstance(
1868-
args["values"], str
1869-
):
1870-
df = df.with_columns(
1871-
**{c: nw.col(c).cast(nw.Float64()) for c in args["values"]}
1872-
)
1873-
else:
1874-
df = df.with_columns(
1875-
**{args["values"]: nw.col(args["values"]).cast(nw.Float64())}
1876-
)
1909+
df = df.with_columns(nw.col(args["values"]).cast(nw.Float64()))
18771910

18781911
except Exception: # pandas, Polars and pyarrow exception types are different
18791912
raise ValueError(
@@ -1883,7 +1916,7 @@ def process_dataframe_hierarchy(args):
18831916

18841917
if args["color"] and args["color"] == args["values"]:
18851918
new_value_col_name = args["values"] + "_sum"
1886-
df = df.with_columns(**{new_value_col_name: nw.col(args["values"])})
1919+
df = df.with_columns(nw.col(args["values"]).alias(new_value_col_name))
18871920
args["values"] = new_value_col_name
18881921
count_colname = args["values"]
18891922
else:
@@ -1894,7 +1927,7 @@ def process_dataframe_hierarchy(args):
18941927
"count" if "count" not in columns else "".join([str(el) for el in columns])
18951928
)
18961929
# we can modify df because it's a copy of the px argument
1897-
df = df.with_columns(**{count_colname: nw.lit(1)})
1930+
df = df.with_columns(nw.lit(1).alias(count_colname))
18981931
args["values"] = count_colname
18991932

19001933
# Since count_colname is always in agg_f, it can be used later to normalize color
@@ -1904,8 +1937,8 @@ def process_dataframe_hierarchy(args):
19041937
discrete_aggs = []
19051938
continuous_aggs = []
19061939

1907-
n_unique_token = nw.generate_temporary_column_name(
1908-
n_bytes=16, columns=[*path, count_colname]
1940+
n_unique_token = _generate_temporary_column_name(
1941+
n_bytes=16, columns=df.collect_schema().names()
19091942
)
19101943

19111944
# In theory, for discrete columns aggregation, we should have a way to do
@@ -1941,10 +1974,10 @@ def process_dataframe_hierarchy(args):
19411974

19421975
discrete_aggs.append(args["color"])
19431976
agg_f[args["color"]] = nw.col(args["color"]).max()
1944-
agg_f[f'{args["color"]}_{n_unique_token}__'] = (
1977+
agg_f[f'{args["color"]}{n_unique_token}'] = (
19451978
nw.col(args["color"])
19461979
.n_unique()
1947-
.alias(f'{args["color"]}_{n_unique_token}__')
1980+
.alias(f'{args["color"]}{n_unique_token}')
19481981
)
19491982
else:
19501983
# This first needs to be multiplied by `count_colname`
@@ -1954,16 +1987,15 @@ def process_dataframe_hierarchy(args):
19541987

19551988
# Other columns (for color, hover_data, custom_data etc.)
19561989
cols = list(set(df.collect_schema().names()).difference(path))
1957-
df = df.with_columns(
1958-
**{c: nw.col(c).cast(nw.String()) for c in cols if c not in agg_f}
1959-
)
1990+
df = df.with_columns(nw.col(c).cast(nw.String()) for c in cols if c not in agg_f)
1991+
19601992
for col in cols: # for hover_data, custom_data etc.
19611993
if col not in agg_f:
19621994
# Similar trick as above
19631995
discrete_aggs.append(col)
19641996
agg_f[col] = nw.col(col).max()
1965-
agg_f[f"{col}_{n_unique_token}__"] = (
1966-
nw.col(col).n_unique().alias(f"{col}_{n_unique_token}__")
1997+
agg_f[f"{col}{n_unique_token}"] = (
1998+
nw.col(col).n_unique().alias(f"{col}{n_unique_token}")
19671999
)
19682000
# Avoid collisions with reserved names - columns in the path have been copied already
19692001
cols = list(set(cols) - set(["labels", "parent", "id"]))
@@ -1972,7 +2004,7 @@ def process_dataframe_hierarchy(args):
19722004

19732005
if args["color"] and not discrete_color:
19742006
df = df.with_columns(
1975-
**{args["color"]: nw.col(args["color"]) * nw.col(count_colname)}
2007+
(nw.col(args["color"]) * nw.col(count_colname)).alias(args["color"])
19762008
)
19772009

19782010
def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFrame:
@@ -1981,14 +2013,14 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
19812013
- discrete_aggs is either [args["color"], <rest_of_cols>] or [<rest_of cols>]
19822014
"""
19832015
return dframe.with_columns(
1984-
**{c: nw.col(c) / nw.col(count_colname) for c in continuous_aggs},
2016+
**{col: nw.col(col) / nw.col(count_colname) for col in continuous_aggs},
19852017
**{
1986-
c: nw.when(nw.col(f"{c}_{n_unique_token}__") == 1)
1987-
.then(nw.col(c))
2018+
col: nw.when(nw.col(f"{col}{n_unique_token}") == 1)
2019+
.then(nw.col(col))
19882020
.otherwise(nw.lit("(?)"))
1989-
for c in discrete_aggs
2021+
for col in discrete_aggs
19902022
},
1991-
).drop([f"{c}_{n_unique_token}__" for c in discrete_aggs])
2023+
).drop([f"{col}{n_unique_token}" for col in discrete_aggs])
19922024

19932025
for i, level in enumerate(path):
19942026

@@ -2006,30 +2038,26 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
20062038
id=nw.col(level).cast(nw.String()),
20072039
)
20082040
if i < len(path) - 1:
2009-
_concat_str_token = nw.generate_temporary_column_name(
2010-
n_bytes=8, columns=[*cols, "labels", "parent", "id"]
2041+
_concat_str_token = _generate_temporary_column_name(
2042+
n_bytes=16, columns=[*cols, "labels", "parent", "id"]
20112043
)
20122044
df_tree = (
20132045
df_tree.with_columns(
2014-
**{
2015-
_concat_str_token: nw.concat_str(
2016-
[
2017-
nw.col(path[j]).cast(nw.String())
2018-
for j in range(len(path) - 1, i, -1)
2019-
],
2020-
separator="/",
2021-
)
2022-
}
2046+
nw.concat_str(
2047+
[
2048+
nw.col(path[j]).cast(nw.String())
2049+
for j in range(len(path) - 1, i, -1)
2050+
],
2051+
separator="/",
2052+
).alias(_concat_str_token)
20232053
)
20242054
.with_columns(
2025-
**{
2026-
"parent": nw.concat_str(
2027-
[nw.col(_concat_str_token), nw.col("parent")], separator="/"
2028-
),
2029-
"id": nw.concat_str(
2030-
[nw.col(_concat_str_token), nw.col("id")], separator="/"
2031-
),
2032-
}
2055+
parent=nw.concat_str(
2056+
[nw.col(_concat_str_token), nw.col("parent")], separator="/"
2057+
),
2058+
id=nw.concat_str(
2059+
[nw.col(_concat_str_token), nw.col("id")], separator="/"
2060+
),
20332061
)
20342062
.drop(_concat_str_token)
20352063
)
@@ -2049,7 +2077,7 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
20492077
while sort_col_name in df_all_trees.columns:
20502078
sort_col_name += "0"
20512079
df_all_trees = df_all_trees.with_columns(
2052-
**{sort_col_name: nw.col(args["color"]).cast(nw.String())}
2080+
nw.col(args["color"]).cast(nw.String()).alias(sort_col_name)
20532081
).sort(by=sort_col_name, nulls_last=True)
20542082

20552083
# Now modify arguments
@@ -2080,10 +2108,8 @@ def process_dataframe_timeline(args):
20802108
try:
20812109
df: nw.DataFrame = args["data_frame"]
20822110
df = df.with_columns(
2083-
**{
2084-
args["x_start"]: nw.col(args["x_start"]).str.to_datetime(),
2085-
args["x_end"]: nw.col(args["x_end"]).str.to_datetime(),
2086-
}
2111+
nw.col(args["x_start"]).str.to_datetime().alias(args["x_start"]),
2112+
nw.col(args["x_end"]).str.to_datetime().alias(args["x_end"]),
20872113
)
20882114
except Exception:
20892115
raise TypeError(
@@ -2092,11 +2118,9 @@ def process_dataframe_timeline(args):
20922118

20932119
# note that we are not adding any columns to the data frame here, so no risk of overwrite
20942120
args["data_frame"] = df.with_columns(
2095-
**{
2096-
args["x_end"]: (
2097-
nw.col(args["x_end"]) - nw.col(args["x_start"])
2098-
).dt.total_milliseconds()
2099-
}
2121+
(nw.col(args["x_end"]) - nw.col(args["x_start"]))
2122+
.dt.total_milliseconds()
2123+
.alias(args["x_end"])
21002124
)
21012125
args["x"] = args["x_end"]
21022126
args["base"] = args["x_start"]
@@ -2594,20 +2618,22 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
25942618
group_sum = group.get_column(
25952619
var
25962620
).sum() # compute here before next line mutates
2597-
group = group.with_columns(**{var: nw.col(var).cum_sum()})
2621+
group = group.with_columns(nw.col(var).cum_sum().alias(var))
25982622
if not ascending:
25992623
group = group.sort(by=base, descending=False, nulls_last=True)
26002624

26012625
if args.get("ecdfmode", "standard") == "complementary":
26022626
group = group.with_columns(
2603-
**{var: (nw.col(var) - nw.lit(group_sum)) * (-1)}
2627+
((nw.col(var) - nw.lit(group_sum)) * (-1)).alias(var)
26042628
)
26052629

26062630
if args["ecdfnorm"] == "probability":
2607-
group = group.with_columns(**{var: nw.col(var) / nw.lit(group_sum)})
2631+
group = group.with_columns(
2632+
(nw.col(var) / nw.lit(group_sum)).alias(var)
2633+
)
26082634
elif args["ecdfnorm"] == "percent":
26092635
group = group.with_columns(
2610-
**{var: nw.col(var) / nw.lit(group_sum) * nw.lit(100.0)}
2636+
(nw.col(var) / nw.lit(group_sum) * nw.lit(100.0)).alias(var)
26112637
)
26122638

26132639
patch, fit_results = make_trace_kwargs(
@@ -2835,22 +2861,3 @@ def _spacing_error_translator(e, direction, facet_arg):
28352861
annot.update(font=None)
28362862

28372863
return fig
2838-
2839-
2840-
def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series:
2841-
dtype = s.dtype
2842-
if dtype == nw.Date:
2843-
return s.dt.timestamp("ms") / 1_000
2844-
if dtype == nw.Datetime:
2845-
if dtype.time_unit in ("s", "ms"):
2846-
return s.dt.timestamp("ms") / 1_000
2847-
elif dtype.time_unit == "us":
2848-
return s.dt.timestamp("us") / 1_000_000
2849-
elif dtype.time_unit == "ns":
2850-
return s.dt.timestamp("ns") / 1_000_000_000
2851-
else:
2852-
msg = "Unexpected dtype, please report a bug"
2853-
raise ValueError(msg)
2854-
else:
2855-
msg = f"Expected Date or Datetime, got {dtype}"
2856-
raise TypeError(msg)

0 commit comments

Comments
 (0)