[pytorch][jit] 无法在脚本函数中索引 nn.ModuleList

2024-03-20 242 views
6
?漏洞

无法在脚本函数中索引 nn.ModuleList

重现

重现该行为的步骤:


import torch

class Model(torch.jit.ScriptModule):
    def __init__(self, ):
        super(Model, self).__init__()
        self.layers = torch.nn.ModuleList([torch.nn.Linear(2, 2)])

    @torch.jit.script_method
    def forward(self, input, index):
        return self.layers[index](input)

model = Model()
model(torch.randn(2, 2))

提供以下跟踪:

RuntimeError:
python value of type 'ModuleList' cannot be used as a value:
@torch.jit.script_method
def forward(self, input, index):
    return self.layers[index](input)
           ~~~~~~~~~~~ <--- HERE
预期行为

没有错误

环境

火炬1.0.0

抄送@ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @gmagogsfm @suo

回答

0

即使您定义 __contants__ 来使用 nn.ModuleList (如文档中的示例所示),它仍然会出错(略多于一半)。 https://pytorch.org/docs/stable/jit.html?highlight=model%20features

这是代码(从示例中更改了超级调用,以便它可以工作)

import torch
from torch import nn

class SubModule(torch.jit.ScriptModule):
    def __init__(self):
        super(SubModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(2))

    @torch.jit.script_method
    def forward(self, input):
        return self.weight + input

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['mods']

    def __init__(self):
        super(MyModule, self).__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    @torch.jit.script_method
    def forward(self, v):
        for module in self.mods:
            v = m(v)
        return v
MyModule()

产生

RuntimeError: 
undefined value m:
@torch.jit.script_method
def forward(self, v):
    for module in self.mods:
        v = m(v)
            ~ <--- HERE
    return v
环境:

火炬 1.0.0

8

@bswrundquist 它应该可以工作,你有一个错字:

    @torch.jit.script_method
    def forward(self, v):
        for module in self.mods:
            v = m(v)

应该

    @torch.jit.script_method
    def forward(self, v):
        for module in self.mods:
            v = module(v)
5

@zou3519 感谢您的指正。

0

迭代 ModuleList 确实可以工作,但直接索引它仍然不行,这是可以预料的,因为 ModuleList 也不能​​在非 jit 代码中索引。

7

@zou3519 你知道 PyTorch 团队是否有修复此问题的时间表吗?我有一个模块,其中有很多 ModuleList 索引。我正在考虑是否应该等待修复。(因为,重写所有这些代码有点乏味 - 所以我希望尽可能推迟它)

3

我记得反对这样做的论点是,如果我们不能静态地知道 ModuleList 的大小,那么就不可能对其进行索引。然而,如果我们将 ModuleList 标记为__constants__...的一部分,这可能是可能的(@eellison @suo 你们中的一个人可以在这里插话吗?)

8

或者如果我们无法索引它。我们可以支持吗enumerate?对我来说,我更需要为张量列表中的每个项目使用每个模块

像这样:

output = [tensor1, tensor2, tensor3]

for i in range(3):
     self.modules[i](output[i])

因此,对于我的用例,如果您可以支持enumerate,它也会起作用,例如

output = [tensor1, tensor2, tensor3]

for i, mod in enumerate(self.modules):
      mod(output[i])
4

@zou3519如果我理解正确的话,无论如何你都不能引用ModuleList if 它不在常量中。

3

@sidazhang我同意enumerate这非常重要,我们应该支持它,但是如果您需要一个快速的解决方法,这里是穷人的枚举,您可以使用:

i = 0
for mod in self.modules:
   mod(output[i])
   i += 1
2

@zou3519我认为我们将无法实现索引,因为模块/函数不是TorchScript中的第一类值,所以你不能将它们分配给变量(好吧,你可以,但我们需要能够静态地解析它们)。因此,这是一个更根本的困难,需要实现适当的调用机制。今天的要求是存在的,因为我们支持的唯一调用形式是函数内联。

