模型压缩之模型蒸馏

abstract

  • 论文地址:Distilling the Knowledge in a Neural Network
  • 在机器学习算法中,一个比较常用的方法是在相同的数据集上训练多个模型,然后对这些模型的预测结果进行加权,得到最终的预测结果,这也就是所谓的集成学习,由单个的弱学习器集成得到一个强学习器。但是如果基学习器太大,则这种方法难以在实际过程中使用。
  • 论文中针对mnist数据集进行实验,将集成的模型中的信息(知识)蒸馏到一个模型中。
  • 论文中也介绍了一种新的集成,包含了specialist models,用于去区分full model中无法区分的细粒度的类别。这些specialist models可以快速并行训练。

introduction

  • 一般在训练过程中,训练的模型很大,如果需要部署在有实时性要求的平台上,则可以对之前这个很大的模型进行蒸馏,得到小模型并进行部署。这个过程的难点主要就是怎样在减小模型大小的情况下尽可能多地保留大模型中学到的知识。
  • 复杂模型有一个问题:在分类问题中,它只关注最大概率的那个值,对于其他的值,即使差距很大,它们之间也没有差别(对loss也没有贡献)。
  • 将负载模型的泛化能力迁移到小模型上的一种方式是将复杂模型得到的类别概率作为soft targets,并让小模型去学习soft targets,我们使用当前的模型去,之后再使用这些小模型进行集成,得到最终的结果。相对于复杂模型,小模型在训练数据之间梯度的方差更小,因此小模型可以使用更大的学习率,在更小的模型上进行训练
  • 以mnist为例,数字2的被识别为3的概率是$10^{-6}$,被识别为7的概率是$10^{-9}$,它们之间的概率差距是很大的,这是非常有用的信息,但是这个信息对于最后的loss function几乎没有影响,有一个比较通用的方法,即在最后softmax时,使用更高的温度$T$,将最后的数值进行压缩(squash),使得其值得差异相对不明显,再使用相同的温度训练小模型,来学习这些softmax targets
  • 可以不使用labeled data,只使用得到的soft targets进行训练,但是作者发现同时使用原始带有labeled data的数据集进行训练会取得更好的效果,因此在这里同时预测soft targets与label,将两个loss进行加权求和,得到最终的loss。

distillation

  • match logits只是一个special case,softmax更一般的表达式如下

$${q_i} = \frac{\exp ({z_i}/T)}{\sum\nolimits_j {\exp ({z_j}/T)} }$$

  • 上面的公式中,$T$是温度系数,一般设置为1,在论文中,设置大于1,进行knowledge distillation,T越大,输出的概率之间的差异越小。
  • 在训练复杂模型和小模型的时候,T都是大于1的,在测试的时候,将T修改为1.
  • 当T很大时,loss对$z_i$进行求导,得到的偏导可以近似表示为

$$\frac{\partial C}{\partial {z_i}} \approx \frac{1}{N{T^2}}({z_i} - {v_i})$$

  • 更详细的公式可以参见原文,上式主要的意思就是当T很大时,BP时的类别之间的梯度差异会很小。T很小的时候,则主要会关注positive example
  • 作者同时也发现,当小模型太小,无法提取复杂模型的所有信息时,使用适中的温度$T$,可以去除very negative logits对结果的影响,这有助于最后的模型结果。

experiments and results

  • mnist:soft targets可以蒸馏得到更多的知识。
  • 语音识别任务上进行了实验,这里不详细说明。

discussion

  • 模型蒸馏可以很好的进行知识迁移,论文也是在mnist等小的数据集上做了实验,证明了集成学习中的多个神经网络可以被蒸馏为相同大小的一个网络,这使得部署更加方便。
  • 在大的数据集上,可以通过使用大量specialist nets来提升模型的精度,但是目前还没有证明能够将这些小网络蒸馏为单独的大网络。