运行bert-bilstm-crf代码时报错,源码为:
class data_generator(DataGenerator):
"""
数据生成器
"""
def __iter__(self, random=True):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for is_end, item in self.sample(random):
token_ids, labels = [tokenizer._token_start], [0] # token_ids = ['[CLS]',234,4324,5234,'['SEP']']
for w, l in item: # w = 'XXXX' l='ORG'
w_token_ids = tokenizer.encode(w)[0][1:-1] # 将每个字转换为token
if len(token_ids) + len(w_token_ids) < 256:
token_ids += w_token_ids
if l == 'O':
labels += [0] * len(w_token_ids)
else:
# print(label2id)
# print(l)
B = label2id[l] * 2 + 1
I = label2id[l] * 2 + 2
labels += ([B] + [I] * (len(w_token_ids) - 1)) # 给B-xx I-xx 映射为对应id,写入lables
else:
break
token_ids += [tokenizer._token_end] # ['SEP']
labels += [0]
segment_ids = [0] * len(token_ids)
batch_token_ids.append(token_ids) # ['[CLS]',234,4324,5234,'['SEP']']
batch_segment_ids.append(segment_ids)
batch_labels.append(labels)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids) # 构造为ndarray矩阵[batch_size * max_len]
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
yield [batch_token_ids, batch_segment_ids], batch_labels
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
if __name__ == '__main__':
train_data, _ = load_data('./data/train.txt', max_len)
valid_data, _ = load_data('./data/test.txt', max_len)
train_generator = data_generator(train_data,batch_size)
valid_generator = data_generator(valid_data, batch_size * 5)
checkpoint = tf.keras.callbacks.ModelCheckpoint(
checkpoint_save_path,
monitor='val_sparse_accuracy',
verbose=1,
save_best_only=True,
mode='max'
)
Evaluator = Evaluator()
print(len(train_generator))
print(len(valid_generator))
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
validation_data=valid_generator.forfit(),
validation_steps=len(valid_generator),
epochs=epochs,
callbacks=[Evaluator]
)
print(K.eval(CRF.trans))
print(K.eval(CRF.trans).shape)
pickle.dump(K.eval(CRF.trans), open('./checkpoint/crf_trans.pkl', 'wb'))
else:
model.load_weights(checkpoint_save_path)
NER.trans = pickle.load(open('./checkpoint/crf_trans.pkl', 'rb'))
)
报错结果:
Epoch 1/4
2022-06-10 19:21:01.623720: W tensorflow/core/framework/op_kernel.cc:1722] OP_REQUIRES failed at cast_op.cc:121 : UNIMPLEMENTED: Cast string to float is not supported
Traceback (most recent call last):
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/train.py", line 123, in <module>
model.fit(
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnimplementedError: Graph execution error:
Detected at node 'model_1/Cast' defined at (most recent call last):
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/train.py", line 123, in <module>
model.fit(
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1384, in fit
tmp_logs = self.train_function(iterator)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1021, in train_function
return step_function(self, iterator)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1010, in step_function
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step
outputs = model.train_step(data)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/training.py", line 859, in train_step
y_pred = self(x, training=True)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
return fn(*args, **kwargs)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 1096, in __call__
outputs = call_fn(inputs, *args, **kwargs)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
return fn(*args, **kwargs)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 451, in call
return self._run_internal_graph(
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 571, in _run_internal_graph
y = self._conform_to_reference_input(y, ref_input=x)
File "/Users/jonnes/Desktop/KGcodes/NLP/myself ner/venv/lib/python3.8/site-packages/keras/engine/functional.py", line 671, in _conform_to_reference_input
tensor = tf.cast(tensor, dtype=ref_input.dtype)
Node: 'model_1/Cast'
Cast string to float is not supported
[[{{node model_1/Cast}}]] [Op:__inference_train_function_40171]