一种基于联邦学习的模型训练方法与流程

文档序号:20265097发布日期:2020-04-03 18:17阅读:493来源:国知局
一种基于联邦学习的模型训练方法与流程

本说明书实施例涉及信息技术领域,尤其涉及一种基于联邦学习的模型训练方法。



背景技术:

联邦学习(federatedmachinelearning/federatedlearning),是指一种机器学习框架,能有效帮助多个节点(可以代表个人或机构)在满足数据隐私保护的要求下,联合训练模型。

在联邦学习框架下,服务端下发模型参数给多个节点,每个节点将本地的训练样本输入模型进行一次训练,本次训练结束后,每个节点会基于本次训练结果计算得到的梯度。随后,服务端基于安全聚合(sa,secureaggregation)协议,可以计算得到各节点的梯度之和。值得强调的是,服务端收到sa协议的限制,并不能获得单个节点上传的梯度。

如此,既可以使得服务端根据各节点上传的梯度之和调整模型参数,又可以实现节点的数据隐私保护。

然而,在有些场景下,模型参数也不适合暴露给节点。



技术实现要素:

为了解决联邦学习框架下存在的模型参数难以保护的问题,本说明书实施例提供一种基于联邦学习的模型训练方法,技术方案如下:

根据本说明书实施例的第1方面,提供一种基于联邦学习的模型训练方法,应用于包括服务端与n个节点的联邦学习系统,n>1,所述方法包括:

在模型训练的第i次迭代中,执行:

所述服务端基于同态加密算法e对模型参数集合θ进行加密,得到e(θ),并将e(θ)下发给mi个节点;其中,mi≤n,所述mi个节点中存在qi个目标类型节点;

第j个目标类型节点根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij);其中,j=(1,2,…,qi);

第j个目标类型节点确定随机数rj,并计算e(wij)-e(rij),得到e(sij),以及,向所述服务端上传e(sij);

所述服务端根据e(sij)计算并基于安全聚合sa协议,计算

所述服务端计算得到并基于更新θ。

根据本说明书实施例的第2方面,提供一种联邦学习系统,包括服务端与n个节点,n>1;

所述服务端,在模型训练的第i次迭代中,基于同态加密算法e对模型参数集合θ进行加密,得到e(θ),并将e(θ)下发给mi个节点,其中,mi≤n,所述mi个节点中存在qi个目标类型节点;

第j个目标类型节点,根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij),其中,j=(1,2,…,qi);确定随机数rij,并计算e(wij)-e(rij),得到e(sij),以及,向所述服务端上传e(sij);

所述服务端,还根据e(sij)计算并基于安全聚合sa协议,计算计算得到并基于更新θ。

本说明书实施例所提供的技术方案,服务端采用同态加密算法对模型参数集合进行加密后下发给节点,节点基于同态加密原理,使用加密后的模型参数与本地训练样本进行加密状态下的模型计算,得到加密梯度。随后,节点基于同态加密原理,计算加密梯度与加密随机数的差,这个差实质上是加密的某个无意义的值。接着,节点将加密后的值上传给服务端。此外,服务端可以利用sa协议,在不获知每个节点上的随机数的前提下,获知各节点上的随机数之和。如此,服务端就可以根据每个节点上传的加密后的值与各随机数之和来还原出每个节点产生的梯度之和,从而可以更新模型参数,以便进入下一次迭代或者完成训练。

通过本说明书实施例,可以在基于联邦学习框架实现服务端与各节点联合训练模型的前提下,实现服务端向节点隐藏模型参数,从而避免节点根据模型参数破解模型。

应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本说明书实施例。

此外,本说明书实施例中的任一实施例并不需要达到上述的全部效果。

附图说明

为了更清楚地说明本说明书实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本说明书实施例中记载的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的附图。

图1是本说明书实施例提供的一种基于联邦学习的模型训练方法的流程示意图;

图2是本说明书实施例提供的一种基于联邦学习的模型训练方法的原理示意图;

