ABOUT ME

-

Total
-
  • 딥러닝: Mish 활성화 함수, 모델 불러오기
    컴퓨터/파이썬 2020. 10. 29. 23:53
    728x90
    반응형

    Mish

     

    Mish

    OMish: A Self Regularized Non-Monotonic Neural Activation Function

    github.com

    BMVC 2020 (@공식 논문 pdf 링크)

     

    1. 소개

    mish

    Activation Function (활성화 함수) 중 하나인 Mish는

    Swish와 ReLU 보다 전체적으로 좀 더 빠르고 좋은 활성화 함수이다.

    (소개할 때 최종 정확도에서, Swish (+.494%), ReLU (+1.671%) 라고 함)

     

    Mish의 식은 아래와 같고, (forward) 아래 그래프를 그린다.

    (참고: ReLU = $max(0, x)$ | Swish = $x * sigmoid(x)$)

    # Pytorch
    y = x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

    $x * tanh(ln(1 + exp(x)))$

     

    2. PyTorch 사용법

     

    아래 코드는 Swish와 Mish가 같이 담겨있다.

    import MishAuto, SwishAuto
    import torch
    from torch import nn as nn
    from torch.nn import functional as F
    
    
    __all__ = ["swish_auto", "SwishAuto", "mish_auto", "MishAuto"]
    
    
    class SwishAutoFn(torch.autograd.Function):
        """Swish - Described in: https://arxiv.org/abs/1710.05941
        Memory efficient variant from:
         https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
        """
    
        @staticmethod
        def forward(ctx, x):
            result = x.mul(torch.sigmoid(x))
            ctx.save_for_backward(x)
            return result
    
        @staticmethod
        def backward(ctx, grad_output):
            x = ctx.saved_tensors[0]
            x_sigmoid = torch.sigmoid(x)
            return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid)))
    
    
    def swish_auto(x, inplace=False):
        # inplace ignored
        return SwishAutoFn.apply(x)
    
    
    class SwishAuto(nn.Module):
        def __init__(self, inplace: bool = False):
            super(SwishAuto, self).__init__()
            self.inplace = inplace
    
        def forward(self, x):
            return SwishAutoFn.apply(x)
    
    
    class MishAutoFn(torch.autograd.Function):
        """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
        Experimental memory-efficient variant
        """
    
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            y = x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))
            return y
    
        @staticmethod
        def backward(ctx, grad_output):
            x = ctx.saved_tensors[0]
            x_sigmoid = torch.sigmoid(x)
            x_tanh_sp = F.softplus(x).tanh()
            return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
    
    
    def mish_auto(x, inplace=False):
        # inplace ignored
        return MishAutoFn.apply(x)
    
    
    class MishAuto(nn.Module):
        def __init__(self, inplace: bool = False):
            super(MishAuto, self).__init__()
            self.inplace = inplace
    
        def forward(self, x):
            return MishAutoFn.apply(x)
    

     

    ResNet에 적용

    MXResNet 버전 링크 (@Github)

    ReLU, SwishAuto를 사용하고 싶으면 Act를 변경하면 된다.

    """
    .. Deep Residual Learning for Image Recognition:
        https://arxiv.org/abs/1512.03385
    """
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    from activations import MishAuto, SwishAuto
    
    # Act = nn.ReLU
    Act = MishAuto
    # Act = SwishAuto
    
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, in_planes, planes, stride=1):
            super(BasicBlock, self).__init__()
            self.conv1 = nn.Conv2d(
                in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
            )
            self.bn1 = nn.BatchNorm2d(planes)
            self.relu = Act(inplace=True)
            self.conv2 = nn.Conv2d(
                planes, planes, kernel_size=3, stride=1, padding=1, bias=False
            )
            self.bn2 = nn.BatchNorm2d(planes)
    
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion * planes:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(
                        in_planes,
                        self.expansion * planes,
                        kernel_size=1,
                        stride=stride,
                        bias=False,
                    ),
                    nn.BatchNorm2d(self.expansion * planes),
                )
    
        def forward(self, x):
            out = self.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out += self.shortcut(x)
            out = self.relu(out)
            return out
    
    
    class Bottleneck(nn.Module):
        expansion = 4
    
        def __init__(self, in_planes, planes, stride=1):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(
                planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
            )
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(
                planes, self.expansion * planes, kernel_size=1, bias=False
            )
            self.bn3 = nn.BatchNorm2d(self.expansion * planes)
            self.relu = Act(inplace=True)
    
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion * planes:
                self.shortcut = nn.Sequential(
                    nn.Conv2d(
                        in_planes,
                        self.expansion * planes,
                        kernel_size=1,
                        stride=stride,
                        bias=False,
                    ),
                    nn.BatchNorm2d(self.expansion * planes),
                )
    
        def forward(self, x):
            out = self.relu(self.bn1(self.conv1(x)))
            out = self.relu(self.bn2(self.conv2(out)))
            out = self.bn3(self.conv3(out))
            out += self.shortcut(x)
            out = self.relu(out)
            return out
    
    
    class ResNetMish(nn.Module):
        def __init__(self, block, num_blocks, num_classes=10):
            super(ResNetMish, self).__init__()
            self.in_planes = 64
    
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = Act(inplace=True)
    
            self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
            self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
            self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
            self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
            self.linear = nn.Linear(512 * block.expansion, num_classes)
    
        def _make_layer(self, block, planes, num_blocks, stride):
            strides = [stride] + [1] * (num_blocks - 1)
            layers = []
            for stride in strides:
                layers.append(block(self.in_planes, planes, stride))
                self.in_planes = planes * block.expansion
            return nn.Sequential(*layers)
    
        def forward(self, x):
            out = self.relu(self.bn1(self.conv1(x)))
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out
    
    
    def ResNetMish18():
        return ResNetMish(BasicBlock, [2, 2, 2, 2])
    
    
    def ResNetMish34():
        return ResNetMish(BasicBlock, [3, 4, 6, 3])
    
    
    def ResNetMish50():
        return ResNetMish(Bottleneck, [3, 4, 6, 3])
    
    
    def ResNetMish101():
        return ResNetMish(Bottleneck, [3, 4, 23, 3])
    
    
    def ResNetMish152():
        return ResNetMish(Bottleneck, [3, 8, 36, 3])
    
    
    def test():
        net = ResNetMish18()
        y = net(torch.randn(1, 3, 32, 32))
        print(y.size())
    
    
    # test()
    

     

    훈련 (ResNetMish + AdaBelief)

    옵티마이저 = @AdaBelief

     

    딥러닝 옵티마이저: Adabelief Optimizer

    Adapting Stepsizes by the Belief in Observed Gradients

    choiseokwon.tistory.com

    learning_rate = $1e-3$

    eps = $1e-8$ (AdaBelief 기본은 $1e-16$)

    total_epoch = 50

    python main.py --model resnet_mish --optim adabelief --lr 1e-3 --eps 1e-8 --beta1 0.9 --beta2 0.999 --momentum 0.9 --total_epoch 50 --rectify False

     

    50번 중 45번째 체크포인트를 이용 (accuracy: 92.7)

    checkpoint에는 net, optimizer, acc, epoch를 저장했다.

    state = {
        "net": net.state_dict(),
        "optimizer": optimizer.state_dict(),
        "acc": test_acc,
        "epoch": epoch,
    }
    
    torch.save(state, os.path.join("checkpoint", ckpt_name + ".tar"))
    
    Epoch: #45
    
    train acc 98.816
     test acc 92.700
     
    Time: 129.36990s
    
    Saving...
    Model: resnet_mish, Optimizer: adabelief, ACC: 92.7%, Epoch at 45 saved

     

    체크포인트 모듈 불러오기

    import torch
    
    checkpoint = torch.load("mish_adabelief.tar")
    
    net = ResNetMish34()  # num_blocks = [3, 4, 6, 3]
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = net.to(device)
    if device == "cuda":
        net = torch.nn.DataParallel(net)
    
    net.load_state_dict(checkpoint['net'])  # net 상태 불러오기
    
    accuracy = checkpoint["acc"]  # 92.7
    
    net.eval()
    

     

    테스트 이미지 불러오기 (CIFAR)

    import torchvision
    import torchvision.transforms as transforms
    
    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    
    testset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2  # 4장 이용
    )
    

     

    테스트 이미지 보기

    import matplotlib.pyplot as plt
    import numpy as np
    
    from random import randint
    
    def imshow(img):
        img = img / 2 + 0.5
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
    dataiter = iter(test_loader)
    for _ in range(randint(0, len(test_loader) - 2)):  # 아무거나 불러오기
      dataiter.next()
    images, labels = dataiter.next()
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    imshow(torchvision.utils.make_grid(images))
    print('Ground Truth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    plt.show()
    

     

    모델 테스트

    잘 나오는 듯하다.

    outputs = net(images)
    
    _, predicted = torch.max(outputs, 1)
    
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                                  for j in range(4)))
    

     

    3. 참고

    @Mish 공식 Github 홈페이지

     

    728x90

    댓글