Skip to content

Commit

Permalink
Clippy cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
SilasMarvin committed May 20, 2024
1 parent b4e35a1 commit dad8c12
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pgml-sdks/pgml/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl Builtins {
query.bind(task.0)
};
let results = query.bind(inputs).bind(args).fetch_all(&pool).await?;
let results = results.get(0).unwrap().get::<serde_json::Value, _>(0);
let results = results.first().unwrap().get::<serde_json::Value, _>(0);
Ok(Json(results))
}
}
Expand Down
8 changes: 6 additions & 2 deletions pgml-sdks/pgml/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ impl Collection {
}

#[instrument(skip(self))]
pub async fn rag(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result<Json> {
pub async fn rag(&self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result<Json> {
let pool = get_or_initialize_pool(&self.database_url).await?;
let (built_query, values) = build_rag_query(query.clone(), self, pipeline, false).await?;
let mut results: Vec<(Json,)> = sqlx::query_as_with(&built_query, values)
Expand All @@ -1136,7 +1136,11 @@ impl Collection {
}

#[instrument(skip(self))]
pub async fn rag_stream(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result<RAGStream> {
pub async fn rag_stream(
&self,
query: Json,
pipeline: &mut Pipeline,
) -> anyhow::Result<RAGStream> {
let pool = get_or_initialize_pool(&self.database_url).await?;

let (built_query, values) = build_rag_query(query.clone(), self, pipeline, true).await?;
Expand Down
8 changes: 4 additions & 4 deletions pgml-sdks/pgml/src/transformer_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl TransformerPipeline {
a.insert("model".to_string(), model.into());

// We must convert any floating point values to integers or our extension will get angry
for field in vec!["gpu_layers"] {
for field in ["gpu_layers"] {
if let Some(v) = a.remove(field) {
let x: u64 = CustomU64Convertor(v).into();
a.insert(field.to_string(), x.into());
Expand Down Expand Up @@ -62,7 +62,7 @@ impl TransformerPipeline {
}

// We must convert any floating point values to integers or our extension will get angry
for field in vec!["max_tokens", "n"] {
for field in ["max_tokens", "n"] {
if let Some(v) = a.remove(field) {
let x: u64 = CustomU64Convertor(v).into();
a.insert(field.to_string(), x.into());
Expand Down Expand Up @@ -95,7 +95,7 @@ impl TransformerPipeline {
.fetch_all(&pool)
.await?
};
let results = results.get(0).unwrap().get::<serde_json::Value, _>(0);
let results = results.first().unwrap().get::<serde_json::Value, _>(0);
Ok(Json(results))
}

Expand All @@ -121,7 +121,7 @@ impl TransformerPipeline {
}

// We must convert any floating point values to integers or our extension will get angry
for field in vec!["max_tokens", "n"] {
for field in ["max_tokens", "n"] {
if let Some(v) = a.remove(field) {
let x: u64 = CustomU64Convertor(v).into();
a.insert(field.to_string(), x.into());
Expand Down

0 comments on commit dad8c12

Please sign in to comment.