伪标签生成模型训练方法、装置及伪标签生成方法及装置与流程

文档序号:14749054发布日期:2018-06-22 09:54阅读:来源:国知局
技术特征:

1.一种伪标签生成模型训练方法,其特征在于,该方法包括:

获取带有标签的源域数据、不带标签的第一目标域数据以及带有标签的第二目标域数据;

使用第一辅助神经网络对第一目标域数据进行特征学习,获取在第一辅助神经网络中指定特征提取层的第一特征向量,并且使用第二辅助神经网络对第二目标域数据进行特征学习,获取第二辅助神经网络中指定特征提取层的第二特征向量;

根据所述第一特征向量以及第二特征向量计算第一域混淆损失;

使用目标神经网络对所述源域数据进行特征学习,获取所述目标神经网络中指定特征提取层提取的源域特征向量;并将目标神经网络输出的特征向量输入至目标分类器得到第一分类结果;

根据所述第一特征向量和所述源域特征向量计算第二域混淆损失;

根据所述第一域混淆损失对所述第一辅助神经网络进行本轮训练;以及根据所述第二域混淆损失,以及所述第一分类结果,对所述目标神经网络进行本轮训练;以及根据所述第一分类结果,对所述目标分类器进行本轮训练;

经过对所述目标神经网络和所述目标分类器进行多轮训练,得到伪标签生成模型。

2.根据权利要求1所述的方法,其特征在于,所述指定特征提取层包括位于每个神经网络末端预设数量的特征提取层;

所述根据所述第一特征向量和所述源域特征向量计算第二域混淆损失,具体包括:

将从第一辅助神经网络的各个指定特征提取层提取的所述第一特征向量进行拼接,形成第一拼接向量,以及将从目标神经网络的各个指定特征提取层提取的所述源域特征向量进行拼接,形成目标拼接向量;

根据所述第一拼接向量以及所述目标拼接向量,计算所述第二域混淆损失。

3.根据权利要求2所述的方法,其特征在于,根据所述第二域混淆损失,以及所述第一分类结果,对所述目标神经网络进行本轮训练;以及,根据所述第一分类结果,对所述目标分类器进行本轮训练,具体包括:

执行如下域混淆损失比对操作以及第一分类损失确定操作,直至第二域混淆损失不大于预设的第二混淆损失阈值,以及第一分类损失不大于预设的第一分类损失阈值;

所述域混淆损失比对操作包括:

将所述第二域混淆损失与预设的第二混淆损失阈值进行比对;

如果所述第二域混淆损失大于预设的第二混淆损失阈值,则调整所述目标神经网络的参数;

所述第一分类损失确定操作包括:

根据对所述源域数据的所述第一分类结果,以及所述源域数据的标签,计算第一分类损失;

将所述第一分类损失与预设的第一分类损失阈值进行比对;

如果所述第一分类损失大于预设的第一分类损失阈值,则调整所述目标神经网络的参数以及所述目标分类器的参数。

4.根据权利要求1所述的方法,其特征在于,所述根据所述第一特征向量以及第二特征向量计算第一域混淆损失之后,还包括:

根据所述第一域混淆损失调整所述第二辅助神经网络在训练过程中的参数;

所述根据所述第一特征向量和所述源域特征向量计算第二域混淆损失之后,还包括:

根据所述第二域混淆损失,调整所述第一辅助神经网络在训练过程中的参数。

5.根据权利要求1所述的方法,其特征在于,所述使用第一辅助神经网络对第一目标域数据进行特征学习之后,还包括:

使用第一分类器对所述第一辅助神经网络输出的特征向量进行分类;

根据第一分类器对第一辅助神经网络输出的特征向量进行分类的结果,调整第一辅助神经网络的在训练过程中的参数;

所述使用第二辅助神经网络对第二目标域数据进行特征学习之后,还包括:

使用第二分类器对第二辅助神经网络输出的特征向量进行分类;

