diff --git a/src/ecco/lm.py b/src/ecco/lm.py index 05176f6..97d3330 100644 --- a/src/ecco/lm.py +++ b/src/ecco/lm.py @@ -86,7 +86,10 @@ def __init__(self, self.model_type = self.model_config['type'] embeddings_layer_name = self.model_config['embedding'] embed_retriever = attrgetter(embeddings_layer_name) - self.model_embeddings = embed_retriever(self.model) + if type(embed_retriever(self.model)) == torch.nn.Embedding: + self.model_embeddings = embed_retriever(self.model).weight + else: + self.model_embeddings = embed_retriever(self.model) self.collect_activations_layer_name_sig = self.model_config['activations'][0] except KeyError: raise ValueError(