深度学习网络的训练方法、装置、设备和存储介质与流程

文档序号:18887630发布日期:2019-10-15 21:11阅读:222来源:国知局
深度学习网络的训练方法、装置、设备和存储介质与流程

本发明涉及计算机技术领域,尤其涉及一种深度学习网络的训练方法、装置、设备和存储介质。



背景技术:

随着科技的不断进步,深度学习网络逐渐兴起。在图像处理领域中,该深度学习网络作为一种运算模型,可以对图像中的特征进行提取,例如:深度学习网络可以用于人体关键点识别,人形进行分割等。

目前,在使用深度学习网络之前,需要对深度学习网络进行训练,在训练过程中,通过对深度学习网络进行不断的调整,使得深度学习网络能够预测出准确的图像信息。但是,在训练深度学习网络时,由于深度学习网络的复杂性,导致训练过程中容易出现图像陷入局部最优(鞍点)的情况发生,也即是说,深度学习网络处理的图像的损失梯度不明显,损失函数对像素全零图像的惩罚较小,这样即便深度学习网络输出像素全零的图像,得到的图像损失依旧很小。如果训练僵持在局部最优,针对损失梯度不明显的图像,深度学习网络将会一直输出像素全零的图像,无法达到全局最优,即无法输出准确的图像信息。

例如:如图1所示,为人体关键点的真值热度图,图1中心的白点的像素值为1,白点周围的灰点的像素值处于0~1之间,其余黑点的像素值为0。热度图陷于鞍点的情况即是:深度学习网络输出的预测热度图为像素值全0的全黑的热度图,这是由于真值热度图的损失梯度不明显,真值热度图和深度学习网络输出的预测热度图之间的损失很小,而现有的损失函数对全零热度图的惩罚较小,无法得到正确的预测热度图,这样深度学习网络在训练时很容易陷在全零的局部最优。



技术实现要素:

本发明的主要目的在于提供一种深度学习网络的训练方法、装置、设备和存储介质,以解决现有深度学习网络在训练时很容易陷在全零的局部最优的问题。

针对上述技术问题,本发明是通过以下技术方案来解决的:

本发明提供了一种深度学习网络的训练方法,包括:获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像;计算所述真值图像和所述预测图像之间的像素均值误差;利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差;根据所述图像基本损失,确定图像综合损失,所述图像综合损失用于对所述深度学习网络执行训练。

其中,所述根据所述图像基本损失,确定图像综合损失,包括:将所述图像基本损失,确定为所述图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算所述图像基本损失和所述图像补充损失的加权和,将所述加权和确定为所述图像综合损失。

其中,所述真值图像为真值热度图;所述预测图像为预测热度图;所述获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像,包括:获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

其中,如果获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;计算所有像素均值误差的均值,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

其中,如果获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;分别计算每个所述图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值;对每个所述图像对应的像素均值误差的均值进行平均值计算,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

其中,利用预设的指数损失函数,确定图像基本损失,包括:

其中,hm_emvd_loss为图像基本损失,diff为像素均值误差,e为自然常数,α为预设参数。

本发明提供了一种深度学习网络的训练装置,获取模块,用于获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像;计算模块,用于计算所述真值图像和所述预测图像之间的像素均值误差;第一确定模块,用于利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差;第二确定模块,用于根据所述图像基本损失,确定图像综合损失,所述图像综合损失用于对所述深度学习网络执行训练。

其中,所述第一确定模块,进一步用于:将所述图像基本损失,确定为所述图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算所述图像基本损失和所述图像补充损失的加权和,将所述加权和确定为所述图像综合损失。

其中,所述真值图像为真值热度图;所述预测图像为预测热度图;所述获取模块,进一步用于:获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

其中,如果所述获取模块获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则所述计算模块,进一步用于:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;计算所有像素均值误差的均值,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

其中,如果所述获取模块获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则所述计算模块,进一步用于:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;分别计算每个所述图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值;对每个所述图像对应的像素均值误差的均值进行平均值计算,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

其中,所述第一确定模块具体用于执行以下计算:

其中,hm_emvd_loss为图像基本损失,diff为像素均值误差,e为自然常数,α为预设参数。

本发明提供一种深度学习网络的训练设备,所述深度学习网络的训练设备包括处理器、存储器;所述处理器用于执行所述存储器中存储的深度学习网络的训练程序,以实现上述的深度学习网络的训练方法。

本发明提供一种存储介质,所述存储介质存储有一个或者多个程序,所述一个或者多个程序可被一个或者多个处理器执行,以实现上述的深度学习网络的训练方法。

本发明有益效果如下:

本发明计算真值图像和预测图像之间的像素均值误差;并且利用预设的指数损失函数,确定图像基本损失,该方式使得真值图像和预测图像之间的差异更加明显,使深度学习网络输出全零图像的损失较大,即增强了对全零图像的惩罚,不容易陷入局部最优,避免出现鞍点的问题。

附图说明

此处所说明的附图用来提供对本发明的进一步理解,构成本申请的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:

图1是现有技术的人体关键点的真值热度图;

图2是根据本发明一实施例的深度学习网络的训练方法的流程图;

图3是根据本发明一实施例的图像基本损失的确定步骤流程图;

图4是根据本发明一实施例的确定图像综合损失的步骤流程图;

图5是根据本发明一实施例的确定图像综合损失的架构示意图;

图6是根据本发明另一实施例的图像基本损失的确定步骤流程图;

图7是根据本发明另一实施例的确定图像综合损失的步骤流程图;

图8是根据本发明一实施例的深度学习网络的训练装置的结构图;

图9是根据本发明一实施例的深度学习网络的训练设备的结构图。

具体实施方式

为使本发明的目的、技术方案和优点更加清楚,以下结合附图及具体实施例,对本发明作进一步地详细说明。

本实施例提供一种深度学习忘了的训练方法。如图2所示,为根据本发明一实施例的深度学习网络的训练方法的流程图。

步骤s210,获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像。

预测图像是深度学习网络预测的分析结果。

真值图像为正确的分析结果。

其中,一个预测图像对应一个真值图像,利用真值图像来衡量预测图像的准确度,确定深度学习网络是否需要继续训练。

进一步地,预先采集样本图像,人工对样本图像进行样本分析,并对分析结果进行标注,得到真值图像。将样本图像输入深度学习网络,训练深度学习网络进行样本分析,输出的分析结果为预测图像。例如:人工对样本图像进行人体关键点分析,获得的分析结果为可以从样本图像中提取出的人体关键点图像,对该人体关键点图像中的人体关键点进行像素标注,得到人体关键点的真值图像;将样本图像输入深度学习网络,训练深度学习网络输出人体关键点图像,深度学习网络输出的人体关键点图像即是人体关键点的预测图像;预测图像的准确度一般低于真值图像,训练深度学习网络目的在于提高预测图像的准确度。

在本实施例中,真值图像可以为真值热度图,预测图像可以为预测热度图。热度图可以是人体关键点的热度图。

进一步地,可以获取同一个图像对应的多个真值热度图以及深度学习网络输出的与每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的与每个真值热度图对应的预测热度图。

步骤s220,计算所述真值图像和所述预测图像之间的像素均值误差。

像素均值误差是指:真值图像的像素均值和对应预测图像的像素均值的误差。像素均值是指,对图像中所有像素点求和取平均值。像素均值的误差是指,真值图像的像素均值和对应预测图像的像素均值的差值的平方。

像素均值误差的计算步骤,将在后面的实施例中进行描述。

步骤s230,利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差。

指数损失函数可以根据需求进行设置。例如:指数损失函数,可以是:

其中,hm_emvd_loss为图像基本损失,diff为像素均值误差,e为自然常数,α为预设参数。α可以根据具体需求设置和调整,例如:根据不同的多任务学习种类进行不同的设置。

相互对应的真值图像和预测图像作为一组真值图像和预测图像,如果获取了多组真值图像和预测图像,则需要计算每组真值图像和预测图像的像素均值误差,并且计算出多个像素均值误差的平均值,将平均的像素均值误差代入指数损失函数的指数幂,进一步地,将该平均的像素均值误差作为diff。

在本实施例中,可以将图像基本损失命名为emvd(exponentialmeanvaluediffloss,指数均值差异损失)损失。emvd损失的计算函数可以命名为emvd损失函数。

步骤s240,根据所述图像基本损失,确定图像综合损失,以便根据所述图像综合损失对所述深度学习网络执行训练。

图像综合损失,为预测图像的总损失。

在本实施例中,可以使用emvd损失函数,确定图像综合损失,或者使用emvd损失函数以及预设的损失函数,确定图像综合损失。也即是说,可以将图像基本损失,直接确定为图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算图像基本损失和图像补充损失的加权和,将该加权和确定为图像综合损失。进一步地,预设的损失函数的数量可以是一个或者多个。预设的损失函数例如是l1损失函数,l2损失函数。

在本实施例中,本发明计算真值图像和预测图像之间的像素均值误差;并且根据该像素均值误差,确定图像基本损失,该方式使得真值图像和预测图像之间的差异更加明显,使深度学习网络输出全零图像的损失较大,即增强了对全零图像的惩罚,不容易陷入局部最优,避免出现鞍点的问题。

本实施例可以应用在多任务学习的深度学习网络中,增强对全零图像的惩罚,emvd损失函数连续可导,使损失函数的导数不那么平坦,不容易陷入局部最优。

下面给出一种较为具体的确定图像综合损失的方式。

本实施例针对单样本(单个图像)的情况进行描述,也就是说,获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,利用多个真值热度图以及多个预测热度图确定图像综合损失,以便根据图像综合损失对深度学习网络执行训练。

图3为根据本发明一实施例的图像基本损失的确定步骤流程图。

步骤s310,获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

获取图像对应的k(k≥1)个真值热度图,以及每个真值热度图对应的预测热度图,得到同一图像对应的k个真值热度图和k个预测热度图,即k组真值热度图和预测热度图。

步骤s320,对每个真值热度图中的所有像素点执行求和取平均,得到每个真值热度图的像素均值,形成包括所有真值热度图的像素均值的矩阵。

计算所有真值热度图的像素均值组成的矩阵。具体的,计算公式可以为:

sumtrue=reduce_mean(htrue,axis=(0,1));

其中,sumtrue为各个真值热度图的像素均值组成的矩阵。htrue表示多个真值热度图叠加在一起组成的真值热度图组,axis=(0,1)表示对真值热度图组中宽度轴(0轴)和高度轴(1轴)组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。reduce_mean用于计算平均值。

sumtrue的维度为1×k,即1行k列的矩阵,sumtrue中的每个元素为一个真值热度图的像素均值,每个真值热度图的像素均值是对该真值热度图中的所有像素点进行求和取平均。

htrue的维度为w×h×k,w(0轴)表示真值热度图的宽度,h(1轴)表示真值热度图的高度,k(2轴)表示真值热度图的数量。0轴和1轴组成的平面为真值热度图的平面,也即是说,htrue是k个w×h叠加在一起。

步骤s330,对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值,得到包括所有预测热度图的像素均值的矩阵。

计算所有预测热度图的像素均值组成的矩阵。具体的,计算公式可以为:

sumpred=reduce_mean(hpred,axis=(0,1));

其中,sumpred为各个预测热度图的像素均值组成的矩阵。hpred表示多个预测热度图叠加在一起组成的预测热度图组,axis=(0,1)表示对预测热度图组中宽度轴(0轴)和高度轴(1轴)组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。

sumpred的维度为1×k,即1行k列的矩阵,sumpred中的每个元素为一个预测热度图的像素均值,每个预测热度图的像素均值是对该预测热度图中的所有像素点进行求和取平均。

hpred的维度为w×h×k,w(0轴)表示预测热度图的宽度,h(1轴)表示预测热度图的高度,k表示预测热度图的数量。0轴和1轴组成的平面为预测热度图的平面,也即是说,hpred是k个w×h叠加在一起。相互对应的真值热度图和预测热度图,w相等且h相等。

步骤s340,根据每个真值热度图的像素均值以及每个预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差。

在获取的多个真值热度图和多个预测热度图中,根据每个真值热度图的像素均值以及每个预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差,形成像素均值误差组成的矩阵。

具体的,计算公式可以是:

diffs=(sumtrue-sumpred)2

其中,diffs为相互对应的真值热度图和预测热度图之间的像素均值误差组成的矩阵。

diffs的维度为1×k,diffs中的每个元素为sumtrue中的一个真值热度图的像素均值与sumpred中对应的预测热度图的像素均值的差值取平方,即diffs中的每个元素为sumtrue中的一个真值热度图的像素均值与sumpred中对应的预测热度图的像素均值误差。

步骤s350,计算所有像素均值误差的均值,得到各个真值热度图和对应的预测热度图之间的平均的像素均值误差。

计算所有像素均值误差的均值,得到该图像对应的像素均值误差,即:每个真值热度图和对应的预测热度图之间的平均的像素均值误差。

具体的,计算公式可以是:

diff=reduce_mean(diffs);

其中,diff为所有像素均值误差的平均值,即各个真值热度图和对应的预测热度图之间的平均的像素均值误差,reduce_mean(diffs)为对diffs中所有元素(所有像素均值误差)求和取平均。

步骤s360,利用预设的指数损失函数,确定图像基本损失;指数损失函数的指数幂包括:(图像对应的)真值热度图和预测热度图之间的像素均值误差。

具体的,计算公式可以是:

其中,hm_emvd_loss表示图像基本损失,e为自然常数,α为预设参数。α可以根据具体需求设置和调整,例如:根据不同的多任务学习种类进行不同的设置。

在本实施例中,上述确定图像基本损失的过程,可以称为emvd函数的执行过程。

在本实施例中,可以直接将图像基本损失确定为图像综合损失。当然,也可以先确定图像补充损失,根据图像基本损失和图像补充损失,确定图像综合损失。

图4为根据本发明一实施例的确定图像综合损失的步骤流程图。图5为根据本发明一实施例的确定图像综合损失的架构示意图。

步骤s410,针对真值热度图及其对应的预测热度图,利用预设的损失函数,计算图像补充损失。

该预设的损失函数,例如是:l1损失函数,l2损失函数。

以l2损失函数为例,需要执行以下计算步骤:

步骤s1,确定l1=(htrue-hpred)2

其中,l1为真值热度图和对应的预测热度图的像素差值的平方组成的矩阵,l1的维度为w×h×k。w为宽度轴(0轴),h为高度轴(1轴),0轴和1轴组成的平面为真值热度图和对应的预测热度图的像素差值的平方组成的平面。真值热度图和预测热度图的数量相等,都为k个,那么l1为k个w×h叠加在一起。

具体的,获取相互对应的真值热度图和预测热度图,对该真值热度图和该预测热度图的对应像素点做差取平方。

步骤s2,确定l2=reduce_mean(l1,axis=(0,1))。

其中,l2表示对l1中宽度轴(0轴)和高度轴(1轴)组成的各个平面中所有像素点的像素求和取平均,axis=(0,1)表示对l1中0轴和1轴组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。

l2的维度为1×k,l2中的每个元素可以表示一个真值热度图和对应的预测热度图之间的误差。

步骤s3,计算hm_l2_loss=reduce_mean(l2)。

其中,hm_l2_loss为l2损失,表示各个真值热度图和对应的预测热度图之间的误差的均值。

步骤s420,针对真值热度图及其对应的预测热度图,利用emvd损失函数,计算图像基本损失。

利用emvd损失函数,计算图像基本损失的步骤已经在上面进行了描述,在此不做赘述。

步骤s430,计算图像基本损失和图像补充损失的加权和,将该加权和确定为图像综合损失。

具体的,计算公式可以为:

hm_loss=w1×hm_l2_loss+w2×hm_emvd_loss;

其中,hm_loss表示图像综合损失,hm_l2_loss为l2损失(图像补充损失),hm_emvd_loss为图像基本损失,w1为第一权重,w2为第二权重。第一权重和第二权重可以根据经验进行设置,也可以在训练深度学习网络的过程中不断调整。

在获得图像综合损失之后,可以根据该图像综合损失来调整深度学习网络,以便使深度学习网络输出更为准确的预测热度图。

本实施例采用图像基本损失和图像补充损失确定图像综合损失,通过使用两种损失函数,进一步增加了真值图像和预测图像之间差异,增强了深度学习网络对全零热度图的惩罚,使陷入局部最优情况下的深度学习网络,输出全零热度图的损失较大,在反向传播过程中增大影响,有利于深度学习网络跳出局部最优,从而达到全局最优的训练效果。

下面给出另一种较为具体的确定图像综合损失的方式。

本实施例针对批量样本(多个图像)的情况进行描述,也就是说,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,利用多个真值热度图以及多个预测热度图确定图像综合损失,以便根据图像综合损失对深度学习网络执行训练。

图6为根据本发明另一实施例的图像基本损失的确定步骤流程图。

步骤s610,获取多个图像中每一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

获取batch(batch≥1)个图像中每一个图像对应的k(k≥1)个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。其中,batch=1,按照上述单样本的情况执行,由于已经对单样本的情况进行了描述,所以在此不做赘述。

例如:在训练深度学习网络的过程中,获取图像a对应的k个真值图像和每个真值图像对应的预测图像,以及获取图像b对应的k个真值图像和每个真值图像对应的预测图像。

步骤s620,对每个真值热度图中的所有像素点执行求和取平均,得到每个真值热度图的像素均值,获得包括多个图像中每个图像对应的多个真值热度图的像素均值的矩阵。

计算所有真值热度图的像素均值组成的矩阵。

具体的,计算公式可以为:

batch_sumtrue=reduce_mean(htrue,axis=(1,2));

其中,batch_sumtrue为各个真值热度图的像素均值组成的矩阵。htrue表示batch个图像中每个图像对应的真值热度图叠加在一起组成的矩阵,axis=(1,2)表示对矩阵中宽度轴(1轴)和高度轴(2轴)组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。

batch_sumtrue的维度为batch×k,即batch行k列的矩阵,batch表示图像的数量,k表示每个图像对应的真值热度图的数量。batch_sumtrue中的每个元素为一个真值热度图的像素均值,每个真值热度图的像素均值是对该真值热度图中的所有像素点进行求和取平均。

htrue的维度为batch×w×h×k,batch(0轴)表示图像的数量,w(1轴)表示真值热度图的宽度,h(2轴)表示真值热度图的高度,k(3轴)表示每个图像对应的真值热度图的数量。0轴和1轴组成的平面为真值热度图的平面,也即是说,htrue是batch列真值热度图组,每列真值热度图组为k个w×h叠加在一起。

步骤s630,对每个预测热度图中的所有像素点执行求和取平均,得到每个预测热度图的像素均值,得到包括多个图像中每个图像对应的多个预测热度图的像素均值的矩阵。

计算所有预测热度图的像素均值组成的矩阵。具体的,计算公式可以为:

batch_sumpred=reduce_mean(hpred,axis=(1,2));

