Skip to content

Commit 233af1a

Browse files
matthew-brettMarcCote
authored andcommitted
WIP: thinking about API for data dict
A suggested partial refactor of data dict. [skip ci]
1 parent c86185b commit 233af1a

File tree

1 file changed

+45
-74
lines changed

1 file changed

+45
-74
lines changed

nibabel/streamlines/tractogram.py

Lines changed: 45 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,60 +10,55 @@
1010

1111

1212
def is_data_dict(obj):
13-
""" Tells if obj is a :class:`DataDict`. """
13+
""" True if `obj` seems to implement the :class:`DataDict` API """
1414
return hasattr(obj, 'store')
1515

1616

1717
def is_lazy_dict(obj):
18-
""" Tells if obj is a :class:`LazyDict`. """
18+
""" True if `obj` seems to implement the :class:`LazyDict` API """
1919
return is_data_dict(obj) and callable(obj.store.values()[0])
2020

2121

22-
class DataDict(collections.MutableMapping):
23-
""" Dictionary that makes sure data are 2D array.
22+
class SliceableDataDict(collections.MutableMapping):
23+
""" Dictionary for which key access can do slicing on the values.
2424
25-
This container behaves like a standard dictionary but it makes sure its
26-
elements are ndarrays. In addition, it makes sure the amount of data
27-
contained in those ndarrays matches the number of streamlines of the
28-
:class:`Tractogram` object provided at the instantiation of this
29-
dictionary.
25+
This container behaves like a standard dictionary but extends key access to
26+
allow keys for key access to be indices slicing into the contained ndarray
27+
values.
3028
"""
31-
def __init__(self, tractogram, *args, **kwargs):
32-
self.tractogram = tractogram
29+
def __init__(self, *args, **kwargs):
3330
self.store = dict()
34-
3531
# Use update to set the keys.
36-
if len(args) == 1:
37-
if isinstance(args[0], DataDict):
38-
self.update(**args[0])
39-
elif args[0] is None:
40-
return
41-
else:
42-
self.update(dict(*args, **kwargs))
32+
if len(args) != 1:
33+
self.update(dict(*args, **kwargs))
34+
return
35+
if args[0] is None:
36+
return
37+
if isinstance(args[0], SliceableDataDict):
38+
self.update(**args[0])
4339
else:
4440
self.update(dict(*args, **kwargs))
4541

4642
def __getitem__(self, key):
4743
try:
4844
return self.store[key]
49-
except KeyError:
50-
pass # Maybe it is an integer.
51-
except TypeError:
52-
pass # Maybe it is an object for advanced indexing.
53-
54-
# Try to interpret key as an index/slice in which case we
55-
# perform (advanced) indexing on every element of the dictionnary.
45+
except (KeyError, TypeError):
46+
pass # Maybe it is an integer or a slicing object
47+
48+
# Try to interpret key as an index/slice for every data element, in
49+
# which case we perform (maybe advanced) indexing on every element of
50+
# the dictionnary.
51+
idx = key
52+
new_dict = type(self)(None)
5653
try:
57-
idx = key
58-
new_dict = type(self)(None)
5954
for k, v in self.items():
6055
new_dict[k] = v[idx]
61-
62-
return new_dict
6356
except TypeError:
6457
pass
58+
else:
59+
return new_dict
6560

66-
# That means key was not an index/slice after all.
61+
# Key was not a valid index/slice after all.
6762
return self.store[key] # Will raise the proper error.
6863

6964
def __delitem__(self, key):
@@ -76,15 +71,21 @@ def __len__(self):
7671
return len(self.store)
7772

7873

79-
class DataPerStreamlineDict(DataDict):
80-
""" Dictionary that makes sure data are 2D array.
74+
class PerArrayDict(SliceableDataDict):
75+
""" Dictionary for which key access can do slicing on the values.
76+
77+
This container behaves like a standard dictionary but extends key access to
78+
allow keys for key access to be indices slicing into the contained ndarray
79+
values. The elements must also be ndarrays.
8180
82-
This container behaves like a standard dictionary but it makes sure its
83-
elements are ndarrays. In addition, it makes sure the amount of data
84-
contained in those ndarrays matches the number of streamlines of the
85-
:class:`Tractogram` object provided at the instantiation of this
81+
In addition, it makes sure the amount of data contained in those ndarrays
82+
matches the number of streamlines given at the instantiation of this
8683
dictionary.
8784
"""
85+
def __init__(self, n_elements, *args, **kwargs):
86+
self.n_elements = n_elements
87+
super(PerArrayDict, self).__init__(*args, **kwargs)
88+
8889
def __setitem__(self, key, value):
8990
value = np.asarray(list(value))
9091

@@ -96,66 +97,36 @@ def __setitem__(self, key, value):
9697
raise ValueError("data_per_streamline must be a 2D array.")
9798

9899
# We make sure there is the right amount of values
99-
# (i.e. same as the number of streamlines in the tractogram).
100-
if self.tractogram is not None and len(value) != len(self.tractogram):
101-
msg = ("The number of values ({0}) should match the number of"
102-
" streamlines ({1}).")
103-
raise ValueError(msg.format(len(value), len(self.tractogram)))
104-
105-
self.store[key] = value
106-
107-
108-
class DataPerPointDict(DataDict):
109-
""" Dictionary making sure data are :class:`ArraySequence` objects.
110-
111-
This container behaves like a standard dictionary but it makes sure its
112-
elements are :class:`ArraySequence` objects. In addition, it makes sure
113-
the amount of data contained in those :class:`ArraySequence` objects
114-
matches the the number of points of the :class:`Tractogram` object
115-
provided at the instantiation of this dictionary.
116-
"""
117-
118-
def __setitem__(self, key, value):
119-
value = ArraySequence(value)
120-
121-
# We make sure we have the right amount of values (i.e. same as
122-
# the total number of points of all streamlines in the tractogram).
123-
if (self.tractogram is not None and
124-
len(value._data) != len(self.tractogram.streamlines._data)):
125-
msg = ("The number of values ({0}) should match the total"
126-
" number of points of all streamlines ({1}).")
127-
nb_streamlines_points = self.tractogram.streamlines._data
128-
raise ValueError(msg.format(len(value._data),
129-
len(nb_streamlines_points)))
100+
if self.n_elements is not None and len(value) != self.n_elements:
101+
msg = ("The number of values ({0}) should match n_elements "
102+
"({1}).").format(len(value), self.n_elements)
103+
raise ValueError(msg)
130104

131105
self.store[key] = value
132106

133107

134-
class LazyDict(DataDict):
108+
class LazyDict(SliceableDataDict):
135109
""" Dictionary of generator functions.
136110
137111
This container behaves like an dictionary but it makes sure its elements
138112
are callable objects and assumed to be generator function yielding values.
139113
When getting the element associated to a given key, the element (i.e. a
140114
generator function) is first called before being returned.
141115
"""
142-
def __init__(self, tractogram, *args, **kwargs):
116+
def __init__(self, *args, **kwargs):
143117
if len(args) == 1 and isinstance(args[0], LazyDict):
144118
# Copy the generator functions.
145-
self.tractogram = tractogram
146119
self.store = dict()
147120
self.update(**args[0].store)
148121
return
149-
150-
super(LazyDict, self).__init__(tractogram, *args, **kwargs)
122+
super(LazyDict, self).__init__(*args, **kwargs)
151123

152124
def __getitem__(self, key):
153125
return self.store[key]()
154126

155127
def __setitem__(self, key, value):
156128
if value is not None and not callable(value):
157129
raise TypeError("`value` must be a generator function or None.")
158-
159130
self.store[key] = value
160131

161132

0 commit comments

Comments
 (0)