一种神经网络的训练方法及相关装置与流程

文档序号:36179675发布日期:2023-11-29 15:28阅读:102来源:国知局
一种神经网络的训练方法及相关装置与流程

本技术实施例涉及人工智能领域,尤其涉及一种神经网络的训练方法及相关装置。


背景技术:

1、许多的工业场景中,都需要检测来自外界的入侵事件,其中,基于光缆传感的入侵检测是一种通用的检测手段。而由于光缆部署于复杂多样的地质环境和背景震动中,因此,光缆很容易获取到与入侵事件很相似的非入侵事件(属于ood事件)的信号。例如,农耕机械旋耕机事件(非入侵事件)跟夯机敲击(入侵事件)很相似,导致入侵识别模型把非入侵事件误报为入侵事件。

2、因此,在光缆感知到外界的入侵信号后,可以采用入侵识别模型来识别入侵事件或非入侵事件的类别。而由于入侵事件信号和非入侵事件信号较为相似,因此,入侵事件信号和非入侵事件信号都会作为入侵识别模型的训练输入,入侵识别模型的输出则为入侵事件或非入侵事件的类别。对此,上述入侵事件信号则属于模型训练过程中的分布内(in-distribution,id)数据,非入侵事件信号则属于模型训练过程中的分布外(out-of-distribution,ood)数据。当获取到新的非入侵事件后,一般会采集该非入侵事件信号(属于ood数据),从而将现有的入侵识别模型重新训练,得到入侵识别模型的输出则包括了该非入侵事件(ood事件)的新增入侵类别。

3、非入侵事件的类别随着时间发展可能会不断增多,因此,入侵识别模型所输出的入侵类别也会不断增多,严重降低入侵识别模型的性能。


技术实现思路

1、本技术实施例提供了一种神经网络的训练方法及相关装置,用于提高分类网络的性能。

2、第一方面,本技术实施例提供了一种神经网络的训练方法。本技术中神经网络的训练方法,可以适用于神经网络模型的输入端存在ood数据的分类网络场景,用于去除ood数据所带来的无效特征贡献的干扰,提高神经网络模型的性能。例如,上述入侵事件识别的场景中,光缆所采集到的信号包括了与入侵事件信号较为相似的非入侵事件信号(高速行驶或管道正常),又例如,在宠物猫的品种分类的图像识别场景中,输入端存在老虎或豹子等与猫较为相似的图片,还可以是其他的存在ood数据干扰的分类网络场景,具体此处不做限定。这些ood数据会对神经网络模型的输出造成极大的干扰。因此,本技术中,训练数据应当包括ood数据和id数据。每个训练数据有各自的标签,用于指示该训练数据属于ood数据或id数据。

3、将训练数据(包括ood数据和id数据)输入第一特征提取网络,得到每个训练数据对应的特征,即得到ood数据对应的第一ood特征和id数据对应的第一id特征。而由于每个训练数据的标签是已知的(属于ood数据或id数据),因此,该训练数据输入第一特征提取网络后所得到的特征的标签,与该训练数据的标签相同。示例性的,若某个训练数据的标签指示其属于ood数据,则该训练数据输入第一特征提取网络后所得到的特征的标签,也同样指示其属于ood数据。

4、在得到ood数据对应的第一ood特征和id数据对应的第一id特征后,其中的第一ood特征则不需要作为第一分类网络的输入,而仅将第一id特征输入到第一分类网络中,得到第一分类结果。

5、根据第一ood特征、第一id特征和所述第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络,以及,根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

6、具体的,可以根据第一ood特征和第一id特征对第一特征提取网络进行训练,而第一分类网络所输出的第一分类结果则可以用于训练第一特征提取网络和第一分类网络。

7、由于在第一特征提取网络生成第一ood特征和第一id特征之后,便已经可以计算第一ood特征和第一id特征之间的损失函数。对此,“根据第一ood特征和第一id特征对第一特征提取网络进行训练”的步骤,可以在第一特征提取网络生成第一ood特征和第一id特征之后执行,也可以在第一分类网络生成第一分类结果之后执行,具体本技术对此不做限定。

8、另一方面,由于训练数据的标签也已经指示了每个训练数据的真实分类结果,因此,可以根据训练数据的真实分类结果与第一分类结果之间的损失函数,来对第一特征提取网络和第一分类网络进行端到端的训练,其训练目的为使得该损失函数的值最小化。

