一种轻量级联邦学习训练方法

文档序号:36238612发布日期:2023-12-01 22:05阅读:33来源:国知局
一种轻量级联邦学习训练方法

本发明涉及信息,特别涉及一种轻量级联邦学习训练方法。


背景技术:

1、随着移动设备的功能越来越强大,越来越多的基于神经网络的智能应用已被开发用于移动设备,例如图像识别、视频分析、目标检测等。为了使智能应用能够达到预计效果,通常会通过大量的数据训练智能应用的神经网络模型,然而,单个移动设备的数据量是有限的,不太可能帮助神经网络达到理想的精度。同时,考虑到隐私保护和通信量过大等原因,将数据从许多移动设备传输到一个中央服务器并进行集中训练将不再可行。在联邦学习中的中央服务器的编排下,以分散的方式训练共享全局模型,实现在保护用户数据隐私的同时,最大化提升模型的训练效率和模型的整体精度。

2、目前,由于联邦学习在解决隐私保护和数据孤岛等问题方面的优势,已经逐步成为流行的机器学习范式。此类方法通常分为四个步骤:首先,在每轮通信中,每个参与设备从中央服务器下载当前模型;其次,通过本地数据训练局部模型;第三,通过中央服务器聚合所有局部模型;第四,将聚合后的全局模型发送回设备。然而,由于移动设备通信成本高且通信传输不稳定,联邦学习通信负载较大等问题,常规的联邦学习方法难以在一定设备尤其是告诉移动设备中使用。因此,目前的面向移动设备的联邦学习方法却存在以下不可忽略的技术问题:

3、传统的联邦学习训练方法主要考虑的是稳定通信的设备或者是慢速的移动设备,从而忽略了联邦学习算法应用在高速移动设备上的挑战。在高速移动场景下,例如高速车联网中,车辆的高速移动性带来了信号质量的下降,导致车载网络无法实现最佳带宽和通信速度,这意味着参与训练的设备将会消耗大量的时间和资源在模型的传输过程中。同时,由于不同设备的网络时延不同,中央服务器的聚合过程将会导致更长的等待时间,这将导致联邦学习的效率进一步降低,这些问题严重影响了传统联邦学习在移动场景下的应用效果。


技术实现思路

1、本发明提供了一种轻量级联邦学习训练方法,其目的是为了节约模型传输过程中的传输时间和减少模型聚合过程中的等待时间。

2、为了达到上述目的,本发明提供了一种轻量级联邦学习训练方法,包括:

3、步骤1,中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;

4、步骤2,客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;

5、步骤3,客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;

6、步骤4,客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;

7、步骤5,客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;

8、步骤6,中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2。

9、进一步来说,在客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别之前,还包括:

10、对采集的本地图像数据进行数据标签规范化处理和异常数据删除处理,得到处理后的本地图像数据;

11、客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别。

12、进一步来说,步骤4包括:

13、根据训练后的本地深度卷积神经网络模型的网络特性,将训练后的本地深度卷积神经网络模型分为编码器和分类器;

14、利用结构化剪枝的方式,将编码器的权重绝对值小于预设阈值的权重进行修剪,得到剪枝后的编码器;

15、利用非结构化剪枝的方式,评估分类器中每个卷积层中每个过滤器的影响系数,并影响系数低于预设值的过滤器进行修剪,得到剪枝后的分类器;

16、将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型。

17、进一步来说,根据训练后的本地深度卷积神经网络模型的网络特性,通过对训练后的本地深度卷积神经网络模型进行正则化,得到编码器和分类器,正则化的表达式为:

18、r(w)=renc(we)+rcls(wc)

19、

20、

21、其中,r(w)表示本地深度卷积神经网络模型的剪枝权重,renc表示编码器的剪枝权重,rcls表示分类器的剪枝权重,we表示编码器的权重,wc表示分类器的权重,||·||g是group lasso算法,fl是第l个卷积层中滤波器的数量,chl是第l个卷积层中通道的个数,rowl代表分类器中第l层的行数,col1代表分类器中第l层的列数。

22、进一步来说,轻量化深度卷积神经网络模型的损失函数为:

23、f(w)=fd(w)+λr(w)

24、其中,fd(w)是轻量化深度卷积神经网络模型的损失函数,λ是结构化稀疏正则化的系数。

25、进一步来说,本地局部模型的损失函数为:

26、

27、其中,β表示控制来自数据或其他模型知识比例的超参数,表示本地局部模型的交叉熵损失函数,dkl表示kl散度,pl表示本地深度卷积神经网络模型的预测值,pm表示本地局部模型的预测值。

28、进一步来说,训练终止的条件为:

29、直至全局模型的精度达到预设训练精度或迭代次数达到预设上限时,终止训练。

30、本发明的上述方案有如下的有益效果:

31、本发明通过中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2;与现有技术相比,本发明采用双向蒸馏的方式压缩模型的参数,极大程度上提高了通信效率并减少了聚合时的等待时间,通过剪枝算法对模型做进一步的压缩,有效的去除了局部模型中多余的参数以减少模型参数量,从而能够在提高通信和聚合效率的同时提升模型的精确性。

32、本发明的其它有益效果将在随后的具体实施方式部分予以详细说明。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1