一种深度学习模型的训练方法及装置与流程

文档序号:34379712发布日期:2023-06-08 01:09阅读:38来源:国知局
一种深度学习模型的训练方法及装置与流程

本发明涉及深度学习,尤其涉及一种深度学习模型的训练方法及装置。


背景技术:

1、目前,智能机器人在商场、机场、车站等公共场所的应用日益广泛,对建筑物、绿化带、行人、车辆等进行分割和识别已经成为其必不可少的功能。这些功能的实现都依赖于resnet、yolo等复杂深度神经网络及其配套学习算法。目标检测就是找出图像中所有感兴趣的物体,包含物体定位和物体分类两个子任务,同时确定物体的类别和位置。

2、目前大多数的深度学习任务的尺度一般为224*224,检测任务coco数据集则是640*640,在许多真实情况下存在一些1280*960分辨率的训练任务,虽然可以在输入深度模型前降低图片的分辨率从而提高batch size,但是在某些实际项目中(电路板缺陷检测,真实场景垃圾分类和某些高精度工业级项目),降低输入图片的分辨率会在训练过程中损失部分特征,且违背使用高分辨率相机的初衷造成浪费,在这种情况下,batch size的设置受限于计算机算力和实际任务,数据存在局部过拟合的问题,导致训练过程无法良好的提取特征最终导致训练出的模型不能落地。

3、由上述可得,现有的深度学习模型的训练方法在batch size设置受限的情况下,导致特征提取过程中数据局部过拟合的问题,最终会造成训练模型鲁棒性不高的问题。


技术实现思路

1、本发明实施例提供一种深度学习模型的训练方法及装置,能够在batch size设置受限的情况下解决数据局部过拟合的问题,从而提高了训练深度学习模型的鲁棒性。

2、本技术实施例的第一方面提供了一种深度学习模型的训练方法,包括:

3、向深度学习模型输入n批mini-batch,以使深度学习模型根据n批mini-batch计算出相应的n个均值和n个方差;

4、根据n个均值计算全局均值后,根据n个均值以及全局均值计算得到权重系数;

5、根据n个方差计算标准差后,根据n个方差以及标准差计算得到偏差;

6、根据权重系数对n个均值进行线性变换,生成第一数据特征;

7、根据偏差对n个方差进行线性变换,生成第二数据特征;

8、根据第一数据特征和第二数据特征训练深度学习模型。

9、在第一方面的一种可能的实现方式中,根据n个均值以及全局均值计算得到权重系数,具体为:

10、

11、

12、σ←σ+α(σb-σ);

13、其中,r为权重系数,σb表示当前训练迭代过程中的实际统计到的均值标准差,σ表示网络推理时的标准差,rmax一般取1-10。

14、在第一方面的一种可能的实现方式中,根据n个方差计算标准差后,根据n个方差以及标准差计算得到偏差,具体为:

15、

16、

17、μ←μ+α(μb-μ);

18、其中,d为偏差,μb表示当前训练迭代过程中的实际统计到的均值。

19、在第一方面的一种可能的实现方式中,根据n个均值计算全局均值,具体为:

20、根据指数滑动平均方法,结合n个均值,计算得到全局均值。

21、在第一方面的一种可能的实现方式中,根据n个方差计算标准差,具体为:

22、根据指数滑动平均方法,结合n个方差,计算得到标准差。

23、本技术实施例的第二方面提供了一种深度学习模型的训练,包括:输入模块、第一计算模块、第二计算模块、第一变换模块、第二变换模块和训练模块;

24、其中,输入模块用于向深度学习模型输入n批mini-batch,以使深度学习模型根据n批mini-batch计算出相应的n个均值和n个方差;

25、第一计算模块用于根据n个均值计算全局均值后,根据n个均值以及全局均值计算得到权重系数;

26、第二计算模块用于根据n个方差计算标准差后,根据n个方差以及标准差计算得到偏差;

27、第一变换模块用于根据权重系数对n个均值进行线性变换,生成第一数据特征;

28、第二变换模块用于根据偏差对n个方差进行线性变换,生成第二数据特征;

29、训练模块用于根据第一数据特征和第二数据特征训练深度学习模型。

30、在第二方面的一种可能的实现方式中,根据n个均值以及全局均值计算得到权重系数,具体为:

31、

32、

33、σ←σ+α(σb-σ);

34、其中,r为权重系数,σb表示当前训练迭代过程中的实际统计到的均值标准差,σ表示网络推理时的标准差,rmax一般取1-10。

35、在第二方面的一种可能的实现方式中,根据n个方差计算标准差后,根据n个方差以及标准差计算得到偏差,具体为:

36、

37、

38、μ←μ+α(μb-μ);

39、其中,d为偏差,μb表示当前训练迭代过程中的实际统计到的均值。

40、在第二方面的一种可能的实现方式中,根据n个均值计算全局均值,具体为:

41、根据指数滑动平均方法,结合n个均值,计算得到全局均值。

42、本技术实施例的第三方面提供了一种基于深度学习模型的目标检测系统,包括:摄像头、通信装置、数据存储装置、中央控制装置以及深度学习模型;

43、其中,摄像头用于根据中央控制装置所发送的第一指令,设置相机设备的参数,获取图像数据;

44、通信装置用于实现中央控制装置与外界设备的双向通信;

45、数据存储装置用于根据中央控制装置所发送的第二指令,向中央控制装置反馈系统当前工作状态信息;

46、中央控制模块用于向摄像头发送第一指令,以获得图像数据并发送至深度学习模型;用于向数据存储装置发送第二指令,以获得系统当前工作状态信息;

47、深度学习模型用于根据中央控制装置发送的图像数据进行目标检测。

48、相比于现有技术,本发明实施例提供的一种深度学习模型的训练方法及装置,所述方法包括:向深度学习模型输入n批mini-batch,以使深度学习模型根据n批mini-batch计算出相应的n个均值和n个方差;根据n个均值计算全局均值后,根据n个均值以及全局均值计算得到权重系数;根据n个方差计算标准差后,根据n个方差以及标准差计算得到偏差;根据权重系数对n个均值进行线性变换,生成第一数据特征;根据偏差对n个方差进行线性变换,生成第二数据特征;根据第一数据特征和第二数据特征训练深度学习模型。

49、其有益效果在于:本发明实施例将n批mini-batch输入至深度学习模型后,计算得到相应的n个均值和n个方差,并根据n个均值和n个方差分别计算得到权重系数和偏差,最后根据权重系数和偏差分别对n个均值和n个方差进行线性变化(即标准化处理),能够避免不同批的n个均值之间差异过大/不同批的n个方差之间差异过大而导致的网络学习过程过于震荡、无法进入全局最优解的问题,根据标准化处理后的均值和方差提取数据特征,利用全局均值和方差优化了在训练过程中每个mini-batch在做bn时的参数,这样做能够解决特征提取过程中数据局部过拟合的问题,从而提高了训练深度学习模型的鲁棒性。

50、同时,本发明实施例能够在无需修改网络结构的情况下提高模型训练的鲁棒性,兼顾模型稳定性和有限容量模型下输入小的mini batch不损失精度的双重目标,节省大量的人工成本和时间成本。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1