基于改进联邦学习的数据分类方法、系统及相关设备
未命名
08-03
阅读:77
评论:0

1.本发明涉及数据分类技术领域,尤其涉及的是一种基于改进联邦学习的数据分类方法、系统及相关设备。
背景技术:
2.随着科学技术的发展,数据分类技术的应用越来越广泛,例如,在对图像数据进行处理时,可以先对图像数据进行分类以提高后续的处理效率。具体的,可以训练对应的数据分类模型来对数据进行分类,以提高数据分类的效率。
3.现有技术中,基于传统的联邦学习技术训练数据分类模型,所有的客户端使用相同的全局模型进行本地训练,然后将本地模型的参数上传到服务器以聚合更新全局模型的参数,也即所有的客户端共用完全相同的模型参数。现有技术的问题在于,在训练过程中所有的客户端共用完全相同的模型参数会导致各个客户端损失掉本地模型的个性化信息,从而影响各个客户端训练出的数据分类模型对自身需要识别的数据的分类准确性,不利于提高数据分类的准确性。
4.因此,现有技术还有待改进和发展。
技术实现要素:
5.本发明的主要目的在于提供一种基于改进联邦学习的数据分类方法、系统及相关设备,旨在解决现有技术中基于传统的联邦学习技术训练数据分类模型以进行数据分类的方案中,在训练过程中所有的客户端共用完全相同的模型参数,导致各个客户端损失掉本地模型的个性化信息,不利于提高数据分类的准确性的问题。
6.为了实现上述目的,本发明第一方面提供一种基于改进联邦学习的数据分类方法,其中,上述基于改进联邦学习的数据分类方法包括:
7.目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取;
8.上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
9.可选的,上述数据分类模型是图像分类模型,上述待分类数据是待分类图像。
10.可选的,上述目标客户端由上述服务器选择确定。
11.可选的,上述目标客户端根据预设的改进联邦学习算法进行第t轮迭代时根据如下步骤进行模型参数调整:
12.上述目标客户端获取上述服务器下发的第t轮全局特征提取器模型参数和第t轮
全局分类器模型参数,其中,上述第t轮全局特征提取器模型参数和上述第t轮全局分类器模型参数由上述服务器根据第t-1轮迭代时的所有目标客户端对应的第t-1轮全局特征提取器模型更新参数和第t-1轮本地分类器模型更新参数计算获得;
13.上述目标客户端从本地存储的数据中获取第t-1轮本地分类器模型更新参数并作为第t轮本地分类器模型参数;
14.根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数。
15.可选的,上述根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数,包括:
16.上述目标客户端将其对应的待训练的数据分类模型中全局特征提取器的模型参数更新为上述第t轮全局特征提取器模型参数;
17.固定上述待训练的数据分类模型中全局特征提取器对应的第t轮全局特征提取器模型参数,根据上述目标客户端中的训练数据、上述第t轮本地分类器模型参数和固定的上述第t轮全局特征提取器模型参数计算上述训练数据对应的第一损失值,并根据上述第一损失值对上述第t轮本地分类器模型参数进行调整以获得第t轮本地分类器模型更新参数;
18.固定上述待训练的数据分类模型中全局分类器对应的第t轮全局分类器模型参数,根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数和固定的上述第t轮全局分类器模型参数计算上述训练数据对应的第二损失值,并根据上述第二损失值对上述第t轮全局特征提取器模型参数进行调整以获得第t轮全局特征提取器模型更新参数。
19.可选的,上述待训练的数据分类模型中各模型参数通过梯度下降方式进行更新。
20.可选的,在上述根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数之后,上述方法还包括:
21.上述目标客户端向上述服务器发送上述第t轮全局特征提取器模型更新参数和上述第t轮本地分类器模型更新参数,以触发上述服务器根据各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮全局特征提取器模型更新参数进行加权聚合计算获得第t+1轮全局特征提取器模型参数,根据上述各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮本地分类器模型更新参数进行加权聚合计算获得第t+1轮全局分类器模型参数。
22.本发明第二方面提供一种基于改进联邦学习的数据分类系统,其中,上述基于改进联邦学习的数据分类系统包括:
23.模型训练模块,用于控制目标客户端根据预设的改进联邦学习算法对该目标客户
端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取;
24.数据分类模块,用于控制上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
25.本发明第三方面提供一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的基于改进联邦学习的数据分类程序,上述基于改进联邦学习的数据分类程序被上述处理器执行时实现上述任意一种基于改进联邦学习的数据分类方法的步骤。
26.本发明第四方面提供一种计算机可读存储介质,上述计算机可读存储介质上存储有基于改进联邦学习的数据分类程序,上述基于改进联邦学习的数据分类程序被处理器执行时实现上述任意一种基于改进联邦学习的数据分类方法的步骤。
27.由上可见,本发明方案中,目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取;上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
28.与现有技术中相比,本发明方案中在迭代训练过程中,使用全局特征提取器模型参数和全局分类器模型参数作为各个目标客户端之间可以共享的模型参数,同时各个客户端还使用始终存储在其本地的本地分类器模型参数来保留自身的个性化信息。从而保证在各个目标客户端上的数据分类模型进行联邦学习训练的过程中仍然可以保留各个目标客户端对应的本地模型信息,使得各个目标客户端对应的数据分类模型在训练后能充分考虑该目标客户端上的数据的特性,从而有利于提高各个客户端训练出的数据分类模型对自身需要识别的数据的分类准确性,进而提高整体的数据分类准确性。
附图说明
29.为了更清楚地说明本发明实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
30.图1是本发明实施例提供的一种基于改进联邦学习的数据分类方法的流程示意图;
31.图2是本发明实施例提供的一种训练过程中计算损失值的流程示意图;
32.图3是本发明实施例提供的一种基于改进联邦学习的数据分类系统的结构示意图;
33.图4是本发明实施例提供的一种智能终端的内部结构原理框图。
具体实施方式
34.以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况下,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
35.应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
36.还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
37.在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是本发明还可以采用其它不同于在此描述的其它方式来实施,本领域技术人员可以在不违背本发明内涵的情况下做类似推广,因此本发明不受下面公开的具体实施例的限制。
38.随着科学技术的发展,数据分类技术的应用越来越广泛,例如,在对图像数据进行处理时,可以先对图像数据进行分类以提高后续的处理效率。具体的,可以训练对应的数据分类模型来对数据进行分类,以提高数据分类的效率。
39.传统的数据分类模型在训练时通常进行集中式训练,即需要先收集各方数据到一个计算中心或者中心服务器上,然后在计算中心或者中心服务器中基于收集的所有数据对数据分类模型进行训练。但在此过程中,将各方数据都收集到一个计算中心或者中心服务器上可能导致用户的隐私泄露,存在较大风险。
40.为了解决数据集中训练所带来的隐私泄露风险,可以采用联邦学习技术进行模型训练,联邦学习可以在不收集各方原始数据而仅收集各方模型参数的前提下训练数据分类模型。
41.现有技术中,基于现有的联邦学习算法可以实现在不进行源数据的共享的同时还能完成机器学习模型的训练,例如可以采用fedavg,fedavg的每一轮联邦学习训练可以分为四个步骤:服务器选择部分客户端,并向它们分发全局模型;客户端初始化本地模型(即将本地模型参数替换为全局模型参数);客户端在本地数据上,执行若干次的本地模型更新;服务端收集客户端完成本地更新后的本地模型,并对这些模型参数进行加权聚合,产生下一轮的全局模型。
42.当联邦学习的各个参与方的本地数据是独立同分布时,fedavg算法训练的模型能取得与中心化训练得到的模型相当的性能。但是当各个参与方的本地数据是非独立同分布时,fedavg算法训练的模型性能会显著下降,无法达到中心化训练模型相同的效果。具体的,在fedavg算法的第二步中,客户端需要将本地模型参数替换为全局模型参数,该做法在
数据是非独立同分布时,损失掉了本地模型的个性化信息。为了保留联邦学习本地模型更多的个性化信息,提升联邦学习训练的模型在数据是非独立同分布时候的性能,需要对fedavg算法进行改进。
43.现有技术中,基于传统的联邦学习技术训练数据分类模型,所有的客户端使用相同的全局模型进行本地训练,然后将本地模型的参数上传到服务器以聚合更新全局模型的参数,也即所有的客户端共用完全相同的模型参数。现有技术的问题在于,在训练过程中所有的客户端共用完全相同的模型参数会导致各个客户端损失掉本地模型的个性化信息,从而影响各个客户端训练出的数据分类模型对自身需要识别的数据的分类准确性,不利于提高数据分类的准确性。
44.在一种应用场景中,可以将联邦学习训练的模型的所有神经网路层分为共享层和个性化层两部分。每一轮将客户端模型的共享层上传给服务器进行联邦学习,但是个性化层则一直保留在本地进行本地训练。这种基于部分层共享的方法可以保留本地个性化信息,提升联邦学习模型性能。该方案中,通过让部分层只进行本地更新保留了本地模型的个性化信息的,但是也使得上传给服务器的共享模型缺失了个性化层的信息。
45.为了同时保留本地模型的个性化信息又可以保证所有层信息的有效共享,本发明实施例中进行进一步改进,提出一种采用双分类器的个性化联邦学习技术,进一步提高联邦学习模型的性能。
46.具体的,以图像分类为例进行具体说明,则上述数据分类模型可以具体为图像分类模型。图像分类模型是用于对图像进行分类的神经网络模型,一般由若干卷积层、池化层、线性层构成。一个训练好的图像分类模型可以对图像进行分类。它的输入是图片,输出是图片所属的类别。
47.传统图像分类模型通常采用集中式的训练方式,要求各参与方将本地数据上传到服务器进行处理。但是对于一些敏感数据,如个人图像数据,用户不希望将其发送给中央服务器进行处理。联邦学习是一种分布式模型训练框架,可以作为一种解决图像分类模型训练时隐私问题的解决方案。联邦学习可以使得每个客户端只需要将自己的本地模型上传到服务器,而不需要将本地数据上传,从而保护数据隐私性。在传统的联邦学习中,所有客户端都使用相同的全局模型进行本地训练,然后将本地模型的参数上传到服务器进行聚合,更新全局模型的参数。然而在实际应用中,图像分类数据通常是非独立同分布的。这是因为来自不同用户、设备、环境等的图像数据通常具有不同的数据分布和特征。当各个参与方数据是非独立同分布时,他们之间的数据分布差异可能比较大。此时像fedavg这种只训练单一全局模型的算法无法捕获到各个客户端的个性化信息,导致最终得到的模型可能无法在每个客户端数据上达到最佳测试准确率。
48.例如,在一种应用场景中,需要联合多部手机上的相册数据来训练图像分类模型,因为每个用户的使用习惯不同,所以不同用户的手机上的图片会呈现不同的分布。有的用户手机里拍摄了更多植物类别的照片,有些用户拍摄了更多动物的照片。传统的联邦学习最终得到只有唯一的一个模型,这个模型是平均化的,所以无法保证该模型就能够对每个用户的相册图片都做到很好的分类。
49.因此,本发明实施例中,对现有的fedavg算法做出改进,提出一种采用双分类器的个性化联邦学习方案,解决参与方数据非独立同分布导致的联邦学习训练的图像分类模型
测试性能下降的问题,提升联邦学习训练模型的分类准确率。具体的,与fedavg相比,本发明实施例方案让模型的一部分层的参数保留在本地,不被替换掉,这部分层保留了更多的个性化信息。由于这部分层的存在,最终每个客户端得到的模型都是不一样的,是定制化的,可以在本地数据上做更好的分类。
50.为了解决上述多个问题中的至少一个问题,本发明方案中,目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取;上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
51.与现有技术中相比,本发明方案中在迭代训练过程中,使用全局特征提取器模型参数和全局分类器模型参数作为各个目标客户端之间可以共享的模型参数,同时各个客户端还使用始终存储在其本地的本地分类器模型参数来保留自身的个性化信息。从而保证在各个目标客户端上的数据分类模型进行联邦学习训练的过程中仍然可以保留各个目标客户端对应的本地模型信息,使得各个目标客户端对应的数据分类模型在训练后能充分考虑该目标客户端上的数据的特性,从而有利于提高各个客户端训练出的数据分类模型对自身需要识别的数据的分类准确性,进而提高整体的数据分类准确性。
52.如图1所示,本发明实施例提供一种基于改进联邦学习的数据分类方法,具体的,上述方法包括如下步骤:
53.步骤s100,目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取。
54.其中,上述目标客户端是需要进行模型训练的客户端,在模型训练过程中,目标客户端根据预设的改进联邦学习算法进行模型迭代训练。整个训练过程中客户端仅与服务器交流模型参数,而不会交流训练数据。改进联邦学习算法的每一轮联邦训练有四个阶段,包括:模型分发阶段,服务器选择部分客户端参与本轮联邦学习,并向它们下发全局模型(包括全局特征提取器和分类器);本地初始化阶段,客户端利用收到的全局模型对本地模型进行初始化,获得本轮本地更新的初始化模型,具体的,对于fedavg等方法此步骤客户端将整个本地模型替换为收到的全局模型,意味着旧的本地模型将被全部丢弃,在本实施例中客户端只将本地模型的特征提取器替换为全局模型的特征提取器,本地分类器依旧保留着,全局模型中的分类器则作为客户端的第二个分类器指导本地更新训练;本地更新阶段,采用双分类器训练策略(即使用本地分类器模型参数来保留自身的个性化信息,而全局分类
器模型参数则与其它客户端共享)进行迭代训练;聚合阶段,客户端将本地特征提取器和本地分类器上传至服务端聚合以得到下一轮的全局模型。
55.上述全局特征提取器模型参数是数据分类模型的特征提取器对应的全局模型参数,上述全局分类器模型参数是数据分类模型的分类器对应的全局模型参数,全局模型参数即可以由所有的目标客户端(或客户端)共享的模型参数;上述本地分类器模型参数是数据分类模型的分类器对应的本地模型参数,值存储在对应目标客户端的本地数据中,并不会共享给其它目标客户端,用于保存对应目标客户端的本地个性化信息。需要说明的是,上述特征提取器包括数据分类模型的特征提取层,分类器则是数据分类模型的最后一层(即用于进行分类的层)。
56.本实施例中,上述数据分类模型是图像分类模型,上述待分类数据是待分类图像。基于本实施例中的数据分类方法可以对图像数据进行分类,此时对应的数据分类模型是图像分类模型。在一种应用场景中,还可以基于本实施例中的数据分类方法对其他类型的数据进行分类,例如对文本数据、语音数据等进行分类,此时对应的数据分类模型为文本分类模型或语音分类模型,在此不作具体限定。
57.需要说明的是,本实施例中的数据分类模型的训练和测试过程基于图像分类数据集(例如cifar10和cifar100)进行,根据该数据集对图像分类模型进行训练,且保证训练过程中数据不会离开本地设备或本地客户端(即本地数据不会被共享),获得更好的训练效果和隐私保护效果。
58.本实施例中,上述目标客户端由上述服务器选择确定,一个目标客户端上设置有一个与该目标客户端对应的待训练的数据分类模型。需要说明的是,不同的目标客户端上设置的数据分类模型的结构是完全相同的,但训练过程中,各个目标客户端上会保留体现自己个性化信息的本地分类器模型参数,因此不同目标客户端上最终训练获得的已训练的数据分类模型并不相同(结构相同但参数不同),训练后可以获得针对该目标客户端上的数据特点的数据分类模型。其中,一个服务器可以与多个客户端通信,并进行模型参数的交互,在训练过程中,服务器根据实际需求从预设的多个客户端中选择确定需要进行模型训练的客户端作为目标客户端。
59.本实施例中,上述目标客户端根据预设的改进联邦学习算法进行第t轮迭代时根据如下步骤进行模型参数调整:
60.上述目标客户端获取上述服务器下发的第t轮全局特征提取器模型参数和第t轮全局分类器模型参数,其中,上述第t轮全局特征提取器模型参数和上述第t轮全局分类器模型参数由上述服务器根据第t-1轮迭代时的所有目标客户端对应的第t-1轮全局特征提取器模型更新参数和第t-1轮本地分类器模型更新参数计算获得;
61.上述目标客户端从本地存储的数据中获取第t-1轮本地分类器模型更新参数并作为第t轮本地分类器模型参数;
62.根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数。
63.其中,t为整数,代表进行的迭代轮次。需要说明的是,各目标客户端分别对其中设
置的待训练的数据分类模型进行多轮迭代训练,直到满足预设的训练停止条件,从而获得已训练的数据分类模型。各个目标客户端对应的迭代训练次数根据实际需求确定,在此不作具体限定。具体的,上述预设的训练停止条件也可以根据实际需求设置,例如可以设置为迭代次数达到预设的迭代次数阈值,或者模型计算获得损失值(可以是第一损失值、第二损失值或两者的均值)小于预设的损失阈值,还可以设置其它条件,在此不作具体限定。
64.进一步的,上述根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数,包括:
65.上述目标客户端将其对应的待训练的数据分类模型中全局特征提取器的模型参数更新为上述第t轮全局特征提取器模型参数;
66.固定上述待训练的数据分类模型中全局特征提取器对应的第t轮全局特征提取器模型参数,根据上述目标客户端中的训练数据、上述第t轮本地分类器模型参数和固定的上述第t轮全局特征提取器模型参数计算上述训练数据对应的第一损失值,并根据上述第一损失值对上述第t轮本地分类器模型参数进行调整以获得第t轮本地分类器模型更新参数;
67.固定上述待训练的数据分类模型中全局分类器对应的第t轮全局分类器模型参数,根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数和固定的上述第t轮全局分类器模型参数计算上述训练数据对应的第二损失值,并根据上述第二损失值对上述第t轮全局特征提取器模型参数进行调整以获得第t轮全局特征提取器模型更新参数。
68.其中,固定对应的模型参数代表该部分参数在当次处理过程中不进行更新。本实施例中,上述待训练的数据分类模型中各模型参数通过梯度下降方式进行更新。具体的,先通过前向传播获取预测值,根据预测值与真实标注值计算获取损失值,然后反向传播计算获取模型参数的梯度(被固定的参数不会计算),然后根据计算获取的梯度采用梯度下降法更新模型参数。
69.进一步的,在上述根据上述目标客户端中的训练数据、上述第t轮全局特征提取器模型参数、上述第t轮全局分类器模型参数以及上述第t轮本地分类器模型参数对上述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得上述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数之后,上述方法还包括:
70.上述目标客户端向上述服务器发送上述第t轮全局特征提取器模型更新参数和上述第t轮本地分类器模型更新参数,以触发上述服务器根据各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮全局特征提取器模型更新参数进行加权聚合计算获得第t+1轮全局特征提取器模型参数,根据上述各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮本地分类器模型更新参数进行加权聚合计算获得第t+1轮全局分类器模型参数。
71.步骤s200,上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
72.在模型训练完成之后各个目标客户端上获得对应的已训练的数据分类模型,将待
分类数据(例如图片)输入到数据分类模型(例如图像分类模型)中,即可以获得对应的分类(即目标类别)。
73.本实施例中,还基于一种具体应用场景对上述数据分类模型的训练过程进行具体说明。具体的,在第t轮联邦训练过程中,被服务器选中的目标客户端接收从服务器下发的整个全局模型,包括第t轮全局特征提取器模型参数和第t轮全局分类器模型参数需要说明的是,当t等于1,即进行第1轮训练时,可以使用服务器中预设的初始模型参数作为第1轮全局特征提取器模型参数和第1轮全局分类器模型参数。
74.在开始本地训练之前,目标客户端首先对本地模型(即该目标客户端中的待训练模型)进行初始化,具体的,将本地模型的特征提取器的模型参数替换为最新的第t轮全局特征提取器模型参数需要说明的是,在替换之前本地模型的特征提取器的模型参数的值为在进行第t-1轮时获得的第t-1轮全局特征提取器模型更新参数,在替换之后此时,每一个目标客户端中保存有三个部分的模型参数:特征提取器对应的模型参数由最新的全局特征提取器替换而来,具体的目标客户端中特征提取器对应的模型参数为目标客户端中本地分类器的模型参数是一直保留在本地的参数,具体为上一轮本地训练后获得的结果即目标客户端从本地存储的数据中获取第t-1轮本地分类器模型更新参数并作为第t轮本地分类器模型参数目标客户端中全局分类器(也即共享分类器)是最新全局模型中的分类器,其模型参数为从服务器获得的第t轮全局分类器模型参数基于上述初始化后的模型参数,本地模型(即目标客户端上的模型)在本地数据(即目标客户端上的训练数据)上进行更新迭代,本实施例中,记客户端某次采样mini-batch的数据为ξ,记l(θ,φ;ξ)为损失函数,客户端将执行以下操作对本地的特征提取器参数和本地分类器参数分别进行更新。
75.首先,目标客户端进行客户端模型参数的更新。具体的,先更新本地分类器模型参数,获得第t轮本地分类器模型更新参数。如下公式(1)所示,在进行前向传播之前,本地特征提取器参数将被固定,目标客户端输入采样数据ξ先后送入本地特征提取器和本地分类器得到预测值后计算损失函数以获得第一损失值,接着反向传播更新本地分类器参数即获得更新后的第t轮本地分类器模型更新参数:
[0076][0077]
其中,上述公式(1)代表在进行客户端模型参数的更新过程中对本地分类器参数的更新,等式左边代表更新后获得的本地分类器模型的参数,也可以代表第t轮本地分类器模型更新参数,等式右边的代表第t轮本地分类器模型参数,且等式右边的的值实际上等于等式右边的代表固定的上述第t轮全局特征提取器模型参数,其取值
具体等于
[0078]
本实施例中,将上述目标客户端中的训练数据输入由上述第t轮本地分类器模型参数和固定的上述第t轮全局特征提取器模型参数计算上述训练数据对应的第一损失值,从而进行模型参数调整。代表对求梯度,ξ代表该次采样的mini-batch的数据(即训练数据)。ηc代表对本地分类器进行更新时的学习率,可以根据实际需求设置和调整,li代表损失计算函数。然后,更新本地特征提取器参数,更新前其值与第t轮全局特征提取器模型参数相同,因此也可以视为对第t轮全局特征提取器模型参数的更新,获得第t轮全局特征提取器模型更新参数。
[0079]
如下公式(2)所示,在进行前向传播之前共享分类器参数将被固定,客户端输入采样数据ξ先后送入本地特征提取器和共享分类器得到预测值后计算损失函数接着反向传播更新本地特征提取器参数即获得更新后的第t轮全局特征提取器模型更新参数:
[0080][0081]
其中,上述公式(2)代表在进行客户端模型参数的更新过程中对特征提取器的模型参数进行更新,但特征提取器的模型参数采用的是全局共享的第t轮全局特征提取器模型参数,因此也可以视为对第t轮全局特征提取器模型参数的更新。上述公式(2)等式左边是更新后的特征提取器的模型参数,也即获得的第t轮全局特征提取器模型更新参数,等式右边的代表进行客户端更新前的第t轮全局特征提取器模型参数,且等式右边的的值实际上等于实际上等于代表对求梯度,ηe代表对本地特征提取器进行更新时候的学习率。需要说明的是,本实施例中使用的损失函数li为交叉熵损失函数。
[0082]
如上所示,在更新本地分类器参数时,本地特征提取器参数并没有发生变化,因此本地特征提取器部分只需要进行一次前向传播。在更新本地的特征提取器时可以直接将更新本地分类器时得到特征提取器输出输入到全局分类器中(即实现神经网络前向传播的过程)而不需要重复计算。具体的,本地特征提取器只进行了一次更新,本地特征提取器也只执行了一次前向传播。假设输入数据x先输入特征提取器得到b,b输入本地分类器,然后得到预测值,反向传播更新本地分类器。然后紧接着,可以直接再把得到的中间结果b输入到全局分类中去预测,然后更新特征提取器。这个过程中,由x到b只计算了一次,能够有效提高模型训练效率。
[0083]
本实施例中,先根据全局的特征提取器和本地分类器对本地分类器进行训练更新;然后根据全局的特征提取器和全局的分类器对本地特征提取器进行训练更新;如此,本地分类器和全局分类器对应的模型参数数据都得到了利用,有利于提高训练的准确性。使用了两种不同的分类器加入到了模型训练更新的过程,保留模型的本地个性化信息。训练中本地分类器因为其参数自始至终没有被全局模型替换掉,所以它保留了很多客户端相关的个性化信息,而全局分类器中含有来自别的客户端的共享信息。相比fedavg,本实施例中保留了个性化信息,同时又保证了共享的信息不会丢失。
[0084]
进一步的,在完成客户端的更新之后,进行服务端的更新。每个目标客户端在完成本地模型参数的更新之后,将进行第t轮迭代之后获得的第t轮全局特征提取器模型更新参数和第t轮本地分类器模型更新参数上传到服务器,以触发上述服务器根据各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮全局特征提取器模型更新参数进行加权聚合计算获得第t+1轮全局特征提取器模型参数,根据上述各上述目标客户端的训练数据中的样本数量和各上述目标客户端对应的第t轮本地分类器模型更新参数进行加权聚合计算获得第t+1轮全局分类器模型参数。
[0085]
具体的,假设第t轮联邦训练共有k个目标客户端参与,它们的本地样本总数为n,每个客户端的本地样本数量为nk。服务器收到客户端上传的模型参数后,将根据各个客户端本地训练数据的样本数量,根据如下公式(3)和(4)对上述上传到服务器的参数进行加权平均聚合计算新一轮的全局模型参数(θ
t+1
,φ
t+1
):
[0086][0087][0088]
其中,代表第t+1轮全局特征提取器模型参数,代表第t+1轮全局分类器模型参数。服务器计算完一轮全局模型后将接着选择部分目标客户端下发模型参数,执行下一轮的联邦训练。
[0089]
图2是本发明实施例提供的一种训练过程中计算损失值的流程示意图,其中,表示输入数据输入神经网络后,前向传播经过全局特征提取器和全局分类器后得到的预测值,表示输入数据输入神经网络后,前向传播经过全局特征提取器和本地分类器后得到的预测值。如图2所示,在进行迭代训练过程中,目标客户端将训练数据ξ输入到待训练的数据分类模型中,根据更新后本地的全局特征提取器的参数和本地分类器的参数计算获得第一种预测类别根据和训练数据ξ对应的标注类别值y计算获得第一损失值。同样的,根据更新后本地的全局特征提取器的参数和全局分类器的参数计算获得第二种预测类别根据和训练数据ξ对应的标注类别值y计算获得第二损失值;根据第一损失值和第二损失值可以对模型参数进行对应调整,反复迭代,直到满足训练停止的条件。
[0090]
由上可见,本实施例方案中,在迭代训练过程中,使用全局特征提取器模型参数和全局分类器模型参数作为各个目标客户端之间可以共享的模型参数,同时各个客户端还使用始终存储在其本地的本地分类器模型参数来保留自身的个性化信息。从而保证在各个目标客户端上的数据分类模型进行联邦学习训练的过程中仍然可以保留各个目标客户端对应的本地模型信息,使得各个目标客户端对应的数据分类模型在训练后能充分考虑该目标客户端上的数据的特性,从而有利于提高各个客户端训练出的数据分类模型对自身需要识别的数据的分类准确性,进而提高整体的数据分类准确性。
[0091]
本实施例中提出的基于改进联邦学习的数据分类方法在参与方数据是非独立同分布时训练出来的模型性能要优于fedavg算法,当数据是独立同分布时训练的模型性能也能达到甚至超过fedavg算法训练出来的模型。通信效率方面,与fedavg算法相比,本实施例方案仅仅多了一个分类层的前向和反向传播的过程,而且分类器层仅仅只有一层神经网络,因此不会带来很多额外的计算开销。在可拓展性方面,本发明实施例方案可应用于fedavg算法及fedavg的一些改进算法(如fedprox)上,提升原算法的性能。
[0092]
如图3中所示,对应于上述基于改进联邦学习的数据分类方法,本发明实施例还提供一种基于改进联邦学习的数据分类系统,上述基于改进联邦学习的数据分类系统包括:
[0093]
模型训练模块310,用于控制目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,上述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对上述待训练的数据分类模型的模型参数进行调整,上述全局特征提取器模型参数和上述全局分类器模型参数由上述目标客户端从服务器获取,上述本地分类器模型参数由上述目标客户端根据本地存储的数据获取;
[0094]
数据分类模块320,用于控制上述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对上述待分类数据进行分类并获取上述待分类数据对应的目标类别。
[0095]
具体的,本实施例中,上述基于改进联邦学习的数据分类系统及其各模块的具体功能可以参照上述基于改进联邦学习的数据分类方法中的对应描述,在此不再赘述。
[0096]
需要说明的是,上述基于改进联邦学习的数据分类系统的各个模块的划分方式并不唯一,在此也不作为具体限定。
[0097]
基于上述实施例,本发明还提供了一种智能终端,其原理框图可以如图4所示。上述智能终端包括处理器及存储器。该智能终端的存储器包括基于改进联邦学习的数据分类程序,存储器为基于改进联邦学习的数据分类程序的运行提供环境。该基于改进联邦学习的数据分类程序被处理器执行时实现上述任意一种基于改进联邦学习的数据分类方法的步骤。需要说明的是,上述智能终端还可以包括其它功能模块或单元,在此不作具体限定。
[0098]
本领域技术人员可以理解,图4中示出的原理框图,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的智能终端的限定,具体地智能终端可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
[0099]
本发明实施例还提供一种计算机可读存储介质,上述计算机可读存储介质上存储有基于改进联邦学习的数据分类程序,上述基于改进联邦学习的数据分类程序被处理器执行时实现本发明实施例提供的任意一种基于改进联邦学习的数据分类方法的步骤。
[0100]
应理解,上述实施例中各步骤的序号大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
[0101]
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将上述系统的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可
以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
[0102]
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
[0103]
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各实例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟是以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
[0104]
在本发明所提供的实施例中,应该理解到,所揭露的系统/智能终端和方法,可以通过其它的方式实现。例如,以上所描述的系统/智能终端实施例仅仅是示意性的,例如,上述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以由另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
[0105]
上述集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,上述计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,上述计算机程序包括计算机程序代码,上述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。上述计算机可读介质可以包括:能够携带上述计算机程序代码的任何实体或装置、记录介质、u盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(rom,read-only memory)、随机存取存储器(ram,random access memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,上述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减。
[0106]
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解;其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不是相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
技术特征:
1.一种基于改进联邦学习的数据分类方法,其特征在于,所述方法包括:目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,所述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对所述待训练的数据分类模型的模型参数进行调整,所述全局特征提取器模型参数和所述全局分类器模型参数由所述目标客户端从服务器获取,所述本地分类器模型参数由所述目标客户端根据本地存储的数据获取;所述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对所述待分类数据进行分类并获取所述待分类数据对应的目标类别。2.根据权利要求1所述的基于改进联邦学习的数据分类方法,其特征在于,所述数据分类模型是图像分类模型,所述待分类数据是待分类图像。3.根据权利要求1所述的基于改进联邦学习的数据分类方法,其特征在于,所述目标客户端由所述服务器选择确定。4.根据权利要求1所述的基于改进联邦学习的数据分类方法,其特征在于,所述目标客户端根据预设的改进联邦学习算法进行第t轮迭代时根据如下步骤进行模型参数调整:所述目标客户端获取所述服务器下发的第t轮全局特征提取器模型参数和第t轮全局分类器模型参数,其中,所述第t轮全局特征提取器模型参数和所述第t轮全局分类器模型参数由所述服务器根据第t-1轮迭代时的所有目标客户端对应的第t-1轮全局特征提取器模型更新参数和第t-1轮本地分类器模型更新参数计算获得;所述目标客户端从本地存储的数据中获取第t-1轮本地分类器模型更新参数并作为第t轮本地分类器模型参数;根据所述目标客户端中的训练数据、所述第t轮全局特征提取器模型参数、所述第t轮全局分类器模型参数以及所述第t轮本地分类器模型参数对所述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得所述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数。5.根据权利要求4所述的基于改进联邦学习的数据分类方法,其特征在于,所述根据所述目标客户端中的训练数据、所述第t轮全局特征提取器模型参数、所述第t轮全局分类器模型参数以及所述第t轮本地分类器模型参数对所述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得所述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数,包括:所述目标客户端将其对应的待训练的数据分类模型中全局特征提取器的模型参数更新为所述第t轮全局特征提取器模型参数;固定所述待训练的数据分类模型中全局特征提取器对应的第t轮全局特征提取器模型参数,根据所述目标客户端中的训练数据、所述第t轮本地分类器模型参数和固定的所述第t轮全局特征提取器模型参数计算所述训练数据对应的第一损失值,并根据所述第一损失值对所述第t轮本地分类器模型参数进行调整以获得第t轮本地分类器模型更新参数;固定所述待训练的数据分类模型中全局分类器对应的第t轮全局分类器模型参数,根据所述目标客户端中的训练数据、所述第t轮全局特征提取器模型参数和固定的所述第t轮全局分类器模型参数计算所述训练数据对应的第二损失值,并根据所述第二损失值对所述
第t轮全局特征提取器模型参数进行调整以获得第t轮全局特征提取器模型更新参数。6.根据权利要求5所述的基于改进联邦学习的数据分类方法,其特征在于,所述待训练的数据分类模型中各模型参数通过梯度下降方式进行更新。7.根据权利要求5所述的基于改进联邦学习的数据分类方法,其特征在于,在所述根据所述目标客户端中的训练数据、所述第t轮全局特征提取器模型参数、所述第t轮全局分类器模型参数以及所述第t轮本地分类器模型参数对所述待训练的数据分类模型的模型参数进行第t轮迭代更新,以获得所述待训练的数据分类模型对应的第t轮全局特征提取器模型更新参数以及第t轮本地分类器模型更新参数之后,所述方法还包括:所述目标客户端向所述服务器发送所述第t轮全局特征提取器模型更新参数和所述第t轮本地分类器模型更新参数,以触发所述服务器根据各所述目标客户端的训练数据中的样本数量和各所述目标客户端对应的第t轮全局特征提取器模型更新参数进行加权聚合计算获得第t+1轮全局特征提取器模型参数,根据所述各所述目标客户端的训练数据中的样本数量和各所述目标客户端对应的第t轮本地分类器模型更新参数进行加权聚合计算获得第t+1轮全局分类器模型参数。8.一种基于改进联邦学习的数据分类系统,其特征在于,所述系统包括:模型训练模块,用于控制目标客户端根据预设的改进联邦学习算法对该目标客户端对应的待训练的数据分类模型进行模型迭代训练,获得该目标客户端对应的已训练的数据分类模型,其中,所述目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对所述待训练的数据分类模型的模型参数进行调整,所述全局特征提取器模型参数和所述全局分类器模型参数由所述目标客户端从服务器获取,所述本地分类器模型参数由所述目标客户端根据本地存储的数据获取;数据分类模块,用于控制所述目标客户端获取待分类数据,通过该目标客户端对应的已训练的数据分类模型对所述待分类数据进行分类并获取所述待分类数据对应的目标类别。9.一种智能终端,其特征在于,所述智能终端包括存储器、处理器以及存储在所述存储器上并可在所述处理器上运行的基于改进联邦学习的数据分类程序,所述基于改进联邦学习的数据分类程序被所述处理器执行时实现如权利要求1-7任意一项所述基于改进联邦学习的数据分类方法的步骤。10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有基于改进联邦学习的数据分类程序,所述基于改进联邦学习的数据分类程序被处理器执行时实现如权利要求1-7任意一项所述基于改进联邦学习的数据分类方法的步骤。
技术总结
本发明公开了一种基于改进联邦学习的数据分类方法、系统及相关设备,方法包括:目标客户端根据预设的改进联邦学习算法对其对应的待训练的数据分类模型进行模型迭代训练,获得对应的已训练的数据分类模型,目标客户端根据预设的改进联邦学习算法进行一轮迭代时,基于全局特征提取器模型参数、全局分类器模型参数和本地分类器模型参数对待训练的数据分类模型的模型参数进行调整,全局特征提取器模型参数和全局分类器模型参数由目标客户端从服务器获取,本地分类器模型参数由目标客户端从本地存储的数据获取;目标客户端获取待分类数据,通过对应的已训练的数据分类模型进行分类获取待分类数据对应的目标类别。本发明有利于提高数据分类的准确性。提高数据分类的准确性。提高数据分类的准确性。
技术研发人员:刘洋 王家勃 王轩 刘川意 漆舒汉 陈斌 王强
受保护的技术使用者:哈尔滨工业大学(深圳)
技术研发日:2023.03.15
技术公布日:2023/8/1
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/