ABOUT ME

-

Total
-
  • 멀티스레드 Phaser - flexible 동기화 장벽
    컴퓨터/JAVA 2024. 6. 12. 23:57
    728x90
    반응형

    소개

    java7 부터 도입된 Phaser.

     

    멀티스레드 프로그래밍을 하다 보면 여러 스레드가 일정 시점에서 동기화되어야 하는 상황이 있었다.

    ex) 여러 스레드가 동시에 시작해야 하거나, 특정 작업이 완료될 때까지 기다려야 하는 경우 (Go의 sync.WaitGroup 느낌)

    java에서는 CyclicBarrier나 CountDownLatch를 사용할 수 있지만, 등록된 party 수를 동적으로 변경할 수 없다.

     

    Phaser

    그럼 Phaser는 뭔가?

    java.util.concurrent에 들어있고, CyclicBarrier와 CountDownLatch의 슈퍼 set 느낌이다.

    CyclicBarrier: 여러 스레드가 미리 정의된 지점에서 동기화되고 독립적 작업 반복
            CountDownLatch: 하나 이상의 스레드가 다른 스레드가 수행하는 작업 set가 완료될 때까지 기다림
    파티 = 특정 작업을 수행하는 스레드나 작업 단위

     

    • 동적 파티 등록: 파티 (참여자) 수를 동적으로 변경할 수 있음
    • 반복 가능 동기화: 반복적으로 사용될 수 있어서, 각 반복마다 새로운 단계를 생성
    • 종료 조건: 특정 조건을 만족할 때 Phaser가 종료되도록 할 수 있음

     

    예제

    3개의 작업이 있다고 생각해 보자 (A, B, C)

    각 작업은: 작업 준비 -> 실행 -> 완료 후 기다림 -> 모든 작업 완료 시 다음 단계 진행을 거친다.

    (각 단계가 Phase임)

    # 시작 (3개 작업 등록)
    Phaser(Phase 0, Registered Parties: 3)
    
    # 준비 (준비 완료 후 도착 알림)
    [Phase 0]
    Task A: Arrived
    Task B: Arrived
    Task C: Arrived
    
    # 모든 작업 도착해서 페이즈 1 전환
    Phaser(Phase 1, Registered Parties: 3) 
    
    # 실행 (실행 완료 후 도착 알림)
    [Phase 1]
    Task A: Arrived
    Task B: Arrived
    Task C: Arrived
    
    # 모든 작업 도착해서 페이즈 2 전환
    Phaser(Phase 2, Registered Parties: 3)
    
    # 완료/기다림 (완료 후 도착 알림)
    [Phase 2]
    Task A: Arrived
    Task B: Arrived
    Task C: Arrived
    
    # 모든 작업 도착 페이즈 3 전환
    Phaser(Phase 3, Registered Parties: 3)

     

    위의 시나리오는 아래처럼 Java로 작성할 수 있다.

    import java.util.concurrent.Phaser;
    
    public class PhaserExample {
        public static void main(String[] args) {
            Phaser phaser = new Phaser(3); // 3개의 작업 등록
    
            // 작업 A
            new Thread(() -> {
                System.out.println("Task A: 준비 완료");
                phaser.arriveAndAwaitAdvance(); // Phase 0 완료
    
                System.out.println("Task A: 실행 중");
                phaser.arriveAndAwaitAdvance(); // Phase 1 완료
    
                System.out.println("Task A: 실행 완료, 대기 중");
                phaser.arriveAndAwaitAdvance(); // Phase 2 완료
    
                System.out.println("Task A: 종료");
            }).start();
    
            // 작업 B
            new Thread(() -> {
                System.out.println("Task B: 준비 완료");
                phaser.arriveAndAwaitAdvance(); // Phase 0 완료
    
                System.out.println("Task B: 실행 중");
                phaser.arriveAndAwaitAdvance(); // Phase 1 완료
    
                System.out.println("Task B: 실행 완료, 대기 중");
                phaser.arriveAndAwaitAdvance(); // Phase 2 완료
    
                System.out.println("Task B: 종료");
            }).start();
    
            // 작업 C
            new Thread(() -> {
                System.out.println("Task C: 준비 완료");
                phaser.arriveAndAwaitAdvance(); // Phase 0 완료
    
                System.out.println("Task C: 실행 중");
                phaser.arriveAndAwaitAdvance(); // Phase 1 완료
    
                System.out.println("Task C: 실행 완료, 대기 중");
                phaser.arriveAndAwaitAdvance(); // Phase 2 완료
    
                System.out.println("Task C: 종료");
            }).start();
        }
    }

     

    onAdvance에서 특정 조건을 만족하면 Phaser가 종료되게 하는 예제

    import java.util.concurrent.Phaser;
    
    public class AdvancedPhaserExample {
        public static void main(String[] args) {
            Phaser phaser = new Phaser() {
                @Override
                protected boolean onAdvance(int phase, int registeredParties) {
                    System.out.println("Phase " + phase + " 완료");
                    // 3단계를 완료하거나 등록된 파티가 0개일 때 종료
                    return phase >= 3 || registeredParties == 0;
                }
            };
    
            phaser.bulkRegister(3); // 3개의 파티 등록
    
            for (int i = 0; i < 3; i++) {
                new Thread(new Worker(phaser), "Thread " + i).start();
            }
        }
    
        static class Worker implements Runnable {
            private Phaser phaser;
    
            Worker(Phaser phaser) {
                this.phaser = phaser;
            }
    
            @Override
            public void run() {
                for (int i = 0; i < 3; i++) {
                    System.out.println(Thread.currentThread().getName() + " 단계 " + i + " 완료");
                    phaser.arriveAndAwaitAdvance(); // 각 단계별로 동기화
                }
            }
        }
    }

     

    Building H2O 문제

    Leetcode에 있는 문제인데 Semaphore만을 사용해서 풀 수도 있고, CyclicBarrier도 같이 써서 풀 수도 있다.

    import java.util.concurrent.Phaser;
    import java.util.concurrent.Semaphore;
    
    class H2O {
        private final Phaser phaser = new Phaser(3);  // 3개의 파티로 설정
        private final Semaphore semH = new Semaphore(2);  // 최대 2개의 수소 스레드 허용
        private final Semaphore semO = new Semaphore(1);  // 최대 1개의 산소 스레드 허용
    
        public H2O() {
        }
    
        public void hydrogen(Runnable releaseHydrogen) throws InterruptedException {
            semH.acquire();  // 수소 스레드가 접근할 수 있도록 허용
            phaser.arriveAndAwaitAdvance();  // 모든 파티가 도착할 때까지 대기
            try {
                // releaseHydrogen() 메서드는 "H"를 출력합니다. 이 줄을 변경하거나 제거하지 마세요.
                releaseHydrogen.run();
            } finally {
                semH.release();  // 수소 스레드가 끝나면 세마포어 해제
            }
        }
    
        public void oxygen(Runnable releaseOxygen) throws InterruptedException {
            semO.acquire();  // 산소 스레드가 접근할 수 있도록 허용
            phaser.arriveAndAwaitAdvance();  // 모든 파티가 도착할 때까지 대기
            try {
                // releaseOxygen() 메서드는 "O"를 출력합니다. 이 줄을 변경하거나 제거하지 마세요.
                releaseOxygen.run();
            } finally {
                semO.release();  // 산소 스레드가 끝나면 세마포어 해제
            }
        }
    }
    
    // 테스트 클래스
    public class H2OTest {
        public static void main(String[] args) {
            H2O h2o = new H2O();
            String water = "OOHHHH";
            
            Thread[] threads = new Thread[water.length()];
            for (int i = 0; i < water.length(); i++) {
                if (water.charAt(i) == 'H') {
                    threads[i] = new Thread(() -> {
                        try {
                            h2o.hydrogen(() -> System.out.print("H"));
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                        }
                    });
                } else {
                    threads[i] = new Thread(() -> {
                        try {
                            h2o.oxygen(() -> System.out.print("O"));
                        } catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                        }
                    });
                }
            }
    
            for (Thread thread : threads) {
                thread.start();
            }
    
            for (Thread thread : threads) {
                try {
                    thread.join();
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
    
            System.out.println();  // 줄 바꿈
        }
    }

     

    내부 코드

    내부 코드를 보면 Phaser의 상태는 하나의 64bit long 변수로 관리되고, 4가지의 비트 필드로 구성된다 (CAS 연산)

    1. unarrived: 아직 동기화 지점에 도착하지 않은 파티의 수 (0-15 비트)
    2. parties: 동기화 지점에 도달해야 하는 총 파티의 수 (16-31 비트)
    3. phase: 현재 Phase 번호 (32-62 비트)
    4. terminated: 종료 여부 (63비트)
        /**
         * Main implementation for methods arrive and arriveAndDeregister.
         * Manually tuned to speed up and minimize race windows for the
         * common case of just decrementing unarrived field.
         *
         * @param adjust value to subtract from state;
         *               ONE_ARRIVAL for arrive,
         *               ONE_DEREGISTER for arriveAndDeregister
         */
        // arrive和arriveAndDeregister的实现方法。手动减少未到达数
    private int doArrive(int adjust) {
    
    // 글쓴이: 현재 상태 가져오기 및 해석
        final Phaser root = this.root;
        for (;;) {
            long s = (root == this) ? state : reconcileState();
            int phase = (int)(s >>> PHASE_SHIFT);
            if (phase < 0)
                return phase;
            int counts = (int)s;
            //获取未到达数
            int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
            if (unarrived <= 0)
                throw new IllegalStateException(badArrive(s));
                
    // 글쓴이: 상태 업데이트 (CAS 연산, 파티가 마지막 도착 파티면 다음 단계 전환)
            if (UNSAFE.compareAndSwapLong(this, stateOffset, s, s-=adjust)) {//更新state
                if (unarrived == 1) {//当前为最后一个未到达的任务
                    long n = s & PARTIES_MASK;  // base of next state
                    int nextUnarrived = (int)n >>> PARTIES_SHIFT;
                    if (root == this) {
                        if (onAdvance(phase, nextUnarrived))//检查是否需要终止phaser
                            n |= TERMINATION_BIT;
                        else if (nextUnarrived == 0)
                            n |= EMPTY;
                        else
                            n |= nextUnarrived;
                        int nextPhase = (phase + 1) & MAX_PHASE;
                        n |= (long)nextPhase << PHASE_SHIFT;
                        UNSAFE.compareAndSwapLong(this, stateOffset, s, n);
                        releaseWaiters(phase);//释放等待phase的线程
                    }
                    //分层结构,使用父节点管理arrive
                    else if (nextUnarrived == 0) { //propagate deregistration
                        phase = parent.doArrive(ONE_DEREGISTER);
                        UNSAFE.compareAndSwapLong(this, stateOffset,
                                                  s, s | EMPTY);
                    }
                    else
                        phase = parent.doArrive(ONE_ARRIVAL);
                }
                return phase;
            }
        }
    }

     

    @Phaser 텍스트 시각화 블로그 글

    @Leetcode 풀어보면 좋은 문제

     

    근데 Foobar 문제를 Phaser 만으로 풀면 에러가 난다.

    class FooBar {
        private int n;
        private final Phaser phaser;
    
        public FooBar(int n) {
            this.n = n;
            this.phaser = new Phaser(2); // 두 스레드 (foo, bar)를 위한 Phaser
        }
    
        public void foo(Runnable printFoo) {
            for (int i = 0; i < n; i++) {
                phaser.arriveAndAwaitAdvance(); // 다른 스레드와 동기화
                // printFoo.run() outputs "foo". Do not change or remove this line.
                printFoo.run();
                phaser.arriveAndAwaitAdvance(); // 다른 스레드에게 신호
            }
        }
    
        public void bar(Runnable printBar) {
            for (int i = 0; i < n; i++) {
                phaser.arriveAndAwaitAdvance(); // 다른 스레드와 동기화
                // printBar.run() outputs "bar". Do not change or remove this line.
                printBar.run();
                phaser.arriveAndAwaitAdvance(); // 다른 스레드에게 신호
            }
        }
    }
    728x90

    댓글