其中,batch_sumpred为所有预测热度图的像素均值组成的矩阵。hpred表示batch个图像中每个图像对应的预测热度图叠加在一起组成的矩阵,axis=(1,2)表示对矩阵中宽度轴(1轴)和高度轴(2轴)组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。

batch_sumpred的维度为batch×k,即batch行k列的矩阵,batch_sumpred中的每个元素为一个预测热度图的像素均值,每个预测热度图的像素均值是对该预测热度图中的所有像素点进行求和取平均。

hpred的维度为batch×w×h×k,batch(0轴)表示图像的数量,w(1轴)表示真值热度图的宽度,h(2轴)表示真值热度图的高度,k(3轴)表示每个图像对应的真值热度图的数量。0轴和1轴组成的平面为预测热度图的平面,也即是说,hpred是batch列预测热度图组,每列真值热度图组为k个w×h叠加在一起。

步骤s640,根据每个真值热度图的像素均值以及每个预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差。

在获取的多个真值热度图和多个预测热度图中,根据每个真值热度图的像素均值以及每个预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差,形成像素均值误差组成的矩阵。

具体的,计算公式可以为:

batch_diffs=(batch_sumtrue-batch_sumpred)2

其中,batch_diffs为相互对应的真值热度图和预测热度图之间的像素均值误差组成的矩阵。

batch_diffs的维度为batch×k。batch_diffs中的每个元素为batch_sumtrue中的一个真值热度图的像素均值与batch_sumpred中对应的预测热度图的像素均值之间的误差。

步骤s650,分别计算每个图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值。

例如:获取了图像a对应的多个真值图像和每个真值图像对应的预测图像,以及获取了图像b对应的多个真值图像和每个真值图像对应的预测图像;计算图像a对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值a,计算图像b对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值b。

具体的,计算公式可以为:

batch_diff=reduce_mean(batch_diffs,axis=1);

其中,batch_diff为每个图像对应的所有像素均值误差的平均值,即每个图像对应的各个真值热度图和对应的预测热度图之间的平均的像素均值误差,reduce_mean(batch_diffs,axis=1)为对batch_diffs中所有元素(所有像素均值误差)求和取平均。

进一步地,batch_diff的维度为batch×1,batch_diff中的每个元素为一个图像对应k个真值热度图和对应的预测热度图之间的像素均值误差的均值。

步骤s660,对每个图像对应的像素均值误差的均值进行平均值计算,得到各个真值热度图和对应的预测热度图之间的平均的像素均值误差。

例如:上例中,计算图像a对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值a,以及图像b对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值b,本例需要计算a和b的均值,即(a+b)÷2。

具体的,计算公式可以为:

diff=reduce_mean(batch_diffs)/batch;

diff为所有像素均值误差的平均值,即各个(batch乘以k个)真值热度图和对应的预测热度图之间的平均的像素均值误差,reduce_mean(batch_diffs)为对batch_diffs中所有元素(所有像素均值误差)求和取平均。

步骤s670,利用预设的指数损失函数,确定图像基本损失;指数损失函数的指数幂包括:(每个图像对应的)真值热度图和预测热度图之间的像素均值误差。

具体的,计算公式可以是:

其中,hm_emvd_loss表示图像基本损失,e为自然常数,α为预设参数。α可以根据具体需求设置和调整,例如:根据不同的多任务学习种类进行不同的设置。

在本实施例中,上述确定图像基本损失的过程,可以称为emvd函数的执行过程。

在本实施例中,可以直接将图像基本损失确定为图像综合损失。当然,也可以先确定图像补充损失,根据图像基本损失和图像补充损失,确定图像综合损失。

图7为根据本发明另一实施例的确定图像综合损失的步骤流程图。

步骤s710,针对真值热度图及其对应的预测热度图,利用预设的损失函数,计算图像补充损失。

该预设的损失函数,例如是:l1损失函数,l2损失函数。

以l2损失函数为例,需要执行以下计算步骤:

步骤s1,确定l1=(htrue-hpred)2

其中,l1为真值热度图和对应的预测热度图的像素差值的平方组成的矩阵,l1的维度为batch×w×h×k。batch为图像的数量(0轴),w为宽度轴(1轴),h为高度轴(2轴),k为每个图像对应的真值热度图(预测热度图)的数量。0轴和1轴组成的平面为真值热度图和对应的预测热度图的像素差值的平方组成的平面。真值热度图和预测热度图的数量相等,都为k个,那么l1为batch个平面组,每个平面组包括k个w×h叠加在一起。

具体的,获取相互对应的真值热度图和预测热度图,对该真值热度图和该预测热度图的对应像素点做差取平方。

步骤s2,确定l2=reduce_mean(l1,axis=(1,2))。

其中,l2表示对l1中宽度轴(1轴)和高度轴(2轴)组成的各个平面中所有像素点的像素求和取平均,axis=(1,2)表示对l1中1轴和2轴组成的各个平面中所有像素点的像素求和,reduce_mean为平均值函数。

l2的维度为batch×k,l2中的每个元素可以表示一个真值热度图和对应的预测热度图之间的误差。

步骤3,计算hm_l2_losses=reduce_mean(l2,axis=1)。

其中,hm_l2_losses为所有图像的l2损失,表示batch个图像对应的各个真值热度图和对应的预测热度图之间的误差的均值。

步骤s4,计算hm_l2_loss=reduce_mean(hm_l2_losses)/batch。

其中,hm_l2_loss为每个图像的l2损失。

步骤s720,针对真值热度图及其对应的预测热度图,利用emvd损失函数,计算图像基本损失。

利用emvd损失函数,计算图像基本损失的步骤已经在上面进行了描述,在此不做赘述。

步骤s730,计算图像基本损失和图像补充损失的加权和,将该加权和确定为图像综合损失。

具体的,计算公式可以为:

hm_loss=w1×hm_l2_loss+w2×hm_emvd_loss;

其中,hm_loss表示图像综合损失,hm_l2_loss为l2损失(图像补充损失),hm_emvd_loss为图像基本损失,w1为第一权重,w2为第二权重。第一权重和第二权重可以根据经验进行设置,也可以在训练深度学习网络的过程中不断调整。

在获得图像综合损失之后,可以根据该图像综合损失来调整深度学习网络,以便使深度学习网络输出更为准确的预测热度图。

本实施例提供一种深度学习网络的训练装置。如图8所示,为根据本发明一实施例的深度学习网络的训练装置的结构图。

在本实施例中,深度学习网络的训练装置,包括:获取模块810,计算模块820,第一确定模块830和第二确定模块840。

获取模块810,用于获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像。

计算模块820,用于计算所述真值图像和所述预测图像之间的像素均值误差。

第一确定模块830,用于利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差。

第二确定模块840,用于根据所述图像基本损失,确定图像综合损失,所述图像综合损失用于对所述深度学习网络执行训练。

可选的,所述第一确定模块830,进一步用于:将所述图像基本损失,确定为所述图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算所述图像基本损失和所述图像补充损失的加权和,将所述加权和确定为所述图像综合损失。

可选的,所述真值图像为真值热度图;所述预测图像为预测热度图;所述获取模块810,进一步用于:获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

可选的,如果所述获取模块810获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则所述计算模块820,进一步用于:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;计算所有像素均值误差的均值,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,如果所述获取模块810获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则所述计算模块820,进一步用于:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;分别计算每个所述图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值;对每个所述图像对应的像素均值误差的均值进行平均值计算,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,所述第一确定模块830具体用于执行以下计算:

其中,hm_emvd_loss为图像基本损失,diff为平均的像素均值误差,e为自然常数,α为预设参数。

本发明所述的装置的功能已经在图2~图7所示的方法实施例中进行了描述,故本实施例的描述中未详尽之处,可以参见前述实施例中的相关说明,在此不做赘述。

