一种糖尿病患者病变评估模型的联邦训练方法及系统

文档序号:37261892发布日期:2024-03-12 20:41阅读:26来源:国知局
一种糖尿病患者病变评估模型的联邦训练方法及系统

本发明涉及糖尿病患者视网膜病变领域,具体来说,涉及糖尿病患者眼底图像噪声鲁棒的分布式机器学习领域,更具体来说,涉及一种糖尿病患者病变评估模型的联邦训练方法及系统。


背景技术:

1、糖尿病病变评估是指评估糖尿病患者身体各个器官和系统是否存在糖尿病引起的并发症或病变。糖尿病病变评估的目的是及早发现并治疗糖尿病的并发症,以减少其对患者的健康造成的影响。糖尿病病变评估方法通常包括以下方面:血糖控制评估、肾功能评估、眼底评估以及心血管评估,其中,眼底评估方法是根据糖尿病患者眼底图像进行病变评估,是一种简单、安全无创、可靠、重要的糖尿病病变评估方法,可以帮助医生及早发现糖尿病视网膜病变,为患者提供更好的治疗和预防措施。

2、为了更好的辅助医生通过糖尿病患者眼底图像更准确的评估糖尿病患者视网膜是否病变,每家医院将自己拍摄的糖尿病患者眼底图像结合深度学习模型进行相应的模型训练以实现糖尿病视网膜病变评估任务,但是不同医院之间接受的患者不同,使得各家医院的眼底图像数据信息也不同,并且各家医院之间信息相互独立,会导致每家医院患者信息缺乏完整性,无法充分利用其他医院的知识和经验来进行更好的评估,缺乏整体性的视角,导致出现信息孤岛问题等。

3、受益于机器学习的不断发展,将联邦学习的方式引入糖尿病患者病变评估模型的训练,可以很好地解决信息孤岛问题。其中,联邦学习是一种分布式机器学习方法,它允许服务端和多个客户端在不共享私有数据的情况下协同训练一个共同的机器学习模型。在联邦学习中,客户端是联邦学习中的终端设备或节点,例如智能手机、个人电脑、传感器或其他物联网设备,每个客户端都拥有自己的本地数据集,这些数据通常是隐私敏感的,例如用户文本、图像、位置信息等。联邦学习的目标是在不泄露这些隐私敏感数据的情况下(即在联邦学习中,客户端不会将原始数据上传到服务端(即中心节点),而是在本地进行模型训练,只将本地训练的参数信息上传到服务端,以便与其他客户端共同协作训练全局模型),通过合并每个客户端的本地模型参数来得到一个全局模型。服务端是联邦学习中的中心节点,该中心节点作为联邦学习的协调者和聚合者。它的主要作用是收集来自所有客户端的中间参数,并执行全局模型的聚合。聚合过程可以是简单的加权平均,也可以使用更复杂的算法,例如联邦平均、联邦随机梯度下降(stochastic gradient descent,sgd)等。服务端确保不接触客户端的原始数据,仅通过接收和处理模型参数来促进模型的共同学习。在联邦学习中,深度学习模型是被共同训练的目标模型。初始时,该模型可能是一个预先定义好的模型,如残差神经网络resnet。随着联邦学习的进行,服务端收集来自多个客户端的模型参数,通过聚合这些参数更新来更新全局模型。这个过程通过迭代的方式不断进行,直到模型收敛。因此,在糖尿病患者病变评估模型的联邦训练中,聚合每家医院通过本地数据集训练训练模型得到的局部模型参数可以充分利用其他医院的知识和经验,具备整体性的视角。

4、然而,通过联邦学习的方式训练糖尿病患者病变评估模型时糖尿病患者眼底图像数据集中会出现标签噪声问题(即在联邦学习中每个客户端的糖尿病患者眼底图像数据集中存在错误或不准确的标签),导致最终训练的糖尿病患者病变评估模型性能不好。现有技术下有研究者提出由客户端自身从局部角度出发解决每个客户端的糖尿病患者眼底图像数据集中的标签噪声问题,但是,由客户端自身从局部角度解决标签噪声问题是不够的,原因有两个:1)客户端中有的训练数据或高噪声率使得局部噪声度量具有挑战性;2)从服务端角度评估数据贡献对于模型聚合期间调整权重等至关重要。

