-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Vertex Check Grounding API integration #186
Merged
Merged
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
04e928c
feat: Vertex Check Grounding API integration
Abhishekbhagwat 4b693d9
add more descriptive docstrings
Abhishekbhagwat 165697d
remove hanging stdout
Abhishekbhagwat 6dfa19e
Merge branch 'main' into check_grounding
Abhishekbhagwat 9fea0b8
fix failing spellcheck test
Abhishekbhagwat ebfe02a
Merge branch 'main' into check_grounding
Abhishekbhagwat db78786
fix import dependencies
Abhishekbhagwat 6c38321
change type to a Runnable, fix empty metadata ignore
Abhishekbhagwat 39cbb17
Merge branch 'main' into check_grounding
Abhishekbhagwat ba3d6f2
standardize naming convention
Abhishekbhagwat fb7f528
add ignore flag for lint
Abhishekbhagwat dc16ff9
fix failing lint w/ import
Abhishekbhagwat ca681da
Merge branch 'main' into check_grounding
Abhishekbhagwat fb57dd6
Merge branch 'main' into check_grounding
Abhishekbhagwat 110bb7b
Merge branch 'main' into check_grounding
Abhishekbhagwat ebb2204
marking integration test as extended
Abhishekbhagwat 460ae1d
ignore imports lint
Abhishekbhagwat File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
241 changes: 241 additions & 0 deletions
241
libs/community/langchain_google_community/vertex_check_grounding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
||
from google.api_core import exceptions as core_exceptions # type: ignore | ||
from google.auth.credentials import Credentials # type: ignore | ||
from langchain_core.documents import Document | ||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field | ||
from langchain_core.runnables import RunnableConfig, RunnableSerializable | ||
|
||
if TYPE_CHECKING: | ||
from google.cloud import discoveryengine_v1alpha | ||
|
||
|
||
class VertexCheckGroundingWrapper( | ||
RunnableSerializable[str, "VertexCheckGroundingWrapper.CheckGroundingResponse"] | ||
): | ||
""" | ||
Initializes the Vertex AI CheckGroundingOutputParser with configurable parameters. | ||
|
||
Calls the Check Grounding API to validate the response against a given set of | ||
documents and returns back citations that support the claims along with the cited | ||
chunks. Output is of the type CheckGroundingResponse. | ||
|
||
Attributes: | ||
project_id (str): Google Cloud project ID | ||
location_id (str): Location ID for the ranking service. | ||
grounding_config (str): | ||
Required. The resource name of the grounding config, such as | ||
``default_grounding_config``. | ||
It is set to ``default_grounding_config`` by default if unspecified | ||
citation_threshold (float): | ||
The threshold (in [0,1]) used for determining whether a fact | ||
must be cited for a claim in the answer candidate. Choosing | ||
a higher threshold will lead to fewer but very strong | ||
citations, while choosing a lower threshold may lead to more | ||
but somewhat weaker citations. If unset, the threshold will | ||
default to 0.6. | ||
credentials (Optional[Credentials]): Google Cloud credentials object. | ||
credentials_path (Optional[str]): Path to the Google Cloud service | ||
account credentials file. | ||
""" | ||
|
||
project_id: str = Field(default=None) | ||
location_id: str = Field(default="global") | ||
grounding_config: str = Field(default="default_grounding_config") | ||
citation_threshold: Optional[float] = Field(default=0.6) | ||
client: Any | ||
credentials: Optional[Credentials] = Field(default=None) | ||
credentials_path: Optional[str] = Field(default=None) | ||
|
||
class CheckGroundingResponse(BaseModel): | ||
support_score: float = 0.0 | ||
cited_chunks: List[Dict[str, Any]] = [] | ||
claims: List[Dict[str, Any]] = [] | ||
answer_with_citations: str = "" | ||
|
||
def __init__(self, **kwargs: Any): | ||
""" | ||
Constructor for CheckGroundingOutputParser. | ||
Initializes the grounding check service client with necessary credentials | ||
and configurations. | ||
""" | ||
super().__init__(**kwargs) | ||
self.client = kwargs.get("client") | ||
if not self.client: | ||
self.client = self._get_check_grounding_service_client() | ||
|
||
def _get_check_grounding_service_client( | ||
self, | ||
) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient": | ||
""" | ||
Returns a GroundedGenerationServiceClient instance using provided credentials. | ||
Raises ImportError if necessary packages are not installed. | ||
|
||
Returns: | ||
A GroundedGenerationServiceClient instance. | ||
""" | ||
try: | ||
from google.cloud import discoveryengine_v1alpha | ||
except ImportError as exc: | ||
raise ImportError( | ||
"Could not import google-cloud-discoveryengine python package. " | ||
"Please install vertexaisearch dependency group: " | ||
"`pip install langchain-google-community[vertexaisearch]`" | ||
) from exc | ||
return discoveryengine_v1alpha.GroundedGenerationServiceClient( | ||
credentials=( | ||
self.credentials | ||
or Credentials.from_service_account_file(self.credentials_path) | ||
if self.credentials_path | ||
else None | ||
) | ||
) | ||
|
||
def invoke( | ||
self, input: str, config: Optional[RunnableConfig] = None | ||
) -> CheckGroundingResponse: | ||
""" | ||
Calls the Vertex Check Grounding API for a given answer candidate and a list | ||
of documents (claims) to validate whether the set of claims support the | ||
answer candidate. | ||
|
||
Args: | ||
answer_candidate (str): The candidate answer to be evaluated for grounding. | ||
documents (List[Document]): The documents against which grounding is | ||
checked. This will be converted to facts: | ||
facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
GroundingFact]): | ||
List of facts for the grounding check. | ||
We support up to 200 facts. | ||
Returns: | ||
Response of the type CheckGroundingResponse | ||
|
||
Attributes: | ||
support_score (float): | ||
The support score for the input answer | ||
candidate. Higher the score, higher is the | ||
fraction of claims that are supported by the | ||
provided facts. This is always set when a | ||
response is returned. | ||
|
||
cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
FactChunk]): | ||
List of facts cited across all claims in the | ||
answer candidate. These are derived from the | ||
facts supplied in the request. | ||
|
||
claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
CheckGroundingResponse.Claim]): | ||
Claim texts and citation info across all | ||
claims in the answer candidate. | ||
|
||
answer_with_citations (str): | ||
Complete formed answer formatted with inline citations | ||
""" | ||
from google.cloud import discoveryengine_v1alpha | ||
|
||
answer_candidate = input | ||
documents = self.extract_documents(config) | ||
|
||
grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec( | ||
citation_threshold=self.citation_threshold, | ||
) | ||
|
||
facts = [ | ||
discoveryengine_v1alpha.GroundingFact( | ||
fact_text=doc.page_content, | ||
attributes={ | ||
key: value | ||
for key, value in ( | ||
doc.metadata or {} | ||
).items() # Use an empty dict if metadata is None | ||
if key not in ["id", "relevance_score"] and value is not None | ||
}, | ||
) | ||
for doc in documents | ||
if doc.page_content # Only check that page_content is not None or empty | ||
] | ||
|
||
if not facts: | ||
raise ValueError("No valid documents provided for grounding.") | ||
|
||
request = discoveryengine_v1alpha.CheckGroundingRequest( | ||
grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}", | ||
answer_candidate=answer_candidate, | ||
facts=facts, | ||
grounding_spec=grounding_spec, | ||
) | ||
|
||
if self.client is None: | ||
raise ValueError("Client not initialized.") | ||
try: | ||
response = self.client.check_grounding(request=request) | ||
except core_exceptions.GoogleAPICallError as e: | ||
raise RuntimeError( | ||
f"Error in Vertex AI Check Grounding API call: {str(e)}" | ||
) from e | ||
|
||
support_score = response.support_score | ||
cited_chunks = [ | ||
{ | ||
"chunk_text": chunk.chunk_text, | ||
"source": documents[int(chunk.source)], | ||
} | ||
for chunk in response.cited_chunks | ||
] | ||
claims = [ | ||
{ | ||
"start_pos": claim.start_pos, | ||
"end_pos": claim.end_pos, | ||
"claim_text": claim.claim_text, | ||
"citation_indices": list(claim.citation_indices), | ||
} | ||
for claim in response.claims | ||
] | ||
|
||
answer_with_citations = self.combine_claims_with_citations(claims) | ||
return self.CheckGroundingResponse( | ||
support_score=support_score, | ||
cited_chunks=cited_chunks, | ||
claims=claims, | ||
answer_with_citations=answer_with_citations, | ||
) | ||
|
||
def extract_documents(self, config: Optional[RunnableConfig]) -> List[Document]: | ||
if not config: | ||
raise ValueError("Configuration is required.") | ||
|
||
potential_documents = config.get("configurable", {}).get("documents", []) | ||
if not isinstance(potential_documents, list) or not all( | ||
isinstance(doc, Document) for doc in potential_documents | ||
): | ||
raise ValueError("Invalid documents. Each must be an instance of Document.") | ||
|
||
if not potential_documents: | ||
raise ValueError("This wrapper requires documents for processing.") | ||
|
||
return potential_documents | ||
|
||
def combine_claims_with_citations(self, claims: List[Dict[str, Any]]) -> str: | ||
sorted_claims = sorted(claims, key=lambda x: x["start_pos"]) | ||
result = [] | ||
for claim in sorted_claims: | ||
if claim["citation_indices"]: | ||
citations = "".join([f"[{idx}]" for idx in claim["citation_indices"]]) | ||
claim_text = f"{claim['claim_text']}{citations}" | ||
else: | ||
claim_text = claim["claim_text"] | ||
result.append(claim_text) | ||
return " ".join(result).strip() | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
return ["langchain", "utilities", "check_grounding"] | ||
|
||
@classmethod | ||
def is_lc_serializable(cls) -> bool: | ||
return False | ||
|
||
class Config: | ||
extra = Extra.ignore | ||
arbitrary_types_allowed = True |
116 changes: 116 additions & 0 deletions
116
libs/community/tests/integration_tests/test_check_grounding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
from typing import List | ||
|
||
import pytest | ||
from google.cloud import discoveryengine_v1alpha | ||
from langchain_core.documents import Document | ||
|
||
from langchain_google_community.vertex_check_grounding import ( | ||
VertexCheckGroundingWrapper, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def input_documents() -> List[Document]: | ||
return [ | ||
Document( | ||
page_content=( | ||
"Born in the German Empire, Einstein moved to Switzerland in 1895, " | ||
"forsaking his German citizenship (as a subject of the Kingdom of " | ||
"Württemberg)[note 1] the following year. In 1897, at the age of " | ||
"seventeen, he enrolled in the mathematics and physics teaching " | ||
"diploma program at the Swiss federal polytechnic school in Zürich, " | ||
"graduating in 1900. In 1901, he acquired Swiss citizenship, which " | ||
"he kept for the rest of his life. In 1903, he secured a permanent " | ||
"position at the Swiss Patent Office in Bern. In 1905, he submitted " | ||
"a successful PhD dissertation to the University of Zurich. In 1914, " | ||
"he moved to Berlin in order to join the Prussian Academy of Sciences " | ||
"and the Humboldt University of Berlin. In 1917, he became director " | ||
"of the Kaiser Wilhelm Institute for Physics; he also became a German " | ||
"citizen again, this time as a subject of the Kingdom of Prussia." | ||
"\nIn 1933, while he was visiting the United States, Adolf Hitler came " | ||
'to power in Germany. Horrified by the Nazi "war of extermination" ' | ||
"against his fellow Jews,[12] Einstein decided to remain in the US, " | ||
"and was granted American citizenship in 1940.[13] On the eve of World " | ||
"War II, he endorsed a letter to President Franklin D. Roosevelt " | ||
"alerting him to the potential German nuclear weapons program and " | ||
"recommending that the US begin similar research. Einstein supported " | ||
"the Allies but generally viewed the idea of nuclear weapons with " | ||
"great dismay.[14]" | ||
), | ||
metadata={ | ||
"language": "en", | ||
"source": "https://en.wikipedia.org/wiki/Albert_Einstein", | ||
"title": "Albert Einstein - Wikipedia", | ||
}, | ||
), | ||
Document( | ||
page_content=( | ||
"Life and career\n" | ||
"Childhood, youth and education\n" | ||
"See also: Einstein family\n" | ||
"Einstein in 1882, age\xa03\n" | ||
"Albert Einstein was born in Ulm,[19] in the Kingdom of Württemberg " | ||
"in the German Empire, on 14 March 1879.[20][21] His parents, secular " | ||
"Ashkenazi Jews, were Hermann Einstein, a salesman and engineer, and " | ||
"Pauline Koch. In 1880, the family moved to Munich's borough of " | ||
"Ludwigsvorstadt-Isarvorstadt, where Einstein's father and his uncle " | ||
"Jakob founded Elektrotechnische Fabrik J. Einstein & Cie, a company " | ||
"that manufactured electrical equipment based on direct current.[19]\n" | ||
"Albert attended a Catholic elementary school in Munich from the age " | ||
"of five. When he was eight, he was transferred to the Luitpold " | ||
"Gymnasium, where he received advanced primary and then secondary " | ||
"school education.[22]" | ||
), | ||
metadata={ | ||
"language": "en", | ||
"source": "https://en.wikipedia.org/wiki/Albert_Einstein", | ||
"title": "Albert Einstein - Wikipedia", | ||
}, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def grounded_generation_service_client() -> ( | ||
discoveryengine_v1alpha.GroundedGenerationServiceClient | ||
): | ||
return discoveryengine_v1alpha.GroundedGenerationServiceClient() | ||
|
||
|
||
@pytest.fixture | ||
def output_parser( | ||
grounded_generation_service_client: ( | ||
discoveryengine_v1alpha.GroundedGenerationServiceClient | ||
), | ||
) -> VertexCheckGroundingWrapper: | ||
return VertexCheckGroundingWrapper( | ||
project_id=os.environ["PROJECT_ID"], | ||
location_id=os.environ.get("REGION", "global"), | ||
grounding_config=os.environ.get("GROUNDING_CONFIG", "default_grounding_config"), | ||
client=grounded_generation_service_client, | ||
) | ||
|
||
|
||
def test_integration_parse( | ||
output_parser: VertexCheckGroundingWrapper, | ||
input_documents: List[Document], | ||
) -> None: | ||
answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire" | ||
response = output_parser.with_config( | ||
configurable={"documents": input_documents} | ||
).invoke(answer_candidate) | ||
|
||
assert isinstance(response, VertexCheckGroundingWrapper.CheckGroundingResponse) | ||
assert response.support_score >= 0 and response.support_score <= 1 | ||
assert len(response.cited_chunks) > 0 | ||
for chunk in response.cited_chunks: | ||
assert isinstance(chunk["chunk_text"], str) | ||
assert isinstance(chunk["source"], Document) | ||
assert len(response.claims) > 0 | ||
for claim in response.claims: | ||
assert isinstance(claim["start_pos"], int) | ||
assert isinstance(claim["end_pos"], int) | ||
assert isinstance(claim["claim_text"], str) | ||
assert isinstance(claim["citation_indices"], list) | ||
assert isinstance(response.answer_with_citations, str) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should any additional APIs be enabled on the project where tests are running?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the only API this integration requires is google-cloud-discoveryengine, which should automatically be installed from langchain-google-community[vertexaisearch]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, should any GCP APIs be enabled / whitelisted for the GCP project where this test is running at?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we would need to whitelist Discovery Engine API if it has not been whitelisted for this to run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to mark it as
extended
then