本技术涉及人工智能,尤其涉及一种模型训练方法、装置、电子设备及存储介质。
背景技术:
1、阿尔茨海默症(alzheimer's disease)是一种以全面性认知功能衰退为主要表现的渐进性退行性脑病,包括学习、记忆、语言、执行功能、复杂注意等能力减退和人格改变,其起病隐匿,不伴有意识障碍,病程多数为不可逆,是威胁老年人健康的严重疾病之一,且目前尚无有效的治疗药物。
2、近年来,许多研究工作在眼底图像上有效利用了深度学习技术,探索了对一些疾病进行筛查和预测的可行性。但目前现有工作存在很大的局限性,主要存在以下几项主要问题:
3、1、阿尔茨海默症患者数据的稀缺性;
4、2、眼底图像中与阿尔茨海默症相关的明显特征难以观察,因此增大了模型的学习难度。
5、基于上述内容可知,现有技术中由于阿尔茨海默症患者数据稀缺、眼底图像中与阿尔茨海默症相关的明显特征难以观察,从而增加了阿尔茨海默症预测模型的训练难度,导致训练得到的阿尔茨海默症预测模型的预测准确率较低。
技术实现思路
1、本技术实施例提供一种模型训练方法、装置、电子设备及存储介质,以解决现有技术中由于阿尔茨海默症患者数据稀缺、眼底图像中与阿尔茨海默症相关的明显特征难以观察,从而增加了阿尔茨海默症预测模型的训练难度,导致训练得到的阿尔茨海默症预测模型的预测准确率较低的问题。
2、为了解决上述技术问题,本技术实施例是这样实现的:
3、第一方面,本技术实施例提供了一种模型训练方法,所述方法包括:
4、获取用户的符合模型训练条件的样本眼底图像,所述样本眼底图像标注有阿尔兹海默症的真实疾病类别;
5、将所述样本眼底图像输入至待训练疾病预测模型,所述待训练疾病预测模型包括:特征筛选及血管分割层和特征融合层;
6、调用所述特征筛选及血管分割层对所述样本眼底图像进行特征筛选及血管分割标注处理,得到注意力特征图和血管标注特征图;
7、调用所述特征融合层对所述注意力特征图和所述血管标注特征图进行特征融合处理,得到所述样本眼底图像对应的预测疾病概率;
8、基于所述真实疾病类别和所述预测疾病概率,计算得到所述待训练疾病预测模型的损失值;
9、在所述损失值处于预设范围内的情况下,得到用于预测阿尔兹海默症的疾病预测模型。
10、可选地,所述特征筛选及血管分割层包括:注意力层和血管分割层,
11、所述调用所述特征筛选及血管分割层对所述样本眼底图像进行特征筛选及血管分割标注处理,得到注意力特征图和血管标注特征图,包括:
12、调用所述注意力层对所述样本眼底图像进行特征筛选处理,输出注意力特征图;
13、调用所述血管分割层对所述样本眼底图像进行血管分割标注处理,输出血管标注特征图。
14、可选地,所述注意力层包括:n个卷积层、特征图处理层和输出层,n个卷积层的参数不同;
15、所述调用所述注意力层对所述样本眼底图像进行特征筛选处理,输出注意力特征图,包括:
16、调用所述n个卷积层分别对所述样本眼底图像进行通道压缩处理,生成n个分支特征图;
17、调用所述特征图处理层对所述n个分支特征图进行处理,生成融合特征图;
18、调用所述输出层对所述融合特征图进行处理,输出所述注意力特征图。
19、可选地,在n=3时,所述n个分支特征图包括:第一分支特征图、第二分支特征图和第三分支特征图,
20、所述调用所述特征图处理层对所述n个分支特征图进行处理,生成融合特征图,包括:
21、调用所述特征图处理层对所述第一分支特征图进行转置操作,得到第一转置特征图;
22、将所述第一转置特征图与所述第二分支特征图进行矩阵相乘,得到中间特征图;
23、对所述中间特征图进行归一化处理,得到归一化特征图;
24、对所述归一化特征图和所述第三分支特征图进行矩阵相乘,得到所述融合特征图。
25、可选地,所述调用所述血管分割层对所述样本眼底图像进行血管分割标注处理,输出血管标注特征图,包括:
26、调用所述血管分割层对所述样本眼底图像内的血管进行血管分割标注处理,得到血管分割特征图;
27、对所述血管分割特征图进行池化处理,输出与所述注意力特征图维度相同的所述血管标注特征图。
28、可选地,所述调用所述特征融合层对所述注意力特征图和所述血管标注特征图进行特征融合处理,得到所述样本眼底图像对应的预测疾病概率,包括:
29、调用所述特征融合层对所述注意力特征图和所述血管标注特征图进行矩阵点乘处理,得到点乘特征图;
30、对所述点乘特征图与所述注意力特征图进行叉乘处理,得到目标特征图;
31、基于所述目标特征图确定所述样本眼底图像对应的预测疾病概率。
32、可选地,所述基于所述真实疾病类别和所述预测疾病概率,计算得到所述待训练疾病预测模型的损失值,包括:
33、基于所述真实疾病类别和所述预测疾病概率,计算得到交叉熵损失值;
34、基于本轮训练的样本数量和所述交叉熵损失值,计算得到所述待训练疾病预测模型的损失值。
35、可选地,所述获取用户的符合模型训练条件的样本眼底图像,包括:
36、获取所述用户的多幅初始眼底图像;
37、调用预设模型对所述初始眼底图像进行处理,以筛选出所述初始眼底图像中符合模型训练条件的所述样本眼底图像,所述样本眼底图像为真实眼底图像类别且图像质量大于质量阈值的眼底图像。
38、可选地,所述调用预设模型对所述初始眼底图像进行处理,以筛选出所述初始眼底图像中符合模型训练条件的所述样本眼底图像,包括:
39、对所述初始眼底图像进行预处理,生成预处理眼底图像;
40、调用眼底图像识别模型对所述预处理眼底图像进行处理,得到眼底图像识别结果;
41、根据所述眼底图像识别结果,确定所述初始眼底图像中的标准眼底图像;
42、调用图像质量等级分类模型对所述预处理眼底图像进行处理,得到所述预处理眼底图像属于各个预设质量等级的概率;
43、根据所述概率,从所述标准眼底图像中筛选出符合模型训练条件的样本眼底图像。
44、第二方面,本技术实施例提供了一种模型训练装置,所述装置包括:
45、样本图像获取模块,用于获取用户的符合模型训练条件的样本眼底图像,所述样本眼底图像标注有阿尔兹海默症的真实疾病类别;
46、样本图像输入模块,用于将所述样本眼底图像输入至待训练疾病预测模型,所述待训练疾病预测模型包括:特征筛选及血管分割层和特征融合层;
47、标注特征图获取模块,用于调用所述特征筛选及血管分割层对所述样本眼底图像进行特征筛选及血管分割标注处理,得到注意力特征图和血管标注特征图;
48、预测类别获取模块,用于调用所述特征融合层对所述注意力特征图和所述血管标注特征图进行特征融合处理,得到所述样本眼底图像对应的预测疾病概率;
49、损失值计算模块,用于基于所述真实疾病类别和所述预测疾病概率,计算得到所述待训练疾病预测模型的损失值;
50、预测模型获取模块,用于在所述损失值处于预设范围内的情况下,得到用于预测阿尔兹海默症的疾病预测模型。
51、可选地,所述特征筛选及血管分割层包括:注意力层和血管分割层,
52、所述标注特征图获取模块包括:
53、注意力特征图输出单元,用于调用所述注意力层对所述样本眼底图像进行特征筛选处理,输出注意力特征图;
54、标注特征图输出单元,用于调用所述血管分割层对所述样本眼底图像进行血管分割标注处理,输出血管标注特征图。
55、可选地,所述注意力层包括:n个卷积层、特征图处理层和输出层,n个卷积层的参数不同;
56、所述注意力图输出单元包括:
57、分支特征图生成子单元,用于调用所述n个卷积层分别对所述样本眼底图像进行通道压缩处理,生成n个分支特征图;
58、融合特征图生成子单元,用于调用所述特征图处理层对所述n个分支特征图进行处理,生成融合特征图;
59、注意力特征图输出子单元,用于调用所述输出层对所述融合特征图进行处理,输出所述注意力特征图。
60、可选地,在n=3时,所述n个分支特征图包括:第一分支特征图、第二分支特征图和第三分支特征图,
61、所述融合特征图生成单元包括:
62、转置特征图获取子单元,用于调用所述特征图处理层对所述第一分支特征图进行转置操作,得到第一转置特征图;
63、中间特征图获取子单元,用于将所述第一转置特征图与所述第二分支特征图进行矩阵相乘,得到中间特征图;
64、归一化特征图获取子单元,用于对所述中间特征图进行归一化处理,得到归一化特征图;
65、融合特征图获取子单元,用于对所述归一化特征图和所述第三分支特征图进行矩阵相乘,得到所述融合特征图。
66、可选地,所述标注特征图输出单元包括:
67、分割特征图获取子单元,用于调用所述血管分割层对所述样本眼底图像内的血管进行血管分割标注处理,得到血管分割特征图;
68、标注特征图输出子单元,用于对所述血管分割特征图进行池化处理,输出与所述注意力特征图维度相同的所述血管标注特征图。
69、可选地,所述预测类别获取模块包括:
70、点乘特征图获取单元,用于调用所述特征融合层对所述注意力特征图和所述血管标注特征图进行矩阵点乘处理,得到点乘特征图;
71、目标特征图获取单元,用于对所述点乘特征图与所述注意力特征图进行叉乘处理,得到目标特征图;
72、预测类别确定单元,用于基于所述目标特征图确定所述样本眼底图像对应的预测疾病概率。
73、可选地,所述损失值计算模块包括:
74、交叉熵损失计算单元,用于基于所述真实疾病类别和所述预测疾病概率,计算得到交叉熵损失值;
75、损失值计算单元,用于基于本轮训练的样本数量和所述交叉熵损失值,计算得到所述待训练疾病预测模型的损失值。
76、可选地,所述样本图像获取模块包括:
77、初始图像获取单元,用于获取所述用户的多幅初始眼底图像;
78、样本图像筛选单元,用于调用预设模型对所述初始眼底图像进行处理,以筛选出所述初始眼底图像中符合模型训练条件的所述样本眼底图像,所述样本眼底图像为真实眼底图像类别且图像质量大于质量阈值的眼底图像。
79、可选地,所述样本图像筛选单元包括:
80、预处理图像生成子单元,用于对所述初始眼底图像进行预处理,生成预处理眼底图像;
81、识别结果获取子单元,用于调用眼底图像识别模型对所述预处理眼底图像进行处理,得到眼底图像识别结果;
82、标准图像确定子单元,用于根据所述眼底图像识别结果,确定所述初始眼底图像中的标准眼底图像;
83、质量概率获取子单元,用于调用图像质量等级分类模型对所述预处理眼底图像进行处理,得到所述预处理眼底图像属于各个预设质量等级的概率;
84、样本图像筛选子单元,用于根据所述概率,从所述标准眼底图像中筛选出符合模型训练条件的样本眼底图像。
85、第三方面,本技术实施例提供了一种电子设备,包括:
86、存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现上述任一项所述的模型训练方法。
87、第四方面,本技术实施例提供了一种可读存储介质,当所述存储介质中的指令由电子设备的处理器执行时,使得电子设备能够执行上述任一项所述的模型训练方法。
88、在本技术实施例中,通过获取用户的符合模型训练条件的样本眼底图像,样本眼底图像标注有阿尔兹海默症的真实疾病类别。将样本眼底图像输入至待训练疾病预测模型,待训练疾病预测模型包括:特征筛选及血管分割层和特征融合层。调用特征筛选及血管分割层对样本眼底图像进行特征筛选及血管分割标注处理,得到血管标注特征图。调用特征融合层对注意力特征图和血管标注特征图进行特征融合处理,得到样本眼底图像对应的预测疾病概率。基于真实疾病类别和预测疾病概率,计算得到待训练疾病预测模型的损失值。在损失值处于预设范围内的情况下,得到用于预测阿尔兹海默症的疾病预测模型。本技术实施例通过对眼底图像中的血管进行分割标注,从而可以使模型能够更有效地学习阿尔茨海默症需要的关键临床信息(即血管分割标注特征),通过对眼底图像中的特征进行筛选,能够有效避免错误信息的干扰,弥补了传统方法难以直接从眼底图像发现疾病特征的局限性,可以较大提高对阿尔茨海默症的预测准确率。
89、上述说明仅是本技术技术方案的概述,为了能够更清楚了解本技术的技术手段,而可依照说明书的内容予以实施,并且为了让本技术的上述和其它目的、特征和优点能够更明显易懂,以下特举本技术的具体实施方式。