Skip to content

Commit

Permalink
Make max_new_tokens optional, default to max_total_tokens - input_len…
Browse files Browse the repository at this point in the history
…gth (#353)
  • Loading branch information
tgaddair committed Mar 22, 2024
1 parent f474be2 commit 8ff0bf5
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 51 deletions.
16 changes: 8 additions & 8 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate(
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
ignore_eos_token: bool = False,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
Expand Down Expand Up @@ -101,7 +101,7 @@ def generate(
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`Optional[int]`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
Expand Down Expand Up @@ -201,7 +201,7 @@ def generate_stream(
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
ignore_eos_token: bool = False,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
Expand Down Expand Up @@ -232,7 +232,7 @@ def generate_stream(
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`Optional[int]`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
Expand Down Expand Up @@ -388,7 +388,7 @@ async def generate(
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
ignore_eos_token: bool = False,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
Expand Down Expand Up @@ -422,7 +422,7 @@ async def generate(
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`Optional[int]`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
Expand Down Expand Up @@ -517,7 +517,7 @@ async def generate_stream(
merged_adapters: Optional[MergedAdapters] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
ignore_eos_token: bool = False,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
Expand Down Expand Up @@ -550,7 +550,7 @@ async def generate_stream(
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`Optional[int]`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
Expand Down
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
max_new_tokens: Optional[int] = None
# Whether to ignore the EOS token during generation
ignore_eos_token: bool = False
# The parameter for repetition penalty. 1.0 means no penalty.
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@
"max_new_tokens": {
"type": "integer",
"format": "int32",
"default": "20",
"default": "null",
"nullable": true,
"minimum": 0.0,
"exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0
},
"ignore_eos_token": {
Expand Down
22 changes: 6 additions & 16 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(default = "false", example = true)]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32,
#[serde(default)]
#[schema(exclusive_minimum = 0, default = "null")]
pub max_new_tokens: Option<u32>,
#[serde(default)]
#[schema(default = "false", example = true)]
pub ignore_eos_token: bool,
Expand Down Expand Up @@ -267,10 +267,6 @@ pub(crate) struct GenerateParameters {
pub response_format: Option<ResponseFormat>,
}

fn default_max_new_tokens() -> u32 {
20
}

fn default_parameters() -> GenerateParameters {
GenerateParameters {
adapter_id: None,
Expand All @@ -284,7 +280,7 @@ fn default_parameters() -> GenerateParameters {
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: default_max_new_tokens(),
max_new_tokens: None,
ignore_eos_token: false,
return_full_text: None,
stop: Vec::new(),
Expand Down Expand Up @@ -621,10 +617,7 @@ impl From<CompletionRequest> for CompatGenerateRequest {
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
max_new_tokens: req.max_tokens.map(|x| x as u32),
ignore_eos_token: req.ignore_eos_token.unwrap_or(false),
return_full_text: req.echo,
stop: req.stop,
Expand Down Expand Up @@ -658,10 +651,7 @@ impl From<ChatCompletionRequest> for CompatGenerateRequest {
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
max_new_tokens: req.max_tokens.map(|x| x as u32),
ignore_eos_token: req.ignore_eos_token.unwrap_or(false),
return_full_text: None,
stop: req.stop,
Expand Down
57 changes: 33 additions & 24 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Validation {
&self,
inputs: String,
truncate: Option<usize>,
max_new_tokens: u32,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
Expand All @@ -81,16 +81,18 @@ impl Validation {
// Unwrap is safe here
let (inputs, input_length) = response_receiver.await.unwrap()?;

// Get total tokens
let total_tokens = input_length + max_new_tokens as usize;

// Validate MaxTotalTokens
if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens,
input_length,
max_new_tokens,
));
if let Some(max_new_tokens) = max_new_tokens {
// Get total tokens
let total_tokens = input_length + max_new_tokens as usize;

// Validate MaxTotalTokens
if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens,
input_length,
max_new_tokens,
));
}
}

// Validate InputLength
Expand All @@ -111,12 +113,13 @@ impl Validation {
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let input_length = truncate.unwrap_or(self.max_input_length);

// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length,
max_new_tokens,
));
if let Some(max_new_tokens) = max_new_tokens {
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length,
max_new_tokens,
));
}
}

Ok((inputs, input_length))
Expand Down Expand Up @@ -231,7 +234,7 @@ impl Validation {
})
.unwrap_or(Ok(0))?;

if max_new_tokens == 0 {
if max_new_tokens.is_some() && max_new_tokens.unwrap() == 0 {
return Err(ValidationError::NegativeMaxNewTokens);
}

Expand Down Expand Up @@ -294,13 +297,19 @@ impl Validation {
schema,
return_k_alternatives,
};

let effective_max_new_tokens =
max_new_tokens.unwrap_or((self.max_total_tokens - input_length) as u32);
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
max_new_tokens: effective_max_new_tokens,
stop_sequences,
ignore_eos_token,
};

metrics::histogram!("lorax_request_max_new_tokens", max_new_tokens as f64);
metrics::histogram!(
"lorax_request_max_new_tokens",
effective_max_new_tokens as f64
);

Ok(ValidGenerateRequest {
inputs,
Expand Down Expand Up @@ -461,7 +470,7 @@ mod tests {
max_total_tokens,
);

let max_new_tokens = 10;
let max_new_tokens = Some(10);
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
Expand All @@ -488,7 +497,7 @@ mod tests {
max_total_tokens,
);

let max_new_tokens = 10;
let max_new_tokens = Some(10);
match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
Expand Down Expand Up @@ -588,7 +597,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: Some(0.99),
max_new_tokens: 1,
max_new_tokens: Some(1),
..default_parameters()
},
},
Expand All @@ -614,7 +623,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: None,
max_new_tokens: 1,
max_new_tokens: Some(1),
..default_parameters()
},
},
Expand Down

0 comments on commit 8ff0bf5

Please sign in to comment.