1212from matplotlib import transforms , collections
1313from matplotlib .backends .backend_agg import FigureCanvasAgg
1414
15+
1516class Exporter (object ):
1617 """Matplotlib Exporter
1718
@@ -44,15 +45,17 @@ def run(self, fig):
4445 # in the correct place.
4546 if fig .canvas is None :
4647 canvas = FigureCanvasAgg (fig )
47- fig .savefig (io .BytesIO (), format = ' png' , dpi = fig .dpi )
48+ fig .savefig (io .BytesIO (), format = " png" , dpi = fig .dpi )
4849 if self .close_mpl :
4950 import matplotlib .pyplot as plt
51+
5052 plt .close (fig )
5153 self .crawl_fig (fig )
5254
5355 @staticmethod
54- def process_transform (transform , ax = None , data = None , return_trans = False ,
55- force_trans = None ):
56+ def process_transform (
57+ transform , ax = None , data = None , return_trans = False , force_trans = None
58+ ):
5659 """Process the transform and convert data to figure or data coordinates
5760
5861 Parameters
@@ -81,8 +84,10 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
8184 Returned only if data is specified
8285 """
8386 if isinstance (transform , transforms .BlendedGenericTransform ):
84- warnings .warn ("Blended transforms not yet supported. "
85- "Zoom behavior may not work as expected." )
87+ warnings .warn (
88+ "Blended transforms not yet supported. "
89+ "Zoom behavior may not work as expected."
90+ )
8691
8792 if force_trans is not None :
8893 if data is not None :
@@ -91,10 +96,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
9196
9297 code = "display"
9398 if ax is not None :
94- for (c , trans ) in [("data" , ax .transData ),
95- ("axes" , ax .transAxes ),
96- ("figure" , ax .figure .transFigure ),
97- ("display" , transforms .IdentityTransform ())]:
99+ for (c , trans ) in [
100+ ("data" , ax .transData ),
101+ ("axes" , ax .transAxes ),
102+ ("figure" , ax .figure .transFigure ),
103+ ("display" , transforms .IdentityTransform ()),
104+ ]:
98105 if transform .contains_branch (trans ):
99106 code , transform = (c , transform - trans )
100107 break
@@ -112,24 +119,23 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
112119
113120 def crawl_fig (self , fig ):
114121 """Crawl the figure and process all axes"""
115- with self .renderer .draw_figure (fig = fig ,
116- props = utils .get_figure_properties (fig )):
122+ with self .renderer .draw_figure (fig = fig , props = utils .get_figure_properties (fig )):
117123 for ax in fig .axes :
118124 self .crawl_ax (ax )
119125
120126 def crawl_ax (self , ax ):
121127 """Crawl the axes and process all elements within"""
122- with self .renderer .draw_axes (ax = ax ,
123- props = utils .get_axes_properties (ax )):
128+ with self .renderer .draw_axes (ax = ax , props = utils .get_axes_properties (ax )):
124129 for line in ax .lines :
125130 self .draw_line (ax , line )
126131 for text in ax .texts :
127132 self .draw_text (ax , text )
128- for (text , ttp ) in zip ([ax .xaxis .label , ax .yaxis .label , ax .title ],
129- ["xlabel" , "ylabel" , "title" ]):
130- if (hasattr (text , 'get_text' ) and text .get_text ()):
131- self .draw_text (ax , text , force_trans = ax .transAxes ,
132- text_type = ttp )
133+ for (text , ttp ) in zip (
134+ [ax .xaxis .label , ax .yaxis .label , ax .title ],
135+ ["xlabel" , "ylabel" , "title" ],
136+ ):
137+ if hasattr (text , "get_text" ) and text .get_text ():
138+ self .draw_text (ax , text , force_trans = ax .transAxes , text_type = ttp )
133139 for artist in ax .artists :
134140 # TODO: process other artists
135141 if isinstance (artist , matplotlib .text .Text ):
@@ -145,107 +151,122 @@ def crawl_ax(self, ax):
145151 if legend is not None :
146152 props = utils .get_legend_properties (ax , legend )
147153 with self .renderer .draw_legend (legend = legend , props = props ):
148- if props [' visible' ]:
154+ if props [" visible" ]:
149155 self .crawl_legend (ax , legend )
150156
151157 def crawl_legend (self , ax , legend ):
152158 """
153159 Recursively look through objects in legend children
154160 """
155- legendElements = list (utils .iter_all_children (legend ._legend_box ,
156- skipContainers = True ))
161+ legendElements = list (
162+ utils .iter_all_children (legend ._legend_box , skipContainers = True )
163+ )
157164 legendElements .append (legend .legendPatch )
158165 for child in legendElements :
159166 # force a large zorder so it appears on top
160- child .set_zorder (1E6 + child .get_zorder ())
167+ child .set_zorder (1e6 + child .get_zorder ())
161168
162169 # reorder border box to make sure marks are visible
163170 if isinstance (child , matplotlib .patches .FancyBboxPatch ):
164- child .set_zorder (child .get_zorder ()- 1 )
171+ child .set_zorder (child .get_zorder () - 1 )
165172
166173 try :
167174 # What kind of object...
168175 if isinstance (child , matplotlib .patches .Patch ):
169176 self .draw_patch (ax , child , force_trans = ax .transAxes )
170177 elif isinstance (child , matplotlib .text .Text ):
171- if child .get_text () != ' None' :
178+ if child .get_text () != " None" :
172179 self .draw_text (ax , child , force_trans = ax .transAxes )
173180 elif isinstance (child , matplotlib .lines .Line2D ):
174181 self .draw_line (ax , child , force_trans = ax .transAxes )
175182 elif isinstance (child , matplotlib .collections .Collection ):
176- self .draw_collection (ax , child ,
177- force_pathtrans = ax .transAxes )
183+ self .draw_collection (ax , child , force_pathtrans = ax .transAxes )
178184 else :
179185 warnings .warn ("Legend element %s not impemented" % child )
180186 except NotImplementedError :
181187 warnings .warn ("Legend element %s not impemented" % child )
182188
183189 def draw_line (self , ax , line , force_trans = None ):
184190 """Process a matplotlib line and call renderer.draw_line"""
185- coordinates , data = self .process_transform (line . get_transform (),
186- ax , line .get_xydata (),
187- force_trans = force_trans )
191+ coordinates , data = self .process_transform (
192+ line . get_transform (), ax , line .get_xydata (), force_trans = force_trans
193+ )
188194 linestyle = utils .get_line_style (line )
189- if (linestyle ['dasharray' ] is None
190- and linestyle ['drawstyle' ] == 'default' ):
195+ if linestyle ["dasharray" ] is None and linestyle ["drawstyle" ] == "default" :
191196 linestyle = None
192197 markerstyle = utils .get_marker_style (line )
193- if (markerstyle ['marker' ] in ['None' , 'none' , None ]
194- or markerstyle ['markerpath' ][0 ].size == 0 ):
198+ if (
199+ markerstyle ["marker" ] in ["None" , "none" , None ]
200+ or markerstyle ["markerpath" ][0 ].size == 0
201+ ):
195202 markerstyle = None
196203 label = line .get_label ()
197204 if markerstyle or linestyle :
198- self .renderer .draw_marked_line (data = data , coordinates = coordinates ,
199- linestyle = linestyle ,
200- markerstyle = markerstyle ,
201- label = label ,
202- mplobj = line )
205+ self .renderer .draw_marked_line (
206+ data = data ,
207+ coordinates = coordinates ,
208+ linestyle = linestyle ,
209+ markerstyle = markerstyle ,
210+ label = label ,
211+ mplobj = line ,
212+ )
203213
204214 def draw_text (self , ax , text , force_trans = None , text_type = None ):
205215 """Process a matplotlib text object and call renderer.draw_text"""
206216 content = text .get_text ()
207217 if content :
208218 transform = text .get_transform ()
209219 position = text .get_position ()
210- coords , position = self .process_transform (transform , ax ,
211- position ,
212- force_trans = force_trans )
220+ coords , position = self .process_transform (
221+ transform , ax , position , force_trans = force_trans
222+ )
213223 style = utils .get_text_style (text )
214- self .renderer .draw_text (text = content , position = position ,
215- coordinates = coords ,
216- text_type = text_type ,
217- style = style , mplobj = text )
224+ self .renderer .draw_text (
225+ text = content ,
226+ position = position ,
227+ coordinates = coords ,
228+ text_type = text_type ,
229+ style = style ,
230+ mplobj = text ,
231+ )
218232
219233 def draw_patch (self , ax , patch , force_trans = None ):
220234 """Process a matplotlib patch object and call renderer.draw_path"""
221235 vertices , pathcodes = utils .SVG_path (patch .get_path ())
222236 transform = patch .get_transform ()
223- coordinates , vertices = self .process_transform (transform ,
224- ax , vertices ,
225- force_trans = force_trans )
237+ coordinates , vertices = self .process_transform (
238+ transform , ax , vertices , force_trans = force_trans
239+ )
226240 linestyle = utils .get_path_style (patch , fill = patch .get_fill ())
227- self .renderer .draw_path (data = vertices ,
228- coordinates = coordinates ,
229- pathcodes = pathcodes ,
230- style = linestyle ,
231- mplobj = patch )
241+ self .renderer .draw_path (
242+ data = vertices ,
243+ coordinates = coordinates ,
244+ pathcodes = pathcodes ,
245+ style = linestyle ,
246+ mplobj = patch ,
247+ )
232248
233- def draw_collection (self , ax , collection ,
234- force_pathtrans = None ,
235- force_offsettrans = None ):
249+ def draw_collection (
250+ self , ax , collection , force_pathtrans = None , force_offsettrans = None
251+ ):
236252 """Process a matplotlib collection and call renderer.draw_collection"""
237- (transform , transOffset ,
238- offsets , paths ) = collection ._prepare_points ()
253+ (transform , transOffset , offsets , paths ) = collection ._prepare_points ()
239254
240255 offset_coords , offsets = self .process_transform (
241- transOffset , ax , offsets , force_trans = force_offsettrans )
242- path_coords = self . process_transform (
243- transform , ax , force_trans = force_pathtrans )
256+ transOffset , ax , offsets , force_trans = force_offsettrans
257+ )
258+ path_coords = self . process_transform ( transform , ax , force_trans = force_pathtrans )
244259
245260 processed_paths = [utils .SVG_path (path ) for path in paths ]
246- processed_paths = [(self .process_transform (
247- transform , ax , path [0 ], force_trans = force_pathtrans )[1 ], path [1 ])
248- for path in processed_paths ]
261+ processed_paths = [
262+ (
263+ self .process_transform (
264+ transform , ax , path [0 ], force_trans = force_pathtrans
265+ )[1 ],
266+ path [1 ],
267+ )
268+ for path in processed_paths
269+ ]
249270
250271 path_transforms = collection .get_transforms ()
251272 try :
@@ -256,30 +277,34 @@ def draw_collection(self, ax, collection,
256277 # matplotlib 1.4: path transforms are already numpy arrays.
257278 pass
258279
259- styles = {'linewidth' : collection .get_linewidths (),
260- 'facecolor' : collection .get_facecolors (),
261- 'edgecolor' : collection .get_edgecolors (),
262- 'alpha' : collection ._alpha ,
263- 'zorder' : collection .get_zorder ()}
280+ styles = {
281+ "linewidth" : collection .get_linewidths (),
282+ "facecolor" : collection .get_facecolors (),
283+ "edgecolor" : collection .get_edgecolors (),
284+ "alpha" : collection ._alpha ,
285+ "zorder" : collection .get_zorder (),
286+ }
264287
265- offset_dict = {"data" : "before" ,
266- "screen" : "after" }
288+ offset_dict = {"data" : "before" , "screen" : "after" }
267289 offset_order = offset_dict [collection .get_offset_position ()]
268290
269- self .renderer .draw_path_collection (paths = processed_paths ,
270- path_coordinates = path_coords ,
271- path_transforms = path_transforms ,
272- offsets = offsets ,
273- offset_coordinates = offset_coords ,
274- offset_order = offset_order ,
275- styles = styles ,
276- mplobj = collection )
291+ self .renderer .draw_path_collection (
292+ paths = processed_paths ,
293+ path_coordinates = path_coords ,
294+ path_transforms = path_transforms ,
295+ offsets = offsets ,
296+ offset_coordinates = offset_coords ,
297+ offset_order = offset_order ,
298+ styles = styles ,
299+ mplobj = collection ,
300+ )
277301
278302 def draw_image (self , ax , image ):
279303 """Process a matplotlib image object and call renderer.draw_image"""
280- self .renderer .draw_image (imdata = utils .image_to_base64 (image ),
281- extent = image .get_extent (),
282- coordinates = "data" ,
283- style = {"alpha" : image .get_alpha (),
284- "zorder" : image .get_zorder ()},
285- mplobj = image )
304+ self .renderer .draw_image (
305+ imdata = utils .image_to_base64 (image ),
306+ extent = image .get_extent (),
307+ coordinates = "data" ,
308+ style = {"alpha" : image .get_alpha (), "zorder" : image .get_zorder ()},
309+ mplobj = image ,
310+ )
0 commit comments