矩阵链乘法的最优结合法

我们定义矩阵链为:$A_1A_2A_3···A_n$

我们可以用一对封闭的小括号表示矩阵链的乘积

e.g.:$A_1A_2A_3A_4$的矩阵链结合方式共有5种

矩阵乘法

在处理矩阵链之前我们必须了解两个矩阵相乘需要执行多少次运算以及结合方式对其造成的影响

1
2
3
4
5
6
7
8
9
10
11
MATRIX-MULTIPLICATION(A,B)
if A.cols ≠ B.rows
then error "incompatible dimension"
else let C be a new A.rows×B.cols matrix
for i ← 1 to A.rows
for j ← 1 to B.cols
C[i,j] ← 0
for k ← 1 to A.cols
C[i,j] ← C[i,j]+A[i,k]*B[j,k]

return C

假如A为p×q,B为q×r,那么算法复杂度为O(pqr)

如果$< A_1,A_2,A_3>=\{(10 \times 100),(100 \times5),(5 \times 50)\}$

$((A_1A_2)A_3)=7500$而$(A_1(A_2A_3))=75000$

后者是前者的10倍!因此结合的方式决定了矩阵链乘法的复杂度

那么我们必须知道结合方式的数量怎么算,思想很简单,因为不管多么长的矩阵链最后终究是两个子积相乘,这两个子积之间必然有一个小括号:"("")"作为分隔符(split),设该位置为$k(k \in (0,n))$,P(n)是长为n的矩阵链的结合方式数量。以该位置分隔,左边k个矩阵的结合方式为P(k),右边n-k个矩阵的结合方式为P(n-k),将它们加起来就是所有分隔的结合方式即总结合方式:

上面的$P(4)=P(1)P(3)+P(2)P(2)+P(3)P(1)=2+1+2=5$

利用这个思想对问题进行分隔,会产生n-1个子问题

最优子结构

设$A_i=p_{i-1} \times p_i$,$A_{i…j}$表示$A_i…A_j$的积

假设$i \ne j$,我们用括号包住$A_i$和$A_j$,表示是一个乘积,然后我们必须在k位置切割得到两个子积:$A_{i…k}$和$A_{k+1…j}$,$k \in [i,j)$,这两个子积分别的执行次数加上子积运算的执行次数就是$A_{i…j}$的执行次数,这是其中一个切割方案,比较所有方案取其最小值就是欲求结果

而$A_{i…k}$和$A_{k+1…j}$的乘积执行的运算次数为$p_{i-1}p_kp_j$

令$DP[i,j]$表示$A_i···A_j$的矩阵链的总乘积次数

那么

然后动态规划优化重叠子问题即可

我们还要准备一个table记录k的位置,通过回溯法构造最优解(括号分布的最优方案)

对于一个乘积我们需要用括号包起来表示,因此构造最优解十分简单

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
MATRIX-CHAIN-MULTIPLICATION(p)
n ← p.length-1
let t[1...n,1...n],s[1...n-1,2...n] be new tables

//初始化
for i ← 1 to n
t[i,i] ← 0

for l ← 2 to n //l 表示矩阵链的长度
for i ← 1 to n-l+1 //i 表示起始位置
j ← l+i-1 //j 表示末尾
t[i,j] ← ∞
for k ← i to j-1
q ← t[i,k]+t[k+1,j]+p[i-1]p[k][j]
if q < t[i,j]
then t[i,j] ← q
s[i,j] ← k

return t and s

PRINT-OPTIMAL-PARENTS(s,i,j)
if i = j //相等表示无法进一步分割
then print "Ai"
else print "("
PRINT-OPTIMAL-PARENTS(s,i,s[i,j]) //分割点左边
PRINT-OPTIMAL-PARENTS(s,s[i.j]+1,j) //分割点右边
print ")"

显然,时间复杂度为$O(n^3)$

C++ Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
template<size_t N>
void print_matrix_chain(int s[N][N], int i, int j) {
if (i == j)
cout << "A" << i;
else {
cout << "(";
print_matrix_chain(s, i, s[i][j]);
print_matrix_chain(s, s[i][j] + 1, j);
cout << ")";
}
}

template<size_t N>
long matrix_chain(int(&p)[N]) {
constexpr int n = N - 1;
int m[n + 1][n + 1];
int s[n][n];
for (int i = 1; i <= n; ++i)
m[i][i] = 0;

for (int l = 2; l <= n; ++l)
for (int i = 1; i <= n - l + 1; ++i) {
int q = 0;
int j = l + i - 1;
m[i][j] = INT_MAX;
for (int k = i; k < j; ++k) {
q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
if (q < m[i][j]) {
m[i][j] = q;
s[i][j] = k;
}
}
}

print_matrix_chain(s, 1, n);
return m[1][n];
}