Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: minor changes sambanova integration #21231

20 changes: 20 additions & 0 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
return SagemakerEndpoint


def _import_sambaverse() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import Sambaverse

return Sambaverse


def _import_sambastudio() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import SambaStudio

return SambaStudio


def _import_self_hosted() -> Type[BaseLLM]:
from langchain_community.llms.self_hosted import SelfHostedPipeline

Expand Down Expand Up @@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
return _import_rwkv()
elif name == "SagemakerEndpoint":
return _import_sagemaker_endpoint()
elif name == "Sambaverse":
return _import_sambaverse()
elif name == "SambaStudio":
return _import_sambastudio()
elif name == "SelfHostedPipeline":
return _import_self_hosted()
elif name == "SelfHostedHuggingFaceLLM":
Expand Down Expand Up @@ -922,6 +938,8 @@ def __getattr__(name: str) -> Any:
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"SparkLLM",
Expand Down Expand Up @@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"replicate": _import_replicate,
"rwkv": _import_rwkv,
"sagemaker_endpoint": _import_sagemaker_endpoint,
"sambaverse": _import_sambaverse,
"sambastudio": _import_sambastudio,
"self_hosted": _import_self_hosted,
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
"stochasticai": _import_stochasticai,
Expand Down
40 changes: 24 additions & 16 deletions libs/community/langchain_community/llms/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,10 @@ class SambaStudio(LLM):

from langchain_community.llms.sambanova import Sambaverse
SambaStudio(
base_url="your SambaStudio environment URL",
project_id=set with your SambaStudio project ID.,
endpoint_id=set with your SambaStudio endpoint ID.,
api_token= set with your SambaStudio endpoint API key.,
sambastudio_base_url="your SambaStudio environment URL",
sambastudio_project_id=set with your SambaStudio project ID.,
sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
sambastudio_api_key= set with your SambaStudio endpoint API key.,
streaming=false
model_kwargs={
"do_sample": False,
Expand All @@ -634,16 +634,16 @@ class SambaStudio(LLM):
)
"""

base_url: str = ""
sambastudio_base_url: str = ""
"""Base url to use"""

project_id: str = ""
sambastudio_project_id: str = ""
"""Project id on sambastudio for model"""

endpoint_id: str = ""
sambastudio_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""

api_key: str = ""
sambastudio_api_key: str = ""
"""sambastudio api key"""

model_kwargs: Optional[dict] = None
Expand Down Expand Up @@ -674,16 +674,16 @@ def _llm_type(self) -> str:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["base_url"] = get_from_dict_or_env(
values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
)
values["project_id"] = get_from_dict_or_env(
values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
)
values["endpoint_id"] = get_from_dict_or_env(
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
)
values["api_key"] = get_from_dict_or_env(
values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
)
return values
Expand Down Expand Up @@ -729,7 +729,11 @@ def _handle_nlp_predict(
ValueError: If the prediction fails.
"""
response = sdk.nlp_predict(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
)
if response["status_code"] != 200:
optional_detail = response["detail"]
Expand All @@ -755,7 +759,7 @@ def _handle_completion_requests(
Raises:
ValueError: If the prediction fails.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)

Expand All @@ -774,7 +778,11 @@ def _handle_nlp_predict_stream(
An iterator of GenerationChunks.
"""
for chunk in sdk.nlp_predict_stream(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
):
yield chunk

Expand All @@ -794,7 +802,7 @@ def _stream(
Returns:
The string generated by the model.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:
Expand Down
2 changes: 2 additions & 0 deletions libs/community/tests/unit_tests/llms/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"StochasticAI",
Expand Down