@suo 不确定这是否已修复。我在 VGG16 模型上遇到了同样的错误。我正在使用最新的 Pytorchmaster
分支。
代码片段:
class VGG16_frontend(nn.Module):
def __init__(self,block_num=5,decode_num=0,load_weights=True,bn=False,IF_freeze_bn=False):
super(VGG16_frontend,self).__init__()
self.block_num = block_num
self.load_weights = load_weights
self.bn = bn
self.IF_freeze_bn = IF_freeze_bn
self.decode_num = decode_num
block_dict = [[64, 64, 'M'], [128, 128, 'M'], [256, 256, 256, 'M'],\
[512, 512, 512,'M'], [512, 512, 512,'M']]
self.frontend_feat = []
for i in range(block_num):
self.frontend_feat += block_dict[i]
if self.bn:
self.features = make_layers(self.frontend_feat, batch_norm=True)
else:
self.features = make_layers(self.frontend_feat, batch_norm=False)
if self.load_weights:
if self.bn:
pretrained_model = models.vgg16_bn(pretrained = True)
else:
pretrained_model = models.vgg16(pretrained = True)
pretrained_dict = pretrained_model.state_dict()
model_dict = self.state_dict()
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# load the new state dict
self.load_state_dict(model_dict)
if IF_freeze_bn:
self.freeze_bn()
def forward(self,x):
if self.bn:
x = self.features[ 0:7](x)
print(type(x))
print(x.shape)
conv1_feat =x if self.decode_num>=4 else []
x = self.features[ 7:14](x)
conv2_feat =x if self.decode_num>=3 else []
x = self.features[ 14:24](x)
conv3_feat =x if self.decode_num>=2 else []
x = self.features[ 24:34](x)
conv4_feat =x if self.decode_num>=1 else []
x = self.features[ 34:44](x)
conv5_feat =x
else:
x = self.features[ 0: 5](x)
conv1_feat =x if self.decode_num>=4 else []
x = self.features[ 5:10](x)
conv2_feat =x if self.decode_num>=3 else []
x = self.features[ 10:17](x)
conv3_feat =x if self.decode_num>=2 else []
x = self.features[ 17:24](x)
conv4_feat =x if self.decode_num>=1 else []
x = self.features[ 24:31](x)
conv5_feat =x
feature_map = {'conv1':conv1_feat,'conv2': conv2_feat,\
'conv3':conv3_feat,'conv4': conv4_feat, 'conv5': conv5_feat}
# feature_map = [conv1_feat, conv2_feat, conv3_feat, conv4_feat, conv5_feat]
return feature_map
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
运行时出错:
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> (Tensor(a)):
Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.torch.nn.modules.container.Sequential'.
aten::slice.str(str string, int start, int end=9223372036854775807, int step=1) -> (str):
Expected a value of type 'str' for argument 'string' but instead found type '__torch__.torch.nn.modules.container.Sequential'.
aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[]):
Could not match type __torch__.torch.nn.modules.container.Sequential to List[t] in argument 'l': Cannot match List[t] to __torch__.torch.nn.modules.container.Sequential.
The original call is:
File "/home/ubuntu/mayub/Github/SS-DCNet/Network/SSDCNet.py", line 92
def forward(self,x):
if self.bn:
x = self.features[0:7](x)
~~~~~~~~~~~~~~~~~ <--- HERE
conv1_feat =x if self.decode_num>=4 else []
x = self.features[ 7:14](x)
谢谢 !