一种模型训练方法及相关装置与流程

文档序号:38026708发布日期:2024-05-17 13:02阅读:6来源:国知局
一种模型训练方法及相关装置与流程

本技术涉及人工智能(artificial intelligence,ai),尤其涉及一种模型训练方法及相关装置。


背景技术:

1、随着人类社会数字化进程越来越快,产生了大量数据。通过机器学习技术可以自动化地挖掘数据中蕴藏的信息,因此经过大量数据训练出来的机器学习模型已经应用在各类场景中,例如人脸识别、语音翻译、医疗辅助诊断等场景。在实际应用中,机器学习模型的精度、泛化能力等至关重要,而这些都依赖于采用大量的数据对机器学习模型进行训练。

2、受限于法律法规、商业机密、个人隐私等数据隐私安全上的约束,多个数据来源方往往无法直接交换数据,进而导致多个数据来源方的数据无法融合在一起对机器学习模型进行训练,制约了机器学习模型能力的进一步提高。联邦学习的诞生即是为了解决这一问题。

3、联邦学习(federated learning)是一种分布式机器学习技术,其核心思想是通过在多个拥有本地数据的数据源之间进行分布式模型训练,在不需要交换本地数据的前提下,仅通过交换中间结果的方式,构建基于融合数据下的全局模型,从而实现数据隐私保护和数据共享计算的平衡。

4、然而,由于参与联邦学习的多个设备分布于不同的地方,多个设备在执行联邦学习的过程中,往往需要通过网络来交换大量的数据。受限于网络的通信能力,在多个设备之间待交换的数据较多的情况下,设备间交换数据的通信时长往往会大于设备训练模型的时长,进而导致联邦学习的整体时延较长。


技术实现思路

1、本技术提供了一种模型训练方法,能够缩短整个迭代训练过程中的空闲等待时长,提高整体的训练效率。

2、本技术第一方面提供一种模型训练方法,应用于人工智能技术领域。该方法包括:第一训练装置获取训练数据集,该训练数据集包括多个数据。其中,第一训练装置所获取到的多个数据可以为一个批次的数据,即多个数据的数量与模型的批次大小相同。批次大小指示了用于同时对模型进行训练的数据样本的数据量大小,即模型在一次迭代训练过程中需要使用到的数据样本的数据量大小。

3、然后,第一训练装置将多个数据划分为多组数据,多组数据中的每组数据包括至少一个数据,且每组数据的数据量与第一训练装置的通信能力相关。例如,每组数据的数据量与第一训练装置的通信能力具有正相关的关系。即,第一训练装置的通信能力越强,每组数据的数据量则越大。

4、最后,基于第一训练装置中所部署的模型,以组为单位分批对多组数据执行处理,以训练模型并分批向第二训练装置传输处理多组数据所得到的数据,第二训练装置与第一训练装置共同参与模型的训练。具体地,通过模型对多组数据执行处理会得到多部分数据,每组数据对应于唯一的一部分数据。并且,由于本步骤中是分批对多组数据执行处理,因此本步骤中多部分数据中的每一部分数据也是分批向第二训练装置传输的。

5、本方案中,将训练过程中的多个数据划分为多组数据,且每组数据的数据量与训练装置的通信能力相关,由训练装置通过模型分批对多组数据进行处理并将处理多组数据所得到的多部分数据分批向其他训练装置传输,以执行模型的训练。这样,训练装置在向其他训练装置传输处理前面组次数据所得到的数据时,仍能够继续对后面组次的数据进行处理,进而实现将模型训练时间隐藏于训练装置之间的通信时间中,最终缩短整个迭代训练过程中的空闲等待时长,提高整体的训练效率。

6、在一种可能的实现方式中,训练模型的过程中所执行的运算包括第一类运算和第二类运算。其中,第一类运算用于基于多组数据生成向第二训练装置传输的多部分数据,即第一类运算是依赖于多组数据来生成待传输数据的运算。第二类运算用于生成仅由第一训练装置处理的数据,即第二类运算的执行与否并不影响待传输数据的生成。

7、此外,第一类运算的执行优先级高于第二类运算的执行优先级。也就是说,在训练过程中,当同时存在均能够执行的第一类运算和第二类运算时,第一训练装置优先执行第一类运算。在所有的第一类运算执行完毕后,第一训练装置再执行第二类运算。

8、本方案中,对于会影响待传输数据生成的第一类运算以及不会影响待传输数据生成的第二类运算,训练装置优先执行第一类运算,以便于能够持续生成待传输数据,尽可能地避免训练装置处于通信空闲状态,以保证通信和训练并行的时间尽可能地长,从而缩短整个迭代过程中的空闲等待时长,提高整体的训练效率。

9、在一种可能的实现方式中,第一类运算包括第一子类运算和第二子类运算,第一子类运算的运算结果用于得到第二子类运算的输入,第二子类运算的运算结果用于向第二训练装置传输。也就是说,第一训练装置执行第二子类运算所得到的运算结果即为需要向第二训练装置传输的数据,而第一训练装置执行第一子类运算所得到的运算结果是用于作为第二子类运算的输入,即第二子类运算的执行依赖于第一子类运算。在实际训练过程中,只有先执行第一子类运算后,才能够得到作为第二子类运算的输入的数据,进而才能执行第二子类运算。

10、在一种可能的实现方式中,第二子类运算的执行优先级高于第一子类运算。

11、简单来说,由于第二子类运算的执行依赖于第一子类运算,因此在训练过程中,第一训练装置需要先执行第一子类运算,并得到运算结果后,才能够基于所得到的运算结果继续执行第二子类运算。即,第二子类运算的运算条件满足后,才能够执行第二子类运算。但是,在第二子类运算的运算条件满足后,第一训练装置则优先执行第二子类运算,以便于尽快生成待传输至第二训练装置的数据。

12、在一种可能的实现方式中,在训练模型的过程中,第一训练装置将执行第一类运算所产生的数据缓存至第一队列中,第一队列用于缓存待传输至第二训练装置的数据。在第一队列中数据的数据量大于或等于第一阈值的情况下,第一训练装置则停止执行第一类运算,并执行第二类运算。

13、具体来说,在第一队列中的数据较多的情况下,代表第一训练装置向第二训练装置发送数据的速度远跟不上第一训练装置执行第一类运算而产生数据的速度,因此第一训练装置继续优先执行第一类运算也不会提高整体的训练效率。在这种情况下,为了避免内存开销过大,第一训练装置则可以是选择转至执行第二类运算,以免产生过多的待传输数据积压在内存中。

14、在一种可能的实现方式中,在第一队列中的数据小于第一阈值的情况下,或在用于支持执行第二类运算的数据已处理完毕的情况下,第一训练装置则停止执行第二类运算,并继续执行第一类运算。其中,用于支持执行第二类运算的数据可以是指作为第二类运算的输入的数据,即第二类运算的输入数据。

15、也就是说,在第一队列中的数据被消耗至一定数量的情况下,为了避免第一队列的数据被消耗完毕而导致出现通信空闲的现象,第一训练装置可以是继续优先执行第一类运算,以持续产生待传输至第二训练装置的数据,保证数据通信的持续性。并且,在第一训练装置处理完毕第二类运算的输入数据之后,用于支持执行第二类运算的数据则已处理完毕,第一训练装置无法继续执行第二类运算,因此第一训练装置转至继续执行第一类运算。

16、在一种可能的实现方式中,第一训练装置所参与的训练为联邦学习,例如横向联邦学习、纵向联邦学习或联邦迁移学习。

17、在一种可能的实现方式中,第一类运算包括基于模型对多组数据进行处理的运算,第二类运算包括基于从第二训练装置获取的数据对模型进行反向梯度计算的运算。

18、在一种可能的实现方式中,模型包括第一子模型和第二子模型;第一类运算包括基于第一子模型对多组数据进行处理的运算,以及基于从第二训练装置获取的数据对第二子模型进行处理的运算;第二类运算包括对第一子模型进行反向梯度计算的运算。

19、在一种可能的实现方式中,多个数据的数量与模型的批次大小相关;在模型的训练过程中,目标梯度与基于多组数据所得到的多个梯度相关,目标梯度用于第一训练装置更新模型。例如,在基于多组数据分别得到多个梯度的情况下,通过求取多个梯度的平均值,得到目标梯度。

20、本方案通过基于多个梯度来求取目标梯度,进而基于目标梯度对模型进行更新,能够保证是基于一个批次的数据来对模型进行更新,确保模型训练的精度不会受到影响。

21、在一种可能的实现方式中,多个数据的数量与模型的批次大小相关,多组数据的组数与待传输数据量具有正相关的关系,且多组数据的组数与第一训练装置的通信能力以及训练时长具有负相关的关系;其中,待传输数据量为处理多组数据后所生成的待传输数据的数据量,训练时长为第一训练装置基于多组数据训练模型的时长。

22、也就是说,处理多组数据后所生成的待传输数据的数据量越大,多组数据的组数则越多,以减少处理每组数据所生成的待传输数据,避免处理每组数据后所生成的待传输数据的数据量过多。并且,第一训练装置的通信能力越强,第一训练装置单位时间内能够向第二训练装置传输的数据则越多,多组数据的组数则可以划分得越少。第一训练装置基于多组数据训练模型的时长越长,则代表第一训练装置处理多组数据以生成待传输数据的速度越慢,多组数据的组数则可以划分得越少。

23、本技术第二方面提供一种模型训练装置,该模型训练装置为第一训练装置,包括:

24、获取模块,用于获取训练数据集,训练数据集包括多个数据;

25、处理模块,用于将多个数据划分为多组数据,多组数据中的每组数据包括至少一个数据,且每组数据的数据量与第一训练装置的通信能力相关;

