Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Vertex Check Grounding API integration #186

Merged
merged 17 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/community/langchain_google_community/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
VertexAISearchRetriever,
VertexAISearchSummaryTool,
)
from langchain_google_community.vertex_check_grounding import (
VertexCheckGroundingWrapper,
)
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_community.vision import CloudVisionLoader, CloudVisionParser

Expand Down Expand Up @@ -52,4 +55,5 @@
"VertexAISearchRetriever",
"VertexAISearchSummaryTool",
"VertexAIRank",
"VertexCheckGroundingWrapper",
]
241 changes: 241 additions & 0 deletions libs/community/langchain_google_community/vertex_check_grounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from google.api_core import exceptions as core_exceptions # type: ignore
from google.auth.credentials import Credentials # type: ignore
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable

if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha


class VertexCheckGroundingWrapper(
RunnableSerializable[str, "VertexCheckGroundingWrapper.CheckGroundingResponse"]
):
"""
Initializes the Vertex AI CheckGroundingOutputParser with configurable parameters.

Calls the Check Grounding API to validate the response against a given set of
documents and returns back citations that support the claims along with the cited
chunks. Output is of the type CheckGroundingResponse.

Attributes:
project_id (str): Google Cloud project ID
location_id (str): Location ID for the ranking service.
grounding_config (str):
Required. The resource name of the grounding config, such as
``default_grounding_config``.
It is set to ``default_grounding_config`` by default if unspecified
citation_threshold (float):
The threshold (in [0,1]) used for determining whether a fact
must be cited for a claim in the answer candidate. Choosing
a higher threshold will lead to fewer but very strong
citations, while choosing a lower threshold may lead to more
but somewhat weaker citations. If unset, the threshold will
default to 0.6.
credentials (Optional[Credentials]): Google Cloud credentials object.
credentials_path (Optional[str]): Path to the Google Cloud service
account credentials file.
"""

project_id: str = Field(default=None)
location_id: str = Field(default="global")
grounding_config: str = Field(default="default_grounding_config")
citation_threshold: Optional[float] = Field(default=0.6)
client: Any
credentials: Optional[Credentials] = Field(default=None)
credentials_path: Optional[str] = Field(default=None)

class CheckGroundingResponse(BaseModel):
support_score: float = 0.0
cited_chunks: List[Dict[str, Any]] = []
claims: List[Dict[str, Any]] = []
answer_with_citations: str = ""

def __init__(self, **kwargs: Any):
"""
Constructor for CheckGroundingOutputParser.
Initializes the grounding check service client with necessary credentials
and configurations.
"""
super().__init__(**kwargs)
self.client = kwargs.get("client")
if not self.client:
self.client = self._get_check_grounding_service_client()

def _get_check_grounding_service_client(
self,
) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient":
"""
Returns a GroundedGenerationServiceClient instance using provided credentials.
Raises ImportError if necessary packages are not installed.

Returns:
A GroundedGenerationServiceClient instance.
"""
try:
from google.cloud import discoveryengine_v1alpha
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
"Please install vertexaisearch dependency group: "
"`pip install langchain-google-community[vertexaisearch]`"
) from exc
return discoveryengine_v1alpha.GroundedGenerationServiceClient(
credentials=(
self.credentials
or Credentials.from_service_account_file(self.credentials_path)
if self.credentials_path
else None
)
)

def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> CheckGroundingResponse:
"""
Calls the Vertex Check Grounding API for a given answer candidate and a list
of documents (claims) to validate whether the set of claims support the
answer candidate.

Args:
answer_candidate (str): The candidate answer to be evaluated for grounding.
documents (List[Document]): The documents against which grounding is
checked. This will be converted to facts:
facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
GroundingFact]):
List of facts for the grounding check.
We support up to 200 facts.
Returns:
Response of the type CheckGroundingResponse

Attributes:
support_score (float):
The support score for the input answer
candidate. Higher the score, higher is the
fraction of claims that are supported by the
provided facts. This is always set when a
response is returned.

cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
FactChunk]):
List of facts cited across all claims in the
answer candidate. These are derived from the
facts supplied in the request.

claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
CheckGroundingResponse.Claim]):
Claim texts and citation info across all
claims in the answer candidate.

answer_with_citations (str):
Complete formed answer formatted with inline citations
"""
from google.cloud import discoveryengine_v1alpha

answer_candidate = input
documents = self.extract_documents(config)

grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec(
citation_threshold=self.citation_threshold,
)

facts = [
discoveryengine_v1alpha.GroundingFact(
fact_text=doc.page_content,
attributes={
key: value
for key, value in (
doc.metadata or {}
).items() # Use an empty dict if metadata is None
if key not in ["id", "relevance_score"] and value is not None
},
)
for doc in documents
if doc.page_content # Only check that page_content is not None or empty
]

if not facts:
raise ValueError("No valid documents provided for grounding.")

request = discoveryengine_v1alpha.CheckGroundingRequest(
grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}",
answer_candidate=answer_candidate,
facts=facts,
grounding_spec=grounding_spec,
)

if self.client is None:
raise ValueError("Client not initialized.")
try:
response = self.client.check_grounding(request=request)
except core_exceptions.GoogleAPICallError as e:
raise RuntimeError(
f"Error in Vertex AI Check Grounding API call: {str(e)}"
) from e

support_score = response.support_score
cited_chunks = [
{
"chunk_text": chunk.chunk_text,
"source": documents[int(chunk.source)],
}
for chunk in response.cited_chunks
]
claims = [
{
"start_pos": claim.start_pos,
"end_pos": claim.end_pos,
"claim_text": claim.claim_text,
"citation_indices": list(claim.citation_indices),
}
for claim in response.claims
]

answer_with_citations = self.combine_claims_with_citations(claims)
return self.CheckGroundingResponse(
support_score=support_score,
cited_chunks=cited_chunks,
claims=claims,
answer_with_citations=answer_with_citations,
)

def extract_documents(self, config: Optional[RunnableConfig]) -> List[Document]:
if not config:
raise ValueError("Configuration is required.")

potential_documents = config.get("configurable", {}).get("documents", [])
if not isinstance(potential_documents, list) or not all(
isinstance(doc, Document) for doc in potential_documents
):
raise ValueError("Invalid documents. Each must be an instance of Document.")

if not potential_documents:
raise ValueError("This wrapper requires documents for processing.")

return potential_documents

def combine_claims_with_citations(self, claims: List[Dict[str, Any]]) -> str:
sorted_claims = sorted(claims, key=lambda x: x["start_pos"])
result = []
for claim in sorted_claims:
if claim["citation_indices"]:
citations = "".join([f"[{idx}]" for idx in claim["citation_indices"]])
claim_text = f"{claim['claim_text']}{citations}"
else:
claim_text = claim["claim_text"]
result.append(claim_text)
return " ".join(result).strip()

@classmethod
def get_lc_namespace(cls) -> List[str]:
return ["langchain", "utilities", "check_grounding"]

@classmethod
def is_lc_serializable(cls) -> bool:
return False

class Config:
extra = Extra.ignore
arbitrary_types_allowed = True
116 changes: 116 additions & 0 deletions libs/community/tests/integration_tests/test_check_grounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
from typing import List

import pytest
from google.cloud import discoveryengine_v1alpha
from langchain_core.documents import Document

from langchain_google_community.vertex_check_grounding import (
VertexCheckGroundingWrapper,
)


@pytest.fixture
def input_documents() -> List[Document]:
return [
Document(
page_content=(
"Born in the German Empire, Einstein moved to Switzerland in 1895, "
"forsaking his German citizenship (as a subject of the Kingdom of "
"Württemberg)[note 1] the following year. In 1897, at the age of "
"seventeen, he enrolled in the mathematics and physics teaching "
"diploma program at the Swiss federal polytechnic school in Zürich, "
"graduating in 1900. In 1901, he acquired Swiss citizenship, which "
"he kept for the rest of his life. In 1903, he secured a permanent "
"position at the Swiss Patent Office in Bern. In 1905, he submitted "
"a successful PhD dissertation to the University of Zurich. In 1914, "
"he moved to Berlin in order to join the Prussian Academy of Sciences "
"and the Humboldt University of Berlin. In 1917, he became director "
"of the Kaiser Wilhelm Institute for Physics; he also became a German "
"citizen again, this time as a subject of the Kingdom of Prussia."
"\nIn 1933, while he was visiting the United States, Adolf Hitler came "
'to power in Germany. Horrified by the Nazi "war of extermination" '
"against his fellow Jews,[12] Einstein decided to remain in the US, "
"and was granted American citizenship in 1940.[13] On the eve of World "
"War II, he endorsed a letter to President Franklin D. Roosevelt "
"alerting him to the potential German nuclear weapons program and "
"recommending that the US begin similar research. Einstein supported "
"the Allies but generally viewed the idea of nuclear weapons with "
"great dismay.[14]"
),
metadata={
"language": "en",
"source": "https://en.wikipedia.org/wiki/Albert_Einstein",
"title": "Albert Einstein - Wikipedia",
},
),
Document(
page_content=(
"Life and career\n"
"Childhood, youth and education\n"
"See also: Einstein family\n"
"Einstein in 1882, age\xa03\n"
"Albert Einstein was born in Ulm,[19] in the Kingdom of Württemberg "
"in the German Empire, on 14 March 1879.[20][21] His parents, secular "
"Ashkenazi Jews, were Hermann Einstein, a salesman and engineer, and "
"Pauline Koch. In 1880, the family moved to Munich's borough of "
"Ludwigsvorstadt-Isarvorstadt, where Einstein's father and his uncle "
"Jakob founded Elektrotechnische Fabrik J. Einstein & Cie, a company "
"that manufactured electrical equipment based on direct current.[19]\n"
"Albert attended a Catholic elementary school in Munich from the age "
"of five. When he was eight, he was transferred to the Luitpold "
"Gymnasium, where he received advanced primary and then secondary "
"school education.[22]"
),
metadata={
"language": "en",
"source": "https://en.wikipedia.org/wiki/Albert_Einstein",
"title": "Albert Einstein - Wikipedia",
},
),
]


@pytest.fixture
def grounded_generation_service_client() -> (
discoveryengine_v1alpha.GroundedGenerationServiceClient
):
return discoveryengine_v1alpha.GroundedGenerationServiceClient()


@pytest.fixture
def output_parser(
grounded_generation_service_client: (
discoveryengine_v1alpha.GroundedGenerationServiceClient
),
) -> VertexCheckGroundingWrapper:
return VertexCheckGroundingWrapper(
project_id=os.environ["PROJECT_ID"],
location_id=os.environ.get("REGION", "global"),
grounding_config=os.environ.get("GROUNDING_CONFIG", "default_grounding_config"),
client=grounded_generation_service_client,
)


def test_integration_parse(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should any additional APIs be enabled on the project where tests are running?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only API this integration requires is google-cloud-discoveryengine, which should automatically be installed from langchain-google-community[vertexaisearch]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, should any GCP APIs be enabled / whitelisted for the GCP project where this test is running at?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we would need to whitelist Discovery Engine API if it has not been whitelisted for this to run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to mark it as extended then

output_parser: VertexCheckGroundingWrapper,
input_documents: List[Document],
) -> None:
answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire"
response = output_parser.with_config(
configurable={"documents": input_documents}
).invoke(answer_candidate)

assert isinstance(response, VertexCheckGroundingWrapper.CheckGroundingResponse)
assert response.support_score >= 0 and response.support_score <= 1
assert len(response.cited_chunks) > 0
for chunk in response.cited_chunks:
assert isinstance(chunk["chunk_text"], str)
assert isinstance(chunk["source"], Document)
assert len(response.claims) > 0
for claim in response.claims:
assert isinstance(claim["start_pos"], int)
assert isinstance(claim["end_pos"], int)
assert isinstance(claim["claim_text"], str)
assert isinstance(claim["citation_indices"], list)
assert isinstance(response.answer_with_citations, str)