一种基于模态感知蒸馏网络的肝细胞癌预测方法

文档序号:34014506发布日期:2023-04-29 23:46阅读:89来源:国知局
一种基于模态感知蒸馏网络的肝细胞癌预测方法

本发明涉及生物学,具体涉及一种基于模态感知蒸馏网络的肝细胞癌预测方法。


背景技术:

1、肝细胞癌是指由肝细胞发生的恶性肿瘤,是原发性肝癌常见的病理类型。目前,有采用以下方法进行肝细胞癌的微细血管浸润预测:1、利用极限梯度增强和ct图像的深度学习来预测术前mvi;2、利用3dcnn预测模型,以融合来自多个mr序列的特征;3、将长短期记忆lstm嵌入到cnn中,以融合多模态mr体积来预测hcc患者的mvi;上述这三种方法仅涉及mr图像来预测mvi状态,预测的准确性较低。此外,也有采用知识蒸馏的以下方法进行预测:1、将kd(知识蒸馏)用于从3d光学图像中有效分割神经元结构显微镜图像;2、引用kd的概念,通过扩展掩码边界使用软标签进行脑部损伤分割3、利用kd进行多源传输学习肺模式分析任务;4、制定了一个类别指导的对比蒸馏模块,以在教师和学生模型中从同一类别中拉近正向图像对,同时推开来自不同类别的负向图像对;上述四种方法采用的蒸馏网络只考虑了不同的影像数据,并从输入的影像数据传输信息,分类精度较差,预测的准确性较低。


技术实现思路

1、本发明的目的在于提供一种基于模态感知蒸馏网络的肝细胞癌预测方法,该基于模态感知蒸馏网络的肝细胞癌预测方法通过将具有图像模态和非图像临床数据的教师网络知识迁移到仅具有图像模态的学生网络,提出了用于hccmvi预测的模态感知蒸馏网络(md-net),可有效提高分类精度和预测的准确性。

2、为实现上述目的,本发明采用以下技术方案:

3、一种基于模态感知蒸馏网络的肝细胞癌预测方法,包括以下步骤:

4、s1、获取肝细胞癌患者的数据集,并根据五折交叉验证方案将整个数据集分成五折,在每一轮交叉验证中,将其中一折数据作为测试集,将其他四折数据作为训练集;

5、s2、对数据进行预处理,为所有患者的肿瘤找到最大的外接立方体,再移除除立方体以外的其他非肿瘤区域;

6、s3、建立模态感知蒸馏网络,并对模态感知蒸馏网络进行训练,模态感知蒸馏网络用于将教师网络通过临床数据模态和图像模态融合学习的知识转移到仅具有图像模态的学生网络;

7、s4、通过训练后的模态感知蒸馏网络进行肝细胞癌预测。

8、优选地,步骤s1中所述数据集由270名经病理证实的hcc患者数据组成,270名患者包括128名m0患者、93名m1患者和49名m2患者;其中,m0表示无微血管侵犯,m1表示侵入血管不超过5个或位于肿瘤表面附近1cm以内,m2表示侵入血管超过5个或距离肿瘤表面1cm以上。

9、优选地,步骤s2中将所述立方体的大小设置为80*80*20像素。

10、优选地,步骤s3中所述模态感知蒸馏网络的训练过程具体为:

11、s31、教师网络将hbp图像和临床数据传递到mri-clinicalfusion模块中,提取出512维向量特征;再将pre图像和临床数据输入到另一个mri-clinical fusion模块中,以获得另一个512维向量特征;将得到的两个512维向量特征输入到sa模块中,得到融合了彼此信息的新特征和最后将新特征和拼接起来,生成zt,并将zt传递到两个全连接层中,以预测分类结果pt;

12、s32、学生网络以3dhbpmri图像和3dpremri图像作为输入,将hbp数据传递到mri-only模块以获得特征,再将pre数据传递到另一个mri-only模块中以获得特征;将得到的两个特征输入到sa模块中,得到融合了彼此信息的新特征和其中,新特征和是两个包含512维度的特征向量;最后将新特征和连接起来,生成zs,并将zs输入到全连接层中,以预测mvi分类结果ps;

13、s33、在学生网络中,引入一个回归任务,将输入hbp图像和输入pre图像的连接特征zs输送到两个全连接层中,预测52维向量pc,用于估计潜在的临床信息,再利用输入的临床数据作为预测pc的真实标签;

14、s34、采用分类级蒸馏损失和特征级蒸馏损失,将教师网络临床数据和mri图像中融合的特征蒸馏到从mri图像中提取的特征,利用知识蒸馏策略将教师网络的临床信息转换到学生网络。

15、优选地,步骤s31中所述mri-clinicalfusion模块集成了mri数据和非影像临床数据,以3dmri数据和矢量化临床数据为输入,mri-clinicalfusion模块在输入临床数据上应用四个全连接层,获得四个特征图,这些特征通道分别为64、128、256和256;利用输入mri图像上的四个卷积块来获得另一个3d特征图,并且特征通道也被设置为64、128、256和256,每个卷积块由两个3×3卷积层组成;将临床数据中的四个特征图和mri数据中的相应四个特征做channel-wise相乘,以将它们整合在一起,再应用3×3卷积层和一个全连接层,输出具有512维度的特征向量。

16、优选地,步骤s32中所述mri-only模块从3dmri图像中提取512维度的特征向量;所述mri-only模块由九个卷积块和一个全连接层组成,每个卷积块包含一个批处理规范层、一个relu激活层和一个3×3卷积层,用于提高网络的鲁棒性;将九个卷积块的输出特征的通道数设置为不同,前五层的特征通道为32、32、64、64和128,后四层的特征通道为256、128、256和256,用于平衡效率和计算负担。

17、优选地,步骤s31和s32中所述sa模块为对称注意力模块,x和y表示sa模块的输入两个特征图,sa模块在x上应用线性变换层以获得三个特征图,包括query向量qx,key向量kx和value向量vx;sa模块在y上应用性变换层来生成key特征映射ky和value特征映射vy;通过乘以qx和kx的转置来生成score特征向量sx,通过乘以qy和ky的转置来生成另一个score特征向量sy;将所获得的score特征向量sx与value特征向量vx相乘,并将sy与vy相乘,生成两个结果特征向量,再将其相加,最终生成输出细化特征向量

18、

19、sa模块在y上应用另一个线性变换层,获得特征向量query向量qy,通过将qy和ky的转置相乘,并将qy和ky的转置相乘来计算两个score特征向量,再通过以下公式计算精化特征向量

20、

21、优选地,步骤s3中模态感知蒸馏包括分类级蒸馏和特级蒸馏;

22、在分类级蒸馏中,让表示从学生网络产生的mri影像数据xi所属类的类别概率,而表示从教师网络产生的mri影像数据xi所属类的类别概率;定义分类水平蒸馏损失以使来自教师网络的类别概率成为训练学生网络的目标,利用kullback-leibler散度来测量两个分布的差异:

23、

24、其中,n和m分别表示训练样本的数量和总类别的数量,dkl(·)表示两种概率之间的kullback-leibler散度,表示学生网络预测样本,表示教师网络预测样本;

25、在特征级蒸馏中,将特征级蒸馏损失计算为和之间的kullback-leibler散度与和之间的kulback-leibleer散度的组合:

26、

27、其中,β用于加权kullback-leibler散度项,权重β1=1,表示两个特征和之间的kullback-leibler散度,表示两个特征和之间的kullback-leibler散度,表示学生网络预测样本,表示教师网络预测样本;

28、最终损失函数包括教师网络和学生网络上的两个监督损失,临床数据预测的自我监督损失,以及学生网络和教师网络之间的蒸馏损失,损失函数的定义如下:

29、

30、其中,和分别表示教师网络预测的监督损失和学生网络预测的监督损失,使用focalloss损失计算和的预测损失,lclinical表示临床数据预测的自我监督损失,使用交叉熵损失来计算pc的预测误差和临床数据的基本事实;表示分类级蒸馏损失,表示教师网络和学生网络之间的方程特征级蒸馏损失,利用方程ltotal的损失函数来训练模态感知蒸馏网络以进行miv预测。

31、采用上述技术方案后,本发明具有如下有益效果:

32、1、本发明通过将具有图像模态和非图像临床数据的教师网络知识迁移到仅具有图像模态的学生网络,提出了用于hccmvi预测的模态感知蒸馏网络(md-net),可有效提高分类精度和预测的准确性。

33、2、本发明的模态感知蒸馏网络(md-net)的学生网络包括两个仅用于提取mri特征的mri-only模块,以及一个用于从两个mri图像中细化特征的对称注意力(sa)模块,而教师网络包括两种mri-clinicalfusion融合模块,用于将mri数据和带有52维度向量的临床数据融合,以及一种用于细化两个融合特征的sa模块。

34、3、本发明除了原始分类级别的结果蒸馏,模态感知蒸馏网络(md-net)还设计了一个特征级别的蒸馏,以更好地将临床数据从教师网络传输到学生网络。此外,还设计了一个新的自我监督任务,从图像数据中预测临床数据,以进一步增强mvi预测。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1