9、进一步的,为了提高神经网络模型的可靠性,在神经网络模型的训练过程中,一般需要经过多轮次的训练迭代来更新模型参数。因此,本技术的神经网络的训练方法,应用于每一个轮次的训练过程中。每一轮训练都需要从总的训练数据集中抽取一定比例的ood数据和id数据,作为神经网络模型本轮次的训练数据,换句话说,本技术的训练数据,可以是总的训练数据集的子集。若完成本轮次的训练后,训练流程未能收敛,则继续从总的训练数据集中抽取另一部分数据作为下一轮次的训练数据,继续训练模型,直至训练流程成功收敛。每一个轮次的训练过程,都执行本技术中的神经网络的训练方法,直至满足预设条件(例如损失函数的值满足预设条件),从而得到第二特征提取网络和第二分类网络。其中,第二特征提取网络为执行过训练操作的第一特征提取网络,第二分类网络为执行过训练操作的第一分类网络。

10、本技术中,只将训练数据所生成的id特征输入到第一分类网络,即第一分类网络只关注对于id特征的处理。因此,即便增加了新的ood数据类别,也不会增加第一分类网络的输出类别,从而提高了第一分类网络训练后的性能。

11、另一方面,第一分类网络的输入只关注id数据,即只有id数据作为神经网络模型的输入时,才会体现在第一分类网络的分类结果中,而ood数据是不会体现在第一分类网络的分类结果中的,从而去除了ood数据所带来的无效特征贡献的干扰,提高神经网络模型的准确性和稳定性。

12、基于第一方面,一种可选的实施方式中,在训练完毕得到第二特征提取网络和第二分类网络之后,可以再根据训练数据中的ood数据来构建ood锚点库,ood锚点库包括至少一个ood特征锚点。ood锚点库用于确定第二特征属于ood特征或id特征,其中,第二特征为第二特征提取网络对待分类数据进行特征提取所得到的特征。具体的,ood锚点库可以通过训练数据中的ood数据和第二特征提取网络来构建:先获取用于构建ood锚点库的训练数据中的ood数据,然后将ood数据输入到第二特征提取网络中,得到ood数据对应的第二ood特征,将该第二ood特征作为ood特征锚点,保存到ood锚点库中。

13、需要说明的是,本技术并不限定ood锚点库中的ood特征锚点的数量。该ood锚点库包括至少一个ood特征锚点,但ood锚点库中的ood特征锚点的数量越多,则过滤器区分ood特征和id特征的准确性和稳定性则越高。一般来说,在模型训练的过程中,ood数据的数量会远低于id数据的数量,因此,过滤器通过结合ood锚点库的方式来区分ood特征和id特征,能够在ood数据的样本数量较少的场景下,仍然能够具备良好的准确性和稳定性。

14、基于第一方面,一种可选的实施方式中,在训练完毕得到第二特征提取网络和第二分类网络之后,便可以进行神经网络模型的预测流程。将待分类数据输入到第二特征提取网络,得到该待分类数据对应的第二特征。若第二特征属于id特征,则将第二特征输入到第二分类网络,从而得到该第二特征对应的分类结果。若第二特征属于ood特征,则第二特征不会输入到第二分类网络,则第二分类网络也不会输出该第二特征的分类结果。

15、基于第一方面,一种可选的实施方式中,训练数据输入到第一特征提取网络,生成训练数据的特征后,是根据该训练数据的标签的指示来确定该特征属于ood特征或id特征的。然而在,模型训练完毕之后(即得到第二特征提取网络和第二分类网络),该神经网络模型的预测过程中,所输入的待分类数据,是没有标签来指示该待分类数据属于ood数据或id数据的,待分类数据输入到第二特征提取网络后,生成待分类数据对应的第二特征,该第二特征同样也没有标签来指示该特征属于ood特征或id特征。因此,需要对第二特征进行过滤,确认第二特征属于ood特征或id特征。

16、若第二特征属于id特征,则将该第二特征输入到第二分类网络中,由第二分类网络对第二特征进行处理,得到该第二特征对应的第二分类结果;若第二特征属于ood特征,则该第二特征不会作为第二分类网络的输入。

17、基于第一方面,一种可选的实施方式中,可以根据ood锚点库来确定第二特征属于ood特征或id特征。具体的,待分类数据输入第二特征提取网络,生成该待分类数据对应的第二特征。第二特征和计算ood锚点库中的ood特征锚点,都可以用二维的向量来进行表示。根据第二特征的向量表示和ood特征锚点的向量表示,计算ood锚点库中的ood特征锚点与第二特征之间的距离,该距离越小,则说明第二特征与ood特征锚点之间的相似度越高、关联度越高,该第二特征有更高的概率属于ood特征;反之,该距离越大,则说明第二特征与ood特征锚点之间的相似度越低、关联度越低,该第二特征有更高的概率属于id特征。因此,可以先配置一个预设阈值,当第二特征与ood锚点库中的ood特征锚点之间的距离小于预设阈值时,确定第二特征为ood特征;当第二特征与ood锚点库中的ood特征锚点之间的距离大于或等于预设阈值时,确定第二特征为id特征。

