Skip to content

Commit 9b08496

Browse files
committed
Refactored trk.py
1 parent 40ec4e5 commit 9b08496

File tree

2 files changed

+61
-50
lines changed

2 files changed

+61
-50
lines changed

nibabel/streamlines/tractogram.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,13 @@ class Tractogram(object):
223223
Sequence of $T$ streamlines. Each streamline is an ndarray of
224224
shape ($N_t$, 3) where $N_t$ is the number of points of
225225
streamline $t$.
226-
data_per_streamline : dict of 2D arrays
226+
data_per_streamline : :class:`PerArrayDict` object
227227
Dictionary where the items are (str, 2D array).
228228
Each key represents an information $i$ to be kept along side every
229229
streamline, and its associated value is a 2D array of shape
230230
($T$, $P_i$) where $T$ is the number of streamlines and $P_i$ is
231231
the number scalar values to store for that particular information $i$.
232-
data_per_point : dict of :class:`ArraySequence` objects
232+
data_per_point : :class:`PerArraySequenceDict` object
233233
Dictionary where the items are (str, :class:`ArraySequence`).
234234
Each key represents an information $i$ to be kept along side every
235235
point of every streamline, and its associated value is an iterable

nibabel/streamlines/trk.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,47 @@ def get_affine_rasmm_to_trackvis(header):
128128
return np.linalg.inv(get_affine_trackvis_to_rasmm(header))
129129

130130

131+
def encode_value_in_name(value, name, max_name_len=20):
132+
""" Encodes a value in the last two bytes of a string.
133+
134+
If `value` is one, then there is no encoding and the last two bytes
135+
are left untouched. This function also verify that the length of name is
136+
less than `max_name_len`.
137+
138+
Parameters
139+
----------
140+
value : int
141+
Integer value to encode.
142+
name : str
143+
Name in which the last two bytes will serve to encode `value`.
144+
max_name_len : int, optional
145+
Maximum length name can have.
146+
147+
Returns
148+
-------
149+
encoded_name : str
150+
Name containing the encoded value.
151+
"""
152+
153+
if len(name) > max_name_len:
154+
msg = ("Data information named '{0}' is too long"
155+
" (max {1} characters.)").format(name, max_name_len)
156+
raise ValueError(msg)
157+
elif len(name) > max_name_len-2 and value > 1:
158+
msg = ("Data information named '{0}' is too long (need to be less"
159+
" than {1} characters when storing more than one value"
160+
" for a given data information."
161+
).format(name, max_name_len-2)
162+
raise ValueError(msg)
163+
164+
if value > 1:
165+
# Use the last two bytes of `name` to store `value`.
166+
name = (asbytes(name[:18].ljust(18, '\x00')) + b'\x00' +
167+
np.array(value, dtype=np.int8).tostring())
168+
169+
return name
170+
171+
131172
class TrkReader(object):
132173
""" Convenience class to encapsulate TRK file format.
133174
@@ -326,8 +367,7 @@ def write(self, tractogram):
326367
self.file.write(self.header.tostring())
327368
return
328369

329-
# Update the 'property_name' field using 'data_per_streamline' of the
330-
# tractogram.
370+
# Update field 'property_name' using 'tractogram.data_per_streamline'.
331371
data_for_streamline = first_item.data_for_streamline
332372
if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE:
333373
msg = ("Can only store {0} named data_per_streamline (also known"
@@ -336,59 +376,30 @@ def write(self, tractogram):
336376
raise ValueError(msg)
337377

338378
data_for_streamline_keys = sorted(data_for_streamline.keys())
339-
self.header['property_name'] = np.zeros(
340-
MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE,
341-
dtype='S20')
342-
for i, k in enumerate(data_for_streamline_keys):
343-
nb_values = data_for_streamline[k].shape[0]
344-
345-
if len(k) > 20:
346-
raise ValueError(("Property name '{0}' is too long (max 20"
347-
"characters.)").format(k))
348-
elif len(k) > 18 and nb_values > 1:
349-
raise ValueError(("Property name '{0}' is too long (need to be"
350-
" less than 18 characters when storing more"
351-
" than one value").format(k))
352-
353-
property_name = k
354-
if nb_values > 1:
355-
# Use the last to bytes of the name to store the nb of values
356-
# associated to this data_for_streamline.
357-
property_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' +
358-
np.array(nb_values, dtype=np.int8).tostring())
359-
360-
self.header['property_name'][i] = property_name
361-
362-
# Update the 'scalar_name' field using 'data_per_point' of the
363-
# tractogram.
379+
property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE,
380+
dtype='S20')
381+
for i, name in enumerate(data_for_streamline_keys):
382+
# Use the last to bytes of the name to store the number of values
383+
# associated to this data_for_streamline.
384+
nb_values = data_for_streamline[name].shape[-1]
385+
property_name[i] = encode_value_in_name(nb_values, name)
386+
self.header['property_name'][:] = property_name
387+
388+
# Update field 'scalar_name' using 'tractogram.data_per_point'.
364389
data_for_points = first_item.data_for_points
365390
if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT:
366391
raise ValueError(("Can only store {0} named data_per_point (also"
367392
" known as 'scalars' in the TRK format)."
368393
).format(MAX_NB_NAMED_SCALARS_PER_POINT))
369394

370395
data_for_points_keys = sorted(data_for_points.keys())
371-
self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT,
372-
dtype='S20')
373-
for i, k in enumerate(data_for_points_keys):
374-
nb_values = data_for_points[k].shape[1]
375-
376-
if len(k) > 20:
377-
raise ValueError(("Scalar name '{0}' is too long (max 18"
378-
" characters.)").format(k))
379-
elif len(k) > 18 and nb_values > 1:
380-
raise ValueError(("Scalar name '{0}' is too long (need to be"
381-
" less than 18 characters when storing more"
382-
" than one value").format(k))
383-
384-
scalar_name = k
385-
if nb_values > 1:
386-
# Use the last to bytes of the name to store the nb of values
387-
# associated to this data_for_streamline.
388-
scalar_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' +
389-
np.array(nb_values, dtype=np.int8).tostring())
390-
391-
self.header['scalar_name'][i] = scalar_name
396+
scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20')
397+
for i, name in enumerate(data_for_points_keys):
398+
# Use the last two bytes of the name to store the number of values
399+
# associated to this data_for_streamline.
400+
nb_values = data_for_points[name].shape[-1]
401+
scalar_name[i] = encode_value_in_name(nb_values, name)
402+
self.header['scalar_name'][:] = scalar_name
392403

393404
# Make sure streamlines are in rasmm then send them to voxmm.
394405
tractogram = tractogram.to_world(lazy=True)

0 commit comments

Comments
 (0)