本发明涉及spark平台技术领域,更具体地,涉及一种稳定的基于spark平台的矩阵求逆算法。
背景技术:
由于spark平台在2012年被提出,主要集中在大数据处理领域,尚不支持大型矩阵的求逆运算,而在spark平台上的大型矩阵求逆算法的研究也非常少,只有2016年liu等人提出的基于递归lu分解的算法以及2018年misra等人提出的基于strassen的递归求逆算法(spin),本发明基于后者,即是spin。
strassen算法由strassen发表于1969年,它应用于矩阵乘法的方法因为把矩阵乘法的复杂度从o(n3)降低到o(n2.8)而广为人知,strassen同时提出的求逆算法则没有那么高的知名度。类似于strassen矩阵乘法,对于
有
那么c矩阵可以通过以下步骤计算得出:
p2=a21×p1
p3=p1×a12
p4=a21×p3
p5=p4-a22
c12=p3×p6
c21=p6×p2
c11=p1-p3×c21
c22=-p6
对子矩阵a11递归使用该算法(spin),到达递归的顶点时,使用已有求逆算法(如lu分解)求解矩阵的逆,执行所有步骤后即可得到矩阵c。strassen求逆算法的伪代码如图1所示,图2给出了spark版本的strassen矩阵求逆算法伪代码。
spark版本的strassen求逆算法,主要涉及到六个函数:
■breakmat:把矩阵分成4个块
■xy:即图2中的_11、_12、_21、_22等四个函数,作用是获取对应位置的子矩阵块
■multiply:矩阵乘法,使用spark自带的矩阵乘法
■subtract:矩阵减法,使用spark自带的矩阵减法
■scalarmul:矩阵乘以一个常数
■arrange:把四个矩阵块合并成一个大矩阵
然而,上述算法是可能导致数值稳定性出问题的。由于传统strassen算法,是递归的对单个子矩阵(a11)做求逆操作,而其余子矩阵(a12,a21,a22)均由其他矩阵运算得出,所以子矩阵a11的状态尤为关键。一般来说,子矩阵a11的状态可以总结为如下3点:
1.当子矩阵a11是良态(well-conditioned)的时候,数值稳定性较好
2.当子矩阵a11是近奇异(near-singular)的时候,数值稳定性较差
3.当子矩阵a11是奇异(sigular)的时候,strassen求逆算法将会失败
情况2和3的矩阵也叫做病态(ill-conditioned)矩阵。可见,只有当子矩阵是良态的时候,才能得到较好的数值稳定性,而misra实现的方法并没有对strassen矩阵求逆算法的数值稳定性不好的问题进行优化。
技术实现要素:
本发明对misra实现的spark平台上的strassen矩阵求逆算法进行改进,混合使用传统算法以及旋转算法,通过结合矩阵条件数判断矩阵的态选择使用原始求逆算法还是旋转求逆算法,使strassen矩阵求逆算法的数值稳定性得到提升。
为实现以上发明目的,采用的技术方案是:
一种稳定的基于spark平台的矩阵求逆算法,包括以下步骤:
s1.初始化算法参数;
s2.判断矩阵的大小是否达到本地求逆的大小,若是的话使用本地求逆算法得到矩阵的逆,然后返回矩阵的逆;否则进行步骤s3;
s3.如果矩阵的大小未达到本地求逆的大小,则把输入矩阵等分成4块a11,a12,a21,a22;然后计算并比较a11,a21的矩阵条件数;如果a11条件数较小,那么对a11进行递归求逆,然后进行步骤s4,否则,对a21递归进行求逆,进行步骤s5;
s4.根据传统strassen求逆步骤计算中间矩阵p2~p6,这里对p6的求逆仍然需要递归求逆,然后进行步骤s6;
s5.根据旋转strassen求逆步骤计算中间矩阵p2~p6,这里对p6的求逆仍然需要递归求逆,然后进行步骤6;
旋转strassen求逆步骤具体如下:
p2=a11×p1
p3=p1×a22
p4=a11×p3
p5=p4-a12
c12=p3×p6
c21=p6×p2
c11=p1-p3×c21
c22=-p6
s6.计算c11,c12,c21,c22,然后合并为输入矩阵的逆矩阵c并返回逆矩阵c。
优选地,所述步骤s3计算a11,a21的矩阵条件数的具体过程如下:先把矩阵a11、a21转换为spark平台中的indexedrowmatrix,然后使用自带的算法计算奇异值分解,然后取奇异值的最大值除以奇异值的最小值,得到矩阵条件数。
与现有技术相比,本发明的有益效果是:
计算机计算是浮点计算,常常会有精度损失问题,计算复杂的时候累计损失会过大,从而影响计算结果,使得结果与真实值发生较大偏差,因此好的数值稳定性对于计算机的计算是非常重要的。misra实现的spin算法对求大规模矩阵逆有良好的性能以及可拓展性也较好,而本发明在此基础上改进,使其可以在spark平台上求得数值稳定性更好的逆矩阵。
附图说明
图1为strassen矩阵求逆算法伪代码的示意图。
图2为spark版本的strassen矩阵求逆算法伪代码的示意图。
图3为本发明提供的矩阵求逆算法的伪代码的示意图。
具体实施方式
附图仅用于示例性说明,不能理解为对本专利的限制;
以下结合附图和实施例对本发明做进一步的阐述。
实施例1
本发明提供了一种稳定的基于spark平台的矩阵求逆算法,如图3所示,其具体包括以下步骤:
s1.初始化算法参数;
s2.判断矩阵的大小是否达到本地求逆的大小,若是的话使用本地求逆算法得到矩阵的逆,然后返回矩阵的逆;否则进行步骤s3;
s3.如果矩阵的大小未达到本地求逆的大小,则把输入矩阵等分成4块a11,a12,a21,a22;然后计算并比较a11,a21的矩阵条件数;如果a11条件数较小,那么对a11进行递归求逆,然后进行步骤s4,否则,对a21递归进行求逆,进行步骤s5;
s4.根据传统strassen求逆步骤计算中间矩阵p2~p6,这里对p6的求逆仍然需要递归求逆,然后进行步骤s6;
s5.根据旋转strassen求逆步骤计算中间矩阵p2~p6,这里对p6的求逆仍然需要递归求逆,然后进行步骤6;
旋转strassen求逆步骤具体如下:
p2=a11×p1
p3=p1×a22
p4=a11×p3
p5=p4-a12
c12=p3×p6
c21=p6×p2
c11=p1-p3×c21
c22=-p6
s6.计算c11,c12,c21,c22,然后合并为输入矩阵的逆矩阵c并返回逆矩阵c。
通过计算矩阵的条件数来衡量一个矩阵是否良态,每次递归都分别计算左上角和左下角矩阵的条件数,选择条件数小的为递归子矩阵。这里使用spark本地的奇异值分解算法来计算,伪代码如下:
先把矩阵转换为spark中的indexedrowmatrix,然后使用自带的算法计算奇异值分解,然后取奇异值的最大值除以奇异值的最小值,得到矩阵条件数。
显然,本发明的上述实施例仅仅是为清楚地说明本发明所作的举例,而并非是对本发明的实施方式的限定。对于所属领域的普通技术人员来说,在上述说明的基础上还可以做出其它不同形式的变化或变动。这里无需也无法对所有的实施方式予以穷举。凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,均应包含在本发明权利要求的保护范围之内。