一种基于知识蒸馏小样本增量学习的图片分类方法及系统与流程

未命名 08-03 阅读:99 评论:0


1.本发明涉及一种图片分类方法及系统,尤其是基于知识蒸馏小样本增量学习的图片分类方法及系统。


背景技术:

2.随着人工智能技术的不断发展和应用,增量学习因其强大的适用性逐渐受到了学术界和工业界的关注。增量学习,指的是对于一个已经训练好的模型,在面临新数据时,不需要使用全部数据重新训练整个模型,而是渐进地对模型进行更新。通过不断修正和加强以前的知识,使得模型在新数据上具有泛化性。增量学习降低了模型训练过程中对时间和空间的需求,广泛应用于推荐系统、图片分类等领域中。当前大多数增量学习方法的训练需要大量新类样本,而在现实环境中,受到人力、物力和客观因素的制约,数据获取往往十分困难导致样本量稀少,这严重影响了传统增量学习方法的性能。
3.知识蒸馏作为一种重要的学习范式,通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息训练这个轻量化模型,以达到更好的性能和精度。其中来自大模型输出的监督信息称之为知识,而小模型学习迁移来自大模型的监督信息称之为蒸馏。然而,传统的知识蒸馏方法往往依赖于大量的训练样本。在小样本场景中,由于缺乏足够多的样本,新旧类别之间的样本数量差异较大,模型在训练或预测过程中往往倾向于更大的旧类训练样本集,容易造成严重的类别不平衡问题导致性能下降,基类与新类样本之间的不平衡也使得模型难以学习新类别。


技术实现要素:

