本发明涉及一种深度学习模型的迭代式训练方法。
背景技术:
1、近年来,深度学习技术在各场景的应用变得越来越广泛,也取得了不错的效果。但单个深度学习模型面对多场景任务时往往不如人意,对于个别场景获得优秀结果的同时也存在检测或识别有困难的场景,其鲁棒性(robustness,也称稳健性)较差。在实际问题中,一个系统通常需要解决多场景的问题。例如某深度学习模型用于从图像中进行文字识别且将文字排列为与图像相同的版面结构,多场景包括文档、票据、表格等。
2、为了解决单个深度学习模型的鲁棒性低的问题,一种解决方案是采用多个深度学习模型融合。但是多模型存在过于冗余、推理时间较长、难以应用的缺点,因此提高单模型的鲁棒性是很有必要的。
技术实现思路
1、本发明所要解决的技术问题是:如何使用于多场景任务的单个深度学习模型在不同的场景上均取得较好的效果,提高单模型的鲁棒性。
2、为解决上述技术问题,本发明提出了一种多场景深度学习模型的迭代式训练方法,包括如下步骤。步骤s1:给定待训练的用于执行多场景任务的初始模型m0、训练数据集tr和测试数据集te;所述训练数据集tr包括多个不同场景下的训练数据子集;所述测试数据集te包括多个不同场景下的测试数据子集。步骤s2:采用训练数据集tr的全部训练数据子集来训练初始模型m0,得到中间模型m1。步骤s3:采用测试数据集te的全部测试数据子集对中间模型m1进行测试;挑选出当前测试结果最差的测试数据子集,这个测试数据子集对应场景就是对当前的中间模型m1来说最具难度的场景,将该最具难度场景下的训练子集称为th。步骤s4:迭代步骤s2-步骤s3,每次迭代时在训练数据集tr中重复增加当前困难场景的训练子集th中的样本,每次迭代的训练对象是上一轮迭代得到的中间模型m1;重复迭代步骤s2-步骤s3,直到获得一个在各个场景均令人满意的最终模型m2。
3、进一步地,所述步骤s1中,不同类型、不同来源的输入数据作为一类场景。
4、进一步地,所述步骤s1中,所述训练数据集tr包括n个不同场景下的训练数据子集;所述测试数据集te包括n个不同场景下的测试数据子集;并且同一个场景下既有训练数据子集也有测试数据子集。
5、进一步地,所述步骤s2中,中间模型m1适用于全部场景。
6、进一步地,所述步骤s2中,在统计学上所述中间模型m1按场景划分,必然存在识别或检测的不同难易程度。
7、进一步地,所述步骤s4中,在训练数据集tr中重复增加当前困难场景的训练子集th中的样本,也就是增加当前困难场景的训练子集th中的样本的出现概率。
8、进一步地,所述步骤s4中,使用交并比iou或准确率指标来判断经过多轮迭代训练后的中间模型m1是否在各个场景上均取得较好的效果,若还没有则重复迭代步骤s2-步骤s3,否则结束训练得到最终模型m2。
9、进一步地,所述步骤s4中,为某指标设置阈值a,经过多次迭代训练的中间模型m1如果在全部测试数据子集上的测试结果都大于或等于a,则表示该中间模型在各个场景上均取得较好的即令人满意的效果;经过多次迭代训练的中间模型m1如果在任意一个或多个测试数据子集上的测试结果小于a,则表示该中间模型还没有在各个场景上均取得较好的即令人满意的效果。
10、本发明还提出了一种多场景深度学习模型的迭代式训练装置,包括初始单元、训练单元、测试单元、迭代单元。所述初始单元用来给定待训练的用于执行多场景任务的初始模型m0、训练数据集tr和测试数据集te;所述训练数据集tr包括多个不同场景下的训练数据子集;所述测试数据集te包括多个不同场景下的测试数据子集。所述训练单元用来采用训练数据集tr的全部训练数据子集来训练初始模型m0,得到中间模型m1。
11、所述测试单元用来采用测试数据集te的全部测试数据子集对中间模型m1进行测试。挑选出当前测试结果最差的测试数据子集,这个测试数据子集对应场景就是对当前的中间模型m1来说最具难度的场景,将该最具难度场景下的训练子集称为th。所述迭代单元用来在训练数据集tr中重复增加当前困难场景的训练子集th中的样本,并且将每一轮迭代得到的中间模型m1作为下一轮迭代训练对象,然后送入训练单元,由训练单元和测试单元重复迭代,直到获得一个在各个场景均令人满意的最终模型m2。
12、本发明取得的技术效果是:提出了一种增强困难场景样本的学习效果的迭代式训练方法,能够让深度学习模型学习并适应不同场景下的样本,尤其针对困难场景增强学习,提高深度学习模型在各种场景下的鲁棒性。
1.一种多场景深度学习模型的迭代式训练方法,其特征是,包括如下步骤;
2.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s1中,不同类型、不同来源的输入数据作为一类场景。
3.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s1中,所述训练数据集tr包括n个不同场景下的训练数据子集;所述测试数据集te包括n个不同场景下的测试数据子集;并且同一个场景下既有训练数据子集也有测试数据子集。
4.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s2中,中间模型m1适用于全部场景。
5.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s2中,在统计学上所述中间模型m1按场景划分,必然存在识别或检测的不同难易程度。
6.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s4中,在训练数据集tr中重复增加当前困难场景的训练子集th中的样本,也就是增加当前困难场景的训练子集th中的样本的出现概率。
7.根据权利要求1所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s4中,使用交并比iou或准确率指标来判断经过多轮迭代训练后的中间模型m1是否在各个场景上均取得较好的效果,若还没有则重复迭代步骤s2-步骤s3,否则结束训练得到最终模型m2。
8.根据权利要求7所述的多场景深度学习模型的迭代式训练方法,其特征是,所述步骤s4中,为某指标设置阈值a,经过多次迭代训练的中间模型m1如果在全部测试数据子集上的测试结果都大于或等于a,则表示该中间模型在各个场景上均取得较好的即令人满意的效果;经过多次迭代训练的中间模型m1如果在任意一个或多个测试数据子集上的测试结果小于a,则表示该中间模型还没有在各个场景上均取得较好的即令人满意的效果。
9.一种多场景深度学习模型的迭代式训练装置,其特征是,包括初始单元、训练单元、测试单元、迭代单元;