4646import dataclasses
4747import logging
4848import time
49- from collections import OrderedDict
5049from collections .abc import Callable , Iterable
5150from copy import copy
5251from itertools import chain , product
@@ -2188,7 +2187,7 @@ def infer_shape(self, fgraph, node, input_shapes):
21882187 # corresponding outer inputs that the Scan would use as input for
21892188 # any given iteration. For simplicity, we use iteration 0.
21902189 inner_ins_shapes = []
2191- out_equivalent = OrderedDict ()
2190+ out_equivalent = {}
21922191
21932192 # The two following blocks are commented as it cause in some
21942193 # cases extra scans in the graph. See gh-XXX for the
@@ -2469,7 +2468,7 @@ def compute_all_gradients(known_grads):
24692468 if (x in diff_inputs )
24702469 and get_inp_idx (self_inputs .index (x )) in connected_inputs
24712470 ]
2472- gmp = OrderedDict ()
2471+ gmp = {}
24732472
24742473 # Required in case there is a pair of variables X and Y, with X
24752474 # used to compute Y, for both of which there is an external
@@ -2478,7 +2477,7 @@ def compute_all_gradients(known_grads):
24782477 # it will be the sum of the external gradient signal and the
24792478 # gradient obtained by propagating Y's external gradient signal
24802479 # to X.
2481- known_grads = OrderedDict ([( k .copy (), v ) for (k , v ) in known_grads .items ()])
2480+ known_grads = { k .copy (): v for (k , v ) in known_grads .items ()}
24822481
24832482 grads = grad (
24842483 cost = None ,
@@ -2548,7 +2547,7 @@ def compute_all_gradients(known_grads):
25482547 dC_dXt = safe_new (dC_douts [idx ][0 ])
25492548 dC_dXts .append (dC_dXt )
25502549
2551- known_grads = OrderedDict ()
2550+ known_grads = {}
25522551 dc_dxts_idx = 0
25532552 for i in range (len (diff_outputs )):
25542553 if i < idx_nitsot_start or i >= idx_nitsot_end :
0 commit comments