본문 바로가기
Network

[DNN] multi-head cross attention vs multi-head self attention 비교

by pulluper 2023. 2. 2.
반응형

안녕하세요 pulluper 입니다. 

 

attention 을 사용한 모듈을 보면, 하나의 인풋이 들어와서 q, k, v 가 같은 length를 가지는 경우가 있고, 

q의 길이와 k, v의 길이는 다른 경우를 왕왕 볼 수 있습니다. 

 

예를들어 DETR의  decoder의 경우에서 사용하는 attention 중 하나는 q, k, v 의 길이가 모두 같지 않습니다. 

이때 timm의 구현과 비슷하게 다음을 구현해 보겠습니다. 


모두 같은 길이를 같는 모듈은 Multi-head Self Attention(MSA) 이라 하겠고,

그렇지 않으면 Multi-head Cross Attention(MCA) 이라 하겠습니다. 

다음은 그것들의 구현입니다. 

 

import torch
import torch.nn as nn
import torch.nn.functional as F


class MSA(nn.Module):
    def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MCA(nn.Module):
    def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self._reset_parameters()

    def _reset_parameters(self):
        torch.manual_seed(0)
        nn.init.xavier_uniform_(self.q.weight)
        nn.init.xavier_uniform_(self.k.weight)
        nn.init.xavier_uniform_(self.v.weight)
        nn.init.xavier_uniform_(self.proj.weight)
        if self.k.bias is not None:
            nn.init.xavier_normal_(self.k.bias)
        if self.v.bias is not None:
            nn.init.xavier_normal_(self.v.bias)
        if self.proj.bias is not None:
            nn.init.constant_(self.proj.bias, 0.)

    def forward(self, x_q, x_k, x_v):
        B, N_q, C = x_q.shape
        _, N_kv, C = x_k.shape
        _, N_kv, C = x_v.shape

        # b, h, n, d
        q = self.q(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(x_k).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v(x_v).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        # [b, h, n, d] * [b, h, d, m] -> [b, h, n, m]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N_q, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# self_attn = nn.MultiheadAttention(256, 8, dropout=0)

if __name__ == '__main__':
    torch.manual_seed(0)
    x = torch.randn([2, 1024, 256])
    q = torch.randn([2, 100, 256]) 
    k = torch.randn([2, 1024, 256]) 
    v = torch.randn([2, 1024, 256]) 
    msa = MSA(dim=256)
    mca = MCA(dim=256)

    print(msa(x).size())
    print(mca(q, k, v).size())

 

이 출력을 살펴보면, [2, 1024, 256], [2, 100, 256] 으로 출력이 달라집니다. 

그러나 MHA의 경우에는 q와 k, v의 sequence의 길이가 달라도

q의 길이로 바뀌어 학습되는 것을 볼 수 있습니다. 

 

그림으로 보면 다음과 같이 볼 수 있겠습니다. 

msa vs mha

인풋 (q) 의 길이가 다를 때, (k, v) 의 길이가 같을 때 MCA를 사용하면 되겠습니다. 😀😀😀

감사합니다. 

반응형

댓글