STPM 利用教师学生网络进行无监督异常检测

本文最后更新于:2022年5月21日 凌晨

异常检测是缺陷检测领域中的重要内容,本文记录运用 STPM 进行异常检测的方法。

STPM

摘要

异常检测问题是一个具有挑战性的任务,通常被定义为针对意外性异常的一类学习问题。本文针对这一问题提出了一种简单而有效的方法,这种方法以其优点在师生框架中得到了实施,但在准确性和效率方面得到了实质性的扩展。在给定一个作为教师的图像分类训练模型的情况下,我们将知识提取到一个具有相同结构的单个学生网络中来学习无异常图像的分布,这种一步转移尽可能地保留了关键线索。此外,我们将多尺度的特征匹配策略集成到框架中,这种层次化的特征匹配使学生网络在更好的监督下能够从特征金字塔中接收到多层次的知识混合,从而允许检测不同规模的异常。两个网络生成的特征金字塔之间的差异可以作为一个评分函数,表明发生异常的概率。由于这样的操作,我们的方法实现了准确和快速的像素级异常检测。非常具有竞争力的结果是在 MVTec 异常检测数据集上提供的,优于最先进的数据集。

核心思路

  • 方法的目标是检测异常出现的数据分布
  • 首先训练一个”见多识广“的教师网络,该网络对测试数据集中绝大多数数据都可以给出自己的正常”见解“,包括异常部分的数据
  • 随后训练一个具有相同结构与规模的学生网络,该网络通过蒸馏教师网络的中间金字塔层输出进行训练,在训练过程中仅使用正常数据,训练得到一个仅识得"仁义礼智"的学生
  • infer 时,将测试数据喂给“教师”和“学生”,正常数据二者表现应该很接近;面对异常数据,“教师”可以泰然处之,“学生”则会方寸大乱,整合二者在金字塔层特征的差异情况来判断是否出现了异常数据

方法流程

学生网络训练阶段

  • 教师、学生网络仅使用 Backbone 即可
  • 在教师 Forward 操作过程中保留金字塔特征 $F_t$,不更新参数
  • 学生网络保留同样位置的特征 $F_s$,设计 Loss 函数,更新自己的参数
  • 每个特征值的损失函数:
$$ \ell^{l}\left(\mathbf{I}_{k}\right)_{i j}=\frac{1}{2}\left\|\hat{F}_{t}^{l}\left(\mathbf{I}_{k}\right)_{i j}-\hat{F}_{s}^{l}\left(\mathbf{I}_{k}\right)_{i j}\right\|_{\ell_{2}}^{2} $$
  • 其中:
$$ \hat{F}_{t}^{l}\left(\mathbf{I}_{k}\right)_{i j}=\frac{F_{t}^{l}\left(\mathbf{I}_{k}\right)_{i j}}{\left\|F_{t}^{l}\left(\mathbf{I}_{k}\right)_{i j}\right\|_{\ell_{2}}}, \hat{F}_{s}^{l}\left(\mathbf{I}_{k}\right)_{i j}=\frac{F_{s}^{l}\left(\mathbf{I}_{k}\right)_{i j}}{\left\|F_{s}^{l}\left(\mathbf{I}_{k}\right)_{i j}\right\|_{\ell_{2}}} $$
  • 每个金字塔特征层的损失为该层中所有特征损失之和:
$$ \ell^{l}\left(\mathbf{I}_{k}\right)=\frac{1}{w_{l} h_{l}} \sum_{i=1}^{w_{l}} \sum_{j=1}^{h_{l}} \ell^{l}\left(\mathbf{I}_{k}\right)_{i j} $$
  • 总损失为所有金字塔特征层的损失之加权和:
$$ \ell(\mathbf{I}_k) =\sum_{l=1}^{L}{α_l }{\ell}^l(\mathbf{I}_k) $$

权重 $α_l$​ 为非负数

训练时,若教师学生输入相似但不同的话,学生加载预训练模型会更容易训练网络

测试阶段

  • 测试阶段,教师、学生网络Forward 得到金字塔特征
  • 特征归一化
  • 逐层特征计算逐个值的损失,在 Channel 维度求和得到金字塔层级数量的损失 map
  • 将各层 map 上采样到原始图像大小
  • 将上采样的特征逐元素求乘积(有的源码实现时采用的求和策略)

$$
\Omega ( J ) = \prod _ { t = 1 } ^ { L } U _ { p s a m p l e } \Omega ^ { l } ( J )
$$

  • 得到最终异常检测 Map

原始论文

参考源码

参考资料