目标追踪模型的训练方法及装置与流程

文档序号:35794847发布日期:2023-10-21 22:12阅读:34来源:国知局
目标追踪模型的训练方法及装置与流程

本申请涉及目标检测,尤其涉及一种目标追踪模型的训练方法及装置。


背景技术:

1、目标追踪(person re-identification)也称行人再识别,是利用计算机视觉技术判断图像或者视频序列中是否存在特定行人的技术。在目标追踪模型的训练中,检测框不准确以及密集场景下的遮挡对模型精度会造成影响,同时在提升模型精度时,模型鲁棒性无法得到保障。


技术实现思路

1、有鉴于此,本申请实施例提供了一种目标追踪模型的训练方法、装置、电子设备及计算机可读存储介质,以解决现有技术中,检测框不准确以及密集场景下的遮挡会影响目标追踪模型的精度,同时目标追踪模型的精度和鲁棒性无法兼顾的问题。

2、本申请实施例的第一方面,提供了一种目标追踪模型的训练方法,包括:串行连接图片处理网络、特征提取网络和分类网络,得到目标追踪模型,其中,图片处理网络用于对输入的图片进行随机遮挡处理和随机裁剪处理,分类网络由全局平均池化层和全连接层组成;获取训练数据集,将训练数据集中的目标样本及其正样本和负样本输入目标追踪模型:通过图片处理网络对目标样本、正样本和负样本进行处理,得到目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本;通过特征提取网络分别对目标样本、正样本、负样本以及目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本进行处理,得到各个样本对应的样本特征;通过分类网络分别对目标样本及其对应的遮挡样本和裁剪样本的样本特征进行处理,得到第一识别结果、第二识别结果和第三识别结果;基于第一识别结果、第二识别结果和第三识别结果,计算分类损失和散度损失,基于各个样本对应的样本特征,计算三元组损失;依据分类损失、散度损失和三元组损失,更新目标追踪模型的模型参数,以完成对目标追踪模型训练。

3、本申请实施例的第二方面,提供了一种目标追踪模型的训练装置,包括:构建模块,被配置为串行连接图片处理网络、特征提取网络和分类网络,得到目标追踪模型,其中,图片处理网络用于对输入的图片进行随机遮挡处理和随机裁剪处理,分类网络由全局平均池化层和全连接层组成;获取模块,被配置为获取训练数据集,将训练数据集中的目标样本及其正样本和负样本输入目标追踪模型:处理模块,被配置为通过图片处理网络对目标样本、正样本和负样本进行处理,得到目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本;提取模块,被配置为通过特征提取网络分别对目标样本、正样本、负样本以及目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本进行处理,得到各个样本对应的样本特征;识别模块,被配置为通过分类网络分别对目标样本及其对应的遮挡样本和裁剪样本的样本特征进行处理,得到第一识别结果、第二识别结果和第三识别结果;计算模块,被配置为基于第一识别结果、第二识别结果和第三识别结果,计算分类损失和散度损失,基于各个样本对应的样本特征,计算三元组损失;更新模块,被配置为依据分类损失、散度损失和三元组损失,更新目标追踪模型的模型参数,以完成对目标追踪模型训练。

4、本申请实施例的第三方面,提供了一种电子设备,包括存储器、处理器以及存储在存储器中并且可在处理器上运行的计算机程序,该处理器执行计算机程序时实现上述方法的步骤。

5、本申请实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。

6、本申请实施例与现有技术相比存在的有益效果是:因为本申请实施例通过串行连接图片处理网络、特征提取网络和分类网络,得到目标追踪模型,其中,图片处理网络用于对输入的图片进行随机遮挡处理和随机裁剪处理,分类网络由全局平均池化层和全连接层组成;获取训练数据集,将训练数据集中的目标样本及其正样本和负样本输入目标追踪模型:通过图片处理网络对目标样本、正样本和负样本进行处理,得到目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本;通过特征提取网络分别对目标样本、正样本、负样本以及目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本进行处理,得到各个样本对应的样本特征;通过分类网络分别对目标样本及其对应的遮挡样本和裁剪样本的样本特征进行处理,得到第一识别结果、第二识别结果和第三识别结果;基于第一识别结果、第二识别结果和第三识别结果,计算分类损失和散度损失,基于各个样本对应的样本特征,计算三元组损失;依据分类损失、散度损失和三元组损失,更新目标追踪模型的模型参数,以完成对目标追踪模型训练。采用上述技术手段,可以解决现有技术中,检测框不准确以及密集场景下的遮挡会影响目标追踪模型的精度,同时目标追踪模型的精度和鲁棒性无法兼顾的问题,进而提高目标追踪模型的精度和鲁棒性。



技术特征:

1.一种目标追踪模型的训练方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,基于所述第一识别结果、所述第二识别结果和所述第三识别结果,计算分类损失,包括:

3.根据权利要求1所述的方法,其特征在于,基于所述第一识别结果、所述第二识别结果和所述第三识别结果,计算散度损失,包括:

4.根据权利要求1所述的方法,其特征在于,基于各个样本对应的样本特征,计算三元组损失,包括:

5.根据权利要求1所述的方法,其特征在于,获取训练数据集之后,所述方法还包括:

6.根据权利要求1所述的方法,其特征在于,获取训练数据集之后,所述方法还包括:

7.根据权利要求1所述的方法,其特征在于,依据所述分类损失、所述散度损失和所述三元组损失,更新所述目标追踪模型的模型参数,以完成对所述目标追踪模型训练,包括:

8.一种目标追踪模型的训练装置,其特征在于,包括:

9.一种电子设备,包括存储器、处理器以及存储在所述存储器中并且可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7中任一项所述方法的步骤。

10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述方法的步骤。


技术总结
本申请提供了一种目标追踪模型的训练方法及装置。该方法包括:串行连接图片处理网络、特征提取网络和分类网络,得到目标追踪模型;通过图片处理网络对目标样本、正样本和负样本进行处理,得到目标样本、正样本和负样本各自对应的遮挡样本和裁剪样本;通过特征提取网络对各个样本进行处理,得到各个样本对应的样本特征;通过分类网络分别对目标样本及其对应的遮挡样本和裁剪样本的样本特征进行处理,得到第一识别结果、第二识别结果和第三识别结果;基于第一识别结果、第二识别结果和第三识别结果,计算分类损失和散度损失,基于各个样本对应的样本特征,计算三元组损失;依据分类损失、散度损失和三元组损失,更新目标追踪模型的模型参数。

技术研发人员:蒋召,黄泽元
受保护的技术使用者:深圳须弥云图空间科技有限公司
技术研发日:
技术公布日:2024/1/15
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1