English 中文(简体)
Design and Analysis of Algorithms

Selected Reading

Strassen’s Matrix Multiplication
  • 时间:2024-12-22

Strassen’s Matrix Multippcation


Previous Page Next Page  

Strassen’s Matrix Multippcation is the spanide and conquer approach to solve the matrix multippcation problems. The usual matrix multippcation method multippes each row with each column to achieve the product matrix. The time complexity taken by this approach is O(n3), since it takes two loops to multiply. Strassen’s method was introduced to reduce the time complexity from O(n3) to O(nlog 7).

Naïve Method

First, we will discuss naïve method and its complexity. Here, we are calculating Z=?X × Y. Using Naïve method, two matrices (X and Y) can be multipped if the order of these matrices are p × q and q × r and the resultant matrix will be of order p × r. The following pseudocode describes the naïve multippcation −


Algorithm: Matrix-Multippcation (X, Y, Z) 
for i = 1 to p do 
   for j = 1 to r do 
      Z[i,j] := 0 
      for k = 1 to q do 
         Z[i,j] := Z[i,j] + X[i,k] × Y[k,j] 

Complexity

Here, we assume that integer operations take O(1) time. There are three for loops in this algorithm and one is nested in other. Hence, the algorithm takes O(n3) time to execute.

Strassen’s Matrix Multippcation Algorithm

In this context, using Strassen’s Matrix multippcation algorithm, the time consumption can be improved a pttle bit.

Strassen’s Matrix multippcation can be performed only on square matrices where n is a power of 2. Order of both of the matrices are n × n.

Divide X, Y and Z into four (n/2)×(n/2) matrices as represented below −

$Z = egin{bmatrix}I & J \K & L end{bmatrix}$ $X = egin{bmatrix}A & B \C & D end{bmatrix}$ and $Y = egin{bmatrix}E & F \G & H end{bmatrix}$

Using Strassen’s Algorithm compute the following −

$$M_{1} : colon= (A+C) imes (E+F)$$

$$M_{2} : colon= (B+D) imes (G+H)$$

$$M_{3} : colon= (A-D) imes (E+H)$$

$$M_{4} : colon= A imes (F-H)$$

$$M_{5} : colon= (C+D) imes (E)$$

$$M_{6} : colon= (A+B) imes (H)$$

$$M_{7} : colon= D imes (G-E)$$

Then,

$$I : colon= M_{2} + M_{3} - M_{6} - M_{7}$$

$$J : colon= M_{4} + M_{6}$$

$$K : colon= M_{5} + M_{7}$$

$$L : colon= M_{1} - M_{3} - M_{4} - M_{5}$$

Analysis

$$T(n)=egin{cases}c & if:n= 1\7:x:T(frac{n}{2})+d:x:n^2 & otherwiseend{cases} :where: c: and :d:are: constants$$

Using this recurrence relation, we get $T(n) = O(n^{log7})$

Hence, the complexity of Strassen’s matrix multippcation algorithm is $O(n^{log7})$.

Example

Let us look at the implementation of Strassen s Matrix Multippcation in various programming languages: C, C++, Java, Python.


#include<stdio.h>
int main(){
   int z[2][2];
   int i, j;
   int m1, m2, m3, m4 , m5, m6, m7;
   int x[2][2] = {
       {12, 34}, 
       {22, 10}
       };
   int y[2][2] = {
       {3, 4}, 
       {2, 1}
   };
   printf("
The first matrix is
");
   for(i = 0; i < 2; i++) {
      printf("
");
      for(j = 0; j < 2; j++)
         printf("%d	", x[i][j]);
   }
   printf("
The second matrix is
");
   for(i = 0; i < 2; i++) {
      printf("
");
      for(j = 0; j < 2; j++)
         printf("%d	", y[i][j]);
   }
   m1= (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
   m2= (x[1][0] + x[1][1]) * y[0][0];
   m3= x[0][0] * (y[0][1] - y[1][1]);
   m4= x[1][1] * (y[1][0] - y[0][0]);
   m5= (x[0][0] + x[0][1]) * y[1][1];
   m6= (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
   m7= (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
   z[0][0] = m1 + m4- m5 + m7;
   z[0][1] = m3 + m5;
   z[1][0] = m2 + m4;
   z[1][1] = m1 - m2 + m3 + m6;
   printf("
Product achieved using Strassen s algorithm 
");
   for(i = 0; i < 2 ; i++) {
      printf("
");
      for(j = 0; j < 2; j++)
         printf("%d	", z[i][j]);
   }
   return 0;
}


Output


The first matrix is

12	34	
22	10	
The second matrix is

3	4	
2	1	
Product achieved using Strassen s algorithm 

104	82	
86	98

#include<iostream>
using namespace std;
int main() {
   int z[2][2];
   int i, j;
   int m1, m2, m3, m4 , m5, m6, m7;
      int x[2][2] = {
         {12, 34}, 
         {22, 10}
      };
   int y[2][2] = {
      {3, 4}, 
      {2, 1}
   };
   cout<<"
The first matrix is
";
   for(i = 0; i < 2; i++) {
      cout<<endl;
      for(j = 0; j < 2; j++)
         cout<<x[i][j]<<" ";
   }
   cout<<"
The second matrix is
";
   for(i = 0;i < 2; i++){
      cout<<endl;
      for(j = 0;j < 2; j++)
         cout<<y[i][j]<<" ";
   }

   m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
   m2 = (x[1][0] + x[1][1]) * y[0][0];
   m3 = x[0][0] * (y[0][1] - y[1][1]);
   m4 = x[1][1] * (y[1][0] - y[0][0]);
   m5 = (x[0][0] + x[0][1]) * y[1][1];
   m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
   m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);

   z[0][0] = m1 + m4- m5 + m7;
   z[0][1] = m3 + m5;
   z[1][0] = m2 + m4;
   z[1][1] = m1 - m2 + m3 + m6;

   cout<<"
Product achieved using Strassen s algorithm 
";
   for(i = 0; i < 2 ; i++) {
      cout<<endl;
      for(j = 0; j < 2; j++)
         cout<<z[i][j]<<" ";
   }
   return 0;
}

Output


The first matrix is

12 34 
22 10 
The second matrix is

3 4 
2 1 
Product achieved using Strassen s algorithm 

104 82 
86 98 

pubpc class Strassens {
   pubpc static void main(String[] args) {
      int[][] x = {{12, 34}, {22, 10}};
      int[][] y = {{3, 4}, {2, 1}};
      int z[][] = new int[2][2];
      int m1, m2, m3, m4 , m5, m6, m7;
      
      System.out.print("The first matrix is: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new pne
         for(int j = 0; j<2; j++) {
            System.out.print(x[i][j] + " ");
         }
      }
      System.out.print("
The second matrix is: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new pne
         for(int j = 0; j<2; j++) {
            System.out.print(y[i][j] + " ");
         }
      }
      m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
      m2 = (x[1][0] + x[1][1]) * y[0][0];
      m3 = x[0][0] * (y[0][1] - y[1][1]);
      m4 = x[1][1] * (y[1][0] - y[0][0]);
      m5 = (x[0][0] + x[0][1]) * y[1][1];
      m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
      m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
      z[0][0] = m1 + m4- m5 + m7;
      z[0][1] = m3 + m5;
      z[1][0] = m2 + m4;
      z[1][1] = m1 - m2 + m3 + m6;
      System.out.print("
Product achieved using Strassen s algorithm: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new pne
         for(int j = 0; j<2; j++) {
            System.out.print(z[i][j] + " ");
         }
      }
   }
}

Output


The first matrix is: 12 34 22 10 
The second matrix is: 3 4 
2 1
Product achieved using Strassen s algorithm: 104 82 
86 98

a = [[1,2,3,4],[2,3,4,5],[3,4,5,6],[4,5,6,7]]
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
def new_m(p, q): # create a matrix filled with 0s
   matrix = [[0 for row in range(p)] for col in range(q)]
   return matrix
def sppt(matrix): # sppt matrix into quarters
   a = matrix
   b = matrix
   c = matrix
   d = matrix
   while(len(a) > len(matrix)/2):
      a = a[:len(a)//2]
      b = b[:len(b)//2]
      c = c[len(c)//2:]
      d = d[len(d)//2:]
   while(len(a[0]) > len(matrix[0])/2):
      for i in range(len(a[0])//2):
         a[i] = a[i][:len(a[i])//2]
         b[i] = b[i][len(b[i])//2:]
         c[i] = c[i][:len(c[i])//2]
         d[i] = d[i][len(d[i])//2:]
   return a,b,c,d
def add_m(a, b):
   if type(a) == int:
      d = a + b
   else:
      d = []
      for i in range(len(a)):
         c = []
         for j in range(len(a[0])):
            c.append(a[i][j] + b[i][j])
         d.append(c)
   return d
def sub_m(a, b):
   if type(a) == int:
      d = a - b
   else:
      d = []
      for i in range(len(a)):
         c = []
         for j in range(len(a[0])):
            c.append(a[i][j] - b[i][j])
         d.append(c)
   return d
def strassen(a, b, q):
   
   # base case: 1x1 matrix
   if q == 1:
      d = [[0]]
      d[0][0] = a[0][0] * b[0][0]
      return d
   else:

      #sppt matrices into quarters
      a11, a12, a21, a22 = sppt(a)
      b11, b12, b21, b22 = sppt(b)
      
      # p1 = (a11+a22) * (b11+b22)
      p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2)
      
      # p2 = (a21+a22) * b11
      p2 = strassen(add_m(a21,a22), b11, q/2)
      
      # p3 = a11 * (b12-b22)
      p3 = strassen(a11, sub_m(b12,b22), q/2)
      
      # p4 = a22 * (b21-b11)
      p4 = strassen(a22, sub_m(b21,b11), q/2)
      
      # p5 = (a11+a12) * b22
      p5 = strassen(add_m(a11,a12), b22, q/2)
      
      # p6 = (a21-a11) * (b11+b12)
      p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2)
      
      # p7 = (a12-a22) * (b21+b22)
      p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2)
      
      # c11 = p1 + p4 - p5 + p7
      c11 = add_m(sub_m(add_m(p1, p4), p5), p7)
      
      # c12 = p3 + p5
      c12 = add_m(p3, p5)
      
      # c21 = p2 + p4
      c21 = add_m(p2, p4)
      
      # c22 = p1 + p3 - p2 + p6
      c22 = add_m(sub_m(add_m(p1, p3), p2), p6)
      c = new_m(len(c11)*2,len(c11)*2)
      for i in range(len(c11)):
         for j in range(len(c11)):
            c[i][j] = c11[i][j]
            c[i][j+len(c11)] = c12[i][j]
            c[i+len(c11)][j] = c21[i][j]
            c[i+len(c11)][j+len(c11)] = c22[i][j]
      return c
      
print("Output Product:")
print(strassen(a, b, 4))

Output


Output Product:
[[70, 70, 70, 70], [96, 96, 96, 96], [122, 122, 122, 122], [148, 148, 148, 148]]
Advertisements