반응형
안녕하세요 pulluper 입니다.
attention 을 사용한 모듈을 보면, 하나의 인풋이 들어와서 q, k, v 가 같은 length를 가지는 경우가 있고,
q의 길이와 k, v의 길이는 다른 경우를 왕왕 볼 수 있습니다.
예를들어 DETR의 decoder의 경우에서 사용하는 attention 중 하나는 q, k, v 의 길이가 모두 같지 않습니다.
이때 timm의 구현과 비슷하게 다음을 구현해 보겠습니다.
모두 같은 길이를 같는 모듈은 Multi-head Self Attention(MSA) 이라 하겠고,
그렇지 않으면 Multi-head Cross Attention(MCA) 이라 하겠습니다.
다음은 그것들의 구현입니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MSA(nn.Module):
def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MCA(nn.Module):
def __init__(self, dim=192, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self._reset_parameters()
def _reset_parameters(self):
torch.manual_seed(0)
nn.init.xavier_uniform_(self.q.weight)
nn.init.xavier_uniform_(self.k.weight)
nn.init.xavier_uniform_(self.v.weight)
nn.init.xavier_uniform_(self.proj.weight)
if self.k.bias is not None:
nn.init.xavier_normal_(self.k.bias)
if self.v.bias is not None:
nn.init.xavier_normal_(self.v.bias)
if self.proj.bias is not None:
nn.init.constant_(self.proj.bias, 0.)
def forward(self, x_q, x_k, x_v):
B, N_q, C = x_q.shape
_, N_kv, C = x_k.shape
_, N_kv, C = x_v.shape
# b, h, n, d
q = self.q(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = self.k(x_k).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v(x_v).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# [b, h, n, d] * [b, h, d, m] -> [b, h, n, m]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N_q, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# self_attn = nn.MultiheadAttention(256, 8, dropout=0)
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn([2, 1024, 256])
q = torch.randn([2, 100, 256])
k = torch.randn([2, 1024, 256])
v = torch.randn([2, 1024, 256])
msa = MSA(dim=256)
mca = MCA(dim=256)
print(msa(x).size())
print(mca(q, k, v).size())
이 출력을 살펴보면, [2, 1024, 256], [2, 100, 256] 으로 출력이 달라집니다.
그러나 MHA의 경우에는 q와 k, v의 sequence의 길이가 달라도
q의 길이로 바뀌어 학습되는 것을 볼 수 있습니다.
그림으로 보면 다음과 같이 볼 수 있겠습니다.
인풋 (q) 의 길이가 다를 때, (k, v) 의 길이가 같을 때 MCA를 사용하면 되겠습니다. 😀😀😀
감사합니다.
반응형
'Network' 카테고리의 다른 글
[DNN] torchvision module 이용해서 resnet-dc5 구현하기 (2) | 2023.03.21 |
---|---|
[DNN] timm을 이용한 VIT models ILSVRC classification 성능평가 (0) | 2023.03.17 |
performer 구현 (0) | 2022.12.12 |
[DNN] VIT(vision transformer) 리뷰 및 코드구현(CIFAR10) (ICLR2021) (1) | 2022.09.11 |
[DNN] Resnet 리뷰/구조 (CVPR2016) (2) | 2022.08.29 |
댓글