본문 바로가기
카테고리 없음

clip 데모

by pulluper 2025. 6. 23.
반응형

1. transformers 이용

 

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

# 1. 모델과 프로세서 로드 (Hugging Face transformers)
model_id = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_id, use_safetensors=True)
processor = CLIPProcessor.from_pretrained(model_id, use_safetensors=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()

# 2. 이미지 로드
src_path = "rabit1.png"
tgt_paths = ["rabit2.png", "spongebob.jpg"]

# 3. 이미지 전처리
src_image = Image.open(src_path).convert("RGB")
tgt_images = [Image.open(path).convert("RGB") for path in tgt_paths]

# 4. 피처 추출
def encode_image(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        features = model.get_image_features(**inputs)
        features = features / features.norm(p=2, dim=-1, keepdim=True)
    return features

src_feat = encode_image(src_image)
tgt_feats = [encode_image(img) for img in tgt_images]

# 5. 유사도 계산
similarities = [(src_feat @ tgt.T).item() for tgt in tgt_feats]

# 6. 결과 출력
for path, sim in zip(tgt_paths, similarities):
    print(f"Cosine Similarity (src ↔ {path}): {sim:.4f}")

# 7. 시각화
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
all_images = [src_image] + tgt_images
titles = ["Source"] + [f"{p}\n(sim: {s:.2f})" for p, s in zip(tgt_paths, similarities)]

for ax, img, title in zip(axs, all_images, titles):
    ax.imshow(img)
    ax.set_title(title)
    ax.axis("off")

plt.tight_layout()
plt.show()

 

 

2. open_clip 이용

import torch
from PIL import Image
import open_clip

# 1. 모델 및 전처리기 로드
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
# model.eval()
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14')
model.eval()

# 2. 텍스트 토크나이저
tokenizer = open_clip.get_tokenizer('ViT-B-32')
tokenizer = open_clip.get_tokenizer('ViT-L-14')

# 3. 이미지 전처리 및 텍스트 준비
src_image = preprocess(Image.open("rabit1.png")).unsqueeze(0)
tgt_image1 = preprocess(Image.open("rabit2.png")).unsqueeze(0)
tgt_image2 = preprocess(Image.open("spongebob.jpg")).unsqueeze(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

src_image = src_image.to(device)
tgt_image1 = tgt_image1.to(device)
tgt_image2 = tgt_image2.to(device)

with torch.no_grad(), torch.autocast(device_type=device):
    src_feat = model.encode_image(src_image)
    tgt_feat1 = model.encode_image(tgt_image1)
    tgt_feat2 = model.encode_image(tgt_image2)

# 정규화 (cosine similarity 위해)
src_feat = src_feat / src_feat.norm(dim=-1, keepdim=True)
tgt_feat1 = tgt_feat1 / tgt_feat1.norm(dim=-1, keepdim=True)
tgt_feat2 = tgt_feat2 / tgt_feat2.norm(dim=-1, keepdim=True)

# 코사인 유사도 계산 (내적)
sim1 = (src_feat @ tgt_feat1.T).item()
sim2 = (src_feat @ tgt_feat2.T).item()

print(f"Cosine Similarity (src ↔ rabit2.png):     {sim1:.4f}")
print(f"Cosine Similarity (src ↔ spongebob.jpg):  {sim2:.4f}")

from matplotlib import pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
for ax, path, title in zip(axs, ["rabit1.png", "rabit2.png", "spongebob.jpg"],
                           ["Source", f"rabit2.png\n(sim: {sim1:.2f})", f"spongebob.jpg\n(sim: {sim2:.2f})"]):
    img = Image.open(path)
    ax.imshow(img)
    ax.set_title(title)
    ax.axis("off")

plt.tight_layout()
plt.show()









# # text = tokenizer(["a diagram", "a rabit", "a dog", "a cat"])

# # 4. 특징 추출 및 정규화
# with torch.no_grad(), torch.autocast("cuda"):
#     image_features = model.encode_image(src_image)
#     image_features = model.encode_image(tgt_image1)
#     image_features = model.encode_image(tgt_image2)
#     # text_features = model.encode_text(text)
#     image_features /= image_features.norm(dim=-1, keepdim=True)
#     # text_features /= text_features.norm(dim=-1, keepdim=True)

#     # 유사도 계산 후 softmax (확률 분포)
#     text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# print("Label probs:", text_probs)

 

 

반응형

댓글