From 5df6f0a61001f86e766c3890d4cab91bc7a2e351 Mon Sep 17 00:00:00 2001 From: cryscan Date: Mon, 10 Jun 2024 22:35:33 +0800 Subject: [PATCH] Implement speculative caching. --- crates/ai00-core/src/lib.rs | 5 +- crates/ai00-core/src/run.rs | 219 ++++++++++++++++++++++++------------ 2 files changed, 147 insertions(+), 77 deletions(-) diff --git a/crates/ai00-core/src/lib.rs b/crates/ai00-core/src/lib.rs index ce43b44..5472562 100644 --- a/crates/ai00-core/src/lib.rs +++ b/crates/ai00-core/src/lib.rs @@ -702,7 +702,7 @@ pub async fn model_route(receiver: Receiver) -> Result<()> { let context = GenerateContext { prompt_tokens: tokens.to_vec(), - prompt_cached: false, + prompt_cached: Default::default(), prefix: Default::default(), suffix: tokens, output: None, @@ -720,8 +720,9 @@ pub async fn model_route(receiver: Receiver) -> Result<()> { let queue = queue.clone(); let sender = sender.clone(); tokio::spawn(async move { + let context = &mut env.read().await.enqueue(context).await; let mut queue = queue.lock().await; - queue.append(&mut env.read().await.enqueue(context).await); + queue.append(context); let _ = sender.send(()); }); } diff --git a/crates/ai00-core/src/run.rs b/crates/ai00-core/src/run.rs index f308af8..da6a0d8 100644 --- a/crates/ai00-core/src/run.rs +++ b/crates/ai00-core/src/run.rs @@ -1,6 +1,7 @@ use std::{ cmp::Ordering, collections::HashMap, + ops::Deref, path::PathBuf, sync::Arc, time::{Duration, Instant}, @@ -33,7 +34,7 @@ use crate::{ }; const END_OF_LINE_TOKEN: u16 = 261; -const PROMPT_CACHE_TOKENS: usize = 32; +const MIN_PROMPT_CACHE_TOKENS: usize = 32; const MAX_CACHE_ITEMS: usize = 256; const SAMPLER_ARENA_CAPACITY: usize = 1048576; const GRAMMAR_ARENA_CAPACITY: usize = 1024; @@ -214,7 +215,7 @@ pub struct GenerateContext { /// Tokens that are provided at first. pub prompt_tokens: Vec, /// Whether the prompt has already been processed and cached. - pub prompt_cached: bool, + pub prompt_cached: CachedPrompt, /// Tokens that have been computed and cached. pub prefix: Tokens, /// Tokens to be computed. @@ -240,8 +241,17 @@ pub struct GenerateContext { pub sender: Sender, } +#[derive(Debug, Default, Clone)] +pub enum CachedPrompt { + #[default] + None, + Future(tokio::sync::watch::Sender>), + Done, +} + +/// An item that a cache slot holds, including a state, last model output and a timestamp. #[derive(Debug, Clone)] -struct CachedItem { +pub struct CachedItem { state: TensorCpu, output: TensorCpu, instant: Instant, @@ -274,7 +284,7 @@ struct CacheCheckout { #[derive(Debug, Default)] struct Cache { state: Option, - cache: Trie, + cache: Trie>>, } impl Cache { @@ -287,6 +297,7 @@ impl Cache { let mut remove = vec![]; for (tokens, _) in cache .iter() + .filter_map(|(tokens, item)| item.borrow().clone().map(|item| (tokens, item))) .sorted_unstable_by_key(|(_, item)| item.instant.elapsed()) .skip(MAX_CACHE_ITEMS) { @@ -514,7 +525,7 @@ impl Runtime { /// Search for the longest common prefix in the memory cache and checkout the state from that point. /// Should there be a cache miss, an initial state is returned. - async fn checkout(&self, id: StateId, tokens: &[u16], batch: usize) -> CacheCheckout { + async fn checkout(&self, id: StateId, tokens: &[u16]) -> CacheCheckout { let mut caches = self.caches.lock().await; let Cache { state, cache } = caches.fetch(id); @@ -523,16 +534,23 @@ impl Runtime { .rev() .find(|len| cache.contains_key(prefix[0..*len].as_token_slice())) .unwrap_or_default(); - log::info!("slot {} checks out backed cache of length {}", batch, len); - let prefix = prefix[0..len].to_vec(); - let state = state.clone().map(|state| state.data); - match cache.remove(prefix[..].as_token_slice()) { - Some(item) => { + let state = state.clone().map(|state| state.data); + let item = cache.get(prefix[..].as_token_slice()).cloned(); + drop(caches); + + match item { + Some(sender) => { + let mut receiver = sender.subscribe(); + let item = loop { + if let Some(item) = receiver.borrow_and_update().deref().clone() { + break item; + } + let _ = receiver.changed().await; + }; let item = CachedItem::update(item); - let key = Tokens(prefix.clone()); - cache.insert(key, item.clone()); + sender.send_replace(Some(item.clone())); CacheCheckout { prefix, state: item.state, @@ -540,6 +558,7 @@ impl Runtime { } } None => { + let prefix = vec![]; let state = state.unwrap_or_else(|| self.state.init()); CacheCheckout { prefix, @@ -618,10 +637,13 @@ impl Runtime { // back a non-relative and non-empty slot and use it for our new context Some(SlotChoice::Back(batch)) => { log::info!("start at non-empty slot {}", batch); - let checkout = self.checkout(context.request.state, &tokens, batch).await; + let checkout = self.checkout(context.request.state, &tokens).await; self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); + assert!(len == 0 || (len > 0 && checkout.output.is_some())); + log::info!("slot {} checks out cache of length {}", batch, len); + let mut state = SlotState::Wait( GenerateContext { prefix: Tokens(tokens[..len].to_vec()), @@ -639,10 +661,13 @@ impl Runtime { // directly occupy an empty slot so no need backing Some(SlotChoice::Empty(batch)) => { log::info!("start at empty slot {}", batch); - let checkout = self.checkout(context.request.state, &tokens, batch).await; + let checkout = self.checkout(context.request.state, &tokens).await; self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); + assert!(len == 0 || (len > 0 && checkout.output.is_some())); + log::info!("slot {} checks out cache of length {}", batch, len); + let state = SlotState::Wait( GenerateContext { prefix: Tokens(tokens[..len].to_vec()), @@ -654,15 +679,18 @@ impl Runtime { .into(), ); slots[batch] = state; - Ok(SlotResult::Fault(batch)) + Ok(SlotResult::Success(batch)) } // continue from an existing slot Some(SlotChoice::Continue(batch, ..)) => { log::info!("continue at slot {}", batch); - let checkout = self.checkout(context.request.state, &tokens, batch).await; + let checkout = self.checkout(context.request.state, &tokens).await; self.state.load(checkout.state, batch)?; let len = checkout.prefix.len(); + assert!(len == 0 || (len > 0 && checkout.output.is_some())); + log::info!("slot {} checks out cache of length {}", batch, len); + let state = SlotState::Wait( GenerateContext { prefix: Tokens(tokens[..len].to_vec()), @@ -680,10 +708,10 @@ impl Runtime { } /// This critical section synchronizes `slots` and fills `payloads`. - async fn prepare(&self, payloads: &mut [Payload]) -> Result<()> { + async fn synchronize(&self, payloads: &mut [Payload]) -> Result<()> { let mut slots = self.slots.lock().await; - // sync payloads and slots: kill dead payloads + // synchronize payloads and slots: kill dead payloads for (slot, payload) in slots.iter().zip(payloads.iter_mut()) { if !(payload.is_empty() || matches!(slot, SlotState::Busy)) { log::warn!("payload should either be empty or slot should be busy"); @@ -708,7 +736,9 @@ impl Runtime { if let Some(output) = context.output { let mut caches = self.caches.lock().await; let cache = &mut caches.fetch(context.request.state).cache; - cache.insert(context.prefix.clone(), CachedItem::new(backed, output)); + let item = CachedItem::new(backed, output); + let (item, _) = tokio::sync::watch::channel(Some(item)); + cache.insert(context.prefix.clone(), item); log::info!( "backed completed slot {} of length {}", batch, @@ -720,7 +750,7 @@ impl Runtime { slots[batch] = SlotState::Idle(context.prefix, Instant::now()); } - // take data from some waiting slots + // take data from some pending slots let occupancy = payloads .iter() .filter(|x| matches!(x, Payload::Busy(_))) @@ -733,51 +763,85 @@ impl Runtime { .take(remain) .map(|(batch, _)| batch) .collect_vec(); + for batch in batches { let mut slot = SlotState::Busy; std::mem::swap(&mut slots[batch], &mut slot); - match slot { - SlotState::Wait(context) => { - let _ = context.sender.send(Token::Start); - assert!(matches!(payloads[batch], Payload::Empty)); - payloads[batch] = Payload::Busy(*context); - } - _ => unreachable!(), + + let SlotState::Wait(context) = slot else { + unreachable!() }; + let mut context = *context; + + // allocate a future cache slot + let mut caches = self.caches.lock().await; + let cache = &mut caches.fetch(context.request.state).cache; + + let enable = context.prompt_tokens.len() > MIN_PROMPT_CACHE_TOKENS; + let enable = enable && !cache.contains_key(context.prompt_tokens.as_token_slice()); + if enable { + let (sender, _) = tokio::sync::watch::channel(None); + context.prompt_cached = CachedPrompt::Future(sender.clone()); + cache.insert(Tokens(context.prompt_tokens.clone()), sender); + + log::info!( + "slot {} schedules future back of length {}", + batch, + context.prompt_tokens.len() + ); + } + + log::info!( + "slot {}, suffix: {}, output: {}", + batch, + context.suffix.len(), + context.output.is_some() + ); + + let _ = context.sender.send(Token::Start); + assert!(matches!(payloads[batch], Payload::Empty)); + payloads[batch] = Payload::Busy(context); } Ok(()) } - async fn sample( - &self, - payloads: &mut [Payload], - outputs: Vec>>, - ) -> Result)>> { + async fn sample(&self, payloads: &mut [Payload]) -> Result)>> { // update raw outputs let mut set = tokio::task::JoinSet::new(); - for (batch, (payload, output)) in payloads.iter().zip(outputs.iter()).enumerate() { - if let (Payload::Busy(context), Some(output)) = (payload, output) { - let num_vocab = self.info.num_vocab; - let output = output.clone(); - let transformers = context.transformers.clone(); - let sampler = context.request.sampler.clone(); - let bias = context.request.bias.clone(); - set.spawn(async move { - let mut data = output.to_vec(); - assert_eq!(data.len(), num_vocab); - - sampler.read().await.transform(&mut data); - for (token, bias) in bias.iter() { - data[*token as usize] += *bias; - } - for transformer in transformers { - transformer.read().await.transform(&mut data); - } + for (batch, payload) in payloads.iter().enumerate() { + let Payload::Busy(context) = payload else { + continue; + }; - (batch, data) - }); + // in case that we have not yet read the whole prompt but still gets the output (from the cache) + if !context.suffix.is_empty() { + continue; } + + let Some(output) = context.output.clone() else { + continue; + }; + + let num_vocab = self.info.num_vocab; + let output = output; + let transformers = context.transformers.clone(); + let sampler = context.request.sampler.clone(); + let bias = context.request.bias.clone(); + set.spawn(async move { + let mut data = output.to_vec(); + assert_eq!(data.len(), num_vocab); + + sampler.read().await.transform(&mut data); + for (token, bias) in bias.iter() { + data[*token as usize] += *bias; + } + for transformer in transformers { + transformer.read().await.transform(&mut data); + } + + (batch, data) + }); } let mut outputs = HashMap::new(); while let Some(Ok((batch, data))) = set.join_next().await { @@ -822,22 +886,17 @@ impl Runtime { Ok(tokens) } - async fn process(&self, payloads: &mut [Payload]) -> Result<()> { - let outputs = payloads - .iter() - .map(|payload| match payload { - Payload::Busy(context) => context.output.clone(), - _ => None, - }) - .collect(); - let tokens = self.sample(payloads, outputs).await?; - + async fn finalize( + &self, + payloads: &mut [Payload], + tokens: HashMap)>, + ) -> Result<()> { for (batch, payload) in payloads.iter_mut().enumerate() { let Payload::Busy(context) = payload else { continue; }; - // in case that we are not actually generating but still gets the output (from the cache) + // in case that we have not yet read the whole prompt but still gets the output (from the cache) if !context.suffix.is_empty() { continue; } @@ -847,15 +906,13 @@ impl Runtime { }; // cache the prompt if it is too long. - let cache_prompt = !context.prompt_cached; - let cache_prompt = cache_prompt && context.prompt_tokens.len() > PROMPT_CACHE_TOKENS; - if let Some(output) = cache_prompt.then_some(()).and(context.output.clone()) { - let mut caches = self.caches.lock().await; - let cache = &mut caches.fetch(context.request.state).cache; + if let (CachedPrompt::Future(sender), Some(output)) = + (context.prompt_cached.clone(), context.output.clone()) + { + assert_eq!(context.prefix.len(), context.prompt_tokens.len()); let backed = self.state.back(batch).await?; - - cache.insert(context.prefix.clone(), CachedItem::new(backed, output)); - context.prompt_cached = true; + sender.send_replace(Some(CachedItem::new(backed, output))); + context.prompt_cached = CachedPrompt::Done; log::info!( "backed prompt of slot {} of length {}", @@ -880,7 +937,7 @@ impl Runtime { let instant = context.instant.get_or_insert(Instant::now()); let mut done = false; - let mut finish = |reason| { + let mut stop = |reason| { let counter = { let prompt = context.prompt_tokens.len(); let completion = context.model_tokens.len(); @@ -998,9 +1055,9 @@ impl Runtime { } else if exhausted || stop_matched { let output = String::from_utf8_lossy(head); let _ = context.sender.send(Token::Content(output.into())); - finish(FinishReason::Stop); + stop(FinishReason::Stop); } else if context.model_tokens.len() >= context.request.max_tokens { - finish(FinishReason::Length); + stop(FinishReason::Length); } else if let Ok(word) = String::from_utf8(head.to_vec()) { let _ = context.sender.send(Token::Content(word)); context.buffer = tail.to_vec(); @@ -1009,7 +1066,13 @@ impl Runtime { done.then(|| payload.finalize()); } - self.prepare(payloads).await?; + Ok(()) + } + + async fn process(&self, payloads: &mut [Payload]) -> Result<()> { + let tokens = self.sample(payloads).await?; + self.finalize(payloads, tokens).await?; + self.synchronize(payloads).await?; let option = InferOption::Last; let batches = payloads @@ -1041,6 +1104,12 @@ impl Runtime { let Payload::Busy(context) = payload else { continue; }; + + // if the suffix is empty, the output is read from the cache, and we don't want to override it. + if context.suffix.is_empty() { + continue; + } + context.output = match output.len() { 0 => None, x if x == self.info.num_vocab => Some(output.0.clone()),