本技术涉及计算机,特别是涉及一种半监督模型训练方法、装置、计算机设备、存储介质和计算机程序产品。
背景技术:
1、伴随着工业互联网应用的普及,人工智能技术为实现企业质检产线自动化、无人化、智能化等需求提供了较好的解决方案,取得了较为显著的效果。
2、但是,现有的人工智能技术对于标注样本的规模要求较大,而对于缺陷样品进行标注的人工成本代价十分高昂,这成为了应用人工智能技术进行工业质检中的一个不可忽视的缺陷。
技术实现思路
1、基于此,有必要针对上述技术问题,提供一种能够减少对样本标注的依赖的半监督模型训练方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
2、第一方面,本技术提供了一种半监督模型训练方法,包括:
3、将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;
4、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;
5、将第二预设数量的标注图像样本输入初始化分类模型,得到每一所述标注图像样本各自对应的特征,根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;
6、根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;
7、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签;
8、将所述每一所述未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对所述初始化分类模型进行参数更新,得到训练好的分类模型;所述训练好的分类模型用于对工业质检图像数据进行分类。
9、在其中一个实施例中,所述将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本被分类至每一分类类别各自对应的概率,包括:将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征;通过所述初始化分类模型中的分类器对每一所述未标注图像样本各自对应的特征进行处理,得到每一所述未标注图像样本被分类至每一分类类别各自对应的概率;所述根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型,包括:将每一所述未标注图像样本被分类至每一分类类别各自对应的概率进行归一化处理,得到每一所述未标注图像样本被分类至每一分类类别各自对应的归一化概率;针对多个分类类别中的每一个分类类别,从每一所述未标注图像样本被分类至每一分类类别各自对应的归一化概率中,筛选出每一所述未标注图像样本被分类至所针对的分类类别的目标归一化概率值;将每一所述未标注图像样本各自对应的特征和被分类至所针对的分类类别的目标归一化概率值相乘所得到的数值进行求和处理,得到所述针对的分类类别的第一类别原型所需的第一参数;将每一所述未标注图像样本被分类至所述针对的分类类别的目标归一化概率进行求和处理,得到计算所述针对的分类类别的第一类别原型所需的第二参数;根据所述第一参数和所述第二参数,确定所述针对的分类类别所对应的第一类别原型。
10、在其中一个实施例中,所述根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型,包括:获取预设函数;所述预设函数是根据所述标注图像样本标注的真实标签和被分类至每一分类类别各自对应的概率是否匹配来确定的;根据所述预设函数确定每一所述标注图像样本各自对应的预设数值;将每一所述标注图像样本各自对应的特征和预设数值相乘所得到的数值进行求和处理,得到计算每一分类类别的第二类别原型所需的第三参数;将每一所述标注图像样本各自对应的预设数值进行求和处理,得到计算每一分类类别的第二类别原型所需的第四参数;根据每一分类类别对应的第三参数和第四参数,确定每一分类类别各自对应的第二类别原型。
11、在其中一个实施例中,所述根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型,包括:确定对所述初始化分类模型进行训练的当前迭代次数;根据所述当前迭代次数,确定第一类别原型在所述当前迭代次数下对应的第一权重系数,以及第二类别原型在所述当前迭代次数下对应的第二权重系数;根据所述第一权重系数和所述第二权重系数,对每一所述分类类别在所述当前迭代次数下对应的第一类别原型和第二类别原型进行线性加权处理,得到每一分类类别在所述当前迭代次数下分别对应的修正类别原型。
12、在其中一个实施例中,所述根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签,包括:针对多个未标注图像样本中的每一个未标注图像样本,从每一个未标注图像样本被分类至每一分类类别对应的修正类别原型中,筛选出所述针对的未标注图像样本被分类至每一分类类别对应的目标修正类别原型;根据所述针对的未标注图像样本对应的特征和被分类至每一分类类别对应的目标修正类别原型之间的余弦距离,得到所述针对的未标注图像样本对应的伪标签。
13、在其中一个实施例中,所述根据所述针对的未标注图像样本对应的特征和被分类至每一分类类别对应的目标修正类别原型之间的余弦距离,得到所述针对的未标注图像样本对应的伪标签,包括:将所述针对的未标注图像样本对应的特征和被分类至每一分类类别对应的目标修正类别原型之间的余弦距离进行归一化处理,得到所述针对的未标注图像样本在每一分类类别中对应的子伪标签;将所述针对的未标注图像样本在每一分类类别中对应的子伪标签和所述针对的未标注图像样本被分类至每一分类类别各自对应的概率相乘得到的数值,作为所述针对的未标注图像样本在每一分类类别中对应的目标子伪标签;
14、根据所述针对的未标注图像样本在每一分类类别中对应的目标子伪标签,得到所述针对的未标注图像样本对应的伪标签。
15、在其中一个实施例中,获取工业质检图像数据;将获取到的所述工业质检图像数据输入训练好的分类模型;通过所述训练好的分类模型对所述工业质检图像数据进行特征提取,得到所述工业质检图像数据对应的图像特征;通过所述训练好的分类模型对所述工业质检图像数据对应的图像特征进行分类,得到所述工业质检图像数据的分类结果,所述工业质检图像数据的分类结果包括所述工业质检图像数被分类至每一分类类别对的概率值。
16、第二方面,本技术还提供了一种半监督模型训练装置,包括:
17、初始模块,用于将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;
18、第一确定模块,用于根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;
19、第二确定模块,用于将第二预设数量的标注图像样本输入初始化分类模型,得到每一所述标注图像样本各自对应的特征,根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;
20、修正模块,用于根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;
21、伪标签模块,用于根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签;
22、更新模块,用于将所述每一所述未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对所述初始化分类模型进行参数更新,得到训练好的分类模型;所述训练好的分类模型用于对工业质检图像数据进行分类。
23、第三方面,本技术还提供了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现以下步骤:
24、将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;
25、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;
26、将第二预设数量的标注图像样本输入初始化分类模型,得到每一所述标注图像样本各自对应的特征,根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;
27、根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;
28、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签;
29、将所述每一所述未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对所述初始化分类模型进行参数更新,得到训练好的分类模型;所述训练好的分类模型用于对工业质检图像数据进行分类。
30、第四方面,本技术还提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现以下步骤:
31、将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;
32、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;
33、将第二预设数量的标注图像样本输入初始化分类模型,得到每一所述标注图像样本各自对应的特征,根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;
34、根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;
35、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签;
36、将所述每一所述未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对所述初始化分类模型进行参数更新,得到训练好的分类模型;所述训练好的分类模型用于对工业质检图像数据进行分类。
37、第五方面,本技术还提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现以下步骤:
38、将第一预设数量的未标注图像样本输入初始化分类模型,得到每一所述未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;
39、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;
40、将第二预设数量的标注图像样本输入初始化分类模型,得到每一所述标注图像样本各自对应的特征,根据每一所述标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;
41、根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;
42、根据每一所述未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一所述未标注图像样本的伪标签;
43、将所述每一所述未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对所述初始化分类模型进行参数更新,得到训练好的分类模型;所述训练好的分类模型用于对工业质检图像数据进行分类。
44、上述半监督模型训练方法、装置、计算机设备、存储介质和计算机程序产品,通过将第一预设数量的未标注图像样本输入初始化分类模型,得到每一未标注图像样本各自对应的特征和被分类至每一分类类别各自对应的概率;根据每一未标注图像样本各自对应的特征和被分类至每一分类类别的概率,确定每一分类类别各自对应的第一类别原型;将第二预设数量的标注图像样本输入初始化分类模型,得到每一标注图像样本各自对应的特征,根据每一标注图像样本各自对应的特征、各自标注的真实标签和被分类至每一分类类别各自对应的概率,确定每一分类类别各自对应的第二类别原型;根据每一分类类别各自对应的第一类别原型和第二类别原型,确定每一分类类别对应的修正类别原型;根据每一未标注图像样本各自对应的特征和被分类至每一分类类别对应的修正类别原型,得到每一未标注图像样本的伪标签;将每一未标注图像样本被分类至每一分类类别各自对应的概率和伪标签之间的差异,对初始化分类模型进行参数更新,得到训练好的分类模型;训练好的分类模型用于对工业质检图像数据进行分类。能够有效为无标注样本构建更为准确的伪标签,并通过伪标签对初始化分类模型的参数进行更新,减少了对样本标注的依赖,并提高了对工业质检图像数据进行分类的准确性。