修改Transformer,我的encoder怎么提取不到有效信息

改动Transformer做文本匹配,训练后准确率极低且不收敛
验证发现从encoder层对premise和hypothesis作self-Attention就有问题,但是水平太差看不出问题在哪
该层代码如下:


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.5, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):

        y = self.pe[:x.size(0), :]
        x = x + y
        return self.dropout(x)

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, embedding_dim):
        super(PoswiseFeedForwardNet, self).__init__()
        self.embedding_dim = embedding_dim
        self.fc = nn.Sequential(
            nn.Linear(self.embedding_dim, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, self.embedding_dim, bias=False)
        )

    def forward(self, inputs):

        residual = inputs
        output = self.fc(inputs)
        return nn.LayerNorm(self.embedding_dim).to(device)(output + residual)

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.W_Q = nn.Linear(self.embedding_dim, d_k * n_heads, bias=False) 
        self.W_K = nn.Linear(self.embedding_dim, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(self.embedding_dim, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, self.embedding_dim, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):

        residual, batch_size = input_Q, input_Q.size(0)
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)
        output = self.fc(context)  
        res_out = output + residual
        return nn.LayerNorm(self.embedding_dim).to(device)(res_out), attn

class EncoderLayer(nn.Module):

    def __init__(self, embedding_dim):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(self.embedding_dim)
        self.pos_ffn = PoswiseFeedForwardNet(self.embedding_dim)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)  
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

class Encoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding_dim):
        super(Encoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.layers = nn.ModuleList([EncoderLayer(self.embedding_dim) for _ in range(n_layers)])

    def forward(self, embedded_inputs, inputs):

        enc_self_attn_mask = get_attn_pad_mask(inputs, inputs) 
        enc_self_attns = []  
        for layer in self.layers: 
            embedded_inputs, enc_self_attn = layer(embedded_inputs, enc_self_attn_mask)  
            enc_self_attns.append(enc_self_attn) 
        return embedded_inputs, enc_self_attns

class Transformer_01(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding_dim,
                 hidden_size,
                 embeddings=512,
                 padding_idx=0,
                 dropout=0.5,
                 num_classes=3,
                 device="cuda"
                 ):
        super(Transformer_01, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.dropout = dropout
        self.device = device
        self.src_emb = nn.Embedding(self.vocab_size,
                                    self.embedding_dim,
                                    padding_idx=padding_idx,
                                    _weight=embeddings)
        self.pos_emb = PositionalEncoding(self.embedding_dim)  
        self.encoder = Encoder(self.vocab_size, self.embedding_dim).to(device)
...

def forward(self, premises, premises_lengths, hypotheses, hypotheses_lengths):

        # embedding
        embedded_premises = self.src_emb(premises) 
        embedded_hypotheses = self.src_emb(hypotheses) 

        # position-embedding
        embedded_premises = self.pos_emb(embedded_premises.transpose(0, 1)).transpose(0, 1)  
        embedded_hypotheses = self.pos_emb(embedded_hypotheses.transpose(0, 1)).transpose(0, 1)  

        # self-encoder
        pre_enc_outputs, pre_enc_self_attns = self.encoder(embedded_premises, premises)
        hyp_enc_outputs, hyp_enc_self_attns = self.encoder(embedded_hypotheses, hypotheses)
...