7

呃。这有点恶心。哈哈

好的。我能做到

7

请修复它!我也有这个问题,Music Vae 不起作用

3

@Danlanchen 你能给我们更多关于它为什么不起作用的信息吗?就像错误消息和获取错误的可重现方法一样。谢谢!

6

抱歉,我通过在常量@suo中添加 self.modules 来修复它

0

zip如果我们想迭代两个模块列表,我们可以使用吗?@索

9

对于一流的模块来说这可能吗?

3

以下代码在 1.1.0 中仍然无法工作:

class MyMod(torch.jit.ScriptModule):
    __constants__ = ['loc_layers']
    def __init__(self):
        super(MyMod, self).__init__()
        locs = []
        for i in range(4):
            locs.append(nn.Conv2d(64, 21, kernel_size=1))
        self.loc_layers = nn.ModuleList(locs)

    @torch.jit.script_method
    def forward(self, input):
        # type: (List[Tensor]) -> Tensor
        locs = []
        i = 0
        for layer in input:
            loc = self.loc_layers[i](layer)
            locs.append(loc)
            i += 1
        loc_preds = torch.cat(locs, 1)
        return loc_preds

给出以下错误:

module cannot be used as a value:
...
loc = self.loc_layers[i](layer)
      ~~~~~~~~ <--- HERE
8

你好 !我也遇到同样的问题,请问你解决了吗?

6

@LCWdmlearning 使用我的代码作为示例,您可以迭代模块列表而不是张量列表:

i = 0
for layer in self.loc_layers:
    loc = layer(input[i])
    locs.append(loc)
    i += 1
6

这里有进展吗?我使用模块列表遇到同样的错误。

3

这个问题被其他一些正在进行的工作所阻止,以Module在 TorchScript 解释器中创建一流的对象,直到完成@TheCodez 的上述解决方案是推荐的解决方法

3

另一个数据点,我也遇到了这个问题......而且我有两个列表......使用for xxx in module_list会很痛苦......

6

@driazati 能够压缩 2 个 ModueList,这也很棒

9

另一个数据点,我也遇到了这个问题...并且我有两个列表...在 module_list 中使用 for xxx 会很痛苦...

9

这里同样的问题。

2

只是想在这里进行快速更新:这对我们来说是一个高优先级项目,但考虑到 TorchScript 的实现方式,修复并不容易,因此需要一些时间。谢谢大家的报道!

8

PR #29236 可能会有帮助。

0

@suo 不确定这是否已修复。我在 VGG16 模型上遇到了同样的错误。我正在使用最新的 Pytorchmaster分支。

代码片段:

