我参考了这个 GitHub 代码库 https://github.com/HeroKillerEver/SeqGAN-Pytorch 来了解 SeqGAN,但我不太明白为什么在生成器的 'sample' 函数中使用了 LSTM,特别是在这一行代码中 output, (_, _) = self.lstm(embedding, (h, c)),这个部分就我的理解是没有更新到h和c的参数。
class Generator(nn.Module):
"""Generator"""
def __init__(self, vocab_size, embedding_size, hidden_dim, num_layers):
super(Generator, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.lstm = nn.LSTM(embedding_size, hidden_dim, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_dim, vocab_size)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
def forward(self, x):
"""
x: (None, sequence_len) LongTensor
"""
embedding = self.embedding(x) # (None, sequence_len, embedding_size)
batch_size = x.size(0)
h0, c0 = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
output, (_, _) = self.lstm(embedding, (h0, c0)) # (None, sequence_len, hidden_dim)
logits = self.linear(output) # (None, sequence_len, vocab_size)
logits = logits.transpose(1, 2) # (None, vocab_size, sequence_len)
return logits # (None, vocab_size, sequence_len)
def step(self, x, h, c):
"""
Args:
x: (batch_size, 1), sequence of tokens generated by generator
h: (1, batch_size, hidden_dim), lstm hidden state
c: (1, batch_size, hidden_dim), lstm cell state
"""
embedding = self.embedding(x) # (batch_size, 1, embedding_size)
self.lstm.flatten_parameters()
output, (h, c) = self.lstm(embedding, (h, c)) # (batch_size, 1, hidden_dim)
logits = self.linear(output).squeeze_(1) # (batch_size, vocab_size)
return logits, h, c
def sample(self, batch_size, sequence_len, x=None):
flag = False
if x is None:
x = util.to_var(torch.zeros(batch_size, 1).long())
flag = True
h, c = self.init_hidden(self.num_layers, batch_size, self.hidden_dim)
samples = []
if flag:
for _ in range(sequence_len):
logits, h, c = self.step(x, h, c)
probs = F.softmax(logits, dim=1)
sample = probs.multinomial(1) # (batch_size, 1)
samples.append(sample)
else:
given_len = x.size(1)
lis = x.chunk(x.size(1), dim=1)
for i in range(given_len):
logits, h, c = self.step(lis[i], h, c)
samples.append(lis[i])
x = F.softmax(logits, dim=1).multinomial(1)
for i in range(given_len, sequence_len):
samples.append(x)
logits, h, c = self.step(x, h, c)
x = F.softmax(logits, dim=1).multinomial(1)
output = torch.cat(samples, 1)
return output # (batch_size, sequence_len)
def init_hidden(self, num_layers, batch_size, hidden_dim):
"""
initialize h0, c0
"""
h = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
c = util.to_var(torch.zeros(num_layers, batch_size, hidden_dim))
return h, c
谢谢各位
不知道你这个问题是否已经解决, 如果还没有解决的话:回答:在SeqGAN生成器中,sample函数使用LSTM生成序列是因为LSTM可以处理序列信息,并且生成的序列可以与训练集中的序列相匹配。具体而言,LSTM是一种可记忆单元,它可以存储和访问之前时刻的信息,因此它可以生成与之前的单词相匹配的单词。
在这个代码库中,生成器的sample函数的代码如下:
def sample(self, batch_size, seq_len, output_type='softmax'):
self.eval()
samples = torch.zeros((batch_size, seq_len), dtype=torch.int64)
h = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
c = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
for t in range(seq_len):
if t == 0:
input = torch.zeros((batch_size, 1), dtype=torch.int64).to(self.device)
else:
input = samples[:, t - 1].unsqueeze(1).to(self.device)
embedding = self.embed(input)
output, (_, _) = self.lstm(embedding, (h, c))
h, c = self.update_state(output, (h, c))
if output_type == 'softmax':
probs = self.softmax(self.fc(output.view(-1, self.hidden_dim)))
sample = torch.multinomial(probs, 1).view(-1)
else:
sample = output.argmax(dim=-1).view(-1)
samples[:, t] = sample
return samples
在此函数中,可以看到第7行代码中h和c的初始化:
h = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
c = torch.zeros((1, batch_size, self.hidden_dim), requires_grad=True).to(self.device)
这些参数的初始化可以确保序列的开始状态是空的,并且在解码单词时,LSTM单元将以第一个输入单词开始,因此则输入的是0向量。
在代码的第16行和第17行,LSTM单元的输出为output,状态h和c保持不变。这是因为LSTM有短期和长期的记忆,它将组合两条信息的权重,从而得到进一步的加权单元,h和c也会根据此计算更新,从而得到下一个单词。
代码中的第21-24行主要解决output_type的问题,当所有概率均相同时,输出同一概率的每个单词,此时就需要Softmax函数,如果使用的是argmax,则仅选取最大值所对应的单词。
总的来说,SeqGAN中生成器的sample函数中使用LSTM生成序列,它自动处理输入序列并将以前的信息存储在隐藏状态中。在该函数实现中,LSTM单元的输出与之前的隐藏状态一起作为输入给新的LSTM单元以获得当前输出,并在下一次迭代时继续使用该结构。