基于深度学习的学习率确定方法和装置与流程

文档序号:16137015发布日期:2018-12-01 01:07阅读:181来源:国知局

本公开涉及深度学习,尤其涉及基于深度学习的学习率确定方法和装置。

背景技术

目前深度学习模型的求解具有多种优化方式,无论哪种优化方式都需要设置一个学习率来控制梯度移动步长。学习率设置太小会导致目标损失下降过慢,迭代次数过多,需要非常久的时间才能收敛;学习率设置太大很容易导致梯度爆炸(乘积趋向无穷大),使整个深度网络的学习无法继续下去。因此,合适的学习率对深度网络参数的求解非常重要。部分研究人员根据经验设定一个合适的固定值,然而这种固定值无法兼顾整个网络的学习过程,通常会导致网络后期的目标损失值出现震荡。考虑到网络学习过程中,前期梯度下降较快后期梯度下降较慢的特性,大多数研究人员将学习率设计成一个指数的衰减函数来进行参数优化。随着迭代次数的增加,指数衰减函数输出的学习率会越来越小导致网络后期参数更新微乎其微,难以快速收敛及越过一些局部极小值点。

现有的学习率设置方法基本上都是根据经验人为设定一个固定值或固定函数来进行网络参数寻优。人为设定的学习率很难顾及到整个网络的学习过程,即使采用循环变化的学习率也有可能使目标损失函数值震荡导致寻优效率低下。



技术实现要素:

为解决上述技术问题,本公开提供一种基于深度学习的学习率确定方法和装置,技术方案如下:

一种基于深度学习的学习率确定方法,用于确定第t+1次迭代所使用的学习率,所述方法包括:

根据预设的次数m,获取第t-m+1次到第t次迭代所输出的m组目标损失值lst-m+1,lst-m+2……lst;

根据m组目标损失值确定第t+1次迭代的学习率,使所述学习率与m组目标损失值的变化速率成正比。

一种基于深度学习的学习率确定装置,用于确定第t+1次迭代所使用的学习率,所述装置包括:

参数获取模块:用于根据预设的次数m,获取第t-m+1次到第t次迭代所输出的m组目标损失值lst-m+1,lst-m+2……lst;

学习率确定模块:用于根据m组目标损失值确定第t+1次迭代的学习率,使所述学习率与m组目标损失值的变化速率成正比。

本公开将学习率与目标损失函数相关联,当目标损失值的变化速率较快时,将下一次参数迭代的学习率自动调整为一个相对较大的数值,当目标损失值的变化速率较慢时,将下一次参数迭代的学习率自动调整为一个相对较小的数值。本公开利用目标损失值的下降速率来自适应的调整学习率,在减少目标损失值震荡的同时,尽可能加快网络的收敛速率。

附图说明

此处的附图被并入说明书中并构成本公开的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。

图1是本公开一示例性实施例示出的基于深度学习的学习率确定方法的流程图;

图2是本公开一示例性实施例示出的基于深度学习的学习率确定装置的结构示意图;

图3是本公开一示例性实施例示出的一种计算机设备的结构示意图。

具体实施方式

这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。

在本公开使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本公开。在本公开和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。

应当理解,尽管在本公开可能采用术语第一、第二、第三等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本公开范围的情况下,第一信息也可以被称为第二信息,类似地,第二信息也可以被称为第一信息。取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”。

目前来说,深度学习模型的求解需要设置一个学习率来控制梯度移动步长。学习率设置太小会导致目标损失下降过慢,迭代次数过多,需要非常久的时间才能收敛;学习率设置太大很容易导致梯度爆炸(乘积趋向无穷大),使整个深度网络的学习无法继续下去。因此,合适的学习率对深度网络参数的求解非常重要。部分研究人员根据经验设定一个合适的固定值,然而这种固定值无法兼顾整个网络的学习过程,通常会导致网络后期的目标损失值出现震荡。考虑到网络学习过程中,前期梯度下降较快后期梯度下降较慢的特性,还有部分研究人员将学习率设计成一个指数的衰减函数来进行参数优化。随着迭代次数的增加,指数衰减函数输出的学习率会越来越小导致网络后期参数更新微乎其微,难以快速收敛及越过一些局部极小值点。

现有的学习率设置方法基本上都是根据经验人为设定一个固定值或固定函数来进行网络参数寻优。人为设定的学习率很难顾及到整个网络的学习过程,即使采用循环变化的学习率也有可能使目标损失函数值震荡导致寻优效率低下。

针对上述问题,本公开提供了一种学习率的确定方法,将学习率与目标损失函数相关联,当目标损失值的变化速率较快时,将下一次参数迭代的学习率自动调整为一个相对较大的数值,当目标损失值的变化速率较慢时,将下一次参数迭代的学习率自动调整为一个相对较小的数值。本公开利用目标损失值的下降速率来自适应的调整学习率,在减少目标损失值震荡的同时,尽可能加快网络的收敛速率。

请参考附图1,附图1为本公开实施例基于深度学习的学习率确定方法的一种流程图。如附图1所示,该流程可包括以下步骤:

s101,根据预设的次数m,获取第t-m+1次到第t次迭代所输出的m组目标损失值lst-m+1,lst-m+2……lst;

在深度学习的模型训练过程中,每一次迭代完成后,目标损失函数都会输出一个目标损失值。深度学习中的损失函数(lossfunction)是一个用于估量模型的预测值与真实值的不一致程度的函数,是一个非负实值函数。通常情况下,随着迭代次数的增加,预测值与真实值逐渐趋于一致,目标损失函数输出的目标损失值会逐渐减小。

假设本次迭代为第t次迭代,需要确定的是第t+1次迭代的学习率。本实施例需要从第t次开始,获取第t次之前(包括第t次)的m组目标损失值。也就是:获取第t-m+1次迭代,第t-m+2次迭代……直到第t次迭代对应的m组目标损失值lst-m+1,lst-m+2……lst。

其中,m是用户预设的数值,举例说明,假设本次是第15次迭代,需要确定第16次迭代的学习率。如果将次数m设定为3,则需要获取第13,14,15次迭代的3组目标损失值;如果将次数设定为5,则需要获取11,12,13,14,15次迭代的5组目标损失值。本实施例是以最近的m次迭代的目标损失值的变化速率来确定下一次迭代所使用的学习率。次数m可以根据经验直接设定,或反复调整以获得一个较佳的值,本实施例对此不作限定。

s102,根据m组目标损失值确定第t+1次迭代的学习率,使所述学习率与m组目标损失值的变化速率成正比。

学习率,也称为步长,该值决定了每次迭代时参数的更新幅度,如果学习率过小,可能导致梯度下降这一过程的速度缓慢,如果学习率过大,则可能导致overshoottheminimum现象,即无法令模型随更新进程而趋近拟合,因此,本实施例需要利用目标损失值的下降速率来自适应的调整学习率。

获取近m次目标损失值后,计算lr't+1,计算公式为:根据计算出的lr't+1确定第t+1次迭代的学习率,使学习率与目标损失值的变化速率成正比。

上述公式中,gx为第x次迭代时目标损失值的变化量,gx=lsx-lsx-1,x为大于1的整数。

如前文所述,目标损失值通常会随着迭代次数t的增加而减少,因此,目标损失量gx是一个负值,学习率应该与目标损失值的负梯度成正比。

在本公开的一种可实施方式中,计算lr't+1,计算公式为:根据lr't+1确定第t+1次迭代的学习率,使学习率与目标损失值的变化速率成正比。

上述公式中,λ为配置参数,用以确定目标损失值的变化速率与学习率的相关度,λ是一个大于0的值,由用户根据实际需求进行设定。ε为修正参数,用以平滑数值。

在本公开的另一种可实施方式中,

确定lr't+1后,继续判定lr't+1是否满足下述条件:

确定lr't+1是否小于lrmin;

如果lr't+1小于lrmin,将第t+1次迭代的学习率lrt+1确定为lrmin;

如果lr't+1不小于lrmin,将第t+1次迭代的学习率lrt+1确定为lr't+1;

其中,lrmin为预设的最小学习率。

确定lr't+1是否大于lrmax;

如果lr't+1大于lrmax,将第t+1次迭代的学习率lrt+1确定为lrt+1;

如果lr't+1不大于lrmax,将第t+1次迭代的学习率lrt+1确定为lr't+1;

其中,lrmax为预设的最大学习率。

