File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 44import torch .nn as nn
55import torch .nn .init as init
66from torchvision import models
7- from torchvision .models .vgg import model_urls
7+ from torchvision .models .vgg import VGG16_BN_Weights
88
99def init_weights (modules ):
1010 for m in modules :
@@ -22,8 +22,10 @@ def init_weights(modules):
2222class vgg16_bn (torch .nn .Module ):
2323 def __init__ (self , pretrained = True , freeze = True ):
2424 super (vgg16_bn , self ).__init__ ()
25- model_urls ['vgg16_bn' ] = model_urls ['vgg16_bn' ].replace ('https://' , 'http://' )
26- vgg_pretrained_features = models .vgg16_bn (pretrained = pretrained ).features
25+ # Use the weights parameter based on the pretrained flag
26+ weights = VGG16_BN_Weights .IMAGENET1K_V1 if pretrained else None
27+ vgg_pretrained_features = models .vgg16_bn (weights = weights ).features
28+
2729 self .slice1 = torch .nn .Sequential ()
2830 self .slice2 = torch .nn .Sequential ()
2931 self .slice3 = torch .nn .Sequential ()
You can’t perform that action at this time.
0 commit comments