introduction
- 本文主要是写一下看了
gan-mnist
代码之后的一些总结,之前是陈天奇大神写的,我fork到自己的repo中,添加了一些其他的代码,地址:https://github.com/littletomatodonkey/gan-mnist
code
- 主要有2个文件夹,一个
mxgan
,是主要代码,一个example
,用于测试。mxgan
encoder.py
- 这个主要是将输入图像转化为一个向量,可以使用lenet或者一个dcgan model。
- 输出的向量的维度越小,计算量越小,但是可能难以表征所以隐含的特征,越大可能越准确,但是计算量也会越大,目前关于该向量的维度,没有得到证明的上界。
generator.py
- 这个是decoder模块,将向量转化为最终地图像,相当于G网络生成图像,可以使用dcgan或者fcgan方法生成图像。
module.py
- 这个是GAN的module部分,D网络是encoder部分,fake image的输出向量为0,real image的输出向量为1;fake image是使用D网络在给定随机数的输入之后通过CNN生成的。
- mxnet的module在训练的时候,通常情况下有
forward
,backward
,update
三个过程,分别对应前向计算输出、反向计算梯度、根据梯度更新权重。在训练GAN的时候。过程如下:- 首先生成随机数,维度为
batch_size * vector_dims
- 使用G网络生成fake image。
- 设置label为0,将fake image和label输入到D网络中,进行
forward
,backward
,先存储梯度信息,不进行更新(后面进行real image的forward之后再进行更新)。 - 更新G网络:将label设置为1,将fake image传入到D网络中(因为G网络的目的就是使D网络将fake image识别为real image),再使用G网络的输入梯度(
get_input_grads
函数)进行backward
与update
。 - 更新D网络:设置label为
pos label
(作者在这里设置了默认为0.9),将real image和label输入到D网络中;进行backward
,和之前的梯度相加,之后进行update
。
- 首先生成随机数,维度为
example
gan_mnist.py
- 这是作者写的训练gan-mnist的过程,我自己是添加了模型保存(G网络)的代码。
generate_mnist.py
- 根据生成的G网络,随机生成图像数据。