본문 바로가기
Network

[DNN] Swin Transformer 리뷰 및 구현 (ICCV 2021)

by pulluper 2023. 4. 12.
반응형

안녕하세요 pulluper 입니다. 

이번 포스팅에서는 ICCV2021 발표 후 많은 비전 모델의 백본으로 사용되고 있는

swin transformer 논문에 대하여 알아보겠습니다. 

https://arxiv.org/abs/2103.14030

 

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as

arxiv.org

 

논문의 전체 이름은 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 입니다. 

shifted window를 이용한다는 점과 계층적 구조를 이용한다는 것을 알 수 있습니다. 

 


문제제기

 

먼저 논문에서 제시하는 문제를 보겠습니다. 

abstract 에서 다음과 같이 언급되며, 

introduction에서는 아래와 같이 패러프레이징이 됩니다. 

 

즉, 정리하면 vision task는 다음 두가지 이유 때문에 Transformer를 적용하기 어렵다 합니다. 

 

1. 스케일의 큰 변동성

2. 높은 해상도의 이미지

 

논문에서는 작은 사이즈의 패치로 나누어 깊어질수록 머징을 하는 계층구조와,

높은 해상도때문에 생기는 quadratic computational cost 문제를

window 크기를 기준으로 줄일수 있다고 합니다. (hw -> M)

이후 나오는 다른 디테일들은 다음 전체 네트워크를 보며 설명해 보겠습니다. 


네트워크

 

네트워크는 위의 그림에 있는 요소들인 

1. Patch Partition

2. Linear Embedding

3. Swin Transformer Block

4. Patch Merging

이 네가지를 분석 및 구현을 하며 알아 보겠습니다. 


1. Patch Partition + Linear Embedding

 

기존의 VIT는 하나의 토큰을 16 x 16 의 크기로 나누었는데, Swin Transformer는

Tiny 기준 patch를 4 x 4로 설정합니다. 

그렇다면 다음그림과 같이 sequence 의 길이는 56 x 56 = 3146이 됩니다. 

코드는 다음과 같습니다.

VIT와 비슷하게 convolution의 kernel과 stride를 같게 하여 patch를 뽑습니다.

이때 convolution의 channel을 96(tiny 기준)으로 만들어 linear embedding까지 합니다.  

 

class PatchPartition(nn.Module):
    def __init__(self,
                 patch_size: int = 4,
                 ):
        """
        this patch partition + Linear Embedding
        :param patch_size:
        """
        super().__init__()
        self.proj = nn.Conv2d(3, 96, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(96)

    def forward(self, x):
        x = self.proj(x)                  # [B, 96, 56, 56]
        x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

 


2. Swin Transformer Block

 

Swin Transformer Block은 다음 두가지 attention을 이용합니다. 

 

1) W-MSA : window로 잘라서 window 내부의 sequence 끼리 attention 하는것 

2) SW-MSA : window를 움직여서 W-MSA 하는것 

 

2-1) W-MSA

W-MSA 부분의 구현에 대하여 설명하기 위해서 논문의 

Efficient batch computation for shifted configuration 챕터와 연관하여 설명해 보겠습니다. 

 

일단 현재 우리는 Patch Partition + Linear Embedding 으로 인해

[B, 96, 56, 56] 의 크기의 텐서를 가지고 있습니다. 

여리서 size가 7인 window로 자르면, [B, 96, 8, 7, 8, 7] 의 크기로 구분가능합니다. 

여기서 7 x 7의 크기를 갖는 area를 또다른 batch로 처리를 해서 다음의 shape으로 만듭니다.

[B', 96, 7, 7] (B' = B x 8 x 8) 로 만든후 연산을 진행하고 다시 원래의 크기로 돌려줍니다. 

이 방식을 Efficient batch computation 라고 하며 다음 그림으로 이해 할 수 있습니다.

 

W-MSA 구현은 다음과 같습니다.

(참고로 가독성을 위해 relative positional encoding 부분은 제외한 상태입니다.)

 