4.发明目的:本发明的目的是提供一种能够提高小样本学习性能的基于知识蒸馏小样本增量学习的图片分类方法;本发明的第二目的是提供一种能够提高小样本学习性能的基于知识蒸馏小样本增量学习的图片分类系统。
5.技术方案:本发明所述的基于知识蒸馏小样本增量学习的图片分类方法,通过蒸馏网络判断输入图片所属类别,包括如下步骤:(1)将随机初始化的resnet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;(2)冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;(3)冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;(4)利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高
的类别为所述输入图像的所属类别。
6.进一步地,步骤(1)中利用所述预热网络计算类别原型包括:类别c的类别原型pc为:;其中,s
ic
表示小样本分类任务中支持集si中类别为c的数据集,|s
ic
|表示s
ic
的大小,x
t
为小样本分类任务中样本的特征向量,y
t
为对应样本的标签,代表该样本所属的类别;为预热网络。
7.进一步地,步骤(1)中训练所述预热网络的预热损失函数lh为:;其中,qi为小样本分类任务的查询集,xq为查询集qi中的新样本,yq为对应样本的标签,代表该样本所属的类别;为归一化分类函数,每个类别c的归一化分类分数为,为softmax函数;为权重,每个类别c的权重为,dc为类别c中类别原型与其他同类别样本的距离和,m为si中除类别c以外的其他类别。
8.进一步地,步骤(2)中利用所述增量网络计算新增类别的类别原型包括:新增类别c' 的类别原型p'
c' 为:;其中,s_new
jc' 表示小样本增量任务中增量支持集s_newj中类别为c' 的数据集;|s_new
jc' |表示s_new
jc' 的大小,x
t' 为小样本增量任务中样本的特征向量,y
t' 为对应样本的标签,代表该样本所属的类别;为增量网络。
9.进一步地,步骤(2)中训练所述增量网络的增量损失函数lr为:;其中,q_newj为小样本增量任务中的增量查询集,x
q' 为增量查询集q_newj中的新样本,y
q'
为对应样本的标签,代表该样本所属的类别;为权重,为增量网络;
;q_new
jc' 为q_newj中类别为的数据集,xn表示q_new
jc' 中类别为c'的其他样本,yn为对应样本的标签。
10.进一步地,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛前,利用蒸馏损失函数对所述增量网络进行训练直至收敛,所述蒸馏损失函数的计算方法为:使用任务无关的数据集du进行蒸馏学习,根据du在预热网络和增量网络上的输出分布f
θ (xu)和g
φ (xu)分别计算蒸馏损失项:;;蒸馏损失函数为;其中为蒸馏网络,t为蒸馏温度系数,xu为du中的样本,λ为参数。
11.进一步地,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛包括:利用增量损失函数更新增量网络参数,冻结增量网络;计算蒸馏损失函数并更新蒸馏网络参数和增量网络参数;冻结蒸馏网络,利用更新的增量网络参数重新计算增量损失函数,优化增量网络;重复上述步骤训练所述增量网络和所述蒸馏网络直至收敛。
12.进一步地,步骤(4)包括以下内容:计算新增类别c' 的最终类别原型 为:;计算样本与每个最终类别原型之间的相似度,;利用上述公式计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
13.本发明所述基于知识蒸馏小样本增量学习的图片分类系统,用于通过蒸馏网络判断输入图片所属类别,包括:预热网络模块,用于将随机初始化的resnet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对于每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;
增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
14.本发明所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述的基于知识蒸馏小样本增量学习的图片分类方法。
15.有益效果:与现有技术相比,本发明的优点在于:(1)提出了一个三阶段算法,通过预热、增量学习和知识蒸馏三个阶段有效地提升了模型性能;(2)将增量学习任务和小样本学习任务分割开,减少小样本对模型带来的过拟合问题;(3)基于任务无关数据集的知识蒸馏方法,有利于提高模型的可扩展性,有效缓解了增量学习中的类别遗忘问题;(4)交互迭代更新方法能够使目标函数进一步收敛,更适用于小样本情况下的模型训练,同时提高了小样本问题下的图片分类效果。
附图说明
16.图1为本发明的图片分类方法流程图。
17.图2为本发明的模型迭代训练阶段的流程示意图。
18.图3为本发明实施例的分类检测准确率对比图。
具体实施方式
19.下面结合附图对本发明的技术方案作进一步说明。
20.对于给定图片判别该图片所属类别的任务,可以看作使用训练集训练出一个高效的检测模型,之后使用该模型对图片进行分类检测。如图1所示,本发明所述的基于知识蒸馏小样本增量学习的图片分类方法,包括以下步骤:数据预处理阶段、模型迭代训练阶段、预测阶段。
21.(1)数据预处理阶段:选择miniimagenet作为预热阶段和增量学习阶段的数据集,cifar10作为知识蒸馏阶段中与任务无关的数据集。
22.对于miniimagenet和cifar10数据集分别按照8:2的比例划分为训练集和测试集。
23.将测试集中的图片使用中心裁剪方法,以图像中心点为参照,按照224
×
224像素大小从外向内进行裁剪。对于小于指定尺寸的图片,在原始图像外侧填充0,再进行中心裁剪。
24.对训练集中的样本进行数据增广。具体地,使用水平翻转方法基于随机的概率水平翻转图片,改变部分图片方向;然后使用图像抖动方法,随机改变图像的亮度、对比度、锐度和饱和度,让特征值在随机因子范围[10,30]内随机变换,增加样本多样性;之后使用图片旋转方法,不改变图片大小、亮度等特征的基础上得到新的数据,使图片分别旋转至
、和,增大样本数量;最后使用随机裁剪方法,将图片随机裁剪为不同大小并改变宽高比,然后缩放至224
×
224像素大小。
[0025]
(2)模型迭代训练阶段,优化嵌入网络,如图2所示。
[0026]
(2.1)将预处理后的训练集作为模型的输入,随机在miniimagenet数据集中选取60个类别作为预热阶段的基础数据集,首先随机初始化的resnet18作为预热网络,使用预热网络计算类别原型;具体步骤如下:使用基于任务的episode训练策略,对于每个episode,从基础训练集中随机抽取n个类,在每类都分别抽取k个样本组成支持集s,然后再从这n个类中剩余的样本抽取一部分数据作为查询集q,构成的分类问题被称为n-way k-shot小样本任务,整体的训练任务由若干个小样本任务构成。对于每个episode执行一个小样本任务t
i = {s
i ,qi},ti为预热阶段第i个子任务,si为子任务中的支持集,qi为子任务中的查询集,计算类别c的原型为:;
[0027]
其中,pc表示在特征空间中样本类别为c的类别原型,s
ic
表示支持集si中类别为c的数据集,|s
ic
| 表示s
ic
的大小,x
t
为样本的特征向量,y
t
为对应样本的标签。
[0028]
(2.2)计算每个类别中原型与其他同类样本的距离之和:;
[0029]
根据每个类别中原型与同类样本距离计算类别的权重:;其中为softmax函数,m为属于si数据集的其他类别。
[0030]
(2.3)对于来自查询集qi中的新样本xq,利用如下的距离判别得到每个类别c的归一化分类分数:;其中为softmax函数。
[0031]
(2.4)指定预热损失函数lh为:;其中,xq为属于查询集qi的样本,yq为对应样本的标签,代表该样本所属的类别;使用上述损失函数对预热网络进行迭代训练直至模型收敛。
[0032]
(2.5)使用miniimagenet数据集中剩余的40个类别作为增量阶段新类数据集,分8次逐步加入,每次增量学习任务新加入5个类别,每个类别随机采样k个样本。构建增量学习
网络,首先冻结经过预热训练后的参数θ,并使用其作为增量网络的初始值;对于一个小样本增量任务,为增量学习阶段的第j个子任务,s_newj为增量子任务中的支持集,q_newj为增量子任务中的查询集,计算新增类别c'的原型为:;其中,p'
c'
表示在特征空间中样本类别为c' 的类别原型,s_new
jc' 表示增量支持集s_newj中类别为c' 的数据集,|s_new
jc' |表示数据集s_new
jc'
的大小,x
t' 为样本的特征向量,y
t' 为对应样本的标签。
[0033]
(2.6)对于来自增量查询集q_newj中的新样本x
q' ,根据每个样本到所属类别原型p'
c' 的距离计算样本的权重:;其中表示标签为c'的样本x
q' 的权重值,为softmax函数,xn表示类别为c'的其他样本,yn为对应样本的标签。
[0034]
(2.7)构建增量损失函数lr:;其中,x
q' 为属于增量查询集的样本,y
q' 为对应样本的标签;使用上述损失函数对增量网络进行迭代训练直至模型收敛。
[0035]
(2.8)使用cifar10作为知识蒸馏阶段的任务无关数据集,随机选取10个类别,每个类别随机选择1000张图片。构建蒸馏网络,首先冻结预热网络和增量网络的参数,并拷贝训练后的增量网络的参数作为蒸馏网络的初始值;使用任务无关的数据集进行蒸馏学习;根据其在预热网络和增量网络上的输出分布f
θ (xu)和g
φ (xu),其中x
u ∈du,分别计算蒸馏损失项:;;其中,为softmax函数,t为蒸馏温度系数。
[0036]
(2.9)用参数调整新旧类别比例并作累加,计算蒸馏损失函数为l
kd

