1313import time
1414from pathlib import Path
1515from typing import Any , Dict , List , Optional
16+ from urllib .parse import urlparse
1617
1718import cv2
1819import detectron2 .data .transforms as T # noqa:N812
@@ -373,6 +374,50 @@ def train(self):
373374 verify_results (self .cfg , self ._last_eval_results )
374375 return self ._last_eval_results
375376
377+ def resume_or_load (self , resume = True ):
378+ """
379+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
380+ a `last_checkpoint` file), resume from the file. Resuming means loading all
381+ available states (eg. optimizer and scheduler) and update iteration counter
382+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
383+
384+ Otherwise, this is considered as an independent training. The method will load model
385+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
386+ from iteration 0.
387+
388+ Args:
389+ resume (bool): whether to do resume or not
390+ """
391+ self .checkpointer .resume_or_load (self .cfg .MODEL .WEIGHTS , resume = resume )
392+ if resume and self .checkpointer .has_checkpoint ():
393+ # The checkpoint stores the training iteration that just finished, thus we start
394+ # at the next iteration
395+ self .start_iter = self .iter + 1
396+
397+ if self .cfg .MODEL .WEIGHTS :
398+ checkpoint = torch .tensor (
399+ self .checkpointer ._load_file (
400+ self .checkpointer .path_manager .get_local_path (
401+ urlparse (self .cfg .MODEL .WEIGHTS )._replace (
402+ query = "" ).geturl ()))['model' ]['backbone.bottom_up.stem.conv1.weight' ]).to (
403+ self .model .backbone .bottom_up .stem .conv1 .weight .device )
404+ input_channels_in_checkpoint = checkpoint .shape [1 ]
405+ input_channels_in_model = self .model .backbone .bottom_up .stem .conv1 .weight .shape [1 ]
406+ if input_channels_in_checkpoint != input_channels_in_model :
407+ logger = logging .getLogger ("detectree2" )
408+ if input_channels_in_checkpoint != 3 :
409+ logger .warning (
410+ "Input channel modification only works if checkpoint was trained on RGB images (3 channels). The first three channels will be copied and then repeated in the model."
411+ )
412+ logger .warning (
413+ "Mismatch in input channels in checkpoint and model, meaning fvcommon would not have been able to automatically load them. Adjusting weights for 'backbone.bottom_up.stem.conv1.weight' manually."
414+ )
415+ with torch .no_grad ():
416+ self .model .backbone .bottom_up .stem .conv1 .weight [:, :
417+ input_channels_in_checkpoint ] = checkpoint [:, :
418+ input_channels_in_checkpoint ]
419+ multiply_conv1_weights (self .model )
420+
376421 @classmethod
377422 def build_evaluator (cls , cfg , dataset_name , output_folder = None ):
378423 """
@@ -964,7 +1009,7 @@ def predictions_on_data(
9641009 json .dump (evaluations , dest )
9651010
9661011
967- def modify_conv1_weights (model , num_input_channels ):
1012+ def multiply_conv1_weights (model ):
9681013 """
9691014 Modify the weights of the first convolutional layer (conv1) to accommodate a different number of input channels.
9701015
@@ -974,12 +1019,12 @@ def modify_conv1_weights(model, num_input_channels):
9741019
9751020 Args:
9761021 model (torch.nn.Module): The model containing the convolutional layer to modify.
977- num_input_channels (int): The number of input channels for the new conv1 layer.
9781022
9791023 """
9801024 with torch .no_grad ():
9811025 # Retrieve the original weights of the conv1 layer
9821026 old_weights = model .backbone .bottom_up .stem .conv1 .weight
1027+ num_input_channels = model .backbone .bottom_up .stem .conv1 .weight .shape [1 ] # The number of input channels
9831028
9841029 # Create a new weight tensor with the desired number of input channels
9851030 # The shape is (out_channels, in_channels, height, width)
0 commit comments