Skip to content

Commit

Permalink
community: add SambaNova embeddings integration (langchain-ai#21227)
Browse files Browse the repository at this point in the history
- **Description:**  SambaNova hosted embeddings integration
  • Loading branch information
jhpiedrahitao authored and kyle-cassidy committed May 10, 2024
1 parent 96e7d02 commit c756818
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 0 deletions.
91 changes: 91 additions & 0 deletions docs/docs/integrations/text_embedding/sambanova.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SambaNova\n",
"\n",
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n",
"\n",
"This example goes over how to use LangChain to interact with SambaNova embedding models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SambaStudio\n",
"\n",
"**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register your environment variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
"\n",
"# Set the environment variables\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Call SambaStudio hosted embeddings directly from LangChain!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings.sambanova import SambaStudioEmbeddings\n",
"\n",
"embeddings = SambaStudioEmbeddings()\n",
"\n",
"text = \"Hello, this is a test\"\n",
"result = embeddings.embed_query(text)\n",
"print(result)\n",
"\n",
"texts = [\"Hello, this is a test\", \"Hello, this is another test\"]\n",
"results = embeddings.embed_documents(texts)\n",
"print(results)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@
from langchain_community.embeddings.sagemaker_endpoint import (
SagemakerEndpointEmbeddings,
)
from langchain_community.embeddings.sambanova import (
SambaStudioEmbeddings,
)
from langchain_community.embeddings.self_hosted import (
SelfHostedEmbeddings,
)
Expand Down Expand Up @@ -276,6 +279,7 @@
"QuantizedBgeEmbeddings",
"QuantizedBiEncoderEmbeddings",
"SagemakerEndpointEmbeddings",
"SambaStudioEmbeddings",
"SelfHostedEmbeddings",
"SelfHostedHuggingFaceEmbeddings",
"SelfHostedHuggingFaceInstructEmbeddings",
Expand Down Expand Up @@ -350,6 +354,7 @@
"QuantizedBiEncoderEmbeddings": "langchain_community.embeddings.optimum_intel",
"OracleEmbeddings": "langchain_community.embeddings.oracleai",
"SagemakerEndpointEmbeddings": "langchain_community.embeddings.sagemaker_endpoint",
"SambaStudioEmbeddings": "langchain_community.embeddings.sambanova",
"SelfHostedEmbeddings": "langchain_community.embeddings.self_hosted",
"SelfHostedHuggingFaceEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501
"SelfHostedHuggingFaceInstructEmbeddings": "langchain_community.embeddings.self_hosted_hugging_face", # noqa: E501
Expand Down
142 changes: 142 additions & 0 deletions libs/community/langchain_community/embeddings/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Dict, Generator, List

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env


class SambaStudioEmbeddings(BaseModel, Embeddings):
"""SambaNova embedding models.
To use, you should have the environment variables
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``,
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``,
set with your personal sambastudio variable or pass it as a named parameter
to the constructor.
Example:
.. code-block:: python
from langchain_community.embeddings import SambaStudioEmbeddings
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
sambastudio_embeddings_project_id=project_id,
sambastudio_embeddings_endpoint_id=endpoint_id,
sambastudio_embeddings_api_key=api_key)
(or)
embeddings = SambaStudioEmbeddings()
"""

API_BASE_PATH = "/api/predict/nlp/"
"""Base path to use for the API usage"""

sambastudio_embeddings_base_url: str = ""
"""Base url to use"""

sambastudio_embeddings_project_id: str = ""
"""Project id on sambastudio for model"""

sambastudio_embeddings_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""

sambastudio_embeddings_api_key: str = ""
"""sambastudio api key"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
)
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_project_id",
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID",
)
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_endpoint_id",
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID",
)
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env(
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY"
)
return values

def _get_full_url(self, path: str) -> str:
"""
Return the full API URL for a given path.
:param str path: the sub-path
:returns: the full API URL for the sub-path
:rtype: str
"""
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}"

def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
"""Generator for creating batches in the embed documents method
Args:
texts (List[str]): list of strings to embed
batch_size (int, optional): batch size to be used for the embedding model.
Will depend on the RDU endpoint used.
Yields:
List[str]: list (batch) of strings of size batch size
"""
for i in range(0, len(texts), batch_size):
yield texts[i : i + batch_size]

def embed_documents(
self, texts: List[str], batch_size: int = 32
) -> List[List[float]]:
"""Returns a list of embeddings for the given sentences.
Args:
texts (`List[str]`): List of texts to encode
batch_size (`int`): Batch size for the encoding
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings
for the given sentences
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)

embeddings = []

for batch in self._iterate_over_batches(texts, batch_size):
data = {"inputs": batch}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"]
embeddings.extend(embedding)

return embeddings

def embed_query(self, text: str) -> List[float]:
"""Returns a list of embeddings for the given sentences.
Args:
sentences (`List[str]`): List of sentences to encode
Returns:
`List[np.ndarray]` or `List[tensor]`: List of embeddings
for the given sentences
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)

data = {"inputs": [text]}

response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"][0]

return embedding
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Test SambaNova Embeddings."""

from langchain_community.embeddings.sambanova import (
SambaStudioEmbeddings,
)


def test_embedding_documents() -> None:
"""Test embeddings for documents."""
documents = ["foo", "bar"]
embedding = SambaStudioEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024


def test_embedding_query() -> None:
"""Test embeddings for query."""
document = "foo bar"
embedding = SambaStudioEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"QuantizedBiEncoderEmbeddings",
"NeMoEmbeddings",
"SparkLLMTextEmbeddings",
"SambaStudioEmbeddings",
"TitanTakeoffEmbed",
"QuantizedBgeEmbeddings",
"PremAIEmbeddings",
Expand Down

0 comments on commit c756818

Please sign in to comment.