Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce by 2 the memory requirement in generate() 🔥🔥🔥 #30536

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

Cyrilvallez
Copy link

@Cyrilvallez Cyrilvallez commented Apr 29, 2024

What does this PR do?

Change the data structure and implementation of DynamicCache to halve the memory requirement in generate() at no noticeable speed degradation.

Reason

I was working on precise memory estimation of generate() and noticed that the expected memory peak and the one observed are different (observed one is much higher). I was able to track down the problem to the current implementation of DynamicCache.

Since torch.cat() creates a copy, and that during generate we loop over each token using the previous cache in the inputs (thus the old cache is referenced in the inputs so cannot be garbage collected), the current implementation basically has 2 copies of the full cache in every iteration. By changing the data structure from Tuple[Tuple[Tensor]] to Tuple[Tuple[List[Tensor]]], I was able to avoid this copy and reduce the memory footprint of generate() by 2.

However, that means that potentially thousands of tensors must be cat()'ed when cache.update() is called to feed the correct tensor to the AttentionLayer. This results in a speed degradation. I was able to mitigate this issue by periodically cat()'ing the tensors in the cache when there are more than N (I used 50 as of now), which incurs a negligible memory increase, as the cache for a sequence length of N=50 is usually completely negligible compared to the size of the full cache (hundreds, or even more frequently thousands). This strategy almost completely removes the speeds degradation, allowing to get the best of both worlds. N could even be chosen dynamically in generate() depending on the input length and max new tokens, but 50 seemed like a good heuristic to start with.

Basically, at a very small performance penalty that is visible only for very large sequence length, we reduce the memory footprint by 2, which by itself allows to increase the batch size by 2 (so should be able to actually speed up the process as passing a sequence of 2 times the batch at lower speed should still be faster then runnings 2 loops at faster speed).

Of course, the best would still be to use a StaticCache as I saw that you started implementing, but a DynamicCache is still very much useful, and should not imply to double its effective memory footprint when used in loop referencing itself.

Benchmark

Here you can see the benchmark I ran using a RTX 4090 (24 GB) and Mistral 7B, Llama2 7B and Llama3 8B.

Fix batch size of 1 and input length of 300, variable new token number

Mistral-7B-v0.1_memory_fix_batch.pdf
Mistral-7B-v0.1_time_fix_batch.pdf
Llama-2-7b-hf_memory_fix_batch.pdf
Llama-2-7b-hf_time_fix_batch.pdf
Meta-Llama-3-8B_memory_fix_batch.pdf
Meta-Llama-3-8B_time_fix_batch.pdf

Fix new token number of 2000 and input length of 300, variable batch size

Mistral-7B-v0.1_memory_fix_length.pdf
Mistral-7B-v0.1_time_fix_length.pdf
Llama-2-7b-hf_memory_fix_length.pdf
Llama-2-7b-hf_time_fix_length.pdf
Meta-Llama-3-8B_memory_fix_length.pdf
Meta-Llama-3-8B_time_fix_length.pdf

Integration in Transformers

As I am changing the data structure of the DynamicCache, I had to modify how it is used in models modeling. For now, I only modified LlamaForCausalLM and MistralForCausalLM which are using the DynamicCache by default to test my implementation. Also, the change of data structure may have impacts elsewhere that I overlooked (if other code rely on attributes of the cache, e.g. a call to cache[0][0].shape would be an AttributeError now). If so, do not hesitate to point me towards these parts and I can modify it.

I would be happy to help integrate the change in all models if you decide to move forward with the PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker and @younesbelkada and @gante

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR and great in depth analysis.
One thing we could / should also do is remove this reference no?
For now we pass the cache as an input to every single layer, which also returns it.
I suppose that having less of this could solve what you mention about having 2 copies of the cache.

Anyhow if we can fix the memory consumption let's go 🔥
Changing the cache format is super breaking and cannot be done easily but we could do our best to avoid copies!

