File tree Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Expand file tree Collapse file tree 3 files changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -556,19 +556,19 @@ def __init__(
556556
557557 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
558558 stages = []
559- for stage_idx in range (num_stages ):
560- out_chs = embed_dims [stage_idx ]
559+ for i in range (num_stages ):
560+ out_chs = embed_dims [i ]
561561 stage = DaVitStage (
562562 in_chs ,
563563 out_chs ,
564- depth = depths [stage_idx ],
565- downsample = stage_idx > 0 ,
564+ depth = depths [i ],
565+ downsample = i > 0 ,
566566 attn_types = attn_types ,
567- num_heads = num_heads [stage_idx ],
567+ num_heads = num_heads [i ],
568568 window_size = window_size ,
569569 mlp_ratio = mlp_ratio ,
570570 qkv_bias = qkv_bias ,
571- drop_path_rates = dpr [stage_idx ],
571+ drop_path_rates = dpr [i ],
572572 norm_layer = norm_layer ,
573573 norm_layer_cl = norm_layer_cl ,
574574 ffn = ffn ,
@@ -579,7 +579,7 @@ def __init__(
579579 )
580580 in_chs = out_chs
581581 stages .append (stage )
582- self .feature_info += [dict (num_chs = out_chs , reduction = 2 , module = f'stages.{ stage_idx } ' )]
582+ self .feature_info += [dict (num_chs = out_chs , reduction = 2 ** ( i + 2 ) , module = f'stages.{ i } ' )]
583583
584584 self .stages = nn .Sequential (* stages )
585585
Original file line number Diff line number Diff line change @@ -407,7 +407,7 @@ def __init__(
407407 )
408408 prev_dim = embed_dims [i ]
409409 stages .append (stage )
410- self .feature_info += [dict (num_chs = embed_dims [i ], reduction = 2 ** (1 + i ), module = f'stages.{ i } ' )]
410+ self .feature_info += [dict (num_chs = embed_dims [i ], reduction = 2 ** (i + 2 ), module = f'stages.{ i } ' )]
411411 self .stages = nn .Sequential (* stages )
412412
413413 # Classifier head
Original file line number Diff line number Diff line change @@ -541,7 +541,7 @@ def __init__(
541541 ** kwargs ,
542542 )]
543543 prev_dim = dims [i ]
544- self .feature_info += [dict (num_chs = dims [i ], reduction = 2 , module = f'stages.{ i } ' )]
544+ self .feature_info += [dict (num_chs = dims [i ], reduction = 2 ** ( i + 2 ) , module = f'stages.{ i } ' )]
545545
546546 self .stages = nn .Sequential (* stages )
547547
You can’t perform that action at this time.
0 commit comments