在根据目标损失值自适应学习率时,可能会各种原因导致数据不可用。如样本噪声等因素出现会导致目标损失值随次数t的增加而增加,进而导致前述公式确定的学习率出现负值;此外,损失值变化率较大时,学习率也相应较大,有一定梯度爆炸的风险。上述的判定条件用于对学习率的上下限进行约束,避免采用异常学习率导致的学习失败。

请参考附图2,附图2为本公开实施例基于深度学习的学习率确定装置的一种结构示意图,所述装置包括:参数获取模块210,学习率确定模块220。

参数获取模块210:用于根据预设的次数m,获取第t-m+1次到第t次迭代所输出的m组目标损失值lst-m+1,lst-m+2……lst;

学习率确定模块220:用于根据m组目标损失值确定第t+1次迭代的学习率,使所述学习率与m组目标损失值的变化速率成正比。

进一步地,学习率确定模块220,具体用于:

计算lr't+1,根据lr't+1确定第t+1次迭代的学习率,使所述学习率与目标损失值的变化速率成正比;

其中,gx为第x次迭代时目标损失值的变化量,gx=lsx-lsx-1,x为大于1的整数。

进一步地,学习率确定模块220,具体还用于:

计算lr't+1,根据lr't+1确定第t+1次迭代的学习率,使所述学习率与目标损失值的变化速率成正比;

其中,gx为第x次迭代时目标损失值的变化量,gx=lsx-lsx-1,x为大于1的整数;

λ为配置参数,用以确定目标损失值的变化速率与学习率的相关度,ε为修正参数,用以平滑数值。

进一步地,学习率确定模块220,具体还用于:

确定lr't+1是否小于lrmin;

如果lr't+1小于lrmin,将第t+1次迭代的学习率lrt+1确定为lrmin;

如果lr't+1不小于lrmin,将第t+1次迭代的学习率lrt+1确定为lr't+1;

其中,lrmin为预设的最小学习率。

进一步地,学习率确定模块220,具体还用于:

确定lr't+1是否大于lrmax;

如果lr't+1大于lrmax,将第t+1次迭代的学习率lrt+1确定为lrt+1;

如果lr't+1不大于lrmax,将第t+1次迭代的学习率lrt+1确定为lr't+1;

其中,lrmax为预设的最大学习率。

本公开实施例还提供一种计算机设备,其至少包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其中,处理器执行所述程序时可以实现前述的基于深度学习的学习率确定方法。

请参考附图3,图3示出了本公开实施例所提供的一种更为具体的计算设备硬件结构示意图,该设备可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。

处理器1010可以采用通用的cpu(centralprocessingunit,中央处理器)、微处理器、应用专用集成电路(applicationspecificintegratedcircuit,asic)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本公开实施例所提供的技术方案。

存储器1020可以采用rom(readonlymemory,只读存储器)、ram(randomaccessmemory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本公开的实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。

输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。

通信接口1040用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如usb、网线等)实现通信,也可以通过无线方式(例如移动网络、wifi、蓝牙等)实现通信。

总线1050包括一通路,在设备的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。

需要说明的是,尽管上述设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本公开实施例方案所必需的组件,而不必包含图中所示的全部组件。

本公开实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现前述基于深度学习的学习率确定方法,该方法至少包括:

根据预设的次数m,获取第t-m+1次到第t次迭代所输出的m组目标损失值lst-m+1,lst-m+2……lst;

根据m组目标损失值确定第t+1次迭代的学习率,使所述学习率与m组目标损失值的变化速率成正比。

计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。

计算机的存储介质的例子包括,但不限于相变内存(pram)、静态随机存取存储器(sram)、动态随机存取存储器(dram)、其他类型的随机存取存储器(ram)、只读存储器(rom)、电可擦除可编程只读存储器(eeprom)、快闪记忆体或其他内存技术、只读光盘只读存储器(cd-rom)、数字多功能光盘(dvd)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体(transitorymedia),如调制的数据信号和载波。

以上所述仅为本公开的较佳实施例而已,并不用以限制本公开,凡在本公开的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本公开保护的范围之内。

本公开中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的模块可以是或者也可以不是物理上分开的,作为模块显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。

以上所述仅为本公开的较佳实施例而已,并不用以限制本公开,凡在本公开的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本公开保护的范围之内。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1