一种融合元学习的多终端协同训练算法及系统

文档序号:29611663发布日期:2022-04-13 09:25阅读:128来源:国知局
一种融合元学习的多终端协同训练算法及系统

1.本发明涉及人工智能技术领域,特别涉及一种融合元学习的多终端协同训练算法及系统。


背景技术:

2.现如今,任何一个人工智能(ai)项目都可能涉及多个领域,因此需要对来自各个公司、各个部门的数据进行整合。然而,实际应用中,由于各方对数据所有权和隐私性的关注越来越多,对用户隐私及数据的安全管理日趋严格,想要将分散在各地、各个机构的数据进行整合几乎变得不可能。在这种前提下,基于大数据的训练对于某个ai项目来说是高精度的必要保障,因此要求在满足隐私监管要求的前提下,设计一个机器学习框架,联邦学习算法应运而生。
3.联邦学习中,一种常见的算法由图1所示,在每个客户端分别使用本地数据训练各自的模型,之后将各自训练好的模型传输到服务器上进行融合,再将融合后的模型传回各客户端做继续训练。由于各客户端中的本地数据经常数目十分有限,导致其采用训练算法获得的模型经常会对本地数据产生过度适配,这样在服务器端将来自不同客户端的模型融合时,各模型无法很快地适配其它客户端上的数据处理,导致整体精度有限,且需要更多轮通信才能获得较为彻底的模型融合。
4.现有技术中,为了尽可能减少通信次数,通常采用的方式是通过限制需要和服务器进行通信的本地客户端数目来实现,或者,采用sgd(联邦平均算法)在本地客户端得到测试损失,和服务器进行通信来达到共同训练的效果。然而,sgd的计算效率虽高,但该方法需要大量的训练才能产生较为精确的模型,对于大多数客户端来说,其本地数据量远远不能达到sgd所需的标准,因此,需要在客户端本地的训练过程中引入更高效的算法,以更好地利用本地较少的数据量训练出具备迁移性的改进算法。


技术实现要素:

