Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added prompt_tokens to the response #165

Merged
merged 8 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BestOfSequence:
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -205,6 +207,8 @@ class Response:
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class BestOfSequence(BaseModel):
class Details(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -222,6 +224,8 @@ class Response(BaseModel):
class StreamDetails(BaseModel):
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
51 changes: 0 additions & 51 deletions clients/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,51 +0,0 @@
import pytest

from lorax import __version__
from huggingface_hub.utils import build_hf_headers


@pytest.fixture
def flan_t5_xxl():
return "google/flan-t5-xxl"


@pytest.fixture
def fake_model():
return "fake/model"


@pytest.fixture
def unsupported_model():
return "gpt2"


@pytest.fixture
def base_url():
return "https://api-inference.huggingface.co/models"


@pytest.fixture
def bloom_url(base_url, bloom_model):
return f"{base_url}/{bloom_model}"


@pytest.fixture
def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}"


@pytest.fixture
def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}"


@pytest.fixture
def unsupported_url(base_url, unsupported_model):
return f"{base_url}/{unsupported_model}"


@pytest.fixture(scope="session")
def hf_headers():
return build_hf_headers(
library_name="lorax-tests", library_version=__version__
)
150 changes: 0 additions & 150 deletions clients/python/tests/test_client.py

This file was deleted.

14 changes: 14 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@
"type": "object",
"required": [
"finish_reason",
"prompt_tokens",
"generated_tokens",
"prefill",
"tokens"
Expand All @@ -428,6 +429,12 @@
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"prompt_tokens": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0.0
},
"generated_tokens": {
"type": "integer",
"format": "int32",
Expand Down Expand Up @@ -773,12 +780,19 @@
"type": "object",
"required": [
"finish_reason",
"prompt_tokens",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"prompt_tokens": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0.0
},
"generated_tokens": {
"type": "integer",
"format": "int32",
Expand Down
4 changes: 4 additions & 0 deletions docs/reference/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BestOfSequence:
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand All @@ -205,6 +207,8 @@ class Response:
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Prefill tokens length
uint32 prefill_tokens_length = 8;
}

message FilterBatchRequest {
Expand Down
37 changes: 26 additions & 11 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl Infer {
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_prefill_length = 0;
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
Expand All @@ -197,16 +198,22 @@ impl Infer {
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
InferStreamResponse::Prefill {
tokens,
tokens_length,
} => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
if let Some(tokens_val) = tokens {
result_prefill = tokens_val
.ids
.into_iter()
.zip(tokens_val.logprobs.into_iter())
.zip(tokens_val.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
result_prefill_length = tokens_length;
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
Expand All @@ -233,6 +240,7 @@ impl Infer {
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
prompt_tokens: result_prefill_length,
generated_text,
queued,
start,
Expand Down Expand Up @@ -569,10 +577,13 @@ fn send_responses(

let mut stopped = false;

if let Some(prefill_tokens) = generation.prefill_tokens {
if generation.prefill_tokens_length > 0 {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)),
Ok(InferStreamResponse::Prefill {
tokens: generation.prefill_tokens,
tokens_length: generation.prefill_tokens_length,
}),
Duration::from_millis(10),
)?;
}
Expand Down Expand Up @@ -629,7 +640,10 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill {
tokens: Option<PrefillTokens>,
tokens_length: u32,
},
// Intermediate messages
Token(Token),
// Last message
Expand All @@ -645,6 +659,7 @@ pub(crate) enum InferStreamResponse {
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) prompt_tokens: u32,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
Expand Down
Loading
Loading