Skip to content

Commit 192e0a8

Browse files
committed
use token in process_dataframe_hierarchy
1 parent 878d4db commit 192e0a8

File tree

4 files changed

+29
-16
lines changed

4 files changed

+29
-16
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,10 @@ def process_dataframe_hierarchy(args):
18681868
discrete_aggs = []
18691869
continuous_aggs = []
18701870

1871+
n_unique_token = nw.generate_temporary_column_name(
1872+
n_bytes=16, columns=[*path, count_colname]
1873+
)
1874+
18711875
if args["color"]:
18721876
if discrete_color:
18731877

@@ -1888,10 +1892,10 @@ def process_dataframe_hierarchy(args):
18881892
# ```
18891893
# However we cannot do that just yet, therefore a workaround is provided
18901894
agg_f[args["color"]] = nw.col(args["color"]).max()
1891-
agg_f[f'{args["color"]}__plotly_n_unique__'] = (
1895+
agg_f[f'{args["color"]}_{n_unique_token}__'] = (
18921896
nw.col(args["color"])
18931897
.n_unique()
1894-
.alias(f'{args["color"]}__plotly_n_unique__')
1898+
.alias(f'{args["color"]}_{n_unique_token}__')
18951899
)
18961900
else:
18971901
# This first needs to be multiplied by `count_colname`
@@ -1909,8 +1913,8 @@ def process_dataframe_hierarchy(args):
19091913
# Similar trick as above
19101914
discrete_aggs.append(col)
19111915
agg_f[col] = nw.col(col).max()
1912-
agg_f[f"{col}__plotly_n_unique__"] = (
1913-
nw.col(col).n_unique().alias(f"{col}__plotly_n_unique__")
1916+
agg_f[f"{col}_{n_unique_token}__"] = (
1917+
nw.col(col).n_unique().alias(f"{col}_{n_unique_token}__")
19141918
)
19151919
# Avoid collisions with reserved names - columns in the path have been copied already
19161920
cols = list(set(cols) - set(["labels", "parent", "id"]))
@@ -1930,12 +1934,12 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
19301934
return dframe.with_columns(
19311935
**{c: nw.col(c) / nw.col(count_colname) for c in continuous_aggs},
19321936
**{
1933-
c: nw.when(nw.col(f"{c}__plotly_n_unique__") == 1)
1937+
c: nw.when(nw.col(f"{c}_{n_unique_token}__") == 1)
19341938
.then(nw.col(c))
19351939
.otherwise(nw.lit("(?)"))
19361940
for c in discrete_aggs
19371941
},
1938-
).drop([f"{c}__plotly_n_unique__" for c in discrete_aggs])
1942+
).drop([f"{c}_{n_unique_token}__" for c in discrete_aggs])
19391943

19401944
for i, level in enumerate(path):
19411945

@@ -1953,11 +1957,13 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
19531957
id=nw.col(level).cast(nw.String()),
19541958
)
19551959
if i < len(path) - 1:
1956-
token = generate_unique_token(n_bytes=8, columns=df_tree.columns)
1960+
_concat_str_token = nw.generate_temporary_column_name(
1961+
n_bytes=8, columns=[*cols, "labels", "parent", "id"]
1962+
)
19571963
df_tree = (
19581964
df_tree.with_columns(
19591965
**{
1960-
token: nw.concat_str(
1966+
_concat_str_token: nw.concat_str(
19611967
[
19621968
nw.col(path[j]).cast(nw.String())
19631969
for j in range(len(path) - 1, i, -1)
@@ -1969,14 +1975,14 @@ def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFra
19691975
.with_columns(
19701976
**{
19711977
"parent": nw.concat_str(
1972-
[nw.col(token), nw.col("parent")], separator="/"
1978+
[nw.col(_concat_str_token), nw.col("parent")], separator="/"
19731979
),
19741980
"id": nw.concat_str(
1975-
[nw.col(token), nw.col("id")], separator="/"
1981+
[nw.col(_concat_str_token), nw.col("id")], separator="/"
19761982
),
19771983
}
19781984
)
1979-
.drop(token)
1985+
.drop(_concat_str_token)
19801986
)
19811987

19821988
# strip "/" if at the end of the string, equivalent to `.str.rstrip`

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,7 @@ def check_label(label, fig):
542542
check_label("density of max of tip", fig)
543543

544544

545-
def test_timeline(request, constructor):
546-
if "pyarrow_table" in str(constructor) or "polars_eager" in str(constructor):
547-
request.applymarker(pytest.mark.xfail)
545+
def test_timeline(constructor):
548546

549547
df = constructor(
550548
{

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,13 @@ def test_sunburst_hoverdict_color(constructor):
191191

192192
def test_date_in_hover(request, constructor):
193193
if "pyarrow_table" in str(constructor) or "polars_eager" in str(constructor):
194+
# fig.data[0].customdata[0][0] is a numpy.datetime64 for non pandas
195+
# input, and it does not keep the timezone when converting to py scalar
194196
request.applymarker(pytest.mark.xfail)
195197

196198
df = nw.from_native(
197199
constructor({"date": ["2015-04-04 19:31:30+01:00"], "value": [3]})
198200
).with_columns(date=nw.col("date").str.to_datetime(format="%Y-%m-%d %H:%M:%S%z"))
199201
fig = px.scatter(df.to_native(), x="value", y="value", hover_data=["date"])
202+
200203
assert fig.data[0].customdata[0][0] == df.item(row=0, column="date")

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def test_with_index():
5353

5454
def test_series(request, constructor):
5555
if "pyarrow_table" in str(constructor):
56+
# By converting to native, we lose the name for pyarrow chunked_array
57+
# and the assertions fail
5658
request.applymarker(pytest.mark.xfail)
5759

5860
data = px.data.tips().to_dict(orient="list")
5961
tips = nw.from_native(constructor(data))
6062
before_tip = (tips.get_column("total_bill") - tips.get_column("tip")).to_native()
61-
# By converting to native, we lose the name for pyarrow chunked_array and the last
62-
# assertion fails
6363
day = tips.get_column("day").to_native()
6464
tips = tips.to_native()
6565

@@ -74,6 +74,8 @@ def test_series(request, constructor):
7474

7575
def test_several_dataframes(request, constructor):
7676
if "pyarrow_table" in str(constructor):
77+
# By converting to native, we lose the name for pyarrow chunked_array
78+
# and the assertions fail
7779
request.applymarker(pytest.mark.xfail)
7880

7981
df = nw.from_native(constructor(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8])))
@@ -153,6 +155,8 @@ def test_several_dataframes(request, constructor):
153155

154156
def test_name_heuristics(request, constructor):
155157
if "pyarrow_table" in str(constructor):
158+
# By converting to native, we lose the name for pyarrow chunked_array
159+
# and the assertions fail
156160
request.applymarker(pytest.mark.xfail)
157161

158162
df = nw.from_native(constructor(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])))
@@ -482,6 +486,8 @@ def test_pass_df_columns(constructor):
482486

483487
def test_size_column(request, constructor):
484488
if "pyarrow_table" in str(constructor):
489+
# By converting to native, we lose the name for pyarrow chunked_array
490+
# and the assertions fail
485491
request.applymarker(pytest.mark.xfail)
486492
data = px.data.tips().to_dict(orient="list")
487493
tips = nw.from_native(constructor(data))

0 commit comments

Comments
 (0)