Skip to content

Commit

Permalink
Add support for batching to embedder models (#503)
Browse files Browse the repository at this point in the history
Co-authored-by: Magdy Saleh <[email protected]>
  • Loading branch information
tgaddair and magdyksaleh committed Jun 8, 2024
1 parent 1b528e0 commit e8f3d33
Show file tree
Hide file tree
Showing 16 changed files with 913 additions and 199 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,24 @@ message DecodeResponse {
optional CachedBatch batch = 2;
}

message EmbedRequest {
string inputs = 1;
message Embedding {
/// Request ID
uint64 request_id = 1;

/// Embedding values
repeated float values = 2;
}

message Embedding {
repeated float values = 1;
message EmbedRequest {
/// Batch
Batch batch = 1;
}

message EmbedResponse {
Embedding embeddings = 1;
/// Embeddings
repeated Embedding embeddings = 1;

/// Error message on failure
string errorMsg = 2;
}

Expand Down
3 changes: 2 additions & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
ngrok = { version = "0.12.3", features = ["axum"], optional = true }
once_cell = "1.19.0"
itertools = "0.12.1"
async-trait = "0.1.80"

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

[features]
default = ["ngrok"]
ngrok = ["dep:ngrok"]
ngrok = ["dep:ngrok"]
16 changes: 8 additions & 8 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ impl Client {
Ok(response)
}

/// Embed
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<EmbedResponse> {
let request = tonic::Request::new(EmbedRequest { inputs }).inject_context();
let response = self.stub.embed(request).await?.into_inner();
Ok(response)
}

/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
Expand Down Expand Up @@ -196,6 +188,14 @@ impl Client {
Ok((response.generations, response.batch))
}

/// Embed
#[instrument(skip(self))]
pub async fn embed(&mut self, batch: Batch) -> Result<Vec<Embedding>> {
let request = tonic::Request::new(EmbedRequest { batch: Some(batch) }).inject_context();
let response = self.stub.embed(request).await?.into_inner();
Ok(response.embeddings)
}

/// Downloads the weights for an adapter.
pub async fn download_adapter(
&mut self,
Expand Down
2 changes: 1 addition & 1 deletion router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse,
AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding,
FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy,
NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters,
};
Expand Down
11 changes: 6 additions & 5 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pb::generate::v1::EmbedResponse;
use crate::pb::generate::v1::{EmbedResponse, Embedding};
/// Multi shard Client
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
Expand Down Expand Up @@ -154,15 +154,16 @@ impl ShardedClient {
merge_generations(results?)
}

/// Get the model info
/// Embed the given batch
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<Vec<EmbedResponse>> {
pub async fn embed(&mut self, batch: Batch) -> Result<Vec<Embedding>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.embed(inputs.clone())))
.map(|client| Box::pin(client.embed(batch.clone())))
.collect();
join_all(futures).await.into_iter().collect()
let results: Result<Vec<Vec<Embedding>>> = join_all(futures).await.into_iter().collect();
Ok(results?.into_iter().flatten().collect())
}

pub async fn download_adapter(
Expand Down

0 comments on commit e8f3d33

Please sign in to comment.