;在实验中设置λ = 0.1;使用上述损失函数对增量网络进行迭代训练直至模型收敛。
[0037]
(2.10)在模型训练过程中,使用交叉迭代网络更新方法,具体包括:首先使用增量损失函数lr更新增量网络中的参数φ,之后冻结增量网络,计算得到蒸馏损失函数l
kd
,对蒸馏网络中的参数σ和增量网络中的φ进行更新,接下来冻结蒸馏网络,根据更新后的参数φ得到新的增量损失函数lr,进一步优化增量网络,重复上述交叉迭代网络更新步骤直至训练函数收敛。
[0038]
(3)预测阶段:(3.1)将预处理后的miniimagenet新类数据集中的测试集数据作为模型输入,使用训练后的蒸馏网络来计算样本的特征向量,计算每个类对应的支持集样本的平均值作为该类的原型:;其中,为特征空间中样本类别为c' 的最终类别原型。
[0039]
(3.2)通过小样本图像分类函数计算测试样本与每个类别原型之间的相似度,最后得到相似度最高的类作为最终检测结果,小样本图像分类函数为:。
[0040]
通过仿真实验对本发明所述的基于知识蒸馏小样本增量学习的图片分类方法进行验证,使用python实现所述的模型训练方法与测试方法,并与icarl、eeil、topic等小样本增量学习方法对比,在miniimagenet数据集5-way 5-shot任务下对比结果如图3所示。所有的程序都是在配有intel core i7-8700 cpu,3.20ghz,32 gbram和nvidia titan rtx的标准服务器上执行的,采用激活函数为relu函数的resnet18神经网络, 设置优化器为adam。在预热阶段和增量学习阶段中,使用0.1作为初始学习率,并在训练过程中使之逐步递减为原值的十分之一。在知识蒸馏阶段学习率固定为0.001,迭代20轮后停止。从图3中可以看出,本发明所述的基于知识蒸馏小样本增量学习的图片分类方法的分类识别准确率比其他方法取得了较大程度的领先,相较于topic算法提高了10%左右的最终分类准确率,表现出了更适合小样本学习这一特殊任务的优越性,显著高效地提升了模型性能。
[0041]
本发明所述基于知识蒸馏小样本增量学习的图片分类系统,用于通过蒸馏网络判断输入图片所属类别,包括:预热网络模块,用于将随机初始化的resnet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对于每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任
务,对所述增量网络进行训练直至收敛;蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。
[0042]
本发明所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现所述的基于知识蒸馏小样本增量学习的图片分类方法。
[0043]
所述计算机可读存储媒体可包括ram、rom、eeprom、cd-rom 或其它光盘存储装置、磁盘存储装置或其它磁性存储装置、快闪存储器或可用来存储指令或数据结构的形式的所要程序代码并且可由计算机存取的任何其它媒体。
[0044]
处理器用于执行存储器存储的计算机程序,以实现上述实施例涉及的方法中的各个步骤。

