Skip to content

Commit

Permalink
Remove _supports_dynamic_cache_class attribute after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed May 21, 2024
1 parent 7a3d939 commit 6608872
Show file tree
Hide file tree
Showing 18 changed files with 3 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_dynamic_cache_class
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
past = model_kwargs.get("past_key_values", None)
if isinstance(past, DynamicCache) and not self._supports_dynamic_cache_class:
if isinstance(past, DynamicCache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `DynamicCache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
Expand Down Expand Up @@ -1639,7 +1639,7 @@ def generate(
)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_dynamic_cache_class:
elif generation_config.cache_implementation is None and self._supports_cache_class:
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,10 +1284,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_supports_cache_class = False
_supports_static_cache = False

# Has support for a `DynamicCache` instance as `past_key_values`. Some models support it but not other caches
# Using it is a big memory gain
_supports_dynamic_cache_class = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,6 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,6 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,6 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
# important: this ported version of Idefics2 isn't meant for training from scratch - only
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,6 @@ class MistralPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,6 @@ class MixtralPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,6 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ class PersimmonPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,6 @@ class PhiPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,6 @@ class Phi3PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True
_supports_dynamic_cache_class = True

_version = "0.0.5"

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,6 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,6 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ class StableLmPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_sdpa = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,6 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_dynamic_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down

0 comments on commit 6608872

Please sign in to comment.