TAS-Distilling Arbitrary Teacher And AND Student Via A Hybrid Assistant

TAS: Distilling Arbitrary Teacher And AND Student Via A Hybrid Assistant

Introduction

论文题目 :TAS: Distilling Arbitrary Teacher And AND Student Via A Hybrid Assistant

论文地址https://arxiv.org/pdf/2410.12342v1.pdf

论文出处 :arXiv 武大夏桂松团队+腾讯优图实验室

代码实现:待公开

Idea

引入一个代理模型作为桥梁,以促进异构教师和学生之间的顺利的特征知识转移。代理模型通过合并来自学生和教师模块功能的卷积和注意力模块,结合了跨架构偏置和模块功能的优势。

Detail

作者详细分析了为什么异构蒸馏如果按常规的同构蒸馏策略来表现很差,主要有以下两点:

  1. 归纳偏置:指的是模型用来对从未见过的数据进行的预测的一组假设

不同的model function会在不同阶段产生不同特征
ViT中的浅层特征和深层特征比分层的CNN有更高的相似度

虽然OFA试图将特征投影到Logit空间来解决异构特征的差距,但是对于特征特定知识的实质性,它是次优的

  1. 模块函数:指的是模型如何读取,编码、解码和处理数据。

CNN model生成的特征是由图中的局部感受野内的局部像素衍生的;
MSA model / MLP model生成的特征都是通过全局交换所有 patche 的信息生成的

MSA moel 输入图像分割成块,然后将其投影到查询、键和值中。
MLP model 也首先将输入图像分割成块。然后,它沿所有块的空间和通道维度混合全局信息。

Model

主要创新部分集中在整体架构上和L2G模块上:

整体架构:提出了TAS蒸馏架构,其中辅助模型主要由权重共享模块 S1→3 和 S4 组成,它们不仅统一了不同的模块功能,而且几乎不引入可忽略不计的额外可学习参数。

L2G模块:为了连接 CNN 和 MSA/MLP 模块,作者提出了L2G模块,其中包含了一个PE模块和MSA模块,该模块用于将CNN的特征转换到MSA模块的特征。

对于最终特征S4S^4通过平均池化对它们进行平滑,以减轻不同特征之间的空间差距.

L=LTAS(Kt,Ks)+LTAS(Kt,Ka)+LTAS(Ka,Ks),\mathcal{L}=\mathcal{L}_{\mathrm{TAS}}(\mathrm{K}_{\mathrm{t}},\mathrm{K}_{\mathrm{s}})+\mathcal{L}_{\mathrm{TAS}}(\mathrm{K}_{\mathrm{t}},\mathrm{K}_{\mathrm{a}})+\mathcal{L}_{\mathrm{TAS}}(\mathrm{K}_{\mathrm{a}},\mathrm{K}_{\mathrm{s}}),

LTASL_{TAS}这三者的具体公式展开是一样的:

LTAS(Kt,Ks)={LOFA(pt,ps)=(1+ptc^)γlog(ptc^psc^)+i=1,ic^Cptclog(ptcpsc),LInfoNCE(ft,fs)=logexp(fsft+/τ2)i=0Ftexp(fsfti/τ2)\mathcal{L}_{\mathrm{TAS}}(\mathrm{K}_{\mathrm{t}},\mathrm{K}_{\mathrm{s}})= \begin{cases} \mathcal{L}_{\mathrm{OFA}}(p_{\mathrm{t}},p_{\mathrm{s}})=(1+p_{\mathrm{t}}^{\hat{c}})^{\gamma}\log(\frac{p_{\mathrm{t}}^{\hat{c}}}{p_{\mathrm{s}}^{\hat{c}}})+\sum_{i=1,i\neq\hat{c}}^{C}p_{\mathrm{t}}^{c}\log(\frac{p_{\mathrm{t}}^{c}}{p_{\mathrm{s}}^{c}}), \\ \\ \mathcal{L}_{\mathrm{InfoNCE}}(f_{\mathrm{t}},f_{\mathrm{s}})=-\mathrm{log}\frac{\exp(f_{\mathrm{s}}\cdot f_{\mathrm{t}}^{+}/\tau_{2})}{\sum_{i=0}^{F_{\mathrm{t}}}\exp(f_{\mathrm{s}}\cdot f_{\mathrm{t}}^{i}/\tau_{2})} & & & \end{cases}

这里提出了用LinfoNCEL_{infoNCE}来计算特征蒸馏损失,用LOFAL_{OFA}来计算逻辑蒸馏损失

Result

作者分别在CIFAR-100和ImageNet数据集上进行了实验:

CIFAR-100数据集:

ImageNet数据集:

创新思路

  1. 交替使用CNN 和 MSA /MLP模块
  2. 修改L2G模块,优化特征的转化方式

这篇文章的代码还没有公布,所以我正在看OFA的代码