联邦模型训练方法、系统、设备及计算机可读存储介质与流程

文档序号:17892022发布日期:2019-06-13 15:43阅读:205来源:国知局
联邦模型训练方法、系统、设备及计算机可读存储介质与流程

本发明涉及机器学习技术领域,尤其涉及一种联邦模型训练方法、系统、设备及计算机可读存储介质。



背景技术:

联邦模型是利用技术算法加密建造的机器学习模型,联邦学习系统中的多个联邦客户端在模型训练时不用给出己方数据,而是根据协作端下发的参数加密的全局模型和客户端本地的数据集来训练本地模型,并返回本地模型参数供协作端聚合更新全局模型,更新后的全局模型重新下发到客户端,循环往复,直到收敛。联邦学习通过加密机制下参数交换的方式保护客户端数据隐私,客户端数据和客户端的本地模型本身不会进行传输,本地数据不会被反猜,联邦模型在较高程度保持数据完整性的同时,保障了数据隐私。

目前,协作端根据多个客户端返回的本地模型参数聚合更新全局模型时,只是对多个客户端的模型参数做简单平均,将平均后的模型参数作为新的全局模型参数下发至客户端继续迭代训练,然而,实际训练中,每个客户端由于其训练数据的不同,训练出的本地模型的预测性能也是参差不齐的,现有的简单平均的聚合方法会导致全局模型的效果不理想。



技术实现要素:

本发明的主要目的在于提供一种联邦模型训练方法、系统、设备及计算机可读存储介质,旨在解决现有的协作端对多个联邦客户端的模型参数采用简单平均的聚合方式来更新全局模型而导致的联邦模型效果不理想的技术问题。

为实现上述目的,本发明提供一种联邦模型训练方法,所述联邦模型训练方法包括步骤:

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;

根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;

检测所述第二全局模型是否收敛;

若检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端。

可选地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率;

基于每个预测模型的所述预测误差率及预设的计算公式,分别计算得到每个模型参数对应的权重系数。

可选地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率的步骤包括:

当协作终端接收到多个客户终端分别发送的模型参数后,将预设的测试样本集中的多个测试样本输入至所述模型参数对应的预测模型中进行预测,得到所述预测模型针对每个所述测试样本的预测值;

根据多个所述预测值,获取所述测试样本集中预测结果错误的测试样本的数量;

将所述预测结果错误的测试样本的数量与所述测试样本集中全部测试样本数量的比值确定为所述预预测模型的预测误差率。

可选地,所述根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型的步骤包括:

分别将每个模型参数与其对应的权重系数相乘,得到多个相乘后的结果;

将所述多个相乘后的结果相加,并将相加结果确定为第二全局模型的模型参数,得到所述第二全局模型。

可选地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前还包括:

发送参数加密的第一全局模型分别至多个客户终端;

接收所述多个客户终端分别发送的模型参数;

其中,所述客户终端在接收到协作终端下发的所述第一全局模型后,所述客户终端根据所述第一全局模型对第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数。

可选地,所述检测所述第二全局模型是否收敛的步骤之后还包括:

若检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,以使所述多个客户终端分别根据协作终端下发的所述第二全局模型继续迭代训练以返回模型参数至所述协作终端。

可选地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前还包括:

接收多个客户终端分别发送的模型参数;

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

接收所述多个客户终端分别发送的与所述模型参数对应的权重系数;其中,所述多个客户终端分别根据预设的测试样本集,测试并得到所述模型参数对应的预测模型的预测误差率,并根据所述预测误差率及预设的计算公式,计算得到所述模型参数对应的权重系数。

此外,本发明还提出一种联邦模型训练系统,所述系统包括协作终端及分别与所述协作终端通信连接的多个客户终端,所述协作终端包括:

获取模块,用于在接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;

聚合更新模块,用于根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;

检测模块,用于检测所述第二全局模型是否收敛;

确定模块,用于在所述检测模块检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端。

可选地,所述获取模块包括:

测试单元,用于在接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率;

计算单元,用于基于每个预测模型的所述预测误差率及预设的计算公式,分别计算得到每个模型参数对应的权重系数。

可选地,所述测试单元包括:

测试子单元,用于在接收到多个客户终端分别发送的模型参数后,将预设的测试样本集中的多个测试样本输入至所述模型参数对应的预测模型中进行预测,得到所述预测模型针对每个所述测试样本的预测值;

获取子单元,用于根据多个所述预测值,获取所述测试样本集中预测结果错误的测试样本的数量;

确定子单元,用于将所述预测结果错误的测试样本的数量与所述测试样本集中全部测试样本数量的比值确定为所述预预测模型的预测误差率。

可选地,所述聚合更新模块包括:

乘处理单元,用于分别将每个模型参数与其对应的权重系数相乘,得到多个相乘后的结果;

更新单元,用于将所述多个相乘后的结果相加,并将相加结果确定为第二全局模型的模型参数,得到所述第二全局模型。

可选地,所述协作终端还包括:

第一下发模块,用于发送参数加密的第一全局模型分别至多个客户终端;

接收模块,用于接收所述多个客户终端分别发送的模型参数;

其中,所述客户终端在接收到所述第一下发模块下发的所述第一全局模型后,所述客户终端根据所述第一全局模型对第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数。

可选地,所述协作终端还包括:

第二下发模块,用于在所述检测模块检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,以使所述多个客户终端分别根据所述第二下发模块下发的所述第二全局模型继续迭代训练以返回模型参数至所述协作终端。

可选地,所述获取模块,还用于接收所述多个客户终端分别发送的与所述模型参数对应的权重系数;其中,所述多个客户终端分别根据预设的测试样本集,测试并得到所述模型参数对应的预测模型的预测误差率,并根据所述预测误差率及预设的计算公式,计算得到所述模型参数对应的权重系数。

此外,为实现上述目的,本发明还提出一种联邦模型训练设备,所述联邦模型训练设备包括存储器、处理器和存储在所述存储器上并可在所述处理器上运行的联邦模型训练程序,所述联邦模型训练程序被所述处理器执行时实现如上所述的联邦模型训练方法的步骤。

此外,为实现上述目的,本发明还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有联邦模型训练程序,所述联邦模型训练程序被处理器执行时实现如上所述的联邦模型训练方法的步骤。

本发明通过当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;检测所述第二全局模型是否收敛;若检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端;由此,协作终端根据联邦多方的客户端返回的模型参数聚合全局模型时,不是对多个模型参数做简单平均,而是结合每个模型参数的权重系数来更新得到新的全局模型,该权重系数是根据每个客户终端训练模型的预测准确性确定的,提升了联邦模型的预测效果,避免了现有的协作端对多个联邦客户端的模型参数采用简单平均的聚合方式来更新全局模型而导致的联邦模型效果不理想问题。

附图说明

图1是本发明实施例方案涉及的硬件运行环境的结构示意图;

图2为本发明联邦模型训练方法第一实施例的流程示意图;

图3为本发明联邦模型训练方法第二实施例的流程示意图;

图4为本发明联邦模型训练方法第三实施例的流程示意图;

图5为本发明联邦模型训练方法第四实施例的流程示意图。

本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。

具体实施方式

应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。

如图1所示,图1是本发明实施例方案涉及的硬件运行环境的结构示意图。

需要说明的是,图1即可为联邦模型训练设备的硬件运行环境的结构示意图。本发明实施例联邦模型训练设备可以是pc,便携计算机等终端设备。

如图1所示,该联邦模型训练设备可以包括:处理器1001,例如cpu,网络接口1004,用户接口1003,存储器1005,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。用户接口1003可以包括显示屏(display)、输入单元比如键盘(keyboard),可选用户接口1003还可以包括标准的有线接口、无线接口。网络接口1004可选的可以包括标准的有线接口、无线接口(如wi-fi接口)。存储器1005可以是高速ram存储器,也可以是稳定的存储器(non-volatilememory),例如磁盘存储器。存储器1005可选的还可以是独立于前述处理器1001的存储装置。

本领域技术人员可以理解,图1中示出的联邦模型训练设备结构并不构成对联邦模型训练设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。

如图1所示,作为一种计算机存储介质的存储器1005中可以包括操作系统、网络通信模块、用户接口模块以及联邦模型训练程序。其中,操作系统是管理和控制联邦模型训练设备硬件和软件资源的程序,支持联邦模型训练程序以及其它软件或程序的运行。

在图1所示的联邦模型训练设备中,用户接口1003主要用于连接客户终端等,与各个终端进行数据通信;网络接口1004主要用于连接后台服务器,与后台服务器进行数据通信;而处理器1001可以用于调用存储器1005中存储的联邦模型训练程序,并执行以下操作:

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;

根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;

检测所述第二全局模型是否收敛;

若检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端。

进一步地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率;

基于每个预测模型的所述预测误差率及预设的计算公式,分别计算得到每个模型参数对应的权重系数。

进一步地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率的步骤包括:

当协作终端接收到多个客户终端分别发送的模型参数后,将预设的测试样本集中的多个测试样本输入至所述模型参数对应的预测模型中进行预测,得到所述预测模型针对每个所述测试样本的预测值;

根据多个所述预测值,获取所述测试样本集中预测结果错误的测试样本的数量;

将所述预测结果错误的测试样本的数量与所述测试样本集中全部测试样本数量的比值确定为所述预预测模型的预测误差率。

进一步地,所述根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型的步骤包括:

分别将每个模型参数与其对应的权重系数相乘,得到多个相乘后的结果;

将所述多个相乘后的结果相加,并将相加结果确定为第二全局模型的模型参数,得到所述第二全局模型。

进一步地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前,处理器1001还可以用于调用存储器1005中存储的联邦模型训练程序,并执行以下步骤:

发送参数加密的第一全局模型分别至多个客户终端;

接收所述多个客户终端分别发送的模型参数;

其中,所述客户终端在接收到协作终端下发的所述第一全局模型后,所述客户终端根据所述第一全局模型对第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数。

进一步地,所述检测所述第二全局模型是否收敛的步骤之后,处理器1001还可以用于调用存储器1005中存储的联邦模型训练程序,并执行以下步骤:

若检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,以使所述多个客户终端分别根据协作终端下发的所述第二全局模型继续迭代训练以返回模型参数至所述协作终端。

进一步地,所述当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前,处理器1001还可以用于调用存储器1005中存储的联邦模型训练程序,并执行以下步骤:

接收多个客户终端分别发送的模型参数;

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

接收所述多个客户终端分别发送的与所述模型参数对应的权重系数;其中,所述多个客户终端分别根据预设的测试样本集,测试并得到所述模型参数对应的预测模型的预测误差率,并根据所述预测误差率及预设的计算公式,计算得到所述模型参数对应的权重系数。

基于上述的结构,提出联邦模型训练方法的各个实施例。

参照图2,图2为本发明联邦模型训练方法第一实施例的流程示意图。

本发明实施例提供了联邦模型训练方法的实施例,需要说明的是,虽然在流程图中示出了逻辑顺序,但是在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤。

联邦模型训练方法包括:

步骤s100,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;

其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的。

联邦模型是利用技术算法加密建造的机器学习模型,联邦学习系统为了保证联邦多方客户端在训练过程中数据的保密性,借助第三方协作终端进行加密训练,联邦学习系统中的多个联邦客户端在模型训练时不用给出己方数据,而是根据协作端下发的参数加密的全局模型和客户端本地的数据集来训练本地模型,并返回本地模型参数供协作端聚合更新全局模型,更新后的全局模型重新下发到客户端,循环往复,直到收敛。联邦学习通过加密机制下参数交换的方式保护客户端数据隐私,客户端数据和客户端的本地模型本身不会进行传输,本地数据不会被反猜,能够在较高程度保持数据完整性的同时,保障数据隐私。

但是,现有的协作端根据多个客户端返回的本地模型参数聚合更新全局模型时,只是对多个客户端的模型参数做简单平均,将平均后的模型参数作为新的全局模型参数下发至客户端继续迭代训练,然而,实际训练中,每个客户端由于其训练数据的不同,训练出的本地模型的预测性能也是参差不齐的,现有技术中的简单平均的聚合方法会导致全局模型的效果不理想,若联邦学习系统中每个客户终端的本地模型的预测准确率差异较大,多个参数简单平均的聚合会降低其中本地模型的预测准确率高的客户端的模型效果,全局模型即最终得到的联邦模型的效果不理想。

