一种面向类不平衡数据分布的联邦学习方法
未命名
08-13
阅读:104
评论:0

1.本发明涉及联邦学习技术领域,具体是一种面向类不平衡数据分布的联邦学习方法。
背景技术:
2.深度学习已在许多领域取得了重要的进展,如语音识别,图像处理等。这些成功的进展通常需要将大量的数据集中起来训练,然而在实际场景中,数据往往分布在不同的设备上,由于存在隐私泄露的风险及高昂的传输成本问题,各数据拥有方不愿意直接共享其数据。联邦学习(federated learning)的分布式框架,可以使多个参与方(客户端)协同训练一个高性能的联邦模型而无需向服务器上传本地私有数据,为数据隐私保护和高效的分布式训练提供了良好的解决方案,目前已经在许多领域中得到广泛关注。
3.在标准的联邦学习框架fedavg中,每轮更新随机选择一部分客户端参与训练,被选择的客户端从服务器下载当前联邦模型,并利用本地数据训练后将更新结果上传到服务器,服务器聚合本地模型的更新结果产生新一轮的联邦模型,这个过程不断重复,直到模型收敛。但是fedavg没有考虑类不平衡问题对联邦学习模型性能的影响。而事实是,类不平衡问题是联邦学习中一个重要的挑战。由于不同客户端的数据可能具有不同的来源和各自的偏好,导致客户端的本地数据集中类别分布不均衡的现象经常存在。例如在医学诊断任务中,患有癌症的人群数量比健康人群的数量少的多;在信用卡欺诈检测任务中,正常交易的数据会占大多数,而真正发生异常情况(存在欺诈风险)的数据却极少。对于联邦学习方法来说,客户端本地数据集中类不平衡的现象,表现为多数类的样本在数据集中的数量远远超过了少数类的样本数量。在本地模型训练过程中,多数类的数据可能非常丰富,这会导致本地模型对这些数据的过拟合和泛化性能下降。另一方面,少数类的数据相对较少,因此很容易被模型忽略,导致这些类别的性能下降。但是在实际应用场景中,对少数类别的正确分类有时具有更加重要的作用,如在医学诊断任务中要比诊断出健康更加敏感的诊断出癌症,而在信用卡欺诈检测任务中要比检测出正常使用更加敏感的检测出欺诈行为。
4.由于缺乏对数据类别分布不平衡的考虑,现有的联邦学习方法通常在类不平衡数据集上的表现不佳,特别是对少样本类别的分类效果较差,而这样有偏差的本地模型会进一步影响联邦模型的聚合效果。因此,如何在联邦学习中类别不平衡的情况下保持联邦模型性能,仍然是一个值得研究和探索的问题。
技术实现要素:
5.本发明的目的是针对现有技术的不足,而提供一种面向类不平衡数据分布的联邦学习方法。这种方法能缓解数据类别不平衡问题对模型训练的不利影响,提高模型性能和效率、方法简单,且有效。
6.实现本发明目的的技术方案是:
7.一种面向类不平衡数据分布的联邦学习方法,包括如下步骤:
8.1)模型初始化:设立联邦学习整体架构,所示整体架构设有一个服务器和与服务器连接的n个客户端,初始化由特征提取器和分类器两部分组成的服务器端的联邦模型,具体为:每个客户端k都有其对应的本地数据集其中是数据集dk中第i个样本数据及其对应标签,i=1,2,...,nk,nk表示客户端k的数据集中样本的总量,每个样本数据都有其对应的类别c,且c∈[1,...,c],其中,表示客户端k的数据集中类别c的样本数量,且每个客户端k的数据集dk中不同类别的样本数量是不平衡的,初始化服务器端的联邦模型网络架构为具有参数w的神经网络,且模型由特征提取器ew(
·
)和分类器fw(
·
)两部分组成,其中,特征提取器ew(
·
),能够从输入的样本数据(xi,yi)中提取出特征表示,并将特征表示映射为一个d维特征向量,特征提取器中提取的特征表示向量将发送给分类器,待分类器对输入数据进行分类预测,分类器fw(
·
),能够为每个输入的数据生成分类预测结果;
[0009]
2)模型下发:服务器从所有客户端中随机选择一组客户端参与当前轮次的联邦训练,并下发联邦模型参数至被选择的客户端,具体为:在第t轮联邦训练中,服务器从n个客户端中随机选择一组参与当前轮次训练的客户端组成客户端子集s
t
,并将联邦模型w
t
下发至被选择的客户端k∈s
t
;
[0010]
3)客户端本地模型训练:客户端k在接收到服务器下发的联邦模型w
t
后,采用本地数据集dk对所得模型进行训练、更新本地模型参数其中,在本地训练阶段,客户端需要分别对本地模型的特征提取器和分类器进行优化,因此本地训练损失包括特征提取损失le和分类损失lf两部分,具体优化包括:
[0011]
3-1)特征提取器优化:采用特征提取损失le优化客户端本地模型训练阶段的特征提取器,使本地模型能够学习到更好的特征表示,客户端k在第t轮本地模型训练的过程中,对于某一输入的样本分别从联邦模型w
t
的中提取的特征表示从本地模型中提取的特征表示从上一轮本地模型中提取的特征表示为了方便表示,令由于联邦模型在每一轮聚合了所有参与训练的客户端学习到的特征表示,具有更好的泛化能力,因此联邦模型通常能够学习到比本地模型更好的特征表示,所以在特征提取损失le中,将和h
t
视为正对,从而拉近本地模型和联邦模型学习到的特征表示的距离,同时,为了使客户端的本地模型学习到比上一轮更好的特征表示,将当前轮次中本地模型学习到的特征表示和上一轮中本地模型学习到的特征表示视为负对,从而增加本地模型在当前轮次学习到的特征表示和上一轮学习到的特征表示的距离,达到特征提取器自身的更优化效果,客户端k在本地模型训练中的特征提取损失le表示为:
[0012][0013]
式中,是余弦相似度函数,且τ表示温度系数,设置τ=0.5;
[0014]
3-2)分类器优化:采用分类损失lf优化客户端本地模型训练阶段的分类器,能够提高客户端在不平衡数据集上的模型训练效果,从而使最终聚合的联邦模型对类不平衡条件更加稳健,客户端k在本地模型训练中的分类损失lf表示为:
[0015][0016]
式中,nb为客户端本地训练中的批量大小,为客户端本地训练中的批量大小,表示在客户端k的一个批量大小中类别c的样本数量,c∈[1,...,c],为符号函数,如果样本的真实类别属于类别c则取1,否则取0,表示客户端k在第t轮本地训练中的分类器对样本属于类别c的预测概率,平衡系数用来调节客户端k中的类别c在分类损失中的权重,从而使分类损失能够适用于类别不平衡的数据,平衡系数的计算方式为:并且平衡系数能够在客户端本地训练中对不同样本数量的类别进行权重的平衡,即某一类别的样本数量越少,则平衡系数越大,使分类器更加注重提高样本数量较少的类别的训练效果,更进一步地,考虑分类器最终分类结果的准确性,除了修正不同类别中样本数量的不平衡,分类器还应该更多关注训练效果不好的样本即分类器输出的预测值与真实值之间差异更大的样本,所以增加一个惩罚项来纠正训练效果不好的样本,这使得训练效果越差的样本将获得越大的惩罚值,在本地模型训练阶段,将客户端k的本地损失函数表示为两部分,第一部分为特征提取损失le,第二部分为分类损失lf,客户端k的本地训练损失表示为:
[0017][0018]
式中,λ是控制特征提取损失和分类损失的权重的超参数,在第t轮本地模型训练中,客户端k采用接收到的联邦模型w
t
在本地数据集上采用学习率η进行随机梯度下降来更新第t轮的本地模型参数更新方式表示为:
[0019]
4)模型参数上传:参与本轮训练的客户端将本地模型更新的参数上传至服务器,用于更新联邦模型;
[0020]
5)模型参数聚合:服务器将接收到的本地模型参数进行加权平均聚合产生新一轮的联邦模型参数w
t+1
,聚合方式表示为:
[0021][0022]
6)判断:服务器判断当前通信轮数是否达到设定的通信轮数的阈值,若未达到阈值,一般阈值设置为100,则返回步骤2)继续训练,否则结束联邦训练、向所有客户端发送最终的联邦模型参数;
[0023]
7)分类:联邦训练结束,各客户端将最终得到的联邦模型应用于本地任务,对本地任务中的待分类图片进行分类,对于本地任务中类不平衡分布的图片,各客户端能够准确地进行分类。
[0024]
与现有技术相比,本技术方案将联邦模型分为特征提取器和分类器两部分,在本地模型训练过程中,通过最大化本地客户端与服务器端的特征表示的一致性优化客户端的本地特征提取器,同时,对不同类别的数据分配不同的损失权重,使模型训练时更多关注于少样本的类别,优化客户端本地模型的分类器。本技术方案缓解了数据类别不平衡问题对联邦模型训练的不利影响,并且不需要向服务器上传客户端本地数据集中真实的数据或类别信息,有效的保障了客户端的数据隐私。
[0025]
这种方法能缓解数据类别不平衡问题对模型训练的不利影响,提高模型性能和效率、方法简单,且有效。
附图说明
[0026]
图1为实施例方法的流程示意图;
[0027]
图2为实施例的系统示意图;
[0028]
图3为实施例中本地模型训练过程示意图。
具体实施方式
[0029]
下面结合附图和实施例对本发明的内容做进一步的阐述,但不是对本发明的限定。
[0030]
实施例:
[0031]
本例中面向类不平衡数据的联邦学习任务,各客户端需要在不共享本地隐私数据的情况下协作训练一个性能良好的图片分类模型,其中,客户端的本地数据集由不同类别的图片数据组成,且不同类别的图片样本数量是不均衡的,
[0032]
参照图2,一种面向类不平衡数据分布的联邦学习方法,包括如下步骤:
[0033]
1)模型初始化:设立联邦学习整体架构,如图1所示,所示整体架构设有一个服务器和与服务器连接的n个客户端,初始化由特征提取器和分类器两部分组成的服务器端的联邦模型,本例中:每个客户端k都有其对应的本地数据集其中是数据集dk中第i个样本数据及其对应标签,i=1,2,...,nk,nk表示客户端k的数据集中样本的总量,每个样本数据都有其对应的类别c,且c∈[1,...,c],其中,表示客户端k的数据集中类别c的样本数量,且每个客户端k的数据集dk中不同类别的样本数
量是不平衡的,初始化服务器端的联邦模型网络架构为具有参数w的神经网络,且模型由特征提取器ew(
·
)和分类器fw(
·
)两部分组成,其中,特征提取器ew(
·
),能够从输入的样本数据(xi,yi)中提取出特征表示,并将特征表示映射为一个d维特征向量,特征提取器中提取的特征表示向量将发送给分类器,待分类器对输入数据进行分类预测,分类器fw(
·
),能够为每个输入的数据生成分类预测结果;
[0034]
2)模型下发:服务器从所有客户端中随机选择一组客户端参与当前轮次的联邦训练,并下发联邦模型参数至被选择的客户端,本例中:在第t轮联邦训练中,服务器从n个客户端中随机选择一组参与当前轮次训练的客户端组成客户端子集s
t
,并将联邦模型w
t
下发至被选择的客户端k∈s
t
;
[0035]
3)客户端本地模型训练:如图3所示,客户端k在接收到服务器下发的联邦模型w
t
后,采用本地数据集dk对所得模型进行训练、更新本地模型参数其中,在本地训练阶段,客户端需要分别对本地模型的特征提取器和分类器进行优化,因此本地训练损失包括特征提取损失le和分类损失lf两部分,具体优化包括:
[0036]
3-1)特征提取器优化:采用特征提取损失le优化客户端本地模型训练阶段的特征提取器,使本地模型能够学习到更好的特征表示,客户端k在第t轮本地模型训练的过程中,对于某一输入的样本分别从联邦模型w
t
的中提取的特征表示从本地模型中提取的特征表示从上一轮本地模型中提取的特征表示为了方便表示,令由于联邦模型在每一轮聚合了所有参与训练的客户端学习到的特征表示,具有更好的泛化能力,因此联邦模型通常能够学习到比本地模型更好的特征表示,所以在特征提取损失le中,将和h
t
视为正对,从而拉近本地模型和联邦模型学习到的特征表示的距离,同时,为了使客户端的本地模型学习到比上一轮更好的特征表示,将当前轮次中本地模型学习到的特征表示和上一轮中本地模型学习到的特征表示视为负对,从而增加本地模型在当前轮次学习到的特征表示和上一轮学习到的特征表示的距离,达到特征提取器自身的更优化效果,客户端k在本地模型训练中的特征提取损失le表示为:
[0037][0038]
式中,是余弦相似度函数,且τ表示温度系数,本例中设置τ=0.5;
[0039]
3-2)分类器优化:采用分类损失lf优化客户端本地模型训练阶段的分类器,能够提高客户端在不平衡数据集上的模型训练效果,从而使最终聚合的联邦模型对类不平衡条件更加稳健,客户端k在本地模型训练中的分类损失lf表示为:
[0040][0041]
式中,nb为客户端本地训练中的批量大小,为客户端本地训练中的批量大小,表示在客户端k的一个批量大小中类别c的样本数量,c∈[1,...,c],为符号函数,如果样本的真实类别属于类别c则取1,否则取0,表示客户端k在第t轮本地训练中的分类器对样本属于类别c的预测概率,平衡系数用来调节客户端k中的类别c在分类损失中的权重,从而使分类损失能够适用于类别不平衡的数据,平衡系数的计算方式为:并且平衡系数能够在客户端本地训练中对不同样本数量的类别进行权重的平衡,即某一类别的样本数量越少,则平衡系数越大,使分类器更加注重提高样本数量较少的类别的训练效果,更进一步地,本例还考虑了分类器最终分类结果的准确性,除了修正不同类别中样本数量的不平衡,分类器还应该更多关注训练效果不好的样本即分类器输出的预测值与真实值之间差异更大的样本,所以增加一个惩罚项来纠正训练效果不好的样本,这使得训练效果越差的样本将获得越大的惩罚值,在本地模型训练阶段,本例将客户端k的本地损失函数表示为两部分,第一部分为特征提取损失le,第二部分为分类损失lf,客户端k的本地训练损失表示为:
[0042][0043]
式中,λ是控制特征提取损失和分类损失的权重的超参数,在第t轮本地模型训练中,客户端k采用接收到的联邦模型w
t
在本地数据集上采用学习率η进行随机梯度下降来更新第t轮的本地模型参数更新方式表示为:
[0044]
4)模型参数上传:参与本轮训练的客户端将本地模型更新的参数上传至服务器,用于更新联邦模型;
[0045]
5)模型参数聚合:服务器将接收到的本地模型参数进行加权平均聚合产生新一轮的联邦模型参数w
t+1
,聚合方式表示为:
[0046][0047]
6)判断:服务器判断当前通信轮数是否达到设定的通信轮数的阈值,本例中阈值设置为100,若未达到阈值,则返回步骤2)继续训练,否则结束联邦训练、向所有客户端发送最终的联邦模型参数;
[0048]
7)分类:联邦训练结束,各客户端将最终得到的联邦模型应用于本地任务,对本地任务中的待分类图片进行分类,对于本地任务中类不平衡分布的图片,各客户端能够准确地进行分类。
技术特征:
1.一种面向类不平衡数据分布的联邦学习方法,其特征在于,包括如下步骤:1)模型初始化:设立联邦学习整体架构,所示整体架构设有一个服务器和与服务器连接的n个客户端,初始化由特征提取器和分类器两部分组成的服务器端的联邦模型,具体为:每个客户端k都有其对应的本地数据集其中是数据集d
k
中第i个样本数据及其对应标签,i=1,2,...,n
k
,n
k
表示客户端k的数据集中样本的总量,每个样本数据都有其对应的类别c,且c∈[1,...,c],其中,表示客户端k的数据集中类别c的样本数量,且每个客户端k的数据集d
k
中不同类别的样本数量是不平衡的,初始化服务器端的联邦模型网络架构为具有参数w的神经网络,且模型由特征提取器e
w
(
·
)和分类器f
w
(
·
)两部分组成,其中,特征提取器e
w
(
·
),能够从输入的样本数据(x
i
,y
i
)中提取出特征表示,并将特征表示映射为一个d维特征向量,特征提取器中提取的特征表示向量将发送给分类器,待分类器对输入数据进行分类预测,分类器f
w
(
·
),能够为每个输入的数据生成分类预测结果;2)模型下发:服务器从所有客户端中随机选择一组客户端参与当前轮次的联邦训练,并下发联邦模型参数至被选择的客户端,具体为:在第t轮联邦训练中,服务器从n个客户端中随机选择一组参与当前轮次训练的客户端组成客户端子集s
t
,并将联邦模型w
t
下发至被选择的客户端k∈s
t
;3)客户端本地模型训练:客户端k在接收到服务器下发的联邦模型w
t
后,采用本地数据集d
k
对所得模型进行训练、更新本地模型参数其中,在本地训练阶段,客户端需要分别对本地模型的特征提取器和分类器进行优化,因此本地训练损失包括特征提取损失l
e
和分类损失l
f
两部分,具体优化包括:3-1)特征提取器优化:采用特征提取损失l
e
优化客户端本地模型训练阶段的特征提取器,使本地模型能够学习到更好的特征表示,客户端k在第t轮本地模型训练的过程中,对于某一输入的样本分别从联邦模型w
t
的中提取的特征表示从本地模型中提取的特征表示从上一轮本地模型中提取的特征表示令令在特征提取损失l
e
中,将和h
t
视为正对,将当前轮次中本地模型学习到的特征表示和上一轮中本地模型学习到的特征表示视为负对,客户端k在本地模型训练中的特征提取损失l
e
表示为:式中,是余弦相似度函数,且τ表示温度系数,设置τ=0.5;3-2)分类器优化:客户端k在本地模型训练中的分类损失l
f
表示为:
式中,n
b
为客户端本地训练中的批量大小,表示在客户端k的一个批量大小中类别c的样本数量,c∈[1,...,c],为符号函数,如果样本的真实类别属于类别c则取1,否则取0,表示客户端k在第t轮本地训练中的分类器对样本属于类别c的预测概率,平衡系数用来调节客户端k中的类别c在分类损失中的权重,使分类损失能够适用于类别不平衡的数据,平衡系数的计算方式为:并且增加一个惩罚项纠正训练效果不好的样本,在本地模型训练阶段将客户端k的本地损失函数表示为两部分,第一部分为特征提取损失l
e
,第二部分为分类损失l
f
,客户端k的本地训练损失表示为:式中,λ是控制特征提取损失和分类损失的权重的超参数,在第t轮本地模型训练中,客户端k采用接收到的联邦模型w
t
在本地数据集上采用学习率η进行随机梯度下降来更新第t轮的本地模型参数更新方式表示为:4)模型参数上传:参与本轮训练的客户端将本地模型更新的参数上传至服务器,用于更新联邦模型;5)模型参数聚合:服务器将接收到的本地模型参数进行加权平均聚合产生新一轮的联邦模型参数w
t+1
,聚合方式表示为:6)判断:服务器判断当前通信轮数是否达到设定的通信轮数的阈值,若未达到阈值,则返回步骤2)继续训练,否则结束联邦训练、向所有客户端发送最终的联邦模型参数;7)分类:联邦训练结束,各客户端将最终得到的联邦模型应用于本地任务,对本地任务中的待分类图片进行分类,对于本地任务中类不平衡分布的图片,各客户端进行分类。
技术总结
本发明公开了一种面向类不平衡数据分布的联邦学习方法,所述方法是将联邦模型分为特征提取器和分类器,客户端在本地模型训练阶段通过最大化本地模型与联邦模型中的特征表示的一致性优化个体客户端的本地特征提取器,同时,在模型训练过程中对不同类别的样本数据分配不同的损失权重,使模型训练时更多关注于样本数量少的类别从而优化客户端本地模型的分类器,纠正有偏的分类器。这种方法能缓解数据类别不平衡问题对模型训练的不利影响,提高模型性能和效率、方法简单,且有效。且有效。且有效。
技术研发人员:彭红艳 吴彤彤 石贞奎 李先贤
受保护的技术使用者:广西师范大学
技术研发日:2023.04.25
技术公布日:2023/8/9
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/