1414import operator
1515import sys
1616import time
17- from collections import defaultdict
17+ from collections import Counter , defaultdict
1818from contextlib import contextmanager
19- from typing import TYPE_CHECKING , Any , Union
19+ from typing import TYPE_CHECKING , Any
2020
2121import numpy as np
2222
@@ -204,8 +204,8 @@ def reset(self):
204204 self .fct_call_time = 0.0
205205 self .fct_callcount = 0
206206 self .vm_call_time = 0.0
207- self .apply_time = {}
208- self .apply_callcount = {}
207+ self .apply_time = defaultdict ( float )
208+ self .apply_callcount = Counter ()
209209 # self.apply_cimpl = None
210210 # self.message = None
211211
@@ -234,9 +234,9 @@ def reset(self):
234234 # Total time spent in Function.vm.__call__
235235 #
236236
237- apply_time : dict [Union ["FunctionGraph" , Variable ], float ] | None = None
237+ apply_time : dict [tuple ["FunctionGraph" , Apply ], float ]
238238
239- apply_callcount : dict [Union ["FunctionGraph" , Variable ], int ] | None = None
239+ apply_callcount : dict [tuple ["FunctionGraph" , Apply ], int ]
240240
241241 apply_cimpl : dict [Apply , bool ] | None = None
242242 # dict from node -> bool (1 if c, 0 if py)
@@ -292,10 +292,9 @@ def reset(self):
292292 # param is called flag_time_thunks because most other attributes with time
293293 # in the name are times *of* something, rather than configuration flags.
294294 def __init__ (self , atexit_print = True , flag_time_thunks = None , ** kwargs ):
295- self .apply_callcount = {}
295+ self .apply_callcount = Counter ()
296296 self .output_size = {}
297- # Keys are `(FunctionGraph, Variable)`
298- self .apply_time = {}
297+ self .apply_time = defaultdict (float )
299298 self .apply_cimpl = {}
300299 self .variable_shape = {}
301300 self .variable_strides = {}
@@ -320,37 +319,29 @@ def class_time(self):
320319
321320 """
322321 # timing is stored by node, we compute timing by class on demand
323- rval = {}
324- for (fgraph , node ), t in self .apply_time .items ():
325- typ = type (node .op )
326- rval .setdefault (typ , 0 )
327- rval [typ ] += t
328- return rval
322+ rval = defaultdict (float )
323+ for (_fgraph , node ), t in self .apply_time .items ():
324+ rval [type (node .op )] += t
325+ return dict (rval )
329326
330327 def class_callcount (self ):
331328 """
332329 dict op -> total number of thunk calls
333330
334331 """
335332 # timing is stored by node, we compute timing by class on demand
336- rval = {}
337- for (fgraph , node ), count in self .apply_callcount .items ():
338- typ = type (node .op )
339- rval .setdefault (typ , 0 )
340- rval [typ ] += count
333+ rval = Counter ()
334+ for (_fgraph , node ), count in self .apply_callcount .items ():
335+ rval [type (node .op )] += count
341336 return rval
342337
343- def class_nodes (self ):
338+ def class_nodes (self ) -> Counter :
344339 """
345340 dict op -> total number of nodes
346341
347342 """
348343 # timing is stored by node, we compute timing by class on demand
349- rval = {}
350- for (fgraph , node ), count in self .apply_callcount .items ():
351- typ = type (node .op )
352- rval .setdefault (typ , 0 )
353- rval [typ ] += 1
344+ rval = Counter (type (node .op ) for _fgraph , node in self .apply_callcount )
354345 return rval
355346
356347 def class_impl (self ):
@@ -360,12 +351,9 @@ def class_impl(self):
360351 """
361352 # timing is stored by node, we compute timing by class on demand
362353 rval = {}
363- for fgraph , node in self .apply_callcount :
354+ for _fgraph , node in self .apply_callcount :
364355 typ = type (node .op )
365- if self .apply_cimpl [node ]:
366- impl = "C "
367- else :
368- impl = "Py"
356+ impl = "C " if self .apply_cimpl [node ] else "Py"
369357 rval .setdefault (typ , impl )
370358 if rval [typ ] != impl and len (rval [typ ]) == 2 :
371359 rval [typ ] += impl
@@ -377,11 +365,10 @@ def op_time(self):
377365
378366 """
379367 # timing is stored by node, we compute timing by Op on demand
380- rval = {}
368+ rval = defaultdict ( float )
381369 for (fgraph , node ), t in self .apply_time .items ():
382- rval .setdefault (node .op , 0 )
383370 rval [node .op ] += t
384- return rval
371+ return dict ( rval )
385372
386373 def fill_node_total_time (self , fgraph , node , total_times ):
387374 """
@@ -414,9 +401,8 @@ def op_callcount(self):
414401
415402 """
416403 # timing is stored by node, we compute timing by Op on demand
417- rval = {}
404+ rval = Counter ()
418405 for (fgraph , node ), count in self .apply_callcount .items ():
419- rval .setdefault (node .op , 0 )
420406 rval [node .op ] += count
421407 return rval
422408
@@ -426,10 +412,7 @@ def op_nodes(self):
426412
427413 """
428414 # timing is stored by node, we compute timing by Op on demand
429- rval = {}
430- for (fgraph , node ), count in self .apply_callcount .items ():
431- rval .setdefault (node .op , 0 )
432- rval [node .op ] += 1
415+ rval = Counter (node .op for _fgraph , node in self .apply_callcount )
433416 return rval
434417
435418 def op_impl (self ):
0 commit comments