一种基于扩展非线性核残差网络的手写字符识别方法与流程

文档序号:11251422阅读:365来源:国知局
一种基于扩展非线性核残差网络的手写字符识别方法与流程

本发明涉及深度学习、机器学习技术领域,特别涉及一种基于扩展非线性核残差网络的手写字符识别方法。



背景技术:

手写数字识别作为图像识别应用中的重要分支,其在生产生活中的重要性也逐渐体现出来。手写数字识别可以用于读取银行支票信息、信封邮政编码信息、海关等需要处理大量字符信息录入的场合。因此人们对计算机所建立的手写数字识别系统的要求也不断提高,手写数字识别系统要完成识别阿拉伯数字的任务,先决条件是构建手写数字识别模型,所以手写数字识别方法研究中最基础的问题是手写数字的特征提取及分类。

目前解决手写数字识别的方法有很多,比较常用的方法主要分为两大类:基于传统特征提取与模式分类的手写字符识别方法和基于深度学习的手写字符识别方法。

公开号为cn104298987a的专利:一种基于点密度加权在线fcm聚类的手写数字识别方法,用于处理大规模的脱机手写数字识别问题,包括步骤:1)预处理所有手写数字图像集合;2)初始化聚类中心,令数据点顺序进入处理流程;3)计算当前数据点与各聚类中心隶属度;4)若隶属度达到阈值更新最近聚类中心位置;5)若未达到阈值则不处理该点并暂时放入待处理区;6)待处理区达到一定标准则用点密度加权fcm算发聚类待处理区中数据,更新聚类中心;7)继续循环直至数据点全部处理完毕;8)用获得的聚类中心分块计算全部数据点的隶属度,并划分类别,通过一次扫描完成数据归类。该发明在处理大规模手写数字识别问题方面能够降低空间复杂度以及时间复杂度。但是该方法在形变非常大且数字相似时数字识别的效果不是太好。

公开号为cn102982343a的专利:包括手写数字的图像的采集和二值化处理;对采集的图像进行分割,构造由手写数字的图像为输入和0-9数字为输出的训练集;构造增量函数,并将该增量函数映射到区间[0,1];设置以λ表示增量参数和模糊支持向量机的计算复杂性参数;确定手写数字的类别,根据任何两个手写数字之间的分类超平面,确定手写数字的类别,在已知类别的手写数字上检验识别精度,确定手写数字类别的方法。但是该方法在手写数字数据量增大时,不能提取出表达能力很好的特征。

有鉴于此,有必要提供一种基于扩展非线性核残差网络的手写字符识别方法,以解决上述问题。



技术实现要素:

为了解决现有技术存在的问题,本发明的目的在于提供一种高效的基于扩展非线性核残差网络的手写字符识别方法,提出了一种新型深度学习方法—基于扩展非线性核残差网络算法,并将该深度学习算法应用在手写字符识别中,提出基于扩展非线性核残差网络的手写字符识别方法。

为了达到上述目的,本发明所采用的技术方案是:一种基于扩展非线性核残差网络的手写字符识别方法,其特征在于,所述手写字符识别方法包括以下步骤:

步骤1:采集手写数字图像作为样本,生成训练数据和测试数据,初始化基于扩展非线性核的残差网络结构;

步骤2:将图像样本引入网络训练之前,使用无监督的聚类算法对实验数据进行预处理;

步骤3:将先验知识优化后的训练数据均匀分批输入基于扩展非线性核的残差网络中,训练数据分别经过卷积层、池化层、基于扩展非线性核的卷积层、池化层,全连接层,完成前向传播;

步骤4:对步骤3中的网络进行梯度计算和误差计算;

步骤5:将步骤4中得到的误差和梯度用反向传播算法,经过池化层,基于扩展非线性核的卷积层、池化层、卷积层、输入层逐层传播,并且根据要求自动反向更新网络的权重,判断是否为输入层,若是则跳转至步骤3,否则重复步骤5,直至提取出有效的数据特征;

步骤6:直到权重更新稳定,建立出基于扩展非线性核的残差网手写字符识别训练模型;

步骤7:将测试数据按上述步骤输入,最终,得到准确的识别结果。

进一步地,所述步骤1中初始化基于扩展非线性核的残差网络结构为,设置该网络的初始参数,其中包括:扩展非线性残差核的数量、池化层数量、扩展非线性残差核的大小、池化层的降幅,并随机初始化扩展非线性残差核的权重和偏置。

进一步地,所述步骤2使用无监督的聚类算法对实验数据进行预处理,其中包括:数据在进入参数训练网络之前,根据需求给出聚类个数n,n为已知数据集的类别个数,实施聚类操作,生成n个已知聚类中心,在运算过程中围绕已知的中心在类内进行优化聚类。

更进一步地,所述步骤2进一步包括:

步骤2.1:初始化聚类中心根据不同数据集特点,并选取n个类别的代表样本作为初始化聚类个数和聚类中心;

步骤2.2:分配各样本xj到相邻近的聚类集合,样本分配依据为:

式中i=1,2,…,k,p≠j;

步骤2.3:根据步骤2.2的分配结果,分别在类别内部更新聚类中心;

步骤2.4:若迭代达到最大迭代步数或者前后两次迭代的差小于设定阈值ε,即则算法结束;否则重复步骤(2.2);

其中,表示将样本集组成的矢量空间划分的多个区域,表示每个区域存在的一个相关区域。

进一步地,所述步骤3的卷积层采用7*7的卷积核,滑动步长为2,池化层采用3*3的最大池化操作,滑动步长为2,基于扩展非线性核的卷积层在手写字符识别方法中使用了9层,层内根据需要调整了滤波器数量,后一个池化层采用步长为1的平均池化方式,再连接上全连接层后,进行前向传播输出。

