本地复现陈丹琦的PureNer抽取,加载中文数据集竟然报错?

中文数据集是CMeEE,错误是:Error(s) in loading state_dict for BertForEntity:size mismatch for ner_classifier.1.weight;size mismatch for ner_classifier.1.bias;与这个入口参数有关:max_span_length,这个参数是用来初始化预训练bert的:self.width_embedding = nn.Embedding(max_span_length+1, width_embedding_dim),各位前辈可否给点建议。调试已久,没有思路,即使将max_span_length改成CMeEE的实体类别个数也报同样的错误。具体报错如下:

03/09/2022 18:35:29 - INFO - transformers.modeling_utils - All the weights of BertForEntity were initialized from the model checkpoint at /data/sciERC-bert-model/ent-scib-ctx0/.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertForEntity for predictions without further training.
Traceback (most recent call last):
File "E:\java-workspace\PureTest\wonerMain.py", line 154, in
model = EntityModel(args, num_ner_labels=num_ner_labels)
File "E:\java-workspace\PureTest\wonerModel.py", line 174, in init
self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)
File "C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\transformers\modeling_utils.py", line 781, in from_pretrained
model.class.name, "\n\t".join(error_msgs)
RuntimeError: Error(s) in loading state_dict for BertForEntity:
size mismatch for ner_classifier.1.weight: copying a param with shape torch.Size([7, 150]) from checkpoint, the shape in current model is torch.Size([10, 150]).
size mismatch for ner_classifier.1.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([10]).

不要着急,把报错发过来看一下

控制台报错如下:
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorflow\python\framework\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
03/10/2022 09:24:56 - INFO - root - ['E:\java-workspace\PureTest\wonerMain.py', '--bert_model_dir', 'E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0', '--do_train', '--do_eval', '--context_window', '0', '--task', 'CMeEE', '--data_dir', 'E:\java-workspace\data\ali_tianchi\CMeEE', '--output_dir', 'E:\java-workspace\PureTest\cme_models\ent-scib-ctx0']
03/10/2022 09:24:56 - INFO - root - Namespace(bert_model_dir='E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0', bertadam=False, context_window=0, data_dir='E:\java-workspace\data\ali_tianchi\CMeEE', dev_data='E:\java-workspace\data\ali_tianchi\CMeEE\dev.json', dev_pred_filename='ent_pred_dev.json', do_eval=True, do_train=True, eval_batch_size=32, eval_per_epoch=1, eval_test=False, learning_rate=1e-05, max_span_length=8, model='bert-base-uncased', num_epoch=100, output_dir='E:\java-workspace\PureTest\cme_models\ent-scib-ctx0', print_loss_step=100, seed=0, task='CMeEE', task_learning_rate=0.0001, test_data='E:\java-workspace\data\ali_tianchi\CMeEE\test.json', test_pred_filename='ent_pred_test.json', train_batch_size=32, train_data='E:\java-workspace\data\ali_tianchi\CMeEE\train.json', train_shuffle=False, use_albert=False, warmup_proportion=0.1)
03/10/2022 09:24:56 - INFO - root - Loading BERT model from E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - Model name 'E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/' is a path, a model identifier, or url to a directory containing tokenizer files.
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - Didn't find file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/added_tokens.json. We won't load it.
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - Didn't find file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/tokenizer.json. We won't load it.
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - loading file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/vocab.txt
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - loading file None
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - loading file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/special_tokens_map.json
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - loading file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/tokenizer_config.json
03/10/2022 09:24:56 - INFO - transformers.tokenization_utils_base - loading file None
03/10/2022 09:24:56 - INFO - transformers.configuration_utils - loading configuration file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/config.json
03/10/2022 09:24:56 - INFO - transformers.configuration_utils - Model config BertConfig {
"architectures": [
"BertNerRe"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 31090
}

03/10/2022 09:24:56 - INFO - transformers.modeling_utils - loading weights file E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/pytorch_model.bin
03/10/2022 09:24:58 - INFO - transformers.modeling_utils - All model checkpoint weights were used when initializing BertForEntity.

03/10/2022 09:24:58 - INFO - transformers.modeling_utils - All the weights of BertForEntity were initialized from the model checkpoint at E:/java-workspace/data/sciERC-bert-model/ent-scib-ctx0/.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertForEntity for predictions without further training.
Traceback (most recent call last):
File "E:\java-workspace\PureTest\wonerMain.py", line 154, in
model = EntityModel(args, num_ner_labels=num_ner_labels)
File "E:\java-workspace\PureTest\wonerModel.py", line 174, in init
self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)
File "C:\ProgramData\Anaconda3\envs\pytorch37\lib\site-packages\transformers\modeling_utils.py", line 781, in from_pretrained
model.class.name, "\n\t".join(error_msgs)
RuntimeError: Error(s) in loading state_dict for BertForEntity:
size mismatch for ner_classifier.1.weight: copying a param with shape torch.Size([7, 150]) from checkpoint, the shape in current model is torch.Size([10, 150]).
size mismatch for ner_classifier.1.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([10]).

报错代码:
1: File "E:\java-workspace\PureTest\wonerMain.py", line 154, in , 是这样的,
model = EntityModel(args, num_ner_labels=num_ner_labels)
2、 File "E:\java-workspace\PureTest\wonerModel.py", line 174, in init
class EntityModel():

def __init__(self, args, num_ner_labels):
    super().__init__()

    bert_model_name = args.model
    vocab_name = bert_model_name
    
    if args.bert_model_dir is not None:
        bert_model_name = str(args.bert_model_dir) + '/'
        # vocab_name = bert_model_name + 'vocab.txt'
        vocab_name = bert_model_name
        logger.info('Loading BERT model from {}'.format(bert_model_name))

    if args.use_albert:
        self.tokenizer = AlbertTokenizer.from_pretrained(vocab_name)
        self.bert_model = AlbertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)
    else:
        self.tokenizer = BertTokenizer.from_pretrained(vocab_name)

** line 174: self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels, max_span_length=args.max_span_length)**

    self._model_device = 'cpu'
    self.move_model_to_cuda()

CMeEE实体类别是这样设置的:
task_ner_labels = {
'ace04': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
'ace05': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
'scierc': ['Method', 'OtherScientificTerm', 'Task', 'Generic', 'Material', 'Metric'],
'CMeEE':['bod', 'dis', 'sym', 'mic', 'pro', 'ite', 'dep', 'dru', 'equ'],
}

task_rel_labels = {
'ace04': ['PER-SOC', 'OTHER-AFF', 'ART', 'GPE-AFF', 'EMP-ORG', 'PHYS'],
'ace05': ['ART', 'ORG-AFF', 'GEN-AFF', 'PHYS', 'PER-SOC', 'PART-WHOLE'],
'scierc': ['PART-OF', 'USED-FOR', 'FEATURE-OF', 'CONJUNCTION', 'EVALUATE-FOR', 'HYPONYM-OF', 'COMPARE'],
}

label2id = {"bod": 0, "dis": 1, "sym": 2, "mic": 3, "pro": 4, "ite": 5, "dep": 6, "dru": 7, "equ": 8} #CMeEE

def get_labelmap(label_list):
label2id = {}
id2label = {}
for i, label in enumerate(label_list):
label2id[label] = i + 1
id2label[i + 1] = label
return label2id, id2label