本技术涉及图像处理,特别是涉及一种广义类发现模型训练方法、一种图像识别方法、装置、电子设备及存储介质。
背景技术:
1、近年来,机器学习在自然语言处理、计算机视觉、语音识别等领域都得到了广泛应用,而在计算机视觉领域,图像分类任务是最受关注且应用最广的任务之一,各种分类技术层出不穷,性能不断提升。
2、在机器学习任务中,通过大量人工标注的图像实现分类的监督学习方法是图像分类的传统方法,由于传统方法主要依赖视觉信息完成模型的训练,忽略了图像类别本身的包含的语义信息,从而降低了分类器对新类的泛化能力。
技术实现思路
1、本技术实施例所要解决的技术问题是提供一种广义类发现模型训练方法、一种图像识别方法、装置、电子设备及存储介质,以解决目标分类器对图像新类的泛化能力较低的问题。
2、第一方面,本技术实施例提供了一种广义类发现模型的训练方法,包括
3、将样本图像中的标记图像和无标记图像分别输入到所述广义类发现模型的类知识引导样本相似性学习模块,得到所述标记图像和无标记图像之间的第一相似性矩阵;
4、将所述标记图像的标签信息输入到广义类发现模型的类别相似性模块,得到标签类别之间的相似性得分;
5、将所述样本图像输入到所述广义类发现模型的标签估计模块,得到所述样本图像的聚类结果;
6、根据所述聚类结果、所述第一相似性矩阵和所述相似性得分对所述广义类发现模型的原型网络进行训练,直到所述广义类发现模型收敛。
7、可选地,所述将样本图像中的标记图像和无标记图像输入到所述广义类发现模型的类知识引导样本相似性学习模块,得到所述标记图像和无标记图像之间的第一相似性矩阵包括:
8、将所述标记图像和所述无标记图像分别输入到所述特征提取网络,得到所述标记图像对应的第一嵌入向量和所述无标记图像的对应的第二嵌入向量;
9、将所述第一嵌入向量和所述第二嵌入向量输入到样本关系网络,得到所述标记图像和无标记图像之间的第一相似性矩阵。
10、可选地,所述将所述标记图像的标签信息输入到广义类发现模型的类别相似性模块,得到标签类别之间的相似性得分包括:
11、获取所述标记图像的标签信息对应的语义描述向量;
12、将所述语义描述向量输入到样本关系网络,得到标签类别之间的相似性得分。
13、可选地,所述将所述样本图像输入到所述广义类发现模型的标签估计模块,得到所述样本图像的聚类结果包括:
14、对所述样本图像进行重表示,获得标记后的样本图像;
15、将所述标记后的样本图像输入到所述特征提取网络,得到标记后的样本图像对应的第三嵌入向量;
16、将所述第三嵌入向量输入到样本关系网络,得到所述标记后的样本图像与所述标记图像之间的第二相似性矩阵;
17、使用半监督k-means聚类算法对所述第二相似性矩阵进行聚类,获得聚类结果,所述聚类结果包括:无标记图像的聚类个数、标记图像的聚类个数和无标记图像的伪标签。
18、可选地,所述原型网络包括:所述样本图像中的标记图像的已知类别原型和所述样本图像的新类别原型,其中,将标记图像的标签信息对应的语义描述向量作为原型网络的已知类别原型,将所述第三嵌入向量作为新类别原型;
19、所述根据所述聚类结果、所述第一相似性矩阵和所述相似性得分对所述广义类发现模型的原型网络进行训练,直到所述广义类发现模型收敛包括:
20、根据所述第二相似性矩阵和所述相似性得分确定类别相似性损失值;
21、根据所述标记图像的第二嵌入向量和标记图像的已知类别原型确定聚类损失值;
22、将所述无标记图像输入到所述原型网络,得到匹配得分;
23、根据所述匹配得分、所述伪标签和所述标记图像的标签信息确定分类损失值;
24、根据所述类别相似性损失值、聚类损失值和分类损失值对所述原型网络的网络参数进行更新,直到所述广义类发现模型收敛。
25、第二方面,本技术实施例提供了一种图像识别方法,包括:
26、获取预先训练的广义类发现模型,所述广义类发现模型由所述的训练方法训练得到;
27、将样本图像中的无标记图像输入到广义类发现模型的特征提取网络,得到所述无标记图像的第四嵌入向量;
28、将所述第四嵌入向量输入到所述广义类发现模型的原型网络,得到所述无标记图像与所述样本图像对应的各个类别原型之间的匹配分数,所述原型网络包括:所述样本图像中的标记图像的已知类别原型和所述样本图像的新类别原型;
29、根据所述匹配分数确定所述无标记图像的预测标签信息。
30、可选地,所述将所述第四嵌入向量输入到所述广义类发现模型的原型网络,得到所述无标记图像与所述样本图像对应的各个类别原型之间的匹配分数包括:
31、计算所述无标记图像与所述已知类别原型之间的第一匹配分数以及所述无标记图像与所述新类别原型之间的第二匹配分数。
32、可选地,所述根据所述匹配分数确定所述无标记图像的预测标签信息包括:
33、通过最大化似然函数从所述第一匹配分数和所述第二匹配分数中确定最大的匹配分数;
34、将所述最大的匹配分数对应的类别原型作为无标记图像的预测标签信息。
35、第三方面,本技术实施例一种广义类发现模型的训练装置,包括:
36、类知识学习模块,用于将样本图像中的标记图像和无标记图像分别输入到所述广义类发现模型的类知识引导样本相似性学习模块,得到所述标记图像和无标记图像之间的第一相似性矩阵;
37、类别模块,用于将所述标记图像的标签信息输入到广义类发现模型的类别相似性模块,得到标签类别之间的相似性得分;
38、标签模块,用于将所述样本图像输入到所述广义类发现模型的标签估计模块,得到所述样本图像的聚类结果;
39、训练模块,用于根据所述聚类结果、所述第一相似性矩阵和所述相似性得分对所述广义类发现模型的原型网络进行训练,直到所述广义类发现模型收敛。
40、第四方面,本技术实施例提供了一种图像识别装置,包括:
41、获取模块,用于获取预先训练的广义类发现模型,所述广义类发现模型由所述的训练方法训练得到;
42、输入模块,用于将样本图像中的无标记图像输入到广义类发现模型的特征提取网络,得到所述无标记图像的第四嵌入向量;
43、匹配分数模块,用于将所述第四嵌入向量输入到所述广义类发现模型的原型网络,得到所述无标记图像与所述样本图像对应的各个类别原型之间的匹配分数,所述原型网络包括:所述样本图像中的标记图像的已知类别原型和所述样本图像的新类别原型;
44、标签模块,用于根据所述匹配分数确定所述无标记图像的预测标签信息。
45、第五方面,本技术实施例提供了一种电子设备,包括:
46、处理器、存储器以及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现所述的训练方法或者所述的图像识别方法。
47、第六方面,本技术实施例提供了一种计算机可读存储介质,当所述计算机可读存储介质存储有被处理器执行时实现所述的训练方法或者所述的图像识别方法的计算机程序。
48、与现有技术相比,本技术实施例包括以下优点:
49、本技术实施例中,通过引入类别先验知识,设计了类知识指导的样本相似性学习模块,得到所述标记图像和无标记图像之间的第一相似性矩阵,将所述标记图像的标签信息输入到广义类发现模型的类别相似性模块,得到标签类别之间的相似性得分;将所述样本图像输入到所述广义类发现模型的标签估计模块,得到所述样本图像的聚类结果;根据所述聚类结果、所述第一相似性矩阵和所述相似性得分对所述广义类发现模型的原型网络进行训练,直到所述广义类发现模型收敛,与现有技术相比,由于在训练的过程中使用了标签信息,可以充分利用标签信息本身具有丰富的语义信息,通过标签信息得到标签类别之间存在紧密的关联,因此能够基于这些关联信息推断新类别,从而提升了广义类发现模型对新类别的泛化能力。
50、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本技术。