[d2l-ai/d2l-zh]15.7.1 加载预训练 bert.base 模型报错

2023-12-11 114 views
4

加载代码

devices = d2l.try_all_gpus()
# devices = [torch.device('cpu')]

# 加载词向量
bert, vocab = load_pretrained_model('bert.base', 
                                    num_hiddens=256, 
                                    ffn_num_hiddens=512, 
                                    num_heads=4,
                                    num_layers=2, 
                                    dropout=0.1, 
                                    max_len=512, 
                                    devices=devices)

# 加大参数量
# bert, vocab = load_pretrained_model('bert.base', 
#                                     num_hiddens=768, 
#                                     ffn_num_hiddens=3072, 
#                                     num_heads=12,
#                                     num_layers=12, 
#                                     dropout=0.1, 
#                                     max_len=512, 
#                                     devices=devices)
报错截图

2023-04-05_222545

报错节选

RuntimeError: Error(s) in loading state_dict for BERTModel: Unexpected key(s) in state_dict: "encoder.blks.2.attention.W_q.weight", "encoder.blks.2.attention.W_q.bias", "encoder.blks.2.attention.W_k.weight", "encoder.blks.2.attention.W_k.bias", "encoder.blks.2.attention.W_v.weight", "encoder.blks.2.attention.W_v.bias", "encoder.blks.2.attention.W_o.weight", "encoder.blks.2.attention.W_o.bias", "encoder.blks.2.addnorm1.ln.weight", "encoder.blks.2.addnorm1.ln.bias", "encoder.blks.2.ffn.dense1.weight", "encoder.blks.2.ffn.dense1.bias", "encoder.blks.2.ffn.dense2.weight", "encoder.blks.2.ffn.dense2.bias", "encoder.blks.2.addnorm2.ln.weight", "encoder.blks.2.addnorm2.ln.bias", "encoder.blks.3.attention.W_q.weight", "encoder.blks.3.attention.W_q.bias", "encoder.blks.3.attention.W_k.weight", "encoder.blks.3.attention.W_k.bias", "encoder.blks.3.attention.W_v.weight", "encoder.blks.3.attention.W_v.bias", "encoder.blks.3.attention.W_o.weight", "encoder.blks.3.attention.W_o.bias", "encoder.blks.3.addnorm1.ln.weight", "encoder.blks.3.addnorm1.ln.bias", "encoder.blks.3.ffn.dense1.weight", "encoder.blks.3.ffn.dense1.bias", "encoder.blks.3.ffn.dense2.weight", "encoder.blks.3.ffn.dense2.bias", "encoder.blks.3.addnorm2.ln.weight", "encoder.blks.3.addnorm2.ln.bias", "encoder.blks.4.attention.W_q.weight", "encoder.blks.4.attention.W_q.bias", "encoder.blks.4.attention.W_k.weight", "encoder.blks.4.attention.W_k.bias", "encoder.blks.4.attention.W_v.weight", "encoder.blks.4.attention.W_v.bias", "encoder.blks.4.attention.W_o.weight", "encoder.blks.4.attention.W_o.bias", "encoder.blks.4.addnorm1.ln.weight", "encoder.blks.4.addnorm1.ln.bias", "encoder.blks.4.ffn.dense1.weight",

回答

1
bug 问题继续

根据报错信息. 进一步调整了输入的模型超参数


# bert模型
bert = BERTModel(vocab_size=60005,
                 num_hiddens=768,
                 norm_shape=[768],

                 ffn_num_input=768,
                 ffn_num_hiddens=3072,

                 num_heads=4,
                 num_layers=2,
                 dropout=0.2,
                 max_len=512,

                 key_size=768,
                 query_size=768,
                 value_size=768,

                 hid_in_features=768,
                 mlm_in_features=768,
                 nsp_in_features=768
                )
加载模型
base_path = r'this is bert abs path'

bert.load_state_dict(torch.load(data_dir))

底部的参数对不上问题没有了, 但是顶部的 keys in state_dict 仍然存在

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[70], line 2
      1 # 加载
----> 2 bert.load_state_dict(torch.load(data_dir))

File ~/.virtualenvs/dl-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for BERTModel:
    Unexpected key(s) in state_dict: "encoder.blks.2.attention.W_q.weight", "encoder.blks.2.attention.W_q.bias", "encoder.blks.2.attention.W_k.weight", "encoder.blks.2.attention.W_k.bias", "encoder.blks.2.attention.W_v.weight", "encoder.blks.2.attention.W_v.bias", "encoder.blks.2.attention.W_o.weight", "encoder.blks.2.attention.W_o.bias", "encoder.blks.2.addnorm1.ln.weight", "encoder.blks.2.addnorm1.ln.bias", "encoder.blks.2.ffn.dense1.weight", "encoder.blks.2.ffn.dense1.bias", "encoder.blks.2.ffn.dense2.weight", "encoder.blks.2.ffn.dense2.bias", "encoder.blks.2.addnorm2.ln.weight", "encoder.blks.2.addnorm2.ln.bias", "encoder.blks.3.attention.W_q.weight", "encoder.blks.3.attention.W_q.bias", "encoder.blks.3.attention.W_k.weight", "encoder.blks.3.attention.W_k.bias", "encoder.blks.3.attention.W_v.weight", "encoder.blks.3.attention.W_v.bias", "encoder.blks.3.attention.W_o.weight", "encoder.blks.3.attention.W_o.bias", "encoder.blks.3.addnorm1.ln.weight", "encoder.blks.3.addnorm1.ln.bias", "encoder.blks.3.ffn.dense1.weight", "encoder.blks.3.ffn.dense1.bias", "encoder.blks.3.ffn.dense2.weight", "encoder.blks.3.ffn.dense2.bias", "encoder.blks.3.addnorm2.ln.weight", "encoder.blks.3.addnorm2.ln.bias", "encoder.blks.4.attention.W_q.weight", "encoder.blks.4.attention.W_q.bias", "encoder.blks.4.attention.W_k.weight", "encoder.blks.4.attention.W_k.bias", "encoder.blks.4.attention.W_v.weight", "encoder.blks.4.attention.W_v.bias", "encoder.blks.4.attention.W_o.weight", "encoder.blks.4.attention.W_o.bias", "encoder.blks.4.addnorm1.ln.weight", "encoder.blks.4.addnorm1.ln.bias", "encoder.blks.4.ffn.dense1.weight", "encoder.blks.4.ffn.dense1.bias", "encoder.blks.4.ffn.dense2.weight", "encoder.blks.4.ffn.dense2.bias", "encoder.blks.4.addnorm2.ln.weight", "encoder.blks.4.addnorm2.ln.bias", "encoder.blks.5.attention.W_q.weight", "encoder.blks.5.attention.W_q.bias", "encoder.blks.5.attention.W_k.weight", "encoder.blks.5.attention.W_k.bias", "encoder.blks.5.attention.W_v.weight", "encoder.blks.5.attention.W_v.bias", "encoder.blks.5.attention.W_o.weight", "encoder.blks.5.attention.W_o.bias", "encoder.blks.5.addnorm1.ln.weight", "encoder.blks.5.addnorm1.ln.bias", "encoder.blks.5.ffn.dense1.weight", "encoder.blks.5.ffn.dense1.bias", "encoder.blks.5.ffn.dense2.weight", "encoder.blks.5.ffn.dense2.bias", "encoder.blks.5.addnorm2.ln.weight", "encoder.blks.5.addnorm2.ln.bias", "encoder.blks.6.attention.W_q.weight", "encoder.blks.6.attention.W_q.bias", "encoder.blks.6.attention.W_k.weight", "encoder.blks.6.attention.W_k.bias", "encoder.blks.6.attention.W_v.weight", "encoder.blks.6.attention.W_v.bias", "encoder.blks.6.attention.W_o.weight", "encoder.blks.6.attention.W_o.bias", "encoder.blks.6.addnorm1.ln.weight", "encoder.blks.6.addnorm1.ln.bias", "encoder.blks.6.ffn.dense1.weight", "encoder.blks.6.ffn.dense1.bias", "encoder.blks.6.ffn.dense2.weight", "encoder.blks.6.ffn.dense2.bias", "encoder.blks.6.addnorm2.ln.weight", "encoder.blks.6.addnorm2.ln.bias", "encoder.blks.7.attention.W_q.weight", "encoder.blks.7.attention.W_q.bias", "encoder.blks.7.attention.W_k.weight", "encoder.blks.7.attention.W_k.bias", "encoder.blks.7.attention.W_v.weight", "encoder.blks.7.attention.W_v.bias", "encoder.blks.7.attention.W_o.weight", "encoder.blks.7.attention.W_o.bias", "encoder.blks.7.addnorm1.ln.weight", "encoder.blks.7.addnorm1.ln.bias", "encoder.blks.7.ffn.dense1.weight", "encoder.blks.7.ffn.dense1.bias", "encoder.blks.7.ffn.dense2.weight", "encoder.blks.7.ffn.dense2.bias", "encoder.blks.7.addnorm2.ln.weight", "encoder.blks.7.addnorm2.ln.bias", "encoder.blks.8.attention.W_q.weight", "encoder.blks.8.attention.W_q.bias", "encoder.blks.8.attention.W_k.weight", "encoder.blks.8.attention.W_k.bias", "encoder.blks.8.attention.W_v.weight", "encoder.blks.8.attention.W_v.bias", "encoder.blks.8.attention.W_o.weight", "encoder.blks.8.attention.W_o.bias", "encoder.blks.8.addnorm1.ln.weight", "encoder.blks.8.addnorm1.ln.bias", "encoder.blks.8.ffn.dense1.weight", "encoder.blks.8.ffn.dense1.bias", "encoder.blks.8.ffn.dense2.weight", "encoder.blks.8.ffn.dense2.bias", "encoder.blks.8.addnorm2.ln.weight", "encoder.blks.8.addnorm2.ln.bias", "encoder.blks.9.attention.W_q.weight", "encoder.blks.9.attention.W_q.bias", "encoder.blks.9.attention.W_k.weight", "encoder.blks.9.attention.W_k.bias", "encoder.blks.9.attention.W_v.weight", "encoder.blks.9.attention.W_v.bias", "encoder.blks.9.attention.W_o.weight", "encoder.blks.9.attention.W_o.bias", "encoder.blks.9.addnorm1.ln.weight", "encoder.blks.9.addnorm1.ln.bias", "encoder.blks.9.ffn.dense1.weight", "encoder.blks.9.ffn.dense1.bias", "encoder.blks.9.ffn.dense2.weight", "encoder.blks.9.ffn.dense2.bias", "encoder.blks.9.addnorm2.ln.weight", "encoder.blks.9.addnorm2.ln.bias", "encoder.blks.10.attention.W_q.weight", "encoder.blks.10.attention.W_q.bias", "encoder.blks.10.attention.W_k.weight", "encoder.blks.10.attention.W_k.bias", "encoder.blks.10.attention.W_v.weight", "encoder.blks.10.attention.W_v.bias", "encoder.blks.10.attention.W_o.weight", "encoder.blks.10.attention.W_o.bias", "encoder.blks.10.addnorm1.ln.weight", "encoder.blks.10.addnorm1.ln.bias", "encoder.blks.10.ffn.dense1.weight", "encoder.blks.10.ffn.dense1.bias", "encoder.blks.10.ffn.dense2.weight", "encoder.blks.10.ffn.dense2.bias", "encoder.blks.10.addnorm2.ln.weight", "encoder.blks.10.addnorm2.ln.bias", "encoder.blks.11.attention.W_q.weight", "encoder.blks.11.attention.W_q.bias", "encoder.blks.11.attention.W_k.weight", "encoder.blks.11.attention.W_k.bias", "encoder.blks.11.attention.W_v.weight", "encoder.blks.11.attention.W_v.bias", "encoder.blks.11.attention.W_o.weight", "encoder.blks.11.attention.W_o.bias", "encoder.blks.11.addnorm1.ln.weight", "encoder.blks.11.addnorm1.ln.bias", "encoder.blks.11.ffn.dense1.weight", "encoder.blks.11.ffn.dense1.bias", "encoder.blks.11.ffn.dense2.weight", "encoder.blks.11.ffn.dense2.bias", "encoder.blks.11.addnorm2.ln.weight", "encoder.blks.11.addnorm2.ln.bias".