class W_MSA(nn.Module):
    def __init__(self,
                 dim, num_heads, head_dim=None, window_size=7,
                 qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = head_dim or dim // num_heads
        attn_dim = head_dim * num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(attn_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # setting
        B, L, C = x.shape
        ws = self.window_size
        w = h = int(math.sqrt(L))
        h_ = int(h // ws)
        w_ = int(w // ws)

        # [B, 3136, C]
        # ----------- efficient batch computation for shifted configuration -----------
        x = x.view(B, h, w, C)                        # [B, H, W, C]
        x = x.view(B, h_, ws, w_, ws, C)              # [0, 1, 2, 3, 4, 5 ] -> [0, 1, 3, 2, 4, 5 ] - idx
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [B, 8, 7, 8, 7, 96] -> [B, 8, 8, 7, 7, 96]
        x = x.view(B * h_ * w_, ws * ws, C)           # [B' = B x 8 x 8],   -> [B'         49, 96]

        # ------------------------------ attention ------------------------------
        B_, N, C = x.shape                            # [B_, 49, 96]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)                                                           # [B_, 49, 96]

        # ---------- make multi-batch tensor original batch tensor ----------v
        x = x.view(B, h_, w_, ws, ws, C)              # [B, 8, 8, 7, 7, 96]
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [B, 8, 7, 8, 7, 96]
        x = x.view(B, h, w, -1)               # (roll)  [B, 56, 56, 96]
        x = x.view(B, h * w, C)                       # [B, 56, 56, 96]
        return x

 

2-2) SW-MSA

 

Shifted Window MSA입니다. 

W-MSA만 사용한다면 window간의 interection이 줄어듭니다.

따라서 window의 위치를 변경해서 W-MSA의 window간 

연결성을 늘려주는 방법을 제안합니다.

 

https://csm-kr.tistory.com/85

 

[pytorch] torch.roll 설명 및 예제

torch.roll 설명 https://pytorch.org/docs/stable/generated/torch.roll.html torch.roll — PyTorch 2.0 documentation Shortcuts pytorch.org 위의 그림과 같이 torch.roll 은 원하는 dim으로 tensor를 이동시키는 것이다. (굴린다고 표현)

csm-kr.tistory.com

 

다음과 같은 그림에서 1번에서 2번으로 torch.roll 을 사용하여 변경할 수 있습니다. 

논문에서는 Cyclic Shift라는 용어를 사용했습니다. 

 

그런데 그림 3과 같이 생각을 해 보면, (5, 8) 영역은 각각 (3, 6)의 영역과 관련이 없습니다. 

이미지의 입장에서 roll을 한다고 해서 그 연속성이나 연결성을 보장해주지 않기 때문입니다. 

 

따라서 관련된 부분들만의 연산을 위해서 mask를 이용해 가려줍니다. 

이를 Masked MSA라고 합니다. 

저 파런색 부분이 위에서 roll 되어온 부분이고 여기서 W-MSA를 수행하면 다음과 같습니다. 

 

1번 영역은 sequence로 만들었을 때 영역 모두가 서로 연관된 부분이기 때문에 mask는 모두 0입니다. 

2번 영역을 생각하면 이를 sequence로 폈을 때, 다음과 같습니다. 

서로 연관된 부분에 대하여 masking 을 해주면 격자의 모양이 나옵니다. 

3번 영역은 다음과 같은 sequence 가 되고 

4번 영역은 2번과 3번의 통합된 모습이 나올것 입니다. 

이에 대하여 mask들을 보면 각 영역에 대하여 다음과 같이 가려주는 부분이 생깁니다. 

여기서 어두운부분이 -100, -inf 등으로 가려주는 부분이고 밝은 부분은 0으로 masking 합니다. 

이렇게 W-MSA에 mask를 더해주어 attention을 하고 다시 원래자리로 돌리는 roll back을 합니다.

attention matrix에 큰 음수값을 더해주면, softmax의 특성상 그 그 부분은 매우작은 값이되어 무시됩니다.

정리하자면 다음의 1, 2, 3, 4단계를 거쳐서 SW-MSA가 완성됩니다. 

이제 다음 그림이 잘 이해가 되실것 입니다.

코드는 다음과 같습니다.  (mask 부분은 official code 참조했습니다.)

 

class SW_MSA(nn.Module):
    """
    need shift torch.roll and attention mask
    """
    def __init__(self,
                 dim, num_heads, head_dim=None, window_size=7,
                 qkv_bias=True, attn_drop=0., proj_drop=0.,
                 input_resolution: tuple = (56, 56)):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = head_dim or dim // num_heads
        attn_dim = head_dim * num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(attn_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

        # calculate attention mask for SW-MSA
        self.input_resolution = input_resolution
        H, W = self.input_resolution
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        cnt = 0
        for h in (
                slice(0, -self.window_size),
                slice(-self.window_size, -3),
                slice(-3, None)):
            for w in (
                    slice(0, -self.window_size),
                    slice(-self.window_size, -3),
                    slice(-3, None)):
                img_mask[:, h, w, :] = cnt
                cnt += 1
        mask_windows = window_partition(img_mask, self.window_size)  # num_win, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        self.attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

    def forward(self, x):
        # setting
        B, L, C = x.shape
        ws = self.window_size
        w = h = int(math.sqrt(L))
        h_ = int(h // ws)
        w_ = int(w // ws)

        # [B, 3136, C]
        # ----------- efficient batch computation for shifted configuration -----------
        x = x.view(B, h, w, C)                             # [B, H, W, C]
        x = torch.roll(x, shifts=(-3, -3), dims=(1, 2))    # [B, H, W, C]
        x = x.view(B, h_, ws, w_, ws, C)                   # [0, 1, 2, 3, 4, 5 ] -> [0, 1, 3, 2, 4, 5 ] - idx
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()       # [B, 8, 7, 8, 7, 96] -> [B, 8, 8, 7, 7, 96]
        x = x.view(B * h_ * w_, ws * ws, C)                # [B' = B x 8 x 8],   -> [B'         49, 96]

        # ------------------------------ attention ------------------------------
        B_, N, C = x.shape  # [B_, 49, 96]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        num_win = self.attn_mask.shape[0]
        attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + self.attn_mask.to(torch.get_device(q)).\
            unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)

        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)                              # [B_, 49, 96]

        # ---------- make multi-batch tensor original batch tensor ----------v
        x = x.view(B, h_, w_, ws, ws, C)                   # [B, 8, 8, 7, 7, 96]
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()       # [B, 8, 7, 8, 7, 96]
        x = x.view(B, h, w, -1)                    # (roll)  [B, 56, 56, 96]
        x = torch.roll(x, shifts=(3, 3), dims=(1, 2))      # [B, 56, 56, 96]
        x = x.view(B, h * w, C)                            # [B, 3136, 96]
        return x

2-3) Patch merging

 

이부분은 resolution을 줄이고, channel을 줄이는 부분입니다. 

여기는 torch.Unfold와 nn.Linear로 쉽게 구현 가능합니다.  

https://csm-kr.tistory.com/79

 

[pytorch] torch.nn.Unfold

torch.nn.Unfold 는 다음과 같다. batched tensor에 대하여 마치 convolution 처럼 sliding 움직이면서 그 local block을 구하는 것이다. 다음 사이트에서 자세한 설명을 볼 수 있다. https://pytorch.org/docs/stable/generated

csm-kr.tistory.com

 

코드는 다음과 같습니다.

 

class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, input_resolution, downscaling_factor=2):
        super().__init__()
        self.input_resolution = input_resolution
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, l, c = x.shape
        h, w = self.input_resolution
        x = x.view(b, h, w, c)
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = x.view(-1, new_h * new_w, c * self.downscaling_factor ** 2)
        x = self.linear(x)
        return x

 


학습 및 결과

 

1) 학습 환경 (분류)

 

[ Imagenet 1k ]

batch size : 1024
learning rate : 0.001
epoch : 300
optimizer : AdamW
weight_decay : 0.05
lr scheduler : cosine scheduler (20 epoch warm up)
loss : LabelSmoothingCrossEntropy (num_classes : 1000)
Augmentation : like Deit (except for RA(repeated augmentation) and EMA(exponential moving average)

 

[ Imagenet 22k for pre-training ]

batch size : 4096
learning rate : 0.001
epoch : 90
optimizer : AdamW
weight_decay : 0.01
lr scheduler : cosine scheduler (5 epoch warm up)
loss : LabelSmoothingCrossEntropy (num_classes : 21,841 classes)
Augmentation : like Deit (except for RA(repeated augmentation) and EMA(exponential moving average)
(1400만개, 21,841 classes)

이후 Imagenet 1k에서 추가적으로 30 epoch의 fine-tuning 을 진행합니다. 
epoch : 30 
batch size : 1024
learning rate : 0.0001
weight_decay : 10^-8

 

참고로 DeiT에서의 augmentation 은 다음과 같습니다. 

 

 

2) 학습결과 (분류)

 

SWIN transformer는 다른 convolution based 네크워크에 비해서도 좋고 VIT와 DeiT에 대비하여

더 적은 파라미터와 flops에도 더 좋은 성능을 내는 모습을 보여줍니다. 

 

 

Imagenet-22k pre-trained model에 대하여도 기존의 vit보다 더 좋은 성능을 보여줍니다. 

 

 

3) 학습결과 (detection/segmentation)

 

다른 vision task에서 많이 쓰이는 backbone 인 resnet등과의 비교입니다. 

 

[object detection]

 

여러 디텍션 알고리즘 (Mask-RCNN, ATSS, RepPoints, Sparse R-CNN) 등의 backbone의 성능을 바꾸어도 

약간의 fps하락이 있지만 성능이 올라가는 모습을 보입니다. 

 

[semantic segmentation]

 

세그멘테이션의 경우에도 특히 UperNet을 보면, ResNet-101에서 44.9의 성능을 갖지만, 

Swin-T, S, B, L 로 넘어갈수록 더 좋은 성능을 보임을 알 수 있습니다. 

 


결론 및 코드구현 

 

Swin Transformer는 local attention을 최대한 window로 나누어 효율적인 계산과

함께 다중 스케일을 고려한 계층적으로 잘 구성된 모델을 만들었습니다.

이는 다양한 모델의 backbone이 될수 있음을 증명했습니다. 

 

더 발전될 가능성은 local attention을 어떻게 global attention과 합치느냐 입니다. 

다음에는 이를 위해서 연구된 InternImage(CVPR2023)도 다뤄볼 예정입니다. 

 

전체 코드는 다음에서 참조 가능합니다.

 

↓ 구현 코드 ↓

https://github.com/csm-kr/swin_transformer_pytorch

 

GitHub - csm-kr/swin_transformer_pytorch

Contribute to csm-kr/swin_transformer_pytorch development by creating an account on GitHub.

github.com

 

질문과 토론은 언제든지 환영합니다.

감사합니다. 😄


Reference

https://github.com/microsoft/Swin-Transformer

https://github.com/berniwal/swin-transformer-pytorch

https://github.com/microsoft/Swin-Transformer/issues/38

 

반응형

댓글