-
Python: 행렬 곱셈 알고리즘 소개컴퓨터/파이썬 2020. 11. 15. 23:09728x90반응형
행렬 (Matrix)
1. 소개
행렬 n x n 을 곱하려면, 보통 for 문 3개, O(n^3) 시간 복잡도를 갖지만,
// 일반 행렬 곱셈 int** multiply(int** A, int** B, int n) { int** C = initializeMatrix(n); setToZero(C, n); // 전부 0으로 초기화 for(int i=0; i<n; i++) for(int j=0; j<n; j++) for(int k=0; k<n; k++) C[i][j] += A[i][k] * B[k][j]; return C; }
슈트라센 알고리즘을 이용하면 O(n^2.81) 정도의 시간 복잡도를 갖게 된다. (재귀)
즉, 7번의 곱셈과 18번의 덧셈/뺄셈으로 행렬 곱셈을 하는 방법이다.
(하지만 이 방식은 2^n x 2^n 행렬에서만 통한다, 따라서 크기가 다르면 0으로 채워서 해야 함)
행렬 곱셈: 시간 복잡도 역사
- O(n^2.48) : 1986년 슈트라센이 다시 내놓음
- O(n^2.376) : 1987년 Coppersmith와 Winograd 기법 @논문 링크
- O(n^2.373) : 2010년 Stothers 기법 @논문 링크
- O(n^2.3729) : 2012년 Vassilevska Williams 기법 @논문 링크
- O(n^2.3728639) : 2014년 Le Gall 기법 @논문 링크
(위 텐서들은 laser 메소드를 통해 분석됨)
이 글에선, O(n^2.81) 슈트라센 기법과,
O(n^2)이라고 자랑하는 Shrohan Mohapatra분의 알고리즘을 사용해볼 것이다.
2. 슈트라센 기법
n X n 으로 두 행렬의 크기가 같아야 함!
n은 2의 제곱 값이 여야함!
먼저, 아래와 같이 2 x 2 행렬 2개를 곱한다고 가정하면,
일반적으론, 아래처럼 곱을 해야 하지만,
슈트라센은 아래처럼 계산을 하였다.
(7번의 곱셈과 18번의 덧셈/뺄셈)
참고로, O(n^2.376) 시간 복잡도의 1987년 Coppersmith와 Winograd 기법은,
3번 덧셈/뺄셈 연산이 적은 아래 방식이다.
본론으로 돌아가서, 슈트라센 기법을 사용하기 위해선, 덧셈과 뺄셈 함수를 만들어야 한다.
덧셈, 뺄셈
# n x n, 0 값을 가진 행렬 def initMatrix(n): return [[0 for _ in range(n)] for _ in range(n)] def add(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] + M2[i][j] return temp def subtract(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] - M2[i][j] return temp
기본 케이스
1 x 1이면 그냥 곱해서 return
if n == 1: C = initMatrix(1) C[0][0] = A[0][0] * B[0][0] return C
변수 초기화
C = initMatrix(n) k = n // 2 # 2 x 2 행렬이면 각각 1 x 1 서브 행렬임 A11 = initMatrix(k) A12 = initMatrix(k) A21 = initMatrix(k) A22 = initMatrix(k) B11 = initMatrix(k) B12 = initMatrix(k) B21 = initMatrix(k) B22 = initMatrix(k) for i in range(k): for j in range(k): A11[i][j] = A[i][j] A12[i][j] = A[i][k + j] A21[i][j] = A[k + i][j] A22[i][j] = A[k + i][k + j] B11[i][j] = B[i][j] B12[i][j] = B[i][k + j] B21[i][j] = B[k + i][j] B22[i][j] = B[k + i][k + j]
P값 재귀
P1 = strassen(A11, subtract(B12, B22, k), k) P2 = strassen(add(A11, A12, k), B22, k) P3 = strassen(add(A21, A22, k), B11, k) P4 = strassen(A22, subtract(B21, B11, k), k) P5 = strassen(add(A11, A22, k), add(B11, B22, k), k) P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k) P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k)
A * B = C
C11 = subtract(add(add(P5, P4, k), P6, k), P2, k) C12 = add(P1, P2, k) C21 = add(P3, P4, k) C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k) for i in range(k): for j in range(k): C[i][j] = C11[i][j] C[i][j + k] = C12[i][j] C[k + i][j] = C21[i][j] C[k + i][k + j] = C22[i][j]
결과
A = [[1, 3], [7, 5]] B = [[6, 8], [4, 2]] print(strassen(A, B, 2)) """ [[18, 14], [62, 66]] """ # 128 x 128 행렬끼리 곱하면 약 8초가 걸린다.
더보기def initMatrix(n, m=None): m = m or n return [[0 for _ in range(m)] for _ in range(n)] def add(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] + M2[i][j] return temp def subtract(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] - M2[i][j] return temp def strassen(A, B, n): if n == 1: C = initMatrix(1) C[0][0] = A[0][0] * B[0][0] return C C = initMatrix(n) k = n // 2 A11 = initMatrix(k) A12 = initMatrix(k) A21 = initMatrix(k) A22 = initMatrix(k) B11 = initMatrix(k) B12 = initMatrix(k) B21 = initMatrix(k) B22 = initMatrix(k) for i in range(k): for j in range(k): A11[i][j] = A[i][j] A12[i][j] = A[i][k + j] A21[i][j] = A[k + i][j] A22[i][j] = A[k + i][k + j] B11[i][j] = B[i][j] B12[i][j] = B[i][k + j] B21[i][j] = B[k + i][j] B22[i][j] = B[k + i][k + j] P1 = strassen(A11, subtract(B12, B22, k), k) P2 = strassen(add(A11, A12, k), B22, k) P3 = strassen(add(A21, A22, k), B11, k) P4 = strassen(A22, subtract(B21, B11, k), k) P5 = strassen(add(A11, A22, k), add(B11, B22, k), k) P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k) P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k) C11 = subtract(add(add(P5, P4, k), P6, k), P2, k) C12 = add(P1, P2, k) C21 = add(P3, P4, k) C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k) for i in range(k): for j in range(k): C[i][j] = C11[i][j] C[i][j + k] = C12[i][j] C[k + i][j] = C21[i][j] C[k + i][k + j] = C22[i][j] return C
3. Shrohan Mohapatra O(n^2) 기법
양수에서만 작동함
왜 Vassilevska Williams, Le Gall 등 기법을 소개하지 않고, 이 기법을 소개하냐면,
논문을 보면 알겠지만, 엄청나게 많은 수학적 지식과, 텐서 파워라는 무거운 데이터 구조를 이용해야 한다.
따라서 위 알고리즘들이, 슈트라센말고는 프로그래밍 언어로 구현된 것이 없는 그 이유다.
덧셈, 뺄셈, 함수를 따로 쓰지도 않고, log10을 이용해 신기하게 구한다.
설명이 따로 없어서, 함수만 정리해서 가져왔다.
def quadratic(A, B): from math import log10 #### 길이가 같은지 확인 flag = True and (len(A) == len(B)) if not flag: return None # A와 B의 크기가 다르면 None을 반환 N = len(A) for i in range(N): flag = flag and (len(A[i]) == len(B[i])) if not flag: return None # A와 B의 각 행의 크기가 다르면 None을 반환 #### 최대값 찾기 maxi = 0 for i in range(N): for j in range(N): if maxi < A[i][j]: maxi = A[i][j] # A의 최대값 갱신 if maxi < B[i][j]: maxi = B[i][j] # B의 최대값 갱신 # 최대값을 기반으로 필요한 자릿수 계산 M = int(log10(maxi)) + 1 P = int(log10((10 ** (2 * M) - 1) * N)) + 1 # 결과 행렬 초기화 C, D, E = ( [0 for _ in range(N)], [0 for _ in range(N)], [[0 for _ in range(N)] for _ in range(N)], ) # 행렬 A와 B를 큰 수로 변환 for i in range(N): for j in range(N): C[i] = C[i] * (10 ** P) + A[i][j] # A의 각 행을 하나의 큰 수로 변환 for j in range(N): for i in range(N): D[j] = D[j] * (10 ** P) + B[N - 1 - i][j] # B의 각 열을 하나의 큰 수로 변환, 역순으로 # 큰 수를 이용한 행렬 곱셈 수행 for i in range(N): for j in range(N): E[i][j] = C[i] * D[j] // (10 ** (P * (N - 1))) % (10 ** P) # 곱셈 결과에서 각 요소 추출 return E # 결과 행렬 반환
4. 벤치마크
Strassen VS Shrohan Mohapatra(SM)
32 x 32 행렬
슈트라센: 0.16169s
SM: 0.00695s
N = 2 ** 5 A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] # 런타임 비교 from timeit import default_timer start = default_timer() result1 = strassen(A, B, N) print(f"{default_timer() - start:.5f}s") start = default_timer() result2 = mat(A, B) print(f"{default_timer() - start:.5f}s") print(result2 == result1)
256 x 256 행렬
슈트라센: 55.70412s
SM: 12.69966s
N = 2 ** 8 A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] # 런타임 비교 from timeit import default_timer start = default_timer() result1 = strassen(A, B, N) print(f"{default_timer() - start:.5f}s") start = default_timer() result2 = mat(A, B) print(f"{default_timer() - start:.5f}s") print(result2 == result1)
참고
슈트라센 위키피디아 @wikipedia
슈트라센 기법 쉽게 외우는 방법 @GeeksForGeeks
슈트라센 곱셈 설명 @Medium
Shrohan Mohapatra n^2 곱셈 기법 @Github
텐서 파워와 빠른 행렬 곱셈 by Le Gall (도쿄대학) @Youtube
텍사스 샌 안토니오 대학 자료 @링크
위 설명 논문들
728x90'컴퓨터 > 파이썬' 카테고리의 다른 글
Python: LRU 캐시 만들어보기 (0) 2020.11.16 파이썬: Monad (모나드) (1) 2020.11.11 파이썬 Sørensen–Dice coefficient (0) 2020.11.05