18、在构建了ood锚点库之后,即便输入了新增的ood数据类型,该ood锚点库也同样适用于识别过滤新增的ood数据,不需要针对该新增的ood数据类型来重新构建或优化ood锚点库中的ood特征锚点,可以继续基于该ood锚点库来识别新增的ood数据,提高了神经网络模型的效率。

19、基于第一方面,一种可选的实施方式中,以增加第一ood特征和第一id特征之间的距离为方向,来构建损失函数。其中,第一ood特征和第一id特征之间的距离越大,则该损失函数的值越大,说明第一ood特征和第一id特征之间的差异越大,越有利于后续进行区分第一ood特征和第一id特征。因此,可以根据第一ood特征和第一id特征之间的差异的损失函数,以该损失函数的值越来越大作为优化方向,来对第一特征提取网络进行训练,从而使得第一特征提取网络所提取得到的第一ood特征和第一id特征之间的差异越来越大。

20、第二方面,本技术实施例提供了一种神经网络的训练装置,该装置包括:

21、处理单元,用于获取训练数据,训练数据包括分布外ood数据和分布内id数据;用于将训练数据输入第一特征提取网络,得到ood数据对应的第一ood特征和id数据对应的第一id特征;还用于将第一id特征输入第一分类网络,得到第一分类结果;

22、训练单元,用于根据第一ood特征、第一id特征和第一分类结果对第一特征提取网络进行训练,得到第二特征提取网络;根据第一分类结果对第一分类网络进行训练,得到第二分类网络。

23、基于第二方面,一种可选的实施方式中,处理单元,还用于:

24、将训练数据中的ood数据输入第二特征网络,得到ood数据对应的第二ood特征;将第二ood特征作为ood特征锚点,保存到ood锚点库,其中,ood锚点库用于确定第二特征属于ood特征或id特征,第二特征为第二特征提取网络对待分类数据进行特征提取所得到的特征。

25、基于第二方面,一种可选的实施方式中,处理单元,还用于:

26、将待分类数据输入第二特征提取网络,得到第二特征;若第二特征属于id特征,则将第二特征输入第二分类网络,得到第二分类结果。

27、基于第二方面,一种可选的实施方式中,处理单元,还用于:对第二特征进行过滤,确认第二特征属于ood特征或id特征。

28、基于第二方面,一种可选的实施方式中,处理单元,具体用于:

29、获取ood锚点库,ood锚点库包括至少一个ood特征锚点;计算ood锚点库中的ood特征锚点与第二特征之间的距离;若距离的值大于预设阈值,则确定第二特征为id特征。

30、基于第二方面,一种可选的实施方式中,处理单元,具体用于:

31、计算第一ood特征与第一id特征之间的距离;以增加第一ood特征与第一id特征之间的距离作为损失函数,对第一特征提取网络进行训练;根据第一分类结果对第一特征提取网络进行训练。

32、第三方面,本发明实施例提供了一种计算机设备,包括处理器,用于执行上述任一方面的神经网络的训练方法。

33、基于第三方面,一种可选的实施方式中,还包括存储器,用于存储代码,与所述处理器耦合;所述处理器具体用于执行所述存储器中的代码,来实现上述任一方面的神经网络的训练方法。

34、第四方面,本技术实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,当其在计算机上运行时,使得计算机执行上述任一方面所述的神经网络的训练方法。

35、第五方面,本技术实施例提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机程序或指令,当其在计算机上运行时,使得计算机执行上述任一方面所述的神经网络的训练方法。

36、第六方面,本技术实施例提供了一种芯片系统,该芯片系统包括处理器,用于实现上述各个方面中所涉及的功能,例如,发送或处理上述方法中所涉及的数据和/或信息。在一种可能的设计中,所述芯片系统还包括存储器,所述存储器,用于保存服务器或通信设备必要的程序指令和数据。该芯片系统,可以由芯片构成,也可以包括芯片和其他分立器件。

37、从以上技术方案可以看出,本技术实施例具有以下优点:

38、本技术公开了一种神经网络的训练方法及相关装置。获取训练数据,训练数据包括分布外ood数据和分布内id数据;将训练数据输入第一特征提取网络,得到ood数据对应的第一ood特征和id数据对应的第一id特征;将第一id特征输入第一分类网络,得到第一分类结果;根据第一ood特征、第一id特征对和第一分类结果第一特征提取网络进行训练,得到第二特征提取网络,以及,根据第一分类结果对第一分类网络进行训练,得到第二分类网络。本技术中,只将训练数据所生成的id特征输入到第一分类网络,即第一分类网络只关注对于id特征的处理。因此,即便增加了新的ood数据类别,也不会增加第一分类网络的输出类别,从而提高了第一分类网络训练后的性能。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1