티스토리 뷰

알고리즘/study

[분할 정복] Strassen algorithm

이즈미르 2020. 11. 26. 23:56

분할 정복 알고리즘(Divide and conquer algorithm) 중에 하나인 Strassen을 알아보자.

 

중고등학생 때 행렬의 곱셈에 대해 배웠을 것이다.

 

행렬 간에 곱셈을 하기 전에 곱셈이 가능한 행과 열로 행렬들이 갖춰졌는지 확인해야 하지만

 

우리는 정사각 행렬(Square matrix)만 다루기로 하자.

 

아래와 같이 각 크기가 n인 정사각 행렬 A, B가 있다고 하자.

 

$$ A_{n,n} = \begin{bmatrix} a_{1,1} & a_{1,2} & \cdots & a_{1,n} \\ a_{2,1} & a_{2,2} & \cdots & a_{2,n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n,1} & a_{n,2} & \cdots & a_{n,n} \end{bmatrix} $$

 

$$ B_{n,n} = \begin{bmatrix} b_{1,1} & b_{1,2} & \cdots & b_{1,n} \\ b_{2,1} & b_{2,2} & \cdots & b_{2,n} \\ \vdots & \vdots & \ddots & \vdots \\ b_{n,1} & b_{n,2} & \cdots & b_{n,n} \end{bmatrix} $$

 

행렬곱 C = AB는 아래와 같이 크기 n인 정사각 행렬로 정의된다.

 

$$ C_{n,n} = \begin{bmatrix} c_{1,1} & c_{1,2} & \cdots & c_{1,n} \\ c_{2,1} & c_{2,2} & \cdots & c_{2,n} \\ \vdots & \vdots & \ddots & \vdots \\ c_{n,1} & c_{n,2} & \cdots & c_{n,n} \end{bmatrix} $$

 

이때 행렬 C의 성분은 아래와 같이 정의된다.

 

$$ c_{i,j} = a_{i,1}b_{1,j} + a_{i,2}b_{2,j} + \cdots + a_{i,n}b_{n,j} = \sum_{k=1}^{n} a_{i,k}b_{k,j} $$

 

이를 다시 정리하면 행렬 C는 아래와 같은 모습이 된다.

 

$$ C_{n,n} = \begin{bmatrix} a_{1,1}b_{1,1}+\cdots+a_{1,n}b_{n,1} & a_{1,1}b_{1,2}+\cdots+a_{1,n}b_{n,2} & \cdots & a_{1,1}b_{1,n}+\cdots+a_{1,n}b_{n,n} \\ a_{2,1}b_{1,1}+\cdots+a_{2,n}b_{n,1} & a_{2,1}b_{1,2}+\cdots+a_{2,n}b_{n,2} & \cdots & a_{2,1}b_{1,n}+\cdots+a_{2,n}b_{n,n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n,1}b_{1,1}+\cdots+a_{n,n}b_{n,1} & a_{n,1}b_{1,2}+\cdots+a_{n,n}b_{n,2} & \cdots & a_{n,1}b_{1,n}+\cdots+a_{n,n}b_{n,n} \end{bmatrix} $$

 

출처 : ko.wikipedia.org/wiki/%ED%96%89%EB%A0%AC_%EA%B3%B1%EC%85%88

 

행렬 곱셈 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 둘러보기로 가기 검색하러 가기 행렬 곱셈을 위해선 첫째 행렬의 열 갯수와 둘째 행렬의 행 갯수가 동일해야한다. 곱셈의 결과 새롭게 만들어진 행렬은 첫째

ko.wikipedia.org

코드로 표현하면 아래와 같다.

 

int** MultiplyMatrix(int** A, int** B, std::size_t n)
{
	int** C = new int*[n];

	int sum = 0;
	for (std::size_t i = 0; i < n; ++i)
	{
		C[i] = new int[n];

		for (std::size_t j = 0; j < n; ++j)
		{
			sum = 0;
			for (std::size_t k = 0; k < n; ++k)
			{
				sum += (A[i][k] * B[k][j]);
			}
			C[i][j] = sum;
		}
	}

	return C;
}

 

이를 시간 복잡도와 공간 복잡도로 표현하면 아래와 같다.

 

$$time\,complexity\,:\,\Theta(n^{3})\quad space\,complexity\,:\,O(1)$$

 

3중 반복문이 행렬의 크기 n만큼 돌기 때문에 $\Theta(n^{3})$의 시간 복잡도가 나오고

 

