@@ -867,6 +867,23 @@ def _to_xml_element(self):
867867
868868
869869class Cifti2Matrix (xml .XmlSerializable , collections .MutableSequence ):
870+ """ CIFTI2 Matrix object
871+
872+ This is a list-like container where the elements are instances of
873+ :class:`Cifti2MatrixIndicesMap`.
874+
875+ * Description: contains child elements that describe the meaning of the
876+ values in the matrix.
877+ * Attributes: [NA]
878+ * Child Elements
879+ * MetaData (0 .. 1)
880+ * MatrixIndicesMap (1 .. N)
881+ * Text Content: [NA]
882+ * Parent Element: CIFTI
883+
884+ For each matrix (data) dimension, exactly one MatrixIndicesMap element must
885+ list it in the AppliesToMatrixDimension attribute.
886+ """
870887 def __init__ (self ):
871888 self ._mims = []
872889 self .metadata = None
@@ -909,11 +926,9 @@ def insert(self, index, value):
909926 self ._mims .insert (index , value )
910927
911928 def _to_xml_element (self ):
912- if (len (self ) == 0 and self .metadata is None ):
913- raise CIFTI2HeaderError (
914- 'Matrix element requires either a MatrixIndicesMap or a Metadata element'
915- )
916-
929+ # From the spec: "For each matrix dimension, exactly one
930+ # MatrixIndicesMap element must list it in the AppliesToMatrixDimension
931+ # attribute."
917932 mat = xml .Element ('Matrix' )
918933 if self .metadata :
919934 mat .append (self .metadata ._to_xml_element ())
@@ -970,9 +985,9 @@ def __init__(self,
970985 Parameters
971986 ----------
972987 dataobj : object
973- Object containing image data. It should be some object that returns
974- an array from ``np.asanyarray``. It should have a ``shape``
975- attribute or property.
988+ Object containing image data. It should be some object that
989+ returns an array from ``np.asanyarray``. It should have a
990+ ``shape`` attribute or property.
976991 header : Cifti2Header instance
977992 Header with data for / from XML part of CIFTI2 format.
978993 nifti_header : None or mapping or NIfTI2 header instance, optional
@@ -985,6 +1000,11 @@ def __init__(self,
9851000 super (Cifti2Image , self ).__init__ (dataobj , header = header ,
9861001 extra = extra , file_map = file_map )
9871002 self ._nifti_header = Nifti2Header .from_header (nifti_header )
1003+ # if NIfTI header not specified, get data type from input array
1004+ if nifti_header is None :
1005+ if hasattr (dataobj , 'dtype' ):
1006+ self ._nifti_header .set_data_dtype (dataobj .dtype )
1007+ self .update_headers ()
9881008
9891009 @property
9901010 def nifti_header (self ):
@@ -1055,6 +1075,7 @@ def to_file_map(self, file_map=None):
10551075 None
10561076 """
10571077 from .parse_cifti2 import Cifti2Extension
1078+ self .update_headers ()
10581079 header = self ._nifti_header
10591080 extension = Cifti2Extension (content = self .header .to_xml ())
10601081 header .extensions .append (extension )
@@ -1066,6 +1087,26 @@ def to_file_map(self, file_map=None):
10661087 img = Nifti2Image (data , None , header )
10671088 img .to_file_map (file_map or self .file_map )
10681089
1090+ def update_headers (self ):
1091+ ''' Harmonize CIFTI2 and NIfTI headers with image data
1092+
1093+ >>> import numpy as np
1094+ >>> data = np.zeros((2,3,4))
1095+ >>> img = Cifti2Image(data)
1096+ >>> img.shape == (2, 3, 4)
1097+ True
1098+ >>> img.update_headers()
1099+ >>> img.nifti_header.get_data_shape() == (2, 3, 4)
1100+ True
1101+ '''
1102+ self ._nifti_header .set_data_shape (self ._dataobj .shape )
1103+
1104+ def get_data_dtype (self ):
1105+ return self ._nifti_header .get_data_dtype ()
1106+
1107+ def set_data_dtype (self , dtype ):
1108+ self ._nifti_header .set_data_dtype (dtype )
1109+
10691110
10701111def load (filename ):
10711112 """ Load cifti2 from `filename`
0 commit comments