图片分类模型的训练方法、分类方法、装置、设备和介质与流程

文档序号:34234351发布日期:2023-05-24 18:34阅读:146来源:国知局
图片分类模型的训练方法、分类方法、装置、设备和介质与流程

本公开涉及计算机,特别涉及一种图片分类模型的训练方法、分类方法、装置、设备和介质。


背景技术:

1、在计算机技术领域中,用于训练模型的数据集合经常会出现不平衡现象,即少数类别的图像含有大量的样本(定义为头部类别),而大多数类别的图像只包含少量的样本(定义为尾部类别),这种情况下,该数据集合可以称为是长尾分布数据,或呈长尾分布的数据集合。

2、由于长尾分布数据中这种不平衡的现象,容易导致训练得到的模型在头部类别的图像分类上表现较好的效果,而在尾部类别的图像分类上精度较低,影响图片分类的整体准确率。


技术实现思路

1、本公开提供一种图片分类模型的训练方法、分类方法、装置、设备和介质,可以提高模型对图片分类的预测准确性。

2、第一方面,本公开提供了一种图片分类模型的训练方法,该图片分类模型的训练方法包括:将预先获取的训练数据集中的图片样本输入图片分类模型,得到图片样本的预测类别;其中,所述图片分类模型包括多个专家子网络,将所述图片样本输入所述图片分类模型进行类别预测的步骤包括:通过所述多个专家子网络分别对所述图片样本进行特征提取,得到每个专家子网络提取的图片特征;根据所述多个专家子网络提取得到的图片特征进行特征分类,得到所述图片样本的预测类别;基于所述预测类别、所述图片样本的标注类别和n个角度,构建目标损失函数,所述n个角度是基于所述每个专家子网络提取的图片特征确定的角度,n为大于或等于1的整数;根据所述目标损失函数调整所述图片分类模型的参数,得到训练后的图片分类模型。

3、第二方面,本公开提供了一种图片分类方法,该图片分类方法包括:获取待分类图片,将所述待分类图片输入图片分类模型中进行类别预测,得到所述待分类图片的预测类别;其中,所述图片分类模型包括多个专家子网络,将所述待分类图片输入图片分类模型中进行类别预测的步骤包括:通过所述多个专家子网络分别对所述图片样本进行特征提取,得到每个专家子网络提取的图片特征;根据所述多个专家子网络提取得到的图片特征进行特征分类,得到所述图片样本的预测类别。

4、第三方面,本公开提供了一种图片分类装置,该图片分类装置包括:获取模块,用于获取待分类图片;得到模块,用于将所述待分类图片输入图片分类模型中进行类别预测,得到所述待分类图片的预测类别;其中,所述图片分类模型包括多个专家子网络,将所述待分类图片输入图片分类模型中进行类别预测的步骤包括:通过所述多个专家子网络分别对所述图片样本进行特征提取,得到每个专家子网络提取的图片特征;根据所述多个专家子网络提取得到的图片特征进行特征分类,得到所述图片样本的预测类别。

5、第四方面,本公开提供了一种电子设备,该电子设备包括:至少一个处理器;以及与至少一个处理器通信连接的存储器;其中,存储器存储有可被至少一个处理器执行的一个或多个计算机程序,一个或多个计算机程序被至少一个处理器执行,以使至少一个处理器能够执行上述的图片分类模型的训练方法或图片分类方法。

6、第五方面,本公开提供了一种计算机可读存储介质,其上存储有计算机程序,其中,计算机程序在被处理器/处理核执行时实现上述的图片分类模型的训练方法或图片分类方法。

7、本公开所提供的实施例,能够使用图片训练数据集中的图片样本对图片分类模型进行模型训练,在模型训练过程中,图片样本的预测类别和标注类别的差异大小用于确定分类的预测精度,根据每个专家子网络提取的图片特征确定的角度大小代表不同专家子网络学习到的图片特征的差异程度;模型训练过程中,基于预测类别、图片样本的标注类别以及不同专家子网络提取的图片特征之间的角度,进行损失函数的构建,可以在保证图片分类精度的基础上通过特征之间的角度来处理多专家网络架构中学习到的特征差异分布不均匀的问题,从而训练得到能力均衡的图片分类模型,提高模型对于所有图片分类的分类性能,进而提高模型对图片分类的预测准确性。

8、应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。



技术特征:

1.一种图片分类模型的训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述图片样本是通过下述方式得到的:

3.根据权利要求1所述的方法,其特征在于,每个专家子网络均包括卷积层和分类层,每个专家子网络的卷积层均包括共享部分以及独立部分;

4.根据权利要求1所述的方法,其特征在于,所述基于所述预测类别、所述图片样本的标注类别和n个角度,构建目标损失函数,包括:

5.根据权利要求4所述的方法,其特征在于,所述对所述第一损失函数和所述第二损失函数进行加权求和,得到所述目标损失函数,包括:

6.根据权利要求4所述的方法,其特征在于,所述基于所述每个专家子网络提取的图片特征确定特征中心点,包括:

7.一种图片分类方法,其特征在于,包括:

8.一种图片分类装置,其特征在于,

9.一种电子设备,其特征在于,包括:

10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序在被处理器执行时实现如权利要求1-6中任一项所述的图片分类模型训练方法,或权利要求7所述的图片分类方法。


技术总结
本公开提供了一种图片分类模型的训练方法、分类方法、装置、设备和介质,该方法包括:将预先获取的训练数据集中的图片样本输入图片分类模型,得到图片样本的预测类别;其中,图片分类模型包括多个专家子网络,将图片样本输入图片分类模型进行类别预测的步骤包括:通过多个专家子网络分别对图片样本进行特征提取,得到每个专家子网络提取的图片特征;根据多个专家子网络提取得到的图片特征进行特征分类,得到图片样本的预测类别;基于预测类别、图片样本的标注类别和N个角度,构建目标损失函数;根据目标损失函数调整图片分类模型的参数,得到训练后的图片分类模型。根据本公开的实施例能够提高模型对图片分类的预测准确性。

技术研发人员:范峻植,杨烨,冉承祥,夏粉,蒋宁
受保护的技术使用者:马上消费金融股份有限公司
技术研发日:
技术公布日:2024/1/12
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1