5、为此,又有研究者提出数据选择和客户端选择等选择方法来进行改进,这样的方法旨在通过部分参与来解决每个客户端的糖尿病患者眼底图像数据集中的标签噪声,以优化糖尿病患者病变评估模型的性能。但是这种选择方法涉及从每个客户端中选择相关数据子集或从候选者中选择客户子集来训练糖尿病患者病变评估模型。这类选择方法可以是随机的,也可以基于验证性能等标准,或局部更新后自我报告的损失,由于客户端之间的噪声异质性,数据质量领域的选择过程至关重要,需要不同的参与机会,上述选择方法无法保证不同质量数据的不同参与机会,可能会引入残余噪声或信息丢失等风险。即使在选择之后,不可避免的噪声仍然会损害糖尿病患者病变评估模型性能,导致结果不理想。

6、此外,除了上述择方法之外,解决糖尿病患者眼底图像数据集中标签噪声还有鲁棒损失设计、鲁棒架构和鲁棒正则化。

7、第一种鲁棒损失设计涉及修改损失函数,为错误标记的糖尿病患者眼底图像数据分配较低的权重或惩罚,减少它们对糖尿病患者病变评估模型参数更新的影响。然而,使用稳健的损失函数通常需要准确度量或了解糖尿病患者眼底图像数据集中标签噪声分布,随着噪声率的增加,准确调整惩罚项的超参数变得更具挑战性。第二种鲁棒的架构通过将噪声适应层合并到糖尿病患者病变评估模型中来学习标签转换过程来解决糖尿病患者眼底图像数据集中的标签噪声问题。然而,这种方法可能不是一种普遍适用的解决方案,无需修改或调整即可轻松扩展到其他架构。第三种鲁棒正则化技术已被证明可以提高糖尿病患者病变评估模型从干净数据中学习模式的能力,并降低其对标签噪声引起的过度拟合的敏感性。鲁棒正则化技术中的典型代表称为渐进式早期停止(pes)的方法,其重点关注糖尿病患者病变评估模型中后层对糖尿病患者眼底图像数据集中标签噪声的敏感性,通过减少对糖尿病患者病变评估模型中后层的训练来防止模型记忆糖尿病患者眼底图像数据集中的标签噪声,防止模型过度记忆糖尿病患者眼底图像数据集中的标签噪声。pes最初需要对前面的层进行较多轮次的训练,然后用较少的轮次对后面的层进行训练,同时保持前面的层固定。pes的有效性依赖于仔细调整超参数,超参数包括前后层之间的断点及其相应的训练轮次。然而,由于渐进式早停技术中网络层断点、训练轮次等超参数难以确定,导致糖尿病患者病变评估模型的性能难以得到保障,使得该模型不能很好地辅助医生通过糖尿病患者眼底图像更准确的评估糖尿病患者视网膜是否病变。

8、从上述现有技术的描述可知,现有技术下的各种方法均不能很好的训练糖尿病患者病变评估模型。因此,亟需一种能够提高糖尿病患者病变评估模型性能的训练方法。


技术实现思路

1、因此,本发明的目的在于克服上述现有技术的缺陷,提供一种糖尿病患者病变评估模型的联邦训练方法及系统。

2、本发明的目的是通过以下技术方案实现的:

3、根据本发明的第一方面,提供一种糖尿病患者病变评估模型的联邦训练方法,所述方法包括:在服务端部署基准糖尿病患者眼底图像训练集和基于基准糖尿病患者眼底图像训练集预训练的初始糖尿病患者病变评估模型;在多个客户端分别部署局部糖尿病患者眼底图像数据集,其中,所述局部糖尿病患者眼底图像数据集与基准糖尿病患者眼底图像数据集没有交集;

4、在所述服务端和多个客户端对所述初始糖尿病患者病变评估模型进行多轮联邦训练直至收敛,其中,每轮联邦训练包括:

5、s1、服务端下发上一轮联邦训练后的糖尿病患者病变评估模型到每个客户端;

6、s2、每个客户端将来自于服务端的上一轮联邦训练后的糖尿病患者病变评估模型作为其局部糖尿病患者病变评估模型,并采用其对应的局部糖尿病患者眼底图像数据集对局部糖尿病患者病变评估模型进行多次训练并将训练完成的局部糖尿病患者病变评估模型参数和局部糖尿病患者眼底图像数据集中的数据量上传所述服务端,其中,

7、每次训练时,每个客户端基于其对应的局部糖尿病患者眼底图像数据集的噪声率得到局部糖尿病患者病变评估模型中关键参数和非关键参数对应的参数矩阵,利用梯度优化方法更新参数矩阵中的关键参数以得到更新后的局部糖尿病患者病变评估模型参数;

8、s3、服务端基于来自于每个客户端的局部糖尿病患者病变评估模型参数和局部糖尿病患者眼底图像数据集中的数据量,利用聚合方法更新上一轮联邦训练后的糖尿病患者病变评估模型参数。

9、在本发明的一些实施例中,所述服务端还部署有基准糖尿病患者眼底图像验证集,在所述步骤s2中,每次训练局部糖尿病患者病变评估模型时,采用如下方式计算每个客户端对应的局部糖尿病患者眼底图像数据集的噪声率:

10、s21、在服务端将所述基准糖尿病患者眼底图像验证集输入上一轮联邦训练后的糖尿病患者病变评估模型得到输出后基于输出计算基准糖尿病患者眼底图像验证集的全局交叉熵损失集合;

11、s22、在每个客户端将所述局部糖尿病患者眼底图像数据集输入上一轮联邦训练后的糖尿病患者病变评估模型得到输出后基于输出计算每个客户端的局部糖尿病患者眼底图像数据集的局部交叉熵损失集合并上传所述服务端;

12、s23、在服务端基于基准糖尿病患者眼底图像验证集的全局交叉熵损失集合和来自于每个客户端的局部糖尿病患者眼底图像数据集的局部交叉熵损失集合,利用散度函数得到每个客户端的局部糖尿病患者眼底图像数据集的损失阈值;

13、s24、在服务端基于每个客户端的局部糖尿病患者眼底图像数据集的损失阈值和来自于每个客户端的局部糖尿病患者眼底图像数据集的局部交叉熵损失集合,利用噪声率计算方法得到每个客户端的局部糖尿病患者眼底图像数据集的噪声率并发送给每个客户端。

14、在本发明的一些实施例中,在所述步骤s21中,采用如下方式得到基准糖尿病患者眼底图像验证集的全局交叉熵损失集合:

15、

16、其中,q表示基准糖尿病患者眼底图像验证集的全局交叉熵损失集合,l(·)表示损失函数,表示将基准糖尿病患者眼底图像验证集中基准糖尿病患者眼底图像xi输入参数为的上一轮联邦训练后的糖尿病患者病变评估模型的输出,yi表示基准糖尿病患者眼底图像xi对应的标签,bte为基准糖尿病患者眼底图像验证集。

17、在本发明的一些实施例中,在所述步骤s22中,采用如下方式得到每个客户端的局部糖尿病患者眼底图像数据集的局部交叉熵损失集合:

18、

19、n∈(1,n)

20、其中,n为客户端个数,pn表示第n个客户端的局部糖尿病患者眼底图像数据集的局部交叉熵损失集合,表示将第n个客户端的局部糖尿病患者眼底图像数据集中糖尿病患者眼底图像输入参数为的上一轮联邦训练后的糖尿病患者病变评估模型的输出,为糖尿病患者眼底图像对应的标签,为第n个客户端的局部糖尿病患者眼底图像数据集。

