반응형
import os
import time
import torch
import visdom
import argparse
import torch.nn as nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from timm.models.layers import trunc_normal_
from torchvision.datasets.cifar import CIFAR10
# performer #############################################
import math
import torch
import torch.nn as nn
from functools import partial
from einops import rearrange, repeat
from distutils.version import LooseVersion
TORCH_GE_1_8_0 = LooseVersion(torch.__version__) >= LooseVersion('1.8.0')
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def linear_attention(q, k, v):
k_cumsum = k.sum(dim = -2)
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
context = torch.einsum('...nd,...ne->...de', k, v)
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
return out
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
ratio = (projection_matrix.shape[0] ** -0.5)
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
diag_data = diag_data.unsqueeze(dim=-1)
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.amax(data_dash, dim=-1, keepdim=True).detach()) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True).detach()) + eps)
return data_dash.type_as(data)
def orthogonal_matrix_chunk(cols, device = None):
unstructured_block = torch.randn((cols, cols), device = device)
if TORCH_GE_1_8_0:
q, r = torch.linalg.qr(unstructured_block.cpu(), mode = 'reduced')
else:
q, r = torch.qr(unstructured_block.cpu(), some = True)
q, r = map(lambda t: t.to(device), (q, r))
return q.t()
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
nb_full_blocks = int(nb_rows / nb_columns)
block_list = []
for _ in range(nb_full_blocks):
q = orthogonal_matrix_chunk(nb_columns, device = device)
block_list.append(q)
remaining_rows = nb_rows - nb_full_blocks * nb_columns
if remaining_rows > 0:
q = orthogonal_matrix_chunk(nb_columns, device = device)
block_list.append(q[:remaining_rows])
final_matrix = torch.cat(block_list)
if scaling == 0:
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
elif scaling == 1:
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
else:
raise ValueError(f'Invalid scaling {scaling}')
return torch.diag(multiplier) @ final_matrix
class FastAttention(nn.Module):
def __init__(self,
dim_heads,
nb_features = None,
ortho_scaling = 0,
causal = False,
generalized_attention = False,
kernel_fn=nn.ReLU(),
no_projection = False):
super().__init__()
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
self.dim_heads = dim_heads
self.nb_features = nb_features
self.ortho_scaling = ortho_scaling
self.create_projection = partial(gaussian_orthogonal_random_matrix,
nb_rows=self.nb_features,
nb_columns=dim_heads,
scaling=ortho_scaling)
projection_matrix = self.create_projection()
self.register_buffer('projection_matrix', projection_matrix)
self.generalized_attention = generalized_attention
self.kernel_fn = kernel_fn
# if this is turned on, no projection will be used
# queries and keys will be softmax-ed as in the original efficient attention paper
self.no_projection = no_projection
self.causal = causal
@torch.no_grad()
def redraw_projection_matrix(self, device):
projections = self.create_projection(device=device)
self.projection_matrix.copy_(projections)
del projections
def forward(self, q, k, v):
device = q.device
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device)
q = create_kernel(q, is_query=True)
k = create_kernel(k, is_query=False)
attn_fn = linear_attention
out = attn_fn(q, k, v)
return out
class SoftKernelAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
super(SoftKernelAttention, self).__init__()
self.heads = num_heads
dim_head = embed_dim // num_heads
self.to_q = nn.Linear(embed_dim, embed_dim)
self.to_k = nn.Linear(embed_dim, embed_dim)
self.to_v = nn.Linear(embed_dim, embed_dim)
self.to_out = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.attn = FastAttention(dim_head)
def forward(self, query, key=None, value=None, attn_mask=None, key_padding_mask=None):
if key is None and value is None:
key = value = query
# image
b, n, d, h = *query.shape, self.heads
q, k, v = self.to_q(query), self.to_k(key), self.to_v(value)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
attn_outs = []
out = self.attn(q, k, v)
attn_outs.append(out)
out = torch.cat(attn_outs, dim=1)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
out = self.dropout(out)
# original
# n, b, d, h = *query.shape, self.heads
# q, k, v = self.to_q(query), self.to_k(key), self.to_v(value)
# q, k, v = map(lambda t: rearrange(t, 'n b (h d) -> b h n d', h=h), (q, k, v))
#
# attn_outs = []
# out = self.attn(q, k, v)
# attn_outs.append(out)
#
# out = torch.cat(attn_outs, dim=1)
# out = rearrange(out, 'b h n d -> b n (h d)')
# out = self.to_out(out)
# out = self.dropout(out)
# out = rearrange(out, 'b n d -> n b d')
return out
# if __name__ == '__main__':
#
# soft_kernel_attn = SoftKernelAttention(embed_dim=192, num_heads=8)
# q = torch.randn([16, 65, 192])
# attn = soft_kernel_attn(q)
# print(attn.size())
# batch heads, length
# q = torch.randn([1, 8, 100, 64])
# k = torch.randn([1, 8, 256, 64])
# v = torch.randn([1, 8, 256, 64])
# attn = FastAttention(dim_heads=64)
# out = attn(q, k, v)
# print(out.size())
# soft_kernel_attn = SoftKernelAttention(embed_dim=256, num_heads=8)
# q = torch.randn([100, 16, 256])
# k = torch.randn([20, 16, 256])
# v = torch.randn([20, 16, 256])
# attn = soft_kernel_attn(q, k, v)
# print(attn.size())
# # seq, length dim
# original_attn = torch.nn.MultiheadAttention(embed_dim=256, num_heads=8)
# print(original_attn(q, k, v)[0].size())
# performer #############################################
class EmbeddingLayer(nn.Module):
def __init__(self, in_chans, embed_dim, img_size, patch_size):
super().__init__()
self.num_tokens = (img_size // patch_size) ** 2
self.embed_dim = embed_dim
self.project = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.num_tokens += 1
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.embed_dim))
# init cls token and pos_embed -> refer timm vision transformer
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L391
nn.init.normal_(self.cls_token, std=1e-6)
trunc_normal_(self.pos_embed, std=.02)
def forward(self, x):
B, C, H, W = x.shape
embedding = self.project(x)
z = embedding.view(B, self.embed_dim, -1).permute(0, 2, 1) # BCHW -> BNC
# concat cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
z = torch.cat([cls_tokens, z], dim=1)
# add position embedding
z = z + self.pos_embed
return z
class MSA(nn.Module):
def __init__(self, dim=192, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, act_layer=nn.GELU, bias=True, drop=0.):
super().__init__()
out_features = in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
# self.attn = MSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = SoftKernelAttention(embed_dim=dim, num_heads=num_heads)
self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class ViT(nn.Module):
def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=192, depth=12,
num_heads=12, mlp_ratio=2., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = nn.LayerNorm
act_layer = nn.GELU
self.patch_embed = EmbeddingLayer(in_chans, embed_dim, img_size, patch_size)
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
# final norm
self.norm = norm_layer(embed_dim)
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = self.norm(x)
x = self.head(x)[:, 0]
return x
def main():
# 1. ** argparser **
parer = argparse.ArgumentParser()
parer.add_argument('--epoch', type=int, default=50)
parer.add_argument('--batch_size', type=int, default=64)
parer.add_argument('--lr', type=float, default=0.001)
parer.add_argument('--step_size', type=int, default=100)
parer.add_argument('--root', type=str, default='./CIFAR10')
parer.add_argument('--log_dir', type=str, default='./log')
parer.add_argument('--name', type=str, default='vit_cifar10')
parer.add_argument('--rank', type=int, default=0)
ops = parer.parse_args()
# 2. ** device **
device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
# 3. ** visdom **
vis = visdom.Visdom(port=8097)
# 4. ** dataset / dataloader **
transform_cifar = tfs.Compose([
tfs.RandomCrop(32, padding=4),
tfs.RandomHorizontalFlip(),
tfs.ToTensor(),
tfs.Normalize(mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)),
])
test_transform_cifar = tfs.Compose([tfs.ToTensor(),
tfs.Normalize(mean=(0.4914, 0.4822, 0.4465),
std=(0.2023, 0.1994, 0.2010)),
])
train_set = CIFAR10(root=ops.root,
train=True,
download=True,
transform=transform_cifar)
test_set = CIFAR10(root=ops.root,
train=False,
download=True,
transform=test_transform_cifar)
train_loader = DataLoader(dataset=train_set,
shuffle=True,
batch_size=ops.batch_size)
test_loader = DataLoader(dataset=test_set,
shuffle=False,
batch_size=ops.batch_size)
# 5. ** model **
model = ViT(patch_size=2).to(device)
# 6. ** criterion **
criterion = nn.CrossEntropyLoss()
# 7. ** optimizer **
optimizer = torch.optim.Adam(model.parameters(),
lr=ops.lr,
weight_decay=5e-5)
# 8. ** scheduler **
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=ops.epoch, eta_min=1e-5)
# 9. ** logger **
os.makedirs(ops.log_dir, exist_ok=True)
# 10. ** training **
print("training...")
for epoch in range(ops.epoch):
model.train()
tic = time.time()
for idx, (img, target) in enumerate(train_loader):
img = img.to(device) # [N, 3, 32, 32]
target = target.to(device) # [N]
# output, attn_mask = model(img, True) # [N, 10]
output = model(img) # [N, 10]
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
for param_group in optimizer.param_groups:
lr = param_group['lr']
if idx % ops.step_size == 0:
vis.line(X=torch.ones((1, 1)) * idx + epoch * len(train_loader),
Y=torch.Tensor([loss]).unsqueeze(0),
update='append',
win='training_loss',
opts=dict(x_label='step',
y_label='loss',
title='loss',
legend=['total_loss']))
print('Epoch : {}\t'
'step : [{}/{}]\t'
'loss : {}\t'
'lr : {}\t'
'time {}\t'
.format(epoch,
idx, len(train_loader),
loss,
lr,
time.time() - tic))
# save
save_path = os.path.join(ops.log_dir, ops.name, 'saves')
os.makedirs(save_path, exist_ok=True)
checkpoint = {'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()}
torch.save(checkpoint, os.path.join(save_path, ops.name + '.{}.pth.tar'.format(epoch)))
# 10. ** test **
print('Validation of epoch [{}]'.format(epoch))
model.eval()
correct = 0
val_avg_loss = 0
total = 0
with torch.no_grad():
for idx, (img, target) in enumerate(test_loader):
model.eval()
img = img.to(device) # [N, 3, 32, 32]
target = target.to(device) # [N]
output = model(img) # [N, 10]
loss = criterion(output, target)
output = torch.softmax(output, dim=1)
# first eval
pred, idx_ = output.max(-1)
correct += torch.eq(target, idx_).sum().item()
total += target.size(0)
val_avg_loss += loss.item()
print('Epoch {} test : '.format(epoch))
accuracy = correct / total
print("accuracy : {:.4f}%".format(accuracy * 100.))
val_avg_loss = val_avg_loss / len(test_loader)
print("avg_loss : {:.4f}".format(val_avg_loss))
if vis is not None:
vis.line(X=torch.ones((1, 2)) * epoch,
Y=torch.Tensor([accuracy, val_avg_loss]).unsqueeze(0),
update='append',
win='test_loss',
opts=dict(x_label='epoch',
y_label='test_',
title='test_loss',
legend=['accuracy', 'avg_loss']))
scheduler.step()
if __name__ == '__main__':
main()
반응형
'Network' 카테고리의 다른 글
[DNN] timm을 이용한 VIT models ILSVRC classification 성능평가 (0) | 2023.03.17 |
---|---|
[DNN] multi-head cross attention vs multi-head self attention 비교 (0) | 2023.02.02 |
[DNN] VIT(vision transformer) 리뷰 및 코드구현(CIFAR10) (ICLR2021) (1) | 2022.09.11 |
[DNN] Resnet 리뷰/구조 (CVPR2016) (2) | 2022.08.29 |
[DNN] VGG구현을 위한 리뷰(ICLR 2015) (0) | 2022.07.27 |
댓글