基于滑动窗口采样的分布式机器学习训练方法及其系统与流程

文档序号:12469228阅读:994来源:国知局
基于滑动窗口采样的分布式机器学习训练方法及其系统与流程

本发明涉及大规模机器分布式训练,特别是涉及一种基于滑动窗口采样的分布式机器学习训练方法及其系统。



背景技术:

在大数据集上进行训练的现代神经网络架构可以跨广泛的多种领域获取可观的结果,领域涵盖从语音和图像认知、自然语言处理、到业界关注的诸如欺诈检测和推荐系统这样的应用等各个方面。但是训练这些神经网络模型在计算上有严格要求,尽管近些年来GPU硬件、网络架构和训练方法上均取得了重大的进步,但事实是在单一机器上,网络训练所需要的时间仍然长得不切实际。幸运的是,我们不仅限于单个机器:大量工作和研究已经使有效的神经网络分布式训练成为了可能。分布式训练中的数据并行方法在每一个worker machine上都有一套完整的模型,但分别对训练数据集的不同子集进行处理。数据并行毫无争议是分布式系统中最适的方法,而且也一直是更多研究的焦点。在数据并行(data parallelism)中,不同的机器有着整个模型的完全拷贝;每个机器只获得整个数据的不同部分。计算的结果通过某些方法结合起来。数据并行训练方法均需要一些整合结果和在各工作器(worker)间同步模型参数的方法。现有的分布式机器学习训练方法一般为SGD,目前常用的SGD算法为基于延迟与软同步的SGD即Staleness Aware SGD,然而它们存在以下一些问题:

Staleness Aware SGD使用当前过期梯度(Staleness)调整对应学习器的学习率,将分布式异步训练的节点快慢产生过期梯度这个问题考虑进来,在普适计算的环境中,人和计算机不断的进行着透明性的交互,在这个交互过程中,普适系统获取与用户需求相关的上下文信息来确认为用户提供什么样的服务,这就是上下文感知,它是普适计算的重要技术。考虑到过期梯度,虽然一定程度上缓解了全异步同步参数引起的整体模型收敛效果差与受集群系统波动的影响的问题,但使用当前学习器的过期梯度无法感知该学习器过期梯度的上下文信息,对于该过期梯度处理过于简单,这样造成分布式异步训练稳定性与收敛效果依然不够理想。



技术实现要素:

鉴于以上所述现有技术的缺点,本发明的目的在于提供一种基于滑动窗口采样的分布式机器学习训练方法及其系统,用于解决现有技术中学习器的梯度过期程度无法感知该学习器梯度过期程度的上下文信息、对于该过期梯度处理过于简单、从而造成分布式异步训练稳定性与收敛效果不好的问题。

为实现上述目的,本发明采用以下方案:一种基于滑动窗口采样的分布式机器学习训练方法,包括以下步骤:步骤1),机器学习模型参数初始化;步骤2),获取所有数据的一个数据分片,独立进行模型训练;步骤3),收集历史的若干轮梯度过期程度样本,通过滑动采样样本,并计算梯度过期程度上下文值,调整学习率后发起梯度更新请求;步骤4),异步收集多个梯度过期程度样本,利用调整后的学习率更新全局模型参数并推送更新的参数;步骤5),异步获取推送的全局参数更新,继续下一次训练;步骤6),检验模型收敛性,若不收敛,进入所述步骤2)循环;若收敛,进入步骤7);步骤7),获取模型参数。

于本发明的一实施方式中,在所述步骤4)中,还包括维护一个逻辑时钟记录当前模型参数版本的步骤,在每进行一次从梯度到参数值的优化动作后,逻辑时钟加1。

于本发明的一实施方式中,在每进行一次逻辑时钟加1后,用当前的逻辑时钟更新机器学习的逻辑时钟。

于本发明的一实施方式中,在所述步骤3)中,梯度过期程度的计算为:当前的逻辑时钟-机器学习的逻辑时钟+1,其中,每个机器学习模型保存前N-1次更新时的梯度过期程度,N为当前更新的次数。

于本发明的一实施方式中,在所述步骤3)中,梯度过期程度上下文值通过计算当前梯度过期程度与前N-1个梯度过期程度的均值得到。

此外,本发明还提供了一种应用上述方法的基于滑动窗口采样的分布式机器学习训练系统,所述系统包括:服务器节点,所述服务器节点异步收集若干个梯度更新请求,进行全局模型参数更新并保存,被动推送更新的参数;学习器节点,每个所述学习器节点获取所有数据的一个数据分片,独立进行模型训练,每轮训练完毕后,使用调整过的学习率向所述服务器节点发起梯度更新,并异步获取所述服务器节点推送的更新的参数,发起下一轮训练;滑动采样模块,所述滑动采样模块附属于所述学习器节点,用于完成前若干轮梯度过期程度样本的采样,计算梯度过期程度上下文值,并在所述学习器节点向所述服务器节点推送梯度时,推送当前梯度过期程度上下文值,调整此次更新学习率。