5.本技术提供了一种融合元学习的多终端协同训练算法及系统,以解决现有技术中,客户端采用少量数据训练出的模型迁移性较差,融合精确度低的问题。
6.第一方面,本技术提供了一种融合元学习的多终端协同训练算法,包括:
7.客户端加载位于本地的训练模型并初始化网络的权重参数;
8.客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;
9.服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。
10.在一些实施例中,得到平均模型后,所述算法还包括:
11.服务器获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;
12.若所述评估结果为满足要求,则停止数据通信与训练;
13.若所述评估结果为未满足要求,则重新执行所述客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型的步骤。
14.在一些实施例中,所述测试数据集中的数据样本根据类别不同分为多个数据包,其中,每个数据包采用n-way k-shot表示,n为每个数据包中随机抽取的类别数,way为类别,k为每个类别包含的数据样本数量,shot为数据单位。
15.在一些实施例中,所述采用元学习算法调整所述训练模型的步骤包括:
16.客户端从本地存储的数据样本中随机抽取一份数据包;
17.利用内循环和外循环更新所述训练模型的模型参数。
18.在一些实施例中,利用内循环更新所述训练模型的模型参数包括:
19.建立多个任务,每个任务采用梯度下降的规则,基于训练模型的原始参数θ得到更新参数θ
i’;其中i表示第i个任务;
20.根据更新参数θ
i’计算交叉熵损失l
ti
,所述交叉熵损失l
ti
由所有任务下得到的更新参数θ
i’相加得到。
21.在一些实施例中,所述外循环更新所述训练模型的模型参数采用下列公式得到:
[0022][0023]
其中,θn为调整后模型的模型参数,β为学习率,ti指i个任务,σti(*)是指对任务求和,指采用参数θ
i’的模型。
[0024]
第二方面,本技术还提供了一种对应与第一方面提供方法的系统。
[0025]
本技术提供的方法在联邦学习的基础上,在各个客户端引入针对小样本情境(即少量训练数据)的元学习算法,在训练中可以高效获取少量样本中的元信息,所训练出的模型对于新数据也有较好的迁移性,采用该方法训练出的客户端模型在服务器端进行融合后对于其它客户端的数据集也具有较高的处理精度。
[0026]
由于客户端训练出来的模型迁移性较好,所需的模型融合的通信次数显著降低,对于每个客户端来说,采用更少的训练次数、更短的训练时间以及更低的能量消耗即能获取同样的模型精度。
附图说明
[0027]
为了更清楚地说明本技术的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0028]
图1为现有技术中一种常见的联邦学习算法的原理图;
[0029]
图2为本技术一种融合元学习的多终端协同训练算法的流程图;
[0030]
图3为图2所示算法中步骤s200的分解步骤图;
[0031]
图4为本技术一种融合元学习的多终端协同训练算法在另一种实施例下的流程图;
[0032]
图5为本技术提供的方法的其中一种实施例的流程图。
具体实施方式
[0033]
鉴于小样本学习和联邦学习目标有重叠,即为了实现客户端侧设备数据隐私保护前提下,训练出一个高精度的集成模型,同时小样本中的元学习训练方案可以帮助模型在未见过的数据上泛化性增强,因此本技术考虑将二者结合,将元学习引入联邦学习的客户端侧训练,提升联邦学习多终端协同训练的性能。提升性能分为三方面:一、在保证学习性能前提下,降低整体通信次数;二、在保证学习性能前提下,降低端侧训练次数;三、相同训练消耗下,集成模型精度提升。该发明提出的解决方法是目前首次提出将端侧元学习引入联邦学习、且有效的方案。
[0034]
在本技术提供的方案中,所提到的联邦学习是指一种学习技术,它允许用户在不需要集中存储数据的情况下,从这些丰富的数据中取获得共享模型的好处。这种方法还允许我们利用网络边缘可用的廉价计算来扩展学习任务。联邦学习适合任务的特点有:一、对来自移动设备的真实数据的训练比对数据中心通常可用的代理数据的训练具有明显优势;二、所处理的数据是隐私敏感性的或者规模较大的,因此它不适合将其记录到数据中心来进行模型训练;三、对于监督型任务,数据集上的标签可以从用户与他们设备的交互过程中自然推理出来。基于联邦学习不能独立解决当客户端侧数据较少的问题时,本技术提供了一种基于联邦学习的改进型算法。
[0035]
参见图2,为本技术一种融合元学习的多终端协同训练算法的结构示意图;
[0036]
由图2可知,本技术实施例提供的一种融合元学习的多终端协同训练算法应用于每一个客户端中时,包括:
[0037]
s100:客户端加载位于本地的训练模型并初始化网络的权重参数;
[0038]
在本实施例中,与服务器进行通信的客户端可以有多个,对于每一个客户端来说均可沿用该方法;各个客户端(端侧)通常在本地配置有不同的训练模型,用于执行对本地数据样本的训练,本技术在读取数据样本前,需要先初始化网络中权重参数,使训练模型保持最初的设定。
[0039]
s200:客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;
[0040]
在本实施例中,客户端所利用的数据样本,是指客户端在本地存储的较少数量的数据样本,不同于现有技术其它方法,本技术方法尤其应用在样本数量较低的情况下,根据数据样本形式的不同,样本数量会相应设置为较低的数量级,例如,如果数据样本是图片,则这里的少量数据样本指的是几十到几千张图片,常规技术中的大量数据样本则一般指成千上万张图片;如果数据样本用数据大小表示,则这里的少量数据可以指几kb至及mb大小的数据,而常规技术中的大量数据样本则一般指以gb或以上数量级的数据,等等。
[0041]
进一步的,上述数据样本可以是原本存储在客户端中的,也可能是客户端由其他方式获取的,例如客户端自行从场景中采集、用户输入等等,还可以是由服务器指定发送的,如果是由服务器指定发送,则在步骤s200之前,本技术的方法还包括数据获取步骤,即由服务器将包含数据样本的数据包分别发送给多个客户端,这里需要说明的是,服务器给各个客户端所发送的数据包均不重复,相互之间可以是互补关系。
[0042]
对于客户端本地及其有限的数据样本,本技术是先引用元学习算法进行对客户端原始训练模型的调整,使各客户端发送给服务器的模型参数并非原始模型参数,而是更有
利于融合的参数。
[0043]
其中,元学习属于小样本学习中的一种,近年来,小样本学习分类发展迅速,面对众多的分类任务,都可以通过训练一个模型来达到任务要求。其中元学习的方法较多,为了最大的适用性,元学习的机制是任务的通用性,即面对不同的任务,不需要构建不同的模型,用同样的一套学习算法即可解决多种不同的任务。定义一个模型的可学习参数θ,面对不同的任务,可以通过改变参数θ的值来解决相应的任务。而参数θ的值可以通过元学习器去学习,在面对不同任务的时候,根据损失函数通过梯度下降的方法不断地更新θ值,使这个模型不断向能解决这个任务的模型靠近,当θ值最终收敛时,我们认为元学习器学习到了一个较好的参数θ,让模型自适应地解决相应任务。这个算法具有高效的特点,因为它没有为学习器引入其他的参数,并且训练学习器的策略使用的是已知的优化过程(如梯度下降等)而不是从头开始构建一个。
[0044]
具体的,参见图3,采用元学习算法调整所述训练模型的步骤包括:
[0045]
s210:客户端从本地存储的数据样本中随机抽取一份数据包;其中每个数据包采用n-way k-shot表示,n为每个数据包中随机抽取的类别数,way为类别,k为每个类别包含的数据样本数量,shot为数据单位,例如5-way 5-shot,即为每次都从剩余的数据样本中随机抽取5个类别,再从每个类别包含的数据中抽取还没有抽取过的5个数据,以此形成5-way 5-shot。
[0046]
s220:利用内循环和外循环更新所述训练模型的模型参数。
[0047]
在本实施例中,内循环也称为本地循环,即在客户端内部执行更新模型参数的过程,外循环也称为全局循环,即在包括多个客户端以及服务器在内的整个系统执行更新模型参数的过程。
[0048]
内循环分为多个任务,每个任务利用梯度下降的规则,基于模型的初始参数,更新得到更新参数,利用更新参数计算模型损失,具体流程为:
[0049]
先获取到本地训练模型;
[0050]
在每轮循环中,建立多个任务,每个任务采用梯度下降的规则,对抽取到的数据包评价模型损失,基于训练模型的原始参数θ得到更新参数θ
i’;
[0051]
根据更新参数θ
i’计算交叉熵损失l
ti
,所述交叉熵损失l
ti
由所有任务下得到的更新参数θ
i’相加得到;
[0052][0053]
其中i表示第i个任务;α为学习率;l
ti
为交叉熵损失;
[0054]
外循环是等待所有内循环的任务结束后,利用下列公式计算得到用于更新原始训练模型的参数:
[0055][0056]
其中,θn为调整后模型的模型参数,ti指i个任务,σti(*)是指对任务求和,指采用参数θ
i’的模型,β为学习率。
[0057]
需要说明的是,在一些可行性实施例中,执行外循环操作的客户端并非是系统中全部与服务器通信连接的客户端,即每次外循环可随机选取一定比例的客户端执行内循环,例如选择20%的客户端进行本地训练,并对该部分数据的训练结果计算更新模型参数;
下一次外循环时,再在其它客户端中随机选取20%客户端执行内循环以获取更新模型参数;这样有利于减小系统的单次消耗,同时也减少了通信次数。
[0058]
s300:服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。
[0059]
在本步骤中,服务器接收各个客户端经过调整后的训练模型(更新参数后的训练模型)后,再根据现有技术中的常用手段进行融合操作,具体的,该融合操作可采用现有技术中的多种方式,例如加权平均操作,求各模型的l2范数操作等;对于具体手段在本实施例中不予限制,应当认为,对模型实施融合操作的方法均可应用到本技术中。
[0060]
在步骤s300中得到的平均模型,是基于多个客户端的融合得到的,但可能仅经过一次融合不能使其对于每一个客户端的新增少量数据均具有较高的精确度,因此,在图4所示出的一些实施例中,需要增加对平均模型精度评估的步骤:
[0061]
s400:服务器获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;其中,所述测试数据集相当于是将各个客户端内存储的数据样本进行了整合,与数据样本的构成相同,测试数据集中也是按照数据样本根据类别分为多个数据包,其中,每个数据包采用n-way k-shot表示,n为每个数据包中随机抽取的类别数,k为每个类别包含的数据样本数量。
[0062]
采用测试数据集评估平均模型的精度,相当于评判平均模型是否能应用于测试数据集中的所有数据样本,如果存在有精度较低的数据样本情况,则认为该平均模型不满足要求,需要继续执行进一步的调整;如果均满足要求,则可认为该平均模型为最终模型,可停止继续训练以及通信过程,节省资源消耗,提高效率。
[0063]
s410:若所述评估结果为满足要求,则停止数据通信与训练;
[0064]
s420:若所述评估结果为未满足要求,说明此时的平均模型需要进一步调整,则重新执行前述的s200-s300的步骤,重新抽取数据包以及执行内循环及外循环的操作。
[0065]
由于在本技术中,在客户端引入了适用于小样本学习的元学习算法,从而让端侧的模型能够调整到最适应新类别的模式,即使加入了新类别导致融合后的模型不满足要求,也可以通过简单的几步即可调整,也就是说,上述步骤s420在实际应用中,循环的次数极少,可能仅需少数几次循环即可停止数据通信与训练过程。
[0066]
参考图5,为本技术提供的方法的其中一种实施例的流程图,首先,可由服务器将用于训练的数据集按照类别进行编号,抽样形成训练集的数据包,每个数据包是5-way5-shot形式,每次划分数据包都从剩余的数据随机抽取5个类别,从每类500个数据中抽取还没有抽取过的5个数据,以此形成5-way 5-shot,共可形成50份测试数据包,每个数据包是5-way 5-shot形式,方法类似训练集数据包,但暂不需要分配给客户端。
[0067]
在数据包划分好后,将训练集数据包中的部分分配给10个客户端,客户端之间的数据包不重复。
[0068]
训练分为内循环(本地循环)和外循环(全局循环)。在外循环中,并行的进行各个客户端的内循环,当外循环开始时,服务器可将更新的模型参数同时通信给各个客户端,客户端根据自身存储的数据包以及模型参数进行内循环和外循环,不断更新自身模型,当内外循环结束后,对客户端的模型在测试集上进行一次评价,若新的模型满足要求,则停止训练和通信,若不满足要求,则使用该新的模型重新进行一次内外循环。
[0069]
由上述技术方案可知,本技术提供了一种融合元学习的多终端协同训练算法,包
括客户端加载位于本地的训练模型并初始化网络的权重参数;客户端利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;服务器对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。本技术提供的方法在联邦学习的基础上,在各个客户端引入针对小样本情境(即少量训练数据)的元学习算法,在训练中可以高效获取少量样本中的元信息,所训练出的模型对于新数据也有较好的迁移性,采用该方法训练出的客户端模型在服务器端进行融合后对于其它客户端的数据集也具有较高的处理精度。
[0070]
对应于上述算法,本技术还提供了一种融合元学习的多终端协同训练系统,包括:
[0071]
服务器以及与服务器通信连接的多个客户端;
[0072]
所述客户端被配置为执行下列方法:
[0073]
加载位于本地的训练模型并初始化网络的权重参数;
[0074]
利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型;
[0075]
将调整后的模型发送给服务器;
[0076]
所述服务器被配置为执行下列方法:
[0077]
对来自多个客户端传输的调整后模型进行融合操作,得到平均模型。
[0078]
进一步的,所述服务器还被配置为:
[0079]
获取包含所有客户端存储的数据样本的测试数据集,根据所述测试数据集评估所述平均模型的精度,得到评估结果;
[0080]
若所述评估结果为满足要求,则停止数据通信与训练;
[0081]
若所述评估结果为未满足要求,则向对应客户端发送控制指令,使客户端重新执行利用本地存储的数据样本,采用元学习算法调整所述训练模型,得到调整后模型的步骤。
[0082]
进一步的,所述客户端配置有:
[0083]
抽取单元,用于从本地存储的数据样本中随机抽取一份数据包;
[0084]
参数更新单元,用于利用内循环和外循环更新所述训练模型的模型参数。
[0085]
利用内循环更新所述训练模型的模型参数包括:
[0086]
建立多个任务,每个任务采用梯度下降的规则,基于训练模型的原始参数θ得到更新参数θ
i’;其中i表示第i个任务;
[0087]
根据更新参数θ
i’计算交叉熵损失l
ti
,所述交叉熵损失l
ti
由所有任务下得到的更新参数θ
i’相加得到;
[0088]
所述外循环更新所述训练模型的模型参数采用下列公式得到:
[0089][0090]
其中,θn为调整后模型的模型参数。
[0091]
本实施例提供的系统的功能作用参见前述方法实施例中的描述,在此不再赘述。
[0092]
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本发明的其它实施方案。本技术旨在涵盖本发明的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本发明的一般性原理并包括本发明未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本发明的真正范围和精神由下面的权利要求指出。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1