本发明属于图像处理技术领域,特别涉及一种零样本草图检索方法,可用于电子商务、医疗诊断、遥感成像。
背景技术:
草图检索是指根据手绘草图检索真实的自然图像。零样本草图检索方法是一种对未知类别的手绘草图进行真实自然图像检索的方法。现有的草图检索方法主要分为两类:基于人工设计的特征和基于深度学习的方法。其中基于人工设计特征方法包括梯度场hog描述子、sift描述子,而基于深度学习的方法则包括孪生网络、三元组网络、深度草图哈希等,它们的主要思想都是提取图像或者文本信息的判别性特征,然后投影到共同的特征空间中进行相似性度量。但是现有的草图检索方法的前提是所有类别在训练阶段必须都是已知的,这样就无法保证训练数据的规模能够覆盖现实场景中的所有类别,所以当测试未见类别时,检索性能将急剧下降。同时不同的人对草图有不同的理解,导致绘制的草图的类内方差较大,草图检索的任务也更具挑战性。
零样本草图检索就是在零样本的设置下实现从已知类别到未见类别的视觉知识迁移,从而解决现有草图检索的问题。当前,研究人员已经提出两种零样本草图检索的方法,例如,yumingshen和liliu等人在2018年的computervisionandpatternrecognition会议上发表的名为“zero-shotsketch-imagehashing”的文章,公开了一种零样本草图哈希检索方法,该方法构建了一个端到端的三网络框架,其中前两个网络为二进制编码器,第三个网络利用克罗内克融合层和图卷积,减轻草图图像的异质性,增强数据间的语义关系,同时文章还提出了一种哈希生成方法,用于重建零样本检索的语义知识表示;sasikiranyelamarthi等人在2018年的europeanconferenceoncomputervision会议上发表的名为“azero-shotframeworkforsketch-basedimageretrieval”的文章,公开了一种基于对抗自动编码器和变分自动编码器的深度条件生成模型的方法,该方法将草图特征向量作为输入,使用生成模型随机填充缺失的信息来生成自然图像特征向量,然后利用这些生成的自然图像特征向量从数据库中检索图像。尽管上述方法取得了良好的性能,但这两种方法由于都没有考虑到草图类内方差较大的问题,因而通过预训练的卷积神经网络提取出的语义信息的判别能力较弱,且难以准确地将草图的视觉知识从已知类迁移到未见类。
技术实现要素:
本发明的目的在于克服上述已有技术存在的不足,提出一种基于语义对抗网络的零样本草图检索方法,以通过预训练的卷积神经网络提取出较好的判别性语义信息,准确地将草图的视觉知识从已知类迁移到未见类。
本发明的技术思路是,通过采用端到端的语义对抗网络中的语义对抗模块来学习草图的语义特征,降低了草图特征的类内方差;通过在生成模块中加入三元组损失,保证每个类别中生成的rgb图像特征的可判别性,从而解决了零样本设置下视觉知识难以从已知类迁移到未见类的问题。
根据上述思路,本发明的实现步骤包括如下:
(1)获取训练样本集:
(1a)从sketchy草图检索数据库中分别提取10,400幅rgb图像和对应的10,400幅二值草图图像组成成对的第一训练样本;从tu-berlin草图检索数据库中分别提取138,839幅rgb图像和138,839幅对应类别的二值草图图像组成成对的第二训练样本;
(1b)对提取的所有298,478张图片都进行随机水平翻转,得到298,478张随机水平翻转后的图像;
(1c)对298,478张随机水平翻转后的图像重新调整大小至224×224,并将得到的298,478张图像分别组成包含第一训练样本的训练样本集s1和包含第二训练样本的训练样本集s2;
(2)构建语义对抗网络:
设置由语义特征提取网络、词嵌入网络、语义判别器组成语义对抗网络,其中:
语义特征提取网络,用于提取二值草图图像的语义特征;
词嵌入网络,用于提取二值草图图像所对应的类别信息的词向量;
语义判别器,用于对提取出的草图图像的语义特征和对应类标的词向量进行对抗学习,通过一个对抗损失ladv(θs,θd)来更新语义特征提取网络的参数,提升输出草图图像语义特征的判别性;
语义对抗网络中的语义特征提取网络和词嵌入网络的输出都输入到语义判别器中进行对抗学习;
(3)对训练样本集中的rgb图像进行特征提取:
(3a)使用在imagenet数据集上预训练的vgg16网络对第一训练样本集中的rgb图像进行特征提取,选取该网络中第二全连接层的输出作为第一训练样本集最终的rgb图像特征,该图像特征的维度为4096;
(3b)使用在imagenet数据集上预训练的vgg16网络对第二训练样本集中的rgb图像进行特征提取,选取该网络中第二全连接层的输出作为第二训练样本集最终的rgb图像特征,该图像特征的维度为4096;
(4)构建生成网络:
构建依次由concatenate层、条件编码器、三元组损失层、kl损失层、解码器、图像重建损失层、回归器和语义重建损失层组成的生成网络,其中:
concatenate层,用于对语义特征提取网络的输出草图语义特征向量xsem和rgb图像特征向量ximg进行维度上的拼接;
条件编码器,用于将concatenate层输出作为输入,使数据分布p(ximg,xsem)通过条件编码器后得到隐藏潜在变量z的先验分布p(z),计算p(z)先验分布的均值向量μ和标准差向量σ;
三元组损失层,用于保持每个训练类别内生成特征的判别性,将条件编码器的均值向量输出μ作为输入,使用三元组损失函数对编码器进行训练,该损失层的损失函数为ltri;
kl损失层,用于使得数据分布p(ximg,xsem)与变分分布q(z|ximg,xsem)近似,然后通过对损失函数lkl的最小化确定变分下界;
解码器,用于将维度为1024的潜在向量z和学习得到维度为300的语义特征xsem进行拼接作为输入,以生成草图图像对应的rgb图像特征
其中,noise表示随机高斯噪声z~n(0,1),噪声维度为1024,
图像重建损失层,用于保证生成的rgb图像特征具有足够的判别性,使用重建损失函数:
回归器,用于将解码器的输出
其中,noise表示随机高斯噪声z~n(0,1),噪声维度为1024,
语义重建损失层,用于保证生成的rgb图像特征
(5)对语义对抗网络和生成网络进行训练:
(5a)对语义对抗网络和生成网络进行初始化,随机初始化时采用的网络参数服从均值为0、标准差为0.1的高斯分布,得到初始化的语义对抗网络和生成网络;
(5b)设整体网络的损失函数为l=ladv+ltri+lkl+lrecon_img+lrecon_sem;
(5c)将经过步骤1预处理后的草图图像及其对应的类别信息作为初始化的语义对抗网络的输入数据,输出草图对应的语义特征,将草图对应的语义特征和使用预训练的vgg16网络提取出的rgb图像特征作为生成网络的输入数据,通过对损失函数l的最小化实现对语义对抗网络和生成网络的训练,得到训练好的语义对抗网络和生成网络;
(6)对待检索的草图图像进行零样本草图检索:
(6a)从与训练样本集类别不相交的测试样本集中提取草图图像,对草图图像进行裁剪后得到待检索的草图图像;
(6b)将待检索的草图图像输入到训练好的语义特征提取网络中,输出草图图像对应的语义特征向量;
(6c)将语义特征向量和随机高斯噪声进行拼接输入到训练好的生成网络中,经过编码器和解码器生成多个草图对应的rgb图像特征;
(6d)取多个生成的rgb图像特征的平均值作为最终rgb图像特征,再根据余弦距离在图像检索库中寻找与生成的最终rgb图像特征最相似的前200张图像。
本发明与现有技术相比,具有如下优点:
本发明在训练阶段借助类别级语义信息的优点,采用端到端的语义对抗网络中的语义对抗模块来学习草图的语义特征,从而降低了草图图像特征的类内方差;并且在生成网络中加入三元组损失,保证每个类别中生成的rgb图像特征的可判别性,从而解决了零样本设置下视觉知识难以从已知类迁移到未见类的问题。
与现有技术相比,本发明简化了训练过程并有效提高了零样本草图检索的检索性能。
附图说明
图1是本发明的实现流程图;
图2是本发明与现有方法的检索结果对比图。
具体实施方案
以下结合附图和具体实施,对本发明作进一步详细描述:
参照图1,本发明基于语义对抗网络的零样本草图检索方法,其实现步骤包括如下:
步骤1,获取训练样本集。
1.1)从sketchy草图检索数据库中分别提取10,400幅rgb图像和对应的10,400幅二值草图图像组成成对的第一训练样本;从tu-berlin草图检索数据库中分别提取138,839幅rgb图像和138,839幅对应类别的二值草图图像组成成对的第二训练样本;
1.2)对提取的所有298,478张图片都进行随机水平翻转,得到298,478张随机水平翻转后的图像;
1.3)对298,478张随机水平翻转后的图像重新调整大小至224×224,并将得到的298,478张图像分别组成包含第一训练样本的训练样本集s1和包含第二训练样本的训练样本集s2:
其中,
步骤2,构建语义对抗网络。
设置由语义特征提取网络、词嵌入网络、语义判别器组成的语义对抗网络,其中:
语义特征提取网络,用于提取二值草图图像的语义特征,具体是在imagenet上预训练的vgg16网络,选取vgg16网络的第五卷积层作为卷积输出,通过一个全连接层输出维度为300的语义特征向量;
词嵌入网络,用于提取二值草图图像所对应的类别信息的词向量,采用在维基百科上预训练的词向量模型,以获取维度为300的类别级词向量表示;
语义判别器,用于对提取出的草图图像的语义特征和对应类标的词向量进行对抗学习,通过一个对抗损失函数来更新语义特征提取网络的参数,提升输出草图图像语义特征的判别性,该损失函数ladv(θs,θd)的数学表达式为:
其中,
语义对抗网络中的语义特征提取网络和词嵌入网络的输出都输入到语义判别器中进行对抗学习。
步骤3,对训练样本集中的rgb图像进行特征提取。
3.1)使用在imagenet数据集上预训练的vgg16网络对第一训练样本集中的rgb图像进行特征提取,选取该网络中第二全连接层的输出作为第一训练样本集最终的rgb图像特征,该图像特征的维度为4096;
3.2)使用在imagenet数据集上预训练的vgg16网络对第二训练样本集中的rgb图像进行特征提取,选取该网络中第二全连接层的输出作为第二训练样本集最终的rgb图像特征,该图像特征的维度为4096。
步骤4,构建生成网络。
构建依次由concatenate层、条件编码器、三元组损失层、kl损失层、解码器、图像重建损失层、回归器和语义重建损失层组成的生成网络,其中:
所述concatenate层,用于对语义特征提取网络的输出维度为300的草图语义特征向量xsem和维度为4096的rgb图像特征向量ximg进行维度上的拼接,输出维度为4396的特征向量;
所述条件编码器,由依次为输入维度为4396,输出维度为4096的第一全连接层、非线性激活层relu、动量参数为0.99和eps=1e-3的一维批规范化层、失活率为0.3的dropout层、输出维度为2048的第二全连接层、非线性激活层relu、动量参数为0.99和eps=1e-3的一维批规范化层组成,用于将concatenate层输出作为输入,使数据分布p(ximg,xsem)通过条件编码器后得到均值向量μ和标准差向量σ,形成隐藏潜在变量z的先验分布p(z);
所述三元组损失层,用于保持每个训练类别内生成特征的判别性,将条件编码器的均值向量输出μ作为输入,使用三元组损失函数对编码器进行训练,该三元组损失函数ltri的数学表达式为:
其中,d(·,·)表示
所述kl损失层,用于使得数据分布p(ximg,xsem)与变分分布q(z|ximg,xsem)近似,然后通过对损失函数lkl的最小化确定变分下界,lkl的数学表达式为:
其中,
所述解码器,由依次为输入维度为1324,输出维度为4096的第一全连接层、非线性激活层relu、输出维度为4096的第二全连接层和非线性激活层relu组成,用于将维度为1024的潜在向量z和学习得到维度为300的语义特征xsem进行拼接作为输入,以生成草图图像对应的rgb图像特征
其中,noise表示随机高斯噪声z~n(0,1),噪声维度为1024,
所述图像重建损失层,用于保证生成的rgb图像特征具有足够的判别性,使用重建损失函数:
所述回归器,由依次为输入维度为4096,输出维度为2048的第一全连接层、非线性激活层relu、输出维度为300的第二全连接层和非线性激活层tanh组成,用于将解码器的输出
其中,noise表示随机高斯噪声z~n(0,1),噪声维度为1024,
所述语义重建损失层,用于保证生成的rgb图像特征能保存类别级语义信息,该层损失函数为:
步骤5,对语义对抗网络和生成网络进行训练。
5.1)对语义对抗网络和生成网络进行初始化,随机初始化时采用的网络参数服从均值为0、标准差为0.1的高斯分布,得到初始化的语义对抗网络和生成网络;
5.2)设整体网络的损失函数为:l=ladv+ltri+lkl+lrecon_img+lrecon_sem;
5.3)将经过步骤1预处理后的草图图像及其对应的类别信息作为初始化的语义对抗网络的输入数据,输出草图对应的语义特征,将草图对应的语义特征和使用预训练的vgg16网络提取出的rgb图像特征作为生成网络的输入数据,通过对损失函数l的最小化实现对语义对抗网络和生成网络的训练,且在训练网络时采用深度学习工具箱pytorch中的adam优化器,其初始学习率为0.0001,β1=0.5,β2=0.99,同时为了训练的稳定性,在前2次训练中交替训练语义对抗网络和生成网络,在之后的18次训练中以端到端的方式训练整个网络,总共训练20次,得到训练好的语义对抗网络和生成网络。
步骤6,对待检索的草图图像进行零样本草图检索。
6.1)从与训练样本集类别不相交的测试样本集中提取草图图像,对草图图像进行裁剪后得到待检索的草图图像;
6.2)将待检索的草图图像输入到训练好的语义特征提取网络中,输出草图图像对应的语义特征向量;
6.3)将语义特征向量和随机高斯噪声进行拼接输入到训练好的生成网络中,经过编码器和解码器生成多个草图对应的rgb图像特征;
6.4)取多个生成的rgb图像特征的平均值作为最终rgb图像特征,再根据余弦距离在图像检索库中寻找与生成的最终rgb图像特征最相似的前200张图像,最后根据这200张检索图像计算检索精度。
以下结合仿真实验,对本发明的技术效果作进一步说明。
1.仿真条件:
本发明使用型号为nvidiagtxtitanv的gpu,基于深度学习的工具箱pytorch进行仿真实验。
2.仿真内容:
本发明在两个公开的专门用于草图检索方法性能测试的数据集sketchy、tu-berlin上进行仿真实验,其中:
数据集sketchy包含来自125个不同类别的75,479张草图图像和73,002张rgb图像,根据标准零样本学习的实验设置,将125个类别中的104个训练类作为已知类,21个测试类作为未见类;
数据集tu-berlin包含来自250个不同类别的20,000张草图图像和204,070张rgb图像,根据标准零样本学习的实验设置,将250个类别中的194个训练类作为已知类,56个测试类作为未见类。
用本发明和现有基于深度卷积神经网络的草图检索方法、零样本学习方法,在上述两个公开数据集sketchy和tu-berlin上进行仿真对比实验,结果如表1。
表1
表1中的精度@200和map@200分别为前200张检索图像的精度和平均精度均值。
由表1的仿真结果可见,本发明在两个数据集上的精度和平均精度均值都高于现有技术在两个数据集上的精度和平均精度均值。
在sketchy数据集上,对本发明和现有最好cvae方法的检索结果进行可视化,在检索的前200张图片中取前10张进行对比,结果如图2所示。
由图2可见,对于3种不同测试类别的草图图片进行检索,本发明的前10张检索图片与草图图片均属于同一类别,而cvae方法的检索结果中有检索出错的图片出现。