基于多尺度知识蒸馏的结核性脑膜脑炎预测方法及系统

文档序号:34923518发布日期:2023-07-28 03:40阅读:14来源:国知局
基于多尺度知识蒸馏的结核性脑膜脑炎预测方法及系统与流程

本发明涉及人工智能与智慧医疗,尤其是涉及一种基于多尺度知识蒸馏的结核性脑膜脑炎预测方法及系统。


背景技术:

1、在结核性脑膜脑炎的影像诊断过程中,影像科医生需要通过观察头部各个剖面的所有mri图像,在不同剖面对应位置上同时定位到病灶才能进行诊断。这种诊断方式不仅效率低,诊断的错误率也比较高。因此,如何帮助医生实现高效、准确的进行影像诊断是亟待解决的难题。

2、而在目前的深度学习图像分类模型中,大多数模型都是利用网络结构的最后一层特征图进行分类预测。由于网络的不同特征图上学习到原图信息范围的不同,越靠近顶部的特征图保留着原图更多的细节信息,越靠近底部的特征图保留的是原图整体的语义信息,因此细节信息在底层特征图有所丢失,这就造成了分类的不准确。


技术实现思路

1、本发明旨在至少解决现有技术中存在的技术问题之一。为此,本发明提出一种基于多尺度知识蒸馏的结核性脑膜脑炎预测方法及系统,能够提高结核性脑膜脑炎预测的准确度。

2、第一方面,本发明实施例提供了一种基于多尺度知识蒸馏的结核性脑膜脑炎预测方法,所述基于多尺度知识蒸馏的结核性脑膜脑炎预测方法包括:

3、对输入的结核性脑膜脑炎图像进行多尺度拆分,获得多尺度特征图;

4、根据transformer结构构建特征编码器,并将所述多尺度特征图输入至所述特征编码器中,获得每种尺度特征图的单维特征向量;

5、采用结核性脑膜脑炎图像数据集对教师模型进行预训练,获得所述教师模型的权重,并根据所述教师模型的权重,获得多尺度教师预测结果;

6、将所述多尺度教师预测结果和通过学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第一损失,并将真实标签和通过所述学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第二损失;

7、将所述第一损失和所述第二损失相加,获得总损失;

8、将所述每种尺度特征图的单维特征向量进行拼接,获得拼接向量;

9、根据所述拼接向量和所述总损失,通过所述学生模型获得结核性脑膜脑炎预测结果。

10、与现有技术相比,本发明第一方面具有以下有益效果:

11、本方法通过对输入的结核性脑膜脑炎图像进行多尺度拆分,获得多尺度特征图;根据transformer结构构建特征编码器,并将多尺度特征图输入至特征编码器中,获得每种尺度特征图的单维特征向量,提取多种尺度特征信息,为后面学生模型进行分类预测提高准确度;采用结核性脑膜脑炎图像数据集对教师模型进行预训练,获得教师模型的权重,并根据教师模型的权重,获得多尺度教师预测结果,将多尺度教师预测结果和通过学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第一损失,并将真实标签和通过学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第二损失,将第一损失和第二损失相加,获得总损失,通过将复杂模型(即教师模型)学到的知识迁移到轻量级模型(即学生模型)上,对轻量级模型进行了多尺度损失计算,克服了轻量级模型因数据不足带来的影响;将每种尺度特征图的单维特征向量进行拼接,获得拼接向量,根据拼接向量和总损失,通过学生模型获得结核性脑膜脑炎预测结果,通过将每种尺度特征图的单维特征向量进行拼接,使得顶层的细节信息和底层抽象的语义信息结合起来,采用拼接向量和总损失进行预测,能够获得很好的预测效果。因此,本方法能够辅助医生判断结核性脑膜脑炎,提高结核性脑膜脑炎判断的准确度,减少判断失误。

12、根据本发明的一些实施例,将resnet50网络作为学生模型,所述对输入的结核性脑膜脑炎图像进行多尺度拆分,获得多尺度特征图,包括:

13、将所述结核性脑膜脑炎图像输入所述学生模型进行多尺度拆分;

14、提取所述学生模型中第二层layer2、第三层layer3和第四层layer4的特征图,获得多尺度特征图。

15、根据本发明的一些实施例,所述将所述多尺度特征图输入至所述特征编码器中,获得每种尺度特征图的单维特征向量,包括:

16、通过所述特征编码器中的卷积层和展平层将所述多尺度特征图转换为多个特征向量;

17、采用注意力机制计算每种特征向量与其他特征向量之间的相关性;

18、根据所述相关性,对每种尺度特征图对应的特征向量通过全局求平均的方式获得所述每种尺度特征图的单维特征向量。

19、根据本发明的一些实施例,通过如下方式获得所述注意力机制:

20、multihead(q,k,v)=concat(head1(q1,k1,v1),…,headh(qh,kh,vh))wo

21、headi(qi,ki,vi)=attention(qi,ki,vi)

22、

23、其中,q、k、v表示一组可学习的权重矩阵,qi,ki,vi表示q、k、v权重矩阵相应的分量,wo表示一个可学习的参数,表示比例因子。

24、根据本发明的一些实施例,采用以注意力机制搭建的swin-transfomer网络模型作为教师模型,所述采用结核性脑膜脑炎图像数据集对教师模型进行预训练,获得所述教师模型的权重,并根据所述教师模型的权重,获得多尺度教师预测结果,包括:

25、采用结核性脑膜脑炎图像数据集对所述教师模型进行预训练,获得所述教师模型的权重;

26、输入结核性脑膜脑炎图像至所述教师模型,获取所述教师模型中第二层layer2、第三层layer3和第四层layer4输出的特征向量;

27、将所述教师模型中第二层layer2、第三层layer3和第四层layer4输出的特征向量通过所述教师模型的权重进行分类预测,获得多尺度教师预测结果。

28、根据本发明的一些实施例,所述将所述多尺度教师预测结果和通过学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第一损失,并将真实标签和通过所述学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第二损失,包括:

29、将所述教师模型中第二层layer2、第三层layer3和第四层layer4对应的所述多尺度教师预测结果分别和所述学生模型中第二层layer2、第三层layer3和第四层layer4对应的多尺度学生预测结果进行交叉熵损失计算,获得第一损失;

30、将所述真实标签和所述学生模型中第二层layer2、第三层layer3和第四层layer4对应的多尺度学生预测结果进行交叉熵损失计算,获得第二损失。

31、根据本发明的一些实施例,通过如下方式计算获得总损失:

32、ζ(x,w)=α*ce(y,σ(zs;t=1))+β*ce(σ(zt;t=τ),σ(zs,t=τ))

33、

34、其中,x表示输入,w表示学生模型的参数,y表示真实标签,ce表示交叉熵损失函数,σ表示softmax temperature激活函数,zs表示学生模型神经元的输出,zt表示教师模型神经元的输出,α和β表示两个权重参数,t表示温度系数,τ表示具体的值,zi表示神经元。

35、第二方面,本发明实施例还提供了一种基于多尺度知识蒸馏的结核性脑膜脑炎预测系统,所述基于多尺度知识蒸馏的结核性脑膜脑炎预测系统包括:

36、多尺度特征图获取单元,用于对输入的结核性脑膜脑炎图像进行多尺度拆分,获得多尺度特征图;

37、单维特征向量获取单元,用于根据transformer结构构建特征编码器,并将所述多尺度特征图输入至所述特征编码器中,获得每种尺度特征图的单维特征向量;

38、教师模型预训练单元,用于采用结核性脑膜脑炎图像数据集对教师模型进行预训练,获得离线蒸馏教师模型,并根据所述离线蒸馏教师模型,获得多尺度教师预测结果;

39、交叉熵损失计算单元,用于将所述多尺度教师预测结果和通过学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第一损失,并将真实标签和通过所述学生模型获得的多尺度学生预测结果进行交叉熵损失计算,获得第二损失;

40、总损失计算单元,用于将所述第一损失和所述第二损失相加,获得总损失;

41、拼接向量获取单元,用于将所述每种尺度特征图的单维特征向量进行拼接,获得拼接向量;

42、预测结果获取单元,用于根据所述拼接向量和所述总损失,通过所述学生模型获得结核性脑膜脑炎预测结果。

43、第三方面,本发明实施例还提供了一种基于多尺度知识蒸馏的结核性脑膜脑炎预测设备,包括至少一个控制处理器和用于与所述至少一个控制处理器通信连接的存储器;所述存储器存储有可被所述至少一个控制处理器执行的指令,所述指令被所述至少一个控制处理器执行,以使所述至少一个控制处理器能够执行如上所述的一种基于多尺度知识蒸馏的结核性脑膜脑炎预测方法。

44、第四方面,本发明实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可执行指令,所述计算机可执行指令用于使计算机执行如上所述的一种基于多尺度知识蒸馏的结核性脑膜脑炎预测方法。

45、可以理解的是,上述第二方面至第四方面与相关技术相比存在的有益效果与上述第一方面与相关技术相比存在的有益效果相同,可以参见上述第一方面中的相关描述,在此不再赘述。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1