Skip to content

Commit

Permalink
Fix conflicts after same issue was patched in upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed May 14, 2024
1 parent 0a0b81b commit b835d4d
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 136 deletions.
4 changes: 2 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
Expand Down Expand Up @@ -183,7 +183,7 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens."""
Expand Down
110 changes: 32 additions & 78 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, EfficientDynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -1530,7 +1530,7 @@ def generate(
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_cache_class:
past = model_kwargs.get('past_key_values', None)
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
elif isinstance(past, tuple):
Expand Down Expand Up @@ -1613,44 +1613,11 @@ def generate(
streamer=streamer,
**model_kwargs,
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
result = self._greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")

result = self._contrastive_search(
input_ids,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
sequential=generation_config.low_memory,
**model_kwargs,
)

elif generation_mode == GenerationMode.SAMPLE:
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand All @@ -1660,57 +1627,43 @@ def generate(
**model_kwargs,
)

# 13. run sample
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

elif generation_mode == GenerationMode.BEAM_SEARCH:
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
result = self._beam_search(
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")

result = self._contrastive_search(
input_ids,
beam_scorer,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
sequential=generation_config.low_memory,
**model_kwargs,
)

elif generation_mode == GenerationMode.BEAM_SAMPLE:
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
prepared_logits_warper = (
self._get_logits_warper(generation_config) if generation_config.do_sample else None
)

# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
Expand All @@ -1732,16 +1685,13 @@ def generate(
)

# 14. run beam sample
result = self._beam_sample(
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
pad_token_id=generation_config.pad_token_id,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
generation_config=generation_config,
synced_gpus=synced_gpus,
**model_kwargs,
)
Expand Down Expand Up @@ -2166,7 +2116,7 @@ def _contrastive_search(
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))

past = tuple(new_key_values)

model_kwargs["past_key_values"] = past
Expand Down Expand Up @@ -2260,8 +2210,12 @@ def _contrastive_search(
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache):
for layer_idx in range(len(next_past_key_values)):
next_past_key_values.key_cache[layer_idx] = next_past_key_values.key_cache[layer_idx][augmented_idx, ...]
next_past_key_values.value_cache[layer_idx] = next_past_key_values.value_cache[layer_idx][augmented_idx, ...]
next_past_key_values.key_cache[layer_idx] = next_past_key_values.key_cache[layer_idx][
augmented_idx, ...
]
next_past_key_values.value_cache[layer_idx] = next_past_key_values.value_cache[layer_idx][
augmented_idx, ...
]
else:
new_key_values = []
for layer in next_past_key_values:
Expand All @@ -2273,7 +2227,6 @@ def _contrastive_search(

next_past_key_values = tuple(new_key_values)


logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]

# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
Expand Down Expand Up @@ -4110,6 +4063,7 @@ def _assisted_decoding(
else:
return input_ids


def _speculative_sampling(
candidate_input_ids,
candidate_logits,
Expand Down
19 changes: 9 additions & 10 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,15 +896,13 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -962,9 +960,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
20 changes: 9 additions & 11 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,16 +1131,13 @@ def forward(

inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -1203,9 +1200,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
v
Expand Down
18 changes: 9 additions & 9 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,14 +877,13 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -948,9 +947,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,10 +1596,10 @@ def forward(
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_seen_tokens = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_usable_length(seq_length)

if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
Expand Down Expand Up @@ -1673,6 +1673,9 @@ def forward(
return_dict=return_dict,
)

if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

if not return_dict:
return tuple(v for v in [*outputs, image_hidden_states] if v is not None)

Expand Down
20 changes: 9 additions & 11 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,16 +973,13 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -1040,9 +1037,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down

0 comments on commit b835d4d

Please sign in to comment.