We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scenic
import clip import torch import jax import numpy as np from scenic.projects.baselines.clip import model as clip_scenic inputs = np.random.randn(1, 336, 336, 3) model, preprocess = clip.load("ViT-L/14@336px", device="cpu") with torch.no_grad(): image = torch.from_numpy(inputs.transpose(0, 3, 1, 2)) image_features = model.encode_image(image).numpy() print(image_features.shape) temp = image_features[0, :4].flatten().tolist() print(", ".join([str(f"{x:.4f}") for x in temp])) print("=====Printing JAX model=====") _CLIP_MODEL_NAME = 'vit_l14_336px' _model = clip_scenic.MODELS[_CLIP_MODEL_NAME]() _model_vars = clip_scenic.load_model_vars(_CLIP_MODEL_NAME) images = jax.numpy.array(inputs) image_embs, _ = _model.apply(_model_vars, images, None) print(image_embs.shape) temp = np.asarray(image_embs[0, :4]).flatten().tolist() print(", ".join([str(f"{x:.4f}") for x in temp]))
Gives:
(1, 768) -0.1827, 0.7319, 0.8779, 0.4829 =====Printing JAX model===== (1, 768) -0.0107, 0.0429, 0.0514, 0.0283
This is not expected, right?
@rchen152?
The text was updated successfully, but these errors were encountered:
I don't think I'm the right person to look at this; my only contribution to this project was adding some comments to disable type checker errors.
Sorry, something went wrong.
No branches or pull requests
Gives:
This is not expected, right?
@rchen152?
The text was updated successfully, but these errors were encountered: