diff --git a/-d b/-d new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.semversioner/next-release/minor-20250410183424623609.json b/.semversioner/next-release/minor-20250410183424623609.json new file mode 100644 index 0000000000..18cd3825f5 --- /dev/null +++ b/.semversioner/next-release/minor-20250410183424623609.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Add merge_entities" +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 3977ed5820..d54289682e 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -399,7 +399,13 @@ class VectorStoreDefaults: audience: None = None database_name: None = None - +@dataclass +class MergeEntitiesDefaults: + """Default values for merging entities workflow.""" + enabled: bool = True + eps: float=0.2 + min_samples: int=2 + @dataclass class GraphRagConfigDefaults: """Default values for GraphRAG.""" @@ -439,7 +445,9 @@ class GraphRagConfigDefaults: vector_store: dict[str, VectorStoreDefaults] = field( default_factory=lambda: {DEFAULT_VECTOR_STORE_ID: VectorStoreDefaults()} ) + merge_entities: MergeEntitiesDefaults = field(default_factory=MergeEntitiesDefaults) workflows: None = None + language_model_defaults = LanguageModelDefaults() diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index c4c5b780c3..73aece48a2 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -33,6 +33,7 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) +from graphrag.config.models.merge_entities_config import MergeEntitiesConfig from graphrag.config.models.text_embedding_config import TextEmbeddingConfig from graphrag.config.models.umap_config import UmapConfig from graphrag.config.models.vector_store_config import VectorStoreConfig @@ -281,7 +282,10 @@ def _validate_reporting_base_dir(self) -> None: description="The basic search configuration.", default=BasicSearchConfig() ) """The basic search configuration.""" - + merge_entities: MergeEntitiesConfig = Field( + description="The merge entities workflow configuration.", default=MergeEntitiesConfig() + ) + def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" for store in self.vector_store.values(): diff --git a/graphrag/config/models/merge_entities_config.py b/graphrag/config/models/merge_entities_config.py new file mode 100644 index 0000000000..722f037646 --- /dev/null +++ b/graphrag/config/models/merge_entities_config.py @@ -0,0 +1,22 @@ +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, Field + +from graphrag.config.defaults import graphrag_config_defaults + + +class MergeEntitiesConfig(BaseModel): + """The default configuration section for Node2Vec.""" + + enabled: bool = Field( + description="A flag indicating whether to enable merge entities workflow.", + default=graphrag_config_defaults.merge_entities.enabled, + ) + eps: float = Field( + description="eps for DBSCAN clustering algorithm.", + default=graphrag_config_defaults.merge_entities.eps, + ) + min_samples: int = Field( + description="min_samples for DBSCAN clustering algorithm.", + default=graphrag_config_defaults.merge_entities.min_samples, + ) \ No newline at end of file diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index 425639be0b..16c67b99d0 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -42,7 +42,9 @@ from .prune_graph import ( run_workflow as run_prune_graph, ) - +from .merge_entities import ( + run_workflow as run_merge_entities, +) # register all of our built-in workflows at once PipelineFactory.register_all({ "create_base_text_units": run_create_base_text_units, @@ -57,4 +59,5 @@ "finalize_graph": run_finalize_graph, "generate_text_embeddings": run_generate_text_embeddings, "prune_graph": run_prune_graph, + "merge_entities": run_merge_entities, }) diff --git a/graphrag/index/workflows/factory.py b/graphrag/index/workflows/factory.py index b68ccf55e0..c64c48c4eb 100644 --- a/graphrag/index/workflows/factory.py +++ b/graphrag/index/workflows/factory.py @@ -48,6 +48,7 @@ def _get_workflows_list( "create_base_text_units", "create_final_documents", "extract_graph", + *(["merge_entities"] if config.merge_entities.enabled else []), "finalize_graph", *(["extract_covariates"] if config.extract_claims.enabled else []), "create_communities", diff --git a/graphrag/index/workflows/merge_entities.py b/graphrag/index/workflows/merge_entities.py new file mode 100644 index 0000000000..0cc2461c50 --- /dev/null +++ b/graphrag/index/workflows/merge_entities.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" +# %% +from datetime import datetime, timezone +from typing import cast +from uuid import uuid4 +import logging +import numpy as np +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.typing.context import PipelineRunContext, PipelineStorage +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.utils.storage import load_table_from_storage, write_table_to_storage + +from sklearn.cluster import DBSCAN +from graphrag.prompts.index.merge_entities import MERGE_ENTITIES_INPUT, MERGE_ENTITIES_SYSTEM + +from graphrag.language_model.manager import ModelManager + + + + + + +from json_repair import loads +from json import dump + +log = logging.getLogger(__name__) +from graphrag.config.embeddings import ( + entity_title_embedding, + #get_embedded_fields, + get_embedding_settings, +) +from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + + llm_config = config.models["default_chat_model"] + llm = ModelManager().get_or_create_chat_model( + name="extract_claims", + model_type=llm_config.type, + config=llm_config, + callbacks=context.callbacks, + cache=context.cache, + ) + # Load entities and relationships + entities = await load_table_from_storage("entities", context.storage) + relationships = await load_table_from_storage("relationships", context.storage) + + if "human_readable_id" not in entities.columns: + entities["human_readable_id"] = entities.index + if "id" not in entities.columns: + entities["id"] = entities["human_readable_id"].apply(lambda _x: str(uuid4())) + # Embed entities.title + + embeddings_df = await create_entity_title_embedding(entities, config, context) + embeddings_df = embeddings_df["entity.title"] + + # calculate cosine similarity + embeddings_numpy = embeddings_df["embedding"].values + embeddings_numpy = np.stack(embeddings_numpy, axis=0) + similarity_matrix = cosine_similarity_matrix(embeddings_numpy) + # clustering + embeddings_df["cluster"] = get_dbscan_cluster_labels(similarity_matrix,eps = config.merge_entities.eps,min_samples = config.merge_entities.min_samples) + # find duplicates + llm_json = find_duplicate_entities(llm, embeddings_df, entities) + + with open(config.output.base_dir + '/merged_entities.json', 'w') as f: + dump({"length": len(llm_json), "llm_json": llm_json}, f) + # update + entities, relationships = update_entities_relationships( + entities, relationships, llm_json + ) + + # save + await write_table_to_storage(relationships, "relationships", context.storage) + await write_table_to_storage(entities, "entities", context.storage) + + return WorkflowFunctionOutput( + result={ + "entities": entities, + "relationships": relationships, + } + ) + + +def find_duplicate_entities(llm, embeddings: pd.DataFrame, entities: pd.DataFrame): + prompt_input = get_input_for_prompt(embeddings=embeddings, entities=entities) + prompt = MERGE_ENTITIES_SYSTEM + MERGE_ENTITIES_INPUT.format(input=prompt_input) + response = llm.chat(prompt).output.content + llm_json = loads(response) + return llm_json + + +async def create_entity_title_embedding( + entities, config: GraphRagConfig, context: PipelineRunContext +): + embedded_fields = set([entity_title_embedding]) + config_copy = config.copy() + config_copy.embed_text.target = "selected" + text_embed = get_embedding_settings(config_copy) + + + embeddings_df = await generate_text_embeddings( + documents=None, + relationships=None, + text_units=None, + community_reports=None, + entities=entities, + callbacks=context.callbacks, + cache=context.cache, + text_embed_config=text_embed, + embedded_fields=embedded_fields, + ) + + return embeddings_df + + +def update_entities_relationships(entities: pd.DataFrame, relationships, response): + import itertools + + entities.index = entities.index.astype(int) + all_ids = [] + new_entities_list = [] + + #entities: ['title', 'type', 'text_unit_ids', 'frequency', 'description'] + #relashionships: ['source', 'target', 'text_unit_ids', 'weight', 'description'] + for item in response: + """ + item is llm output + { + "ids": [4, 13], + "entities": ["PCB", "PRINTED CIRCUIT BOARD"], + "final_entity": "PRINTED CIRCUIT BOARD", + "final_description": "A printed circuit board (PCB), also known as a printed wiring board (PWB) or printed board, is a thin board of insulating material used in electronics assembly to hold and connect electronic components. The PCB serves as a substrate, typically made from thermosetting or thermoplastic plastics, reinforced with materials like paper, glass fiber, cotton, or nylon. It features conductive pathways (usually copper) printed on one or both sides, which interconnect components via soldering to lands (pads). These connections are made either through plated through-holes for leaded components or directly onto the surface for surface-mount components. PCBs are manufactured using printing techniques, and the conductive tracks can be created additively (adding tracks) or subtractively (removing excess material from a pre-coated base). They are available in single-sided, double-sided, and multi-layered configurations, and are essential in all electronic assemblies, providing support and pathways for components during the soldering process.", + "final_type": "MATERIAL" + } + """ + item["ids"] = list(map(int, item["ids"])) + old_rows = entities.loc[item["ids"], :] + + new_title = item["final_entity"] + new_type = item["final_type"] + new_description = item["final_description"] + + frequency = old_rows["frequency"].sum() + textunit_ids = old_rows["text_unit_ids"] + textunit_ids = list(itertools.chain.from_iterable(textunit_ids)) + row = { + + "title": new_title, + "type": new_type, + "description": new_description, + "text_unit_ids": textunit_ids, + "frequency": frequency, + + } + relationships.loc[ + relationships["source"].isin(item["entities"]), "source" + ] = new_title + relationships.loc[ + relationships["target"].isin(item["entities"]), "target" + ] = new_title + + + new_entities_list.append(row) + all_ids.extend(item["ids"]) + + + + entities = entities.drop(all_ids) + entities = entities.drop(columns=["human_readable_id", "id"]) + entities = pd.concat([entities, pd.DataFrame(new_entities_list)]).reset_index( + drop=True + ) + return entities, relationships + + +def cosine_similarity_matrix(X): + norms = np.linalg.norm(X, axis=1, keepdims=True) + X_normalized = X / norms + sim_matrix = np.dot(X_normalized, X_normalized.T) + sim_matrix = np.clip(sim_matrix, -1.0, 1.0) + return sim_matrix + + +def get_dbscan_cluster_labels(similarity_matrix,eps=0.2,min_samples=2): + + dbscan = DBSCAN( + metric="precomputed", eps=eps, min_samples=min_samples + ) # eps=0.2 corresponds to 80% similarity + cosine_distance = 1 - similarity_matrix + labels = dbscan.fit_predict(cosine_distance) + return labels + + +def get_input_for_prompt(embeddings: pd.DataFrame, entities: pd.DataFrame): + text = "" + for cluster_id, group_df in embeddings.groupby("cluster"): + if cluster_id == -1: + continue + rows = entities.loc[ + group_df.index, ["human_readable_id", "title", "type", "description"] + ] + text += "[" + "\n" + for index, r in rows.iterrows(): + text += f"{{'entity': '{r['title']}', 'type': '{r['type']}', 'description': '{r['description']}', 'id': {r['human_readable_id']} }}," + "\n" + text += "]" + "\n" + + return text diff --git a/graphrag/prompts/index/merge_entities.py b/graphrag/prompts/index/merge_entities.py new file mode 100644 index 0000000000..2cfa2ad940 --- /dev/null +++ b/graphrag/prompts/index/merge_entities.py @@ -0,0 +1,117 @@ +MERGE_ENTITIES_SYSTEM = """ +You are tasked with analyzing pairs or groups of entities and determining if they should be merged into a single, canonical entity. Your goal is to merge entities that represent the **exact same real-world concept** but differ due to minor linguistic variations. + +**Merging Rules:** + +1. **Primary Condition:** Merge only if the entities definitively refer to the *same specific thing* (e.g., the same specific car model, the same person, the same city). +2. **Allowed Variations for Merging:** + * **Plural vs. Singular:** e.g., `CAR` vs. `CARS` (if descriptions refer generally), `BUILDING` vs. `BUILDINGS`. + * **Acronyms vs. Full Names:** e.g., `NYC` vs. `NEW YORK CITY`, `USA` vs. `UNITED STATES OF AMERICA`. + * **Minor Spelling/Formatting Variations:** e.g., `Dr John Smith` vs. `Dr. John Smith`. + * **Well-Known Aliases/Nicknames (Use with Caution):** e.g., `THE BIG APPLE` vs. `NEW YORK CITY`, but *only if* descriptions strongly confirm they refer to the identical entity. + * **Titles/No Titles (for People):** e.g., `PRESIDENT BIDEN` vs. `JOE BIDEN`, *if descriptions confirm it's the same individual*. +3. **Mandatory Requirements for Merging:** + * **Identical Type:** The `type` attribute MUST be the same for all entities being merged (e.g., all must be `LOCATION`, or `PERSON`, or `VEHICLE`). + * **Compatible Descriptions:** The `description` fields must be consistent and describe the same entity. One description might be more detailed, but they should not contradict each other or describe fundamentally different aspects. Merged descriptions should combine the information logically. +4. **DO NOT MERGE IF:** + * **Different Types:** e.g., `BERLIN` (LOCATION) vs. `BERLINER` (PERSON or FOOD). + * **Different Concepts:** Entities represent distinct things, even if related. e.g., `CAR` vs. `TRUCK`, `PARIS` (City) vs. `FRANCE` (Country), `ENGINE` (Component) vs. `CAR` (Vehicle). + * **Different Specific Entities:** Names are similar but descriptions indicate different individuals, places, or things. e.g., Two people named `John Smith` with different professions described. + * **Significant Name Differences:** Unless it's a confirmed acronym or very well-known alias with supporting descriptions. + * **Contradictory Descriptions:** Descriptions point to different characteristics or facts about the entity. + +**Output Format:** + +If a merge is warranted, output a JSON list containing one object per merge group. Each object should have the following structure: + +```json +[{ + "ids": [list_of_original_entity_ids], + "entities": [list_of_original_entity_names], + "final_entity": "chosen_canonical_entity_name", + "final_description": "combined_and_refined_description", + "final_type": "common_entity_type" +}]''' +If no entities in the input should be merged, output an empty list []. +### +Examples +### +input: +[ + {'entity': 'CAR', 'type': 'VEHICLE', 'description': 'A four-wheeled road vehicle powered by an engine, able to carry a small number of people.', 'id': 10}, + {'entity': 'CARS', 'type': 'VEHICLE', 'description': 'road vehicles designed for transporting people.', 'id': 11} +] +output: +[{ + "ids": [10, 11], + "entities": ["CAR", "CARS"], + "final_entity": "CAR", + "final_description": "A four-wheeled road vehicle powered by an engine, designed for transporting a small number of people.", + "final_type": "VEHICLE" +}] +### +input: +[ + {'entity': 'LOS ANGELES', 'type': 'LOCATION', 'description': 'A sprawling Southern California city and the center of the nation’s film and television industry.', 'id': 101}, + {'entity': 'CALIFORNIA', 'type': 'LOCATION', 'description': 'A state in the Western United States, known for its diverse terrain including cliff-lined beaches, redwood forests, the Sierra Nevada Mountains, Central Valley farmland and the Mojave Desert.', 'id': 103}, + {'entity': 'LA', 'type': 'LOCATION', 'description': 'Common abbreviation for Los Angeles, a major city on the Pacific Coast of the USA.', 'id': 102}, + {'entity': 'SAN FRANCISCO', 'type': 'LOCATION', 'description': 'A hilly city on a peninsula between the Pacific Ocean and San Francisco Bay in Northern California.', 'id': 104}, + {'entity': 'HOLLYWOOD SIGN', 'type': 'LANDMARK', 'description': 'An American landmark and cultural icon overlooking Hollywood, Los Angeles, California.', 'id': 105} +] +output: +[{ + "ids": [101, 102], + "entities": ["LOS ANGELES", "LA"], + "final_entity": "LOS ANGELES", + "final_description": "Los Angeles (LA) is a sprawling Southern California city, a major city on the Pacific Coast of the USA, and the center of the nation’s film and television industry.", + "final_type": "LOCATION" +}] +### +input: +[ + {'entity': 'DR. ANGELA MERKEL', 'type': 'PERSON', 'description': 'German politician who served as Chancellor of Germany from 2005 to 2021. Holds a doctorate in quantum chemistry.', 'id': 30}, + {'entity': 'ANGELA MERKEL', 'type': 'PERSON', 'description': 'Former Chancellor of Germany, leader of the Christian Democratic Union.', 'id': 31}, + {"entity": "BARACK OBAMA","type": "PERSON","description": "44th President of the United States, first African-American president, Nobel Peace Prize laureate.","id": 32} +] +output: +[{ + "ids": [30, 31], + "entities": ["DR. ANGELA MERKEL", "ANGELA MERKEL"], + "final_entity": "ANGELA MERKEL", + "final_description": "Angela Merkel (Dr.) is a German politician and former Chancellor of Germany (2005-2021), former leader of the Christian Democratic Union, holding a doctorate in quantum chemistry.", + "final_type": "PERSON" +}] +### +input: +[ + {'entity': 'BERLIN', 'type': 'LOCATION', 'description': 'The capital and largest city of Germany.', 'id': 40}, + {'entity': 'BERLINER', 'type': 'FOOD', 'description': 'A type of German doughnut often filled with jam.', 'id': 41} +] +output: +[] +### +input: +[ + {'entity': 'JOHN SMITH', 'type': 'PERSON', 'description': 'Captain John Smith was an English explorer important to the establishment of Jamestown.', 'id': 50}, + {'entity': 'JOHN SMITH', 'type': 'PERSON', 'description': 'John Smith is a contemporary musician known for his folk guitar style.', 'id': 51} +] +output: +[] + +### +input: +[ + {'entity': 'GERMANY', 'type': 'LOCATION', 'description': 'A country in Central Europe, member of the EU.', 'id': 60}, + {'entity': 'BAVARIA', 'type': 'LOCATION', 'description': 'A state (Bundesland) in the southeast of Germany.', 'id': 61} +] +output: +[] +### +""" + +MERGE_ENTITIES_INPUT = """ +Real input: +{input} +output: +""" +