Skip to content

Commit 7ef9f28

Browse files
committed
WIP
1 parent 87e5a39 commit 7ef9f28

File tree

3 files changed

+97
-59
lines changed

3 files changed

+97
-59
lines changed
Lines changed: 93 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
22
Built-in datasets for demonstration, educational and test purposes.
33
"""
4+
import narwhals.stable.v1 as nw
45

5-
6-
def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
6+
def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False, return_type="pandas"):
77
"""
88
Each row represents a country on a given year.
99
@@ -17,16 +17,16 @@ def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
1717
If `centroids` is True, two new columns are added: ['centroid_lat', 'centroid_lon']
1818
If `year` is an integer, the dataset will be filtered for that year
1919
"""
20-
df = _get_dataset("gapminder")
20+
df = nw.from_native(_get_dataset("gapminder", return_type=return_type), eager_only=True)
2121
if year:
22-
df = df[df["year"] == year]
22+
df = df.filter(nw.col("year") == year)
2323
if datetimes:
24-
df["year"] = (df["year"].astype(str) + "-01-01").astype("datetime64[ns]")
24+
df = df.with_columns(nw.concat_str([nw.col("year").cast(nw.String()), nw.lit("-01-01")]).cast(nw.Datetime(time_unit="ns")))
2525
if not centroids:
26-
df = df.drop(["centroid_lat", "centroid_lon"], axis=1)
26+
df = df.drop("centroid_lat", "centroid_lon")
2727
if pretty_names:
28-
df.rename(
29-
mapper=dict(
28+
df = df.rename(
29+
dict(
3030
country="Country",
3131
continent="Continent",
3232
year="Year",
@@ -37,14 +37,12 @@ def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
3737
iso_num="ISO Numeric Country Code",
3838
centroid_lat="Centroid Latitude",
3939
centroid_lon="Centroid Longitude",
40-
),
41-
axis="columns",
42-
inplace=True,
40+
)
4341
)
44-
return df
42+
return df.to_native()
4543

4644

47-
def tips(pretty_names=False):
45+
def tips(pretty_names=False, return_type="pandas"):
4846
"""
4947
Each row represents a restaurant bill.
5048
@@ -54,25 +52,23 @@ def tips(pretty_names=False):
5452
A `pandas.DataFrame` with 244 rows and the following columns:
5553
`['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`."""
5654

57-
df = _get_dataset("tips")
55+
df = nw.from_native(_get_dataset("tips", return_type=return_type), eager_only=True)
5856
if pretty_names:
59-
df.rename(
60-
mapper=dict(
57+
df = df.rename(
58+
dict(
6159
total_bill="Total Bill",
6260
tip="Tip",
6361
sex="Payer Gender",
6462
smoker="Smokers at Table",
6563
day="Day of Week",
6664
time="Meal",
6765
size="Party Size",
68-
),
69-
axis="columns",
70-
inplace=True,
66+
)
7167
)
72-
return df
68+
return df.to_native()
7369

7470

75-
def iris():
71+
def iris(return_type="pandas"):
7672
"""
7773
Each row represents a flower.
7874
@@ -81,28 +77,28 @@ def iris():
8177
Returns:
8278
A `pandas.DataFrame` with 150 rows and the following columns:
8379
`['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', 'species_id']`."""
84-
return _get_dataset("iris")
80+
return _get_dataset("iris", return_type=return_type)
8581

8682

87-
def wind():
83+
def wind(return_type="pandas"):
8884
"""
8985
Each row represents a level of wind intensity in a cardinal direction, and its frequency.
9086
9187
Returns:
9288
A `pandas.DataFrame` with 128 rows and the following columns:
9389
`['direction', 'strength', 'frequency']`."""
94-
return _get_dataset("wind")
90+
return _get_dataset("wind", return_type=return_type)
9591

9692

97-
def election():
93+
def election(return_type="pandas"):
9894
"""
9995
Each row represents voting results for an electoral district in the 2013 Montreal
10096
mayoral election.
10197
10298
Returns:
10399
A `pandas.DataFrame` with 58 rows and the following columns:
104100
`['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result', 'district_id']`."""
105-
return _get_dataset("election")
101+
return _get_dataset("election", return_type=return_type)
106102

107103

108104
def election_geojson():
@@ -128,18 +124,18 @@ def election_geojson():
128124
return result
129125

130126

131-
def carshare():
127+
def carshare(return_type="pandas"):
132128
"""
133129
Each row represents the availability of car-sharing services near the centroid of a zone
134130
in Montreal over a month-long period.
135131
136132
Returns:
137133
A `pandas.DataFrame` with 249 rows and the following columns:
138134
`['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`."""
139-
return _get_dataset("carshare")
135+
return _get_dataset("carshare", return_type=return_type)
140136

141137

142-
def stocks(indexed=False, datetimes=False):
138+
def stocks(indexed=False, datetimes=False, return_type="pandas"):
143139
"""
144140
Each row in this wide dataset represents closing prices from 6 tech stocks in 2018/2019.
145141
@@ -149,16 +145,23 @@ def stocks(indexed=False, datetimes=False):
149145
If `indexed` is True, the 'date' column is used as the index and the column index
150146
If `datetimes` is True, the 'date' column will be a datetime column
151147
is named 'company'"""
152-
df = _get_dataset("stocks")
148+
if indexed and return_type != "pandas":
149+
msg = "Cannot set index for backend different from pandas"
150+
raise NotImplementedError(msg)
151+
152+
df = nw.from_native(_get_dataset("stocks", return_type=return_type), eager_only=True)
153153
if datetimes:
154-
df["date"] = df["date"].astype("datetime64[ns]")
155-
if indexed:
156-
df = df.set_index("date")
154+
df = df.with_columns(nw.col("date").cast(nw.Datetime(time_unit="ns")))
155+
156+
if indexed: # then it must be pandas
157+
df = df.to_native().set_index("date")
157158
df.columns.name = "company"
158-
return df
159+
return df
159160

161+
return df.to_native()
160162

161-
def experiment(indexed=False):
163+
164+
def experiment(indexed=False, return_type="pandas"):
162165
"""
163166
Each row in this wide dataset represents the results of 100 simulated participants
164167
on three hypothetical experiments, along with their gender and control/treatment group.
@@ -168,13 +171,20 @@ def experiment(indexed=False):
168171
A `pandas.DataFrame` with 100 rows and the following columns:
169172
`['experiment_1', 'experiment_2', 'experiment_3', 'gender', 'group']`.
170173
If `indexed` is True, the data frame index is named "participant" """
171-
df = _get_dataset("experiment")
172-
if indexed:
174+
175+
if indexed and return_type != "pandas":
176+
msg = "Cannot set index for backend different from pandas"
177+
raise NotImplementedError(msg)
178+
179+
df = nw.from_native(_get_dataset("experiment", return_type=return_type), eager_only=True)
180+
if indexed: # then it must be pandas
181+
df = df.to_native()
173182
df.index.name = "participant"
174-
return df
183+
return df
184+
return df.to_native()
175185

176186

177-
def medals_wide(indexed=False):
187+
def medals_wide(indexed=False, return_type="pandas"):
178188
"""
179189
This dataset represents the medal table for Olympic Short Track Speed Skating for the
180190
top three nations as of 2020.
@@ -184,14 +194,20 @@ def medals_wide(indexed=False):
184194
`['nation', 'gold', 'silver', 'bronze']`.
185195
If `indexed` is True, the 'nation' column is used as the index and the column index
186196
is named 'medal'"""
187-
df = _get_dataset("medals")
188-
if indexed:
189-
df = df.set_index("nation")
197+
198+
if indexed and return_type != "pandas":
199+
msg = "Cannot set index for backend different from pandas"
200+
raise NotImplementedError(msg)
201+
202+
df = nw.from_native(_get_dataset("medals", return_type=return_type), eager_only=True)
203+
if indexed: # then it must be pandas
204+
df = df.to_native().set_index("nation")
190205
df.columns.name = "medal"
191-
return df
206+
return df
207+
return df.to_native()
192208

193209

194-
def medals_long(indexed=False):
210+
def medals_long(indexed=False, return_type="pandas"):
195211
"""
196212
This dataset represents the medal table for Olympic Short Track Speed Skating for the
197213
top three nations as of 2020.
@@ -200,23 +216,42 @@ def medals_long(indexed=False):
200216
A `pandas.DataFrame` with 9 rows and the following columns:
201217
`['nation', 'medal', 'count']`.
202218
If `indexed` is True, the 'nation' column is used as the index."""
203-
df = _get_dataset("medals").melt(
204-
id_vars=["nation"], value_name="count", var_name="medal"
205-
)
219+
220+
if indexed and return_type != "pandas":
221+
msg = "Cannot set index for backend different from pandas"
222+
raise NotImplementedError(msg)
223+
224+
df = (
225+
nw.from_native(_get_dataset("medals", return_type=return_type), eager_only=True)
226+
.unpivot(
227+
index=["nation"],
228+
value_name="count",
229+
variable_name="medal",
230+
))
206231
if indexed:
207-
df = df.set_index("nation")
208-
return df
232+
df = nw.maybe_set_index(df, "nation")
233+
return df.to_native()
209234

210235

211-
def _get_dataset(d):
212-
import pandas
236+
def _get_dataset(d, return_type):
213237
import os
238+
from importlib import import_module
214239

215-
return pandas.read_csv(
216-
os.path.join(
217-
os.path.dirname(os.path.dirname(__file__)),
218-
"package_data",
219-
"datasets",
220-
d + ".csv.gz",
221-
)
240+
AVAILABLE_BACKENDS = {"pandas", "polars", "pyarrow"}
241+
242+
filepath = os.path.join(
243+
os.path.dirname(os.path.dirname(__file__)),
244+
"package_data",
245+
"datasets",
246+
d + ".csv.gz",
222247
)
248+
if return_type not in AVAILABLE_BACKENDS:
249+
msg = f"Unsupported return_type. Found {return_type}, expected one of {AVAILABLE_BACKENDS}"
250+
raise NotImplementedError(msg)
251+
252+
try:
253+
backend = import_module(return_type)
254+
return backend.read_csv(filepath)
255+
except ModuleNotFoundError:
256+
msg = f"return_type={return_type}, but {return_type} is not installed"
257+
raise ModuleNotFoundError(msg)

packages/python/plotly/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
### $ pip install -r requirements.txt ###
55
### ###
66
###################################################
7+
8+
## dataframe agnostic layer ##
9+
narwhals>=1.12.0

packages/python/plotly/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def run(self):
603603
data_files=[
604604
("etc/jupyter/nbconfig/notebook.d", ["jupyterlab-plotly.json"]),
605605
],
606-
install_requires=["packaging"],
606+
install_requires=["narwhals>=1.12.0", "packaging"],
607607
zip_safe=False,
608608
cmdclass=dict(
609609
build_py=js_prerelease(versioneer_cmds["build_py"]),

0 commit comments

Comments
 (0)