Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[使用skyclip 计算image-text similarity的代码跑不通] #3

Open
randomtutu opened this issue Jul 18, 2023 · 0 comments
Open

[使用skyclip 计算image-text similarity的代码跑不通] #3

randomtutu opened this issue Jul 18, 2023 · 0 comments

Comments

@randomtutu
Copy link

randomtutu commented Jul 18, 2023

from PIL import Image
import requests
import clip
import torch
from transformers import BertTokenizer
from transformers import CLIPProcessor, CLIPModel, CLIPTextModel
import numpy as np

query_texts = ['一个人', '一辆汽车', '两个男人', '两个女人']  # 这里是输入提示词,可以随意替换。
# 加载SkyCLIP 中英文双语 text_encoder
text_tokenizer = BertTokenizer.from_pretrained("./tokenizer")
text_encoder = CLIPTextModel.from_pretrained("./text_encoder").eval()
text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']

url = "http://images.cocodataset.org/val2017/000000040083.jpg"  #这里可以换成任意图片的url
# 加载CLIP的image encoder
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_text_proj = clip_model.text_projection
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors="pt")

with torch.no_grad():
   image_features = clip_model.get_image_features(**image)
   text_features = text_encoder(text)[0]
   # sep_token对应于openai-clip的eot_token
   sep_index = torch.nonzero(text == student_tokenizer.sep_token_id)
   text_features = text_features[torch.arange(text.shape[0]), sep_index[:, 1]]
   # 乘text投影矩阵
   text_features = clip_text_proj(text_features)
   image_features = image_features / image_features.norm(dim=1, keepdim=True)
   text_features = text_features / text_features.norm(dim=1, keepdim=True)
   # 计算余弦相似度 logit_scale是尺度系数
   logit_scale = clip_model.logit_scale.exp()
   logits_per_image = logit_scale * image_features @ text_features.t()
   logits_per_text = logits_per_image.t()
   probs = logits_per_image.softmax(dim=-1).cpu().numpy()
   print(np.around(probs, 3))

请问一下代码里 tokenizer 和 text_encoder分别是啥?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant