使用聚类损失训练神经网络的制作方法

文档序号:18093335发布日期:2019-07-06 10:53阅读:243来源:国知局
使用聚类损失训练神经网络的制作方法

本说明书涉及训练神经网络。



背景技术:

神经网络是使用非线性单元的一个或多个层来针对所接收的输入预测输出的机器学习模型。一些神经网络除了包括输出层之外还包括一个或多个隐藏层。每个隐藏层的输出被用作网络中的下一层(即,下一个隐藏层或输出层)的输入。网络的每个层根据相应参数集的当前值从接收的输入生成输出。

一些神经网络是递归神经网络。递归神经网络是接收输入序列并从输入序列生成输出序列的神经网络。特别地,递归神经网络可以使用来自前一个时间步长的网络的一些或全部内部状态来在当前时间步长计算输出。递归神经网络的示例是包括一个或多个lstm存储器块的长短期(lstm)神经网络。每个lstm存储器块可以包括一个或多个细胞,每个细胞包括输入门、遗忘门和输出门,它们允许细胞存储该细胞的先前状态,例如,以用于生成当前激活或被提供到lstm神经网络的其他组件。



技术实现要素:

该说明书描述了在一个或多个位置上的一个或多个计算机上实现为计算机程序的系统,该系统训练神经网络,所述神经网络具有网络参数并且被配置为接收输入数据项目并根据网络参数处理输入数据项目以生成输入数据项目的嵌入。在一些特定的非限制性示例中,以本文描述的方式训练的神经网络可以用于图像分类。

可以实现本说明书中描述的主题的特定实施例,以便实现以下优点中的一个或多个。通过训练在本说明书中描述的神经网络,即,通过训练神经网络以优化所描述的目标,训练的神经网络可以生成更准确地反映网络输入之间的相似性的嵌入。特别地,通过如本说明书中描述训练神经网络,由训练的神经网络生成的嵌入可以有效地用作用于各种任务的网络输入的特征或表示,包括基于特征的检索、聚类、近似重复检测、验证、特征匹配、域适应和基于视频的弱监督学习等。

在一些特定示例中,根据本文描述的训练方法训练的神经网络可以用于图像分类。更具体地,在这些示例中,通过以这种方式训练神经网络,由神经网络生成的嵌入可以有效地用于大规模分类任务,即,其中类的数量非常大并且每个类的示例的数量变得稀缺的任务。在此设置中,任何直接分类或回归方法由于可能数量大的类而变得不切实际。然而,所描述的训练技术允许例如通过用相应的中心点表示每个类并确定最接近网络输入的嵌入的中心点来使用由训练的神经网络生成的嵌入来将网络输入精确地分类为类之一。

另外,用于训练神经网络以生成嵌入的许多传统方法在训练数据可用于训练神经网络之前需要对训练数据进行计算密集的预处理。例如,许多现有技术需要单独的数据准备阶段,其中,必须首先成对地准备训练数据,即,每对包括三元组的正和负示例,即每个三元组在训练数据可用于训练之前包括锚示例、正示例和负示例,或者是n对元组格式。该过程具有非常昂贵的时间和空间成本,因为它通常需要复制训练数据并且需要重复访问磁盘以确定如何格式化训练示例。相比之下,本说明书中描述的训练技术在训练中使用一批次的训练项目之前几乎不需要或根本不需要预处理,在仍然如上所述训练网络以有效地生成嵌入的同时,减少了训练神经网络所需的计算成本和时间。

在附图和以下描述中阐述了本说明书中描述的主题的一个或多个实施例的细节。根据说明书、附图和权利要求,本主题的其他特征、方面和优点将变得显而易见。

附图说明

图1示出了示例神经网络训练系统。

图2是使用聚类损失训练神经网络的示例过程的流程图。

图3是用于确定对神经网络的参数的当前值的更新的示例过程的流程图。

各附图中相同的附图标记和名称表示相同的元件。

具体实施方式

