@@ -1114,7 +1114,9 @@ def _compute_loss_for_scale(
11141114 # At coarser scales (especially with mixed precision), logits can explode
11151115 # BCEWithLogitsLoss: clamp to [-20, 20] (sigmoid maps to [2e-9, 1-2e-9])
11161116 # MSELoss with tanh: clamp to [-10, 10] (tanh maps to [-0.9999, 0.9999])
1117- task_output = torch .clamp (task_output , min = - 20.0 , max = 20.0 )
1117+ clamp_min = getattr (self .cfg .model , 'deep_supervision_clamp_min' , - 20.0 )
1118+ clamp_max = getattr (self .cfg .model , 'deep_supervision_clamp_max' , 20.0 )
1119+ task_output = torch .clamp (task_output , min = clamp_min , max = clamp_max )
11181120
11191121 # Apply specified losses for this task
11201122 for loss_idx in loss_indices :
@@ -1142,7 +1144,9 @@ def _compute_loss_for_scale(
11421144 else :
11431145 # Standard deep supervision: apply all losses to all outputs
11441146 # Clamp outputs to prevent numerical instability at coarser scales
1145- output_clamped = torch .clamp (output , min = - 20.0 , max = 20.0 )
1147+ clamp_min = getattr (self .cfg .model , 'deep_supervision_clamp_min' , - 20.0 )
1148+ clamp_max = getattr (self .cfg .model , 'deep_supervision_clamp_max' , 20.0 )
1149+ output_clamped = torch .clamp (output , min = clamp_min , max = clamp_max )
11461150
11471151 for loss_fn , weight in zip (self .loss_functions , self .loss_weights ):
11481152 loss = loss_fn (output_clamped , target )
@@ -1191,7 +1195,19 @@ def _compute_deep_supervision_loss(
11911195 main_output = outputs ['output' ]
11921196 ds_outputs = [outputs [f'ds_{ i } ' ] for i in range (1 , 5 ) if f'ds_{ i } ' in outputs ]
11931197
1194- ds_weights = [1.0 ] + [0.5 ** i for i in range (1 , len (ds_outputs ) + 1 )]
1198+ # Use configured weights or default exponential decay
1199+ if hasattr (self .cfg .model , 'deep_supervision_weights' ) and self .cfg .model .deep_supervision_weights is not None :
1200+ ds_weights = self .cfg .model .deep_supervision_weights
1201+ # Ensure we have enough weights for all outputs
1202+ if len (ds_weights ) < len (ds_outputs ) + 1 :
1203+ warnings .warn (
1204+ f"deep_supervision_weights has { len (ds_weights )} weights but "
1205+ f"{ len (ds_outputs ) + 1 } outputs. Using exponential decay for missing weights."
1206+ )
1207+ ds_weights = [1.0 ] + [0.5 ** i for i in range (1 , len (ds_outputs ) + 1 )]
1208+ else :
1209+ ds_weights = [1.0 ] + [0.5 ** i for i in range (1 , len (ds_outputs ) + 1 )]
1210+
11951211 all_outputs = [main_output ] + ds_outputs
11961212
11971213 total_loss = 0.0
0 commit comments