于本发明的一实施方式中,每个所述学习器节点保存前N-1次更新时的梯度过期程度,其中,N为当前更新的次数。

于本发明的一实施方式中,所述服务器节点每进行一次从梯度到参数值的优化动作后,逻辑时钟加1。

于本发明的一实施方式中,每个所述学习器节点维护一个自己的逻辑时钟,所述学习器节点异步提交梯度并立即获取当前所述服务器节点参数值,并用当前所述服务器节点的逻辑时钟更新自己的逻辑时间。

于本发明的一实施方式中,每个所述学习器节点异步提交梯度至所述服务器节点,所述服务器节点异步累计任意个所述学习器节点的梯度,然后进行梯度到参数的优化动作。

如上所述,本发明的基于滑动窗口采样的分布式机器学习训练方法及其系统,具有以下有益效果:通过使用滑动窗口采样训练过程中的快慢节点特征来感知训练过程的上下文,相比普通的梯度延迟感知的异步更新的训练方法,能更好地感知过期梯度,使用过期梯度上下文来控制学习器的学习率,缓解训练收敛效果不佳的问题,同时减小了分布式系统带来的训练波动,提高了分布式训练的鲁棒性。

附图说明

图1显示为本发明基于滑动窗口采样的分布式机器学习训练方法于一实施例中的流程图。

图2显示为基于服务器节点的数据并行训练模式于一实施例中的架构图。

元件标号说明

1 服务器节点

2 学习器节点

S1~S7 步骤

具体实施方式

以下通过特定的具体实例说明本发明的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本发明的其他优点与功效。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。

需要说明的是,以下实施例中所提供的图示仅以示意方式说明本发明的基本构想,虽图示中仅显示与本发明中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。

本发明涉及大规模机器分布式训练的方法,针对当前分布式机器学习训练方法稳定性与收敛性效果不好的问题,提出了一种基于滑动窗口采样的分布式机器学习训练方法及其系统,在进行机器学习分布式训练时,使用滑动窗口采样感知过期梯度并基于感知到的过期梯度上下文调整学习率来实现提高大规模机器学习分布式训练的稳定性与收敛效果。本发明可以用于以下几种典型的应用场景:数据并行的大规模机器学习系统、深度学习系统。以下结合附图和实施例对本发明作进一步的说明。

实施例一

请参阅图1,为本发明基于滑动窗口采样的分布式机器学习训练方法于一实施例中的流程图。本发明的方法包括以下步骤:

S1,机器学习模型参数初始化;

S2,获取所有数据的一个数据分片,独立进行模型训练;

S3,收集历史的若干轮梯度过期程度样本,通过滑动采样样本,并计算梯度过期程度上下文值,调整学习率后发起梯度更新请求;

S4,异步收集多个梯度过期程度样本,利用调整后的学习率更新全局模型参数并推送更新的参数;

S5,异步获取推送的全局参数更新,继续下一次训练;

S6,检验模型收敛性,若不收敛,进入所述步骤2)循环;若收敛,进入步骤7);

S7,获取模型参数。

作为示例,在所述步骤S4中,还包括维护一个逻辑时钟记录当前模型参数版本的步骤,在每进行一次从梯度到参数值的优化动作后,逻辑时钟加1。

作为示例,在每进行一次逻辑时钟加1后,用当前的逻辑时钟更新机器学习的逻辑时钟。

作为示例,在所述步骤S3中,梯度过期程度的计算为:当前的逻辑时钟-机器学习的逻辑时钟+1,其中,每个机器学习模型保存前N-1次更新时的梯度过期程度,N为当前更新的次数。

作为示例,在所述步骤S3中,梯度过期程度上下文值通过计算当前梯度过期程度与前N-1个梯度过期程度的均值得到,也即滑动窗口采样。

作为示例,设stalenessContext(i)为第i个节点梯度过期程度上下文值,gradient(i)为第i个节点梯度,使用C(C为大于等于1的整数)个平均梯度,则做以下加权平均得到:

G=1/C*sum(stalenessContext(i)*gradient(i)),i={0,1,...nodes},即使用梯度过期程度上下文值来调整学习率,让其感知过期梯度上下文。

实施例二