@Cyrilvallez
Copy link
Author

Hi @ArthurZucker, thanks for the feedback!

Unfortunately, I don't think we can handle all cases by keeping the same data structure. For example, if someone passes the past_key_values to generate() to resume generation, there is no way to avoid having the reference around as it will be referenced outside "transformer's scope" (e.g. the global scope). In this case, the current data structure will always create a copy and return the copy while still having the old one in memory because we cannot just dynamically add more memory to an already existing tensor. With the new data structure however, we can simply append to the old cache in-place. I just tested it and it works nicely, without allocating any more memory (except for the memory needed for adding a few tokens to the already existing cache of course)!

I checked the rest of the code base and I agree that it is a big change, but it is very much worth it. I'm definitely down to collaborate on it if you wish!

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 6, 2024

Hi @ArthurZucker,

I am done with the work. Could you please review it? At this point, it should be 100% backward compatible. The only change is that now past_key_values will be returned as a Tuple[Tuple[List[Tensor]]], and be modified in-place if directly supplied to model.generate() for any of the model architectures that use DynamicCache by default (i.e. Cohere, Dbrx, Gemma, Idefics2, Llama, Mistral, Mixtral, Olmo, Persimmon, Phi, Phi3, Qwen2, Qwen2Moe, StableLM, StarCoder2 as of now).
Every other model will still use the old format Tuple[Tuple[Tensor]] (or their own custom format, e.g. for Bloom) but will NOT benefit from the improved memory performance. However, if you make older models also use DynamicCache by default, they will automatically benefit from it without further modifications on your part.

Here is a final benchmark of the improved performances for every decoding strategy available in Transformers. It was still done on one of my RTX 4090, and with Mistral-7B-v0.1:

contrastive_search.pdf
greedy.pdf
sample.pdf
beam_search.pdf
beam_sample.pdf
constrained_beam_search.pdf
group_beam_search.pdf
assisted.pdf

Note that for assisted decoding, for 1000 tokens and more I ran into a weird CUDA error ([/opt/conda/conda-bld/pytorch_1704987289929/work/aten/src/ATen/native/cuda/Indexing.cu:1237]: indexSelectSmallIndex: block: [29,0,0], thread: [84,0,0] Assertion 'srcIndex < srcSelectDimSize' failed. and RuntimeError: CUDA error: device-side assert triggered) in BOTH the actual Transformers version and my modified script, which is why this particular benchmark stops after 500 tokens. I don't know exactly what is happening here, but it is not due to my changes as it is also happening in the current version.

As you can see, for non-beam decoding methods, the memory is divided by 2x as advertised (minus some small overhead, so around 86% improvement in practice for Mistral; between 91% and 98% for Llama2 7B and Llama3 8B). For beam methods, as I still rely on index_select() to reorder the cache which performs a copy, it is less than that, but still very much significant (I measured it as an improvement of 46%, so memory is divided by 1.5).
The speed impact is only noticeable when the number of new tokens generated in one go becomes very large (more than 4000), which in practice should be very rare anyway.

Finally, however, as the input size becomes large, the memory bottleneck may become the first model forward pass instead of the cache size. In settings where the input size is large compared to the number of new tokens generated (say e.g. input size 4000 and new tokens 5), the memory footprint will be dominated by the cost of the first model forward pass instead of the cache size.
In this case, the memory improvement will still be at least 25% for models with small caches such as Mistral, and usually around 60-70% for models with bigger caches such as Llama2 from my experiments, because we still save a full copy of the cache just after the first forward.
This is illustrated in the following 2 figures (here the input size is already 1000):

scaling_example_beam.pdf
scaling_example_non_beam.pdf

In those figures (the first is for beam-based methods, the second for non beam-based), we can observe "3 zones". First, as the number of new tokens generated is very small compared to the input size, we have a plateau of "minimal memory improvement", then as the number of new tokens increases a sharp increase in memory improvement, until the size of the cache is similar to the memory footprint of the first forward pass. At this point, we then plateau to the zone of « maximum memory improvement », which is roughly 2 times more efficient for all non beam-based methods, and roughly 1.5 more efficient for all beam-based decoding strategies.

I checked and we ALWAYS hit the zone of minimal improved performance, which means in this case (see figure) that ANY call to non beam-based method will yield at least a memory improvement of ~1.22 for Mistral and ~1.62 for Llama2 (even generating only 2 new tokens for any input size), with still even better improved efficiency (up to between 1.85 and 2) for a large range of input size and max new tokens. The same is true for beam-based methods, with minimal improvement ~1.3 for Mistral and more for Llama2 (up to ~1.5 max for a large range of inputs).

I hope this is clear and can be integrated as soon as possible! Don't hesitate to tell me if I am missing something that should be done before merging.

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 7, 2024

Hi @ArthurZucker,

I was investigating why we observe those "minimal improvements" even with very large input sizes and just 2 new tokens. I found out that the reason was that the logits in every iteration leaked to the next iteration though some hanging references to the outputs dictionary. This was true for every decoding strategy except assisted decoding. I fixed it in my last 2 commits.

Non-beam methods

That means that for any input size, the memory peak was not caused by the forward pass of the first iteration with the large input, but during the 2nd iteration, when the copying of the full cache was happening. At that point, the memory peak consisted of 2*past_KV + first_iteration_logits (the logits in the first iteration when all tokens are feeded to the model can quickly grow very large). After my first modifications, we were no longer keeping the 2nd cache copy, thus in the second iteration, the peak would be past_KV + first_iteration_logits, which is always at least smaller or equal to the memory needed for the forward in the first iteration (this pass is creating those values so need at least that much memory, in practice more). Thus the leaking of logits was not a memory bottleneck after solving the copying of the cache in this case. I still solved it though, because it does not hurt performances and would be weird to keep. In beam methods however, removing it is a big improvement (see below).
Anyway, that means that we can compute the "minimal improvement" as:

minimal_improvement = (2*past_KV + first_iteration_logits) / memory(first_forward_pass)

Now, both the cache and the logits trivially scale linearly with the input size. Since flash attention, the memory needed for the forward pass also scales linearly with the input size, thus minimal_improvement is a constant for all input sizes. This is exactly what I observed in my past comment, with this constant being ~1.22 for Mistral and ~1.62 for Llama2 (so we always divide the memory requirements by at least those numbers). I computed the ratio presented above and obtained the same values, which confirmed my analysis.

Now, as we generate more tokens after the first iteration, and the cache size exceeds the memory(first_forward_pass), we tend to the maximal_improvement = 2, because we removed the perpetual copy of the cache (logits are negligible here because after the first pass, the logits will always be for only 1 token, thus small in size).

Illustration:
non_beam-based.pdf

Beam methods

For beam methods, the story is slightly different. Because reordering the cache allocates a full copy of the cache, the leak of the cache and the logits caused 3*past_KV + first_iteration_logits to be allocated in memory during the 2nd iteration. After my work, it is now only 2*past_KV (I believe we could even reorder in a smarter way that would avoid that 2nd copy, but that would be future work). That means, as before, the minimal_improvement is:

minimal_improvement = (3*past_KV + first_iteration_logits) / max{memory(first_forward_pass), 2*past_KV}

and when generating enough new tokens, the maximal_improvement is going from 3*past_KV to 2*past_KV, thus dividing the memory requirement by 1.5. However, in this case, minimal_improvement is always bigger than 1.5, meaning that if we generate few tokens compared to the size of the input we will benefit from memory improvement bigger than 1.5, and if we generate enough new tokens compared to the input size, the improvement will tend to 1.5.

Illustration:
beam-based.pdf

Contrastive search