새로 생성할 행렬 C의 공간을 제외하면 sum 변수 정도를 위한 공간을 사용하기 때문에 $\Theta(1)$ 공간 복잡도가 나온다.

 

이렇게 행렬의 곱을 구하는 방법을 [방법 1]이라고 하자.

 

크기가 n인 정사각 행렬 곱셈의 시간 복잡도가 [방법 1]보다 더 나은 것이 없다고 생각할 수 있지만

 

Strassen 알고리즘을 사용하면 더 나은 시간 복잡도를 얻을 수 있다.

 

원리를 간단하게 설명하자면 행렬들을 쪼개서 곱하고 더하는 과정을 재귀적으로 반복하는데

 

행렬의 덧셈이 곱셈보다 더 빠른 점을 이용하기 위해 쪼갠 행렬들의 곱셈 횟수를 줄이고 덧셈 횟수를 늘린다.

 

(곱셈은 앞서 본 것처럼 $\Theta(n^{3})$이지만 덧셈은 $\Theta(n^{2})$으로 더 나은 시간 복잡도를 가진다.)

 

Strassen 알고리즘을 본격적으로 설명하기 앞서 재귀적으로 행렬을 쪼개어 곱하고 더하는 방법을 먼저 살펴보자.

 

행렬을 재귀적으로 쪼개는 작업이 n을 2로 계속 나누기 때문에 사전에 행렬들의 크기를 2의 거듭제곱으로 맞춰 놓는게 좋다.

 

그 다음 행렬 A와 B, C를 아래와 같이 각각 4분할로 쪼갠다.

 

$$ A = \begin{bmatrix} A_{1,1} & A_{1,2} \\ A_{2,1} & A_{2,2} \end{bmatrix},\, B = \begin{bmatrix} B_{1,1} & B_{1,2} \\ B_{2,1} & B_{2,2} \end{bmatrix},\, C = \begin{bmatrix} C_{1,1} & C_{1,2} \\ C_{2,1} & C_{2,2} \end{bmatrix} $$

 

쪼개진 행렬들로 다음 식들이 성립될 수 있다.

 

$$ [식 1] \quad C_{1,1} = A_{1,1}B_{1,1} + A_{1,2}B_{2,1} $$

$$ [식 2] \quad C_{1,2} = A_{1,1}B_{1,2} + A_{1,2}B_{2,2} $$

$$ [식 3] \quad C_{2,1} = A_{2,1}B_{1,1} + A_{2,2}B_{2,1} $$

$$ [식 4] \quad C_{2,2} = A_{2,1}B_{1,2} + A_{2,2}B_{2,2} $$

 

위 식들에 있는 행렬끼리의 곱에서 다시 위 과정을 재귀적으로 반복한다.

 

행렬의 크기가 더이상 쪼개지지 않는 1까지 반복하고 1개 요소만 남았으면 그냥 곱해주면 된다.

 

이렇게 행렬의 곱을 구하는 방법을 [방법 2]라고 하자.

 

그리고 아래와 같이 [방법 2]의 실행 시간을 정의할 수 있다.

 

$$ T(n) = \begin{cases} \Theta(1), & \text{if }n = 1 \\ 8T(n/2) + \Theta(n^{2}), & \text{if }n > 1 \end{cases} $$

 

[식 1~4]에서 쪼개진 행렬들의 곱이 8번 일어나고 그 행렬의 크기는 n/2가 된다.

 

그리고 행렬들의 덧셈이 4번 일어나는데 $4(n/2)^{2}$가 곧 $\Theta(n^{2})$가 된다.

 

$\Theta(n^{2}) = cn^{2} \quad \text{where}\, c > 0\,$로 정의하고 T(n)에 대해서 구체적으로 알아보자.

 

T(n)을 트리로 나타내면 아래와 같다.

 

[트리 1]

 

[트리 1]은 부모 노드가 각각 8개씩 자식 노드를 갖고 있는 트리이다.

(...으로 표시된 노드는 다 표시하지 못한 나머지 자식 노드들임을 알아두자.

그리고 첫 번째 자식 노드만 비교적 구체적으로 표현했고 나머지 노드들도 그런 식으로 자식 노드들이 있어야 한다.)

 

부모 노드에서는 자식 노드들에서 행렬 곱셈한 결과로 나온 행렬들을 더하는 작업의 실행 시간을 나타낸다. 

 

트리의 높이는 $\log_{2} n$이고 리프 노드의 개수는 $8^{\log_{2}n} = n^{3}$이다.

(트리의 높이는 n을 1이 될 때까지 2로 몇 번 나눌 수 있는가이고

리프 노드 개수는 트리 깊이가 1씩 증가할 때마다 8배씩 증가하므로 $8^{\log_{2}n} = n^{3}$이 된다.)

 

트리의 총 실행 시간을 (리프 노드들을 제외한 모든 레벨의 총 실행 시간) + (리프 노드들의 총 실행 시간)으로 구하면 아래와 같다.

 

$$ \begin{align} T(n) &= (cn^{2} + 8c(n/2)^{2} + 8^{2}c(n/4)^{2} + \cdots + (8/4)^{(\log_{2} n)-1}cn^{2}) + (\Theta(n^{3})) \\ &= \sum_{k=0}^{(\log_{2} n)-1} 2^{k}cn^{2} + \Theta(n^{3}) = {cn^{2}(2^{\log_{2}n} - 1) \over (2 - 1)} + \Theta(n^{3}) \\ &= cn^{2}(n - 1) + \Theta(n^{3}) = \Theta(n^{3}) \end{align} $$

 

결국 [방법 2][방법 1]과 같은 시간 복잡도를 갖는다.

 

이제 Strassen 알고리즘을 알아보자.

 

아래와 같이 7개의 행렬을 정의해 보자.

 

$$ \begin{align} &M_{1} = (A_{1,1} + A_{2,2})(B_{1,1} + B_{2,2}) \\ &M_{2} = (A_{2,1} + A_{2,2})B_{1,1} \\ &M_{3} = A_{1,1}(B_{1,2} - B_{2,2}) \\ &M_{4} = A_{2,2}(B_{2,1} - B_{1,1}) \\ &M_{5} = (A_{1,1} + A_{1,2})B_{2,2} \\ &M_{6} = (A_{2,1} - A_{1,1})(B_{1,1} + B_{1,2}) \\ &M_{7} = (A_{1,2} - A_{2,2})(B_{2,1} + B_{2,2}) \end{align} $$

 

그리고 위의 행렬들로 아래의 식들이 성립된다.

 

$$ \begin{align} &C_{1,1} = M_{1} + M_{4} - M_{5} + M_{7} \\ &C_{1,2} = M_{3} + M_{5} \\ &C_{2,1} = M_{2} + M_{4} \\ &C_{2,2} = M_{1} - M_{2} + M_{3} + M_{6} \end{align} $$

 

위 식들을 직접 전개해 보면 [식 1~4]가 그대로 나온다.

 

출처 : ko.wikipedia.org/wiki/%EC%8A%88%ED%8A%B8%EB%9D%BC%EC%84%BC_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

 

슈트라센 알고리즘 - 위키백과, 우리 모두의 백과사전

위키백과, 우리 모두의 백과사전. 둘러보기로 가기 검색하러 가기 선형대수학에서 슈트라센 알고리즘은 독일의 수학자 폴커 슈트라센(Volker Strassen)이 1969년에 개발한 행렬 곱셈 알고리즘이다.

ko.wikipedia.org

앞서 간단하게 설명했던 것과 같이 행렬의 곱셈을 1번 줄이는 대신 행렬의 덧셈을 여러번 더 수행한 셈이다.

 

행렬의 덧셈이 비록 여러번 더 수행되었지만 이는 $n^{2}$의 상수배에 달하는 시간이고 이는 결국 $\Theta(n^{2})$에 종결된다.

 

다시 실행 시간 T(n)을 정의해 보면 아래와 같다.

(굳이 다시 트리로 펼치지 않고 식으로 정의하겠다.)

 

$$ \begin{align} T(n) &= 7T(n/2) + \Theta(n^{2}) \\ &= (cn^{2} + 7c(n/2)^{2} + 7^{2}c(n/4)^{2} + \cdots + ({7 \over 4})^{(\log_{2} n)-1}cn^{2}) + \Theta(n^{\log_{2} 7}) \\ &= \sum_{k=0}^{(\log_{2} n) - 1} ({7 \over 4})^{k}cn^{2} + \Theta(n^{\log_{2} 7}) \\ &= {{7 \over 4}^{\log_{2} n} - 1 \over {7 \over 4} - 1}cn^{2} + \Theta(n^{\log_{2} 7}) \\ &= {4 \over 3}c(n^{\log_{2} 7} - n^{2}) + \Theta(n^{\log_{2} 7}) \\ &= \Theta(n^{\log_{2} 7}) \quad \because \, 2.80 < \log_{2} 7 < 2.81 \end{align} $$

 

