一种基于对比测试时间适应的心率失常分类方法

文档序号:33624168发布日期:2023-03-25 14:28阅读:65来源:国知局
一种基于对比测试时间适应的心率失常分类方法

1.本发明属于心率分析领域,具体涉及一种基于对比测试时间适应的心率失常分类方法。


背景技术:

2.当测试数据与训练数据来自相同的分布时,深度神经网络在各种应用中工作得非常好,但是当训练分布和测试分布之间存在分布变化时,性能会急剧下降,在现实世界中,模型的性能的下降可能会使模型分类效果变得很差从而无法继续使用。最近有很多工作致力于训练鲁棒的模型,虽然这是一个可行的研究方向,但它需要修改模型的训练过程。出于隐私/存储方面的考虑,修改模型训练过程可能并不总是可行的,因为训练数据可能不再可用,所有可用的仅仅是预训练的模型。因此,人们对测试时间适应越来越感兴趣,在这种设置中模型可以在测试时间进行自适应,而无需改变训练过程或要求访问原始训练数据。
3.心电图(electrocardiogram,简称ecg)被心脏病专家和医疗从业人员广泛用于诊断和应对心律失常和心肌梗塞等严重心血管综合症,基于心电图的心跳分类可以为心脏病专家提供有关慢性心血管疾病的准确信息。由于心血管疾病是全球主要的死亡来源,所以用于诊断心血管疾病的智能系统非常受欢迎。当前应用于心电图分类的机器学习系统要么依赖于手动提取的特征,要么依赖于直接利用一维心电信号的大型复杂深度神经网络,这些传统方法的缺点是特征提取部分和模式分类部分的分离,且这些方法需要有关输入数据和所选特征的专业知识,需要专家手动提取特征,特征提取的过程不仅耗时,而且成本非常高。
4.目前现有技术还存在以下问题:
5.1)ecg数据类别不平衡。类别不平衡(class-imbalance)是指分类任务中不同类别的训练样例数目差别很大的情况,在ecg数据分类任务中类别不平衡现象尤为显著,因为在心电图相关的数据集中,正常心拍数量很多,而异常心拍数量很少,这导致模型的预测性能较差,特别是针对样本较少类别的预测。
6.2)需要手动提取特征。使用ecg信号进行心跳分类的传统方法主要依赖于使用信号处理技术手动制作或手动提取的特征,例如基于数字滤波器的方法、基于阈值的方法、傅里叶变换和小波变换,这些传统方法的缺点是特征提取部分和模式分类部分的分离。此外,这些方法需要有关输入数据和所选特征的专业知识,手动提取特征的过程需要大量的人力财力,且得到的特征很可能包含一些噪声数据。
7.3)分类效率低。传统的ecg数据分类方法通常需要使用复杂的神经网络才能实现较好的分类效果,但是要训练一个复杂的神经网络通常需要花费很长的时间,在现实的机器感知系统中,通常对时效性有一定要求,本文提出的对比测试时间适应方法基本上能达到实时适应,实时分类的效果,且分类效果与传统方法相当。


技术实现要素:

8.为了解决上述背景技术提到的技术问题,本发明提出了一种。
9.为了实现上述技术目的,本发明的技术方案为:
10.一种基于对比测试时间适应的心率失常分类方法,包括以下步骤:
11.s1、采集心律失常数据并生成数据集,通过数据集自身标注的q峰位置采集心拍;
12.s2、使用smote算法对步骤s1采集的心拍进行类别平衡采样;
13.s3、使用多模态图像融合框架将一维心拍转换成二维图像,划分数据集;
14.s4、搭建训练阶段的卷积神经网络模型;
15.s5、初始化步骤s4搭建的卷积神经网络模型参数;将步骤s3划分的数据集输入到卷积神经网络模型中进行训练;
16.s6、将步骤s5训练好的卷积神经网络模型进行对比测试时间域适应;包括改进伪标签、自监督对比学习以及正则化;经过对比测试时间域适应后,在适应的过程中更新源模型的权重参数,将经过调整后的模型用于最终的心率失常分类任务。
17.优选地,步骤s1具体指:采用mitdb、mitsa心律失常数据库作为数据集,其中mitdb心律失常数据库采样频率为360hz,通过波形数据库工具包将mitsa心律失常数据重采样到360hz,将除正常心拍、室上异位搏动、室性异位搏动、融合拍以及不可分类拍以外的数据删除,根据数据集的注释文件得到q峰所在的位置下标索引,然后以q峰的位置为基准,向左包括150个点,向右包括100个点,然后将这一段数据截取出来作为心拍。
18.优选地,步骤s2具体指:通过在少数类样本与其3个最近邻样本之间线性内插的方法合成新的样本,使用的smote算法将正常心拍、室上异位搏动、室性异位搏动、融合拍以及不可分类拍均采样到20000个样本。
19.优选地,步骤s3中使用多模态图像融合框架将一维心拍转换成二维图像;将一维ecg数据转换成三种不同类型的二维灰度图像,分别是格拉姆角场、递归图和马尔可夫变迁场,然后将这三个灰度图像组合起来形成三通道彩色图像,将mitdb心律失常数据库重新打上标签并打乱重排,将mitdb心律失常数据库分为训练集和测试集,取前90%的数据作为训练集,后10%的数据作为测试集。
20.优选地,步骤s4中训练阶段使用的网络模型是resnet-18,resnet-18中包含17个卷积层和一个全连接层,所以将resnet-18的全连接层输出维度改为5,。
21.优选地,步骤s5中,模型学习率为10^(-4),样本批次大小为256,迭代次数为20,使用adam优化学习率,其中平滑因子β_1=0.6,β_2=0.98,损失函数使用交叉熵。
22.优选地,步骤s6中改进伪标签具体指:使用预训练模型的权重初始化目标模型,使用目标模型为未标记的目标数据生成伪标签,建立一个长度为m的内存队列qw存储弱增强目标样本的特征向量ew和预测概率pw,每当输入一个测试样本x
t
后,使用目标模型得到特征向量ew和预测概率pw对队列进行更新,使用动量更新的方法更新目标模型参数θ,公式表达如下:
23.θ
t+1
=mθ
t+1
+(1-m)θ
t
24.其中,θ
t+1
表示时间步t+1时目标模型的参数,m是动量超参数,本发明中,m=0.2;每当输入一个测试样本x
t
后,得到x
t
的弱增强图像的特征向量e
t
,然后从内存队列qw中取出与e
t
余弦距离最近的10个特征向量ew和预测概率pw,然后求10个概率的平均值最终伪
标签公式表达如下:
[0025][0026]
其中,argmax是pytorch框架中内置的函数,能返回指定维度最大值的序号。
[0027]
优选地,步骤s6中自监督对比学习具体指:将测试图像编码成两个不同强增强视图的特征query(q)和key(k),q,k被视为正样本。维持一个长度为n的负样本的队列q
negtive
存储负样本的特征k-,对比学习损失函数的目标是最小化q,k之间的余弦距离,同时最大化q和队列q
negtive
中每一个负样本特征k-之间的余弦距离,对比损失函数公式表达如下:
[0028][0029]
其中,n表示队列q
negtive
的长度,q表示特征query,k
+
表示特征key,ki表示队列中的第i个负样本,τ是温度系数超参数,τ=0.1。
[0030]
优选地,步骤s6中采用了一致性正则化维持对图像的强增强图像和弱增强图像之间预测的一致性,公式表达如下:
[0031][0032]
其中,c表示分类任务的类别数,c=5,pw表示弱增强图像通过模型输出的预测概率,qs表示强增强图像通过模型输出的预测概率
[0033]
优选地,步骤s6采用了多样性正则化减少模型在适应过程中盲目相信错误标签导致模型出现错误积累,公式表达如下:
[0034][0035]
其中,c表示分类任务的类别数,c=5,ps表示强增强图像通过模型输出的预测概率;
[0036]
最终损失函数表达式如下所示:
[0037]
l=αl
cont
+βl
con
+γl
div
[0038]
其中,α,β,γ为控制各损失之前的权重值,α=0.9,β=0.7,γ=0.5。
[0039]
采用上述技术方案带来的有益效果:
[0040]
(1)本发明提出的基于多模态图像融合与对比测试时间适应的心率失常分类方法可以解决由于ecg数据集类别不平衡而导致模型分类性能差,通过使用smote算法将类别数目都采样到相同的数量,从而解决类别不平衡对模型性能的影响。
[0041]
(2)本发明提出的基于多模态图像融合的心率失常分类方法不需要手动提取特征,使用多模态图像融合框架将一维心电数据转换成二维图像,而多模态融合可以通过高效的深度神经网络准确的提取出特征。
[0042]
(3)本发明提出的对比测试时间适应方法效率很高,不需要训练复杂的神经网络,对比测试时间适应的过程基本上可以实现实时适应,实时分类,且分类效果很好。
附图说明
[0043]
图1是基于多模态图像融合与对比测试时间适应的心率失常分类方法的工作流程图;
[0044]
图2是训练过程中模型损失值变化图;
[0045]
图3是对比测试时间域适应过程图。
具体实施方式
[0046]
以下将结合附图,对本发明的技术方案进行详细说明。
[0047]
本发明公开了基于多模态图像融合与对比测试时间适应的心率失常分类方法,如图1所示,具体流程如下:
[0048]
步骤1:数据预处理
[0049]
1.1通过数据集自身标注的q峰位置采集心拍。本次实验采用mitdb、mitsa心律失常数据库作为数据集,其中mitdb心律失常数据库采样频率为360hz,因此,mitdb心律失常数据库可直接使用。而mitsa心律失常数据库的采样频率为128hz,需要通过wfdb(波形数据库)工具包进行重采样到360hz,保证心电图样本集中的ecg数据的采样频率均为360hz。按照美国医疗仪器促进协会提出的标准,所有的心拍可以被分为五大类:正常心拍(n)、室上异位搏动(s)、室性异位搏动(v)、融合拍(f)以及不可分类拍(q),将五大类标签编码为0、1、2、3、4,将不在这五大类中的其他数据删除。根据数据集的注释文件得到q峰所在的位置下标索引,然后以q峰的位置为基准,向左包括150个点,向右包括100个点,然后将这一段数据截取出来作为心拍。
[0050]
1.2使用smote算法进行类别平衡采样。在mitdb心律失常数据库中,n、s、v、f、q五个类样本的数目分别为72471、2223、5788、641、6431,存在严重的类别不平衡现象,若不使用smote算法进行类别平衡采样,最后训练得到的模型性能非常差。smote算法通过在少数类样本与其k个(本发明中k=3)最近邻样本之间线性内插的方法合成新的样本,在本发明中,使用smote算法将五大类均采样到20000个样本,从而实现类平衡。
[0051]
1.3使用多模态图像融合框架将一维心拍转换成二维图像。将一维ecg数据转换成三种不同类型的二维灰度图像,分别是格拉姆角场(gramian angular field gaf)、递归图(recurrence plot rp)和马尔可夫变迁场(markov transition field mtf),然后将这三个灰度图像组合起来形成三通道彩色图像(gaf-rp-mtf),三个灰度图像都是由原始ecg数据通过不同的统计方法形成的,保持了信号的时间依赖性且不会丢失一维信号的任何信息,因此得到的三通道彩色图像含有更多的信息量,且三通道图像可以很容易地与alexnet、resnet等现成的卷积神经网络一起使用。
[0052]
1.4划分数据集。为mitdb心律失常数据库重新打上标签并打乱重排,将mitdb心律失常数据库分为训练集和测试集,取前90%的数据作为训练集,后10%的数据作为测试集。
[0053]
步骤2:搭建训练阶段的卷积神经网络
[0054]
2.1搭建训练阶段的卷积神经网络。训练阶段使用的网络模型是resnet-18,resnet-18中包含17个卷积层和一个全连接层,由于本发明是一个五分类任务,所以将resnet-18的全连接层输出维度改为5,而卷积层不需要更改。
[0055]
步骤3:预训练卷积神经网络
[0056]
3.1初始化模型参数。设置模型学习率为10-4
,样本批次大小为256,迭代次数为20。使用adam优化学习率,其中平滑因子β1=0.6,β2=0.98,损失函数使用交叉熵。
[0057]
3.2将预处理后的mitdb数据集输入到模型中进行预训练。将预处理后的mitdb数据集输入到模型中,首先初始化模型参数,然后开始训练模型,损失函数的值越来越小直到收敛时,此时初始化的参数会随着发生变化,保存变化后的参数,将经过预训练后的卷积神经网络作为源模型。训练过程的模型损失值变化如图2所示,预训练后的卷积神经网络对mitdb数据集分类准确率为99.5%。
[0058]
3.3保存模型。将训练好的模型保存起来,在进行ecg分类任务时,直接将训练好的卷积神经网络参数作为ecg心率失常分类任务的初始化参数。
[0059]
步骤4:对比测试时间域适应
[0060]
4.1改进伪标签。本发明提出的对比测试时间域适应方法是基于无监督的,所以伪标签的质量决定了模型最终的性能,为了得到准确的伪标签,本发明使用了基于内存队列的最近邻投票的伪标签改进方案,在适应过程中的每一轮在线改进伪标签。首先,使用预训练模型的权重初始化目标模型,使用目标模型为未标记的目标数据生成伪标签,然后,为了实现最近邻搜索,维护了一个长度为m的内存队列qw,用于存储弱增强目标样本的特征向量ew和预测概率pw,每当输入一个测试样本x
t
后,使用目标模型得到特征向量ew和预测概率pw对队列进行更新,为了保证特征空间更加稳定,使用动量更新的方法更新目标模型参数θ,更新规则如等式(1)。每当输入一个测试样本x
t
后,得到x
t
的弱增强图像的特征向量e
t
,然后从内存队列qw中取出与e
t
余弦距离最近的k个(本发明中k=10)特征向量ew和预测概率pw,然后求10个概率的平均值最终伪标签可由等式(2)求得:
[0061]
θ
t+1
=mθ
t+1
+(1-m)θ
t
ꢀꢀꢀꢀꢀ
(1)
[0062]
其中,θ
t+1
表示时间步t+1时目标模型的参数,m是动量超参数,本发明中,m=0.2。
[0063][0064]
其中,argmax是pytorch框架中内置的函数,能返回指定维度最大值的序号。
[0065]
4.2自监督对比学习。本发明中对比学习使用的代理任务为个体判别(instance-discrimination),通过数据增强获得不同的图像视图,同一图像的不同视图的特征被视为正对拉近,而不同图像的特征被视为负对推开。将测试图像编码成两个不同强增强视图的特征query(q)和key(k),q,k被视为正样本。维持一个长度为n的负样本的队列q
negtive
存储负样本的特征k-,对比学习损失函数的目标是最小化q,k之间的余弦距离,同时最大化q和队列q
negtive
中每一个负样本特征k
_
之间的余弦距离。对比损失函数如等式(3):
[0066][0067]
其中,n表示队列q
negtive
的长度,q表示特征query,k
+
表示特征key,ki表示队列中的第i个负样本,τ是温度系数超参数,本发明中τ=0.1。
[0068]
4.3正则化。本发明中使用了两种正则化方法,分别是一致性正则化和多样性正则化。一致性正则化的目的是使得模型维持对图像的强增强图像和弱增强图像之间预测的一致性,从而提高模型的鲁棒性。一致性正则化表达式如等式(4):
[0069][0070]
其中,c表示分类任务的类别数,本发明中c=5,pw表示弱增强图像通过模型输出
的预测概率,qs表示强增强图像通过模型输出的预测概率。
[0071]
多样性正则化的目的是进一步提高伪标签的质量,虽然在步骤4.1中改进伪标签有效地减少了域移位带来的伪标签中的噪声,但它们仍然不适合作为真正的标签。为了防止模型在适应过程中盲目相信错误标签导致模型出现错误积累,在损失函数中使用正则化项来鼓励类多样化。多样性正则化表达式如等式(5):
[0072][0073]
其中,c表示分类任务的类别数,本发明中c=5,ps表示强增强图像通过模型输出的预测概率。
[0074]
最终损失函数表达式如等式(6):
[0075]
l=αl
cont
+βl
con
+γl
div
ꢀꢀꢀꢀꢀꢀꢀꢀꢀ
(6)
[0076]
其中,α,β,γ为控制各损失之前的权重值,本发明中α=0.9,β=0.7,γ=0.5。
[0077]
4.4心率失常分类任务。经过对比测试时间域适应后,源模型不断适应目标数据,在适应的过程中更新源模型的权重参数,将经过调整后的模型用于最终的心率失常分类任务。
[0078]
本领域内的技术人员应明白,本技术的实施例可提供为方法、系统、或计算机程序产品。因此,本技术可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本技术可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、cd-rom、光学存储器等)上实施的计算机程序产品的形式。本技术实施例中的方案可以采用各种计算机语言实现,例如,面向对象的程序设计语言java和直译式脚本语言javascript等。
[0079]
本技术是参照根据本技术实施例的方法、设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
[0080]
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
[0081]
这些计算机程序指令也可装载到计算机或其他可编程数据处理设备上,使得在计算机或其他可编程设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
[0082]
尽管已描述了本技术的优选实施例,但本领域内的技术人员一旦得知了基本创造性概念,则可对这些实施例作出另外的变更和修改。所以,所附权利要求意欲解释为包括优选实施例以及落入本技术范围的所有变更和修改。
[0083]
显然,本领域的技术人员可以对本技术进行各种改动和变型而不脱离本技术的精神和范围。这样,倘若本技术的这些修改和变型属于本技术权利要求及其等同技术的范围
之内,则本技术也意图包含这些改动和变型在内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1