图1示出了示例性神经网络训练系统100。神经网络训练系统100是在一个或多个位置中的一个或多个计算机上实现为计算机程序的系统的示例,其中,可以实现下面描述的系统、组件和技术。

神经网络训练系统100是在训练数据140上训练神经网络110以从网络参数的初始值确定神经网络110的参数(在本说明书中称为网络参数)的训练值的系统。

神经网络110是被配置为接收输入数据项目102并处理输入数据项目以根据网络参数生成输入数据项目102的嵌入112的神经网络。通常,数据项目的嵌入是表示数据项目的数值的有序集合,例如矢量。换句话说,每个嵌入是多维嵌入空间中的一个点。一旦被训练,由神经网络110生成的多维空间中的嵌入的位置可以反映嵌入表示的数据项目之间的相似性。

神经网络110可以被配置为接收任何类型的数字数据输入作为输入并且从输入生成嵌入。例如,输入数据项目(也称为网络输入)可以是图像、文档的部分、文本序列和音频数据等。

神经网络110可以具有适合于由神经网络110处理的网络输入类型的任何架构。例如,当网络输入是图像时,神经网络110可以是卷积神经网络。例如,嵌入可以是已经预先训练用于图像分类的卷积神经网络(例如,在c.szegedy,w.liu,y.jia,p.sermanet,s.reed,d.anguelov,d.erhan,v.vanhoucke,anda.rabinovich,goingdeeperwithconvolutions,incvpr,2015中描述的初始网络)的中间层的输出。

一旦被训练,由网络110生成的嵌入可用于各种目的中的任何目的。

例如,系统100可以将由训练的神经网络生成的嵌入作为相应网络输入的特征提供为另一系统的输入,例如,用于在该网络输入上执行机器学习任务。示例任务可以包括基于特征的检索、聚类、近似重复检测、验证、特征匹配、域自适应和基于视频的弱监督学习等。

作为另一示例,系统100可以使用由训练的神经网络生成的嵌入来对相应的网络输入进行分类。特别地,对于一组多个可能类中的每一个,系统可以保持识别相应的中心点的数据,即嵌入空间中的相应代表点。然后,系统100可以将网络输入分类为属于由最接近由训练的神经网络对于该网络输入生成的嵌入的中心点所表示的类。

由系统100用于训练神经网络110的训练数据140包括多批次训练输入和用于训练输入的真实聚类分配。真实聚类分配将每个训练输入分配到聚类集中的相应聚类中。例如,在分类上下文中,聚类集可以包括网络输入可以被分类为的每个可能的类别或类的相应聚类,并且真实分配将每个数据项目分配给数据项目应该被分类为的类别或类的聚类。

系统100通过优化聚类目标150来在训练数据140上训练神经网络110。具体地,聚类目标150是针对给定批次的多个训练输入因为下述情况而惩罚神经网络110的目标:对于除该批次的真实分配之外的每个可能的聚类分配,产生不会导致用于该批次的oracle聚类分值比用于该可能的聚类分配的聚类分值高至少在该可能的聚类分配和真实分配之间的结构化差额的嵌入。

每个可能的聚类分配通过指定一组聚类中心点(即,嵌入空间中的一组代表点)来定义批次中的训练示例的聚类,该一组聚类中心点包括用于该组聚类中的每个聚类的一个中心点。然后,聚类分配将批次中的每个训练项目分配给最接近训练项目的嵌入的中心点。

给定聚类分配的聚类分值(也称为设施位置分值)测量批次中的训练项目的嵌入每个与该嵌入的最接近的中心点的接近程度。特别地,在一些实现中,生成聚类分值的设施位置函数f满足:

其中,|x|是批次中的训练输入x的总数,总和是遍及批次中的所有训练输入的总和,s是给定聚类分配中的中心点集,对于第i个训练输入xi,f(xi;θ)是根据网络参数θ生成的训练输入的嵌入,并且是从中心点集中的最近中心点到该训练输入的嵌入的距离。

给定真实聚类分配和网络参数,oracle聚类分值测量批次中的训练示例的聚类质量,即根据网络参数生成的批次中的训练输入的嵌入定义的聚类的质量。特别是,生成oracle聚类分值的oracle聚类函数表示为:

其中,|γ|是真实聚类分配y*中的聚类总数,总和是遍及真实聚类分配中的所有聚类的总和,i∶y*[i]=k是批次中的训练示例的通过真实聚类分配被分簇为聚类k的子集,并且对于聚类k,是当聚类k的中心点是聚类k中的任何训练项目的任何嵌入时生成的聚类分值中的仅聚类k中的训练项目的最大聚类分值。

给定的可能聚类分配和真实聚类分配之间的结构化差额测量可能聚类分配相对于真实分配的的质量。特别地,在一些实现中,结构化差额基于可能的聚类分配和真实分配之间的标准化互信息度量。特别地,在这些实现中,聚类分配y和真实分配y*之间的结构化差额δ表示为:

δ(y,y*)=1-nmi(y,y*),

其中,nmi(y,y*)是两个分配之间的标准化互信息并且满足:

其中,mi是两个分配之间的互信息,并且h是分配的熵。

通常,互信息和熵都基于两个分配中的聚类的边际概率和一个分配中的一个聚类与另一个分配中的另一个聚类之间的联合概率。用于计算熵和互信息的边际概率可以对于给定的聚类和给定的分配被估计为由给定的分配分配给给定聚类的训练项目的分值。用于计算第一分配中的聚类i和第二分配中的聚类j之间的熵和互信息的联合概率可以被估计为通过第一分配分配给聚类i并且通过第二分配到聚类j的训练项目的分值。

一批次的训练输入x和用于该批次的真实聚类分配y*的聚类损失函数可以然后满足:

其中,最大值遍及可能的聚类分配即与真实分配中存在聚类具有相同数量的中心点的可能的中心点集,γ是正常数值,g(s)是向由分配s中最接近训练项目的嵌入的中心点表示的聚类分配每个训练项目的函数,并且[a]+如果a小于或等于零则等于0,如果a大于0则等于a。用于给定聚类分配的项f(x,s;θ)+γδ(g(s),y*)在本说明书中将被称为聚类分配的增强聚类分值。

下面参考图2和图3更详细地描述在该目标上训练神经网络。

一旦训练了神经网络,系统100就提供指定训练的神经网络以用于处理新网络输入的数据。也就是说,系统100可以例如通过输出到用户设备或者通过在系统100可访问的存储器中存储网络参数的训练值,以便稍后用于使用训练的神经网络处理输入。作为输出训练的神经网络数据的替代或补充,系统100可以实例化具有网络参数的训练值的神经网络的实例,例如通过由系统提供的应用程序编程接口(api)接收要处理的输入,使用训练的神经网络处理接收的输入以生成嵌入,并且然后响应于接收的输入提供生成的嵌入。

图2是用于在一批次训练数据上训练神经网络的示例过程200的流程图。为方便起见,过程200将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,被适当编程的神经网络训练系统(例如,图1的神经网络训练系统100)可以执行过程200。

系统可以针对多个不同批次的训练项目多次执行过程200,以从网络参数的初始值确定网络参数的训练值。

系统获得一批次的训练项目和对该批次中的训练项目到多个聚类的真实分配(步骤202)。真实分配将该批次中的每个训练项目分配给来自聚类集的相应聚类。

系统使用神经网络并根据网络参数的当前值处理该批次中的每个训练项目,以为每个训练项目生成相应的嵌入(步骤204)。

系统基于该批次中的训练项目的嵌入来确定用于真实分配的oracle聚类分值(步骤206)。如上所述,oracle聚类分值在给定真实聚类分配和网络参数的情况下测量聚类(即,由根据网络参数的当前值生成的嵌入所定义的聚类)的质量。

系统通过执行神经网络训练过程的迭代来调整网络参数的当前值以使用oracle聚类分值来优化(即,最小化)聚类目标(步骤208)。通常,训练过程从聚类目标相对于参数的梯度确定对参数的当前值的更新,然后将该更新应用于(例如,添加到)当前值以确定参数的更新值。例如,训练过程可以是随机梯度下降,并且系统可以将梯度乘以学习速率以确定更新,然后将更新添加到网络参数的当前值。下面参考图3更详细地描述确定聚类目标的梯度。

图3是用于确定对网络参数的当前值的更新的示例过程300的流程图。为方便起见,过程300将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,被适当编程的神经网络训练系统(例如,图1的神经网络训练系统100)可以执行过程300。

系统可以在一批次的训练输入上训练神经网络期间执行过程300,以确定对用于该批次的网络参数的当前值的更新。然后,系统可以应用(即,添加)针对该批次中的输入确定的更新,以生成网络参数的更新值。

系统确定除真实分配之外的具有最高增强聚类分值的可能聚类分配(步骤302)。

如上所述,增强聚类分值是可能聚类分配的聚类分值加上在可能聚类分配和真实分配之间的结构化差额。

在一些实现中,为了确定最高评分聚类分配,系统使用迭代损失增强推理技术确定初始最佳可能聚类分配。特别地,在推理技术的每次迭代中,系统将新中心点添加到聚类分配。也就是说,系统以具有零中心点的聚类分配开始,即,不向任何聚类分配任何嵌入,并继续添加中心点,直到聚类分配中的中心点数量等于在真实分配中的中心点的数量即聚类的数量。

在推理技术的每个步骤中,系统向当前聚类分配添加最大增加聚类分配的增强聚类分值的中心点。

然后,系统使用损失增强细化技术修改初始最佳可能聚类分配,以确定最高评分的可能聚类分配。特别地,系统执行细化技术的多次迭代,以从初始最佳可能聚类分配确定最高评分的可能聚类分配。

在每次迭代并且对于当前最佳可能聚类分配中的每个聚类,系统确定通过根据当前最佳可能聚类分配执行聚类的当前中心点与同一聚类中的替代点的成对交换修改当前最佳可能聚类分配是否将增大聚类分配的增强聚类分值,如果是,则将当前中心点交换为替代点以更新聚类分配。

要执行的损失增强细化技术的迭代次数可以是固定的,例如,三次、五次或七次迭代,或者系统可以继续执行该技术,直到逐点交换都不会改善增强的聚类分值为止。

在一些实现中,对于迭代损失增强推理技术和损失增强细化技术两者,系统仅考虑批次中的训练项目的嵌入作为候选中心点。在一些其他实现中,系统还考虑嵌入空间中不是该批次中的训练项目的嵌入的点,例如,空间中的可能点的整个空间或可能点的预定离散子集。

系统使用最高评分聚类分配来确定聚类目标相对于网络参数的梯度(步骤304)。

特别地,系统可以将梯度确定为目标函数的增强聚类评分函数项相对于网络参数的梯度与目标函数的oracle评分函数项相对于网络参数的梯度之间的差异。也就是说,总梯度是第一梯度项即增强聚类评分函数的梯度减去第二梯度项即oracle评分函数的梯度。

更具体地说,第一个梯度项满足:

其中,总和遍及批次中的所有训练输入的总和,f(xi;θ)是第i个训练输入xi的嵌入,是在最高评分分配中的中心点中最接近于嵌入f(xi;θ)的中心点,并且是相对于网络参数的梯度。

第二个梯度项满足:

其中,总和遍及真实分配中的聚类,并且是通过真实聚类分配分配给聚类k的训练输入相对于最高评分分配中的聚类k的中心点的聚类分值。

系统可以使用传统的神经网络训练技术,即通过经由神经网络反向传播梯度,确定相对于神经网络的所有参数的第一梯度项和第二梯度项中的梯度。

在一些实现中,如果批次的损失,即批次的聚类损失函数l(x,y*)的值大于零,则系统仅如上所述计算梯度。如果损失小于或等于零,则系统将梯度设置为零,并且不更新网络参数的当前值。

然后,系统从聚类目标的梯度确定对网络参数的当前值的更新(步骤306)。例如,系统可以通过将学习速率应用于梯度来确定更新。

本说明书结合系统和计算机程序组件使用术语“被配置”。对于被配置为执行特定操作或动作的一个或多个计算机的系统意味着系统已经在其上安装了软件、固件、硬件或它们的组合,其在运行中使得系统执行操作或动作。对于被配置为执行特定操作或动作的一个或多个计算机程序,意味着一个或多个程序包括当由数据处理装置执行时使装置执行操作或动作的指令。

本说明书中描述的主题和功能操作的实施例可以被实现在数字电子电路中、在有形地实施的计算机软件或固件中、在计算机硬件(包括本说明书中公开的结构及其结构等同物)中或在它们的一个或多个的组合中。本说明书中描述的主题的实施例可以被实现为一个或多个计算机程序,即编码在有形非暂时性程序载体上的计算机程序指令的一个或多个模块,用于由数据处理装置执行或控制数据处理装置的操作。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行存取存储器设备或它们中的一个或多个的组合。替代地或补充地,程序指令可以编码在人工生成的传播信号上,例如机器生成的电、光或电磁信号,其被生成以编码信息以便传输到合适的接收器装置以供数据处理装置执行。

术语“数据处理装置”指代数据处理硬件,并且涵盖用于处理数据的所有种类的装置、设备和机器,例如包括可编程处理器、计算机或多个处理器或计算机。所述装置也可以是或还包括专用逻辑电路,例如,fpga(现场可编程门阵列)或asic(专用集成电路)。除了硬件之外,所述装置可以选用地包括创建用于计算机程序的执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。

计算机程序(也称为程序、软件、软件应用、app、模块、软件模块、脚本或代码)可以以任何形式的编程语言编写,该任何形式的编程语言包括编译或解释语言或者声明性或过程语言,并且该计算机程序可以以任何形式部署,包括作为独立程序或作为适于在计算环境中使用的模块、组件、子例程或其他单元。程序可以但不需要对应于文件系统中的文件。程序可以存储在保存其他程序或数据(例如,存储在标记语言文档中的一个或多个脚本)的文件的一部分中、在专用于所涉及的程序的单个文件中或者在多个协同文件中(例如,存储一个或多个模块、子程序或代码部分的文件)。计算机程序可以被部署为在一个计算机上或在位于一个地点或分布在多个地点并通过通信网络互连的多个计算机上执行。

在本说明书中,术语“数据库”广泛用于指代任何数据集合:数据不需要以任何特定方式结构化或根本不结构化,并且它可以在一个或多个位置的存储设备上存储。因此,例如,索引数据库可以包括多个数据集合,每个数据集合可以被不同地组织和访问。

类似地,在本说明书中,术语“引擎”广泛用于指代被编程为执行一个或多个特定功能的基于软件的系统、子系统或过程。通常,引擎将被实现为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件。在某些情况下,一台或多台计算机将专用于特定的引擎;在其他情况下,可以在同一台计算机或多个计算机上安装和运行多个引擎。

本说明书中描述的过程和逻辑流程可以由一个或多个可编程处理器执行,该一个或多个可编程处理器执行一个或多个计算机程序以通过对输入数据进行操作并生成输出来执行动作。过程和逻辑流程也可以由专用逻辑电路(例如fpga或asic)或专用逻辑电路和一个或多个编程计算机的组合来执行。

适合于执行计算机程序的计算机可以基于通用或专用微处理器或两者,或任何其他种类的中央处理单元。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的基本元件是用于执行指令的中央处理单元和用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器可以由专用逻辑电路补充或并入专用逻辑电路中。通常,计算机还将包括或可操作地耦合以从一个或多个大容量存储设备接收数据或将数据传输到一个或多个大容量存储设备,所述一个或多个大容量存储设备用于存储数据,例如是磁盘、磁光盘或光盘。但是,计算机不需要这样的设备。此外,计算机可以嵌入在另一个设备中,例如移动电话、个人数字助理(pda)、移动音频或视频播放器、游戏控制台、全球定位系统(gps)接收器或便携式存储设备(例如,通用串行总线(usb)闪存驱动器),此处仅举几例。

适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,包括例如:半导体存储器设备,例如eprom、eeprom和闪存设备;磁盘,例如内部硬盘或可移动磁盘;磁光盘;以及,cdrom和dvd-rom盘。

为了提供与用户的交互,本说明书中描述的主题的实施例可以在具有用于向用户显示信息的显示设备(例如,crt(阴极射线管)或lcd(液晶显示器)监视器)和键盘以及指示设备(例如,鼠标或轨迹球)的计算机上实现,用户可以通过显示设备和键盘以及指示设备向计算机提供输入。其他类型的设备也可用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感觉反馈,例如视觉反馈、听觉反馈或触觉反馈;并且可以以任何形式接收来自用户的输入,包括声学、语音或触觉输入。另外,计算机可以通过下述方式来与用户交互:向用户使用的设备发送文档和从用户使用的设备接收文档;例如,响应于从web浏览器接收的请求将网页发送到用户设备上的web浏览器。此外,计算机可以通过向个人设备(例如,运行消息收发应用的智能电话)发送文本消息或其他形式的消息并且从用户接收响应消息作为回报来与用户交互。

用于实现机器学习模型的数据处理装置还可以包括例如专用硬件加速器单元,用于处理机器学习训练或生产的公共和计算密集部分,即推断、工作负载。

可以使用机器学习框架(例如,tensorflow框架、microsoft认知工具包框架、apachesinga框架或apachemxnet框架)来实现和部署机器学习模型,

本说明书中描述的主题的实施例可以实现在计算系统中,该计算系统包括诸如作为数据服务器的后端组件,或者包括诸如应用服务器的中间件组件,或者包括诸如具有图形用户界面、web浏览器或app的客户端计算机的前端组件,或者包括一个或多个这样的后端、中间件或前端组件的任何组合,用户可以通过该图形用户界面、web浏览器或app与本说明书中描述的主题的实现交互。系统的组件可以通过任何形式或介质的数字数据通信例如通信网络互连。通信网络的示例包括局域网(“lan”)和广域网(“wan”),例如因特网。

计算系统可以包括客户端和服务器。客户端和服务器通常彼此远离并且通常通过通信网络交互。客户端和服务器的关系借助于在相应计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器将数据(例如,html页面)发送到用户设备,例如,用于向与作为客户端的设备交互的用户显示数据和从该用户接收用户输入的目的。可以在服务器处从设备接收在用户设备处生成的数据,例如,用户交互的结果。

虽然本说明书包含许多具体实施细节,但是这些不应被解释为对任何发明的范围或对所要求保护内容的范围的限制,而是作为可以对特定发明的特定实施例特定的特征的描述。在本说明书中在单独实施例的上下文中描述的某些特征也可以在单个实施例中组合实现。相反,在单个实施例的上下文中描述的各种特征也可以在多个实施例中单独地或以任何合适的子组合来实现。此外,虽然特征可以在上面描述为在某些组合中起作用并且甚至最初如此要求保护,但是来自所要求保护的组合的一个或多个特征在一些情况下可以从组合中删除,并且所要求保护的组合可以涉及子组合或子组合的变体。

类似地,虽然在附图中描绘了并且在权利要求中以特定顺序叙述了操作,但是这不应被理解为要求这些操作以所示的特定顺序或以顺序次序执行,或者所有所示的操作被执行,以实现期望的结果。在某些情况下,多任务和并行处理可能是有利的。此外,上述实施例中的各种系统模块和组件的分离不应被理解为在所有实施例中都需要这样的分离,并且应当理解,所描述的程序组件和系统通常可以一起集成在单个软件产品中,或者封装成多个软件产品。

因此,已经描述了主题的特定实施例。其他实施例在所附权利要求的范围内。例如,权利要求中所述的动作可以以不同的顺序执行并且仍然实现期望的结果。作为示例,附图中描绘的过程不一定需要所示的特定顺序或顺序的顺序以实现期望的结果。在某些情况下,多任务和并行处理可以是有利的。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1