11"""
22Built-in datasets for demonstration, educational and test purposes.
33"""
4+ import narwhals .stable .v1 as nw
45
5-
6- def gapminder (datetimes = False , centroids = False , year = None , pretty_names = False ):
6+ def gapminder (datetimes = False , centroids = False , year = None , pretty_names = False , return_type = "pandas" ):
77 """
88 Each row represents a country on a given year.
99
@@ -17,16 +17,16 @@ def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
1717 If `centroids` is True, two new columns are added: ['centroid_lat', 'centroid_lon']
1818 If `year` is an integer, the dataset will be filtered for that year
1919 """
20- df = _get_dataset ("gapminder" )
20+ df = nw . from_native ( _get_dataset ("gapminder" , return_type = return_type ), eager_only = True )
2121 if year :
22- df = df [ df [ "year" ] == year ]
22+ df = df . filter ( nw . col ( "year" ) == year )
2323 if datetimes :
24- df [ "year" ] = ( df [ "year" ]. astype ( str ) + "-01-01" ). astype ( "datetime64[ns]" )
24+ df = df . with_columns ( nw . concat_str ([ nw . col ( "year" ). cast ( nw . String ()), nw . lit ( "-01-01" )]). cast ( nw . Datetime ( time_unit = "ns" )) )
2525 if not centroids :
26- df = df .drop ([ "centroid_lat" , "centroid_lon" ], axis = 1 )
26+ df = df .drop ("centroid_lat" , "centroid_lon" )
2727 if pretty_names :
28- df .rename (
29- mapper = dict (
28+ df = df .rename (
29+ dict (
3030 country = "Country" ,
3131 continent = "Continent" ,
3232 year = "Year" ,
@@ -37,14 +37,12 @@ def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False):
3737 iso_num = "ISO Numeric Country Code" ,
3838 centroid_lat = "Centroid Latitude" ,
3939 centroid_lon = "Centroid Longitude" ,
40- ),
41- axis = "columns" ,
42- inplace = True ,
40+ )
4341 )
44- return df
42+ return df . to_native ()
4543
4644
47- def tips (pretty_names = False ):
45+ def tips (pretty_names = False , return_type = "pandas" ):
4846 """
4947 Each row represents a restaurant bill.
5048
@@ -54,25 +52,23 @@ def tips(pretty_names=False):
5452 A `pandas.DataFrame` with 244 rows and the following columns:
5553 `['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`."""
5654
57- df = _get_dataset ("tips" )
55+ df = nw . from_native ( _get_dataset ("tips" , return_type = return_type ), eager_only = True )
5856 if pretty_names :
59- df .rename (
60- mapper = dict (
57+ df = df .rename (
58+ dict (
6159 total_bill = "Total Bill" ,
6260 tip = "Tip" ,
6361 sex = "Payer Gender" ,
6462 smoker = "Smokers at Table" ,
6563 day = "Day of Week" ,
6664 time = "Meal" ,
6765 size = "Party Size" ,
68- ),
69- axis = "columns" ,
70- inplace = True ,
66+ )
7167 )
72- return df
68+ return df . to_native ()
7369
7470
75- def iris ():
71+ def iris (return_type = "pandas" ):
7672 """
7773 Each row represents a flower.
7874
@@ -81,28 +77,28 @@ def iris():
8177 Returns:
8278 A `pandas.DataFrame` with 150 rows and the following columns:
8379 `['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', 'species_id']`."""
84- return _get_dataset ("iris" )
80+ return _get_dataset ("iris" , return_type = return_type )
8581
8682
87- def wind ():
83+ def wind (return_type = "pandas" ):
8884 """
8985 Each row represents a level of wind intensity in a cardinal direction, and its frequency.
9086
9187 Returns:
9288 A `pandas.DataFrame` with 128 rows and the following columns:
9389 `['direction', 'strength', 'frequency']`."""
94- return _get_dataset ("wind" )
90+ return _get_dataset ("wind" , return_type = return_type )
9591
9692
97- def election ():
93+ def election (return_type = "pandas" ):
9894 """
9995 Each row represents voting results for an electoral district in the 2013 Montreal
10096 mayoral election.
10197
10298 Returns:
10399 A `pandas.DataFrame` with 58 rows and the following columns:
104100 `['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result', 'district_id']`."""
105- return _get_dataset ("election" )
101+ return _get_dataset ("election" , return_type = return_type )
106102
107103
108104def election_geojson ():
@@ -128,18 +124,18 @@ def election_geojson():
128124 return result
129125
130126
131- def carshare ():
127+ def carshare (return_type = "pandas" ):
132128 """
133129 Each row represents the availability of car-sharing services near the centroid of a zone
134130 in Montreal over a month-long period.
135131
136132 Returns:
137133 A `pandas.DataFrame` with 249 rows and the following columns:
138134 `['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`."""
139- return _get_dataset ("carshare" )
135+ return _get_dataset ("carshare" , return_type = return_type )
140136
141137
142- def stocks (indexed = False , datetimes = False ):
138+ def stocks (indexed = False , datetimes = False , return_type = "pandas" ):
143139 """
144140 Each row in this wide dataset represents closing prices from 6 tech stocks in 2018/2019.
145141
@@ -149,16 +145,23 @@ def stocks(indexed=False, datetimes=False):
149145 If `indexed` is True, the 'date' column is used as the index and the column index
150146 If `datetimes` is True, the 'date' column will be a datetime column
151147 is named 'company'"""
152- df = _get_dataset ("stocks" )
148+ if indexed and return_type != "pandas" :
149+ msg = "Cannot set index for backend different from pandas"
150+ raise NotImplementedError (msg )
151+
152+ df = nw .from_native (_get_dataset ("stocks" , return_type = return_type ), eager_only = True )
153153 if datetimes :
154- df ["date" ] = df ["date" ].astype ("datetime64[ns]" )
155- if indexed :
156- df = df .set_index ("date" )
154+ df = df .with_columns (nw .col ("date" ).cast (nw .Datetime (time_unit = "ns" )))
155+
156+ if indexed : # then it must be pandas
157+ df = df .to_native ().set_index ("date" )
157158 df .columns .name = "company"
158- return df
159+ return df
159160
161+ return df .to_native ()
160162
161- def experiment (indexed = False ):
163+
164+ def experiment (indexed = False , return_type = "pandas" ):
162165 """
163166 Each row in this wide dataset represents the results of 100 simulated participants
164167 on three hypothetical experiments, along with their gender and control/treatment group.
@@ -168,13 +171,20 @@ def experiment(indexed=False):
168171 A `pandas.DataFrame` with 100 rows and the following columns:
169172 `['experiment_1', 'experiment_2', 'experiment_3', 'gender', 'group']`.
170173 If `indexed` is True, the data frame index is named "participant" """
171- df = _get_dataset ("experiment" )
172- if indexed :
174+
175+ if indexed and return_type != "pandas" :
176+ msg = "Cannot set index for backend different from pandas"
177+ raise NotImplementedError (msg )
178+
179+ df = nw .from_native (_get_dataset ("experiment" , return_type = return_type ), eager_only = True )
180+ if indexed : # then it must be pandas
181+ df = df .to_native ()
173182 df .index .name = "participant"
174- return df
183+ return df
184+ return df .to_native ()
175185
176186
177- def medals_wide (indexed = False ):
187+ def medals_wide (indexed = False , return_type = "pandas" ):
178188 """
179189 This dataset represents the medal table for Olympic Short Track Speed Skating for the
180190 top three nations as of 2020.
@@ -184,14 +194,20 @@ def medals_wide(indexed=False):
184194 `['nation', 'gold', 'silver', 'bronze']`.
185195 If `indexed` is True, the 'nation' column is used as the index and the column index
186196 is named 'medal'"""
187- df = _get_dataset ("medals" )
188- if indexed :
189- df = df .set_index ("nation" )
197+
198+ if indexed and return_type != "pandas" :
199+ msg = "Cannot set index for backend different from pandas"
200+ raise NotImplementedError (msg )
201+
202+ df = nw .from_native (_get_dataset ("medals" , return_type = return_type ), eager_only = True )
203+ if indexed : # then it must be pandas
204+ df = df .to_native ().set_index ("nation" )
190205 df .columns .name = "medal"
191- return df
206+ return df
207+ return df .to_native ()
192208
193209
194- def medals_long (indexed = False ):
210+ def medals_long (indexed = False , return_type = "pandas" ):
195211 """
196212 This dataset represents the medal table for Olympic Short Track Speed Skating for the
197213 top three nations as of 2020.
@@ -200,23 +216,42 @@ def medals_long(indexed=False):
200216 A `pandas.DataFrame` with 9 rows and the following columns:
201217 `['nation', 'medal', 'count']`.
202218 If `indexed` is True, the 'nation' column is used as the index."""
203- df = _get_dataset ("medals" ).melt (
204- id_vars = ["nation" ], value_name = "count" , var_name = "medal"
205- )
219+
220+ if indexed and return_type != "pandas" :
221+ msg = "Cannot set index for backend different from pandas"
222+ raise NotImplementedError (msg )
223+
224+ df = (
225+ nw .from_native (_get_dataset ("medals" , return_type = return_type ), eager_only = True )
226+ .unpivot (
227+ index = ["nation" ],
228+ value_name = "count" ,
229+ variable_name = "medal" ,
230+ ))
206231 if indexed :
207- df = df . set_index ( "nation" )
208- return df
232+ df = nw . maybe_set_index ( df , "nation" )
233+ return df . to_native ()
209234
210235
211- def _get_dataset (d ):
212- import pandas
236+ def _get_dataset (d , return_type ):
213237 import os
238+ from importlib import import_module
214239
215- return pandas . read_csv (
216- os . path . join (
217- os .path .dirname ( os . path . dirname ( __file__ )),
218- "package_data" ,
219- "datasets " ,
220- d + ".csv.gz " ,
221- )
240+ AVAILABLE_BACKENDS = { " pandas" , "polars" , "pyarrow" }
241+
242+ filepath = os .path .join (
243+ os . path . dirname ( os . path . dirname ( __file__ )) ,
244+ "package_data " ,
245+ "datasets " ,
246+ d + ".csv.gz" ,
222247 )
248+ if return_type not in AVAILABLE_BACKENDS :
249+ msg = f"Unsupported return_type. Found { return_type } , expected one of { AVAILABLE_BACKENDS } "
250+ raise NotImplementedError (msg )
251+
252+ try :
253+ backend = import_module (return_type )
254+ return backend .read_csv (filepath )
255+ except ModuleNotFoundError :
256+ msg = f"return_type={ return_type } , but { return_type } is not installed"
257+ raise ModuleNotFoundError (msg )
0 commit comments