Code:
//بسم الله الرحمن الرحیم
//Strassen Algorithm Impementation in C++
//Release date May 5 2010
#include <iostream>
#include <cstdlib>
#include <iomanip>
#include <ctime>
#include <windows.h>
using namespace std;
void Strassen(int n, int** MatrixA, int ** MatrixB, int ** MatrixC);
void ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int length );
void SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int length );
void MUL(int** MatrixA, int** MatrixB, int** MatrixResult, int length );
void FillMatrix( int** matrix1, int** matrix2, int length);
int main()
{
int n = 0;
int MatrixSize = 0;
int** MatrixA;
int** MatrixB;
int** MatrixC;
clock_t startTime_For_Normal_Multipilication ;
clock_t endTime_For_Normal_Multipilication ;
clock_t startTime_For_Strassen ;
clock_t endTime_For_Strassen ;
time_t start,end;
srand(time(0));
cout<<"In the name of GOD"<<endl;
cout<<"\nPlease Enter your Matrix Size: ";
cin>>MatrixSize;
int N = MatrixSize;
cout<<"Please Enter your Threshhold: ";
cin>>n;
MatrixA = new int *[MatrixSize];
MatrixB = new int *[MatrixSize];
MatrixC = new int *[MatrixSize];
for (int i = 0; i < MatrixSize; i++)
{
MatrixA[i] = new int [MatrixSize];
MatrixB[i] = new int [MatrixSize];
MatrixC[i] = new int [MatrixSize];
}
FillMatrix(MatrixA,MatrixB,MatrixSize);
// if ( n <= MatrixSize )
//{
cout<<"Phase I started: "<< (startTime_For_Normal_Multipilication = clock());
//MUL(MatrixA,MatrixB,MatrixC,MatrixSize);
//{
for (int i=0;i<N;i++)
{
for (int j=0;j<N;j++)
{
MatrixC[i][j]=0;
for (int k=0;k<N;k++)
{
MatrixC[i][j]=MatrixC[i][j]+MatrixA[i][k]*MatrixB[k][j];
}
}
}
//}
cout<<"\nPhase I ended: "<< (endTime_For_Normal_Multipilication = clock());
//}
//else
//{
cout<<"\nPhase II started: "<< (startTime_For_Strassen = clock());
Strassen(MatrixSize,MatrixA,MatrixB,MatrixC);
cout<<"\nPhase II ended: "<<(endTime_For_Strassen = clock());
// }
cout<<"\nStats:\n";
cout<<"Normal mode "<<(endTime_For_Normal_Multipilication - startTime_For_Normal_Multipilication)/CLOCKS_PER_SEC<<" Sec";
cout<<"\nStrassen mode "<<(endTime_For_Strassen - startTime_For_Strassen)/CLOCKS_PER_SEC<<" Sec";
system("Pause");
return 0;
}
void Strassen(int N, int** MatrixA, int** MatrixB, int** MatrixC)
{
int HalfSize = N/2;
int** M1;
int** M2;
int** M3;
int** M4;
int** M5;
int** M6;
int** M7;
int** A11;
int** A12;
int** A21;
int** A22;
int** B11;
int** B12;
int** B21;
int** B22;
int** C11;
int** C12;
int** C21;
int** C22;
int** AResult;
int** BResult;
A11 = new int *[N];
A12 = new int *[N];
A21 = new int *[N];
A22 = new int *[N];
B11 = new int *[N];
B12 = new int *[N];
B21 = new int *[N];
B22 = new int *[N];
C11 = new int *[N];
C12 = new int *[N];
C21 = new int *[N];
C22 = new int *[N];
M1 = new int *[N];
M2 = new int *[N];
M3 = new int *[N];
M4 = new int *[N];
M5 = new int *[N];
M6 = new int *[N];
M7 = new int *[N];
AResult = new int *[N];
BResult = new int *[N];
for ( int i = 0; i < N; i++)
{
A11[i] = new int[N];
A12[i] = new int[N];
A21[i] = new int[N];
A22[i] = new int[N];
B11[i] = new int[N];
B12[i] = new int[N];
B21[i] = new int[N];
B22[i] = new int[N];
C11[i] = new int[N];
C12[i] = new int[N];
C21[i] = new int[N];
C22[i] = new int[N];
M1[i] = new int[N];
M2[i] = new int[N];
M3[i] = new int[N];
M4[i] = new int[N];
M5[i] = new int[N];
M6[i] = new int[N];
M7[i] = new int[N];
AResult[i] = new int[N];
BResult[i] = new int[N];
}
if ( N == 2 )
{
MUL(MatrixA,MatrixB,MatrixC,N);
}
else
{
for (int i = 0; i < N / 2; i++)
{
for (int j = 0; j < N / 2; j++)
{
A11[i][j] = MatrixA[i][j];
A12[i][j] = MatrixA[i][j + N / 2];
A21[i][j] = MatrixA[i + N / 2][j];
A22[i][j] = MatrixA[i + N / 2][j + N / 2];
B11[i][j] = MatrixB[i][j];
B12[i][j] = MatrixB[i][j + N / 2];
B21[i][j] = MatrixB[i + N / 2][j];
B22[i][j] = MatrixB[i + N / 2][j + N / 2];
}
}
//M1[][]
ADD( A11,A22,AResult, HalfSize);
ADD( B11,B22,BResult, HalfSize);
Strassen( HalfSize, AResult, BResult, M1 ); //Mul(AResult,BResult,M1);
//M2[][]
ADD( A21,A22,AResult, HalfSize); //M2=(A21+A22)B11
Strassen(HalfSize, AResult, B11, M2); //Mul(AResult,B11,M2);
//M3[][]
SUB( B12,B22,BResult, HalfSize); //M3=A11(B12-B22)
Strassen(HalfSize, A11, BResult, M3); //Mul(A11,BResult,M3);
//M4[][]
SUB( B21, B11, BResult, HalfSize); //M4=A22(B21-B11)
Strassen(HalfSize, A22, BResult, M4); //Mul(A22,BResult,M4);
//M5[][]
ADD( A11, A12, AResult, HalfSize); //M5=(A11+A12)B22
Strassen(HalfSize, AResult, B22, M5); //Mul(AResult,B22,M5);
//M6[][]
SUB( A21, A11, AResult, HalfSize);
ADD( B11, B12, BResult, HalfSize); //M6=(A21-A11)(B11+B12)
Strassen( HalfSize, AResult, BResult, M6); //Mul(AResult,BResult,M6);
//M7[][]
SUB(A12, A22, AResult, HalfSize);
ADD(B21, B22, BResult, HalfSize); //M7=(A12-A22)(B21+B22)
Strassen(HalfSize, AResult, BResult, M7); //Mul(AResult,BResult,M7);
//C11 = M1 + M4 - M5 + M7;
ADD( M1, M4, AResult, HalfSize);
SUB( M7, M5, BResult, HalfSize);
ADD( AResult, BResult, C11, HalfSize);
//C12 = M3 + M5;
ADD( M3, M5, C12, HalfSize);
//C21 = M2 + M4;
ADD( M2, M4, C21, HalfSize);
//C22 = M1 + M3 - M2 + M6;
ADD( M1, M3, AResult, HalfSize);
SUB( M6, M2, BResult, HalfSize);
ADD( AResult, BResult, C22, HalfSize);
for (int i = 0; i < N / 2; i++)
{
for (int j = 0 ; j < N / 2; j++)
{
MatrixC[i][j] = C11[i][j];
MatrixC[i][j + N / 2] = C12[i][j];
MatrixC[i + N / 2][j] = C21[i][j];
MatrixC[i + N / 2][j + N / 2] = C22[i][j];
}
}
}//end of else
//MatrixC;
}
void ADD(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
for ( int i = 0; i < MatrixSize; i++)
{
for ( int j = 0; j < MatrixSize; j++)
{
MatrixResult[i][j] = MatrixA[i][j] + MatrixB[i][j];
}
}
}
void SUB(int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
for ( int i = 0; i < MatrixSize; i++)
{
for ( int j = 0; j < MatrixSize; j++)
{
MatrixResult[i][j] = MatrixA[i][j] - MatrixB[i][j];
}
}
}
void MUL( int** MatrixA, int** MatrixB, int** MatrixResult, int MatrixSize )
{
for (int i=0;i<MatrixSize ;i++)
{
for (int j=0;j<MatrixSize ;j++)
{
MatrixResult[i][j]=0;
for (int k=0;k<MatrixSize ;k++)
{
MatrixResult[i][j]=MatrixResult[i][j]+MatrixA[i][k]*MatrixB[k][j];
}
}
}
}
void FillMatrix( int** MatrixA, int** MatrixB, int length)
{
for(int row = 0; row<length; row++)
{
for(int column = 0; column<length; column++)
{
MatrixB[row][column] = (MatrixA[row][column] = rand() %9999999);
//matrix2[row][column] = rand() % 2;//ba hazfe in khat 50% afzayeshe soorat khahim dasht
}
}
}