forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community: add SambaNova embeddings integration (langchain-ai#21227)
- **Description:** SambaNova hosted embeddings integration
- Loading branch information
1 parent
d8e0e2d
commit d37586e
Showing
5 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# SambaNova\n", | ||
"\n", | ||
"**[SambaNova](https://sambanova.ai/)'s** [Sambastudio](https://sambanova.ai/technology/full-stack-ai-platform) is a platform for running your own open-source models\n", | ||
"\n", | ||
"This example goes over how to use LangChain to interact with SambaNova embedding models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## SambaStudio\n", | ||
"\n", | ||
"**SambaStudio** allows you to train, run batch inference jobs, and deploy online inference endpoints to run open source models that you fine tuned yourself." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"A SambaStudio environment is required to deploy a model. Get more information at [sambanova.ai/products/enterprise-ai-platform-sambanova-suite](https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Register your environment variables:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n", | ||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n", | ||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n", | ||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n", | ||
"\n", | ||
"# Set the environment variables\n", | ||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_BASE_URL\"] = sambastudio_base_url\n", | ||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID\"] = sambastudio_project_id\n", | ||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID\"] = sambastudio_endpoint_id\n", | ||
"os.environ[\"SAMBASTUDIO_EMBEDDINGS_API_KEY\"] = sambastudio_api_key" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Call SambaStudio hosted embeddings directly from LangChain!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.embeddings.sambanova import SambaStudioEmbeddings\n", | ||
"\n", | ||
"embeddings = SambaStudioEmbeddings()\n", | ||
"\n", | ||
"text = \"Hello, this is a test\"\n", | ||
"result = embeddings.embed_query(text)\n", | ||
"print(result)\n", | ||
"\n", | ||
"texts = [\"Hello, this is a test\", \"Hello, this is another test\"]\n", | ||
"results = embeddings.embed_documents(texts)\n", | ||
"print(results)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
142 changes: 142 additions & 0 deletions
142
libs/community/langchain_community/embeddings/sambanova.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import Dict, Generator, List | ||
|
||
import requests | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, root_validator | ||
from langchain_core.utils import get_from_dict_or_env | ||
|
||
|
||
class SambaStudioEmbeddings(BaseModel, Embeddings): | ||
"""SambaNova embedding models. | ||
To use, you should have the environment variables | ||
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, | ||
``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``, ``SAMBASTUDIO_EMBEDDINGS_API_KEY``, | ||
set with your personal sambastudio variable or pass it as a named parameter | ||
to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import SambaStudioEmbeddings | ||
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, | ||
sambastudio_embeddings_project_id=project_id, | ||
sambastudio_embeddings_endpoint_id=endpoint_id, | ||
sambastudio_embeddings_api_key=api_key) | ||
(or) | ||
embeddings = SambaStudioEmbeddings() | ||
""" | ||
|
||
API_BASE_PATH = "/api/predict/nlp/" | ||
"""Base path to use for the API usage""" | ||
|
||
sambastudio_embeddings_base_url: str = "" | ||
"""Base url to use""" | ||
|
||
sambastudio_embeddings_project_id: str = "" | ||
"""Project id on sambastudio for model""" | ||
|
||
sambastudio_embeddings_endpoint_id: str = "" | ||
"""endpoint id on sambastudio for model""" | ||
|
||
sambastudio_embeddings_api_key: str = "" | ||
"""sambastudio api key""" | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( | ||
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" | ||
) | ||
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( | ||
values, | ||
"sambastudio_embeddings_project_id", | ||
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID", | ||
) | ||
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env( | ||
values, | ||
"sambastudio_embeddings_endpoint_id", | ||
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID", | ||
) | ||
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env( | ||
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY" | ||
) | ||
return values | ||
|
||
def _get_full_url(self, path: str) -> str: | ||
""" | ||
Return the full API URL for a given path. | ||
:param str path: the sub-path | ||
:returns: the full API URL for the sub-path | ||
:rtype: str | ||
""" | ||
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}" | ||
|
||
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: | ||
"""Generator for creating batches in the embed documents method | ||
Args: | ||
texts (List[str]): list of strings to embed | ||
batch_size (int, optional): batch size to be used for the embedding model. | ||
Will depend on the RDU endpoint used. | ||
Yields: | ||
List[str]: list (batch) of strings of size batch size | ||
""" | ||
for i in range(0, len(texts), batch_size): | ||
yield texts[i : i + batch_size] | ||
|
||
def embed_documents( | ||
self, texts: List[str], batch_size: int = 32 | ||
) -> List[List[float]]: | ||
"""Returns a list of embeddings for the given sentences. | ||
Args: | ||
texts (`List[str]`): List of texts to encode | ||
batch_size (`int`): Batch size for the encoding | ||
Returns: | ||
`List[np.ndarray]` or `List[tensor]`: List of embeddings | ||
for the given sentences | ||
""" | ||
http_session = requests.Session() | ||
url = self._get_full_url( | ||
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" | ||
) | ||
|
||
embeddings = [] | ||
|
||
for batch in self._iterate_over_batches(texts, batch_size): | ||
data = {"inputs": batch} | ||
response = http_session.post( | ||
url, | ||
headers={"key": self.sambastudio_embeddings_api_key}, | ||
json=data, | ||
) | ||
embedding = response.json()["data"] | ||
embeddings.extend(embedding) | ||
|
||
return embeddings | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Returns a list of embeddings for the given sentences. | ||
Args: | ||
sentences (`List[str]`): List of sentences to encode | ||
Returns: | ||
`List[np.ndarray]` or `List[tensor]`: List of embeddings | ||
for the given sentences | ||
""" | ||
http_session = requests.Session() | ||
url = self._get_full_url( | ||
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" | ||
) | ||
|
||
data = {"inputs": [text]} | ||
|
||
response = http_session.post( | ||
url, | ||
headers={"key": self.sambastudio_embeddings_api_key}, | ||
json=data, | ||
) | ||
embedding = response.json()["data"][0] | ||
|
||
return embedding |
22 changes: 22 additions & 0 deletions
22
libs/community/tests/integration_tests/embeddings/test_sambanova.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
"""Test SambaNova Embeddings.""" | ||
|
||
from langchain_community.embeddings.sambanova import ( | ||
SambaStudioEmbeddings, | ||
) | ||
|
||
|
||
def test_embedding_documents() -> None: | ||
"""Test embeddings for documents.""" | ||
documents = ["foo", "bar"] | ||
embedding = SambaStudioEmbeddings() | ||
output = embedding.embed_documents(documents) | ||
assert len(output) == 2 | ||
assert len(output[0]) == 1024 | ||
|
||
|
||
def test_embedding_query() -> None: | ||
"""Test embeddings for query.""" | ||
document = "foo bar" | ||
embedding = SambaStudioEmbeddings() | ||
output = embedding.embed_query(document) | ||
assert len(output) == 1024 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters