一种联邦学习方法、装置及电子设备与流程

文档序号:34023903发布日期:2023-05-05 06:02阅读:34来源:国知局
一种联邦学习方法、装置及电子设备与流程

本发明涉及计算机,特别涉及一种联邦学习方法、装置及电子设备。


背景技术:

1、联邦学习本质上是一种分布式机器学习框架,其核心思想是通过在多个拥有本地数据的数据源之间进行分布式模型训练,在不需要交换本地个体或样本数据的前提下,仅通过交互模型中间参数进行模型联合训练,原始数据可以不出本地。

2、然而,现有的联邦学习方法中,虽然保证了明文数据不出本地,但存在部分客户端的模型发生偏移、导致本地模拟效果较差的问题。


技术实现思路

1、有鉴于此,本发明提供了一种联邦学习方法、装置及电子设备,主要目的在于解决目前存在现有的联邦学习方法中容易造成客户端训练获得模型发生偏移、进而导致本地模拟效果较差的问题。

2、为解决上述问题,本技术提供一种联邦学习方法,包括:

3、向各客户端发送元模型以及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;

4、从各所述客户端中确定出用于进行模型再训练的、若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型的第一模型参数,所述n为正整数;

5、基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数;

6、将所述第二模型参数发送给各所述目标客户端,以使各所述目标客户端基于已接收的元模型、第二模型参数、客户端本地的训练数据以及目标补丁模型进行模型再训练获得当前第二模型;

7、判断各客户端训练获得的当前第二模型是否均符合训练条件,在各客户端训练获得的当前第二模型不符合训练条件时,重新确定出用于进行模型再训练的、若干目标客户端;在各客户端训练获得第二模型符合训练条件时停止训练。

8、可选的,在向各客户端发送元模型以及初始模型参数之前,所述方法还包括:

9、接收各客户端发送的各数据类型以及与各数据类型对应的各训练数据的数据标识;

10、基于各客户端发送的数据类型以及数据标识的数量,计算获得各数据类型之间数据标识总量的数据占比;

11、在进行每一轮模型训练之前,基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识,并将各所述目标数据标识发送给对应的客户端,以为各客户端重新分配用于进行模型训练的训练数据。

12、可选的,所述基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识,具体包括:

13、基于所述数据占比以及各客户端所包含的目标数据类型,分别确定与各客户端对应的目标数据占比;

14、基于各客户端的目标数据占比,从对应客户端所发送的各数据类型的数据标识中确定出若干目标数据标识。

15、可选的,所述基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数,具体包括:

16、基于各目标客户端第n轮训练获得的第一模型参数以及各目标客户端第n-1轮训练获得的历史模型参数,确定各目标客户端对应的梯度参数;

17、基于各目标客户端的梯度参数,采用预定的计算公式计算获得目标梯度参数;

18、基于所述目标梯度参数确定用于进行n+1轮模型训练的第二模型参数。

19、可选的,所述方法还包括:基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构,确定映射补丁模型、残差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型。

20、可选的,所述基于模型训练任务的任务类型、服务器数据与客户端数据的数据差异度以及服务器元模型的结构复杂度,确定映射补丁模型、残差补丁模型、内部补丁模型中的任意一种为所述目标补丁模型,具体包括:

21、在所述任务类型为监控任务或定位任务时,确定所述映射补丁模型为所述目标补丁模型;

22、在所述服务器数据与客户端数据的数据差异度大于预定差异度阈值时,确定所述残差补丁模型为所述目标补丁模型;

23、在服务器元模型的结构复杂度大于预定复杂度时,确定所述内部补丁模型为所述目标补丁模型。

24、可选的,所述映射补丁模型包括:映射网络以及激活层;

25、所述残差补丁模型包括:残差连接层;

26、所述内部补丁模型包括:卷积层以及激活层。

27、为解决上述问题,本技术提供一种联邦学习方法,应用于各客户端,包括:

28、接收服务端发送的元模型、初始模型参数,以基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;

