@@ -233,15 +233,78 @@ def write_report(node, report_type=None, is_mapnode=False):
233233 return
234234
235235
236+ def _identify_collapses (hastraits ):
237+ """ Identify traits that will collapse when being set to themselves.
238+
239+ ``OutputMultiObject``s automatically unwrap a list of length 1 to directly
240+ reference the element of that list.
241+ If that element is itself a list of length 1, then the following will
242+ result in modified values.
243+
244+ hastraits.trait_set(**hastraits.trait_get())
245+
246+ Cloning performs this operation on a copy of the original traited object,
247+ allowing us to identify traits that will be affected.
248+ """
249+ raw = hastraits .trait_get ()
250+ cloned = hastraits .clone_traits ().trait_get ()
251+
252+ collapsed = set ()
253+ for key in cloned :
254+ orig = raw [key ]
255+ new = cloned [key ]
256+ # Allow numpy to handle the equality checks, as mixed lists and arrays
257+ # can be problematic.
258+ if isinstance (orig , list ) and len (orig ) == 1 and (
259+ not np .array_equal (orig , new ) and np .array_equal (orig [0 ], new )):
260+ collapsed .add (key )
261+
262+ return collapsed
263+
264+
265+ def _uncollapse (indexable , collapsed ):
266+ """ Wrap collapsible values in a list to prevent double-collapsing.
267+
268+ Should be used with _identify_collapses to provide the following
269+ idempotent operation:
270+
271+ collapsed = _identify_collapses(hastraits)
272+ hastraits.trait_set(**_uncollapse(hastraits.trait_get(), collapsed))
273+
274+ NOTE: Modifies object in-place, in addition to returning it.
275+ """
276+
277+ for key in indexable :
278+ if key in collapsed :
279+ indexable [key ] = [indexable [key ]]
280+ return indexable
281+
282+
283+ def _protect_collapses (hastraits ):
284+ """ A collapse-protected replacement for hastraits.trait_get()
285+
286+ May be used as follows to provide an idempotent trait_set:
287+
288+ hastraits.trait_set(**_protect_collapses(hastraits))
289+ """
290+ collapsed = _identify_collapses (hastraits )
291+ return _uncollapse (hastraits .trait_get (), collapsed )
292+
293+
236294def save_resultfile (result , cwd , name ):
237295 """Save a result pklz file to ``cwd``"""
238296 resultsfile = os .path .join (cwd , 'result_%s.pklz' % name )
239297 if result .outputs :
240298 try :
241- outputs = result .outputs .trait_get ()
299+ collapsed = _identify_collapses (result .outputs )
300+ outputs = _uncollapse (result .outputs .trait_get (), collapsed )
301+ # Double-protect tosave so that the original, uncollapsed trait
302+ # is saved in the pickle file. Thus, when the loading process
303+ # collapses, the original correct value is loaded.
304+ tosave = _uncollapse (outputs .copy (), collapsed )
242305 except AttributeError :
243- outputs = result .outputs .dictcopy () # outputs was a bunch
244- result .outputs .set (** modify_paths (outputs , relative = True , basedir = cwd ))
306+ tosave = outputs = result .outputs .dictcopy () # outputs was a bunch
307+ result .outputs .set (** modify_paths (tosave , relative = True , basedir = cwd ))
245308
246309 savepkl (resultsfile , result )
247310 logger .debug ('saved results in %s' , resultsfile )
@@ -293,7 +356,7 @@ def load_resultfile(path, name):
293356 else :
294357 if result .outputs :
295358 try :
296- outputs = result .outputs . trait_get ( )
359+ outputs = _protect_collapses ( result .outputs )
297360 except AttributeError :
298361 outputs = result .outputs .dictcopy () # outputs == Bunch
299362 try :
0 commit comments