ABOUT ME

-

  • Python: 행렬 곱셈 알고리즘 소개
    컴퓨터/파이썬 2020. 11. 15. 23:09
    728x90
    반응형

    행렬 (Matrix)

     

    행렬

    위키백과

    ko.wikipedia.org

     

    1. 소개

    행렬 n x n 을 곱하려면, 보통 for 문 3개, O(n^3) 시간 복잡도를 갖지만,

    arduino
    // 일반 행렬 곱셈 int** multiply(int** A, int** B, int n) { int** C = initializeMatrix(n); setToZero(C, n); // 전부 0으로 초기화 for(int i=0; i<n; i++) for(int j=0; j<n; j++) for(int k=0; k<n; k++) C[i][j] += A[i][k] * B[k][j]; return C; }

     

    슈트라센 알고리즘을 이용하면 O(n^2.81) 정도의 시간 복잡도를 갖게 된다. (재귀)

    즉, 7번의 곱셈과 18번의 덧셈/뺄셈으로 행렬 곱셈을 하는 방법이다.

    (하지만 이 방식은 2^n x 2^n 행렬에서만 통한다, 따라서 크기가 다르면 0으로 채워서 해야 함)

     

    행렬 곱셈: 시간 복잡도 역사

    1. O(n^2.48) : 1986년 슈트라센이 다시 내놓음
    2. O(n^2.376) : 1987년 Coppersmith와 Winograd 기법 @논문 링크
    3. O(n^2.373) : 2010년 Stothers 기법 @논문 링크
    4. O(n^2.3729) : 2012년 Vassilevska Williams 기법 @논문 링크
    5. O(n^2.3728639) : 2014년 Le Gall 기법 @논문 링크

     

    (위 텐서들은 laser 메소드를 통해 분석됨)

     

    이 글에선, O(n^2.81) 슈트라센 기법과,

    O(n^2)이라고 자랑하는 Shrohan Mohapatra분의 알고리즘을 사용해볼 것이다.

     

    2. 슈트라센 기법

    n X n 으로 두 행렬의 크기가 같아야 함!

    n은 2의 제곱 값이 여야함!

     

    먼저, 아래와 같이 2 x 2 행렬 2개를 곱한다고 가정하면,

     

    일반적으론, 아래처럼 곱을 해야 하지만,

    @샌 안토니오 대학 자료

     

    슈트라센은 아래처럼 계산을 하였다.

    (7번의 곱셈과 18번의 덧셈/뺄셈)

    @샌 안토니오 대학 자료

     

    참고로, O(n^2.376) 시간 복잡도의 1987년 Coppersmith와 Winograd 기법은,

    3번 덧셈/뺄셈 연산이 적은 아래 방식이다.

    @샌 안토니오 대학 자료

     

    본론으로 돌아가서, 슈트라센 기법을 사용하기 위해선, 덧셈과 뺄셈 함수를 만들어야 한다.

     

    덧셈, 뺄셈

    python
    # n x n, 0 값을 가진 행렬 def initMatrix(n): return [[0 for _ in range(n)] for _ in range(n)] def add(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] + M2[i][j] return temp def subtract(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] - M2[i][j] return temp

     

    기본 케이스

    1 x 1이면 그냥 곱해서 return

    python
    if n == 1: C = initMatrix(1) C[0][0] = A[0][0] * B[0][0] return C

     

    변수 초기화

    python
    C = initMatrix(n) k = n // 2 # 2 x 2 행렬이면 각각 1 x 1 서브 행렬임 A11 = initMatrix(k) A12 = initMatrix(k) A21 = initMatrix(k) A22 = initMatrix(k) B11 = initMatrix(k) B12 = initMatrix(k) B21 = initMatrix(k) B22 = initMatrix(k) for i in range(k): for j in range(k): A11[i][j] = A[i][j] A12[i][j] = A[i][k + j] A21[i][j] = A[k + i][j] A22[i][j] = A[k + i][k + j] B11[i][j] = B[i][j] B12[i][j] = B[i][k + j] B21[i][j] = B[k + i][j] B22[i][j] = B[k + i][k + j]

     

    P값 재귀

    python
    P1 = strassen(A11, subtract(B12, B22, k), k) P2 = strassen(add(A11, A12, k), B22, k) P3 = strassen(add(A21, A22, k), B11, k) P4 = strassen(A22, subtract(B21, B11, k), k) P5 = strassen(add(A11, A22, k), add(B11, B22, k), k) P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k) P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k)

     

    A * B = C

    python
    C11 = subtract(add(add(P5, P4, k), P6, k), P2, k) C12 = add(P1, P2, k) C21 = add(P3, P4, k) C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k) for i in range(k): for j in range(k): C[i][j] = C11[i][j] C[i][j + k] = C12[i][j] C[k + i][j] = C21[i][j] C[k + i][k + j] = C22[i][j]

     

    결과

    python
    A = [[1, 3], [7, 5]] B = [[6, 8], [4, 2]] print(strassen(A, B, 2)) """ [[18, 14], [62, 66]] """ # 128 x 128 행렬끼리 곱하면 8초가 걸린다.
    풀소스 확인
    python
    def initMatrix(n, m=None): m = m or n return [[0 for _ in range(m)] for _ in range(n)] def add(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] + M2[i][j] return temp def subtract(M1, M2, n): temp = initMatrix(n) for i in range(n): for j in range(n): temp[i][j] = M1[i][j] - M2[i][j] return temp def strassen(A, B, n): if n == 1: C = initMatrix(1) C[0][0] = A[0][0] * B[0][0] return C C = initMatrix(n) k = n // 2 A11 = initMatrix(k) A12 = initMatrix(k) A21 = initMatrix(k) A22 = initMatrix(k) B11 = initMatrix(k) B12 = initMatrix(k) B21 = initMatrix(k) B22 = initMatrix(k) for i in range(k): for j in range(k): A11[i][j] = A[i][j] A12[i][j] = A[i][k + j] A21[i][j] = A[k + i][j] A22[i][j] = A[k + i][k + j] B11[i][j] = B[i][j] B12[i][j] = B[i][k + j] B21[i][j] = B[k + i][j] B22[i][j] = B[k + i][k + j] P1 = strassen(A11, subtract(B12, B22, k), k) P2 = strassen(add(A11, A12, k), B22, k) P3 = strassen(add(A21, A22, k), B11, k) P4 = strassen(A22, subtract(B21, B11, k), k) P5 = strassen(add(A11, A22, k), add(B11, B22, k), k) P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k) P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k) C11 = subtract(add(add(P5, P4, k), P6, k), P2, k) C12 = add(P1, P2, k) C21 = add(P3, P4, k) C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k) for i in range(k): for j in range(k): C[i][j] = C11[i][j] C[i][j + k] = C12[i][j] C[k + i][j] = C21[i][j] C[k + i][k + j] = C22[i][j] return C

     

    3. Shrohan Mohapatra O(n^2) 기법

    양수에서만 작동함

     

    Vassilevska Williams, Le Gall 등 기법을 소개하지 않고, 이 기법을 소개하냐면,

    논문을 보면 알겠지만, 엄청나게 많은 수학적 지식과, 텐서 파워라는 무거운 데이터 구조를 이용해야 한다.

    따라서 위 알고리즘들이, 슈트라센말고는 프로그래밍 언어로 구현된 것이 없는 그 이유다.

     

    덧셈, 뺄셈, 함수를 따로 쓰지도 않고, log10을 이용해 신기하게 구한다.

    설명이 따로 없어서, 함수만 정리해서 가져왔다.

    python
    def quadratic(A, B): from math import log10 #### 길이가 같은지 확인 flag = True and (len(A) == len(B)) if not flag: return None # A B 크기가 다르면 None 반환 N = len(A) for i in range(N): flag = flag and (len(A[i]) == len(B[i])) if not flag: return None # A B 행의 크기가 다르면 None 반환 #### 최대값 찾기 maxi = 0 for i in range(N): for j in range(N): if maxi < A[i][j]: maxi = A[i][j] # A 최대값 갱신 if maxi < B[i][j]: maxi = B[i][j] # B 최대값 갱신 # 최대값을 기반으로 필요한 자릿수 계산 M = int(log10(maxi)) + 1 P = int(log10((10 ** (2 * M) - 1) * N)) + 1 # 결과 행렬 초기화 C, D, E = ( [0 for _ in range(N)], [0 for _ in range(N)], [[0 for _ in range(N)] for _ in range(N)], ) # 행렬 A B 수로 변환 for i in range(N): for j in range(N): C[i] = C[i] * (10 ** P) + A[i][j] # A 행을 하나의 수로 변환 for j in range(N): for i in range(N): D[j] = D[j] * (10 ** P) + B[N - 1 - i][j] # B 열을 하나의 수로 변환, 역순으로 # 수를 이용한 행렬 곱셈 수행 for i in range(N): for j in range(N): E[i][j] = C[i] * D[j] // (10 ** (P * (N - 1))) % (10 ** P) # 곱셈 결과에서 요소 추출 return E # 결과 행렬 반환

     

    4. 벤치마크

    Strassen VS Shrohan Mohapatra(SM)

     

    32 x 32 행렬

    슈트라센: 0.16169s

    SM: 0.00695s

    python
    N = 2 ** 5 A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] # 런타임 비교 from timeit import default_timer start = default_timer() result1 = strassen(A, B, N) print(f"{default_timer() - start:.5f}s") start = default_timer() result2 = mat(A, B) print(f"{default_timer() - start:.5f}s") print(result2 == result1)

     

    256 x 256 행렬

    슈트라센: 55.70412s

    SM: 12.69966s

    python
    N = 2 ** 8 A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)] # 런타임 비교 from timeit import default_timer start = default_timer() result1 = strassen(A, B, N) print(f"{default_timer() - start:.5f}s") start = default_timer() result2 = mat(A, B) print(f"{default_timer() - start:.5f}s") print(result2 == result1)

     

    참고

    슈트라센 위키피디아 @wikipedia

    슈트라센 기법 쉽게 외우는 방법 @GeeksForGeeks

    슈트라센 곱셈 설명 @Medium

    Shrohan Mohapatra n^2 곱셈 기법 @Github

    텐서 파워와 빠른 행렬 곱셈 by Le Gall (도쿄대학) @Youtube

    텍사스 샌 안토니오 대학 자료 @링크

    위 설명 논문들

    728x90

    '컴퓨터 > 파이썬' 카테고리의 다른 글

    댓글