图3是本说明书实施例提供的一种联邦学习系统中的服务端的结构示意图;

图4是本说明书实施例提供的一种联邦学习系统中的节点的结构示意图;

图5是用于配置本说明书实施例方法的一种设备的结构示意图。

具体实施方式

在联邦学习框架下,通常由服务端负责根据节点上传的梯度更新模型参数,并将模型参数下发给节点,由节点基于模型参数与本地训练样本计算梯度。为了防止服务端根据节点上传的梯度推断出节点的本地训练样本,一般基于sa协议来实现节点将梯度上传给服务端,使得服务端仅会获取到各节点上传的梯度之和,却无法获取到单个节点上传的梯度。

可见,在现有的联邦学习架构下,节点可以向服务端隐藏本地训练样本,而服务端却不会向节点隐藏模型参数。

然而,在有些场景下,服务端也并不想将隐私(即模型参数)暴露给节点。例如,假设需要基于联邦学习架构训练诈骗交易识别模型,服务端一方面需要各节点提供诈骗交易作为样本训练该模型,另一方面也不希望节点获知模型参数,否则,模型参数容易暴露给恶意分子,导致基于该模型构筑的诈骗防线容易被攻破。

为了解决上述问题,在本说明书实施例中,基于同态加密算法,对模型参数进行加密后下发到节点,节点基于加密后的模型参数与本地训练样本计算梯度,基于同态加密原理,计算出的实际上是加密后的梯度。

此处需要说明,一方面,sa协议并不支持对加密后的数据进行上传,另一方面,如果节点直接将加密后的梯度上传,则服务端可以直接解密获得节点的梯度明文,反而造成节点的隐私泄露。

因此,在本说明书实施例中,并不会让节点直接将加密后的梯度上传给服务端,而是做如下处理:

1、节点确定一个随机数,基于同态加密算法(记为e)对随机数进行加密,然后计算加密后的梯度与加密后的随机数的差值,记为e(s)。将加密后的梯度记为e(w),w为梯度,将加密后的随机数记为e(r),r为随机数,可以理解,基于同态加密原理,如果e(w)-e(r)=e(s),则w-r=s。

2、节点将e(s)上传给服务端。如此,服务端即便对e(s)解密得到s,由于s是无意义的值,服务端也无法获知节点的隐私。

3、服务端基于sa协议,获取各节点的随机数之和。由于sa协议可以用于将至少两个节点的明文数据上传给服务端,并且确保服务端只能获知各节点上传的明文数据之和,而不能获知单个节点的明文数据。

4、服务端获取每个节点上传的e(s)之后,可以确定各节点上的s之和,将s之和与随机数之和相加,就能得到各节点上的梯度之和,以便对模型参数进行更新。

通过本说明书书实施例,服务端可以在基于联邦学习框架实现服务端与各节点联合训练模型的前提下,实现服务端向节点隐藏模型参数,从而避免节点根据模型参数破解模型。

为了使本领域技术人员更好地理解本说明书实施例中的技术方案,下面将结合本说明书实施例中的附图,对本说明书实施例中的技术方案进行详细地描述,显然,所描述的实施例仅仅是本说明书的一部分实施例,而不是全部的实施例。基于本说明书中的实施例,本领域普通技术人员所获得的所有其他实施例,都应当属于保护的范围。

以下结合附图,详细说明本说明书各实施例提供的技术方案。

图1是本说明书实施例提供的一种基于联邦学习的模型训练方法的流程示意图,包括以下步骤:

s100:服务端基于同态加密算法e对模型参数集合θ进行加密,得到e(θ)。

众所周知,在机器学习领域,一般采用迭代调参的方式来训练模型。步骤s100~s114是训练模型过程中的一次迭代,可以理解,训练模型的过程,实际上是循环执行步骤s100~s114的过程,当模型参数被更新到满足训练停止条件时,就会停止循环。

具体地,可以将训练停止条件设定为:循环执行步骤s100~s114的次数达到指定次数g,或者,一次迭代的损失函数值小于指定值。

为了描述的方便,本文将s100~s114视为第i次迭代执行的步骤。可以理解,如果训练停止条件为循环次数达到指定次数g,则i=(1,2,…,g)。

图1所示的方法应用于联邦学习系统,联邦学习系统包括服务端与n个节点(即节点设备),其中,n大于1。

在本文中,为了描述的方便,将同态加密算法记为e,经过同态加密算法加密过的数据记为e(*),*代表被加密的数据。此外,还将模型的模型参数集合记为θ。

s102:服务端将e(θ)下发给mi个节点。

本说明书实施例的应用场景主要有两类,一类是服务端tob场景(服务端与至少两个机构进行联合学习),另一类是服务端toc场景(服务端与至少两个个人用户进行联合学习)。

在服务端tob场景下,节点的数量并不多,在每次迭代中,服务端可以将e(θ)下发给每个节点进行计算。

在服务端toc场景下,一般会有海量的个人用户参与训练,节点的数量很大,因此,在每次迭代中,服务端为了避免数据处理压力过大,可以选择部分节点下发e(θ),仅根据这部分节点反馈的训练效果来更新模型参数集合。

需要说明的是,服务端在每次迭代中选择的节点可以不同,选择的节点数量也可以不同。为了描述的方便,将第i次迭代中选择的节点数量记为mi。

还需要说明的是,mi可以小于n,也可以等于n。

此外,由于在后续的步骤s110中,服务端需要基于sa协议获取节点上传的数据,而sa协议中使用了秘密共享技术。对于秘密共享技术,其用于实现在l个节点中秘密共享数据,其需要满足l个节点中t个节点在线。因此,一般需要满足mi大于等于ti,ti为:以在所述mi个节点中实现秘密共享为目的,所述mi个节点中处于在线状态的节点的数量的下限值,也就是第i次迭代中,sa协议指定的mi个节点中处于在线状态的节点的数量的下限值。

s104:第j个目标类型节点根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij)。

在本说明书实施例中,在应用于服务端toc的场景的情况下,服务端向mi个节点下发e(θ)之后,由于个人用户的节点设备并不一定总是在线(即不一定总是可以连接到网络,与服务端或其他节点进行数据交互),而如果某个节点不在线,则服务端并不能获取到该节点反馈的训练效果,因此,针对所述mi个节点中的任一节点,如果该节点在接收到e(θ)之后,继续处于在线状态直至第i次迭代结束,则该节点的训练效果才能反馈给服务端。本文为了描述的方便,将所述mi个节点中能够向服务端反馈训练效果的节点称为目标类型节点。

图1所示的方法流程中,只描述了一次迭代中每个目标类型节点执行的操作,并没有描述非目标类型节点执行的操作。然而,可以理解,对于所述mi个节点中的非目标类型节点,其在接收到e(θ)之后,也可以执行类似于s104~s106的操作,只不过无法将执行结果上传给服务端。

此外,在本说明书实施例中,在应用于服务端tob的场景的情况下,由于每个机构的节点设备通常是稳定在线的,因此,所述mi个节点可以都属于目标类型节点。

在本说明书实施例中,为了描述的方便,假设所述mi个节点中存在qi个目标类型节点,qi≤mi,并且,针对所述qi个目标类型节点中第j个目标类型节点进行描述。其中,j=(1,2,…,qi),可以理解,wij是第j个目标类型节点计算得到的梯度,e(wij)是第j个目标类型节点计算得到的梯度的加密值。

在本说明书实施例中,由于e(θ)是加密的模型参数,因此,目标类型节点仅能根据e(θ)与本地训练样本进行加密状态下的模型计算。此处需要说明的是,由于同态加密算法通常是加性的(主要包括加法与数乘),因此,为了使得根据e(θ)与本地训练样本进行加密状态下的模型计算后得到的结果尽可能保留本地训练样本的特征信息,训练模型所采用的机器学习算法通常是线性性质的。例如,可以采用线性回归的机器学习算法训练线性回归模型,也可以采用神经网络算法(激活函数经过线性转换)训练神经网络模型。

此处以线性回归算法为例进行说明。

假设线性回归模型的模型参数集合θ=(θ0,θ1,…θk),样本可以记为其中,实际上是样本的特征向量,线性映射函数为:梯度为:

对模型参数集合进行加密后,得到e(θ)=[e(θ0),e(θ1),…e(θk)],加密状态的映射函数为其中,同态求和符号上省略了b的取值范围(1,k)。

加密梯度为:

s106:第j个目标类型节点确定随机数rij,并计算e(wij)-e(rij),得到e(sij)。

在本说明书实施例中,为了描述的方便,将第j个目标类型节点确定的随机数记为rij。此外,对于同一节点而言,可以在每次迭代中都重新生成随机数,也可以预先生成随机数,在每次迭代中复用该随机数。

基于同态加密原理,e(wij)-e(rij)的计算结果事实上是加密的某个无意义的值,为了描述的方便,将这个无意义的值记为sij,加密的sij即是e(sij)。显然,wij、rij、sij满足以下关系:wij=rij+sij。

s108:第j个目标类型节点向所述服务端上传e(sij)。

此处值得强调的是,即便服务端可以解密e(sij)获得sij,服务端也无法从sij中获取有价值的信息。

s110:服务端根据e(sij)计算并基于安全聚合sa协议,计算

在本说明书实施例中,服务端可以对e(sij)进行解密,得到sij,进而得到

或者,服务端可以计算进而解密得到

此外,在服务端toc的场景下,还需要判断目标类型节点的数量qi是否达到ti,如果qi小于ti,则各目标类型节点之间将无法基于sa协议向服务端反馈数据。

具体而言,所述服务端若确定qi≥ti,则根据e(sj)计算所述服务端若确定qi<ti,则停止本次迭代,并进入下一次迭代。

s112:服务端计算得到

s114:服务端基于更新θ。

假设本说明书实施例中,为梯度下降法指定的学习率为α,第i次迭代中使用的样本总数为d,则可以采用如下公式更新θ,得到更新后的θ(记为θ):

图2是本说明书实施例提供的一种基于联邦学习的模型训练方法的原理示意图。在图2中,示出了包括服务端与3个节点的联邦学习系统。服务端首先对模型参数集合进行加密,得到e(θ),然后将e(θ)下发给各节点。图2仅标注了服务端与节点2之间的交互,可以理解,服务端与节点1、节点3都会进行同样的交互。随后,服务端基于sa协议,获取各节点的随机数之后,并进而计算得到各节点的梯度之和,以便更新模型参数集合。

此外,本说明书实施例中的模型对应的输入数据可以包括以下一种:图像、文本、语音。也即,模型训练过程中使用的训练样本可以是图像,可以是文本,也可以语音。模型训练完成后,可以相应地用于对图像、文本或语音进行处理。

进一步地,所述文本可以包含实体对象信息。其中,实体对象可以是用户、商户等对象。

还需要说明的是,本文所述的用于处理图像的模型例如可以是图像分类模型、图像分割模型等,本文所述的用于处理文本的模型例如可以是机器人客服模型、实体对象风险识别模型、推荐模型等,本文所述的用于处理语音的模型可以是语音助手模型、语音识别模型等。

一种联邦学习系统,包括服务端与n个节点,n>1;

所述服务端,在模型训练的第i次迭代中,基于同态加密算法e对模型参数集合θ进行加密,得到e(θ),并将e(θ)下发给mi个节点,其中,mi≤n,所述mi个节点中存在qi个目标类型节点;

第j个目标类型节点,根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij),其中,j=(1,2,…,qi);确定随机数rij,并计算e(wij)-e(rij),得到e(sij),以及,向所述服务端上传e(sij);

