一种缓解多任务学习中任务冲突方法、装置及存储介质

文档序号:26715122发布日期:2021-09-22 19:51阅读:214来源:国知局
一种缓解多任务学习中任务冲突方法、装置及存储介质

1.本发明涉及计算机技术领域,尤其涉及一种缓解多任务学习中任务冲突方法、装置及存储介质。


背景技术:

2.深度学习在各个领域已经取得了不错的成绩,但是目前的人工智能依赖于海量数据的训练,模型泛化能力不佳,在有限数据领域下的效果和快速拓展到新任务的能力都不尽人意。针对这个问题,一些研究者提出多任务学习方法(mtl,multi

task learning)来解决这个问题。多任务学习方法能够联合多个任务一起学习,一些数据有限的任务能够利用其他任务共享的信息进行训练,从而提高任务的表现。
3.基于优化的多任务学习方法是现有的多任务学习方法中的一种;而现有基于优化的多任务学习方法中,当任务梯度发生冲突或者任务被其他较大梯度的任务支配时,通常是通过设计不同的策略来调整各个任务loss的权重,对于训练比较快的任务,降低其权重,减少模型对其的关注程度,让模型多关注那些没有训练充分的任务,从而实现缓解任务训练的不平衡。但是通过调整各任务loss权重来平衡任务训练,会导致某些任务被其他任务所支配,从而得不到充分训练,无法实现各个任务训练均衡,进而降低了模型的整体性能。


技术实现要素:

4.本发明实施例提供一种缓解多任务学习中任务冲突的方法及装置,能在实现多任务模型中各个任务训练均衡的同时缓解任务冲突。
5.本发明一实施例提供一种缓解多任务学习中任务冲突的方法,包括:
6.获取待缓解多任务学习模型中各个学习任务的梯度值;
7.任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;
8.在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;
9.计算所有所述学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。
10.进一步的,所述获取待缓解多任务学习模型中各个学习任务的梯度值,具体包括:计算每一所述学习任务的损失值,继而根据每一所述学习任务的损失值计算每一所述学习任务对所述待缓解多任务学习模型中网络参数的偏导数,获得每一所述学习任务的梯度值。
11.进一步的,所述对所述选定学习任务的梯度值进行梯度修剪,具体包括:
12.根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面;其中,所述冲突学习任务为与所述选定学习任务存在任务冲突的学习任务;
13.分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值;
14.根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值。
15.进一步的,所述根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面,具体包括:
16.通过以下公式确定所述选定学习任务与所述冲突学习任务的冲突平面:
17.p
γ
=g
i

g
j

18.其中,g
i
为所述选定学习任务的梯度值,g
j
为所述冲突学习任务的梯度值。
19.进一步的,所述分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值,具体包括:
20.通过以下公式计算所述选定学习任务以及所述冲突学习任务与所述冲突平面的夹角的余弦值;
[0021][0022]
通过以下公式计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量:
[0023]
δg
i
=g
i
·
cosφ
i
,δg
j
=g
j
·
cosφ
j

[0024]
通过以下公式计算所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值:
[0025]
δg
{i,j}
=||g
i
·
cosφ
i