26、处理模块,还用于基于第一训练装置中所部署的模型,以组为单位分批对多组数据执行处理,以训练模型并分批向第二训练装置传输处理多组数据所得到的数据,第二训练装置与第一训练装置共同参与模型的训练。

27、在一种可能的实现方式中,训练模型的过程中所执行的运算包括第一类运算和第二类运算,第一类运算用于基于多组数据生成向第二训练装置传输的数据,第二类运算用于生成仅由第一训练装置处理的数据,第一类运算的执行优先级高于第二类运算的执行优先级。

28、在一种可能的实现方式中,第一类运算包括第一子类运算和第二子类运算,第一子类运算的运算结果用于得到第二子类运算的输入,第二子类运算的运算结果用于向第二训练装置传输。

29、在一种可能的实现方式中,第二子类运算的执行优先级高于第一子类运算。

30、在一种可能的实现方式中,处理模块,还用于:

31、将执行第一类运算所产生的数据缓存至第一队列中,第一队列用于缓存待传输至第二训练装置的数据;

32、在第一队列中数据的数据量大于或等于第一阈值的情况下,停止执行第一类运算,并执行第二类运算。

33、在一种可能的实现方式中,处理模块,还用于:

34、在第一队列中的数据小于第一阈值的情况下,或在用于支持执行第二类运算的数据已处理完毕的情况下,停止执行第二类运算,并继续执行第一类运算。

35、在一种可能的实现方式中,第一训练装置所参与的训练为联邦学习。

36、在一种可能的实现方式中,第一类运算包括基于模型对多组数据进行处理的运算,第二类运算包括基于从第二训练装置获取的数据对模型进行反向梯度计算的运算。

37、在一种可能的实现方式中,模型包括第一子模型和第二子模型;

38、第一类运算包括基于第一子模型对多组数据进行处理的运算,以及基于从第二训练装置获取的数据对第二子模型进行处理的运算;

39、第二类运算包括对第一子模型进行反向梯度计算的运算。

40、在一种可能的实现方式中,多个数据的数量与模型的批次大小相关;

41、在模型的训练过程中,目标梯度与基于多组数据所得到的多个梯度相关,目标梯度用于第一训练装置更新模型。

42、在一种可能的实现方式中,多个数据的数量与模型的批次大小相关,多组数据的组数与待传输数据量具有正相关的关系,且多组数据的组数与第一训练装置的通信能力以及训练时长具有负相关的关系;

43、其中,待传输数据量为处理多组数据后所生成的待传输数据的数据量,训练时长为第一训练装置基于多组数据训练模型的时长。

44、本技术第三方面提供一种模型训练装置,可以包括处理器,处理器和存储器耦合,存储器存储有程序指令,当存储器存储的程序指令被处理器执行时实现上述第一方面或第一方面任一实现方式所述的方法。对于处理器执行第一方面的各个可能实现方式中的步骤,具体均可以参阅第一方面,此处不再赘述。

45、在一种可能的实现方式中,该模型训练装置还包括通信接口,该模型训练装置通过通信接口向其他的模型训练装置传输数据,或者是接收其他的模型训练装置所传输的数据。具体地,该模型训练装置可以是基于通信接口,通过远程直接数据存取(remote directmemory access,rdma)技术来向其他的模型训练装置传输数据,在此并不限定模型训练装置之间传输数据的具体方式。

46、其中,rdma是指通过网络把数据直接传入计算机的存储区,即将数据从一个系统快速移动到远程系统存储器中,而不对操作系统造成任何影响,从而不占用计算机的处理资源。因此,在模型训练装置之间通过rdma技术来实现数据传输的情况下,模型训练装置可以将待传输的数据存储至内存中,以便于通过网络将内存中的数据直接穿入另一个模型训练装置的内存中,进而减少对两个模型训练装置的处理资源的消耗。

47、本技术第四方面提供了一种模型训练系统,包括至少两个模型训练装置,该至少两个模型训练装置共同参与模型的训练,且至少两个模型训练装置中的任意一个模型训练装置采用上述第一方面任一实现方式所述的方法与其他的模型训练装置进行交互并执行模型的训练。

48、本技术第五方面提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,当其在计算机上运行时,使得计算机执行上述第一方面任一实现方式所述的方法。

49、本技术第六方面提供了一种电路系统,所述电路系统包括处理电路,所述处理电路配置为执行上述第一方面任一实现方式所述的方法。

50、本技术第七方面提供了一种计算机程序产品,当其在计算机上运行时,使得计算机执行上述第一方面任一实现方式所述的方法。

51、本技术第八方面提供了一种芯片系统,该芯片系统包括处理器,用于支持服务器或门限值获取装置实现上述第一方面任一实现方式中所涉及的功能,例如,发送或处理上述方法中所涉及的数据和/或信息。在一种可能的设计中,所述芯片系统还包括存储器,所述存储器,用于保存服务器或通信设备必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包括芯片和其他分立器件。

52、上述第二方面至第八方面的有益效果可以参考上述第一方面的介绍,在此不再赘述。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1