基于学习自动机的深度神经网络优化方法与流程

文档序号:11729867阅读:397来源:国知局
基于学习自动机的深度神经网络优化方法与流程

本发明涉及的是一种信息处理领域的技术,具体是一种基于学习自动机(learningautomata,la)的深度神经网络中去除弱连接的方法。



背景技术:

神经网络是一种传统的机器学习算法,可以实现输入到输出的非线性映射,能应用于特征变换、分类、识别等任务中。由于其具有强大的模型表达能力,在模式识别、人工智能等领域得到了广泛应用。神经网络模型通常包含输入层、输出层和隐藏层,每层由特定个数的神经元组成,每个神经元可描述为y=f(w*x+b),其中:x代表输入向量;y代表输出值;权值向量w和偏置b为可训练的参数,其集合可用θ表示;f为非线性的激活函数(通常采用sigmoid函数或relu)。因此神经网络模型的每一层可描述为对输入进行加权求和,并通过非线性变换得到输出值。对于特定的训练样本,通常采用损失函数来衡量神经网络输出值与期望值之间的偏差,对模型的训练即求解θ,使得在训练样本上最小化损失函数。训练方法通常采用反向传播算法计算梯度,并采用梯度下降法迭代更新参数θ(权重和偏置值),直至得到最优的θ值。

与传统的神经网络相比,深层神经网络模型包含更多的隐藏层,每一层的输出直接作为下一层的输入。其每个隐藏层都对上一层的输出进行特征变换,得到更加抽象的特征,因此深层神经网络具有强大的特征表达能力。而且通过端对端的训练,深层神经网络可以实现完全自主学习特征,避免了人工设计特征的繁琐和盲目性。长期以来由于梯度弥散等理论问题以及硬件计算能力的限制,包含多个隐藏层的神经网络的训练一直是一个难以解决的问题。从2006年起,随着深度学习技术的兴起,深层神经网络的训练在理论上得到了一定程度的解决。计算机计算能力的提升特别是gpu加速的使用,以及更多的训练样本使得深层神经网络的训练成为可能,并且在计算机视觉、语音识别、自然语言处理等领域都取得了显著效果。

然而,由于深度神经网络中包含大量参数,一旦网络设置过大,很容易陷入过拟合,使得测试集上的效果反而会变差。而网络大小的设置往往依靠经验及大量实验进行尝试,具有一定的盲目性。目前已经有一些防止网络过拟合的方法,如:在损失函数中加入正则项,以惩罚较大的权重值;设置验证集以监测泛化误差,当其不再减小时即停止训练;在每次迭代时随机丢掉一部分神经元等。



技术实现要素:

本发明针对深度神经网络冗余参数过多,容易陷入过拟合的问题,提出一种基于学习自动机的深度神经网络弱连接的去除方法,在传统的梯度下降迭代过程中引入la寻找连接中的弱连接,去掉冗余连接以减少网络参数,降低网络计算量,提高在测试样本上的分类精度,使其具有更强的防止过拟合的能力。

本发明是通过以下技术方案实现的:

本发明涉及一种基于学习自动机的深度神经网络优化方法,在深度神经网络的训练阶段,从全连接的初始网络结构出发,在通过梯度下降迭代更新参数的过程中不断找到网络中的弱连接并将其去除,从而得到更为稀疏连接、具有更小的泛化误差的网络结构,以便用于对测试样本进行更高精度的图像分类。

所述的弱连接,通过la在训练过程中不断与神经网络交互而进行判定,具体是指:对神经网络中的每一个连接,分别分配一个la对当前连接的强弱进行判定,即:采用具有两个行为α1和α2的fssa模型,其中:行为α1对应判定当前连接为强连接,α2对应判定当前连接为弱连接;每个行为对应n个内部状态,即la共有2n个状态,记为l2n,n,其中:n代表记忆深度;该学习自动机的输出函数为:当时刻t处于状态q(t)=qi,1≤i≤n,则输出α1,即判定当前连接为强连接;当处于状态q(t)=qi,n+1≤i≤2n,则输出α2,即判定当前连接为弱连接。

所述的与神经网络交互是指:当当前连接权重大于阈值时,对当前la进行奖励,否则,对la进行惩罚:在没有任何先验知识的情况的初始时刻下,la处于状态q1;经过一次迭代过程中,当la得到奖励,则向判定为强连接的状态移动,即从当前状态qi转移到qi-1,当i=1则保持原状态;当la得到惩罚,则向判定为弱连接的状态移动,即从当前状态qi转移到qi+1,当i=2n则保持原状态不变。

所述的去除是指标记或者删除弱连接的过程,该过程最为简单的处理方案是:在前向传播时把当前权重置零;并且在反向传播过程中把当前权重的梯度置零。

所述的la可定义为一个五元组<a,b,q,t,g>,其中:a为输出行为集合,也是la需要最终从中找出最优行为的行为集;b为从环境输入的反馈的集合,通常包含奖励和惩罚两种;q为la内部状态的集合;t为状态转移方程,即la根据环境反馈更新内部状态的策略;g为输出方程,描述从内部状态到输出行为的映射。

技术效果

与现有技术相比,本发明提出对神经网络中的连接进行判断,并去除弱连接的方法,并提出采用增强学习中的la算法来完成弱连接的判断,其优势有以下几点:la模型简洁直观,且不会耗费很大的额外计算量;la采用迭代更新的优化过程,便于和梯度下降的迭代更新过程同步进行;由于训练过程中网络参数不断变化,处于非平稳环境中,而la对非平稳及有噪声的环境具有很强的适应能力。

与现有技术相比,本发明可以实现对深度神经网络的结构进行优化,有效地削减冗余连接,并实现用更少的参数得到更低的分类误差,由于神经网络内连接数量的减少,使得图像分类中的计算量有所降低,提高了分类速度。

附图说明

图1为本发明方法示意图;

图2为本发明基于la去除弱连接的神经网络训练流程图;

图3为实施例中应用本专利的图像分类系统图。

具体实施方式

如图1所示,本实施例在深度神经网络的训练阶段,从全连接的初始网络结构出发,在通过梯度下降迭代更新参数的过程中不断找到网络中的弱连接并将其去除,从而得到更为稀疏连接、具有更小的泛化误差的网络结构。

如图3所示,为基于上述方法得到的优化后的深度神经网络对测试样本进行图像分类的过程,具体为:首先对原始的输入图像(如灰度图或rgb图像)进行简单的标准化预处理:对各个维度减去均值并除以方差,然后输入经训练的分类模型,进行分类并得到更高精度的结果。

所述的分类模型包含深度神经网络和la,其中:深度神经网络为全连接的多层前馈神经网络,la负责对网络结构进行优化调整,即削减弱连接。

如图2所示,为分类模型的训练过程,具体通过以下实施例进行描述。

本实施例中采用mnist手写数字数据集对基于专利所提出方法的分类模型进行训练,数据集中包含0-9共十类手写体数字,图像大小为28×28的灰度图。整个数据集分为训练集和测试集两部分,我们用训练集中的60000个样本对深度神经网络和la进行训练,然后测试其在测试集10000个样本上的分类误差。训练过程包含以下步骤:

步骤一:根据分类模型中采用的前馈神经网络层数及神经元数目初始化网络参数θ(包括权重w和偏置值b),对每个连接的权重设置一个la并初始化其状态为q1。

在本实施例中,分别构建了隐藏层数为2、3、4、5的前馈神经网络,每个隐藏层均为1000个神经元,对权重和偏置采用高斯初始化。

步骤二:从训练集中随机抽取一个批次的样本,并采用反向传播方法计算损失函数对网络参数θ的梯度。

为了增加网络的稀疏性,增大弱连接的比例以加快训练速度,可在损失函数中加入l1或l2正则项。此时损失函数中:代表预测误差,λ1、λ2分别为l1和l2正则项的比重。

本实施例中使用relu激活函数,并采用了l2正则项,获得最优训练效果的参数设置:λ2=0.00001。

步骤三:根据得到的梯度值,利用梯度下降法对权重和偏置进行更新:θ(t+1)=θ(t)+αm(t),其中:m(t)为动量项,为时刻t的梯度值,α为学习速率。

由于学习速率的设置对最终训练效果有较大影响,本实施例中采用了初始学习速率为1,并随迭代次数逐渐衰减的学习策略。动量项的参数γ设为0.9。

步骤四:根据当前神经网络的权重值分别对每个la进行一次奖励或惩罚,神经网络与la的交互方式如图2所示,其规则为:若当前权重值大于某个阈值,则对当前la进行奖励;若当前权重值小于某个阈值,则对la进行惩罚。la得到神经网络的反馈后更新自身的内部状态,更新规则为:

当la得到奖励,则向判定为强连接的状态移动,即从当前状态qi转移到qi-1,若i=1则保持原状态不变;当la得到惩罚,则向判定为弱连接的状态移动,即从当前状态qi转移到qi+1,若i=2n则保持原状态不变。

步骤五:每隔一段周期(本实施例中周期为每四次迭代),根据la的当前状态对网络中每个连接的强弱进行判定,并去除判定为弱连接的连接权重。即在后续迭代过程中在前向传播时权重置零;并且在反向传播过程中把权重的梯度置零。

以上步骤二到五不断进行迭代直到损失函数不再达到更优的值为止。

对于上述方法,权重阈值的设定将影响到la进行第一次判定时弱连接的比例。去掉过多的弱连接可以使网络更稀疏,但同时也会损失一部分网络中已经学习到的特征,影响训练效果。根据不同层数的网络中对权重的初始化值对权重阈值进行调整,使得第一次判定时弱连接比例保持在10%左右。同时,本实施例采用了l2正则化以加速每次判定时弱连接的比例。

在mnist手写数字数据集上对所提出的方法的效果进行了测试。在60000个样本上对网络进行训练,并使用另外10000个样本测试其泛化误差。对包含不同层数的深层前馈神经网络进行了实验。实验中采用每个隐藏层都设置为1000个神经元。

结果表明对于包含隐藏层数分别为2、3、4、5的深层神经网络,采用传统的反向传播方法加上l1、l2正则项在训练集上进行训练,它们在测试集上所能达到的分类误差分别为1.44、1.43、1.61、1.79,而采用基于la的方法在测试集上可达到的分类误差为1.39、1.38、1.49、1.47。达到最优测试误差时所用的连接数(即去掉弱连接后剩余的连接数)分别为全连接时的43%、62%、60%、44%。

实验结果可以看出本发明所提出的方法去掉了冗余连接,最终的模型所需的参数更少,并且在测试集上达到了更低的分类误差。

上述具体实施可由本领域技术人员在不背离本发明原理和宗旨的前提下以不同的方式对其进行局部调整,本发明的保护范围以权利要求书为准且不由上述具体实施所限,在其范围内的各个实现方案均受本发明之约束。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1