此外,本发明还提供了一种应用上述方法的基于滑动窗口采样的分布式机器学习训练系统,请参阅图2,所述系统包括:服务器节点1,所述服务器节点1异步收集若干个梯度更新请求,进行全局模型参数更新并保存,被动向客户端推送更新的参数;学习器节点2,每个所述学习器节点2获取所有数据的一个数据分片,独立进行模型训练,每轮训练完毕后,使用调整过的学习率向所述服务器节点1发起梯度更新,并异步获取所述服务器节点1推送的更新的参数,发起下一轮训练;滑动采样模块(未示出),所述滑动采样模块附属于所述学习器节点2,用于完成前若干轮梯度过期程度样本的采样,计算梯度过期程度上下文值,并在所述学习器节点2向所述服务器节点1推送梯度时,推送当前梯度过期程度上下文值,调整此次更新学习率。

需要注意的是,滑动采样模块在参数服务器中的学习器节点2完成采样工作,并且提供一个整体训练的技术方案,可用于实际的通用机器学习、深度学习算法的分布式训练过程。

作为示例,所述服务器节点1每进行一次从梯度到参数值的优化动作后,逻辑时钟加1。所述服务器节点1统一维护一个逻辑时钟记录当前参数的版本,所述服务器节点1每进行一次从梯度到参数值的优化动作,逻辑时钟加1。

作为示例,每个所述学习器节点2也维护一个自己的逻辑时钟,所述学习器节点2异步提交梯度并立即获取当前所述服务器节点1参数值,并用当前所述服务器节点1的逻辑时钟更新自己的逻辑时间。

作为示例,每个所述学习器节点2需要保存前N-1次更新时的梯度过期程度,其中N为当前参数更新的次数。

作为示例,若设trainer表示为学习器节点,pserver表示为服务器节点,gradient表示梯度,staleness表示为梯度过期程度,value表示参数值,stalenessContext表示为梯度过期程度上下文值,Timestamp表示逻辑时钟,i表示第i个学习器节点,且i={0,1,...nodes},G为平均梯度值,C表示收集的学习器节点的个数,则梯度过期程度staleness的计算式为:

staleness(trainer)=Timestamp(pserver)-Timestamp(trainer)+1。

当前梯度过期程度staleness与前N-1个梯度过期程度staleness的均值,即滑动窗口采样,任何学习器节点trainer异步提交梯度gradient至服务器节点pserver,服务器节点pserver异步累计C个学习器节点trainer的梯度gradient,然后进行梯度gradient到参数value的优化动作,即使用软同步策略。计算方法是使用来自C个学习器节点trainer的平均梯度G,做以下加权平均则得到:

G=1/C*sum(stalenessContext(i)*gradient(i)),其中i={0,1,...nodes},即使用梯度过期程度上下文值stalenessContext调整学习率,让其感知过期梯度上下文。学习器节点trainer与服务器节点pserver按照上述方法进行训练,直到训练收敛。

需要注意的是,对任何学习器节点trainer而言,它异步提交梯度gradient并立即获取当前的参数值value;对服务器节点pserver而言,它累计来自任意C个学习器节点trainer的梯度gradient,并立即进行优化动作。模型参数更新公式为:

其中,s代表等待s个学习器节点的更新,Wi代表第i个学习器节点训练后的参数,ΔWj代表第j个学习器节点贡献的梯度,λ(ΔWj)代表经过梯度过期程度上下文值stalenessContext放缩过的学习率。

综上所述,首先,对于每个学习器,需要它要保存前N-1个样本,每个样本为每次更新参数时的梯度过期程度值;然后,每个学习器的梯度过期程度上下文值为当前的梯度过期程度值与前N-1个样本值的均值,计算完毕后丢弃最老的样本,保存最新的样本;最后,在学习器向参数服务器发起更新请求时,使用梯度过期程度上下文值控制该学习器的学习率。

本发明的基于滑动窗口采样的分布式机器学习训练方法及其系统,通过使用滑动采样模块训练过程中的快慢节点特征来感知训练过程的上下文,相比普通的梯度延迟感知的异步更新的训练方法,能更好地感知过期梯度,使用过期梯度上下文来控制学习器的学习率,提高大规模机器学习分布式训练的稳定性与收敛效果,同时减小了分布式系统带来的训练波动,提高了分布式训练的鲁棒性。

上述实施例仅例示性说明本发明的原理及其功效,而非用于限制本发明。任何熟悉此技术的人士皆可在不违背本发明的精神及范畴下,对上述实施例进行修饰或改变。因此,举凡所属技术领域中具有通常知识者在未脱离本发明所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本发明的权利要求所涵盖。

当前第1页1 2 3 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1