트리의 높이는 $\log_{2} n$으로 [방법 2]와 같지만 부모 노드의 자식 노드 개수가 8개에서 7개로 감소했다.

 

그래서 리프 노드들의 총 실행 시간은 $\Theta(n^{\log_{2} 7})$이 된다.

(트리의 각 레벨에 있는 노드들의 개수는 $7^{k}$(사실 k는 깊이(레벨-1)이다.)이고 마지막 레벨인 $\log_{2} n$을 $k$에 대입하면 $7^{\log_{2} n}$이 되며 이는 곧 $n^{\log_{2} 7}$이 된다.)

 

따라서 Strassen 알고리즘은 [방법 1]과 [방법 2]보다 더 나은 시간 복잡도를 보여주고 있다.

 

공간 복잡도는 어떨까?

 

먼저 처음 $M_{1,2, \cdots, 7}$을 위한 공간이 필요하므로 여러 개의 $(n/2)^{2}$ 크기의 공간이 필요하다.

 

그 다음은 여러 개의 $(n/4)^{2}$ 크기가 필요하고 이런 식으로 $n^{2}$의 상수배 공간들이 필요하게 된다.

(멀티 스레드가 아니면 어떤 노드와 그 노드의 사촌 노드를 위한 공간을 동시에 필요로 하지 않는다.

즉, 실행 시간과 달리 공간은 트리의 깊이에 따라 7배씩 증가하지 않는다.)

 

따라서 $\Theta(n^{2})$으로 볼 수 있다.

 

다음으로 실제 실행 시간에 대해서 생각해보자.

 

Strassen 알고리즘은 여러 행렬 덧셈을 수행하므로 매우 큰 $n$이 아니면 [방법 1]보다 오래 걸린다.

 

필자도 Strassen 알고리즘이 더 빠른 $n$을 찾고자 여러 시도를 해봤는데 실행 시간이 감당이 되지 않고 메모리도 부족하여 일반 가정용 컴퓨터로 찾기에는 무리인 것으로 판단했다.

(Strassen 알고리즘을 멀티 스레드로 돌려도 답이 없다.

측정 가능한 $n$들에 대해 걸리는 시간을 그래프로 그리고 Strassen 알고리즘이 더 빨라지는 $n$을 추론하는 방법이 있겠다.)

 

마지막으로 Strassen 알고리즘을 코드로 표현하면 아래와 같다.

 

class Matrix
{
	typedef int (*Calculate)(int, int);
	static int Add(int value1, int value2)
	{
		return value1 + value2;
	}
	static int Sub(int value1, int value2)
	{
		return value1 - value2;
	}
	static int Set(int value1, int value2)
	{
		return value2;
	}

	Matrix(const Matrix& other, std::size_t row1, std::size_t col1, std::size_t row2, std::size_t col2, std::size_t size, Calculate cal)
	{
		_n = _size = size;
		_useStrassen = other._useStrassen; // it must be true
		_name = other._name;

		_values = new Value * [_size];
		for (std::size_t i = 0; i < _size; ++i)
		{
			_values[i] = new Value[_size];
			for (std::size_t j = 0; j < _size; ++j)
			{
				_values[i][j] = cal(other[row1 + i][col1 + j], other[row2 + i][col2 + j]);
			}
		}
	}

	Matrix(const Matrix& other, std::size_t row, std::size_t col, std::size_t size)
	{
		_n = _size = size;
		_useStrassen = other._useStrassen;
		_name = other._name;

		_values = new Value * [_size];
		for (std::size_t i = 0; i < _size; ++i)
		{
			_values[i] = new Value[_size];
			memcpy(_values[i], other[row + i] + col, _size);
		}
	}

	void Calc(const Matrix& other, std::size_t row, std::size_t rowCount, std::size_t col, std::size_t colCount, Calculate cal)
	{
		for (std::size_t i = 0; i < rowCount; ++i)
		{
			for (std::size_t j = 0; j < colCount; ++j)
			{
				_values[row + i][col + j] = cal(_values[row + i][col + j], other[i][j]);
			}
		}
	}