21、在本发明的一些实施例中,在所述步骤s23中,采用如下方式得到每个客户端的局部糖尿病患者眼底图像数据集的损失阈值:

22、

23、其中,λn表示第n个客户端的局部糖尿病患者眼底图像数据集的损失阈值,js(·)表示js散度函数,fq(·)表示在基准糖尿病患者眼底图像验证集的全局交叉熵损失集合q上的累积分布函数,表示第n个客户端在损失阈值λn和局部糖尿病患者眼底图像数据集的局部交叉熵损失函数集合pn上的累积分布函数。

24、在本发明的一些实施例中,在所述步骤s24中,服务端采用如下方式得到每个客户端的局部糖尿病患者眼底图像数据集的噪声率:

25、

26、其中,σn表示第n个客户端的局部糖尿病患者眼底图像数据集的噪声率。

27、在本发明的一些实施例中,在所述步骤s2中,每次训练局部糖尿病患者病变评估模型时,每个客户端采用如下方式得到局部糖尿病患者病变评估模型中关键参数和非关键参数对应的参数矩阵:

28、

29、

30、

31、i∈(1,m)

32、其中,表示在第t次训练时第n个客户端的局部糖尿病患者病变评估模型中关键参数和非关键参数对应的参数矩阵,表示在第t次训练时第n个客户端的局部糖尿病患者病变评估模型的第i个参数,表示在第t次训练时第n个客户端的局部糖尿病患者病变评估模型的非关键参数集合,表示在第t次训练时损失函数l(·)关于局部糖尿病患者病变评估模型中第i个参数的梯度,表示在第t次训练时第n个客户端确定关键参数和非关键参数的阈值,表示在第t次训练时第n个客户端的局部糖尿病患者病变评估模型中所有参数梯度的有序集合,m为第n个客户端的局部糖尿病患者病变评估模型的参数量。

33、在本发明的一些实施例中,在所述步骤s2中,每次训练局部糖尿病患者病变评估模型时,每个客户端采用如下方式更新参数矩阵中的关键参数以得到更新后的局部糖尿病患者病变评估模型参数:

34、

35、t∈(1,t)

36、其中,表示在第t次训练时第n个客户端更新后的局部糖尿病患者病变评估模型参数,表示在第t次训练时第n个客户端的局部糖尿病患者病变评估模型参数,η为预设的学习率,表示在第t次训练时损失函数l(·)关于第n个客户端的局部糖尿病患者病变评估模型参数的梯度,t为局部糖尿病患者病变评估模型的训练次数。

37、在本发明的一些实施例中,在所述步骤s3中,服务端采用如下方式更新上一轮联邦训练后的糖尿病患者病变评估模型参数:

38、

39、

40、其中,θ(s)表示更新后的上一轮联邦训练后的糖尿病患者病变评估模型参数,表示第n个客户端局部糖尿病患者眼底图像数据集中的数据量,表示经t次训练后的第n个客户端更新后的局部糖尿病患者病变评估模型参数。

41、根据本发明的第二方面,提供一种糖尿病患者病变评估模型的联邦训练系统,所述系统被配置为采用上述实施例中所述的方法训练糖尿病患者病变评估模型。

42、根据本发明的第三方面,提供一种糖尿病患者视网膜病变评估方法,所述方法包括:

43、获取待评估的糖尿病患者眼底图像;

44、采用上述实施例中所述联邦训练方法训练得到的糖尿病患者病变评估模型对所述待评估的糖尿病患者眼底图像进行评估以得到病变结果。

45、根据本发明的第四方面,提供一种电子设备,包括:一个或多个处理器;存储装置,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述电子设备实现第一、三方面所述方法的步骤。

46、与现有技术相比,本发明的优点在于:

47、本发明上述实施例中提出的糖尿病患者病变评估模型的训练方法通过建立噪声率与非关键参数率之间的关系,为不同的客户端确定不同的重要模型参数,确保高质量数据和模型参数对糖尿病患者病变评估模型的贡献更大,从而引导糖尿病患者病变评估模型快速优化和收敛,提高模型的性能。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1