利用pytorch实现GAN的同时梯度下降算法以及实现论文的新方法

最近在做GAN的研究,打算用GAN的DNN和CNN跑一下MINIST数据集,发现效果很差,可能是数据集太简单然后模型泛化了,然后收敛的也很差。总之查了很多文献资料发现The Numerics of GANs这篇论文使用的数据集也很简单,而且做下来效果也很不错。这篇论文用的是同时梯度下降算法(SimGA),反正查了一下整个网上好像还没有人用代码实现过这个方法,Github上原文是用tf1.0跑的,想着用pytorch实现一下,可是不知道如何让生成器和判别器的梯度同时下降,只能分别迭代,所以想问一下有没有可以用pytorch实现这个过程的方法。代码贴在下面了

# Simultaneous gradient steps
class SimGDOptimizer(object):
    def __init__(self, learning_rate):
        self._sgd = tf.train.RMSPropOptimizer(learning_rate)

    def conciliate(self, d_loss, g_loss, d_vars, g_vars, global_step=None):
        # Compute gradients
        d_grads = tf.gradients(d_loss, d_vars)
        g_grads = tf.gradients(g_loss, g_vars)

        # Merge variable and gradient lists
        variables = d_vars + g_vars
        grads = d_grads + g_grads

        # Gradient updates
        reg_grads = list(zip(grads, variables))

        train_op = self._sgd.apply_gradients(reg_grads)

        return [train_op]

然后是论文用的新方法

# Consensus optimization, method presented in the paper
class ConsensusOptimizer(object):
    def __init__(self, learning_rate, alpha=0.1, beta=0.9, eps=1e-8):
        self.optimizer = tf.train.RMSPropOptimizer(learning_rate)
        self._eps = eps
        self._alpha = alpha
        self._beta = beta

    def conciliate(self, d_loss, g_loss, d_vars, g_vars, global_step=None):
        alpha = self._alpha
        beta = self._beta

        # Compute gradients
        d_grads = tf.gradients(d_loss, d_vars)
        g_grads = tf.gradients(g_loss, g_vars)

        # Merge variable and gradient lists
        variables = d_vars + g_vars
        grads = d_grads + g_grads

        # Reguliarizer
        reg = 0.5 * sum(
            tf.reduce_sum(tf.square(g)) for g in grads
        )
        # Jacobian times gradiant
        Jgrads = tf.gradients(reg, variables)

        # Gradient updates
        apply_vec = [
             (g + self._alpha * Jg, v)
             for (g, Jg, v) in zip(grads, Jgrads, variables) if Jg is not None
        ]

        train_op = self.optimizer.apply_gradients(apply_vec)

        return [train_op]