本发明涉及机器学习领域,特别涉及一种自动驾驶模型的训练方法、部署方法、系统、介质和设备。
背景技术:
1、目前已有很多提高感知模型鲁棒性的方法,针对恶劣环境对模型鲁棒性的影响,主要采用数据增强的方式,利用采集或生成的容易导致模型鲁棒性下降的样本训练模型。
2、感知模型的鲁棒性可能受恶劣环境、传感器故障、恶意攻击等多种类型干扰因素的影响,然而,当前提升感知模型鲁棒性的方法通常只针对一种或某几种干扰因素,没有涵盖尽可能多的干扰因素,所获得的模型只能在一种或某几种干扰因素下具备较好的鲁棒性,当模型部署到车端,仍容易遭受其他干扰因素的影响,如针对恶劣天气因素训练的模型,仍会受到恶意攻击的影响,导致鲁棒性下降,威胁自动驾驶安全。此外,当前车端部署的感知模型主要为传统小模型,一方面小模型表达能力较弱,难以同时兼顾不同类型的干扰因素;另一方面,使用针对多种类型干扰因素增强的数据训练小模型,容易造成小模型在正常场景下性能下降。
技术实现思路
1、本发明的目的是提供一种自动驾驶模型的训练方法、系统、存储介质和电子设备,能够提升用于部署至自动驾驶系统的第二感知模型的鲁棒性。
2、为解决上述技术问题,本发明提供一种自动驾驶模型的训练方法,具体技术方案如下:
3、获取数据集;所述数据集中的样本包括正常场景数据和干扰场景数据;
4、从所述数据集中选取第一正常场景样本和第一干扰场景样本,训练得到第一感知模型;
5、从所述数据集中选取第二正常场景样本和第二干扰场景样本,基于所述第一感知模型使用知识蒸馏方法训练得到第二感知模型;所述第二感知模型用于部署至自动驾驶系统。
6、可选的,所述获取数据集包括:
7、获取所述正常场景数据或所述干扰场景数据的数据帧;其中,每个所述数据帧至少包含一种模态。
8、可选的,获取数据集之后,还包括:
9、对所述数据集中的样本进行特定任务标注和干扰因素标注。
10、可选的,所述从所述数据集中选取第一正常场景样本和第一干扰场景样本包括:
11、按照相同概率从所述数据集中选取第一正常场景样本和第一干扰场景样本。
12、可选的,利用第一算法训练得到第一感知模型包括:
13、将所述第一正常场景样本和所述第一干扰场景样本输入基础网络,得到多尺度特征;所述基础网络包含骨干网络和多尺度特征提取网络;
14、利用多模态融合网络融合所述样本多个模态的特征,并输入至元网络;所述元网络由基础元网络组和抗干扰元网络组构成;所述基础元网络组包含基础元网络,所述抗干扰元网络组包含抗干扰元网络;所述元网络包含卷积神经网络和多层感知机;
15、在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征;
16、将所述融合特征输入至任务网络,经过训练得到第一感知模型。
17、可选的,利用第一算法训练得到第一感知模型后,还包括:
18、确定所述第一感知模型的损失函数,并利用反向传播优化器优化所述第一感知模型,直至所述第一感知模型完全收敛。
19、可选的,从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法,训练得到第二感知模型包括:
20、从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法训练得到初始第二感知模型;
21、结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法对所述初始第二感知模型进行关联关系矩阵或特征值的一致化处理,得到第二感知模型。
22、可选的,从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法训练得到初始第二感知模型包括:
23、将所述第二正常场景样本和所述第二干扰场景样本输入第二算法得到的基础网络,得到多尺度特征;所述基础网络包含骨干网络和多尺度特征提取网络;
24、利用多模态融合网络融合所述样本多个模态的特征,并输入至元网络;所述元网络由基础元网络组和抗干扰元网络组构成;所述基础元网络组包含基础元网络,所述抗干扰元网络组包含抗干扰元网络;所述元网络包含卷积神经网络和多层感知机;
25、在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征;
26、将所述融合特征输入至任务网络,经过训练得到初始第二感知模型。
27、可选的,结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法对所述初始第二感知模型进行关联关系矩阵或特征值的一致化处理,得到第二感知模型包括:
28、利用多模态融合网络融合所述样本多个模态的特征后,利用干扰特征知识蒸馏方法对干扰场景样本和正常场景样本之间的前景特征以及背景特征间的关联关系进行知识蒸馏,并利用损失函数计算关联关系矩阵之间的差异;
29、在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征后,利用基础特征判别损失函数对所述融合特征进行预测,并根据预测结果进行特征值对齐。
30、可选的,利用干扰特征知识蒸馏方法对干扰场景样本和正常场景样本之间的前景特征以及背景特征间的关联关系进行知识蒸馏,并利用损失函数计算关联关系矩阵之间的差异包括:
31、输入正常场景样本和干扰场景样本;
32、对于所述第一感知模型,得到正常场景样本的第一多模态融合特征和第一元知识特征;对于第二感知模型,得到正常场景样本的第二多模态融合特征和第二元知识特征;
33、分别确定所述第一多模态融合特征、所述第一元知识特征、所述第二多模态融合特征和所述第二元知识特征的宽、长和通道数量;
34、根据所述第一多模态融合特征、所述第一元知识特征、所述第二多模态融合特征和所述第二元知识特征的宽、长和通道数量确定所述前景特征和所述背景特征;
35、对于所述前景特征,从标注的目标边界框内进行均匀的特征点采样,获得前景特征点;
36、对于所述背景特征,从标注的目标边界框外的区域进行均匀的特征点采样,得到背景特征点
37、根据所述前景特征点和所述背景特征点利用损失函数计算关联关系矩阵之间的差异。
38、可选的,还包括:
39、确定干扰场景样本的干扰特征点,并根据所述干扰特征点确定需要进行知识迁移的特征。
40、可选的,所述确定所述干扰场景样本的干扰特征点包括:
41、确定子任务的损失函数;
42、利用所述子任务的损失函数计算所述第二感知模型的干扰场景样本的多模态融合特征梯度和元知识特征梯度;
43、沿通道方向对所述多模态融合特征梯度和所述元知识特征梯度求和,对宽、高维度上每个位置梯度值求绝对值,并按照梯度绝对值进行降序排序;
44、取所述排序中前预设数量个梯度绝对值的位置索引,按照所述位置索引在所述第二感知模型的干扰场景样本的多模态融合特征中采样第一干扰特征点,按照所述位置索引在所述第一感知模型的干扰场景样本的多模态融合特征中采样第二干扰特征点。
45、可选的,根据所述前景特征点和所述背景特征点利用损失函数计算关联关系矩阵之间的差异包括:
46、根据所述第一感知模型和所述第二感知模型各自的多模态融合特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第一知识蒸馏损失;
47、根据所述第一感知模型和所述第二感知模型各自的元知识特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第二知识蒸馏损失;
48、根据所述第一知识蒸馏损失和所述第二知识蒸馏损失计算关联关系矩阵之间的差异。
49、可选的,根据所述第一感知模型和所述第二感知模型各自的多模态融合特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第一知识蒸馏损失包括:
50、聚合所述第一感知模型的多模态融合特征的前景特征点、背景特征点和干扰特征点,得到第一聚合特征;
51、确定所述第一聚合特征的第一关联关系矩阵;
52、聚合所述第二感知模型的多模态融合特征的前景特征点、背景特征点和干扰特征点,得到第二聚合特征;
53、确定所述第二聚合特征的第二关联关系矩阵;
54、根据所述第一关联关系矩阵和所述第二关联关系矩阵计算干扰特征第一知识蒸馏损失。
55、可选的,根据所述第一感知模型和所述第二感知模型各自的元知识特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第二知识蒸馏损失包括:
56、聚合所述第一感知模型的元知识特征的前景特征点、背景特征点和干扰特征点,得到第三聚合特征;
57、确定所述第三聚合特征的第三关联关系矩阵;
58、聚合所述第二感知模型的元知识特征的前景特征点、背景特征点和干扰特征点,得到第四聚合特征;
59、确定所述第四聚合特征的第四关联关系矩阵;
60、根据所述第三关联关系矩阵和所述第四关联关系矩阵计算干扰特征第二知识蒸馏损失。
61、可选的,利用基础特征判别损失函数对所述融合特征进行预测,并根据预测结果进行特征值对齐包括:
62、确定所述第一感知模型的第一基础特征,和所述第二感知模型的第二基础特征;
63、将所述第一基础特征和所述第二基础特征输入至基础特征判别网络,确定所述第一基础特征或者所述第二基础特征的预测结果;所述基础特征判别网络包含所述基础特征判别损失函数;
64、根据所述预测结果进行特征值对齐。
65、可选的,将所述第一基础特征和所述第二基础特征输入至基础特征判别网络之前,还包括:
66、利用梯度反转层对所述第二感知模型基础元网络或抗干扰元网络进行知识蒸馏。
67、可选的,利用第二算法结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法,训练得到第二感知模型时,还包括:
68、每次训练时,从正常场景数据集采样得到正常场景训练样本,从感知鲁棒性干扰场景数据库中采样一个干扰场景数据训练样本;
69、将所述正常场景训练样本和所述干扰场景数据训练样本输入至所述第一感知模型和所述第二感知模型中,得到各自的多模态融合特征和融合的元知识特征。
70、可选的,得到各自的多模态融合特征和融合的元知识特征之后,还包括:
71、计算多模态融合特征的干扰特征知识蒸馏损失函数,元知识特征的干扰特征知识蒸馏损失函数,基础特征判别损失函数,抗干扰特征判别损失函数,干扰因素判别损失函数的损失值;
72、利用所述多模态融合特征的干扰特征知识蒸馏损失函数、元知识特征的干扰特征知识蒸馏损失函数、基础特征判别损失函数、抗干扰特征判别损失函数、干扰因素判别损失函数的损失值和各子任务网络计算获得的损失值,反向传播优化所述第二感知模型。
73、本发明还提供一种自动驾驶模型的部署方法,包括:
74、获取第二感知模型;
75、将所述第二感知模型部署至自动驾驶系统。
76、本发明还提供一种自动驾驶模型的训练系统,包括:
77、数据获取模块,用于获取数据集;所述数据集中的样本包括正常场景数据和干扰场景数据;
78、第一感知模型训练模块,用于从所述数据集中选取第一正常场景样本和第一干扰场景样本,利用第一算法训练得到第一感知模型;
79、第二感知模型训练模块,用于从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法,训练得到第二感知模型;所述第二感知模型用于部署至自动驾驶系统。
80、可选的,还包括:
81、标注模块,用于对所述数据集中的样本进行特定任务标注和干扰因素标注。
82、可选的,所述第一感知模型训练模块包括:
83、第一选取子模块,用于按照相同概率从所述数据集中选取第一正常场景样本和第一干扰场景样本。
84、可选的,第一感知模型训练模块包括:
85、第一输入子模块,用于将所述第一正常场景样本和所述第一干扰场景样本输入基础网络,得到多尺度特征;所述基础网络包含骨干网络和多尺度特征提取网络;
86、第一融合子模块,用于利用多模态融合网络融合所述样本多个模态的特征,并输入至元网络;所述元网络由基础元网络组和抗干扰元网络组构成;所述基础元网络组包含基础元网络,所述抗干扰元网络组包含抗干扰元网络;所述元网络包含卷积神经网络和多层感知机;
87、第二融合子模块,用于在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征;
88、第一输入子模块,用于将所述融合特征输入至任务网络,经过训练得到第一感知模型。
89、可选的,还包括:
90、损失函数确定子模块,用于确定所述第一感知模型的损失函数,并利用反向传播优化器优化所述第一感知模型,直至所述第一感知模型完全收敛。
91、可选的,第二感知模型训练模块包括:
92、第二选取子模块,用于从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法训练得到初始第二感知模型;
93、一致化处理子模块,用于结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法对所述初始第二感知模型进行关联关系矩阵或特征值的一致化处理,得到第二感知模型。
94、可选的,第二选取子模块包括:
95、第一输入单元,用于将所述第二正常场景样本和所述第二干扰场景样本输入第二算法得到的基础网络,得到多尺度特征;所述基础网络包含骨干网络和多尺度特征提取网络;
96、第一融合单元,用于利用多模态融合网络融合所述样本多个模态的特征,并输入至元网络;所述元网络由基础元网络组和抗干扰元网络组构成;所述基础元网络组包含基础元网络,所述抗干扰元网络组包含抗干扰元网络;所述元网络包含卷积神经网络和多层感知机;
97、第二融合单元,用于在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征;
98、第二输入单元,用于将所述融合特征输入至任务网络,经过训练得到初始第二感知模型。
99、可选的,一致化处理子模块包括:
100、计算单元,用于利用多模态融合网络融合所述样本多个模态的特征后,利用干扰特征知识蒸馏方法对干扰场景样本和正常场景样本之间的前景特征以及背景特征间的关联关系进行知识蒸馏,并利用损失函数计算关联关系矩阵之间的差异;
101、预测单元,用于在元知识融合网络中对各所述元网络输出的结果进行融合,得到融合特征后,利用基础特征判别损失函数对所述融合特征进行预测,并根据预测结果进行特征值对齐。
102、可选的,计算单元包括:
103、输入子单元,用于输入正常场景样本和干扰场景样本;
104、特征确定子单元,用于对于所述第一感知模型,得到正常场景样本的第一多模态融合特征和第一元知识特征;对于第二感知模型,可获得正常场景样本的第二多模态融合特征和第二元知识特征;
105、第一确定子单元,用于分别确定所述第一多模态融合特征、所述第一元知识特征、所述第二多模态融合特征和所述第二元知识特征的宽、长和通道数量;
106、第二确定子单元,用于根据所述第一多模态融合特征、所述第一元知识特征、所述第二多模态融合特征和所述第二元知识特征的宽、长和通道数量确定所述前景特征和所述背景特征;
107、第一采样子单元,用于对于所述前景特征,从标注的目标边界框内进行均匀的特征点采样,获得前景特征点;
108、第二采样子单元,用于对于所述背景特征,从标注的目标边界框外的区域进行均匀的特征点采样,得到背景特征点
109、第一计算子单元,用于根据所述前景特征点和所述背景特征点利用损失函数计算关联关系矩阵之间的差异。
110、可选的,还包括:
111、第三确定子单元,用于确定干扰场景样本的干扰特征点,并根据所述干扰特征点确定需要进行知识迁移的特征。
112、可选的,第三确定子单元包括:
113、第四确定子单元,用于确定子任务的损失函数;
114、第二计算子单元,用于利用所述子任务的损失函数计算所述第二感知模型的干扰场景样本的多模态融合特征梯度和元知识特征梯度;
115、排序子单元,用于沿通道方向对所述多模态融合特征梯度和所述元知识特征梯度求和,对宽、高维度上每个位置梯度值求绝对值,并按照梯度绝对值进行降序排序;
116、第三采样子单元,用于取所述排序中前预设数量个梯度绝对值的位置索引,按照所述位置索引在所述第二感知模型的干扰场景样本的多模态融合特征中采样第一干扰特征点,按照所述位置索引在所述第一感知模型的干扰场景样本的多模态融合特征中采样第二干扰特征点。
117、可选的,第一计算子单元包括:
118、第一知识蒸馏损失计算子单元,用于根据所述第一感知模型和所述第二感知模型各自的多模态融合特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第一知识蒸馏损失;
119、第二知识蒸馏损失计算子单元,用于根据所述第一感知模型和所述第二感知模型各自的元知识特征的前景特征点、背景特征点和干扰特征点,计算干扰特征的第二知识蒸馏损失;
120、差异计算子单元,用于根据所述第一知识蒸馏损失和所述第二知识蒸馏损失计算关联关系矩阵之间的差异。
121、可选的,第一知识蒸馏损失计算子单元包括:
122、第一聚合子单元,用于聚合所述第一感知模型的多模态融合特征的前景特征点、背景特征点和干扰特征点,得到第一聚合特征;
123、第一关联关系矩阵确定子单元,用于确定所述第一聚合特征的第一关联关系矩阵;
124、第二聚合子单元,用于聚合所述第二感知模型的多模态融合特征的前景特征点、背景特征点和干扰特征点,得到第二聚合特征;
125、第二关联关系矩阵确定子单元,用于确定所述第二聚合特征的第二关联关系矩阵;
126、干扰特征第一知识蒸馏损失计算子单元,用于根据所述第一关联关系矩阵和所述第二关联关系矩阵计算干扰特征第一知识蒸馏损失。
127、可选的,第二知识蒸馏损失计算子单元包括:
128、第三聚合子单元,用于聚合所述第一感知模型的元知识特征的前景特征点、背景特征点和干扰特征点,得到第三聚合特征;
129、第三关联关系矩阵确定子单元,用于确定所述第三聚合特征的第三关联关系矩阵;
130、第四聚合子单元,用于聚合所述第二感知模型的元知识特征的前景特征点、背景特征点和干扰特征点,得到第四聚合特征;
131、第四关联关系矩阵确定子单元,用于确定所述第四聚合特征的第四关联关系矩阵;
132、干扰特征第二知识蒸馏损失计算子单元,用于根据所述第三关联关系矩阵和所述第四关联关系矩阵计算干扰特征第二知识蒸馏损失。
133、可选的,预测单元包括:
134、第二基础特征和第二基础特征确定子单元,用于确定所述第一感知模型的第一基础特征,和所述第二感知模型的第二基础特征;
135、预测结果确定子单元,用于将所述第一基础特征和所述第二基础特征输入至基础特征判别网络,确定所述第一基础特征或者所述第二基础特征的预测结果;所述基础特征判别网络包含所述基础特征判别损失函数;
136、特征值对齐子单元,用于根据所述预测结果进行特征值对齐。
137、可选的,在执行预测结果确定子单元中的步骤之前,还包括:
138、知识蒸馏子单元,用于利用梯度反转层对所述第二感知模型基础元网络或抗干扰元网络进行知识蒸馏。
139、可选的,在执行第二感知模型训练模块中的步骤时,还包括:
140、样本获取模块,用于每次训练时,从正常场景数据集采样得到正常场景训练样本,从感知鲁棒性干扰场景数据库中采样一个干扰场景数据训练样本;
141、输入模块,用于将所述正常场景训练样本和所述干扰场景数据训练样本输入至所述第一感知模型和所述第二感知模型中,得到各自的多模态融合特征和融合的元知识特征。
142、可选的,在执行输入模块中的步骤之后,包括:
143、计算模块,用于计算多模态融合特征的干扰特征知识蒸馏损失函数,元知识特征的干扰特征知识蒸馏损失函数,基础特征判别损失函数,抗干扰特征判别损失函数,干扰因素判别损失函数的损失值;
144、第二感知模型传播优化模块,用于利用所述多模态融合特征的干扰特征知识蒸馏损失函数、元知识特征的干扰特征知识蒸馏损失函数、基础特征判别损失函数、抗干扰特征判别损失函数、干扰因素判别损失函数的损失值和各子任务网络计算获得的损失值,反向传播优化所述第二感知模型。
145、本发明还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如上所述的方法的步骤。
146、本发明还提供一种电子设备,包括存储器和处理器,所述存储器中存有计算机程序,所述处理器调用所述存储器中的计算机程序时实现如上所述的方法的步骤。
147、本发明提供一种自动驾驶模型的训练方法,包括:获取数据集;所述数据集中的样本包括正常场景数据和干扰场景数据;从所述数据集中选取第一正常场景样本和第一干扰场景样本,利用第一算法训练得到第一感知模型;从所述数据集中选取第二正常场景样本和第二干扰场景样本,利用第二算法结合所述第一感知模型使用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法,训练得到第二感知模型;将所述第二感知模型部署至自动驾驶系统中。
148、本发明在训练得到第一感知模型的基础上,进一步利用干扰特征知识蒸馏方法和元网络对抗知识蒸馏方法训练得到第二感知模型,有效对第二感知模型进行特征质量的评估,避免高质量特征和低质量特征同时被蒸馏,提升了知识蒸馏的效果,便于对第二感知模型中低质量的抗干扰特征进行针对性的知识蒸馏,有效提升了第二感知模型的鲁棒性。
149、本发明还提供一种自动驾驶模型的训练系统、存储介质和电子设备,具有上述有益效果,此处不再赘述。