ABOUT ME

-

Total
-
  • 딥러닝 옵티마이저: Adabelief Optimizer
    컴퓨터/파이썬 2020. 10. 27. 12:48
    728x90
    반응형

    Adabelief

    v0.1.0

     

    Adapting Stepsizes by the Belief in Observed Gradients

    Adabelief Optimizer 설명

    juntang-zhuang.github.io

     

    1. 소개

    클릭하여 실행 (Beale) 속도가 정말 빠르다. @Zhuang et. al. 

     

    공식 소개

    Adam처럼 빠르고, SGD처럼 일반화 잘하고, GAN을 트레인 하기에 충분히 안정적이다.

     

    Adabelief는 Adam을 수정한 딥러닝 최적화 알고리즘이다.

    (실제로 Adam에서 한 줄만 바꿔도 됨) 

    • 더 빠른 트레이닝 수렴
    • 더 안정적인 트레이닝
    • 더 나은 일반화
    • 더 높은 모델 정확도

     

    2. Adam에서의 문제

    Adam 옵티마이저 알고리즘 @Zhuang et. al.

     

    SGD(확률적 경사 하강법)의 초반 트레이닝에서 수렴이 너무 느린 문제를 해결한 Adam.

     

    하지만 Adam은, 기울기(gradient)가 크지만, 분산(variance)이 작을 때,

    Adam이 작은 step size(혹은 학습률)을 예상한다는 문제가 있다. (아래 그림 참고)

     

    AdaBelief Optimizer @Zhuang et. al.

     

    3. Adam의 문제 해결

    Adam에서의 모멘텀 제곱 값은 아래와 같다.

    (또는, EMA, 지수 이동 평균)

    Adam EMA @Zhuang et. al.

     

    문제를 해결하기 위해, 모멘텀 제곱 값을 계산하기보단,

    서서히 기울기의 분산 값을 계산하는 방법으로 문제를 해결했다.

    AdaBelief EMA @Zhuang et. al.

     

    Belief(믿음, 신뢰)란 단어도 여기에서 나온 것이다.

    분산이 현재 추정된 모멘텀 값으로 계산되고,

    본질적으로 예측된(expected, believed) 기울기로부터 거리 제곱 값이기 때문이다.

     

    AdaBelief 옵티마이저 알고리즘 @Zhuang et. al.

     

    4. 벤치마크

    한 줄 정도의 식을 바꿨다고 결과가 크게 바뀔까? 생각하겠지만,

    아래 결과를 보면 엄청난 성능 효과를 기대할 수 있다.

     

    이미지 분석 (파란 선)

    @Zhuang et. al

     

    GAN (생성적 적대 신경망)

    낮을수록 좋음.

    @Zhuang et. al

     

    5. 코드 사용하기

    소스 코드는 논문 저자 Github에서 확인할 수 있다. @링크

    1. PyTorch

    AdaBelief Optimizer

    import math
    import torch
    from torch.optim.optimizer import Optimizer
    from tabulate import tabulate
    
    version_higher = torch.__version__ >= "1.5.0"
    
    
    class AdaBelief(Optimizer):
        r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
        Arguments:
            params (iterable): iterable of parameters to optimize or dicts defining
                parameter groups
            lr (float, optional): learning rate (default: 1e-3)
            betas (Tuple[float, float], optional): coefficients used for computing
                running averages of gradient and its square (default: (0.9, 0.999))
            eps (float, optional): term added to the denominator to improve
                numerical stability (default: 1e-16)
            weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
            amsgrad (boolean, optional): whether to use the AMSGrad variant of this
                algorithm from the paper `On the Convergence of Adam and Beyond`_
                (default: False)
            weight_decouple (boolean, optional): ( default: True) If set as True, then
                the optimizer uses decoupled weight decay as in AdamW
            fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
                is set as True.
                When fixed_decay == True, the weight decay is performed as
                $W_{new} = W_{old} - W_{old} \times decay$.
                When fixed_decay == False, the weight decay is performed as
                $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
                weight decay ratio decreases with learning rate (lr).
            rectify (boolean, optional): (default: True) If set as True, then perform the rectified
                update similar to RAdam
            degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
                when variance of gradient is high
        reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
        """
    
        def __init__(
            self,
            params,
            lr=1e-3,
            betas=(0.9, 0.999),
            eps=1e-16,
            weight_decay=0,
            amsgrad=False,
            weight_decouple=True,
            fixed_decay=False,
            rectify=True,
            degenerated_to_sgd=True,
        ):
    
            if not 0.0 <= lr:
                raise ValueError("Invalid learning rate: {}".format(lr))
            if not 0.0 <= eps:
                raise ValueError("Invalid epsilon value: {}".format(eps))
            if not 0.0 <= betas[0] < 1.0:
                raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
            if not 0.0 <= betas[1] < 1.0:
                raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
    
            self.degenerated_to_sgd = degenerated_to_sgd
            if (
                isinstance(params, (list, tuple))
                and len(params) > 0
                and isinstance(params[0], dict)
            ):
                for param in params:
                    if "betas" in param and (
                        param["betas"][0] != betas[0] or param["betas"][1] != betas[1]
                    ):
                        param["buffer"] = [[None, None, None] for _ in range(10)]
    
            defaults = dict(
                lr=lr,
                betas=betas,
                eps=eps,
                weight_decay=weight_decay,
                amsgrad=amsgrad,
                buffer=[[None, None, None] for _ in range(10)],
            )
            super(AdaBelief, self).__init__(params, defaults)
    
            self.degenerated_to_sgd = degenerated_to_sgd
            self.weight_decouple = weight_decouple
            self.rectify = rectify
            self.fixed_decay = fixed_decay
            if self.weight_decouple:
                print("Weight decoupling enabled in AdaBelief")
                if self.fixed_decay:
                    print("Weight decay fixed")
            if self.rectify:
                print("Rectification enabled in AdaBelief")
            if amsgrad:
                print("AMSGrad enabled in AdaBelief")
    
        def __setstate__(self, state):
            super(AdaBelief, self).__setstate__(state)
            for group in self.param_groups:
                group.setdefault("amsgrad", False)
    
        def reset(self):
            for group in self.param_groups:
                for p in group["params"]:
                    state = self.state[p]
                    amsgrad = group["amsgrad"]
    
                    # State initialization
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = (
                        torch.zeros_like(p.data, memory_format=torch.preserve_format)
                        if version_higher
                        else torch.zeros_like(p.data)
                    )
    
                    # Exponential moving average of squared gradient values
                    state["exp_avg_var"] = (
                        torch.zeros_like(p.data, memory_format=torch.preserve_format)
                        if version_higher
                        else torch.zeros_like(p.data)
                    )
    
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_exp_avg_var"] = (
                            torch.zeros_like(p.data, memory_format=torch.preserve_format)
                            if version_higher
                            else torch.zeros_like(p.data)
                        )
    
        def step(self, closure=None):
            """Performs a single optimization step.
            Arguments:
                closure (callable, optional): A closure that reevaluates the model
                    and returns the loss.
            """
            loss = None
            if closure is not None:
                loss = closure()
    
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None:
                        continue
                    grad = p.grad.data
                    if grad.is_sparse:
                        raise RuntimeError(
                            "AdaBelief does not support sparse gradients, please consider SparseAdam instead"
                        )
                    amsgrad = group["amsgrad"]
    
                    state = self.state[p]
    
                    beta1, beta2 = group["betas"]
    
                    # State initialization
                    if len(state) == 0:
                        state["step"] = 0
                        # Exponential moving average of gradient values
                        state["exp_avg"] = (
                            torch.zeros_like(p.data, memory_format=torch.preserve_format)
                            if version_higher
                            else torch.zeros_like(p.data)
                        )
                        # Exponential moving average of squared gradient values
                        state["exp_avg_var"] = (
                            torch.zeros_like(p.data, memory_format=torch.preserve_format)
                            if version_higher
                            else torch.zeros_like(p.data)
                        )
                        if amsgrad:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state["max_exp_avg_var"] = (
                                torch.zeros_like(
                                    p.data, memory_format=torch.preserve_format
                                )
                                if version_higher
                                else torch.zeros_like(p.data)
                            )
    
                    # get current state variable
                    exp_avg, exp_avg_var = state["exp_avg"], state["exp_avg_var"]
    
                    state["step"] += 1
                    bias_correction1 = 1 - beta1 ** state["step"]
                    bias_correction2 = 1 - beta2 ** state["step"]
    
                    # Update first and second moment running average
                    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                    grad_residual = grad - exp_avg
                    exp_avg_var.mul_(beta2).addcmul_(
                        grad_residual, grad_residual, value=1 - beta2
                    )
    
                    if amsgrad:
                        max_exp_avg_var = state["max_exp_avg_var"]
                        # Maintains the maximum of all 2nd moment running avg. till now
                        torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var)
    
                        # Use the max. for normalizing running avg. of gradient
                        denom = (
                            max_exp_avg_var.add_(group["eps"]).sqrt()
                            / math.sqrt(bias_correction2)
                        ).add_(group["eps"])
                    else:
                        denom = (
                            exp_avg_var.add_(group["eps"]).sqrt()
                            / math.sqrt(bias_correction2)
                        ).add_(group["eps"])
    
                    # perform weight decay, check if decoupled weight decay
                    if self.weight_decouple:
                        if not self.fixed_decay:
                            p.data.mul_(1.0 - group["lr"] * group["weight_decay"])
                        else:
                            p.data.mul_(1.0 - group["weight_decay"])
                    else:
                        if group["weight_decay"] != 0:
                            grad.add_(p.data, alpha=group["weight_decay"])
    
                    # update
                    if not self.rectify:
                        # Default update
                        step_size = group["lr"] / bias_correction1
                        p.data.addcdiv_(exp_avg, denom, value=-step_size)
    
                    else:  # Rectified update, forked from RAdam
                        buffered = group["buffer"][int(state["step"] % 10)]
                        if state["step"] == buffered[0]:
                            N_sma, step_size = buffered[1], buffered[2]
                        else:
                            buffered[0] = state["step"]
                            beta2_t = beta2 ** state["step"]
                            N_sma_max = 2 / (1 - beta2) - 1
                            N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
                            buffered[1] = N_sma
    
                            # more conservative since it's an approximated value
                            if N_sma >= 5:
                                step_size = math.sqrt(
                                    (1 - beta2_t)
                                    * (N_sma - 4)
                                    / (N_sma_max - 4)
                                    * (N_sma - 2)
                                    / N_sma
                                    * N_sma_max
                                    / (N_sma_max - 2)
                                ) / (1 - beta1 ** state["step"])
                            elif self.degenerated_to_sgd:
                                step_size = 1.0 / (1 - beta1 ** state["step"])
                            else:
                                step_size = -1
                            buffered[2] = step_size
    
                        if N_sma >= 5:
                            denom = exp_avg_var.sqrt().add_(group["eps"])
                            p.data.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
                        elif step_size > 0:
                            p.data.add_(exp_avg, alpha=-step_size * group["lr"])
    
            return loss
    

     

    CIFAR 예제

    # 파라미터는 Adam 것을 사용해도 됨.
    
    # ...
    
    optimizer = AdaBelief(
        model_params, args.lr, betas=(0.9, 0.999), weight_decay=5e-4, eps=1e-8, rectify=False
    )
    
    # ...
    
    
    def train(net, epoch, device, data_loader, optimizer, criterion, args):
        print("\nEpoch: %d" % epoch)
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
        accuracy = 100.0 * correct / total
        print("train acc %.3f" % accuracy)
    
        return accuracy
    
    
    # ...
    
    # ...
    
    
    def adjust_learning_rate(optimizer, epoch, step_size=150, gamma=0.1, reset=False):
        for param_group in optimizer.param_groups:
            if epoch % step_size == 0 and epoch > 0:
                param_group["lr"] *= gamma
    
        if epoch % step_size == 0 and epoch > 0 and reset:
            optimizer.reset()
    
    
    # ...
    

    트레이닝 정확도는 비슷하게 나옴.

     

    2. TensorFlow

    AdaBelief Optimizer

    """AdaBeliefOptimizer optimizer."""
    import tensorflow as tf
    from tensorflow_addons.utils.types import FloatTensorLike
    
    from typing import Union, Callable, Dict
    from typeguard import typechecked
    
    from tabulate import tabulate
    from colorama import Fore, Back, Style
    
    
    @tf.keras.utils.register_keras_serializable(package="Addons")
    class AdaBeliefOptimizer(tf.keras.optimizers.Optimizer):
        """
        It implements the AdaBeliefOptimizer proposed by
        Juntang Zhuang et al. in [AdaBeliefOptimizer Optimizer: Adapting stepsizes by the belief
        in observed gradients](https://arxiv.org/abs/2010.07468).
        Example of usage:
        ```python
        opt = tfa.optimizers.AdaBeliefOptimizer(lr=1e-3)
        ```
        Note: `amsgrad` is not described in the original paper. Use it with
              caution.
        AdaBeliefOptimizer is not a placement of the heuristic warmup, the settings should be
        kept if warmup has already been employed and tuned in the baseline method.
        You can enable warmup by setting `total_steps` and `warmup_proportion`:
        ```python
        opt = tfa.optimizers.AdaBeliefOptimizer(
            lr=1e-3,
            total_steps=10000,
            warmup_proportion=0.1,
            min_lr=1e-5,
        )
        ```
        In the above example, the learning rate will increase linearly
        from 0 to `lr` in 1000 steps, then decrease linearly from `lr` to `min_lr`
        in 9000 steps.
        Lookahead, proposed by Michael R. Zhang et.al in the paper
        [Lookahead Optimizer: k steps forward, 1 step back]
        (https://arxiv.org/abs/1907.08610v1), can be integrated with AdaBeliefOptimizer,
        which is announced by Less Wright and the new combined optimizer can also
        be called "Ranger". The mechanism can be enabled by using the lookahead
        wrapper. For example:
        ```python
        adabelief = tfa.optimizers.AdaBeliefOptimizer()
        ranger = tfa.optimizers.Lookahead(adabelief, sync_period=6, slow_step_size=0.5)
        ```
        """
    
        @typechecked
        def __init__(
            self,
            learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001,
            beta_1: FloatTensorLike = 0.9,
            beta_2: FloatTensorLike = 0.999,
            epsilon: FloatTensorLike = 1e-14,
            weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0,
            rectify: bool = True,
            amsgrad: bool = False,
            sma_threshold: FloatTensorLike = 5.0,
            total_steps: int = 0,
            warmup_proportion: FloatTensorLike = 0.1,
            min_lr: FloatTensorLike = 0.0,
            name: str = "AdaBeliefOptimizer",
            **kwargs
        ):
            r"""Construct a new AdaBelief optimizer.
            Args:
                learning_rate: A `Tensor` or a floating point value, or a schedule
                    that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
                    The learning rate.
                beta_1: A float value or a constant float tensor.
                    The exponential decay rate for the 1st moment estimates.
                beta_2: A float value or a constant float tensor.
                    The exponential decay rate for the 2nd moment estimates.
                epsilon: A small constant for numerical stability.
                weight_decay: A `Tensor` or a floating point value, or a schedule
                    that is a `tf.keras.optimizers.schedules.LearningRateSchedule`.
                    Weight decay for each parameter.
                rectify: boolean. Whether to enable rectification as in RectifiedAdam
                amsgrad: boolean. Whether to apply AMSGrad variant of this
                    algorithm from the paper "On the Convergence of Adam and
                    beyond".
                sma_threshold. A float value.
                    The threshold for simple mean average.
                total_steps: An integer. Total number of training steps.
                    Enable warmup by setting a positive value.
                warmup_proportion: A floating point value.
                    The proportion of increasing steps.
                min_lr: A floating point value. Minimum learning rate after warmup.
                name: Optional name for the operations created when applying
                    gradients. Defaults to "AdaBeliefOptimizer".
                **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                    `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients
                    by norm; `clipvalue` is clip gradients by value, `decay` is
                    included for backward compatibility to allow time inverse
                    decay of learning rate. `lr` is included for backward
                    compatibility, recommended to use `learning_rate` instead.
            """
            super().__init__(name, **kwargs)
    
            # ------------------------------------------------------------------------------
            # Print modifications to default arguments
            print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-tf from version 0.0.1.')
            print(Fore.RED + 'Modifications to default arguments:')
            default_table = tabulate([
                    ['adabelief-tf=0.0.1','1e-8','Not supported','Not supported'],
                    ['Current version (0.1.0)','1e-14','supported','default: True']],
                    headers=['eps','weight_decouple','rectify'])
            print(Fore.RED + default_table)
    
            print(Fore.RED +'For a complete table of recommended hyperparameters, see')
            print(Fore.RED + 'https://github.com/juntang-zhuang/Adabelief-Optimizer')
    
            print(Style.RESET_ALL)
            # ------------------------------------------------------------------------------
    
            if isinstance(learning_rate, Dict):
                learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate)
    
            if isinstance(weight_decay, Dict):
                weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay)
    
            self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
            self._set_hyper("beta_1", beta_1)
            self._set_hyper("beta_2", beta_2)
            self._set_hyper("decay", self._initial_decay)
            self._set_hyper("weight_decay", weight_decay)
            self._set_hyper("sma_threshold", sma_threshold)
            self._set_hyper("total_steps", int(total_steps))
            self._set_hyper("warmup_proportion", warmup_proportion)
            self._set_hyper("min_lr", min_lr)
            self.epsilon = epsilon or tf.keras.backend.epsilon()
            self.amsgrad = amsgrad
            self.rectify = rectify
            self._has_weight_decay = weight_decay != 0.0
            self._initial_total_steps = total_steps
    
        def _create_slots(self, var_list):
            for var in var_list:
                self.add_slot(var, "m")
            for var in var_list:
                self.add_slot(var, "v")
            for var in var_list:
                self.add_slot(var, "grad_dif")
            if self.amsgrad:
                for var in var_list:
                    self.add_slot(var, "vhat")
    
        def set_weights(self, weights):
            params = self.weights
            num_vars = int((len(params) - 1) / 2)
            if len(weights) == 4 * num_vars + 1:
                weights = weights[: len(params)]
            super().set_weights(weights)
    
        def _decayed_wd(self, var_dtype):
            wd_t = self._get_hyper("weight_decay", var_dtype)
            if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule):
                wd_t = tf.cast(wd_t(self.iterations), var_dtype)
            return wd_t
    
        def _resource_apply_dense(self, grad, var):
            var_dtype = var.dtype.base_dtype
            lr_t = self._decayed_lr(var_dtype)
            wd_t = self._decayed_wd(var_dtype)
            m = self.get_slot(var, "m")
            v = self.get_slot(var, "v")
            beta_1_t = self._get_hyper("beta_1", var_dtype)
            beta_2_t = self._get_hyper("beta_2", var_dtype)
            epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
            local_step = tf.cast(self.iterations + 1, var_dtype)
            beta_1_power = tf.pow(beta_1_t, local_step)
            beta_2_power = tf.pow(beta_2_t, local_step)
    
            if self._initial_total_steps > 0:
                total_steps = self._get_hyper("total_steps", var_dtype)
                warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
                min_lr = self._get_hyper("min_lr", var_dtype)
                decay_steps = tf.maximum(total_steps - warmup_steps, 1)
                decay_rate = (min_lr - lr_t) / decay_steps
                lr_t = tf.where(
                    local_step <= warmup_steps,
                    lr_t * (local_step / warmup_steps),
                    lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
                )
    
            sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
            sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)
    
            m_t = m.assign(
                beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking
            )
            m_corr_t = m_t / (1.0 - beta_1_power)
    
            grad_dif = self.get_slot(var,'grad_dif')
            grad_dif.assign( grad - m_t )
            v_t = v.assign(
                beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad_dif) + epsilon_t,
                use_locking=self._use_locking,
            )
            if self.amsgrad:
                vhat = self.get_slot(var, "vhat")
                vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
                v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
            else:
                vhat_t = None
                v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))
    
            r_t = tf.sqrt(
                (sma_t - 4.0)
                / (sma_inf - 4.0)
                * (sma_t - 2.0)
                / (sma_inf - 2.0)
                * sma_inf
                / sma_t
            )
    
            if self.rectify:
                sma_threshold = self._get_hyper("sma_threshold", var_dtype)
                var_t = tf.where(
                    sma_t >= sma_threshold, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t
                )
            else:
                var_t =  m_corr_t / (v_corr_t + epsilon_t)
    
            if self._has_weight_decay:
                var_t += wd_t * var
    
            var_update = var.assign_sub(lr_t * var_t, use_locking=self._use_locking)
    
            updates = [var_update, m_t, v_t]
            if self.amsgrad:
                updates.append(vhat_t)
            return tf.group(*updates)
    
        def _resource_apply_sparse(self, grad, var, indices):
            var_dtype = var.dtype.base_dtype
            lr_t = self._decayed_lr(var_dtype)
            wd_t = self._decayed_wd(var_dtype)
            beta_1_t = self._get_hyper("beta_1", var_dtype)
            beta_2_t = self._get_hyper("beta_2", var_dtype)
            epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
            local_step = tf.cast(self.iterations + 1, var_dtype)
            beta_1_power = tf.pow(beta_1_t, local_step)
            beta_2_power = tf.pow(beta_2_t, local_step)
    
            if self._initial_total_steps > 0:
                total_steps = self._get_hyper("total_steps", var_dtype)
                warmup_steps = total_steps * self._get_hyper("warmup_proportion", var_dtype)
                min_lr = self._get_hyper("min_lr", var_dtype)
                decay_steps = tf.maximum(total_steps - warmup_steps, 1)
                decay_rate = (min_lr - lr_t) / decay_steps
                lr_t = tf.where(
                    local_step <= warmup_steps,
                    lr_t * (local_step / warmup_steps),
                    lr_t + decay_rate * tf.minimum(local_step - warmup_steps, decay_steps),
                )
    
            sma_inf = 2.0 / (1.0 - beta_2_t) - 1.0
            sma_t = sma_inf - 2.0 * local_step * beta_2_power / (1.0 - beta_2_power)
    
            m = self.get_slot(var, "m")
            m_scaled_g_values = grad * (1 - beta_1_t)
            m_t = m.assign(m * beta_1_t, use_locking=self._use_locking)
            with tf.control_dependencies([m_t]):
                m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
            m_corr_t = m_t / (1.0 - beta_1_power)
    
            grad_dif = self.get_slot(var,'grad_dif')
            grad_dif.assign(m_t)
            grad_dif = self._resource_scatter_add(grad_dif, indices, -1.0 * grad)
    
            v = self.get_slot(var, "v")
            v_scaled_g_values = grad_dif * grad_dif * (1 - beta_2_t)
            v_t = v.assign(v * beta_2_t + epsilon_t, use_locking=self._use_locking)
            v_t = v.assign(v_t + v_scaled_g_values)
    
            if self.amsgrad:
                vhat = self.get_slot(var, "vhat")
                vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
                v_corr_t = tf.sqrt(vhat_t / (1.0 - beta_2_power))
            else:
                vhat_t = None
                v_corr_t = tf.sqrt(v_t / (1.0 - beta_2_power))
    
            r_t = tf.sqrt(
                (sma_t - 4.0)
                / (sma_inf - 4.0)
                * (sma_t - 2.0)
                / (sma_inf - 2.0)
                * sma_inf
                / sma_t
            )
    
            if self.rectify:
                sma_threshold = self._get_hyper("sma_threshold", var_dtype)
                var_t = tf.where(
                sma_t >= sma_threshold, r_t * m_corr_t / (v_corr_t + epsilon_t), m_corr_t
                )
            else:
                var_t =  m_corr_t / (v_corr_t + epsilon_t)
    
            if self._has_weight_decay:
                var_t += wd_t * var
    
            with tf.control_dependencies([var_t]):
                var_update = self._resource_scatter_add(
                    var, indices, tf.gather(-lr_t * var_t, indices)
                )
    
            updates = [var_update, m_t, v_t]
            if self.amsgrad:
                updates.append(vhat_t)
            return tf.group(*updates)
    
        def get_config(self):
            config = super().get_config()
            config.update(
                {
                    "learning_rate": self._serialize_hyperparameter("learning_rate"),
                    "beta_1": self._serialize_hyperparameter("beta_1"),
                    "beta_2": self._serialize_hyperparameter("beta_2"),
                    "decay": self._serialize_hyperparameter("decay"),
                    "weight_decay": self._serialize_hyperparameter("weight_decay"),
                    "sma_threshold": self._serialize_hyperparameter("sma_threshold"),
                    "epsilon": self.epsilon,
                    "amsgrad": self.amsgrad,
                    "rectify": self.rectify,
                    "total_steps": self._serialize_hyperparameter("total_steps"),
                    "warmup_proportion": self._serialize_hyperparameter(
                        "warmup_proportion"
                    ),
                    "min_lr": self._serialize_hyperparameter("min_lr"),
                }
            )
            return config

     

    CIFAR 예제

    optimizer = AdaBeliefOptimizer(learning_rate=0.001, epsilon=1e-14)  # Deafault epsilon for Adam is 1e-7 in tensorlfow, 1e-8 in pytorch
    
    # ...
    
    # Optimization process.
    def run_optimization(x, y):
        # Wrap computation inside a GradientTape for automatic differentiation.
        with tf.GradientTape() as g:
            # Forward pass.
            pred = conv_net(x, is_training=True)
            # Compute loss.
            loss = cross_entropy_loss(pred, y)
    
        # Variables to update, i.e. trainable variables.
        trainable_variables = conv_net.trainable_variables
    
        # Compute gradients.
        gradients = g.gradient(loss, trainable_variables)
    
        # Update W and b following gradients.
        optimizer.apply_gradients(zip(gradients, trainable_variables))
    
    # ...

     

    6. 참고

    논문 (@pdf 링크)

    논문 저자 github.io (@링크)

    딥러닝 옵티마이저 visualization (@Github 링크)

    미디엄 by Frank Odom (@Medium 링크)

     

    728x90

    댓글