gan-mnist代码解读

introduction

code

  • 主要有2个文件夹,一个mxgan,是主要代码,一个example,用于测试。

    mxgan

    encoder.py

  • 这个主要是将输入图像转化为一个向量,可以使用lenet或者一个dcgan model。
  • 输出的向量的维度越小,计算量越小,但是可能难以表征所以隐含的特征,越大可能越准确,但是计算量也会越大,目前关于该向量的维度,没有得到证明的上界。

generator.py

  • 这个是decoder模块,将向量转化为最终地图像,相当于G网络生成图像,可以使用dcgan或者fcgan方法生成图像。

module.py

  • 这个是GAN的module部分,D网络是encoder部分,fake image的输出向量为0,real image的输出向量为1;fake image是使用D网络在给定随机数的输入之后通过CNN生成的。
  • mxnet的module在训练的时候,通常情况下有forward,backward,update三个过程,分别对应前向计算输出、反向计算梯度、根据梯度更新权重。在训练GAN的时候。过程如下:
    • 首先生成随机数,维度为batch_size * vector_dims
    • 使用G网络生成fake image。
    • 设置label为0,将fake image和label输入到D网络中,进行forward,backward,先存储梯度信息,不进行更新(后面进行real image的forward之后再进行更新)。
    • 更新G网络:将label设置为1,将fake image传入到D网络中(因为G网络的目的就是使D网络将fake image识别为real image),再使用G网络的输入梯度(get_input_grads函数)进行backwardupdate
    • 更新D网络:设置label为pos label(作者在这里设置了默认为0.9),将real image和label输入到D网络中;进行backward,和之前的梯度相加,之后进行update

example

gan_mnist.py

  • 这是作者写的训练gan-mnist的过程,我自己是添加了模型保存(G网络)的代码。

generate_mnist.py

  • 根据生成的G网络,随机生成图像数据。