class VGG16_frontend(nn.Module):
    def __init__(self,block_num=5,decode_num=0,load_weights=True,bn=False,IF_freeze_bn=False):
        super(VGG16_frontend,self).__init__()
        self.block_num = block_num
        self.load_weights = load_weights
        self.bn = bn
        self.IF_freeze_bn = IF_freeze_bn
        self.decode_num = decode_num

        block_dict = [[64, 64, 'M'], [128, 128, 'M'], [256, 256, 256, 'M'],\
             [512, 512, 512,'M'], [512, 512, 512,'M']]

        self.frontend_feat = []
        for i in range(block_num):
            self.frontend_feat += block_dict[i]

        if self.bn:
            self.features = make_layers(self.frontend_feat, batch_norm=True)
        else:
            self.features = make_layers(self.frontend_feat, batch_norm=False)

        if self.load_weights:
            if self.bn:
                pretrained_model = models.vgg16_bn(pretrained = True)
            else:
                pretrained_model = models.vgg16(pretrained = True)
            pretrained_dict = pretrained_model.state_dict()
            model_dict = self.state_dict()
            # filter out unnecessary keys
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # overwrite entries in the existing state dict
            model_dict.update(pretrained_dict) 
            # load the new state dict
            self.load_state_dict(model_dict)

        if IF_freeze_bn:
            self.freeze_bn()

    def forward(self,x):
        if self.bn: 
            x = self.features[ 0:7](x)
            print(type(x))
            print(x.shape)
            conv1_feat =x if self.decode_num>=4 else []
            x = self.features[ 7:14](x)
            conv2_feat =x if self.decode_num>=3 else []
            x = self.features[ 14:24](x)
            conv3_feat =x if self.decode_num>=2 else []
            x = self.features[ 24:34](x)
            conv4_feat =x if self.decode_num>=1 else []
            x = self.features[ 34:44](x)
            conv5_feat =x 
        else:
            x = self.features[ 0: 5](x)
            conv1_feat =x if self.decode_num>=4 else []
            x = self.features[ 5:10](x)
            conv2_feat =x if self.decode_num>=3 else []
            x = self.features[ 10:17](x)
            conv3_feat =x if self.decode_num>=2 else []
            x = self.features[ 17:24](x)
            conv4_feat =x if self.decode_num>=1 else []
            x = self.features[ 24:31](x)
            conv5_feat =x 

        feature_map = {'conv1':conv1_feat,'conv2': conv2_feat,\
            'conv3':conv3_feat,'conv4': conv4_feat, 'conv5': conv5_feat}   

        # feature_map = [conv1_feat, conv2_feat, conv3_feat, conv4_feat, conv5_feat]

        return feature_map

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

运行时出错:


RuntimeError: 
Arguments for call are not valid.
The following variants are available:

  aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> (Tensor(a)):
  Expected a value of type 'Tensor' for argument 'self' but instead found type '__torch__.torch.nn.modules.container.Sequential'.

  aten::slice.str(str string, int start, int end=9223372036854775807, int step=1) -> (str):
  Expected a value of type 'str' for argument 'string' but instead found type '__torch__.torch.nn.modules.container.Sequential'.

  aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[]):
  Could not match type __torch__.torch.nn.modules.container.Sequential to List[t] in argument 'l': Cannot match List[t] to __torch__.torch.nn.modules.container.Sequential.

The original call is:
  File "/home/ubuntu/mayub/Github/SS-DCNet/Network/SSDCNet.py", line 92
    def forward(self,x):
        if self.bn: 
            x = self.features[0:7](x)
                ~~~~~~~~~~~~~~~~~ <--- HERE
            conv1_feat =x if self.decode_num>=4 else []
            x = self.features[ 7:14](x)

谢谢 !

8

目前的情况是,我们支持使用整数文字对 ModuleList 进行索引,就像我们对元组所做的那样。我们不支持切片,这涉及创建新的 ModuleList。这在理论上是可能的;介意用你的例子提出一个单独的问题吗?

9

谢谢。我做了一个新的#38034

7

有关支持 ModuleDict 索引的任何更新吗?

6

这对任何人都有效吗?使用此(最新的 pytorch 版本)时,我仍然得到“' torch .torch.nn.modules.container.ModuleList'对象不可迭代”。

1

@kwanUm 是的,我用这种方法对 HRNetV2 进行了即时编译

0

在 PyTorch 1.6.0 上,在 ModuleList 上做索引仍然会出现此错误。

5

实际上在 pytorch 1.6 上对我有用。确保您没有迭代您所处理的类的子模块/数据成员的 ModuleList。这有一个未解决的错误。

7

在 PyTorch 1.6 下,ModuleDict字符串索引仍然给我Only ModuleList, Sequential, and ModuleDict modules are subscriptable:错误。有修复此问题的预计到达时间吗?如果我们不能使用字符串来索引 dict,这基本上就违背了 dict 的目的,对吗?

8

这个问题在1.7.0解决了吗?在 1.6.0 中,我使用 ModuleDict 遇到了同样的问题。:(

2

您能分享一下 JITTed HRNet 版本吗?