Skip to content

Commit

Permalink
Fix quant cache OOM (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed May 30, 2024
1 parent 7d6b1d4 commit 26e0982
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):

# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
cache_dtype = torch.uint8 if fp8_supported else self.dtype
dtype_size = torch.tensor([], dtype=cache_dtype).element_size()
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

Expand Down

0 comments on commit 26e0982

Please sign in to comment.