29、接收服务端发送的第二模型参数,所述第二模型参数是由服务端基于若干目标客户端第n轮训练获得的第一模型参数所计算获得的;基于服务端发送的所述第二模型参数、已接收的元模型、客户端本地的训练数据以及目标补丁模型进行第n+1轮模型训练。

30、可选的,在接收服务端发送的元模型、初始模型参数之前,所述方法还包括:

31、将各数据类型以及与各数据类型对应的各训练数据的数据标识发送给服务端,以使服务端基于各客户端发送的数据类型以及数据标识的数量,计算获得各数据类型之间数据标识总量的数据占比;并使服务端在进行每一轮模型训练之前,基于所述数据占比以及各客户端所包含的目标数据类型,从各目标数据类型对应的若干数据标识中确定若干目标数据标识;

32、接收服务端发送的各目标数据标识,基于各所述目标数据标识从本地的训练数据中确定与目标数据标识对应的目标训练数据,以基于重新分配获得的所述目标训练数据进行模型训练。

33、可选的,所述目标补丁模型包括如下任意一种:映射补丁模型、残差补丁模型、内部补丁模型;

34、其中,所述映射补丁模型包括:映射网络以及激活层;

35、所述残差补丁模型包括:残差连接层;

36、所述内部补丁模型包括:卷积层以及激活层。

37、为解决上述问题,本技术提供一种联邦学习装置,包括:

38、第一发送模块,用于向各客户端发送元模型以及初始模型参数,以使各客户端基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练;

39、采集模块,用于从各所述客户端中确定出用于进行模型再训练的、若干目标客户端,并采集各目标客户端当前第n轮训练获得的第一模型的第一模型参数,所述n为正整数;

40、计算模块,用于基于各目标客户端的第一模型参数,采用预定的计算方式计算获得用于进行n+1轮模型训练的第二模型参数;

41、第二发送模块,用于将所述第二模型参数发送给各所述目标客户端,以使各所述目标客户端基于已接收的元模型、第二模型参数、客户端本地的训练数据以及目标补丁模型进行模型再训练,获得当前第二模型;

42、判断模块,用于判断各客户端训练获得的当前第二模型是否均符合训练条件,在各客户端训练获得的当前第二模型不符合训练条件时,基于所述采集模块重新确定出用于进行模型再训练的、若干目标客户端;在各客户端训练获得第二模型符合训练条件时停止训练。

43、为解决上述问题,本技术提供一种联邦学习装置,包括:接收模块以及模型训练模块;

44、所述接收模块用于,接收服务端发送的元模型、初始模型参数,以及用于接收服务端发送的第二模型参数,所述第二模型参数是由服务端基于若干目标客户端第n轮训练获得的第一模型参数所计算获得的;

45、所述模型训练模块用于,基于所述元模型、初始模型参数、客户端本地的训练数据以及目标补丁模型进行第一轮模型训练,以及用于基于服务端发送的所述第二模型参数、已接收的元模型、客户端本地的训练数据以及目标补丁模型进行第n+1轮模型训练。

46、为解决上述问题,本技术提供一种电子设备,至少包括存储器、处理器,所述存储器上存储有计算机程序,所述处理器在执行所述存储器上的计算机程序时实现上述任一项所述联邦学习方法的步骤。

47、本技术中的联邦学习方法、装置及电子设备,在每一轮模型训练时,通过从各客户端中确定出需要进行下一轮模型训练的目标客户端,然后基于各目标客户端当前模型参数计算获得下一轮模型训练的模型参数,能够使得模型参数的确定更加合理准确,使得目标客户端在进行下一轮训练时,能够基于补丁模型以及服务端计算获得模型参数,精准的训练获得符合本地模拟情况的模型,避免了最终训练获得模型发生偏移的问题,使得各客户端训练获得模型均具有良好的本地模拟效果。

48、上述说明仅是本发明技术方案的概述,为了能够更清楚了解本发明的技术手段,而可依照说明书的内容予以实施,并且为了让本发明的上述和其它目的、特征和优点能够更明显易懂,以下特举本发明的具体实施方式。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1