本实施例中,协作终端采用预设的加密算法对第一全局模型的参数加密,并下发参数加密的第一全局模型至联邦学习系统中的多个客户终端,其中,预设的加密算法本实施例不做具体限制,可以是非对称加密算法等等,第一全局模型是本实施例待训练联邦模型完成了若干次迭代运算后得到的全局模型。

进一步地,客户终端在接收到协作终端下发的参数加密的第一全局模型后,每个客户终端根据其本地的训练样本数据对该第一全局模型进行训练得到其各自的本地模型,需要说明的是,本实施例中,多个客户端的训练样本类别均服从独立同分布,客户终端将得到的本地模型的模型参数返回至协作终端。

当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的,作为一种实施方式,协作终端存储有测试样本集,测试样本集中的测试样本与多个客户终端的训练样本均具有相同的特征维度,在接收到多个客户终端分别发送的模型参数后,协作终端采用测试样本集,测试每个模型参数对应的预测模型的预测准确性,具体地,将测试样本集中的多个测试样本输入至每个客户终端回传的预测模型中,得到该预测模型对该测试样本集中多个测试样本的预测结果,筛选出预测错误的测试样本的数量,用筛选出的预测错误的测试样本的数量除该测试样本集中测试样本的总数,即得到当前预测模型的预测误差率,采用同样的方法,得到每个客户终端发送的模型参数对应的预测模型的预测误差率。

进一步地,根据每个模型参数下的预测误差率确定该模型参数参与全局模型聚合时的权重系数,每个模型参数下的预测误差率与该模型参数的权重系数负相关,即模型参数对应的预测模型的预测误差率越小,则该模型参数的权重系数越大,本实施例协作终端根据多个客户终端发送的模型参数聚合时,对预测准确性高的模型增加其模型参数的权重,对预测准确性低的模型降低其模型参数的权重,以此更新得到的新的全局模型保证了每个客户终端模型效果的增长。作为一种实施方式,权重系数的计算可以是根据计算公式计算得到,其中,εi为第i个客户终端的预测模型的预测误差率,αi为第i个客户终端的预测模型的模型参数对应的权重系数,i为大于零的整数,协作终端即获取到每个模型参数对应的权重系数。

需要说明的是,在其它实施例中,联邦学习系统中的多个客户终端可以均存储有相同的测试样本集,该测试样本集中的测试样本与多个客户终端的训练样本均具有相同的特征维度,每个客户终端的模型参数对应的权重系数可以是客户终端根据本地存储的测试样本集测试其预测模型的预测误差率,进而得到其模型参数对应的权重系数,客户终端发送模型参数的同时,将计算得到的模型参数对应的权重系数也一并发至协作终端供协作终端聚合,本实施例在此不做具体限制,进一步地,权重系数的计算方法也不限于本实施例所述的计算方法,在其它实施例中,可以根据需求设置相应的计算规则。

步骤s200,根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;

本实施例中,每个模型参数对应的预测模型的预测误差率与该模型参数的权重系数负相关,即预测误差率越小则该模型参数的权重系数越大,预测误差率越大则该模型参数的权重系数越小,对每个模型参数乘其对应的权重系数,再将乘了权重系数的多个模型参数相加即得到新的全局模型的参数,得到新的全局模型即第二全局模型。

本实施例协作终端根据多个客户终端发送的模型参数聚合时,对预测准确性高的模型增加其模型参数的权重,对预测准确性低的模型降低其模型参数的权重,以此更新得到的新的全局模型保证了每个客户终端模型效果的增长,避免了现有的协作端对多个联邦客户端的模型参数采用简单平均的聚合方式来更新全局模型而导致的联邦模型效果不理想问题。

步骤s300,检测所述第二全局模型是否收敛;

本实施例中,作为一种实施方式,协作终端根据第二全局模型的损失函数得到损失值,根据损失值判断第二全局模型是否收敛,具体地,协作终端存储有第一全局模型下的第一损失值,协作终端根据第二全局模型的损失函数得到第二损失值,计算第一损失值和第二损失值之间的差值,并判断该差值是否小于或者等于预设阈值,若该差值小于或者等于预设阈值,则确定所述第二全局模型处于收敛状态,联邦模型训练完成,实际训练时,预设阈值可以根据用户的需求来自行设定,本实施例对预设阈值不做具体限制。

