本技术涉及模型训练,尤其涉及一种联邦学习模型训练方法、装置、设备及介质。
背景技术:
1、基于联邦学习进行模型训练时,可以在数据不离开参与方本地的情况下,实现各参与方的联合建模。训练好的联邦学习模型(也可称为全局模型)可以在各参与方之间共享和部署。联邦学习在智慧医疗、金融保险和智能物联网等诸多领域有广泛的应用前景。
2、然而,联邦学习面临着数据安全问题的严峻挑战。例如,参与方与中央服务器之间通常只进行模型参数更新的通信,无法分辨参与方设备发送的模型参数是否为来自恶意参与方的具有一定安全威胁的模型参数、参与方的模型参数存在隐私泄露风险等。
3、因此,亟需一种可以提高联邦学习过程的安全性的技术方案。
技术实现思路
1、本技术提供了一种联邦学习模型训练方法、装置、设备及介质,用以提高联邦学习过程的安全性。
2、第一方面,本技术提供了一种联邦学习模型训练方法,应用于中央服务器,所述方法包括:
3、在参与对联邦学习模型的每轮迭代训练过程中,至少执行以下步骤:
4、接收每个参与方设备发送的下一轮的加密本地模型参数;
5、基于每个所述加密本地模型参数以及预设的参数计算算法,分别获得每个参与方的所述加密本地模型参数的l2范数;
6、针对所述每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全;
7、基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数。
8、在一种可能的实施方式中,所述基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全,包括:
9、判断该参与方的l2范数与选取的基准l2范数之间的差值是否超过预设差值阈值,其中所述基准l2范数是从各参与方的l2范数中选取的;或者,
10、基于该参与方的l2范数以及预设的余弦相似度计算算法,确定该参与方的加密本地模型参数与确定的聚合参数之间的相似度,基于所述相似度,判断该参与方的加密本地模型参数是否安全,其中,所述聚合参数是基于各参与方的加密本地模型参数确定的。
11、在一种可能的实施方式中,所述基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全,包括:
12、判断该参与方的l2范数与选取的基准l2范数之间的差值是否超过预设差值阈值,其中所述基准l2范数是从各参与方的l2范数中选取的;
13、若是,则将该参与方确定为不安全参与方,将该参与方的相应数据剔除;
14、若否,则基于该参与方的l2范数以及预设的余弦相似度计算算法,确定该参与方的加密本地模型参数与确定的聚合参数之间的相似度,基于所述相似度,判断该参与方的加密本地模型参数是否安全,其中,所述聚合参数是基于未被剔除数据的各参与方的加密本地模型参数确定的。
15、在一种可能的实施方式中,所述从各参与方的l2范数中选取基准l2范数,包括:
16、对各参与方的l2范数进行排序,将所述排序中的中位数,确定为所述基准l2范数。
17、在一种可能的实施方式中,所述基于所述相似度,判断该参与方的加密本地模型参数是否安全,包括:
18、从各参与方对应的所述相似度中,选取基准相似度;
19、确定除所述基准相似度之外的每个其他相似度与所述基准相似度之间的相似度差值,并确定各相似度差值的平均差值;
20、针对所述每个参与方,判断该参与方对应的所述相似度与所述基准相似度之间的差值是否超过所述平均差值。
21、在一种可能的实施方式中,所述从各参与方对应的所述相似度中,选取基准相似度,包括:
22、将各参与方对应的所述相似度中的最大相似度,确定为所述基准相似度。
23、在一种可能的实施方式中,所述基于该参与方的l2范数以及预设的余弦相似度计算算法,确定该参与方的加密本地模型参数与确定的聚合参数之间的相似度,包括:
24、确定所述聚合参数的l2范数;
25、若所述加密本地模型参数为采用随机数序列干扰的模型参数,将所述聚合参数发送给每个参与方设备,使每个参与方设备采用随机数序列对所述聚合参数进行干扰;接收每个参与方设备发送的干扰后的聚合参数;
26、针对每个参与方,基于该参与方的加密本地模型参数以及干扰后的聚合参数,获得该参与方的加密本地模型参数与所述聚合参数的第一乘积;并确定该参与方的l2范数与所述聚合参数的l2范数的第二乘积;基于所述第一乘积与所述第二乘积的比值,确定该参与方的加密本地模型参数与所述聚合参数之间的相似度。
27、在一种可能的实施方式中,所述基于该参与方的加密本地模型参数以及干扰后的聚合参数,获得该参与方的加密本地模型参数与所述聚合参数的第一乘积,包括:
28、确定该参与方的加密本地模型参数与干扰后的聚合参数的内积,将加密的所述内积发送给该参与方设备,使该参与方设备基于接收到内积以及预设的随机数干扰去除算法,获得所述加密本地模型参数与所述聚合参数的第一乘积;接收该参与方设备发送的所述第一乘积。
29、在一种可能的实施方式中,所述基于每个所述加密本地模型参数以及预设的参数计算算法,分别获得所述每个参与方的所述加密本地模型参数的l2范数,包括:
30、若所述加密本地模型参数为采用同态加密进行了加密且采用随机数序列进行了干扰,针对每个参与方,基于所述参数计算算法,获得该参与方的加密本地模型参数中每个子参数及相应随机数的平方和,将加密后的所述平方和发送给该参与方设备,使该参与方设备基于预设的同态加密算法,获得该参与方的加密本地模型参数的l2范数平方的密文;接收该参与方设备发送的所述密文,并基于所述密文,计算该参与方的所述加密本地模型参数的l2范数。
31、第二方面,本技术提供了一种联邦学习模型训练方法,应用于参与方设备,所述方法包括:
32、在参与对联邦学习模型的每轮迭代训练过程中,至少执行以下步骤:
33、采用预设加密算法对获得的下一轮的本地模型参数进行加密,获得加密本地模型参数;
34、将所述加密本地模型参数发送给中央服务器,使所述中央服务器基于每个参与方设备发送的所述加密本地模型参数以及预设的参数计算算法,分别获得每个参与方的所述加密本地模型参数的l2范数;并使所述中央服务器针对所述每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全;并使所述中央服务器基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数。
35、在一种可能的实施方式中,所述方法还包括:
36、接收所述中央服务器发送的聚合参数;
37、若所述加密本地模型参数为采用随机数序列干扰的模型参数,则采用随机数序列对所述聚合参数进行干扰,并将干扰后的聚合参数发送给所述中央服务器。
38、在一种可能的实施方式中,所述方法还包括:
39、接收所述中央服务器发送的所述加密本地模型参数与干扰后的聚合参数的加密的内积;
40、基于所述内积以及预设的随机数干扰去除算法,获得所述加密本地模型参数与所述聚合参数的第一乘积,将所述第一乘积发送给所述中央服务器。
41、在一种可能的实施方式中,所述方法还包括:
42、接收所述中央服务器发送的所述加密本地模型参数中每个子参数及相应随机数的加密的平方和;
43、若所述加密本地模型参数为采用同态加密进行了加密且采用随机数序列进行了干扰,则基于预设的同态加密算法,获得所述加密本地模型参数的l2范数平方的密文,并将所述密文发送给所述中央服务器。
44、第三方面,本技术还提供了一种联邦学习模型训练系统,所述系统包括:
45、每个参与方设备,用于在参与对联邦学习模型的每轮迭代训练过程中,至少执行以下步骤:采用预设加密算法对获得的下一轮的本地模型参数进行加密,获得加密本地模型参数;将所述加密本地模型参数发送给中央服务器;
46、所述中央服务器,用于基于每个所述加密本地模型参数以及预设的参数计算算法,分别获得所述每个参与方的所述加密本地模型参数的l2范数;针对所述每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全;基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数。
47、第四方面,本技术提供了一种联邦学习模型训练装置,应用于中央服务器,所述装置包括:
48、接收模块,用于在参与对联邦学习模型的每轮迭代训练过程中,接收每个参与方设备发送的下一轮的加密本地模型参数;
49、范数模块,用于基于每个所述加密本地模型参数以及预设的参数计算算法,分别获得所述每个参与方的所述加密本地模型参数的l2范数;
50、安全模块,用于针对所述每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全;
51、聚合模块,用于基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数。
52、在一种可能的实施方式中,所述安全模块,具体用于:
53、判断该参与方的l2范数与选取的基准l2范数之间的差值是否超过预设差值阈值,其中所述基准l2范数是从各参与方的l2范数中选取的;或者,
54、基于该参与方的l2范数以及预设的余弦相似度计算算法,确定该参与方的加密本地模型参数与确定的聚合参数之间的相似度,基于所述相似度,判断该参与方的加密本地模型参数是否安全,其中,所述聚合参数是基于各参与方的加密本地模型参数确定的。
55、在一种可能的实施方式中,所述安全模块,具体用于:
56、判断该参与方的l2范数与选取的基准l2范数之间的差值是否超过预设差值阈值,其中所述基准l2范数是从各参与方的l2范数中选取的;
57、若是,则将该参与方确定为不安全参与方,将该参与方的相应数据剔除;
58、若否,则基于该参与方的l2范数以及预设的余弦相似度计算算法,确定该参与方的加密本地模型参数与确定的聚合参数之间的相似度,基于所述相似度,判断该参与方的加密本地模型参数是否安全,其中,所述聚合参数是基于未被剔除数据的各参与方的加密本地模型参数确定的。
59、在一种可能的实施方式中,所述安全模块,具体用于:
60、对各参与方的l2范数进行排序,将所述排序中的中位数,确定为所述基准l2范数。
61、在一种可能的实施方式中,所述安全模块,具体用于:
62、从各参与方对应的所述相似度中,选取基准相似度;
63、确定除所述基准相似度之外的每个其他相似度与所述基准相似度之间的相似度差值,并确定各相似度差值的平均差值;
64、针对所述每个参与方,判断该参与方对应的所述相似度与所述基准相似度之间的差值是否超过所述平均差值。
65、在一种可能的实施方式中,所述安全模块,具体用于:
66、将各参与方对应的所述相似度中的最大相似度,确定为所述基准相似度。
67、在一种可能的实施方式中,所述安全模块,具体用于:
68、确定所述聚合参数的l2范数;
69、若所述加密本地模型参数为采用随机数序列干扰的模型参数,将所述聚合参数发送给每个参与方设备,使每个参与方设备采用随机数序列对所述聚合参数进行干扰;接收每个参与方设备发送的干扰后的聚合参数;
70、针对每个参与方,基于该参与方的加密本地模型参数以及干扰后的聚合参数,获得该参与方的加密本地模型参数与所述聚合参数的第一乘积;并确定该参与方的l2范数与所述聚合参数的l2范数的第二乘积;基于所述第一乘积与所述第二乘积的比值,确定该参与方的加密本地模型参数与所述聚合参数之间的相似度。
71、在一种可能的实施方式中,所述安全模块,具体用于:
72、确定该参与方的加密本地模型参数与干扰后的聚合参数的内积,将加密的所述内积发送给该参与方设备,使该参与方设备基于接收到内积以及预设的随机数干扰去除算法,获得所述加密本地模型参数与所述聚合参数的第一乘积;接收该参与方设备发送的所述第一乘积。
73、在一种可能的实施方式中,所述范数模块,具体用于:
74、若所述加密本地模型参数为采用同态加密进行了加密且采用随机数序列进行了干扰,针对每个参与方,基于所述参数计算算法,获得该参与方的加密本地模型参数中每个子参数及相应随机数的平方和,将加密后的所述平方和发送给该参与方设备,使该参与方设备基于预设的同态加密算法,获得该参与方的加密本地模型参数的l2范数平方的密文;接收该参与方设备发送的所述密文,并基于所述密文,计算该参与方的所述加密本地模型参数的l2范数。
75、第五方面,本技术提供了一种联邦学习模型训练装置,应用于参与方设备,所述装置包括:
76、加密模块,用于在参与对联邦学习模型的每轮迭代训练过程中,采用预设加密算法对获得的下一轮的本地模型参数进行加密,获得加密本地模型参数;
77、发送模块,用于将所述加密本地模型参数发送给中央服务器,使所述中央服务器基于每个参与方设备发送的所述加密本地模型参数以及预设的参数计算算法,分别获得所述每个参与方的所述加密本地模型参数的l2范数;并使所述中央服务器针对所述每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全;并使所述中央服务器基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数。
78、在一种可能的实施方式中,所述发送模块,还用于:
79、接收所述中央服务器发送的聚合参数;
80、若所述加密本地模型参数为采用随机数序列干扰的模型参数,则采用随机数序列对所述聚合参数进行干扰,并将干扰后的聚合参数发送给所述中央服务器。
81、在一种可能的实施方式中,所述发送模块,还用于:
82、接收所述中央服务器发送的所述加密本地模型参数与干扰后的聚合参数的加密的内积;
83、基于所述内积以及预设的随机数干扰去除算法,获得所述加密本地模型参数与所述聚合参数的第一乘积,将所述第一乘积发送给所述中央服务器。
84、在一种可能的实施方式中,所述发送模块,还用于:
85、接收所述中央服务器发送的所述加密本地模型参数中每个子参数及相应随机数的加密的平方和;
86、若所述加密本地模型参数为采用同态加密进行了加密且采用随机数序列进行了干扰,则基于预设的同态加密算法,获得所述加密本地模型参数的l2范数平方的密文,并将所述密文发送给所述中央服务器。
87、第六方面,本技术还提供了一种电子设备,所述电子设备至少包括处理器和存储器,所述处理器用于执行存储器中存储的计算机程序时实现如上述任一所述方法的步骤。
88、第七方面,本技术提供了一种计算机可读存储介质,其存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一所述方法的步骤。
89、由于本技术实施例中,参与方设备发送给中央服务器的本地模型参数为加密本地模型参数,相较于发送明文本地模型参数,发送加密本地模型参数可以防止由恶意的中央服务器等获取到参与方设备发送的明文模型参数,可以防止恶意的中央服务器等发起逆向推理攻击等,导致泄露参与方的数据隐私,从而可以保证参与方的数据安全。另外,由于本技术实施例还可以基于每个参与方的加密本地模型参数以及预设的参数计算算法,分别获得每个参与方的加密本地模型参数的l2范数;针对每个参与方,基于该参与方的l2范数以及各参与方的l2范数,判断该参与方的加密本地模型参数是否安全,从而可以分辨参与方设备发送的模型参数是否为来自恶意参与方的具有一定安全威胁的模型参数,进而可以基于得到的安全的加密本地模型参数,确定待训练的联邦学习模型的下一轮全局模型参数,从而可以提高联邦学习过程的安全性。