	void Strassen(const Matrix& other, std::size_t thisRow, std::size_t thisCol, std::size_t otherRow, std::size_t otherCol)
	{
		// reference : https://ko.wikipedia.org/wiki/%EC%8A%88%ED%8A%B8%EB%9D%BC%EC%84%BC_%EC%95%8C%EA%B3%A0%EB%A6%AC%EC%A6%98

		if (1 < _size)
		{
			std::size_t half = _size / 2;

			Matrix M1(*this, thisRow, thisCol, thisRow + half, thisCol + half, half, Add); // M1 = A11 + A22
			M1.Strassen(Matrix(other, otherRow, otherCol, otherRow + half, otherCol + half, half, Add), 0, 0, 0, 0); // M1 *= (B11 + B22)

			Matrix M2(*this, thisRow + half, thisCol, thisRow + half, thisCol + half, half, Add); // M2 = A21 + A22
			M2.Strassen(other, 0, 0, otherRow, otherCol); // M2 *= B11

			Matrix M3(*this, thisRow, thisCol, half); // M3 = A11
			M3.Strassen(Matrix(other, otherRow, otherCol + half, otherRow + half, otherCol + half, half, Sub), 0, 0, 0, 0); // M3 *= (B12 - B22)

			Matrix M4(*this, thisRow + half, thisCol + half, half); // M4 = A22
			M4.Strassen(Matrix(other, otherRow + half, otherCol, otherRow, otherCol, half, Sub), 0, 0, 0, 0); // M4 *= (B21 - B11)

			Matrix M5(*this, thisRow, thisCol, thisRow, thisCol + half, half, Add); // M5 = A11 + A12
			M5.Strassen(other, 0, 0, otherRow + half, otherCol + half); // M5 *= B22

			Matrix M6(*this, thisRow + half, thisCol, thisRow, thisCol, half, Sub); // M6 = A21 - A11
			M6.Strassen(Matrix(other, otherRow, otherCol, otherRow, otherCol + half, half, Add), 0, 0, 0, 0); // M6 *= (B11 + B12)

			Matrix M7(*this, thisRow, thisCol + half, thisRow + half, thisCol + half, half, Sub); // M7 = A12 - A22
			M7.Strassen(Matrix(other, otherRow + half, otherCol, otherRow + half, otherCol + half, half, Add), 0, 0, 0, 0); // M7 *= (B21 + B22)

			// for C11
			//Calc(M1, 0, half, 0, half, Set); // C11 = M1
			for (std::size_t i = 0; i < half; ++i)
			{
				memcpy(_values[i], M1[i], half);
			}
			Calc(M4, 0, half, 0, half, Add); // C11 += M4
			Calc(M5, 0, half, 0, half, Sub); // C11 -= M5
			Calc(M7, 0, half, 0, half, Add); // C11 += M7

			// for C12
			//Calc(M3, 0, half, half, half, Set); // C12 = M3
			for (std::size_t i = 0; i < half; ++i)
			{
				memcpy(_values[i] + half, M3[i], half);
			}
			Calc(M5, 0, half, half, half, Add); // C12 += M5

			// for C21
			//Calc(M2, half, half, 0, half, Set); // C21 = M2
			for (std::size_t i = 0; i < half; ++i)
			{
				memcpy(_values[half + i], M2[i], half);
			}
			Calc(M4, half, half, 0, half, Add); // C21 += M4

			// for C22
			//Calc(M1, half, half, half, half, Set); // C22 = M1
			for (std::size_t i = 0; i < half; ++i)
			{
				memcpy(_values[half + i] + half, M1[i], half);
			}
			Calc(M2, half, half, half, half, Sub); // C22 -= M2
			Calc(M3, half, half, half, half, Add); // C22 += M3
			Calc(M6, half, half, half, half, Add); // C22 += m6
		}
		else
		{
			_values[thisRow][thisCol] *= other[otherRow][otherCol];
			Progress::Proceed(_name, 1);
		}
	}

	void Strassen(const Matrix&& other, std::size_t row1, std::size_t col1, std::size_t row2, std::size_t col2)
	{
		Strassen(other, row1, col1, row2, col2);
	}
 };

 

Strassen 알고리즘의 주요 코드 일부를 가져왔다.

 

최대한 새로운 행렬을 생성하는 것을 막고자 인덱스로 행렬 성분들을 참조하도록 구현했다.

(그리고 recurrence를 iteration으로 바꾸면 더욱 효율적일 것이다.)

 

여기까지 Strassen 알고리즘에 대해 알아보았다.

댓글