步骤s400,若检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端。

若检测到第二全局模型处于收敛状态,则联邦模型训练完成,第二全局模型即确定为联邦模型训练的最终结果,协作终端下发参数加密的第二全局模型至所述多个客户终端,多个客户终端即在不用给出己方数据的前提下,实现了本地模型的效果增长,保障数据隐私的同时,提升了预测准确性。

本实施例通过当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;检测所述第二全局模型是否收敛;若检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端;由此,协作终端根据联邦多方的客户端返回的模型参数聚合全局模型时,不是对多个模型参数做简单平均,而是结合每个模型参数的权重系数来更新得到新的全局模型,该权重系数是根据每个客户终端训练模型的预测准确性确定的,提升了联邦模型的预测效果,避免了现有的协作端对多个联邦客户端的模型参数采用简单平均的聚合方式来更新全局模型而导致的联邦模型效果不理想问题。

进一步地,提出本发明联邦模型训练方法第二实施例。

参照图3,图3为本发明联邦模型训练方法第二实施例的流程示意图,基于上述图2所示的实施例,本实施例中,步骤s100,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

步骤s101,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率;

步骤s102,基于每个预测模型的所述预测误差率及预设的计算公式,分别计算得到每个模型参数对应的权重系数。

具体地,在本实施例中,协作终端存储有测试样本集,测试样本集中包括多个测试样本,多个测试样本与多个客户终端的本地训练样本均具有相同的特征维度,协作终端将测试样本集中的多个测试样本输入至每个客户终端回传的预测模型中,得到该预测模型对该测试样本集中多个测试样本的预测结果,筛选得到预测结果错误的测试样本的数量,用预测结果错误的测试样本的数量除该测试样本集中测试样本的总数,即得到当前预测模型的预测误差率,进一步地,采用同样的方法,得到每个客户终端发送的模型参数对应的预测模型的预测误差率。

在本实施例中,预设的计算公式为:其中,εi为第i个客户终端的预测模型的预测误差率,αi为第i个客户终端的预测模型的模型参数对应的权重系数,i为大于零的整数;协作终端通过计算得到每个客户终端的预测模型的预测误差率后,分别将每个预测误差率代入该计算公式计算,得到的结果即为每个模型参数对应的权重系数。

进一步地,基于上述图2所示的实施例,本实施例中,步骤s200,根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型的步骤包括:

步骤s201,分别将每个模型参数与其对应的权重系数相乘,得到多个相乘后的结果;

步骤s202,将所述多个相乘后的结果相加,并将相加结果确定为第二全局模型的模型参数,得到所述第二全局模型。

本实施例中,对每个模型参数乘其对应的权重系数,再将乘了权重系数的多个模型参数相加即得到新的全局模型的参数,得到新的全局模型即第二全局模型。

本实施例协作终端根据多个客户终端发送的模型参数聚合时,对预测准确性高的模型增加其模型参数的权重,对预测准确性低的模型降低其模型参数的权重,以此更新得到的新的全局模型保证了每个客户终端模型效果的增长,避免了现有的协作端对多个联邦客户端的模型参数采用简单平均的聚合方式来更新全局模型而导致的联邦模型效果不理想问题。

进一步地,提出本发明联邦模型训练方法第三实施例。

参照图4,图4为本发明联邦模型训练方法第三实施例的流程示意图,基于上述图2所示的实施例,本实施例中,步骤s100,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前还包括:

步骤s110,发送参数加密的第一全局模型分别至多个客户终端;

步骤s120,接收所述多个客户终端分别发送的模型参数;

其中,所述客户终端在接收到协作终端下发的所述第一全局模型后,所述客户终端根据所述第一全局模型对第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数。

本实施例中,作为一种实施方式,假设待训练的联邦模型已经完成了第k次的迭代运算并得到第一全局模型modelk,其中,k为大于零的整数,本实施例联邦模型训练方法具体包括如下步骤;

步骤a:协作终端发送参数加密的第一全局模型modelk分别至每一个联邦客户终端;

步骤b:第i个客户终端接收modelk,用modelk对本地的训练样本集xi进行预测,根据预测结果将xi中的训练样本分为两个集合:预测错误的样本数据集(x1,x2,...,xn)和预测正确的样本数据集(y1,y2,...,ym),其中n表示xi中预测错误的样本数量,m表示xi中预测正确的样本数量,n可以等于m,本实施例不做具体限制,故有:

xi=(x1,x2,...,xn)∪(y1,y2,...,ym)且成立;

第i个客户终端在训练modelk之前,先对xi进行采样,具体是选取所述预测错误的样本数据集(x1,x2,...,xn),并从所述预测正确的样本数据集(y1,y2,...,ym)中抽取部分样本(y1,y2,...,yk),k<m,来构成采样后的训练数据集yi,即yi=(x1,x2,...,xn)∪(y1,y2,...,yk),k<n,采用训练数据集yi对modelk进行训练,训练后得到新的本地预测模型

步骤c:第i个客户终端发送训练后得到的本地预测模型至协作终端,协作终端将协作终端存储的测试样本集中的多个测试样本输入至第i个客户终端回传的预测模型中,得到该预测模型对该测试样本集中多个测试样本的预测结果,筛选出预测错误的测试样本的数量,用筛选出的预测错误的测试样本的数量除该测试样本集中测试样本的总数,即得到第i个客户终端的预测模型的预测误差率

步骤d:协作终端将计算得到的第i个客户终端的预测模型的预测误差率代入计算公式中,计算得到第i个客户终端的模型参数对应的权重系数

步骤e:协作终端根据每一个客户终端发送的模型参数及计算得到的每一个模型参数对应的权重系数,聚合更新第一全局模型modelk得到第二全局模型modelk+1,其中q为本实施例联邦客户终端的总数量,协作终端检测modelk+1是否收敛,若收敛,则将modelk+1作为本实施例联邦模型的最终训练结果,并将modelk+1的模型参数加密下发至各个客户终端。

若协作终端检测到modelk+1未收敛,则重复上述步骤a-步骤e,直至联邦模型收敛。

本实施例客户终端根据协作终端下发的参数加密的第一全局模型训练时,客户终端首先根据所述第一全局模型对客户终端的本地第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数并回传所述模型参数至协作终端,实现了对于预测错误的样本数据,在下一次迭代时提高其权重,优化了本地训练模型的性能即提升了每个客户终端发送至协作终端的模型参数的质量,从而提升了全局模型即本实施例联邦模型的预测准确性。

进一步地,提出本发明联邦模型训练方法第四实施例。

参照图5,图5为本发明联邦模型训练方法第四实施例的流程示意图,基于图2所示的实施例,本实施例中,步骤s300,检测所述第二全局模型是否收敛的步骤之后还包括:

步骤s500,若检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,以使所述多个客户终端分别根据协作终端下发的所述第二全局模型继续迭代训练以返回模型参数至所述协作终端。

本实施例中,若检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第二全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;根据多个模型参数及每个模型参数对应的权重系数,聚合得到第三全局模型;检测所述第三全局模型是否收敛;若检测到所述第三全局模型处于收敛状态,则将所述第三全局模型确定为联邦模型训练的最终结果,并下发参数加密的第三全局模型至所述多个客户终端,若检测到所述第三全局模型处于未收敛状态,下发参数加密的第三全局模型分别至多个客户终端,重复本发明上述任一实施例的步骤,继续训练直至模型收敛。

进一步地,提出本发明联邦模型训练方法第五实施例。

基于图2所示的实施例,本实施例中,步骤s100,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤之前还包括步骤:

接收多个客户终端分别发送的模型参数;

步骤s100,当协作终端接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数的步骤包括:

接收所述多个客户终端分别发送的与所述模型参数对应的权重系数;其中,所述多个客户终端分别根据预设的测试样本集,测试并得到所述模型参数对应的预测模型的预测误差率,并根据所述预测误差率及预设的计算公式,计算得到所述模型参数对应的权重系数。

本实施例中,作为一种实施方式,每个客户终端的模型参数对应的权重系数的计算是在各个客户终端分别进行的,联邦学习系统中的多个客户终端均存储有相同的测试样本集,该测试样本集中的测试样本与多个客户终端的训练样本均具有相同的特征维度,每个客户终端的模型参数对应的权重系数可以是客户终端根据本地存储的测试样本集测试其预测模型的预测误差率,进而得到其模型参数对应的权重系数,客户终端发送模型参数的同时,将计算得到的模型参数对应的权重系数也一并发至协作终端供协作终端聚合得到全局模型。

