@@ -133,6 +133,23 @@ def execute_node(self, context, graph):
133133 pass
134134 else :
135135 raise Exception ("Unknown data_layout and input ndim" " combination for MultiThreshold." )
136+
137+ # Remember whether the shape has been modified to handle 1d or 3d data
138+ # layouts
139+ orig_shape = None
140+ # If the input tensor has dimensions not covered by the NC or NCWH data
141+ # layouts, the shape needs to be adapted such that it can be handled by
142+ # multithreshold.
143+ # TODO: Seems like a rather sketchy solution to support arbitrary data
144+ # layouts. This does not even validate the assumption of channel last
145+ # layout.
146+ if v .ndim not in {2 , 4 }:
147+ # Remember the original shape to be restored later
148+ orig_shape = v .shape
149+ # Assume last dimension to be the channel dimension C and reshape
150+ # into NC layout which is supported by multithreshold
151+ v = v .reshape ((- 1 , v .shape [- 1 ]))
152+
136153 # calculate output
137154 output = multithreshold (v , thresholds , out_scale , out_bias )
138155 # setting context according to output
@@ -145,6 +162,13 @@ def execute_node(self, context, graph):
145162 pass
146163 else :
147164 raise Exception ("Unknown data_layout and output ndim" " combination for MultiThreshold." )
165+
166+ # If the shape has been modified to support arbitrary layouts, restore
167+ # the original shape
168+ # TODO: Part of the rather sketchy solution above.
169+ if orig_shape is not None :
170+ output = output .reshape (orig_shape )
171+
148172 context [node .output [0 ]] = output
149173
150174 def verify_node (self ):
0 commit comments