1010import time
1111import traceback
1212import warnings
13- from collections import UserList , defaultdict , deque
13+ from collections import Counter , UserList , defaultdict , deque
1414from collections .abc import Callable , Iterable , Sequence
1515from collections .abc import Iterable as IterableType
1616from functools import _compose_mro , partial , reduce # type: ignore
@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
11531153 r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
11541154
11551155 def __init__ (self ):
1156- self .tracked_instances : dict [Op , list [NodeRewriter ]] = {}
1157- self .tracked_types : dict [type , list [NodeRewriter ]] = {}
1156+ self .tracked_instances : dict [Op , list [NodeRewriter ]] = defaultdict ( list )
1157+ self .tracked_types : dict [type , list [NodeRewriter ]] = defaultdict ( list )
11581158 self .untracked_rewrites : list [NodeRewriter ] = []
11591159
11601160 def add_tracker (self , rw : NodeRewriter ):
@@ -1166,9 +1166,9 @@ def add_tracker(self, rw: NodeRewriter):
11661166 else :
11671167 for c in tracks :
11681168 if isinstance (c , type ):
1169- self .tracked_types . setdefault ( c , []) .append (rw )
1169+ self .tracked_types [ c ] .append (rw )
11701170 else :
1171- self .tracked_instances . setdefault ( c , []) .append (rw )
1171+ self .tracked_instances [ c ] .append (rw )
11721172
11731173 def _find_impl (self , cls ) -> list [NodeRewriter ]:
11741174 r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
@@ -1250,22 +1250,16 @@ def __init__(
12501250
12511251 self .profile = profile
12521252 if self .profile :
1253- self .time_rewrites : dict [Rewriter , float ] = {}
1254- self .process_count : dict [Rewriter , int ] = {}
1255- self .applied_true : dict [Rewriter , int ] = {}
1256- self .node_created : dict [Rewriter , int ] = {}
1253+ self .time_rewrites : dict [Rewriter , float ] = defaultdict ( float )
1254+ self .process_count : dict [Rewriter , int ] = Counter ()
1255+ self .applied_true : dict [Rewriter , int ] = Counter ()
1256+ self .node_created : dict [Rewriter , int ] = Counter ()
12571257
12581258 self .tracker = OpToRewriterTracker ()
12591259
12601260 for o in self .rewrites :
12611261 self .tracker .add_tracker (o )
12621262
1263- if self .profile :
1264- self .time_rewrites .setdefault (o , 0.0 )
1265- self .process_count .setdefault (o , 0 )
1266- self .applied_true .setdefault (o , 0 )
1267- self .node_created .setdefault (o , 0 )
1268-
12691263 def __str__ (self ):
12701264 return getattr (
12711265 self ,
@@ -2316,30 +2310,29 @@ def apply(self, fgraph, start_from=None):
23162310 changed = True
23172311 max_use_abort = False
23182312 rewriter_name = None
2319- global_process_count = {}
2313+ global_process_count = Counter ()
23202314 start_nb_nodes = len (fgraph .apply_nodes )
23212315 max_nb_nodes = len (fgraph .apply_nodes )
23222316 max_use = max_nb_nodes * self .max_use_ratio
23232317
23242318 loop_timing = []
23252319 loop_process_count = []
23262320 global_rewriter_timing = []
2327- time_rewriters = {}
2321+ time_rewriters = defaultdict ( float )
23282322 io_toposort_timing = []
23292323 nb_nodes = []
2330- node_created = {}
2324+ node_created = Counter ()
23312325 global_sub_profs = []
23322326 final_sub_profs = []
23332327 cleanup_sub_profs = []
2334- for rewriter in (
2335- self .global_rewriters
2336- + list (self .get_node_rewriters ())
2337- + self .final_rewriters
2338- + self .cleanup_rewriters
2339- ):
2340- global_process_count .setdefault (rewriter , 0 )
2341- time_rewriters .setdefault (rewriter , 0 )
2342- node_created .setdefault (rewriter , 0 )
2328+
2329+ for rewriter in [
2330+ * self .global_rewriters ,
2331+ * self .get_node_rewriters (),
2332+ * self .final_rewriters ,
2333+ * self .cleanup_rewriters ,
2334+ ]:
2335+ time_rewriters [rewriter ] += 0
23432336
23442337 def apply_cleanup (profs_dict ):
23452338 changed = False
@@ -2351,15 +2344,14 @@ def apply_cleanup(profs_dict):
23512344 time_rewriters [crewriter ] += time .perf_counter () - t_rewrite
23522345 profs_dict [crewriter ].append (sub_prof )
23532346 if change_tracker .changed :
2354- process_count .setdefault (crewriter , 0 )
23552347 process_count [crewriter ] += 1
23562348 global_process_count [crewriter ] += 1
23572349 changed = True
23582350 node_created [crewriter ] += change_tracker .nb_imported - nb
23592351 return changed
23602352
23612353 while changed and not max_use_abort :
2362- process_count = {}
2354+ process_count = Counter ()
23632355 t0 = time .perf_counter ()
23642356 changed = False
23652357 iter_cleanup_sub_profs = {}
@@ -2376,7 +2368,6 @@ def apply_cleanup(profs_dict):
23762368 time_rewriters [grewrite ] += time .perf_counter () - t_rewrite
23772369 sub_profs .append (sub_prof )
23782370 if change_tracker .changed :
2379- process_count .setdefault (grewrite , 0 )
23802371 process_count [grewrite ] += 1
23812372 global_process_count [grewrite ] += 1
23822373 changed = True
@@ -2431,7 +2422,6 @@ def chin_(node, i, r, new_r, reason):
24312422 time_rewriters [node_rewriter ] += time .perf_counter () - t_rewrite
24322423 if not node_rewriter_change :
24332424 continue
2434- process_count .setdefault (node_rewriter , 0 )
24352425 process_count [node_rewriter ] += 1
24362426 global_process_count [node_rewriter ] += 1
24372427 changed = True
@@ -2459,7 +2449,6 @@ def chin_(node, i, r, new_r, reason):
24592449 time_rewriters [grewrite ] += time .perf_counter () - t_rewrite
24602450 sub_profs .append (sub_prof )
24612451 if change_tracker .changed :
2462- process_count .setdefault (grewrite , 0 )
24632452 process_count [grewrite ] += 1
24642453 global_process_count [grewrite ] += 1
24652454 changed = True
@@ -2514,7 +2503,7 @@ def chin_(node, i, r, new_r, reason):
25142503 (start_nb_nodes , end_nb_nodes , max_nb_nodes ),
25152504 global_rewriter_timing ,
25162505 nb_nodes ,
2517- time_rewriters ,
2506+ dict ( time_rewriters ) ,
25182507 io_toposort_timing ,
25192508 node_created ,
25202509 global_sub_profs ,
@@ -2597,14 +2586,7 @@ def print_profile(cls, stream, prof, level=0):
25972586 count_rewrite = []
25982587 not_used = []
25992588 not_used_time = 0
2600- process_count = {}
2601- for o in (
2602- rewrite .global_rewriters
2603- + list (rewrite .get_node_rewriters ())
2604- + list (rewrite .final_rewriters )
2605- + list (rewrite .cleanup_rewriters )
2606- ):
2607- process_count .setdefault (o , 0 )
2589+ process_count = Counter ()
26082590 for count in loop_process_count :
26092591 for o , v in count .items ():
26102592 process_count [o ] += v
0 commit comments