Skip to content

Commit

Permalink
embed batch
Browse files Browse the repository at this point in the history
  • Loading branch information
levkk committed May 22, 2024
1 parent 51d0057 commit 8b0b9ac
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions pgml-sdks/pgml/src/builtins.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::Context;
use rust_bridge::{alias, alias_methods};
use sqlx::Row;
use tracing::instrument;
Expand All @@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
#[cfg(feature = "python")]
use crate::{query_runner::QueryRunnerPython, types::JsonPython};

#[alias_methods(new, query, transform, embed)]
#[alias_methods(new, query, transform, embed, embed_batch)]
impl Builtins {
pub fn new(database_url: Option<String>) -> Self {
Self { database_url }
Expand Down Expand Up @@ -97,12 +98,45 @@ impl Builtins {
///
pub async fn embed(&self, model: &str, text: &str) -> anyhow::Result<Json> {
let pool = get_or_initialize_pool(&self.database_url).await?;
let query = sqlx::query("SELECT pgml.embed($1, $2)");
let query = sqlx::query("SELECT embed FROM pgml.embed($1, $2)");
let result = query.bind(model).bind(text).fetch_one(&pool).await?;
let result = result.get::<Vec<f32>, _>(0);
let result = serde_json::to_value(result)?;
Ok(Json(result))
}

/// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
///
/// # Arguments
///
/// * `model` - The model to use.
/// * `texts` - The texts to embed.
///
pub async fn embed_batch(&self, model: &str, texts: Json) -> anyhow::Result<Json> {
let texts = texts
.0
.as_array()
.with_context(|| "embed_batch takes an array of texts")?
.into_iter()
.map(|v| {
v.as_str()
.with_context(|| "only text embeddings are supported")
.unwrap()
.to_string()
})
.collect::<Vec<String>>();
let pool = get_or_initialize_pool(&self.database_url).await?;
let query = sqlx::query("SELECT embed AS embed_batch FROM pgml.embed($1, $2)");
let results = query
.bind(model)
.bind(texts)
.fetch_all(&pool)
.await?
.into_iter()
.map(|embeddings| embeddings.get::<Vec<f32>, _>(0))
.collect::<Vec<Vec<f32>>>();
Ok(Json(serde_json::to_value(results)?))
}
}

#[cfg(test)]
Expand Down

0 comments on commit 8b0b9ac

Please sign in to comment.