Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scenic CLIP doesn't match the outputs of the original CLIP implementation #991

Open
sayakpaul opened this issue Jan 31, 2024 · 1 comment

Comments

@sayakpaul
Copy link

import clip
import torch 
import jax
import numpy as np
from scenic.projects.baselines.clip import model as clip_scenic

inputs = np.random.randn(1, 336, 336, 3)
model, preprocess = clip.load("ViT-L/14@336px", device="cpu")

with torch.no_grad():
    image = torch.from_numpy(inputs.transpose(0, 3, 1, 2))
    image_features = model.encode_image(image).numpy()
    print(image_features.shape)

temp = image_features[0, :4].flatten().tolist()
print(", ".join([str(f"{x:.4f}") for x in temp]))
print("=====Printing JAX model=====")

_CLIP_MODEL_NAME = 'vit_l14_336px'
_model = clip_scenic.MODELS[_CLIP_MODEL_NAME]()
_model_vars = clip_scenic.load_model_vars(_CLIP_MODEL_NAME)

images = jax.numpy.array(inputs)
image_embs, _ = _model.apply(_model_vars, images, None)
print(image_embs.shape)
temp = np.asarray(image_embs[0, :4]).flatten().tolist()
print(", ".join([str(f"{x:.4f}") for x in temp]))

Gives:

(1, 768)
-0.1827, 0.7319, 0.8779, 0.4829
=====Printing JAX model=====
(1, 768)
-0.0107, 0.0429, 0.0514, 0.0283

This is not expected, right?

@rchen152?

@rchen152
Copy link
Contributor

I don't think I'm the right person to look at this; my only contribution to this project was adding some comments to disable type checker errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants