ABOUT ME

-

Total
-
  • Python: 행렬 곱셈 알고리즘 소개
    컴퓨터/파이썬 2020. 11. 15. 23:09
    728x90
    반응형

    행렬 (Matrix)

     

    행렬

    위키백과

    ko.wikipedia.org

     

    1. 소개

    행렬 n x n 을 곱하려면, 보통 for 문 3개, O(n^3) 시간 복잡도를 갖지만,

    // 일반 행렬 곱셈
    int** multiply(int** A, int** B, int n) {
            int** C = initializeMatrix(n);
            setToZero(C, n);  // 전부 0으로 초기화
            for(int i=0; i<n; i++)
                    for(int j=0; j<n; j++)
                            for(int k=0; k<n; k++)
                                    C[i][j] += A[i][k] * B[k][j];
            return C;
    }

     

    슈트라센 알고리즘을 이용하면 O(n^2.81) 정도의 시간 복잡도를 갖게 된다. (재귀)

    즉, 7번의 곱셈과 18번의 덧셈/뺄셈으로 행렬 곱셈을 하는 방법이다.

    (하지만 이 방식은 2^n x 2^n 행렬에서만 통한다, 따라서 크기가 다르면 0으로 채워서 해야 함)

     

    행렬 곱셈: 시간 복잡도 역사

    1. O(n^2.48) : 1986년 슈트라센이 다시 내놓음
    2. O(n^2.376) : 1987년 Coppersmith와 Winograd 기법 @논문 링크
    3. O(n^2.373) : 2010년 Stothers 기법 @논문 링크
    4. O(n^2.3729) : 2012년 Vassilevska Williams 기법 @논문 링크
    5. O(n^2.3728639) : 2014년 Le Gall 기법 @논문 링크

     

    (위 텐서들은 laser 메소드를 통해 분석됨)

     

    이 글에선, O(n^2.81) 슈트라센 기법과,

    O(n^2)이라고 자랑하는 Shrohan Mohapatra분의 알고리즘을 사용해볼 것이다.

     

    2. 슈트라센 기법

    n X n 으로 두 행렬의 크기가 같아야 함!

    n은 2의 제곱 값이 여야함!

     

    먼저, 아래와 같이 2 x 2 행렬 2개를 곱한다고 가정하면,

     

    일반적으론, 아래처럼 곱을 해야 하지만,

    @샌 안토니오 대학 자료

     

    슈트라센은 아래처럼 계산을 하였다.

    (7번의 곱셈과 18번의 덧셈/뺄셈)

    @샌 안토니오 대학 자료

     

    참고로, O(n^2.376) 시간 복잡도의 1987년 Coppersmith와 Winograd 기법은,

    3번 덧셈/뺄셈 연산이 적은 아래 방식이다.

    @샌 안토니오 대학 자료

     

    본론으로 돌아가서, 슈트라센 기법을 사용하기 위해선, 덧셈과 뺄셈 함수를 만들어야 한다.

     

    덧셈, 뺄셈

    # n x n, 0 값을 가진 행렬
    def initMatrix(n):
        return [[0 for _ in range(n)] for _ in range(n)]
    
    
    def add(M1, M2, n):
        temp = initMatrix(n)
        for i in range(n):
            for j in range(n):
                temp[i][j] = M1[i][j] + M2[i][j]
        return temp
    
    
    def subtract(M1, M2, n):
        temp = initMatrix(n)
        for i in range(n):
            for j in range(n):
                temp[i][j] = M1[i][j] - M2[i][j]
        return temp

     

    기본 케이스

    1 x 1이면 그냥 곱해서 return

    if n == 1:
        C = initMatrix(1)
        C[0][0] = A[0][0] * B[0][0]
        return C

     

    변수 초기화

    C = initMatrix(n)
    k = n // 2  # 2 x 2 행렬이면 각각 1 x 1 서브 행렬임
    
    A11 = initMatrix(k)
    A12 = initMatrix(k)
    A21 = initMatrix(k)
    A22 = initMatrix(k)
    B11 = initMatrix(k)
    B12 = initMatrix(k)
    B21 = initMatrix(k)
    B22 = initMatrix(k)
    
    
    for i in range(k):
        for j in range(k):
            A11[i][j] = A[i][j]
            A12[i][j] = A[i][k + j]
            A21[i][j] = A[k + i][j]
            A22[i][j] = A[k + i][k + j]
    
            B11[i][j] = B[i][j]
            B12[i][j] = B[i][k + j]
            B21[i][j] = B[k + i][j]
            B22[i][j] = B[k + i][k + j]

     

    P값 재귀

    P1 = strassen(A11, subtract(B12, B22, k), k)
    P2 = strassen(add(A11, A12, k), B22, k)
    P3 = strassen(add(A21, A22, k), B11, k)
    P4 = strassen(A22, subtract(B21, B11, k), k)
    P5 = strassen(add(A11, A22, k), add(B11, B22, k), k)
    P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k)
    P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k)

     

    A * B = C

    C11 = subtract(add(add(P5, P4, k), P6, k), P2, k)
    C12 = add(P1, P2, k)
    C21 = add(P3, P4, k)
    C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k)
    
    for i in range(k):
        for j in range(k):
            C[i][j] = C11[i][j]
            C[i][j + k] = C12[i][j]
            C[k + i][j] = C21[i][j]
            C[k + i][k + j] = C22[i][j]

     

    결과

    A = [[1, 3], [7, 5]]
    B = [[6, 8], [4, 2]]
    
    print(strassen(A, B, 2))
    
    """
    [[18, 14],
     [62, 66]]
    """
    
    # 128 x 128 행렬끼리 곱하면 약 8초가 걸린다.
    
    더보기
    def initMatrix(n, m=None):
        m = m or n
        return [[0 for _ in range(m)] for _ in range(n)]
    
    
    def add(M1, M2, n):
        temp = initMatrix(n)
        for i in range(n):
            for j in range(n):
                temp[i][j] = M1[i][j] + M2[i][j]
        return temp
    
    
    def subtract(M1, M2, n):
        temp = initMatrix(n)
        for i in range(n):
            for j in range(n):
                temp[i][j] = M1[i][j] - M2[i][j]
        return temp
    
    
    def strassen(A, B, n):
        if n == 1:
            C = initMatrix(1)
            C[0][0] = A[0][0] * B[0][0]
            return C
    
        C = initMatrix(n)
        k = n // 2
    
        A11 = initMatrix(k)
        A12 = initMatrix(k)
        A21 = initMatrix(k)
        A22 = initMatrix(k)
        B11 = initMatrix(k)
        B12 = initMatrix(k)
        B21 = initMatrix(k)
        B22 = initMatrix(k)
    
        for i in range(k):
            for j in range(k):
                A11[i][j] = A[i][j]
                A12[i][j] = A[i][k + j]
                A21[i][j] = A[k + i][j]
                A22[i][j] = A[k + i][k + j]
    
                B11[i][j] = B[i][j]
                B12[i][j] = B[i][k + j]
                B21[i][j] = B[k + i][j]
                B22[i][j] = B[k + i][k + j]
    
        P1 = strassen(A11, subtract(B12, B22, k), k)
        P2 = strassen(add(A11, A12, k), B22, k)
        P3 = strassen(add(A21, A22, k), B11, k)
        P4 = strassen(A22, subtract(B21, B11, k), k)
        P5 = strassen(add(A11, A22, k), add(B11, B22, k), k)
        P6 = strassen(subtract(A12, A22, k), add(B21, B22, k), k)
        P7 = strassen(subtract(A11, A21, k), add(B11, B12, k), k)
    
        C11 = subtract(add(add(P5, P4, k), P6, k), P2, k)
        C12 = add(P1, P2, k)
        C21 = add(P3, P4, k)
        C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k)
    
        for i in range(k):
            for j in range(k):
                C[i][j] = C11[i][j]
                C[i][j + k] = C12[i][j]
                C[k + i][j] = C21[i][j]
                C[k + i][k + j] = C22[i][j]
    
        return C

     

    3. Shrohan Mohapatra O(n^2) 기법

    양수에서만 작동함

     

    Vassilevska Williams, Le Gall 등 기법을 소개하지 않고, 이 기법을 소개하냐면,

    논문을 보면 알겠지만, 엄청나게 많은 수학적 지식과, 텐서 파워라는 무거운 데이터 구조를 이용해야 한다.

    따라서 위 알고리즘들이, 슈트라센말고는 프로그래밍 언어로 구현된 것이 없는 그 이유다.

     

    덧셈, 뺄셈, 함수를 따로 쓰지도 않고, log10을 이용해 신기하게 구한다.

    설명이 따로 없어서, 함수만 정리해서 가져왔다.

    def quadratic(A, B):
        from math import log10
    
        #### 길이가 같은지 확인
        flag = True and (len(A) == len(B))
        if not flag:
            return None  # A와 B의 크기가 다르면 None을 반환
    
        N = len(A)
        for i in range(N):
            flag = flag and (len(A[i]) == len(B[i]))
    
        if not flag:
            return None  # A와 B의 각 행의 크기가 다르면 None을 반환
        
        #### 최대값 찾기
        maxi = 0
        for i in range(N):
            for j in range(N):
                if maxi < A[i][j]:
                    maxi = A[i][j]  # A의 최대값 갱신
                if maxi < B[i][j]:
                    maxi = B[i][j]  # B의 최대값 갱신
    
        # 최대값을 기반으로 필요한 자릿수 계산
        M = int(log10(maxi)) + 1
        P = int(log10((10 ** (2 * M) - 1) * N)) + 1
    
        # 결과 행렬 초기화
        C, D, E = (
            [0 for _ in range(N)],
            [0 for _ in range(N)],
            [[0 for _ in range(N)] for _ in range(N)],
        )
    
        # 행렬 A와 B를 큰 수로 변환
        for i in range(N):
            for j in range(N):
                C[i] = C[i] * (10 ** P) + A[i][j]  # A의 각 행을 하나의 큰 수로 변환
        for j in range(N):
            for i in range(N):
                D[j] = D[j] * (10 ** P) + B[N - 1 - i][j]  # B의 각 열을 하나의 큰 수로 변환, 역순으로
    
        # 큰 수를 이용한 행렬 곱셈 수행
        for i in range(N):
            for j in range(N):
                E[i][j] = C[i] * D[j] // (10 ** (P * (N - 1))) % (10 ** P)  # 곱셈 결과에서 각 요소 추출
    
        return E  # 결과 행렬 반환

     

    4. 벤치마크

    Strassen VS Shrohan Mohapatra(SM)

     

    32 x 32 행렬

    슈트라센: 0.16169s

    SM: 0.00695s

    N = 2 ** 5
    A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)]
    B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)]
    
    # 런타임 비교
    from timeit import default_timer
    
    start = default_timer()
    result1 = strassen(A, B, N)
    print(f"{default_timer() - start:.5f}s")
    
    start = default_timer()
    result2 = mat(A, B)
    print(f"{default_timer() - start:.5f}s")
    print(result2 == result1)
    

     

    256 x 256 행렬

    슈트라센: 55.70412s

    SM: 12.69966s

    N = 2 ** 8
    A = [[randint(1, 1000) for _ in range(N)] for _ in range(N)]
    B = [[randint(1, 1000) for _ in range(N)] for _ in range(N)]
    
    # 런타임 비교
    from timeit import default_timer
    
    start = default_timer()
    result1 = strassen(A, B, N)
    print(f"{default_timer() - start:.5f}s")
    
    start = default_timer()
    result2 = mat(A, B)
    print(f"{default_timer() - start:.5f}s")
    print(result2 == result1)
    

     

    참고

    슈트라센 위키피디아 @wikipedia

    슈트라센 기법 쉽게 외우는 방법 @GeeksForGeeks

    슈트라센 곱셈 설명 @Medium

    Shrohan Mohapatra n^2 곱셈 기법 @Github

    텐서 파워와 빠른 행렬 곱셈 by Le Gall (도쿄대학) @Youtube

    텍사스 샌 안토니오 대학 자료 @링크

    위 설명 논문들

    728x90

    '컴퓨터 > 파이썬' 카테고리의 다른 글

    Python: LRU 캐시 만들어보기  (0) 2020.11.16
    파이썬: Monad (모나드)  (1) 2020.11.11
    파이썬 Sørensen–Dice coefficient  (0) 2020.11.05

    댓글