Contrastive search is a bit of a special case. As a non-beam strategy, it will benefit from 2x memory savings but will not scale in the same way. Because we artificially multiply the size of the cache by top_k in each iteration to forward with artificially increased batch size, if we neglect the memory needed for the initial forward pass (which is of the input size, not with batch size increased), we will immediately benefit from a memory improvement of (2 * top_k + 1) / (top_k + 1) which is ~2 very quickly. So, independently of the number of new tokens, we will get 2x increased performance.

Illustration:
contrastive_top_k.pdf

As you see on the image, as top_k increases, we quickly tend to 2x for any new tokens number. Note that this was obtained with Mistral which has a small cache (thus cost of initial forward is not completely negligible), which means that other models such as Llama2 would get to 2x much faster (for smaller top_k) but I couldn't really show it as I would immediately OOM with the old version.

Restarting from previous cache

Now, because greedy and sample search never perform any copies of the cache anymore, restarting from previous cache such as:

out = model.generate(inputs, max_new_tokens=100, return_dict_in_generate=True, do_sample=False)
cache = out.past_key_values

# Do whatever here

outputs = model.generate(out.sequences, past_key_values=cache, max_new_tokens=100, do_sample=False)

will actually benefit from a 3x memory improvement (that is, we divide the memory footprint by 3x). This is because before we would always have 2 copies of the cache internally during the second call to generate + the copy outside (variable cache). Now, no copies are performed and cache is modified in-place (we just append to it). This results in a 3x memory improvement. If the number of new tokens in the second call to generate becomes very large, then the size of the cache variable becomes somewhat negligible, and we are back to our 2x memory savings. So this setting will always yield between 2x and 3x savings.

As most chat applications will actually always use greedy or sample decoding, and will restart every new conversation turn from the previous cache (at least they should most of the time), it means that these chat applications will always save between 2x and 3x memory.

TL;DR

This PR significantly reduces the memory footprint of generate() for all inputs combinations and models.
For decoding strategies that do not rely on beam (including what I think are the most 2 populars, greedy and sample), most models will benefit of at least ~1.6x memory reduction independently of the input size. Those with smaller caches such as Mistral will benefit from roughly at least ~1.2x memory decrease for all input sizes. Moreover, as the max_new_tokens generated increases, the previous numbers quickly go up to 2x improvement. In most practical cases, we will actually hit the 2x increase zone (that is dividing the necessary memory by 2x).
For decoding strategies relying on beams, the improvement is always AT LEAST 1.5x with even better performance for small max_new_tokens.
Restarting generation from previous cache for greedy and sample decoding will ALWAYS (independently of all input sizes and number of new tokens generated) yield memory savings between 2x and 3x.
This comes at no cost ("free lunch") as the speed penalty is only noticeable when generating more than 6k or 7k new tokens in one go (around 15%-20% performance penalty in the worst case when generating 7k tokens in one-go, which may even still be acceptable), which in practice should (except in extremely niche cases) never happen. Moreover, as the batch size can be largely increased due to the memory savings, the actual processing speed may be even faster.
Nowadays, as the model sizes increase, the memory is (almost) always the bottleneck in all applications. This work tries to alleviate the pressure on those GPUs as much as possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante super interesting commit!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez kudos for this really high quality contribution 🤗 🚀
Your in-depth explanations are very useful, and all of it makes sense.
I think at some point we did want to change the way the cache was stored, to also remove the two transpose operation that are always needed in the modeling code (which is not the case for jax codes).

@gante some potential break is the cache format, but I think we can add an EfficientDynamicCache class that would be used upon activation. @Cyrilvallez this would be for use better than having two code paths and makes more sense for isolating bugs / maintenance!

PS: do you have any gist on how to reproduce memory benchmarks and etc?

WDYT @gante

