Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added features for May 2024 Embeddings Models #205

Merged
merged 19 commits into from
Jun 8, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 31 additions & 7 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GoogleEmbeddingModelType(str, Enum):
def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]:
if value.lower().startswith("text"):
return GoogleEmbeddingModelType.TEXT
elif "multimodalembedding" in value.lower():
if "multimodalembedding" in value.lower():
return GoogleEmbeddingModelType.MULTIMODAL
return None

Expand Down Expand Up @@ -109,6 +109,9 @@ def __init__(
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")
self.instance["dimensionality_supported"] = self.client._endpoint_name.endswith(
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
"preview-0409"
)

retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
Expand Down Expand Up @@ -188,25 +191,29 @@ def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
return batches

def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
self,
texts: List[str],
embeddings_type: Optional[str] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
with telemetry.tool_context_manager(self._user_agent):
if self.model_type == GoogleEmbeddingModelType.MULTIMODAL:
return self._get_multimodal_embeddings_with_retry(texts)
return self._get_multimodal_embeddings_with_retry(texts, dimensions)
return self._get_text_embeddings_with_retry(
texts, embeddings_type=embeddings_type
texts, embeddings_type=embeddings_type, dimensions=dimensions
)

def _get_multimodal_embeddings_with_retry(
self, texts: List[str]
self, texts: List[str], dimensions: Optional[int] = None
) -> List[List[float]]:
tasks = []
for text in texts:
tasks.append(
self.instance["task_executor"].submit(
self.instance["get_embeddings_with_retry"],
contextual_text=text,
dimensions=dimensions,
)
)
if len(tasks) > 0:
Expand All @@ -215,7 +222,10 @@ def _get_multimodal_embeddings_with_retry(
return embeddings

def _get_text_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
self,
texts: List[str],
embeddings_type: Optional[str] = None,
dimensions: Optional[int] = None,
holtskinner marked this conversation as resolved.
Show resolved Hide resolved
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""

Expand All @@ -225,7 +235,12 @@ def _get_text_embeddings_with_retry(
]
else:
requests = texts
embeddings = self.instance["get_embeddings_with_retry"](requests)

kwargs = {}
if dimensions and self.instance["dimensionality_supported"]:
kwargs["output_dimensionality"] = dimensions

embeddings = self.instance["get_embeddings_with_retry"](requests, **kwargs)
return [embedding.values for embedding in embeddings]

def _prepare_and_validate_batches(
Expand Down Expand Up @@ -309,8 +324,11 @@ def embed(
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of strings.

Expand All @@ -329,6 +347,11 @@ def embed(
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
The following are only supported on preview models:
QUESTION_ANSWERING
FACT_VERIFICATION
dimensions: [int] optional. Output embeddings dimensions.
Only supported on preview models.

Returns:
List of embeddings, one for each text.
Expand All @@ -355,6 +378,7 @@ def embed(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
dimensions=dimensions,
)
)
if len(tasks) > 0:
Expand Down