@@ -293,6 +293,13 @@ def overlay(self):
293293 """The current overlay """
294294 return self ._overlay
295295
296+ @overlay .setter
297+ def overlay (self , img ):
298+ if img is None :
299+ self ._remove_overlay ()
300+ else :
301+ self .set_overlay (img )
302+
296303 @property
297304 def threshold (self ):
298305 """The current data display threshold """
@@ -382,16 +389,7 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
382389
383390 # we already have a plotted overlay
384391 if self ._overlay is not None :
385- # remove all images + cross hair lines
386- for nn , im in enumerate (self ._overlay ._ims ):
387- im .remove ()
388- for line in self ._overlay ._crosshairs [nn ].values ():
389- line .remove ()
390- # remove the fourth axis, if it was created for the overlay
391- if (self ._overlay .n_volumes > 1 and len (self ._overlay ._axes ) > 3
392- and self .n_volumes == 1 ):
393- a = self ._axes .pop (- 1 )
394- a .remove ()
392+ self ._remove_overlay ()
395393
396394 axes = self ._axes
397395 o_n_volumes = int (np .prod (data .shape [3 :]))
@@ -401,6 +399,9 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
401399 # 4D underlay, 3D overlay
402400 elif o_n_volumes < self .n_volumes and o_n_volumes == 1 :
403401 axes = axes [:- 1 ]
402+ # 4D underlay, 4D overlay
403+ elif o_n_volumes > 1 and self .n_volumes > 1 :
404+ raise TypeError ('Cannot set 4D overlay on top of 4D underlay' )
404405
405406 # mask array for provided threshold
406407 self ._overlay = self .__class__ (data , affine = affine , axes = axes )
@@ -416,6 +417,21 @@ def set_overlay(self, data, affine=None, threshold=None, cmap='viridis',
416417 cross ['vert' ].set_visible (False )
417418 self ._overlay ._draw ()
418419
420+ def _remove_overlay (self ):
421+ """ Removes current overlay image + associated axes """
422+ # remove all images + cross hair lines
423+ for nn , im in enumerate (self ._overlay ._ims ):
424+ im .remove ()
425+ for line in self ._overlay ._crosshairs [nn ].values ():
426+ line .remove ()
427+ # remove the fourth axis, if it was created for the overlay
428+ if (self ._overlay .n_volumes > 1 and len (self ._overlay ._axes ) > 3
429+ and self .n_volumes == 1 ):
430+ a = self ._axes .pop (- 1 )
431+ a .remove ()
432+
433+ self ._overlay = None
434+
419435 def link_to (self , other ):
420436 """Link positional changes between two canvases
421437
0 commit comments