此外,本发明实施例还提出一种联邦模型训练系统,所述系统包括协作终端及分别与所述协作终端通信连接的多个客户终端,所述协作终端包括:

获取模块,用于在接收到多个客户终端分别发送的模型参数后,根据预设的获取规则,获取每个模型参数对应的权重系数;其中,所述模型参数是客户终端根据协作终端下发的参数加密的第一全局模型进行联邦模型训练得到的,所述权重系数是基于所述模型参数对应的预测模型的预测准确性确定的;

聚合更新模块,用于根据多个模型参数及每个模型参数对应的权重系数,聚合得到第二全局模型;

检测模块,用于检测所述第二全局模型是否收敛;

确定模块,用于在所述检测模块检测到所述第二全局模型处于收敛状态,则将所述第二全局模型确定为联邦模型训练的最终结果,并下发参数加密的第二全局模型至所述多个客户终端。

优选地,所述获取模块包括:

测试单元,用于在接收到多个客户终端分别发送的模型参数后,根据预设的测试样本集,测试并得到每个模型参数对应的预测模型的预测误差率;

计算单元,用于基于每个预测模型的所述预测误差率及预设的计算公式,分别计算得到每个模型参数对应的权重系数。

优选地,所述测试单元包括:

测试子单元,用于在接收到多个客户终端分别发送的模型参数后,将预设的测试样本集中的多个测试样本输入至所述模型参数对应的预测模型中进行预测,得到所述预测模型针对每个所述测试样本的预测值;

获取子单元,用于根据多个所述预测值,获取所述测试样本集中预测结果错误的测试样本的数量;

确定子单元,用于将所述预测结果错误的测试样本的数量与所述测试样本集中全部测试样本数量的比值确定为所述预预测模型的预测误差率。

优选地,所述预设的计算公式为:其中,εi为第i个客户终端的预测模型的预测误差率,αi为第i个客户终端的所述预测模型的模型参数对应的权重系数,i为大于零的整数。

优选地,所述聚合更新模块包括:

乘处理单元,用于分别将每个模型参数与其对应的权重系数相乘,得到多个相乘后的结果;

更新单元,用于将所述多个相乘后的结果相加,并将相加结果确定为第二全局模型的模型参数,得到所述第二全局模型。

优选地,所述协作终端还包括:

第一下发模块,用于发送参数加密的第一全局模型分别至多个客户终端;

接收模块,用于接收所述多个客户终端分别发送的模型参数;

其中,所述客户终端在接收到所述第一下发模块下发的所述第一全局模型后,所述客户终端根据所述第一全局模型对第一训练样本集进行预测以得到预测值,并根据所述预测值对所述第一训练样本集进行采样得到第二训练样本集,所述客户终端基于所述第二训练样本集训练所述第一全局模型,训练后得到所述模型参数。

优选地,所述协作终端还包括:

第二下发模块,用于在所述检测模块检测到所述第二全局模型处于未收敛状态,则下发参数加密的第二全局模型分别至多个客户终端,以使所述多个客户终端分别根据所述第二下发模块下发的所述第二全局模型继续迭代训练以返回模型参数至所述协作终端。

优选地,所述获取模块,还用于接收所述多个客户终端分别发送的与所述模型参数对应的权重系数;其中,所述多个客户终端分别根据预设的测试样本集,测试并得到所述模型参数对应的预测模型的预测误差率,并根据所述预测误差率及预设的计算公式,计算得到所述模型参数对应的权重系数。

本发明联邦模型训练系统具体实施方式与上述联邦模型训练方法各实施例基本相同,在此不再赘述。

此外,本发明实施例还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有联邦模型训练程序,所述联邦模型训练程序被处理器执行时实现如上所述的奖励发送方法的步骤。

本发明计算机可读存储介质具体实施方式与上述联邦模型训练方法各实施例基本相同,在此不再赘述。

需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。

上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。

通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如rom/ram、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。

以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1