本发明涉及金融科技(fintech)领域,尤其涉及一种联邦学习模型的训练方法及装置。
背景技术:
随着计算机技术的发展,越来越多的技术(例如:区块链、云计算或大数据)应用在金融领域,传统金融业正在逐步向金融科技转变,大数据技术也不例外,但由于金融、支付行业的安全性、实时性要求,也对大数据技术提出的更高的要求。
现有技术的联邦学习是通过传输参数的方式来进行节点之间的交流,在训练过程中,通过参数平均的方式来整合各个节点数据提供的信息。联邦学习在学习过程中,需要对数据进行训练,并在每次开始随机选择节点,将全局模型发放给该节点,利用节点的数据进行迭代,然后将训练得到的模型参数发回给中心服务器,中心服务器对各个节点训练得到的模型参数进行平均确定出下一次迭代的模型。
但在现有技术中的联邦学习中,使用非独立同分布的数据训练出来的模型准确率低,效果不佳。因此,需一种方法提高非独立同分布的数据模型参数训练准确率。
技术实现要素:
本发明实施例提供一种联邦学习模型的训练方法及装置,用于提高非独立同分布的数据训练出来的模型的准确率,优化根据非独立同分布的数据训练出来的模型。
第一方面,本发明实施例提供一种联邦学习模型的训练方法,包括:
客户端获取服务器广播的第k-1次迭代的全局模型参数;所述k为正整数;
所述客户端将所述全局模型参数作为设有正则化约束的本地模型参数,使用本地数据进行第k次迭代训练,得到第k次迭代训练的本地模型参数;所述正则化约束是根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
所述客户端将所述第k次迭代训练的本地模型参数发送给所述服务器,以使所述服务器更新第k次迭代的所述全局模型参数。
上述技术方案中,客户端获取k-1次迭代的全局模型参数,在k为1时,即全局模型参数为初始的全局模型参数,将初始的全局模型参数设置一个正则化约束,作为本地模型参数,以使在迭代过程中全局模型参数和本地模型参数都存在有正则化约束,在本地模型参数对本地的数据进行训练时,约束本地模型参数的损失函数,优化本地模型参数的梯度,进而减小本地数据中极端数据对本地模型参数的训练结果的影响,提高本地模型参数对非独立同分布数据训练的准确率,得到第k次迭代训练的本地模型,再将第k次迭代训练的本地模型参数发送给服务器,以使服务器更新第k次迭代的全局模型参数,进而提高全局模型参数对非独立同分布数据训练的准确率。
可选的,所述根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的所述正则化约束,包括:
对所述全局模型参数与所述本地模型参数的差值进行f范数计算,得到所述正则化约束。
上述技术方案,根据全局模型参数与本地模型参数得到正则化约束,用于优化本地模型参数中的损失函数,提高本地模型参数对非独立同分布数据训练的准确率。
可选的,根据下述公式(1)确定出所述本地模型参数的最终损失函数;
其中,
可选的,根据下述公式(2)确定出所述第k次迭代训练的本地模型参数;
其中,wk(i)为所述第k次迭代的本地模型参数,wk-1(i)为第k-1次迭代的本地模型参数,αk为第k次迭代的学习率,
第二方面,本发明实施例提供一种联邦学习模型的训练方法,包括:
服务器获取多个客户端发送的第k次迭代的本地模型参数;所述第k次迭代的本地模型参数是客户端使用设有正则化约束的本地模型参数对第k-1次迭代的全局模型参数进行训练得到的;所述正则化约束是所述客户端根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
所述服务器根据所述第k次迭代的本地模型参数与k-1次迭代的全局模型参数,确定出第k次迭代的全局模型参数;
所述服务器将所述第k次迭代的全局模型参数广播给所述多个客户端,以使多个客户端进行第k+1次迭代训练。
上述技术方案中,客户端获取多个客户端发送的第k次迭代的本地模型参数,确定出第k次迭代的全局模型参数,由于客户端的本地模型参数通过本地数据进行训练后,提高了本地模型参数对非独立同分布数据训练的准确率,致使全局模型参数根据本地模型参数进行更新后,全局模型参数也提高了对非独立同分布数据训练的准确率,然后再将第k次迭代的全局模型参数广播给多个客户端,以进行下一次迭代。
可选的,根据下述公式(3),确定出所述第k次迭代的全局模型参数;
其中,wk(0)为所述第k次迭代的全局模型参数,wk-1(0)为第k-1次迭代的所述全局模型参数,αk为第k次迭代的学习率,wk(i)为第i个客户端的第k次迭代的本地模型参数。
第三方面,本发明实施例提供一种联邦学习模型的训练装置,包括:
获取模块,用于获取服务器广播的第k-1次迭代的全局模型参数;所述k为正整数;
处理模块,用于将所述全局模型参数作为设有正则化约束的本地模型的参数,使用本地数据进行第k次迭代训练,得到第k次迭代训练的本地模型参数;所述正则化约束是根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
将所述第k次迭代训练的本地模型参数发送给所述服务器,以使所述服务器更新第k次迭代的所述全局模型参数。
可选的,所述处理模块具体用于:
对所述全局模型参数与所述本地模型参数的差值进行f范数计算,得到所述正则化约束。
可选的,所述处理模块具体用于:
根据下述公式(1)确定出所述本地模型参数的最终损失函数;
其中,
可选的,所述处理模块具体用于:
根据下述公式(2)确定出所述第k次迭代的本地模型参数;
其中,wk(i)为所述第k次迭代的本地模型参数,wk-1(i)为第k-1次迭代的本地模型参数,αk为第k次迭代的学习率,
第四方面,本发明实施例提供一种联邦学习模型的训练装置,包括:
获取单元,用于获取多个客户端发送的第k次迭代的本地模型参数;所述第k次迭代的本地模型参数是客户端使用设有正则化约束的本地模型参数对第k-1次迭代的全局模型参数进行训练得到的;所述正则化约束是所述客户端根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
处理单元,根据所述第k次迭代的本地模型参数与k-1次迭代的全局模型参数,确定出第k次迭代的全局模型参数;
将所述第k次迭代的全局模型参数广播给所述多个客户端,以使多个客户端进行第k+1次迭代训练。
可选的,所述处理单元具体用于:
根据下述公式(3),确定出所述第k次迭代的全局模型参数;
其中,wk(0)为所述第k次迭代的全局模型参数,wk-1(0)为第k-1次迭代的所述全局模型参数,αk为第k次迭代的学习率,wk(i)为第i个客户端的第k次迭代的本地模型参数。
第五方面,本发明实施例还提供一种计算设备,包括:
存储器,用于存储程序指令;
处理器,用于调用所述存储器中存储的程序指令,按照获得的程序执行上述联邦学习模型的训练方法。
第六方面,本发明实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行上述联邦学习模型的训练方法。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简要介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种系统架构示意图;
图2为本发明实施例提供的一种联邦学习模型的训练方法的流程示意图;
图3为本发明实施例提供的一种联邦学习模型的训练方法的流程示意图;
图4为本发明实施例提供的一种联邦学习模型的训练装置的结构示意图;
图5为本发明实施例提供的一种联邦学习模型的训练装置的结构示意图。
具体实施方式
为了使本发明的目的、技术方案和优点更加清楚,下面将结合附图对本发明作进一步地详细描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
图1示例性的示出了本发明实施例所适用的一种系统架构,该系统架构包括服务器100,客户端200。
其中,服务器100用于与客户端200进行连接,并向客户端200发送第k-1次迭代的全局模型参数,需要说明的是,图1中仅示例性的示出了一个客户端200,实际可以为多个客户端200,在这里不做限定。
客户端200用于获取服务器100发送的第k-1次迭代的全局模型参数,并将第k-1次迭代的全局模型参数作为设有正则化约束的本地模型参数,然后使用本地数据进行训练,得到第k次迭代训练的本地模型参数,再将第k次迭代训练的本地模型参数发送至服务器100,以使服务器100更新第k次迭代的全局模型参数。
需要说明的是,上述图1所示的结构仅是一种示例,本发明实施例对此不做限定。
基于上述描述,图2示例性的示出了本发明实施例提供的一种联邦学习模型的训练方法的流程,该流程可由联邦学习模型的训练装置执行。
如图2所示,该流程具体包括:
步骤201,客户端获取服务器广播的第k-1次迭代的全局模型参数;所述k为正整数。
本发明实施例,客户端在获取服务器广播的第k-1次迭代的全局模型参数之前,会将第k-1次迭代的本地模型参数发送至服务器,以使服务器进行第k-1次迭代的全局模型参数之后,广播第k-1次迭代的全局模型参数。
步骤202,所述客户端将所述全局模型参数作为设有正则化约束的本地模型参数,使用本地数据进行第k次迭代训练,得到第k次迭代训练的本地模型参数;所述正则化约束是根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的。
本发明实施例,客户端在获取k=1时的全局模型参数后,将获取到的全局模型参数作为本地模型参数,并设置正则化约束,以使本次模型收敛之前,全局模型参数和本地模型参数设置有正则化约束,从而得到本地模型参数中的最终损失函数,减小极端数据对模型训练的影响。
进一步地,正则化约束是根据服务器的全局模型参数和客户端的本地模型参数确定的,包括:对所述全局模型参数与所述本地模型参数的差值进行范数计算,得到所述正则化约束。
通过计算服务器发送的全局模型参数中的权重和客户端的本地模型参数中的权重的f范数来得到正则化约束,其中,f范数是一种矩阵范数,指的是矩阵的每个元素的平方和的开方,具体的根据下述公式(4)确定出正则化约束。
其中,jt-s(w(i))为正则化约束,w(0)为全局模型参数中的权重,w(i)为本地模型参数中的权重。
进一步地,根据下述公式(1)确定出本地模型参数的最终损失函数;
其中,
客户端获取第k-1次迭代的全局模型参数后,计算正则化约束,将设置有正则化约束的全局模型参数作为本地模型参数,然后进行模型训练,得到本地模型参数的最终损失函数,进而得到本地模型参数。
通过上述算法,对于极端的数据分布相对于本地模型参数的原本的损失函数(w(i))影响较大,但根据正则化约束的计算,当w(0)不变时,w(i)增大,则正则化约束(jt-s(w(i)))减小,致使本地模型的最终损失函数
根据下述公式(2)确定出所述第k次迭代训练的本地模型参数;
其中,wk(i)为所述第k次迭代的本地模型参数,wk-1(i)为第k-1次迭代的本地模型参数,αk为第k次迭代的学习率,
通过对本地模型参数的最终损失函数进行求导,再与k-1次迭代的本地模型参数相加求和,得到更新的第k次迭代的本地模型参数。
为了更好的解释上述技术方案,下面在具体的实例中进行描述。
实例1
初始化非独立同分布的数据,作为模型训练的训练集,包括下述两种方式。
1、将数据按照数字标签排序,然后将数据分成多份,如十份,每个客户端将持有多种的数字标签的数据,如两种,例如,将数据分成十份后,进行标签排序,1-10,然后取1和8对第1个客户端进行模型训练,取9和7对第2个客户端进行模型训练,从而使每个客户端的数据都不能作为全局数据分布的代表。
2、使用基准数据集,将数据分成十份,使每个客户端上的数据数量相差较大,使每个客户端的数据都不能作为全局数据分布的代表。
选择10个客户端使用上述数据进行模型训练,10个客户端获取服务器广播的第k-1次的全局模型参数wk-1(0),根据最小化最终损失函数,得到第k次迭代的本地模型参数:
步骤203,客户端将所述第k次迭代训练的本地模型参数发送给所述服务器,以使所述服务器更新第k次迭代的所述全局模型参数。
本发明实施例,多个客户端将训练后得到的第k次迭代的本地模型参数发送给服务器,以使服务器获取整个数据集对应的多个第k次迭代的本地模型参数,进而更新第k次迭代的全局模型参数,进行下一次广播第k次迭代的全局模型参数。
在本发明实施例中,通过将获取的k-1次迭代的全局模型参数设置正则化约束,作为本地模型参数,以使本地模型参数进行训练时,减小本地数据中极端数据对本地模型参数的训练结果的影响,提高本地模型参数对非独立同分布数据训练的准确率,得到第k次迭代训练的本地模型,再将第k次迭代训练的本地模型参数发送给服务器,以提高全局模型参数对非独立同分布数据训练的准确率。
图2示例性的示出了本发明实施例提供的一种联邦学习模型的训练方法的流程。
如图3所示,具体流程包括:
步骤301,服务器获取多个客户端发送的第k次迭代的本地模型参数;所述第k次迭代的本地模型参数是客户端使用设有正则化约束的本地模型参数对第k-1次迭代的全局模型参数进行训练得到的;所述正则化约束是所述客户端根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的。
本发明实施例,服务器获取多个客户端的发送的第k次迭代的本地模型参数,其中本地模型参数设置有正则化约束。
步骤302,所述服务器根据所述第k次迭代的本地模型参数与k-1次迭代的全局模型参数,确定出第k次迭代的全局模型参数。
本发明实施例,服务器将获取到的多个客户端的第k次迭代的本地模型参数与k-1次迭代的全局模型参数求差值,再将所求的差值求和,然后根据学习率与第k-1次迭代的全局模型参数,得到第k次迭代的全局模型参数,
进一步地,根据下述公式(3),确定出第k次迭代的全局模型参数;
其中,wk(0)为第k次迭代的所述全局模型参数,wk-1(0)为第k-1次迭代的所述全局模型参数,αk为第k次迭代的学习率,wk(i)为第i个客户端的第k次迭代的本地模型参数。
相对于传统的联邦学习中求均值的方法,本发明实施例通过将所有客户端的第k次迭代的本地模型参数与第k-1次迭代的全局模型参数之间的差值求和,将和作为步长,与第k-1次迭代的全局模型参数相加,作为更新的第k次迭代的全局模型参数。
结合上述图2中实例1,下面在具体实例中描述上述技术方案。
实例2
获取10个客户端发送的第k次迭代的本地模型参数,然后将所有客户端的第k次迭代的本地模型参数与第k-1次迭代的全局模型参数之间的差值求和,得到和为:
步骤303,所述服务器将所述第k次迭代的全局模型参数广播给所述多个客户端,以使多个客户端进行第k+1次迭代训练。
本发明实施例,服务器将第k次迭代的全局模型参数广播给多个客户端,客户端不需要再重新设置正则化约束,直接使多个客户端进行第k+1次迭代训练,并获取多个客户端发送的第k+1次迭代的本地训练模型参数,以进行更新第k+1次迭代的全局模型参数。
基于相同的技术构思,图4示例性的示出了本发明实施例提供的一种联邦学习模型的训练装置的结构,该装置可以执行图2中的联邦学习模型的训练方法的流程。
如图4所示,该装置具体包括:
获取模块401,用于获取服务器广播的第k-1次迭代的全局模型参数;所述k为正整数;
处理模块402,用于将所述全局模型参数作为设有正则化约束的本地模型的参数,使用本地数据进行第k次迭代训练,得到第k次迭代训练的本地模型参数;所述正则化约束是根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
将所述第k次迭代训练的本地模型参数发送给所述服务器,以使所述服务器更新第k次迭代的所述全局模型参数。
可选的,所述处理模块402具体用于:
对所述全局模型参数与所述本地模型参数的差值进行f范数计算,得到所述正则化约束。
可选的,所述处理模块402具体用于:
根据下述公式(1)确定出所述本地模型参数的最终损失函数;
其中,
可选的,所述处理模块402具体用于:
根据下述公式(2)确定出所述第k次迭代训练的本地模型参数;
其中,wk(i)为所述第k次迭代的本地模型参数,wk-1(i)为第k-1次迭代的本地模型参数,αk为第k次迭代的学习率,
图5示例性的示出了本发明实施例提供的一种联邦学习模型的训练装置的结构,该装置可以执行图3中的联邦学习模型的训练方法的流程。如图5所示,该装置具体包括:
获取单元501,用于获取多个客户端发送的第k次迭代的本地模型参数;所述第k次迭代的本地模型参数是客户端使用设有正则化约束的本地模型参数对第k-1次迭代的全局模型参数进行训练得到的;所述正则化约束是所述客户端根据所述服务器的全局模型参数和所述客户端的本地模型参数确定的;
处理单元502,根据所述第k次迭代的本地模型参数与k-1次迭代的全局模型参数,确定出第k次迭代的全局模型参数;
将所述第k次迭代的全局模型参数广播给所述多个客户端,以使多个客户端进行第k+1次迭代训练。
可选的,所述处理单元502具体用于:
根据下述公式(3),确定出所述第k次迭代的全局模型参数;
其中,wk(0)为所述第k次迭代的全局模型参数,wk-1(0)为第k-1次迭代的所述全局模型参数,αk为第k次迭代的学习率,wk(i)为第i个客户端的第k次迭代的本地模型参数。
基于相同的技术构思,本发明实施例还提供一种计算设备,包括:
存储器,用于存储程序指令;
处理器,用于调用所述存储器中存储的程序指令,按照获得的程序执行上述联邦学习模型的训练方法。
基于相同的技术构思,本发明实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行上述联邦学习模型的训练方法。
本领域内的技术人员应明白,本申请的实施例可提供为方法、系统、或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、cd-rom、光学存储器等)上实施的计算机程序产品的形式。
本申请是参照根据本申请的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
显然,本领域的技术人员可以对本申请进行各种改动和变型而不脱离本申请的精神和范围。这样,倘若本申请的这些修改和变型属于本申请权利要求及其等同技术的范围之内,则本申请也意图包含这些改动和变型在内。