Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tabby-inference, http-api-bindings): support llama.cpp server embedding interface. #2094

Merged
merged 4 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ serde_json = { workspace = true }
tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
tracing.workspace = true

[dev-dependencies]
tokio ={ workspace = true, features = ["rt", "macros"]}
23 changes: 23 additions & 0 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mod openai_chat;

use std::sync::Arc;

use openai_chat::OpenAIChatEngine;
use tabby_inference::ChatCompletionStream;

use crate::{get_optional_param, get_param};

pub fn create(model: &str) -> Arc<dyn ChatCompletionStream> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai-chat" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");

let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
Arc::new(engine)

Check warning on line 19 in crates/http-api-bindings/src/chat/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/mod.rs#L10-L19

Added lines #L10 - L19 were not covered by tests
} else {
panic!("Only openai-chat are supported for http chat");

Check warning on line 21 in crates/http-api-bindings/src/chat/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/mod.rs#L21

Added line #L21 was not covered by tests
}
}

Check warning on line 23 in crates/http-api-bindings/src/chat/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/chat/mod.rs#L23

Added line #L23 was not covered by tests
33 changes: 33 additions & 0 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
mod llama;
mod openai;

use std::sync::Arc;

use llama::LlamaCppEngine;
use openai::OpenAIEngine;
use tabby_inference::CompletionStream;

use crate::{get_optional_param, get_param};

pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
(Arc::new(engine), prompt_template, chat_template)

Check warning on line 29 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L12-L29

Added lines #L12 - L29 were not covered by tests
} else {
panic!("Only openai are supported for http completion");

Check warning on line 31 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L31

Added line #L31 was not covered by tests
}
}

Check warning on line 33 in crates/http-api-bindings/src/completion/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/completion/mod.rs#L33

Added line #L33 was not covered by tests
64 changes: 64 additions & 0 deletions crates/http-api-bindings/src/embedding/llama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;

pub struct LlamaCppEngine {
client: reqwest::Client,
api_endpoint: String,
api_key: Option<String>,
}

impl LlamaCppEngine {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Self {
let client = reqwest::Client::new();

Self {
client,
api_endpoint: format!("{}/embeddings", api_endpoint),
api_key,
}
}

Check warning on line 20 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L12-L20

Added lines #L12 - L20 were not covered by tests
}

#[derive(Serialize)]
struct EmbeddingRequest {
content: String,
}

#[derive(Deserialize)]

Check warning on line 28 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L28

Added line #L28 was not covered by tests
struct EmbeddingResponse {
embedding: Vec<f32>,
}

#[async_trait]
impl Embedding for LlamaCppEngine {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let request = EmbeddingRequest {
content: prompt.to_owned(),
};

let mut request = self.client.post(&self.api_endpoint).json(&request);
if let Some(api_key) = &self.api_key {
request = request.bearer_auth(api_key);
}

Check warning on line 43 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L35-L43

Added lines #L35 - L43 were not covered by tests

let response = request.send().await?.json::<EmbeddingResponse>().await?;
Ok(response.embedding)
}

Check warning on line 47 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L45-L47

Added lines #L45 - L47 were not covered by tests
}

#[cfg(test)]
mod tests {
use super::*;

/// This unit test should only run manually when the server is running
/// curl -L https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q8_0.gguf -o ./models/nomic.gguf
/// ./server -m ./models/nomic.gguf --port 8000 --embedding
#[tokio::test]
#[ignore]
async fn test_embedding() {
let engine = LlamaCppEngine::create("http://localhost:8000", None);
let embedding = engine.embed("hello").await.unwrap();
assert_eq!(embedding.len(), 768);
}

Check warning on line 63 in crates/http-api-bindings/src/embedding/llama.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/llama.rs#L59-L63

Added lines #L59 - L63 were not covered by tests
}
21 changes: 21 additions & 0 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
mod llama;

use std::sync::Arc;

use llama::LlamaCppEngine;
use tabby_inference::Embedding;

use crate::{get_optional_param, get_param};

pub fn create(model: &str) -> Arc<dyn Embedding> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
Arc::new(engine)

Check warning on line 17 in crates/http-api-bindings/src/embedding/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/mod.rs#L10-L17

Added lines #L10 - L17 were not covered by tests
} else {
panic!("Only llama are supported for http embedding");

Check warning on line 19 in crates/http-api-bindings/src/embedding/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/mod.rs#L19

Added line #L19 was not covered by tests
}
}

Check warning on line 21 in crates/http-api-bindings/src/embedding/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/embedding/mod.rs#L21

Added line #L21 was not covered by tests
56 changes: 8 additions & 48 deletions crates/http-api-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,13 @@
mod llama;
mod openai;
mod openai_chat;
mod chat;
mod completion;
mod embedding;

use std::sync::Arc;

use openai::OpenAIEngine;
use openai_chat::OpenAIChatEngine;
pub use chat::create as create_chat;
pub use completion::create;
pub use embedding::create as create_embedding;
use serde_json::Value;
use tabby_inference::{ChatCompletionStream, CompletionStream};

pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = llama::LlamaCppEngine::create(&api_endpoint, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else {
panic!("Only openai are supported for http completion");
}
}

pub fn create_chat(model: &str) -> Arc<dyn ChatCompletionStream> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai-chat" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");

let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
Arc::new(engine)
} else {
panic!("Only openai-chat are supported for http chat");
}
}

fn get_param(params: &Value, key: &str) -> String {
pub(crate) fn get_param(params: &Value, key: &str) -> String {

Check warning on line 10 in crates/http-api-bindings/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/lib.rs#L10

Added line #L10 was not covered by tests
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
Expand All @@ -56,7 +16,7 @@
.to_owned()
}

fn get_optional_param(params: &Value, key: &str) -> Option<String> {
pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option<String> {

Check warning on line 19 in crates/http-api-bindings/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/http-api-bindings/src/lib.rs#L19

Added line #L19 was not covered by tests
params
.get(key)
.map(|x| x.as_str().expect("Type unmatched").to_owned())
Expand Down
6 changes: 6 additions & 0 deletions crates/tabby-inference/src/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use async_trait::async_trait;

#[async_trait]
pub trait Embedding: Sync + Send {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>>;
}
2 changes: 2 additions & 0 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ mod chat;
mod code;
mod completion;
mod decoding;
mod embedding;

pub use chat::{ChatCompletionOptions, ChatCompletionOptionsBuilder, ChatCompletionStream};
pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream};
pub use embedding::Embedding;

fn default_seed() -> u64 {
std::time::SystemTime::now()
Expand Down