使用TF2.4 构建bert-Bilstm-crf,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4”,怎么解决?

这几天在利用TensorFlow2.4构建bert-Bilstm-crf做NER时,出现这样的warning:“CRF Decoding does not work with KerasTensors in TF2.4. The bug has since been fixed in tensorflow/tensorflow##45534”。
虽然模型能够建起来,summary也能看到,但这会对训练有影响吗?怎么解决?
我的TensorFlow版本为2.4,keras2bert版本为0.89.0
CRF用的是tensorflow_addons(版本为0.13.0)的layers里面的,即“from tensorflow_addons.layers import CRF”
代码如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Bidirectional, LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
# from tensorflow_addons.layers import CRF
import tensorflow_addons
from tensorflow_addons.layers import CRF

from tf2crf import ModelWithCRFLoss
import os
import keras_bert


class MyModel(object):
    def __init__(self, vocab_size: int, num_class: int, max_len: int = 100,
                 embedding_dim: int = 128, rnn_units: int = 128, drop_rate: float = 0.5):
        self.vocab_size = vocab_size
        self.num_class = num_class
        self.max_len = max_len
        self.embedding_dim = embedding_dim
        self.rnn_units = rnn_units
        self.drop_rate = drop_rate
        self.bert_config_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_config.json"
        self.bert_check_point_path = "./chinese_bert_wwm_L-12_H-768_A-12/bert_model.ckpt"
        self.vocab_path = "./chinese_bert_wwm_L-12_H-768_A-12/vocab.txt"
        self.crf = CRF(num_class)

    def build_model(self):
        model = keras_bert.load_trained_model_from_checkpoint(
            config_file=self.bert_config_path,
            checkpoint_file=self.bert_check_point_path,
            seq_len=self.max_len,
            trainable=True
        )
        inputs = model.inputs
        embedding = model.output
        x = Bidirectional(LSTM(units=self.rnn_units, return_sequences=True))(embedding)
        x = Dropout(self.drop_rate)(x)
        x = Dense(self.num_class)(x)
        x = CRF(self.num_class)(x)
        model = Model(inputs=inputs, outputs=x)
        return model


if __name__ == '__main__':
    mymodel = MyModel(
        vocab_size=300,
        num_class=5,
        max_len=100
    )
    model = mymodel.build_model()
    model.summary()

warning如下:

img

解决方法

这是个警告,一般可能不影响

但如果你想从问题上解决这个警告的话,或许可以尝试下安装这个tf2crf

pip install tf2crf

然后

from tf2CRF import CRF

Demo:

import tensorflow as tf
from tf2CRF import CRF
from tensorflow.keras.layers import Input, Embedding, Bidirectional, GRU, Dense
from tensorflow.keras.models import Model
from tf2crf import CRF, ModelWithCRFLoss

inputs = Input(shape=(None,), dtype='int32')
output = Embedding(100, 40, trainable=True, mask_zero=True)(inputs)
output = Bidirectional(GRU(64, return_sequences=True))(output)
crf = CRF(units=9, type='float32')
output = crf(output)
base_model = Model(inputs, output)
model = ModelWithCRFLoss(base_model, sparse_target=True)
model.compile(optimizer='adam')

x = [[5, 2, 3] * 3] * 10
y = [[1, 2, 3] * 3] * 10

model.fit(x=x, y=y, epochs=2, batch_size=2)
model.save('tests/1')

参考:


GitHub - keras-team/keras-contrib: Keras community contributions Keras community contributions. Contribute to keras-team/keras-contrib development by creating an account on GitHub. https://github.com/keras-team/keras-contrib#install-keras_contrib-for-tensorflowkeras

如有问题及时沟通

warning不影响正常运行,可以先看下模型的效果应该没事

先不要管warning,那只是一个警告,一般不影响结果的,重要的是你看看你的模型有没有被影响到,我每一次写这个东西,都会加一个


import warnings

warnings.filterwarnings('ignore')