Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

embedding with graph_store #12860

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 25 additions & 18 deletions llama-index-core/llama_index/core/indices/knowledge_graph/base.py
Expand Up @@ -213,12 +213,7 @@ def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> KG:

if self.include_embeddings:
triplet_texts = [str(t) for t in triplets]

embed_outputs = self._embed_model.get_text_embedding_batch(
triplet_texts, show_progress=self._show_progress
)
for rel_text, rel_embed in zip(triplet_texts, embed_outputs):
index_struct.add_to_embedding_dict(rel_text, rel_embed)
self._embed(triplet_texts)

return index_struct

Expand All @@ -234,16 +229,32 @@ def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
triplet_str = str(triplet)
self.upsert_triplet(triplet)
self._index_struct.add_node([subj, obj], n)
if (
self.include_embeddings
and triplet_str not in self._index_struct.embedding_dict
):
rel_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(triplet_str, rel_embedding)
if self.include_embeddings:
self._embed([triplet_str])

# Update the storage context's index_store
self._storage_context.index_store.add_index_struct(self._index_struct)

def embed(self, triplet_texts: List[str] = [], refresh: bool = False, add_to_index_struct: bool = False, use_graph_store: bool = False) -> None:
"""Generate embeddings for the index."""
if not self.include_embeddings:
logger.warning("Embeddings are not included in the index. Enable them first using include_embeddings.")
return None

if use_graph_store:
triplet_texts.extend([str(triplet) for triplet in self.graph_store.get_all_triplets(refresh=refresh)])

filtered_triplet_texts = [text for text in triplet_texts if text not in self._index_struct.embedding_dict]
embed_outputs = self._embed_model.get_text_embedding_batch(
filtered_triplet_texts, show_progress=self._show_progress
)
for rel_text, rel_embed in zip(filtered_triplet_texts, embed_outputs):
self._index_struct.add_to_embedding_dict(rel_text, rel_embed)

# Update the storage context's index_store
if add_to_index_struct or use_graph_store:
self._storage_context.index_store.add_index_struct(self._index_struct)

def upsert_triplet(
self, triplet: Tuple[str, str, str], include_embeddings: bool = False
) -> None:
Expand All @@ -259,9 +270,7 @@ def upsert_triplet(
self._graph_store.upsert_triplet(*triplet)
triplet_str = str(triplet)
if include_embeddings:
set_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
self._storage_context.index_store.add_index_struct(self._index_struct)
self._embed([triplet_str], add_to_index_struct=True)

def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""Add node.
Expand Down Expand Up @@ -299,9 +308,7 @@ def upsert_triplet_and_node(
self.add_node([subj, obj], node)
triplet_str = str(triplet)
if include_embeddings:
set_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
self._storage_context.index_store.add_index_struct(self._index_struct)
self._embed([triplet_str], add_to_index_struct=True)

def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
"""Delete a node."""
Expand Down
Expand Up @@ -34,7 +34,8 @@ def __init__(
self._driver.query(f"CREATE INDEX FOR (n:`{self._node_label}`) ON (n.id)")

self._database = database

self.nodes = []
self.triplets = []
self.schema = ""
self.get_query = f"""
MATCH (n1:`{self._node_label}`)-[r]->(n2:`{self._node_label}`)
Expand Down Expand Up @@ -179,3 +180,36 @@ def get_schema(self, refresh: bool = False) -> str:
def query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
result = self._driver.query(query, params=params)
return result.result_set

def get_all_nodes(self, refresh: bool = False) -> List[str]:
"""
Get all nodes in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of nodes from the graph store.

Returns:
- List[str]: A list of node IDs/names.
"""
if refresh or not self.nodes:
query = f"MATCH (n:{self.node_label}) RETURN n.id as node"
self.nodes = [record["node"] for record in self.query(query)]
return self.nodes

def get_all_triplets(self, refresh: bool = False) -> List[str]:
"""
Get all triplets in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of triplets from the graph store.

Returns:
- List[str]: A list of relationship descriptions in the format (start_node, rel_type, end_node).
"""
if refresh or not self.triplets:
query = """
MATCH (start)-[rel]->(end)
RETURN start.id AS start_node, type(rel) AS rel_type, end.id AS end_node
"""
self.triplets = [(record["start_node"], ' '.join(record["rel_type"].split('_')).capitalize(), record["end_node"]) for record in self.query(query)]
return self.triplets
Expand Up @@ -18,6 +18,8 @@ def __init__(
self.connection = kuzu.Connection(database)
self.node_table_name = node_table_name
self.rel_table_name = rel_table_name
self.nodes = []
self.triplets = []
self.init_schema()

def init_schema(self) -> None:
Expand Down Expand Up @@ -224,3 +226,36 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "KuzuGraphStore":
Graph store.
"""
return cls(**config_dict)

def get_all_nodes(self, refresh: bool = False) -> List[str]:
"""
Get all nodes in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of nodes from the graph store.

Returns:
- List[str]: A list of node IDs/names.
"""
if refresh or not self.nodes:
query = f"MATCH (n:{self.node_table_name}) RETURN n.ID as node"
self.nodes = [record["node"] for record in self.query(query)]
return self.nodes

def get_all_triplets(self, refresh: bool = False) -> List[str]:
"""
Get all triplets in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of triplets from the graph store.

Returns:
- List[str]: A list of relationship descriptions in the format (start_node, rel_type, end_node).
"""
if refresh or not self.triplets:
query = """
MATCH (start)-[rel]->(end)
RETURN start.ID AS start_node, rel.predicate AS rel_type, end.ID AS end_node
"""
self.triplets = [(record["start_node"], record["rel_type"].capitalize(), record["end_node"]) for record in self.query(query)]
return self.triplets
Expand Up @@ -202,6 +202,8 @@ def __init__(
)

self._include_vid = include_vid
self.nodes = []
self.triplets = []

def init_session_pool(self) -> Any:
"""Return NebulaGraph session pool."""
Expand Down Expand Up @@ -665,3 +667,36 @@ def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
col_list = result.column_values(col_name)
d[col_name] = [x.cast() for x in col_list]
return d

def get_all_nodes(self, refresh: bool = False) -> List[str]:
"""
Get all nodes in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of nodes from the graph store.

Returns:
- List[str]: A list of node IDs/names.
"""
if refresh or not self.nodes:
query = "MATCH (n) RETURN id(n) as node"
self.nodes = [record["node"] for record in self.query(query)]
return self.nodes

def get_all_triplets(self, refresh: bool = False) -> List[str]:
"""
Get all triplets in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of triplets from the graph store.

Returns:
- List[str]: A list of relationship descriptions in the format (start_node, rel_type, end_node).
"""
if refresh or not self.triplets:
query = """
MATCH (start)-[rel]->(end)
RETURN id(start) AS start_node, type(rel) AS rel_type, id(end) AS end_node
"""
self.triplets = [(record["start_node"], ' '.join(record["rel_type"].split('_')).capitalize(), record["end_node"]) for record in self.query(query)]
return self.triplets
Expand Up @@ -49,6 +49,8 @@ def __init__(
self._database = database
self.schema = ""
self.structured_schema: Dict[str, Any] = {}
self.nodes = []
self.triplets = []
# Verify connection
try:
with self._driver as driver:
Expand Down Expand Up @@ -253,3 +255,36 @@ def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
with self._driver.session(database=self._database) as session:
result = session.run(query, param_map)
return [d.data() for d in result]

def get_all_nodes(self, refresh: bool = False) -> List[str]:
"""
Get all nodes in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of nodes from the graph store.

Returns:
- List[str]: A list of node IDs/names.
"""
if refresh or not self.nodes:
query = f"MATCH (n:{self.node_label}) RETURN n.id as node"
self.nodes = [record["node"] for record in self.query(query)]
return self.nodes

def get_all_triplets(self, refresh: bool = False) -> List[str]:
"""
Get all triplets in the graph store.

Parameters:
- refresh (bool): If True, refreshes the list of triplets from the graph store.

Returns:
- List[str]: A list of relationship descriptions in the format (start_node, rel_type, end_node).
"""
if refresh or not self.triplets:
query = """
MATCH (start)-[rel]->(end)
RETURN start.id AS start_node, type(rel) AS rel_type, end.id AS end_node
"""
self.triplets = [(record["start_node"], ' '.join(record["rel_type"].split('_')).capitalize(), record["end_node"]) for record in self.query(query)]
return self.triplets