技术特征:
1.一种基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,通过蒸馏网络判断输入图片所属类别,包括如下步骤:(1)将随机初始化的resnet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;(2)冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;(3)冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;(4)利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。2.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(1)中利用所述预热网络计算类别原型包括:类别c的类别原型p
c
为:;其中,s
ic 表示小样本分类任务中支持集s
i 中类别为c的数据集,|s
ic | 表示s
ic 的大小, x
t 为小样本分类任务中样本的特征向量,y
t 为对应样本的标签,代表该样本所属的类别;为预热网络。3.根据权利要求2所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(1)中训练所述预热网络的预热损失函数l
h
为:;其中,q
i 为小样本分类任务的查询集,x
q 为查询集q
i 中的新样本,y
q 为对应样本的标签,代表该样本所属的类别;为归一化分类函数,每个类别c的归一化分类分数为,为softmax函数;为权重,每个类别c的权重为,d
c 为类别c中类别原型与其他同类别样本的距离和,m为s
i 中除类别c以外的其他类别。4.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(2)中利用所述增量网络计算新增类别的类别原型包括:新增类别c' 的类别原型p'
c' 为:
;其中,s_new
jc' 表示小样本增量任务中增量支持集s_new
j 中类别为c'的数据集;表示s_new
jc' 的大小,x
t' 为小样本增量任务中样本的特征向量,y
t' 为对应样本的标签,代表该样本所属的类别;为增量网络。5.根据权利要求4所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(2)中训练所述增量网络的增量损失函数l
r
为:;其中,q_new
j
为小样本增量任务中的增量查询集,x
q' 为增量查询集q_new
j
中的新样本,y
q'
为对应样本的标签,代表该样本所属的类别;为权重,为增量网络;;q_new
jc' 为q_new
j
中类别为的数据集,x
n
表示q_new
jc' 中类别为c' 的其他样本, y
n
为对应样本的标签;为softmax函数。6.根据权利要求5所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛前,利用蒸馏损失函数对所述增量网络进行训练直至收敛,所述蒸馏损失函数的计算方法为:使用任务无关的数据集d
u
进行蒸馏学习,根据d
u
在预热网络和增量网络上的输出分布f
θ (x
u
)和g
φ (x
u
)分别计算蒸馏损失项:;;蒸馏损失函数为;其中为蒸馏网络,为softmax函数,t为蒸馏温度系数,x
u
为d
u
中的样本,λ为参数。7.根据权利要求6所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(3)所述交叉迭代训练所述增量网络和所述蒸馏网络直至收敛包括:利用增量损失函数更新增量网络参数,冻结增量网络;计算蒸馏损失函数并更新蒸馏
网络参数和增量网络参数;冻结蒸馏网络,利用更新的增量网络参数重新计算增量损失函数,优化增量网络;重复上述步骤训练所述增量网络和所述蒸馏网络直至收敛。8.根据权利要求1所述的基于知识蒸馏小样本增量学习的图片分类方法,其特征在于,步骤(4)包括以下内容:计算新增类别c' 的最终类别原型 为:;计算样本与每个最终类别原型之间的相似度,;其中为蒸馏网络,s_new
jc'
表示小样本增量任务中增量支持集s_new
j 中类别为c' 的数据集,表示s_new
jc'
的大小,x
t' 为小样本增量任务中样本的特征向量,y
t' 为对应样本的标签,代表该样本所属的类别;利用上述公式计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。9.一种基于知识蒸馏小样本增量学习的图片分类系统,其特征在于,用于通过蒸馏网络判断输入图片所属类别,包括:预热网络模块,用于将随机初始化的resnet18作为预热网络,利用所述预热网络计算类别原型,使用基于任务的episode训练策略,对每个episode执行一个小样本分类任务,对所述预热网络进行训练直至收敛;增量网络模块,用于冻结所述预热网络的参数,并将该参数作为增量网络的初始值,利用所述增量网络计算新增类别的类别原型,对每个episode执行一个小样本增量任务,对所述增量网络进行训练直至收敛;蒸馏网络及交叉迭代模块,用于冻结所述预热网络和所述增量网络的参数,将所述预热网络和所述增量网络通过知识蒸馏形成蒸馏网络,将所述增量网络的参数作为所述蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;预测模块,用于利用所述蒸馏网络计算输入图像与每个类别原型之间的相似度,相似度最高的类别为所述输入图像的所属类别。10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现根据权利要求1-8任一项所述的基于知识蒸馏小样本增量学习的图片分类方法。