g
j
·
cosφ
j
||;
[0026]
其中,cosφ
i
为所述选定学习任务与所述冲突平面的夹角的余弦值,cosφ
j
为所述冲突学习任务与所述冲突平面的夹角的余弦值,δg
i
为所述选定学习任务在所述冲突平面上的梯度分量,δg
j
为所述冲突学习任务在所述冲突平面上的梯度分量。
[0027]
进一步的,所述根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值,具体包括:根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值
[0028]
在上述方法项实施例的基础上,本发明对应提供了装置项实施例;
[0029]
本发明一实施例提供了一种缓解多任务学习中任务冲突的装置,包括梯度值获取模块、学习任务梯度值更新模块以及模型参数值更新模块;
[0030]
所述梯度值获取模块,用于获取待缓解多任务学习模型中各个学习任务的梯度值;
[0031]
所述学习任务梯度值更新模块,用于任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习
任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;以及,
[0032]
在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;
[0033]
所述模型参数值更新模块,用于计算所有所述学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。
[0034]
在上述方法项实施例的基础上,本发明对应提供了一存储介质项实施例;
[0035]
本发明一实施例提供了一种存储介质,所述存储介质包括存储的计算机程序,其中,在所述计算机程序运行时控制所述所在设备执行本发明任意一项所述的缓解多任务学习中任务冲突的方法。
[0036]
通过实施本发明实施例具有如下有益效果:
[0037]
本发明实施例提供了一种种缓解多任务学习中任务冲突的方法、装置及存储介质,所述方法首先获取待缓解多任务学习模型中各个学习任务的梯度值,紧接着判断各学习任务中选定学习任务与其余各学习任务之间是否存在任务冲突,在判定存在任务冲突时对选定学习任务的梯度值进行修剪,并将选定学习任务的梯度值更新为修剪后的梯度值,在选定学习任务的梯度值更新执行完毕后重新选定一学习任务作为选定学习任务重复梯度值更新操作,直至待缓解多任务学习模型中所有学习任务的梯度值更新完毕,最后计算所有所述学习任务完成梯度值更新后的梯度值,获得平均梯度,将平均梯度作为上述待缓解多任务学习模型的整体梯度值,然后根据平均梯度对待缓解多任务学习模型的网络参数进行更新。与现有技术相比,本发明通过对各学习任务的梯度值进行梯度修剪来更新模型的网络参数,从而缓解多任务学习模型中任务冲突的问题,并且在此过程中无需调整各任务的loss权重,避免了任务训练不均衡的问题,提高了模型的整体性能。
附图说明
[0038]
图1是本发明一实施例提供的一种缓解多任务学习中任务冲突的方法的流程示意图。
[0039]
图2是本发明一是合理提供的一种缓解多任务学习中任务冲突的装置示意图。
具体实施方式
[0040]
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
[0041]
如图1所示,本发明一实施例提供了一种缓解多任务学习中任务冲突的方法,具体包括:
[0042]
步骤s101:获取待缓解多任务学习模型中各个学习任务的梯度值。
[0043]
步骤s102:任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突。
[0044]
步骤s103:在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕。
[0045]
步骤s104:计算所有所述学习任务完成梯度值更新后的梯度值的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。
[0046]
对于步骤s101:在一个优选的实施例中,所述获取待缓解多任务学习模型中各个学习任务的梯度值,具体包括:计算每一所述学习任务的损失值,继而根据每一所述学习任务的损失值计算每一所述学习任务对所述待缓解多任务学习模型中网络参数的偏导数,获得每一所述学习任务的梯度值。
[0047]
具体的,假设上述待缓解多任务学习模型存在三个学习任务i、j以及f;在一次训练迭代之后,计算得出学习任务i的损失值学习任务j的损失值学习任务f的损失值然后计算学习任务i对于待缓解多任务学习模型中网络参数的偏导数,得到学习任务i的梯度值g
i
,计算任务j对于待缓解多任务学习模型中网络参数的偏导数,得到学习任务j的梯度值g
j
,计算任务f对于待缓解多任务学习模型中网络参数的偏导数,得到学习任务f的梯度值g
f
,具体计算公式如下:
[0048]
其中θ为待缓解多任务学习模型的网络参数。
[0049]
对于步骤s102:在一个优选的实施例中,所述对所述选定学习任务的梯度值进行梯度修剪,具体包括:
[0050]
根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面;其中,所述冲突学习任务为与所述选定学习任务存在任务冲突的学习任务;
[0051]
分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值;
[0052]
根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪。
[0053]
所述根据所述选定学习任务的梯度值以及冲突学习任务的梯度值,确定所述选定学习任务与所述冲突学习任务的冲突平面,具体包括:
[0054]
通过以下公式确定所述选定学习任务与所述冲突学习任务的冲突平面:
[0055]
p
γ
=g
i

g
j

[0056]
其中,g
i
为所述选定学习任务的梯度值,g
j
为所述冲突学习任务的梯度值。
[0057]
所述分别计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量,继而根据所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差
值,具体包括:
[0058]
通过以下公式计算所述选定学习任务以及所述冲突学习任务与所述冲突平面的夹角的余弦值;
[0059][0060]
通过以下公式计算所述选定学习任务以及所述冲突学习任务在所述冲突平面上的梯度分量:
[0061]
δg
i
=g
i
·
cosφ
i
,δg
j
=g
j
·
cosφ
j

[0062]
通过以下公式计算所述梯度分量计算选定学习任务与所述冲突学习任务之间的梯度分量差值:
[0063]
δg
{i,j}
=||g
i
·
cosφ
i

g
j
·
cosφ
j
||;
[0064]
其中,cosφ
i
为所述选定学习任务与所述冲突平面的夹角的余弦值,cosφ
j
为所述冲突学习任务与所述冲突平面的夹角的余弦值,δg
i
为所述选定学习任务在所述冲突平面上的梯度分量,δg
j
为所述冲突学习任务在所述冲突平面上的梯度分量。
[0065]
在一个优选的实施例中,所述根据所述梯度分量差值对所述选定学习任务的梯度值进行梯度值修剪,获得修剪后的梯度值,具体包括:
[0066]
通过以下公式计算所述选定学习任务修剪后的梯度值:
[0067][0068]
其中,为所述冲突平面的方向。
[0069]
示意性的,以学习任务i为上述选定学习任务,则根据选定学习任务i执行梯度更新的具体步骤如下:
[0070]
步骤1、首先判定学习任务i与其余任意一学习任务,例如是学习任务j之间是否存在任务冲突,具体的计算以下公式的数值:α
i,j
=sign(<g
i
,g
j
>);如果α
i,j


1,则说明学习任务i和学习任务j之间存在冲突;如果α
i,j


1,则说明学习任务i和学习任务j之间不存在冲突。如果学习任务i和学习任务j之间存在冲突,则将学习任务j作为上述冲突学习任务并执行步骤2,步骤3以及步骤4;如果学习任务i和学习任务j之间不存在冲突,则继续判断学习任务i与学习任务f之间是否存在冲突;如果学习任务i与学习任务f之间存在冲突,则将学习任务f作为上述冲突学习任务,将步骤2和步骤3中的冲突学习任务替换为学习任务f,对选定学习任务i的梯度值进行更新;如果不存在冲突,则选定学习任务i的梯度值更新执行完毕,此时选定学习任务i执行梯度更新后的梯度值依旧为原始的梯度值。
[0071]
步骤2、确定选定学习任务i与冲突学习任务j之间的冲突平面p
γ
,p
γ
=g
i

g
j
,然后计算选定学习任务i与冲突平面的夹角φ
i
的余弦值cosφ
i
,计算冲突学习任务j与冲突平面的夹角φ
j
的余弦值cosφ
j
,然后计算选定学习任务i与冲突学习任务j之间的梯度分量差值
[0072]
然后计算选定学习任务i与冲突学习任务j之间的梯度分量差值。所用到的公式计算公式如下:
[0073][0074]
δg
i
=g
i
·
cosφ
i
,δg
j
=g
j
·
cosφ
j
[0075]
δg
{i,j}
=||g
i
·
cosφ
i

g
j
·
cosφ
j
||;
[0076]
步骤3、梯度修剪:计算任务冲突平面的方向然后根据步骤2计算出选定学习任务i与冲突学习任务j之间的梯度分量差值来修剪选定学习任务i的梯度g
i
。所用到的公式计算公式如下:经过梯度修剪后选定学习任务i的梯度值从g
i
更新为g

i

[0077]
步骤4、将g

i
作为选定学习任务i更新后的梯度值,然后根据g

i
以及学习任务f的梯度值g
f
,继续判断选定学习任务i与学习任务f是否存在冲突,如果存不存在冲突,则选定学习任务i的梯度值更新执行完毕;如果存在冲突,则将学习任务f作为新的冲突学习任务,紧接着将步骤2和步骤3的冲突学习任务j替换为学习f,将步骤2和步骤3中选定任务的梯度值替换为g

i
,然后执行步骤2和步骤3,对选定任务的梯度值进行再次修剪和更新。
[0078]
对于步骤s103:按上述方法完成学习任务i的梯度值更新后,从其余的学习任务中(如上述学习任务j和学习任务f中),重新选取一个学习任务作为选定学习任务,重复执行步骤2中的梯度值更新,以此类推直至完成待缓解多任务学习模型中所有学习任务的梯度值更新。
[0079]
对于步骤s104,在所有学习任务的梯度值更新完毕后,获取所有学习任务最终的梯度值,然后计算平均值,将平均值作为待缓解多任务学习模型整体梯度值,即将待缓解多任务学习模型的梯度替换为上述平均值,然后根据平均梯度值继续神经网络的训练,使用优化器进行参数更新。
[0080]
具体公式如下,
[0081]
式中,g为待缓解多任务学习模型更新后梯度值,g

x
为模型中各个学习任务完成所述梯度值更新后的梯度值,t为模型中学习任务的个数。
[0082]
本发明上述实施例通过对各学习任务的梯度值进行梯度修剪来更新模型中网络参数,从而缓解多任务学习模型中任务冲突的问题,并且在此过程中无需调整各任务的loss权重,避免了任务训练不均衡的问题,提高了模型的整体性能。
[0083]
在上述方法项实施例的基础上,本发明对应提供了装置项实施例;
[0084]
如图2所示,在一个优选的实施例中,包括梯度值获取模块、学习任务梯度值更新模块以及模型参数值更新模块;
[0085]
所述梯度值获取模块,用于获取待缓解多任务学习模型中各个学习任务的梯度值;
[0086]
所述学习任务梯度值更新模块,用于任意选取一学习任务作为选定学习任务,根据所述选定学习任务执行梯度值更新;其中,所述梯度值更新具体包括:根据所述选定学习任务的梯度值以及其余各学习任务的梯度值,逐一判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;在每一次判定存在任务冲突时,对所述选定学习任务的梯度值
进行梯度修剪,并将所述选定学习任务的梯度值更新为修剪后的梯度值,根据所述选定学习任务更新后的梯度值继续判断所述选定学习任务与其余各学习任务之间是否存在任务冲突;以及,
[0087]
在所述梯度值更新执行完毕后,重新选取一学习任务作为更新后的选定学习任务,并重复执行所述梯度值更新,直至所有所述学习任务的梯度值更新完毕;
[0088]
所述模型梯度值更新模块,用于计算所有所述学习任务完成梯度值更新后的平均值,获得平均梯度,根据所述平均梯度对所述待缓解多任务学习模型的网络参数进行更新。
[0089]
需说明的是,上述装置项实施例是于本发明方法项实施例相对应的,其能够实现本发明任意一项所述的缓解多任务学习中任务冲突的方法。此外,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本发明提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
[0090]
在上述方法项实施例的基础上,本发明对应提供了存储介质项实施例;
[0091]
本发明一实施例提供了一种存储介质,所述存储介质包括存储的计算机程序,其中,在所述计算机程序运行时控制所述所在设备执行本发明任意一项所述的缓解多任务学习中任务冲突的方法。
[0092]
上述存储介质为计算机可读存储介质,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、u盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(rom,read

only memory)、随机存取存储器(ram,random access memory)、电载波信号、电信信号以及软件分发介质等。
[0093]
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也视为本发明的保护范围。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1