所述服务端,还根据e(sij)计算并基于安全聚合sa协议,计算计算得到并基于更新θ。

图3是本说明书实施例提供的一种联邦学习系统中的服务端的结构示意图,应用于模型训练的第i次迭代中,所述联邦学习系统还包括n个节点,n>1;

所述服务端包括:

模型参数加密模块301,基于同态加密算法e对模型参数集合θ进行加密,得到e(θ);

模型参数下发模块302,将e(θ)下发给mi个节点,其中,mi≤n,所述mi个节点中存在qi个目标类型节点,以使第j个目标类型节点根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij),其中,j=(1,2,…,qi);进而使第j个目标类型节点确定随机数rij,并计算e(wij)-e(rij),得到e(sij),以及,向所述服务端上传e(sij);

第一计算模块303,根据e(sij)计算并基于安全聚合sa协议,计算

第二计算模块304,计算得到

模型参数更新模块305,基于更新θ。

图4是本说明书实施例提供的一种联邦学习系统中的节点的结构示意图,所述联邦学习系统包括n个节点与服务端,n>1;

在模型训练的第i次迭代中,第j个目标类型节点包括:

模型计算模块401,根据e(θ)与本地训练样本,进行加密状态下的模型计算,得到加密梯度e(wij);

随机数确定模块402,确定随机数rij;

上传模块403,计算e(wij)-e(rij),得到e(sij),以及,向所述服务端上传e(sij);

其中:

所述服务端基于同态加密算法e对模型参数集合θ进行加密,得到e(θ),并将e(θ)下发给mi个节点;mi≤n,所述mi个节点中存在qi个目标类型节点,j=(1,2,…,qi);所述服务端根据e(sij)计算并基于安全聚合sa协议,计算所述服务端计算得到并基于更新θ。

本说明书实施例还提供一种计算机设备,其至少包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,处理器执行所述程序时实现本说明书中的服务端或目标类型节点的方法。

图5示出了本说明书实施例所提供的一种更为具体的计算设备硬件结构示意图,该设备可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。

处理器1010可以采用通用的cpu(centralprocessingunit,中央处理器)、微处理器、应用专用集成电路(applicationspecificintegratedcircuit,asic)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本说明书实施例所提供的技术方案。

存储器1020可以采用rom(readonlymemory,只读存储器)、ram(randomaccessmemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。

输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。

通信接口1040用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如usb、网线等)实现通信,也可以通过无线方式(例如移动网络、wifi、蓝牙等)实现通信。

总线1050包括一通路,在设备的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。

需要说明的是,尽管上述设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本说明书实施例方案所必需的组件,而不必包含图中所示的全部组件。

本说明书实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现本说明书中的服务端或目标类型节点的方法。

计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(pram)、静态随机存取存储器(sram)、动态随机存取存储器(dram)、其他类型的随机存取存储器(ram)、只读存储器(rom)、电可擦除可编程只读存储器(eeprom)、快闪记忆体或其他内存技术、只读光盘只读存储器(cd-rom)、数字多功能光盘(dvd)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitorymedia),如调制的数据信号和载波。

通过以上的实施方式的描述可知,本领域的技术人员可以清楚地了解到本说明书实施例可借助软件加必需的通用硬件平台的方式来实现。基于这样的理解,本说明书实施例的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在存储介质中,如rom/ram、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务设备,或者网络设备等)执行本说明书实施例各个实施例或者实施例的某些部分所述的方法。

上述实施例阐明的系统、方法、模块或单元,具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现。一种典型的实现设备为计算机,计算机的具体形式可以是个人计算机、膝上型计算机、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件收发设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任意几种设备的组合。

本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的模块可以是或者也可以不是物理上分开的,在实施本说明书实施例方案时可以把各模块的功能在同一个或多个软件和/或硬件中实现。也可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。

以上所述仅是本说明书实施例的具体实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本说明书实施例原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本说明书实施例的保护范围。

当前第1页1 2 3 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1