技术总结
本发明公开了一种基于知识蒸馏小样本增量学习的图片分类方法及系统,通过蒸馏网络判断输入图片所属类别,该方法利用预热网络计算类别原型,对于每个episode执行一个小样本分类任务;然后将预热网络的参数作为增量网络的初始值,计算新增类别的类别原型,对每个episode执行一个小样本增量任务;将预热网络和增量网络通过知识蒸馏形成蒸馏网络,将增量网络的参数作为蒸馏网络的初始值,交叉迭代训练所述增量网络和所述蒸馏网络直至收敛;利用所述蒸馏网络计算相似度得到输入图像的所属类别;本发明通过预热、增量学习和知识蒸馏三个阶段减少小样本过拟合问题,缓解了增量学习中的类别遗忘问题,提高了小样本问题下的图片分类效果。分类效果。分类效果。


技术研发人员:许扬汶 韩冬 刘天鹏 罗广宁 孙腾中 李彦辰
受保护的技术使用者:南京大数据集团有限公司
技术研发日:2023.06.27
技术公布日:2023/8/1
版权声明

本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)

航空之家 https://www.aerohome.com.cn/

飞机超市 https://mall.aerohome.com.cn/

航空资讯 https://news.aerohome.com.cn/

分享:

扫一扫在手机阅读、分享本文

相关推荐