进一步地,所述步骤4是在步骤3完成前向传播后,训练数据被传送到全连接层中,使用交叉熵代价函数计算输出值与期望值之间的差量,激活函数采用relu,提升训练速度,按极小化误差的方法使之收敛,并将误差向量保存起来。

进一步地,所述步骤5的反向传播算法为:按极小化误差的方法反向传播并且调整基于扩展非线性核残差网络中的权值参数,首先对样本批量前向传播,计算出基于扩展非线性核残差网络中所有的激活值,然后根据每层节点,计算每层网络中的复合残差,并根据需要,用复合残差去逼近表达每层函数,并且中间卷积层采用双通道卷积核。

进一步地,所述步骤5中的每一个扩展非线性残差核在输出至下一层时都加入了dropout层,此时,可在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来。

与现有技术相比,本发明的有益效果是:本发明提出的基于扩展非线性核残差网络的手写字符识别方法,该方法能深度地描述样本数据和期望数据的相关性,能高效的从原始数据中自动地学习数字图像特征;其次,该方法引入了合适的类内无监督聚类算法,克服了深度学习网络在手写字符识别领域现有的技术不足。本发明简单且易于实现,提升手写字符识别性能的同时,也提高了网络的训练效率。

附图说明

图1是本发明基于扩展非线性核残差网络手写字符识别方法的流程图。

图2是本发明所提出的扩展非线性核结构。

具体实施方式

以下通过实施例形式对本发明的上述内容再作进一步的详细说明,但不应将此理解为本发明上述主题的范围仅限于以下的实施例,凡基于本发明上述内容所实现的技术均属于本发明的范围。

图1是本发明基于扩展非线性核残差网络手写字符识别的流程图。

在本实施例中,如图1所示,本发明基于扩展非线性核残差网络的手写字符识别方法,包括以下步骤:

(1)、以标准手写数字识别库mnist进行实例验证

(2)、初始化各个参数:卷积层采用7*7的卷积核,滑动步长为2。池化层采用3*3的最大池化操作,滑动步长为2。基于扩展非线性核的卷积层在本手写识别系统中使用了9层,层内根据需要调整了滤波器数量。后一个池化层采用步长为1的平均池化方式,再连接上全连接层。

引入先验知识:从mnist手写字符库中提取手写字符图像样本,在将图像样本引入网络训练之前,使用无监督的聚类算法对实验数据进行预处理。数据在进入参数训练网络之前,根据需求给出聚类个数10,并分别在0-9中抽取一个(共10个)作为起始聚类中心,实施聚类操作,生成10个已知聚类中心。在运算过程中围绕这些已知的中心在类内(非整个数据集)进行优化聚类。其中算法包含样本分配、类内更新聚类中心等4个步骤。

(2.1)、初始化聚类中心根据不同数据集特点,并选取10个类别的代表样本作为初始化聚类个数和聚类中心;

(2.2)、分配各样本xj到相邻近的聚类集合,样本分配依据为:

式中i=1,2,…,10,p≠j。

(2.3)、根据(2.2)的分配结果,分别在类别内部(非整个数据集)更新聚类中心。

(2.4)、若迭代达到最大迭代步数或者前后两次迭代的差小于设定阈值ε,即则算法结束;否则重复步骤(2.2)。

其中,表示将样本集组成的矢量空间划分的多个区域,表示每个区域存在的一个相关区域。本文采用的聚类算法,步骤(2.3)更新聚类中心的操作设计为:在某一个类别中(类内)进行聚类。对比其他聚类算法在整个数据集上更新聚类中心,很大程度上缩短了参数训练时间。

(3)、将先验知识优化后的训练数据均匀分批输入基于扩展非线性核的手写字符深度网络中,训练数据分别经过卷积层、池化层、基于扩展非线性核的卷积层、池化层,全连接层、完成前向传播;其中能够使本发明的效果优于其他手写字符识别效果的关键在于:引入了本发明提出的基于扩展非线性核的手写字符识别原理。如图2所示。该扩展核在网络自动训练参数的过程中,根据每层节点情况,计算出每层网络中的复合残差并逐级保存记录,根据需要,用复合残差去逼近表达每层函数。减小了每层函数表达的误差。并且中间卷积层采用双通道卷积核,在相同深度的网络结构下,提升了网络表达复杂函数的能力,得到了更好的手写数字识别效果。另外,系统中的每一个扩展非线性残差核在输出至下一层时都加入了dropout层。此时,可在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来。加入该结构后,解决了网络过拟合问题和改善了训练网络参数耗时过大的问题。

(4)、对步骤(3)中的网络进行梯度计算和误差计算,使用交叉熵代价函数计算输出值与期望值之间的差量,激活函数采用relu,提升训练速度。按极小化误差的方法使之收敛,并将误差向量保存起来。并判断误差是否收敛;若是,则跳转至步骤6,否则跳转至步骤5;

(5)、将步骤(4)中得到的误差和梯度用反向传播算法,经过池化层,基于扩展非线性核的卷积层、池化层、卷积层、输入层逐层传播,并且根据要求自动反向更新网络的权重,判断是否为输入层,若是则跳转至步骤(3),否则重复步骤(5),直至提取出有效的数据特征;

(6)直到权重更新稳定,建立出基于扩展非线性核的残差网手写字符识别训练模型;

(7)将测试数据按上述步骤输入,最终,得到准确的识别结果。

以上所述仅是本发明的优选实施方式,具体实施方式中牵涉到的数值参数仅仅用来对上述的具体实施方式进行详细说明,不能作为限制本发明保护范围的依据。应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明技术原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1