Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing model layers throws an index error. #30508

Closed
candemircan opened this issue Apr 26, 2024 · 5 comments
Closed

Removing model layers throws an index error. #30508

candemircan opened this issue Apr 26, 2024 · 5 comments
Labels
Feature request Request for a new feature

Comments

@candemircan
Copy link

candemircan commented Apr 26, 2024

Feature request

Hello,

When I try to remove a layer from the LLaMa models using the code snippet below, I get an index error (pasted below the snippet). From what I could tell, layer_idx attribute of self.attn is being used for generation, and the layer_idx are not updated automatically. I believe the same behaviour holds in other models (e.g. gemma-2b). Apologies if there is another existing way to remove layers. I'm posting this after an extensive search.

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_name = "meta-llama/Meta-Llama-3-8B"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch_dtype,
)

model.model.layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i != 16])
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, return_dict=True, output_hidden_states=True)
IndexError                                Traceback (most recent call last)
Cell In[132], line 14
     12 prompt = "hello"
     13 tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
---> 14 output = model(tokenized, return_dict=True, output_hidden_states=True)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1208, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1205 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1207 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1208 outputs = self.model(
   1209     input_ids=input_ids,
   1210     attention_mask=attention_mask,
   1211     position_ids=position_ids,
   1212     past_key_values=past_key_values,
   1213     inputs_embeds=inputs_embeds,
   1214     use_cache=use_cache,
   1215     output_attentions=output_attentions,
   1216     output_hidden_states=output_hidden_states,
   1217     return_dict=return_dict,
   1218     cache_position=cache_position,
   1219 )
   1221 hidden_states = outputs[0]
   1222 if self.config.pretraining_tp > 1:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1018, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1007     layer_outputs = self._gradient_checkpointing_func(
   1008         decoder_layer.__call__,
   1009         hidden_states,
   (...)
   1015         cache_position,
   1016     )
   1017 else:
-> 1018     layer_outputs = decoder_layer(
   1019         hidden_states,
   1020         attention_mask=causal_mask,
   1021         position_ids=position_ids,
   1022         past_key_value=past_key_values,
   1023         output_attentions=output_attentions,
   1024         use_cache=use_cache,
   1025         cache_position=cache_position,
   1026     )
   1028 hidden_states = layer_outputs[0]
   1030 if use_cache:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:741, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    738 hidden_states = self.input_layernorm(hidden_states)
    740 # Self Attention
--> 741 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    742     hidden_states=hidden_states,
    743     attention_mask=attention_mask,
    744     position_ids=position_ids,
    745     past_key_value=past_key_value,
    746     output_attentions=output_attentions,
    747     use_cache=use_cache,
    748     cache_position=cache_position,
    749     **kwargs,
    750 )
    751 hidden_states = residual + hidden_states
    753 # Fully Connected

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:653, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    650 if past_key_value is not None:
    651     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    652     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 653     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    655 key_states = repeat_kv(key_states, self.num_key_value_groups)
    656 value_states = repeat_kv(value_states, self.num_key_value_groups)

File ~/.local/lib/python3.10/site-packages/transformers/cache_utils.py:149, in DynamicCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
    146     self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
    147     self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
--> 149 return self.key_cache[layer_idx], self.value_cache[layer_idx]

IndexError: list index out of range

Motivation

I think it'd be fantastic to decouple the layer_idx variable somehow to allow easy removal of entire blocks. I imagine this would be useful for the general research community to experiment with these models.

Your contribution

I'm not very familiar with the inner workings of the library, however I'd be happy to make a PR if you can give me some high level suggestions on how to make this change. Thanks!

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@amyeroberts amyeroberts added the Feature request Request for a new feature label Apr 26, 2024
@grahamannett
Copy link

grahamannett commented May 11, 2024

@candemircan I'm pretty sure I ran into something similar when trying to chop/remove the majority of a model for local dev and it makes what you are trying to do somewhat impossible.

Seems like there are two possibilities, one is figure out what args you need to pass in that possibly allow you to use the model without much modifying, for instance with what you posted using use_cache=False (but other models use the layer_idx for cache or other parts in different ways that also is not decoupled so you may need to figure out other kwargs as well):

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
model.model.layers.pop(16)
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, use_cache=False, return_dict=True, output_hidden_states=True)

should work. The other approach which is less layer adaptable but seems to be less prone to breaking across various forwards (but might not be helpful for what you are trying to do) is to modify a config that you pass into the model creation and change the number of layers (e.g. config.num_hidden_layers) before it is passed to from_pretrained.

Kind of annoying and I agree the model layers that are in something like a ModuleList should be decoupled from the model to allow for more easily debugging/dev locally without having to wrap/subclass/etc the Model/Config.

@ArthurZucker
Copy link
Collaborator

Layer index is mostly ( and only) used for the chace.
And this is more a feature request than a bug: you are manually re-ordering the layers without updating the layer index.

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_name = "meta-llama/Meta-Llama-3-8B"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch_dtype,
)
for i, layer in enumerate(model.model.layers[:-1]):
    if i<16:
        model.model.layers[i]
    else:
        model.model.layers[i] = model.model.layers[i+1]
        model.model.layers[i].layer_idx = i+1
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, return_dict=True, output_hidden_states=True)

It's not really part of the API not sure we want to add some kind of trick to automatically update the layer idx

@candemircan
Copy link
Author

hi Arthur,

thanks for the response. I think the workarounds you and @grahamannett suggested suit what I need

@ArthurZucker
Copy link
Collaborator

Glad we could help! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants