Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add SambaNova embeddings integration #21227

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 91 additions & 0 deletions docs/docs/integrations/text_embedding/sambanova.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SambaNova\n",
"\n",
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
"\n",
"This example goes over how to use LangChain to interact with SambaNova embedding models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SambaStudio\n",
"\n",
"**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register your environment variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n",
"# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call SambaStudio hosted embeddings directly from LangChain!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings.sambanova import SambaStudioEmbeddings\n",
"\n",
"embeddings = SambaStudioEmbeddings()\n",
"\n",
"text = \"Hello, this is a test\"\n",
"result = embeddings.embed_query(text)\n",
"print(result)\n",
"\n",
"texts = [\"Hello, this is a test\", \"Hello, this is another test\"]\n",
"results = embeddings.embed_documents(texts)\n",
"print(results)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@
from langchain_community.embeddings.sagemaker_endpoint import (
SagemakerEndpointEmbeddings,
)
from langchain_community.embeddings.sambanova import (
SambaStudioEmbeddings,
)
from langchain_community.embeddings.self_hosted import (
SelfHostedEmbeddings,
)
Expand Down Expand Up @@ -276,6 +279,7 @@
"QuantizedBgeEmbeddings",
"QuantizedBiEncoderEmbeddings",
"SagemakerEndpointEmbeddings",
"SambaStudioEmbeddings",
"SelfHostedEmbeddings",
"SelfHostedHuggingFaceEmbeddings",
"SelfHostedHuggingFaceInstructEmbeddings",
Expand Down Expand Up @@ -350,6 +354,7 @@
"QuantizedBiEncoderEmbeddings": "langchain_community.embeddings.optimum_intel",
"OracleEmbeddings": "langchain_community.embeddings.oracleai",
"SagemakerEndpointEmbeddings": "langchain_community.embeddings.sagemaker_endpoint",
"SambaStudioEmbeddings": "langchain_community.embeddings.sambanova",
"SelfHostedEmbeddings": "langchain_community.embeddings.self_hosted",
"SelfHostedHuggingFaceEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501
"SelfHostedHuggingFaceInstructEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501
Expand Down
142 changes: 142 additions & 0 deletions libs/community/langchain_community/embeddings/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Dict, Generator, List

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env


class SambaStudioEmbeddings(BaseModel, Embeddings):
"""SambaNova embedding models.

To use, you should have the environment variables
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``,
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``,
set with your personal sambastudio variable or pass it as a named parameter
to the constructor.

Example:
.. code-block:: python

from langchain_community.embeddings import SambaStudioEmbeddings
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
sambastudio_embeddings_project_id=project_id,
sambastudio_embeddings_endpoint_id=endpoint_id,
sambastudio_embeddings_api_key=api_key)
(or)
embeddings = SambaStudioEmbeddings()
"""

API_BASE_PATH = "/api/predict/nlp/"
"""Base path to use for the API usage"""

sambastudio_embeddings_base_url: str = ""
"""Base url to use"""

sambastudio_embeddings_project_id: str = ""
"""Project id on sambastudio for model"""

sambastudio_embeddings_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""

sambastudio_embeddings_api_key: str = ""
"""sambastudio api key"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
)
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_project_id",
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID",
)
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_endpoint_id",
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID",
)
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env(
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY"
)
return values

def _get_full_url(self, path: str) -> str:
"""
Return the full API URL for a given path.

:param str path: the sub-path
:returns: the full API URL for the sub-path
:rtype: str
"""
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}"

def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
"""Generator for creating batches in the embed documents method
Args:
texts (List[str]): list of strings to embed
batch_size (int, optional): batch size to be used for the embedding model.
Will depend on the RDU endpoint used.
Yields:
List[str]: list (batch) of strings of size batch size
"""
for i in range(0, len(texts), batch_size):
yield texts[i : i + batch_size]

def embed_documents(
self, texts: List[str], batch_size: int = 32
) -> List[List[float]]:
"""Returns a list of embeddings for the given sentences.
Args:
texts (`List[str]`): List of texts to encode
batch_size (`int`): Batch size for the encoding

Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings
for the given sentences
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)

embeddings = []

for batch in self._iterate_over_batches(texts, batch_size):
data = {"inputs": batch}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"]
embeddings.extend(embedding)

return embeddings

def embed_query(self, text: str) -> List[float]:
"""Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode

Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings
for the given sentences
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)

data = {"inputs": [text]}

response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"][0]

return embedding
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Test SambaNova Embeddings."""

from langchain_community.embeddings.sambanova import (
SambaStudioEmbeddings,
)


def test_embedding_documents() -> None:
"""Test embeddings for documents."""
documents = ["foo", "bar"]
embedding = SambaStudioEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024


def test_embedding_query() -> None:
"""Test embeddings for query."""
document = "foo bar"
embedding = SambaStudioEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"QuantizedBiEncoderEmbeddings",
"NeMoEmbeddings",
"SparkLLMTextEmbeddings",
"SambaStudioEmbeddings",
"TitanTakeoffEmbed",
"QuantizedBgeEmbeddings",
"PremAIEmbeddings",
Expand Down