1010
1111
1212def is_data_dict (obj ):
13- """ Tells if obj is a :class:`DataDict`. """
13+ """ True if ` obj` seems to implement the :class:`DataDict` API """
1414 return hasattr (obj , 'store' )
1515
1616
1717def is_lazy_dict (obj ):
18- """ Tells if obj is a :class:`LazyDict`. """
18+ """ True if ` obj` seems to implement the :class:`LazyDict` API """
1919 return is_data_dict (obj ) and callable (obj .store .values ()[0 ])
2020
2121
22- class DataDict (collections .MutableMapping ):
23- """ Dictionary that makes sure data are 2D array .
22+ class SliceableDataDict (collections .MutableMapping ):
23+ """ Dictionary for which key access can do slicing on the values .
2424
25- This container behaves like a standard dictionary but it makes sure its
26- elements are ndarrays. In addition, it makes sure the amount of data
27- contained in those ndarrays matches the number of streamlines of the
28- :class:`Tractogram` object provided at the instantiation of this
29- dictionary.
25+ This container behaves like a standard dictionary but extends key access to
26+ allow keys for key access to be indices slicing into the contained ndarray
27+ values.
3028 """
31- def __init__ (self , tractogram , * args , ** kwargs ):
32- self .tractogram = tractogram
29+ def __init__ (self , * args , ** kwargs ):
3330 self .store = dict ()
34-
3531 # Use update to set the keys.
36- if len (args ) = = 1 :
37- if isinstance ( args [ 0 ], DataDict ):
38- self . update ( ** args [ 0 ])
39- elif args [0 ] is None :
40- return
41- else :
42- self .update (dict ( * args , ** kwargs ) )
32+ if len (args ) ! = 1 :
33+ self . update ( dict ( * args , ** kwargs ))
34+ return
35+ if args [0 ] is None :
36+ return
37+ if isinstance ( args [ 0 ], SliceableDataDict ) :
38+ self .update (** args [ 0 ] )
4339 else :
4440 self .update (dict (* args , ** kwargs ))
4541
4642 def __getitem__ (self , key ):
4743 try :
4844 return self .store [key ]
49- except KeyError :
50- pass # Maybe it is an integer.
51- except TypeError :
52- pass # Maybe it is an object for advanced indexing.
53-
54- # Try to interpret key as an index/slice in which case we
55- # perform (advanced) indexing on every element of the dictionnary.
45+ except (KeyError , TypeError ):
46+ pass # Maybe it is an integer or a slicing object
47+
48+ # Try to interpret key as an index/slice for every data element, in
49+ # which case we perform (maybe advanced) indexing on every element of
50+ # the dictionnary.
51+ idx = key
52+ new_dict = type (self )(None )
5653 try :
57- idx = key
58- new_dict = type (self )(None )
5954 for k , v in self .items ():
6055 new_dict [k ] = v [idx ]
61-
62- return new_dict
6356 except TypeError :
6457 pass
58+ else :
59+ return new_dict
6560
66- # That means key was not an index/slice after all.
61+ # Key was not a valid index/slice after all.
6762 return self .store [key ] # Will raise the proper error.
6863
6964 def __delitem__ (self , key ):
@@ -76,15 +71,21 @@ def __len__(self):
7671 return len (self .store )
7772
7873
79- class DataPerStreamlineDict (DataDict ):
80- """ Dictionary that makes sure data are 2D array.
74+ class PerArrayDict (SliceableDataDict ):
75+ """ Dictionary for which key access can do slicing on the values.
76+
77+ This container behaves like a standard dictionary but extends key access to
78+ allow keys for key access to be indices slicing into the contained ndarray
79+ values. The elements must also be ndarrays.
8180
82- This container behaves like a standard dictionary but it makes sure its
83- elements are ndarrays. In addition, it makes sure the amount of data
84- contained in those ndarrays matches the number of streamlines of the
85- :class:`Tractogram` object provided at the instantiation of this
81+ In addition, it makes sure the amount of data contained in those ndarrays
82+ matches the number of streamlines given at the instantiation of this
8683 dictionary.
8784 """
85+ def __init__ (self , n_elements , * args , ** kwargs ):
86+ self .n_elements = n_elements
87+ super (PerArrayDict , self ).__init__ (* args , ** kwargs )
88+
8889 def __setitem__ (self , key , value ):
8990 value = np .asarray (list (value ))
9091
@@ -96,66 +97,36 @@ def __setitem__(self, key, value):
9697 raise ValueError ("data_per_streamline must be a 2D array." )
9798
9899 # We make sure there is the right amount of values
99- # (i.e. same as the number of streamlines in the tractogram).
100- if self .tractogram is not None and len (value ) != len (self .tractogram ):
101- msg = ("The number of values ({0}) should match the number of"
102- " streamlines ({1})." )
103- raise ValueError (msg .format (len (value ), len (self .tractogram )))
104-
105- self .store [key ] = value
106-
107-
108- class DataPerPointDict (DataDict ):
109- """ Dictionary making sure data are :class:`ArraySequence` objects.
110-
111- This container behaves like a standard dictionary but it makes sure its
112- elements are :class:`ArraySequence` objects. In addition, it makes sure
113- the amount of data contained in those :class:`ArraySequence` objects
114- matches the the number of points of the :class:`Tractogram` object
115- provided at the instantiation of this dictionary.
116- """
117-
118- def __setitem__ (self , key , value ):
119- value = ArraySequence (value )
120-
121- # We make sure we have the right amount of values (i.e. same as
122- # the total number of points of all streamlines in the tractogram).
123- if (self .tractogram is not None and
124- len (value ._data ) != len (self .tractogram .streamlines ._data )):
125- msg = ("The number of values ({0}) should match the total"
126- " number of points of all streamlines ({1})." )
127- nb_streamlines_points = self .tractogram .streamlines ._data
128- raise ValueError (msg .format (len (value ._data ),
129- len (nb_streamlines_points )))
100+ if self .n_elements is not None and len (value ) != self .n_elements :
101+ msg = ("The number of values ({0}) should match n_elements "
102+ "({1})." ).format (len (value ), self .n_elements )
103+ raise ValueError (msg )
130104
131105 self .store [key ] = value
132106
133107
134- class LazyDict (DataDict ):
108+ class LazyDict (SliceableDataDict ):
135109 """ Dictionary of generator functions.
136110
137111 This container behaves like an dictionary but it makes sure its elements
138112 are callable objects and assumed to be generator function yielding values.
139113 When getting the element associated to a given key, the element (i.e. a
140114 generator function) is first called before being returned.
141115 """
142- def __init__ (self , tractogram , * args , ** kwargs ):
116+ def __init__ (self , * args , ** kwargs ):
143117 if len (args ) == 1 and isinstance (args [0 ], LazyDict ):
144118 # Copy the generator functions.
145- self .tractogram = tractogram
146119 self .store = dict ()
147120 self .update (** args [0 ].store )
148121 return
149-
150- super (LazyDict , self ).__init__ (tractogram , * args , ** kwargs )
122+ super (LazyDict , self ).__init__ (* args , ** kwargs )
151123
152124 def __getitem__ (self , key ):
153125 return self .store [key ]()
154126
155127 def __setitem__ (self , key , value ):
156128 if value is not None and not callable (value ):
157129 raise TypeError ("`value` must be a generator function or None." )
158-
159130 self .store [key ] = value
160131
161132
0 commit comments