如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

我想要使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵,我的尝试方案:

问题相关代码
import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)

net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)

def fnet_single(params,x,y):
    result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
    return result.squeeze(0).squeeze(0)

text = u'江苏省苏州市读者马玉兰有一个在外地上学的朋友'
inputs = tokenizer.encode_plus(text)

segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2
batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)

for i in range(length):
    if i > 0:
        batch_token_ids[2 * i - 1, i] = 103
        batch_token_ids[2 * i - 1, i + 1] = 103
    batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
    x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
    jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
    x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
    jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)
    print(jacobian1,end='-----------------jacobian1-----------------\n')  
    print(jacobian2,end='-----------------jacobian2-----------------\n') 
运行结果及报错内容

Traceback (most recent call last):
File "study_jacrev.py", line 49, in
batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients

我想要达到的结果

可以计算出BertForMaskedLM的雅克比矩阵(梯度)

你好,我是有问必答小助手,非常抱歉,本次您提出的有问必答问题,技术专家团超时未为您做出解答


本次提问扣除的有问必答次数,已经为您补发到账户,我们后续会持续优化,扩大我们的服务范围,为您带来更好地服务。