Skip to content

Commit cd68b2a

Browse files
Implement drivers for geoparquet.
1 parent a6c0ca8 commit cd68b2a

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

pins/drivers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def _assert_is_pandas_df(x, file_type: str) -> None:
2222
)
2323

2424

25+
def _assert_is_geopandas_df(x):
26+
# Assume we have already protected against uninstalled geopandas
27+
import geopandas as gpd
28+
29+
if not isinstance(x, gpd.GeoDataFrame):
30+
raise NotImplementedError(
31+
"Currently only geopandas.GeoDataFrame can be saved to a GeoParquet."
32+
)
33+
34+
2535
def load_path(meta, path_to_version):
2636
# Check that only a single file name was given
2737
fnames = [meta.file] if isinstance(meta.file, str) else meta.file
@@ -104,6 +114,17 @@ def load_data(
104114

105115
return pd.read_csv(f)
106116

117+
elif meta.type == "geoparquet":
118+
try:
119+
import geopandas as gpd
120+
except ModuleNotFoundError:
121+
raise ModuleNotFoundError(
122+
'The "geopandas" package is required to read "geoparquet" type '
123+
"files."
124+
) from None
125+
126+
return gpd.read_parquet(f)
127+
107128
elif meta.type == "joblib":
108129
import joblib
109130

@@ -144,6 +165,8 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
144165
if apply_suffix:
145166
if type == "file":
146167
suffix = "".join(Path(obj).suffixes)
168+
elif type == "geoparquet":
169+
suffix = ".parquet"
147170
else:
148171
suffix = f".{type}"
149172
else:
@@ -175,6 +198,11 @@ def save_data(obj, fname, type=None, apply_suffix: bool = True) -> "str | Sequen
175198

176199
obj.to_parquet(final_name)
177200

201+
elif type == "geoparquet":
202+
_assert_is_geopandas_df(obj)
203+
204+
obj.to_parquet(final_name)
205+
178206
elif type == "joblib":
179207
import joblib
180208

pins/tests/test_drivers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ def test_driver_roundtrip(tmp_path: Path, type_):
7676
assert df.equals(obj)
7777

7878

79+
def test_driver_geoparquet_roundtrip(tmp_dir2):
80+
import geopandas as gpd
81+
82+
gdf = gpd.GeoDataFrame(
83+
{"x": [1, 2, 3], "geometry": gpd.points_from_xy([1, 2, 3], [1, 2, 3])}
84+
)
85+
86+
fname = "some_gdf"
87+
full_file = f"{fname}.parquet"
88+
89+
p_obj = tmp_dir2 / fname
90+
res_fname = save_data(gdf, p_obj, "geoparquet")
91+
92+
assert Path(res_fname).name == full_file
93+
94+
meta = MetaRaw(full_file, "geoparquet", "my_pin")
95+
obj = load_data(meta, fsspec.filesystem("file"), tmp_dir2, allow_pickle_read=True)
96+
97+
assert gdf.equals(obj)
98+
99+
79100
@pytest.mark.parametrize(
80101
"type_",
81102
[

0 commit comments

Comments
 (0)