一种联邦学习方法

文档序号:35028131发布日期:2023-08-05 16:08阅读:78来源:国知局
一种联邦学习方法

本发明属于边缘人工智能,具体涉及一种联邦学习方法。


背景技术:

1、据《全球移动市场报告》预测,2023年全球智能手机用户或将突破四十亿,这些设备每秒钟都会产生大量数据。将数据上传到云端处理会引起较大的时延,即使应用对实时性的要求不高,云端也会随着设备规模的不断扩大而面临更大的压力,将计算下沉到边缘端是一个缓解云端压力的方法。边缘计算在靠近用户或数据输入的地方提供服务,在边缘端完成部分工作可以更好地保护用户的数据隐私。

2、许多研究将机器学习应用到其他领域后取得了很大突破。边缘设备更靠近用户,能收集到许多数据,而机器学习能从数据中提取有用的信息,在边缘计算中使用机器学习的理论和技术可以为用户提供更好的服务。联邦学习能在保护数据安全的前提下,联合客户端和服务器训练神经网络模型。

3、然而,客户端在服务器的协助下训练模型有一定的困难。神经网络模型规模较大,在训练中需要多次迭代,而边缘设备的资源、计算能力受限,可能难以承受训练和通信开销。为了降低成本,许多研究通过减少参数量、通信数据量或通信次数来轻量化模型,但这忽略了客户端之间数据的非独立同分布事实。此外,由各个客户端之间异构的数据训练出的模型往往也具有异构性,在服务器中聚合后可能会出现模型性能下降的情况。个性化联邦学习是针对数据异构性的很好的解决方案,常见的构建个性化联邦学习的方法包括元学习、迁移学习、自适应调整等,但这些方法都没有考虑客户端的计算和通信能力等都受限的问题。

4、以上减少模型开销的方法会对模型性能有一定的损伤,而针对异构数据的方法需要用额外的参数表示数据,增加了模型的开销。在联邦学习中使用剪枝策略可以在去除冗余参数的同时为客户端训练个性化模型,但是通常采用的剪枝策略是启发式的,没有考虑到权重在训练中的动态性,会造成不小的精度损失,这需要额外的再训练来恢复,这种方法的代价仍然比较高。


技术实现思路

1、本发明的目的在于提供一种联邦学习方法,用以解决采用启发式的剪枝策略造成模型精度损失的问题。

2、为解决上述技术问题,本发明提供了一种联邦学习方法,包括如下步骤:

3、1)各个客户端对服务器下传的网络模型进行稀疏化训练,在每一次稀疏化迭代训练过程中均对网络模型精度和剪枝率进行判断,若网络模型精度大于精度阈值且剪枝率小于目标剪枝率,则将网络模型中部分趋近于0的权重置为0以进行剪枝,且对于置为0的权重,相应掩码中对应位置也置为0,直至达到稀疏化训练的迭代终止条件,从而得到各个客户端的本地模型和掩码,并上传至服务器;

4、2)服务器对各个客户端上传到本地模型进行聚合,得到全局模型;根据全局模型和各个客户端的掩码,得到全局模型的子网并下送至各个客户端。

5、上述技术方案的有益效果为:本发明中的客户端对网络模型进行稀疏化训练后再剪枝,使得网络模型更为轻量化,与直接剪枝的启发式方式相比,加快了收敛速度,更好地解决客户端资源受限的问题,最终使网络模型能更好地运行在客户端中。而且,与没有稀疏化训练的方法相比,剪枝对模型精度的影响更小,且在剪枝之前对网络模型精度和剪枝率进行判断,在两者均达到要求的情况下再剪枝,防止对模型的精度造成永久性损伤。

6、进一步地,在进行本地模型聚合时,仅对本地模型中未进行剪枝的部分进行聚合。

7、上述技术方案的有益效果为:仅对未剪枝部分进行聚合,可以减少数据异构对模型性能的影响。

8、进一步地,将各个本地模型中未进行剪枝的部分相应位置取平均值以实现聚合。

9、进一步地,步骤2)中将全局模型与客户端各自的掩码的对应位置进行相乘得到各个客户端的子网。

10、进一步地,将损失函数转化具有稀疏度要求的优化问题,并采用admm算法求解优化目标,以实现对网络模型进行稀疏化训练。

11、上述技术方案的有益效果为:admm能够很好地解决稀疏约束下的非凸优化问题,因此使用admm进行求解。



技术特征:

1.一种联邦学习方法,其特征在于,包括如下步骤:

2.根据权利要求1所述的联邦学习方法,其特征在于,在进行本地模型聚合时,仅对本地模型中未进行剪枝的部分进行聚合。

3.根据权利要求2所述的联邦学习方法,其特征在于,将各个本地模型中未进行剪枝的部分相应位置取平均值以实现聚合。

4.根据权利要求1所述的联邦学习方法,其特征在于,步骤2)中将全局模型与客户端各自的掩码的对应位置进行相乘得到各个客户端的子网。

5.根据权利要求1~4任一项所述的联邦学习方法,其特征在于,将损失函数转化具有稀疏度要求的优化问题,并采用admm算法求解优化目标,以实现对网络模型进行稀疏化训练。


技术总结
本发明属于边缘人工智能技术领域,具体涉及一种联邦学习方法。该方法中各个客户端对服务器下传的网络模型进行稀疏化训练,在每一次迭代训练过程中,若网络模型精度大于精度阈值且剪枝率小于目标剪枝率,则进行剪枝,且对于置为0的权重,相应掩码中对应位置也置为0,直至达到训练的迭代终止条件,得到各个客户端的本地模型和掩码,并上传至服务器;服务器对各个客户端上传到本地模型进行聚合,得到全局模型,进而根据全局模型和各个客户端的掩码,得到全局模型的子网并进行下发。本发明中的客户端对网络模型进行稀疏化训练后再剪枝,使得网络模型更为轻量化,与直接剪枝的启发式方式相比,加快了收敛速度,更好地解决客户端资源受限的问题。

技术研发人员:袁培燕,石玲,赵晓焱,张俊娜,刘春红
受保护的技术使用者:河南师范大学
技术研发日:
技术公布日:2024/1/14
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1