本实施例提供一种深度学习网络的训练设备。如图9所示,为根据本发明第五实施例的深度学习网络的训练设备的结构图。

在本实施例中,所述深度学习网络的训练设备,包括但不限于:处理器910、存储器920。

所述处理器910用于执行存储器920中存储的深度学习网络的训练程序,以实现上述的深度学习网络的训练方法。

具体而言,所述处理器910用于执行存储器920中存储的深度学习网络的训练程序,以实现以下步骤:获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像;计算所述真值图像和所述预测图像之间的像素均值误差;利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差;根据所述图像基本损失,确定图像综合损失,所述图像综合损失用于对所述深度学习网络执行训练。

可选的,所述根据所述图像基本损失,确定图像综合损失,包括:将所述图像基本损失,确定为所述图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算所述图像基本损失和所述图像补充损失的加权和,将所述加权和确定为所述图像综合损失。

可选的,所述真值图像为真值热度图;所述预测图像为预测热度图;所述获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像,包括:获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

可选的,如果获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;计算所有像素均值误差的均值,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,如果获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;分别计算每个所述图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值;对每个所述图像对应的像素均值误差的均值进行平均值计算,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,利用预设的指数损失函数,确定图像基本损失,包括:

其中,hm_emvd_loss为图像基本损失,diff为平均的像素均值误差,e为自然常数,α为预设参数。

本实施例提供了一种存储介质。这里的存储介质存储有一个或者多个程序。其中,存储介质可以包括易失性存储器,例如随机存取存储器;存储器也可以包括非易失性存储器,例如只读存储器、快闪存储器、硬盘或固态硬盘;存储器还可以包括上述种类的存储器的组合。

当存储介质中一个或者多个程序可被一个或者多个处理器执行,以实现上述的深度学习网络的训练方法。

具体而言,所述处理器用于执行存储器中存储的深度学习网络的训练程序,以实现以下步骤:获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像;计算所述真值图像和所述预测图像之间的像素均值误差;利用预设的指数损失函数,确定图像基本损失;所述指数损失函数的指数幂包括:所述真值图像和所述预测图像之间的像素均值误差;根据所述图像基本损失,确定图像综合损失,所述图像综合损失用于对所述深度学习网络执行训练。

可选的,所述根据所述图像基本损失,确定图像综合损失,包括:将所述图像基本损失,确定为所述图像综合损失;或者,利用预设的损失函数,确定图像补充损失;计算所述图像基本损失和所述图像补充损失的加权和,将所述加权和确定为所述图像综合损失。

可选的,所述真值图像为真值热度图;所述预测图像为预测热度图;所述获取真值图像以及深度学习网络输出的所述真值图像对应的预测图像,包括:获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图;或者,获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图。

可选的,如果获取同一个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;计算所有像素均值误差的均值,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,如果获取多个图像中每个图像对应的多个真值热度图以及深度学习网络输出的每个真值热度图对应的预测热度图,则计算所述真值图像和所述预测图像之间的像素均值误差,包括:对每个真值热度图中的所有像素点执行求和取平均,得到每个所述真值热度图的像素均值;对每个预测热度图中的所有像素点执行求和取平均,得到每个所述预测热度图的像素均值;根据每个所述真值热度图的像素均值以及每个所述预测热度图的像素均值,计算相互对应的真值热度图和预测热度图之间的像素均值误差;分别计算每个所述图像对应的各个真值热度图和对应的预测热度图之间的像素均值误差的均值;对每个所述图像对应的像素均值误差的均值进行平均值计算,得到各个所述真值热度图和对应的预测热度图之间的平均的像素均值误差。

可选的,利用预设的指数损失函数,确定图像基本损失,包括:

其中,hm_emvd_loss为图像基本损失,diff为平均的像素均值误差,e为自然常数,α为预设参数。

以上所述仅为本发明的实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的权利要求范围之内。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1