Comment on lines 156 to 324
# Whenever we have more than N new K-V value, cat() them. That way, we keep a relatively low number
# of tensors in self.key_cache[layer_idx], which is more efficient to later cat() them all, and we only
# copy a small subset into memory whenever we cat() the last N K-V states
N = 50
index = None
for i, x in enumerate(self.key_cache[layer_idx]):
if x.shape[-2] == 1:
index = i
break
if index is not None and len(self.key_cache[layer_idx]) - 1 - index > N:
self.key_cache[layer_idx] = self.key_cache[layer_idx][:index] + [
torch.cat(self.key_cache[layer_idx][index:], dim=-2)
]
self.value_cache[layer_idx] = self.value_cache[layer_idx][:index] + [
torch.cat(self.value_cache[layer_idx][index:], dim=-2)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an interesting way to do some bucketing

# Whenever we have more than N new K-V value, cat() them. That way, we keep a relatively low number
# of tensors in self.key_cache[layer_idx], which is more efficient to later cat() them all, and we only
# copy a small subset into memory whenever we cat() the last N K-V states
N = 50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be a generation_config argument

@@ -1302,7 +1302,10 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
if isinstance(model_kwargs["past_key_values"][0][0], list):
past_length = sum(x.shape[-2] for x in model_kwargs["past_key_values"][0][0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the past key values is a EfficientCache class, this should be easier to extract. + @gante using the cache_positions would be the best here!

@Cyrilvallez
Copy link
Author

@Cyrilvallez kudos for this really high quality contribution 🤗 🚀 Your in-depth explanations are very useful, and all of it makes sense. I think at some point we did want to change the way the cache was stored, to also remove the two transpose operation that are always needed in the modeling code (which is not the case for jax codes).

@gante some potential break is the cache format, but I think we can add an EfficientDynamicCache class that would be used upon activation. @Cyrilvallez this would be for use better than having two code paths and makes more sense for isolating bugs / maintenance!

PS: do you have any gist on how to reproduce memory benchmarks and etc?

WDYT @gante

Thanks for reviewing @ArthurZucker! Here is a link to the benchmarks I ran: https://gist.github.com/Cyrilvallez/ce1adfad1d561c1e8dc92666ab5a9e8c

It is a bit messy but you should find everything. Each benchmark was run on the official version 4.40.1, saving outputs with legacy in the filenames, and then on my current branch using patched in the filenames. You should manually change the filenames depending on which version you use if you re-run the benchmarks. Also, you may need to modify the actual paths to those filenames when creating the figures. Apart from that, I think the link should contain everything you need to reproduce.

@ArthurZucker
Copy link
Collaborator

@Cyrilvallez what do you think about creating a new cache class?

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 9, 2024

@ArthurZucker I agree that a new class would be clearer and more maintainable. Not sure what you meant by "that would be used upon activation" however? I would use the new EfficientDynamicCache as the new default for all models actually using DynamicCache by default if not provided. As I cannot think of any backward compatibility issues (even if some people save cache to file to restart from it, the new class will work), doing otherwise would be counter-productive I think (it would be weird to have to pass an argument to activate a feature that does the same as before, but more efficiently in all cases). I don't have your experience maintaining such a widely used codebase however, so if you think people may run into problems, we could raise a warning to point them to how to use the old DynamicCache if they run into any issue?

@ArthurZucker
Copy link
Collaborator

By default we don't change anything, and just passing "cache_implementaiton="efficient" would use your implementation! We can warn about "you are using the non efficient version of the DynamicCache for example 🤗

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 9, 2024

Ok, then by default we will still benefit from removing the leak of the logits which is already a big gain. I will make the necessary changes next Monday 👌🏻

@gante
Copy link
Member

gante commented May 9, 2024

@Cyrilvallez very cool in-depth exploration 😮🔥 And also very impactful consequences of the suggested changes!

Reading the discussion, from a usage perspective, I agree with Arthur: a separate class would be ideal! With a separate cache class, we:

  1. Can keep a simple reference cache class in DynamicCache;
  2. The default class used and returned in forward stays as simple as possible;
  3. [after a proper deprecation cycle] Can use the new efficient class by default in generate, getting all the benefits showcased in the benchmarks above.

I'll review the full code after it is implemented as a separate class, as we all seem aligned on what should be done 🤗 @Cyrilvallez ping me when it's ready :D

@gante
Copy link
Member

gante commented May 13, 2024

@Cyrilvallez out of curiosity: have you considered/explored expanding the cache with fixed-size blocks every time we hit the limit, similarly to paged attention?

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 14, 2024

@ArthurZucker @gante I realized yesterday that what actually creates the copies is not the current DynamicCache itself, but the back and forth from_legacy_cache and to_legacy_cache calls (that creates tuples that are immutables). Thus, forcing to use proper DynamicCache in generate() before going into specialized decoding strategy would retain all the above benefits while removing the need for a new class. Something like:

if use_cache and self._supports_cache_class:
    if generation_config.cache_implementation is None and not isinstance(model_kwargs.get('past_key_values', None), Cache):
        model_kwargs['past_key_values'] = DynamicCache() # or DynamicCache.from_legacy_cache() depending on model_kwargs['past_key_values'] type

Doing so will avoid the back and forth switch from legacy to cache, and thus all references in the decoding strategies functions will point to the same Cache object that will be modified in-place for all of them because the caches are List[Tensor].

This would greatly simplify the above approach while retaining all benefits (even more as reorder would also happen in-place layer after layer, thus not allocating new memory -> beam decoding will use 3x less memory instead of 1.5). The small speed decrease would also be removed.

@Cyrilvallez
Copy link
Author

Cyrilvallez commented May 15, 2024

@ArthurZucker @gante The work is ready for final review!

As previously said, EfficientDynamicCache was not needed in the end. This makes the changes more natural based on the current state of the library, and will be much easier to maintain. I rebased my branch to remove previous commits and make everything clearer. Models with support for DynamicCache will use it from start to end as of now, without back and forth to the old legacy format. However, generate() will return the legacy format by default for backward compatibility. This can be controlled with the generation config argument return_legacy_cache if this is not wanted (e.g. in chat applications, when restarting from previous cache, it is more efficient to always keep a DynamicCache to avoid any copying and get between 3X and 2X memory improvement even for sample and greedy decoding as previously mentioned).

Models without support for DynamicCache will still benefit from improvements as well as I removed all the references to leaked logits, and, as much as possible, leaked copies of past_key_values. This is illustrated in the following for greedy search:
old_models.pdf
As you can see, models with large logits (i.e. large vocabulary size) such as bloom benefit a lot from the removed leaked logits as long as the number of new tokens generated is not too big. This is because we divide memory consumption by the factor (2*cache + logits_first_iteration) / max{memory_first_forward, 2*cache} and that usually for large logits the denominator is equal to memory_first_forward. For smaller logits, the denominator will be 2*cache and we "only" save the full logits size.

Moreover, this work paves the way to the full adoption of the new Cache format everywhere as all generate functions can now accomodate a DynamicCache instance all the way (some functions and/or some keywords were not completely supported before).

Finally, here is the final benchmark of performances:
beam-based.pdf
non_beam-based.pdf
contrastive.pdf
contrastive_top_k.pdf

Beam-based methods now benefit from using 3X less memory because even reorder() is made in-place. Non beam-based are the same as before (except contrastive search). Contrastive always benefits from at least 2X less memory now. It will be even better for small top_k, and then scale to 2X as top_k increases, because I changed how the cache is expanded (without temporary copy, thus we only have top_k cache replications instead of top_k +1 which is a nice gain for small top_k compared to previous benchmarks -> memory gain factor is (2*top_k + 1) / top_k now).

Finally, as we are using DynamicCache, we ne longer have any speed degradation. All those memory improvements are FREE LUNCH all the way.

@Cyrilvallez
Copy link
Author

New idea to further improve memory: #30860

@ArthurZucker
Copy link
Collaborator

Damn that's impressive! Reviewing now!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright! In general 🔥 looks great

  • I think splitting the PR into 3 would be the best!
  1. PR with the clone and del of the output
  2. PR with the removal of the support for tuples
  3. PR with the new efficient cache.
    Unless 2 and 3 have to go together!

@gante will be the one reviewing, and deciding so let's get his opinion! 🤗

In any case, amazing work! 🚀

if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
past_key_values.crop(maximum_length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super nice

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, well done!

@@ -382,6 +384,7 @@ def __init__(self, **kwargs):

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.return_legacy_cache = kwargs.pop("cache_implementation", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit weird to fetch cache_implementation

Copy link
Member

@gante gante May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming return_legacy_cache should be fetched instead (and not cache_implementation) :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, should be return_legacy_cache, this one is clearly my mistake!

Comment on lines +2030 to +2036
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
logit_for_next_step = outputs.logits[:, -1, :].clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💘

Comment on lines 3805 to 3807
# New efficient cache
elif isinstance(data, DynamicCache):
return data.split(full_batch_size, split_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a nice way to put it IMO

Comment on lines +1169 to +1171
# We may have an initialized but empty DynamicCache during first iteration
past_exist = past_key_values is not None and not (
isinstance(past_key_values, DynamicCache) and len(past_key_values) == 0
)
if past_exist:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are in generate, this should not even exist anyway and cache positions should be used!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very cool developments! It's amazing to see that we could get the memory benefits without adding a new class 💛 Really good work, @Cyrilvallez 🔥

I've added a few comments, mostly related to long-term maintenance

@ArthurZucker @Cyrilvallez -- I'm happy with this being done in a single PR :) After we 3 are happy with the changes, before merging, I will run a few slow tests on my end to confirm the PR is fully backwards compatible!

self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]

def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:

Suggestion: split can mean many things, as the object contains 2 lists of 4D tensors (so 6 possible "dimensions" to split). A more precise name helps with readability 🤗

Comment on lines 217 to 218
def from_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `split()` method. This will be used by `stack_model_outputs` in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def from_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `split()` method. This will be used by `stack_model_outputs` in
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in

(continuation of the suggestion above)

if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
past_key_values.crop(maximum_length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, well done!

@@ -382,6 +384,7 @@ def __init__(self, **kwargs):

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.return_legacy_cache = kwargs.pop("cache_implementation", True)
Copy link
Member

@gante gante May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming return_legacy_cache should be fetched instead (and not cache_implementation) :)

Comment on lines 1106 to 1108
if isinstance(past, DynamicCache) and not self._supports_dynamic_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `DynamicCache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this new if and the _supports_dynamic_cache_class model attribute are redundant: all models with self._supports_cache_class = True should support DynamicCache. As such, both this if and the new model attribute can be removed.

The exception is Jamba, which has a custom Cache (and no support for the legacy class). lmk if custom logic is needed for Jamba, so we can find a solution that doesn't require a new model attribute 🤗

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yes, this has indeed become redundant since the clarification for cache supports attributes in 9d889f8! No need for this attribute anymore then 👌 I rebased my branch to apply the latest changes in main, and removed the attribute

Comment on lines 1691 to 1698
# Remove potential default DynamicCache if assistant does not support it
assistant_kwargs = copy.copy(model_kwargs)
if assistant_model is not None:
if use_dynamic_cache_by_default and not assistant_model._supports_dynamic_cache_class:
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request: can we move this logic to AssistedCandidateGenerator.__init__? That way, all logic regarding the initialization of the assistant attributes stays in one place :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This will indeed be much cleaner

@Cyrilvallez
Copy link
Author

@ArthurZucker @gante I applied all changes following your comments!

Repo consistency and code quality errors do not come from files I modified (I think code quality errors come from ruff version now being 0.4.4, previously 0.1.x).
Same goes for tests_tf, I did not touch any tf files either.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants