44import io
55import itertools
66import textwrap
7- from collections import ChainMap
7+ from collections import ChainMap , defaultdict
88from collections .abc import (
99 Callable ,
1010 Hashable ,
1111 Iterable ,
1212 Iterator ,
1313 Mapping ,
1414)
15+ from dataclasses import dataclass , field
1516from html import escape
1617from os import PathLike
1718from typing import (
2122 Literal ,
2223 NoReturn ,
2324 ParamSpec ,
25+ TypeAlias ,
2426 TypeVar ,
2527 Union ,
2628 overload ,
8587 DtCompatible ,
8688 ErrorOptions ,
8789 ErrorOptionsWithWarn ,
90+ NestedDict ,
8891 NetcdfWriteModes ,
8992 T_ChunkDimFreq ,
9093 T_ChunksFreq ,
@@ -441,6 +444,20 @@ def map( # type: ignore[override]
441444 return Dataset (variables , attrs = attrs )
442445
443446
447+ FromDictDataValue : TypeAlias = "CoercibleValue | Dataset | DataTree | None"
448+
449+
450+ @dataclass
451+ class _CoordWrapper :
452+ value : CoercibleValue
453+
454+
455+ @dataclass
456+ class _DatasetArgs :
457+ data_vars : dict [str , CoercibleValue ] = field (default_factory = dict )
458+ coords : dict [str , CoercibleValue ] = field (default_factory = dict )
459+
460+
444461class DataTree (
445462 NamedNode ,
446463 DataTreeAggregations ,
@@ -1154,51 +1171,215 @@ def drop_nodes(
11541171 result ._replace_node (children = children_to_keep )
11551172 return result
11561173
1174+ @overload
1175+ @classmethod
1176+ def from_dict (
1177+ cls ,
1178+ data : Mapping [str , FromDictDataValue ] | None = ...,
1179+ coords : Mapping [str , CoercibleValue ] | None = ...,
1180+ * ,
1181+ name : str | None = ...,
1182+ nested : Literal [False ] = ...,
1183+ ) -> Self : ...
1184+
1185+ @overload
1186+ @classmethod
1187+ def from_dict (
1188+ cls ,
1189+ data : (
1190+ Mapping [str , FromDictDataValue | NestedDict [FromDictDataValue ]] | None
1191+ ) = ...,
1192+ coords : Mapping [str , CoercibleValue | NestedDict [CoercibleValue ]] | None = ...,
1193+ * ,
1194+ name : str | None = ...,
1195+ nested : Literal [True ] = ...,
1196+ ) -> Self : ...
1197+
11571198 @classmethod
11581199 def from_dict (
11591200 cls ,
1160- d : Mapping [str , Dataset | DataTree | None ],
1161- / ,
1201+ data : (
1202+ Mapping [str , FromDictDataValue | NestedDict [FromDictDataValue ]] | None
1203+ ) = None ,
1204+ coords : Mapping [str , CoercibleValue | NestedDict [CoercibleValue ]] | None = None ,
1205+ * ,
11621206 name : str | None = None ,
1207+ nested : bool = False ,
11631208 ) -> Self :
11641209 """
11651210 Create a datatree from a dictionary of data objects, organised by paths into the tree.
11661211
11671212 Parameters
11681213 ----------
1169- d : dict-like
1170- A mapping from path names to xarray.Dataset or DataTree objects.
1214+ data : dict-like, optional
1215+ A mapping from path names to ``None`` (indicating an empty node),
1216+ ``DataTree``, ``Dataset``, objects coercible into a ``DataArray`` or
1217+ a nested dictionary of any of the above types.
11711218
1172- Path names are to be given as unix-like path. If path names
1173- containing more than one part are given, new tree nodes will be
1174- constructed as necessary.
1219+ Path names should be given as unix-like paths, either absolute
1220+ (/path/to/item) or relative to the root node (path/to/item). If path
1221+ names containing more than one part are given, new tree nodes will
1222+ be constructed automatically as necessary.
11751223
11761224 To assign data to the root node of the tree use "", ".", "/" or "./"
11771225 as the path.
1226+ coords : dict-like, optional
1227+ A mapping from path names to objects coercible into a DataArray, or
1228+ nested dictionaries of coercible objects.
11781229 name : Hashable | None, optional
11791230 Name for the root node of the tree. Default is None.
1231+ nested : bool, optional
1232+ If true, nested dictionaries in ``data`` and ``coords`` are
1233+ automatically flattened.
11801234
11811235 Returns
11821236 -------
11831237 DataTree
11841238
1239+ See also
1240+ --------
1241+ Dataset
1242+
11851243 Notes
11861244 -----
1187- If your dictionary is nested you will need to flatten it before using this method.
1188- """
1189- # Find any values corresponding to the root
1190- d_cast = dict (d )
1191- root_data = None
1192- for key in ("" , "." , "/" , "./" ):
1193- if key in d_cast :
1194- if root_data is not None :
1245+ ``DataTree.from_dict`` serves a conceptually different purpose from
1246+ ``Dataset.from_dict`` and ``DataArray.from_dict``. It converts a
1247+ hierarchy of Xarray objects into a DataTree, rather than converting pure
1248+ Python data structures.
1249+
1250+ Examples
1251+ --------
1252+
1253+ Construct a tree from a dict of Dataset objects:
1254+
1255+ >>> dt = DataTree.from_dict(
1256+ ... {
1257+ ... "/": Dataset(coords={"time": [1, 2, 3]}),
1258+ ... "/ocean": Dataset(
1259+ ... {
1260+ ... "temperature": ("time", [4, 5, 6]),
1261+ ... "salinity": ("time", [7, 8, 9]),
1262+ ... }
1263+ ... ),
1264+ ... "/atmosphere": Dataset(
1265+ ... {
1266+ ... "temperature": ("time", [2, 3, 4]),
1267+ ... "humidity": ("time", [3, 4, 5]),
1268+ ... }
1269+ ... ),
1270+ ... }
1271+ ... )
1272+ >>> dt
1273+ <xarray.DataTree>
1274+ Group: /
1275+ │ Dimensions: (time: 3)
1276+ │ Coordinates:
1277+ │ * time (time) int64 24B 1 2 3
1278+ ├── Group: /ocean
1279+ │ Dimensions: (time: 3)
1280+ │ Data variables:
1281+ │ temperature (time) int64 24B 4 5 6
1282+ │ salinity (time) int64 24B 7 8 9
1283+ └── Group: /atmosphere
1284+ Dimensions: (time: 3)
1285+ Data variables:
1286+ temperature (time) int64 24B 2 3 4
1287+ humidity (time) int64 24B 3 4 5
1288+
1289+ Or equivalently, use a dict of values that can be converted into
1290+ `DataArray` objects, with syntax similar to the Dataset constructor:
1291+
1292+ >>> dt2 = DataTree.from_dict(
1293+ ... data={
1294+ ... "/ocean/temperature": ("time", [4, 5, 6]),
1295+ ... "/ocean/salinity": ("time", [7, 8, 9]),
1296+ ... "/atmosphere/temperature": ("time", [2, 3, 4]),
1297+ ... "/atmosphere/humidity": ("time", [3, 4, 5]),
1298+ ... },
1299+ ... coords={"/time": [1, 2, 3]},
1300+ ... )
1301+ >>> assert dt.identical(dt2)
1302+
1303+ Nested dictionaries are automatically flattened if ``nested=True``:
1304+
1305+ >>> DataTree.from_dict({"a": {"b": {"c": {"x": 1, "y": 2}}}}, nested=True)
1306+ <xarray.DataTree>
1307+ Group: /
1308+ └── Group: /a
1309+ └── Group: /a/b
1310+ └── Group: /a/b/c
1311+ Dimensions: ()
1312+ Data variables:
1313+ x int64 8B 1
1314+ y int64 8B 2
1315+
1316+ """
1317+ if data is None :
1318+ data = {}
1319+
1320+ if coords is None :
1321+ coords = {}
1322+
1323+ if nested :
1324+ data_items = utils .flat_items (data )
1325+ coords_items = utils .flat_items (coords )
1326+ else :
1327+ data_items = data .items ()
1328+ coords_items = coords .items ()
1329+ for arg_name , items in [("data" , data_items ), ("coords" , coords_items )]:
1330+ for key , value in items :
1331+ if isinstance (value , dict ):
1332+ raise TypeError (
1333+ f"{ arg_name } contains a dict value at { key = } , "
1334+ "which is not a valid argument to "
1335+ f"DataTree.from_dict() with nested=False: { value } "
1336+ )
1337+
1338+ # Canonicalize and unify paths between `data` and `coords`
1339+ flat_data_and_coords = itertools .chain (
1340+ data_items ,
1341+ ((k , _CoordWrapper (v )) for k , v in coords_items ),
1342+ )
1343+ nodes : dict [NodePath , _CoordWrapper | FromDictDataValue ] = {}
1344+ for key , value in flat_data_and_coords :
1345+ path = NodePath (key ).absolute ()
1346+ if path in nodes :
1347+ raise ValueError (
1348+ f"multiple entries found corresponding to node { str (path )!r} "
1349+ )
1350+ nodes [path ] = value
1351+
1352+ # Merge nodes corresponding to DataArrays into Datasets
1353+ dataset_args : defaultdict [NodePath , _DatasetArgs ] = defaultdict (_DatasetArgs )
1354+ for path in list (nodes ):
1355+ node = nodes [path ]
1356+ if node is not None and not isinstance (node , Dataset | DataTree ):
1357+ if path .parent == path :
1358+ raise ValueError ("cannot set DataArray value at root" )
1359+ if path .parent in nodes :
11951360 raise ValueError (
1196- "multiple entries found corresponding to the root node"
1361+ f"cannot set DataArray value at { str (path )!r} when "
1362+ f"parent node at { str (path .parent )!r} is also set"
11971363 )
1198- root_data = d_cast .pop (key )
1364+ del nodes [path ]
1365+ if isinstance (node , _CoordWrapper ):
1366+ dataset_args [path .parent ].coords [path .name ] = node .value
1367+ else :
1368+ dataset_args [path .parent ].data_vars [path .name ] = node
1369+ for path , args in dataset_args .items ():
1370+ try :
1371+ nodes [path ] = Dataset (args .data_vars , args .coords )
1372+ except (ValueError , TypeError ) as e :
1373+ raise type (e )(
1374+ "failed to construct xarray.Dataset for DataTree node at "
1375+ f"{ str (path )!r} with data_vars={ args .data_vars } and "
1376+ f"coords={ args .coords } "
1377+ ) from e
11991378
12001379 # Create the root node
1201- if isinstance (root_data , DataTree ):
1380+ root_data = nodes .pop (NodePath ("/" ), None )
1381+ if isinstance (root_data , cls ):
1382+ # use cls so type-checkers understand this method returns Self
12021383 obj = root_data .copy ()
12031384 obj .name = name
12041385 elif root_data is None or isinstance (root_data , Dataset ):
@@ -1209,31 +1390,29 @@ def from_dict(
12091390 f"or DataTree, got { type (root_data )} "
12101391 )
12111392
1212- def depth (item ) -> int :
1213- pathstr , _ = item
1214- return len (NodePath ( pathstr ) .parts )
1393+ def depth (item : tuple [ NodePath , object ] ) -> int :
1394+ node_path , _ = item
1395+ return len (node_path .parts )
12151396
1216- if d_cast :
1217- # Populate tree with children determined from data_objects mapping
1397+ if nodes :
1398+ # Populate tree with children
12181399 # Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
1219- for path , data in sorted (d_cast .items (), key = depth ):
1400+ for path , node in sorted (nodes .items (), key = depth ):
12201401 # Create and set new node
1221- if isinstance (data , DataTree ):
1222- new_node = data .copy ()
1223- elif isinstance (data , Dataset ) or data is None :
1224- new_node = cls (dataset = data )
1402+ if isinstance (node , DataTree ):
1403+ new_node = node .copy ()
1404+ elif isinstance (node , Dataset ) or node is None :
1405+ new_node = cls (dataset = node )
12251406 else :
1226- raise TypeError (f"invalid values: { data } " )
1407+ raise TypeError (f"invalid values: { node } " )
12271408 obj ._set_item (
12281409 path ,
12291410 new_node ,
12301411 allow_overwrite = False ,
12311412 new_nodes_along_path = True ,
12321413 )
12331414
1234- # TODO: figure out why mypy is raising an error here, likely something
1235- # to do with the return type of Dataset.copy()
1236- return obj # type: ignore[return-value]
1415+ return obj
12371416
12381417 def to_dict (self , relative : bool = False ) -> dict [str , Dataset ]:
12391418 """
0 commit comments