Skip to content

Commit

Permalink
community: minor changes sambanova integration (#21231)
Browse files Browse the repository at this point in the history
- **Description:** fix: variable names in root validator not allowing
pass credentials as named parameters in llm instancing, also added
sambanova's sambaverse and sambastudio llms to __init__.py for module
import
  • Loading branch information
jhpiedrahitao committed May 6, 2024
1 parent d9a61c0 commit df1c102
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
20 changes: 20 additions & 0 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,18 @@ def _import_sagemaker_endpoint() -> Type[BaseLLM]:
return SagemakerEndpoint


def _import_sambaverse() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import Sambaverse

return Sambaverse


def _import_sambastudio() -> Type[BaseLLM]:
from langchain_community.llms.sambanova import SambaStudio

return SambaStudio


def _import_self_hosted() -> Type[BaseLLM]:
from langchain_community.llms.self_hosted import SelfHostedPipeline

Expand Down Expand Up @@ -793,6 +805,10 @@ def __getattr__(name: str) -> Any:
return _import_rwkv()
elif name == "SagemakerEndpoint":
return _import_sagemaker_endpoint()
elif name == "Sambaverse":
return _import_sambaverse()
elif name == "SambaStudio":
return _import_sambastudio()
elif name == "SelfHostedPipeline":
return _import_self_hosted()
elif name == "SelfHostedHuggingFaceLLM":
Expand Down Expand Up @@ -922,6 +938,8 @@ def __getattr__(name: str) -> Any:
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"SparkLLM",
Expand Down Expand Up @@ -1015,6 +1033,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"replicate": _import_replicate,
"rwkv": _import_rwkv,
"sagemaker_endpoint": _import_sagemaker_endpoint,
"sambaverse": _import_sambaverse,
"sambastudio": _import_sambastudio,
"self_hosted": _import_self_hosted,
"self_hosted_hugging_face": _import_self_hosted_hugging_face,
"stochasticai": _import_stochasticai,
Expand Down
40 changes: 24 additions & 16 deletions libs/community/langchain_community/llms/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,10 @@ class SambaStudio(LLM):
from langchain_community.llms.sambanova import Sambaverse
SambaStudio(
base_url="your SambaStudio environment URL",
project_id=set with your SambaStudio project ID.,
endpoint_id=set with your SambaStudio endpoint ID.,
api_token= set with your SambaStudio endpoint API key.,
sambastudio_base_url="your SambaStudio environment URL",
sambastudio_project_id=set with your SambaStudio project ID.,
sambastudio_endpoint_id=set with your SambaStudio endpoint ID.,
sambastudio_api_key= set with your SambaStudio endpoint API key.,
streaming=false
model_kwargs={
"do_sample": False,
Expand All @@ -634,16 +634,16 @@ class SambaStudio(LLM):
)
"""

base_url: str = ""
sambastudio_base_url: str = ""
"""Base url to use"""

project_id: str = ""
sambastudio_project_id: str = ""
"""Project id on sambastudio for model"""

endpoint_id: str = ""
sambastudio_endpoint_id: str = ""
"""endpoint id on sambastudio for model"""

api_key: str = ""
sambastudio_api_key: str = ""
"""sambastudio api key"""

model_kwargs: Optional[dict] = None
Expand Down Expand Up @@ -674,16 +674,16 @@ def _llm_type(self) -> str:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["base_url"] = get_from_dict_or_env(
values["sambastudio_base_url"] = get_from_dict_or_env(
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
)
values["project_id"] = get_from_dict_or_env(
values["sambastudio_project_id"] = get_from_dict_or_env(
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
)
values["endpoint_id"] = get_from_dict_or_env(
values["sambastudio_endpoint_id"] = get_from_dict_or_env(
values, "sambastudio_endpoint_id", "SAMBASTUDIO_ENDPOINT_ID"
)
values["api_key"] = get_from_dict_or_env(
values["sambastudio_api_key"] = get_from_dict_or_env(
values, "sambastudio_api_key", "SAMBASTUDIO_API_KEY"
)
return values
Expand Down Expand Up @@ -729,7 +729,11 @@ def _handle_nlp_predict(
ValueError: If the prediction fails.
"""
response = sdk.nlp_predict(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
)
if response["status_code"] != 200:
optional_detail = response["detail"]
Expand All @@ -755,7 +759,7 @@ def _handle_completion_requests(
Raises:
ValueError: If the prediction fails.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)

Expand All @@ -774,7 +778,11 @@ def _handle_nlp_predict_stream(
An iterator of GenerationChunks.
"""
for chunk in sdk.nlp_predict_stream(
self.project_id, self.endpoint_id, self.api_key, prompt, tuning_params
self.sambastudio_project_id,
self.sambastudio_endpoint_id,
self.sambastudio_api_key,
prompt,
tuning_params,
):
yield chunk

Expand All @@ -794,7 +802,7 @@ def _stream(
Returns:
The string generated by the model.
"""
ss_endpoint = SSEndpointHandler(self.base_url)
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
tuning_params = self._get_tuning_params(stop)
try:
if self.streaming:
Expand Down
2 changes: 2 additions & 0 deletions libs/community/tests/unit_tests/llms/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
"RWKV",
"Replicate",
"SagemakerEndpoint",
"Sambaverse",
"SambaStudio",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"StochasticAI",
Expand Down

0 comments on commit df1c102

Please sign in to comment.