根据第二分类器对第二辅助神经网络输出的特征向量进行分类的结果,调整第二辅助神经网络的在训练过程中的参数。

6.根据权利要求5所述的方法,其特征在于,所述根据第一分类器对第一辅助神经网络输出的特征向量进行分类的结果,调整第一辅助神经网络在训练过程中的参数,具体包括:

执行如下交叉熵确定操作,直至交叉熵不大于预设的交叉熵阈值;

所述交叉熵损失确定操作包括:

根据所述第一分类器对第一辅助神经网络输出的特征向量进行分类的结果,以及第一目标域数据和源域数据中每一类数据的相似度,计算第一辅助神经网络的交叉熵损失;

在所述交叉熵损失不小于预设的交叉熵阈值时,调整所述第一辅助神经网络在训练过程中的参数。

7.根据权利要求5所述的方法,其特征在于,所述根据第二分类器对第二辅助神经网络输出的特征向量进行分类的结果,调整第二辅助神经网络的在训练过程中的参数,具体包括:

执行如下第二分类损失确定操作,直至第二分类损失不大于预设的第二分类损失阈值;

所述第二分类损失确定操作包括:

根据第二分类器对所述第二辅助神经网络输出的特征向量进行分类的结果,以及第二目标域数据的标签,计算所述第二分类损失;

将所述第二分类损失与预设的第二分类损失阈值进行比对;

如果所述第二分类损失大于预设的第二分类损失阈值,则调整所述第二辅助神经网络的参数以及所述第二分类器的参数。

8.根据权利要求1所述的方法,其特征在于,所述经过对所述目标神经网络和所述目标分类器进行多轮训练之后,得到伪标签生成模型之前,还包括:

使用经过多轮训练的目标神经网络为所述第一目标域数据提取第三特征向量,并将第三特征向量输入经过多轮训练的目标分类器得到分类结果,将得到的分类结果作为第一目标域数据的临时标签;

将具有临时标签的第一目标域数据、源域数据输入至经过多轮训练的目标神经网络,使用经过多轮训练的目标神经网络对具有临时标签的第一目标域数据、源域数据进行特征学习,获取所述经过多轮训练的目标神经网络中指定特征提取层提取的第四特征向量;

将第二目标域数据输入至经过多轮训练的第二辅助神经网络,使用经过多轮训练的第二辅助神经网络对第二目标域数据进行特征学习,获取所述经过多轮训练的第二辅助神经网络中指定特征提取层提取的第五特征向量;

根据所述第四特征向量和所述第五特征向量计算第三域混淆损失;

根据所述第三域混淆损失,调整所述目标神经网络在训练过程中的参数;根据经过多轮训练的目标分类器对所述经过多轮训练的目标神经网络输出的特征向量进行分类的结果,调整经过多轮训练的目标神经网络在训练过程中的参数,并调整经过多轮训练的目标分类器的参数;

对经过多轮训练的目标神经网络和经过多轮训练的目标分类器的再次进行多轮训练,得到所述伪标签生成模型。

9.根据权利要求8所述的方法,其特征在于,根据所述第四特征向量和所述第五特征向量计算第三域混淆损失之后,还包括:

根据所述第三域混淆损失调整经过多轮训练的第二辅助神经网络的参数;

使用经过多轮训练的第二辅助神经网络对第二目标域数据进行特征学习之后,还包括:

使用经过多轮训练的第二分类器对经过多轮训练的第二辅助神经网络输出的特征向量进行分类;

根据经过多轮训练的第二分类器对经过多轮训练的第二辅助神经网络输出的特征向量进行分类的结果,调整经过多轮训练的第二辅助神经网络的参数。

10.一种伪标签生成方法,其特征在于,该方法包括:

获取待分类数据;

将所述待分类数据输入至通过权利要求1-10任意一项所述的伪标签生成模型训练方法得到的伪标签生成模型,得到所述待分类数据的分类结果;

将所述待分类数据的分类结果作为所述待分类数据的伪标签。

当前第2页1 2 3 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1