Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies #4696

Merged
merged 12 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 63 additions & 0 deletions benchmarks/overheads/benchmark_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import cProfile
import pstats

from vllm import LLM, SamplingParams

# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)


def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()

print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)

print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())

# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
4 changes: 2 additions & 2 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_create_single_target_seq_group_metadata(k: int):

assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids(
) == prompt_tokens
assert list(
output.seq_data[target_seq_id].get_prompt_token_ids()) == prompt_tokens
assert output.seq_data[target_seq_id].get_output_token_ids(
) == prev_output_tokens + token_ids

Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -436,7 +436,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[List[int]],
prompt_tokens: Union[List[List[int]], List[Tuple[int,
...]]],
output_tokens: List[List[int]], vocab_size: int,
extra_seeds_to_generate: int, device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
Expand All @@ -446,7 +447,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
list(tokens) + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max([len(tokens) for tokens in output_tokens],
Expand Down
6 changes: 3 additions & 3 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
prompt_token_ids: Union[List[int], Tuple[int, ...]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove tuple input?

prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
Expand All @@ -84,7 +84,7 @@ def __init__(
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_token_ids = list(prompt_token_ids)
Copy link
Collaborator

@Yard1 Yard1 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do this? maybe it will be better to change tests instead.

Copy link
Collaborator Author

@KuntaiDu KuntaiDu May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case corresponding to this change is:
In tests/conftest.py:

        for req_output in req_outputs:
            ...
            prompt_ids = req_output.prompt_token_ids
            ...
            for sample in req_output.outputs:
                ...
                req_sample_output_ids.append(prompt_ids + output_ids)
                ...
            ...

And the code prompt_ids + output_ids requires prompt_ids to be list.

I guess processing vllm's output using code like prompt_ids + output_ids may be common in current vllm-based apps. So maybe keeping the prompt_token_ids attribute in RequestOutputs as list would be better for compatibility's sake.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think it would be fine to keep it as tuple and cast it to list in the test code tbh

self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
Expand Down
34 changes: 20 additions & 14 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -112,13 +112,13 @@ class SequenceData:

def __init__(
self,
prompt_token_ids: List[int],
prompt_token_ids: Union[List[int], Tuple[int, ...]],
output_token_ids: Optional[List[int]] = None,
) -> None:
if output_token_ids is None:
output_token_ids = []

self.prompt_token_ids = prompt_token_ids
self.prompt_token_ids: Tuple[int, ...] = tuple(prompt_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment this should not be changed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also consider making it private attr

self._prompt_token_id

?

self.output_token_ids = output_token_ids
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
Expand All @@ -139,7 +139,18 @@ def get_output_len(self) -> int:
return len(self.output_token_ids)

def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
return list(self.prompt_token_ids) + self.output_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove tuple input and remove conversion!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. For performance we need to avoid tuple -> list conversion. How about storing the list version as prompt_token_ids for accessing, and stores the tuple version in _prompt_token_ids_tuple for hashing speedup purposes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think that sounds good to me. (it should be fine since prompt tokens are not going to be changed)

Copy link
Collaborator

@rkooo567 rkooo567 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this part? (ping me when it is done!)


def get_prefix_token_ids(
self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = len(self.prompt_token_ids)
if num_tokens > prompt_length:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when does this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens when calculating hashes for both the user input (i.e. the prompt tokens) and the LLM-generated output (i.e. output tokens ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it happen under normal circumstance or is it only for recomputation case?

Copy link
Collaborator Author

@KuntaiDu KuntaiDu May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will happen for normal cases (inside function _allocate_last_physical_block in vllm/sequence.py).

return (self.prompt_token_ids,
tuple(self.output_token_ids[:num_tokens - prompt_length]))
else:
return (self.prompt_token_ids[:num_tokens], None)

def get_num_computed_tokens(self) -> int:
"""Return the number of prefill tokens that are already computed."""
Expand Down Expand Up @@ -174,7 +185,7 @@ def get_last_token_id(self) -> int:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]

def get_prompt_token_ids(self) -> List[int]:
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids

def get_output_token_ids(self) -> List[int]:
Expand Down Expand Up @@ -245,14 +256,9 @@ def get_output_text_to_return(self, buffer_length: int):
self.output_text)

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size

# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))

def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
Expand Down Expand Up @@ -306,7 +312,7 @@ def get_output_len(self) -> int:
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()

def get_prompt_token_ids(self) -> List[int]:
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.data.get_prompt_token_ids()

def get_last_token_id(self) -> int:
Expand Down Expand Up @@ -433,7 +439,7 @@ def prompt(self) -> str:
return next(iter(self.seqs_dict.values())).prompt

@property
def prompt_token_ids(self) -> List[int]:
def prompt_token_ids(self) -> Tuple[int, ...]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
Expand Down