SeqGAN生成器中的sample函数如何使用LSTM生成序列?

我参考了这个 GitHub 代码库 https://github.com/HeroKillerEver/SeqGAN-Pytorch 来了解 SeqGAN,但我不太明白为什么在生成器的 'sample' 函数中使用了 LSTM,特别是在这一行代码中 output, (_, _) = self.lstm(embedding, (h, c)),这个部分就我的理解是没有更新到h和c的参数。

class Generator(nn.Module):
    """Generator"""
    def __init__(self, vocab_size, embedding_size, hidden_dim, num_layers):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers



    def forward(self, x):
        """
        x: (None, sequence_len) LongTensor
        """
        embedding = self.embedding(x) # (None, sequence_len, embedding_size)
        batch_size = x.size(0)
        h0, c0 = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
        output, (_, _) = self.lstm(embedding, (h0, c0))  # (None, sequence_len, hidden_dim)
        logits = self.linear(output) # (None, sequence_len, vocab_size)
        logits = logits.transpose(1, 2) # (None, vocab_size, sequence_len)

        return logits  # (None, vocab_size, sequence_len)


    def step(self, x, h, c):
        """
        Args:
            x: (batch_size,  1), sequence of tokens generated by generator
            h: (1, batch_size, hidden_dim), lstm hidden state
            c: (1, batch_size, hidden_dim), lstm cell state
        """
        embedding = self.embedding(x) # (batch_size, 1, embedding_size)
        self.lstm.flatten_parameters()
        output, (h, c) = self.lstm(embedding, (h, c)) # (batch_size, 1, hidden_dim)
        logits = self.linear(output).squeeze_(1)  # (batch_size, vocab_size)

        return logits, h, c






    def sample(self, batch_size, sequence_len, x=None):

        flag = False
        if x is None:
            x = util.to_var(torch.zeros(batch_size, 1).long())
            flag = True

        h, c = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
        samples = []
        if flag:
            for _ in range(sequence_len):
                logits, h, c = self.step(x, h, c)
                probs = F.softmax(logits, dim=1)
                sample = probs.multinomial(1) # (batch_size, 1)
                samples.append(sample)
        else:
            given_len = x.size(1)
            lis = x.chunk(x.size(1), dim=1)
            for i in range(given_len):
                logits, h, c = self.step(lis[i], h, c)
                samples.append(lis[i])
            x = F.softmax(logits, dim=1).multinomial(1)
            for i in range(given_len, sequence_len):
                samples.append(x)
                logits, h, c = self.step(x, h, c)
                x = F.softmax(logits, dim=1).multinomial(1)
        output = torch.cat(samples, 1)
        return output # (batch_size, sequence_len)


    def init_hidden(self, num_layers, batch_size, hidden_dim):
        """
        initialize h0, c0
        """
        h = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
        c = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))

        return h, c

谢谢各位

不知道你这个问题是否已经解决, 如果还没有解决的话:
  • 以下回答来自chatgpt:

    回答:在SeqGAN生成器中,sample函数使用LSTM生成序列是因为LSTM可以处理序列信息,并且生成的序列可以与训练集中的序列相匹配。具体而言,LSTM是一种可记忆单元,它可以存储和访问之前时刻的信息,因此它可以生成与之前的单词相匹配的单词。

    在这个代码库中,生成器的sample函数的代码如下:

    def sample(self, batch_size, seq_len, output_type='softmax'):
        self.eval()
        samples = torch.zeros((batch_size, seq_len), dtype=torch.int64)
        h = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
        c = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
    
        for t in range(seq_len):
            if t == 0:
                input = torch.zeros((batch_size, 1), dtype=torch.int64).to(self.device)
            else:
                input = samples[:, t - 1].unsqueeze(1).to(self.device)
            embedding = self.embed(input)
            output, (_, _) = self.lstm(embedding, (h, c))
            h, c = self.update_state(output, (h, c))
            if output_type == 'softmax':
                probs = self.softmax(self.fc(output.view(-1, self.hidden_dim)))
                sample = torch.multinomial(probs, 1).view(-1)
            else:
                sample = output.argmax(dim=-1).view(-1)
            samples[:, t] = sample
        return samples
    

    在此函数中,可以看到第7行代码中h和c的初始化:

    h = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
    c = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
    

    这些参数的初始化可以确保序列的开始状态是空的,并且在解码单词时,LSTM单元将以第一个输入单词开始,因此则输入的是0向量。

    在代码的第16行和第17行,LSTM单元的输出为output,状态h和c保持不变。这是因为LSTM有短期和长期的记忆,它将组合两条信息的权重,从而得到进一步的加权单元,h和c也会根据此计算更新,从而得到下一个单词。

    代码中的第21-24行主要解决output_type的问题,当所有概率均相同时,输出同一概率的每个单词,此时就需要Softmax函数,如果使用的是argmax,则仅选取最大值所对应的单词。

    总的来说,SeqGAN中生成器的sample函数中使用LSTM生成序列,它自动处理输入序列并将以前的信息存储在隐藏状态中。在该函数实现中,LSTM单元的输出与之前的隐藏状态一起作为输入给新的LSTM单元以获得当前输出,并在下一次迭代时继续使用该结构。


如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^