Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260210011450472481.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add DataReader class for typed dataframe loading from TableProvider across indexing workflows and query CLI"
}
8 changes: 5 additions & 3 deletions packages/graphrag/graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
from graphrag.config.load_config import load_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -375,20 +376,21 @@ def _resolve_output_files(
output_list: list[str],
optional_list: list[str] | None = None,
) -> dict[str, Any]:
"""Read indexing output files to a dataframe dict."""
"""Read indexing output files to a dataframe dict, with correct column types."""
dataframe_dict = {}
storage_obj = create_storage(config.output_storage)
table_provider = create_table_provider(config.table_provider, storage=storage_obj)
reader = DataReader(table_provider)
for name in output_list:
df_value = asyncio.run(table_provider.read_dataframe(name))
df_value = asyncio.run(getattr(reader, name)())
dataframe_dict[name] = df_value

# for optional output files, set the dict entry to None instead of erroring out if it does not exist
if optional_list:
for optional_file in optional_list:
file_exists = asyncio.run(table_provider.has(optional_file))
if file_exists:
df_value = asyncio.run(table_provider.read_dataframe(optional_file))
df_value = asyncio.run(getattr(reader, optional_file)())
dataframe_dict[optional_file] = df_value
else:
dataframe_dict[optional_file] = None
Expand Down
4 changes: 4 additions & 0 deletions packages/graphrag/graphrag/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
# Licensed under the MIT License

"""Knowledge model package."""

from graphrag.data_model.data_reader import DataReader

__all__ = ["DataReader"]
71 changes: 71 additions & 0 deletions packages/graphrag/graphrag/data_model/data_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A DataReader that loads typed dataframes from a TableProvider."""

import pandas as pd
from graphrag_storage.tables import TableProvider

from graphrag.data_model.dfs import (
communities_typed,
community_reports_typed,
covariates_typed,
documents_typed,
entities_typed,
relationships_typed,
text_units_typed,
)


class DataReader:
"""Reads dataframes from a TableProvider and applies correct column types.

When loading from weakly-typed formats like CSV, list columns are stored as
plain strings. This class wraps a TableProvider, loading each table and
converting columns to their expected types before returning.
"""

def __init__(self, table_provider: TableProvider) -> None:
"""Initialize a DataReader with the given TableProvider.

Args
----
table_provider: TableProvider
The table provider to load dataframes from.
"""
self._table_provider = table_provider

async def entities(self) -> pd.DataFrame:
"""Load and return the entities dataframe with correct types."""
df = await self._table_provider.read_dataframe("entities")
return entities_typed(df)

async def relationships(self) -> pd.DataFrame:
"""Load and return the relationships dataframe with correct types."""
df = await self._table_provider.read_dataframe("relationships")
return relationships_typed(df)

async def communities(self) -> pd.DataFrame:
"""Load and return the communities dataframe with correct types."""
df = await self._table_provider.read_dataframe("communities")
return communities_typed(df)

async def community_reports(self) -> pd.DataFrame:
"""Load and return the community reports dataframe with correct types."""
df = await self._table_provider.read_dataframe("community_reports")
return community_reports_typed(df)

async def covariates(self) -> pd.DataFrame:
"""Load and return the covariates dataframe with correct types."""
df = await self._table_provider.read_dataframe("covariates")
return covariates_typed(df)

async def text_units(self) -> pd.DataFrame:
"""Load and return the text units dataframe with correct types."""
df = await self._table_provider.read_dataframe("text_units")
return text_units_typed(df)

async def documents(self) -> pd.DataFrame:
"""Load and return the documents dataframe with correct types."""
df = await self._table_provider.read_dataframe("documents")
return documents_typed(df)
129 changes: 129 additions & 0 deletions packages/graphrag/graphrag/data_model/dfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A package containing dataframe processing utilities."""

from typing import Any

import pandas as pd

from graphrag.data_model.schemas import (
COMMUNITY_CHILDREN,
COMMUNITY_ID,
COMMUNITY_LEVEL,
COVARIATE_IDS,
EDGE_DEGREE,
EDGE_WEIGHT,
ENTITY_IDS,
FINDINGS,
N_TOKENS,
NODE_DEGREE,
NODE_FREQUENCY,
PERIOD,
RATING,
RELATIONSHIP_IDS,
SHORT_ID,
SIZE,
TEXT_UNIT_IDS,
)


def _split_list_column(value: Any) -> list[Any]:
"""Split a column containing a list string into an actual list."""
if isinstance(value, str):
return [item.strip("[] '") for item in value.split(",")] if value else []
return value


def entities_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the entities dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)
if NODE_FREQUENCY in df.columns:
df[NODE_FREQUENCY] = df[NODE_FREQUENCY].astype(int)
if NODE_DEGREE in df.columns:
df[NODE_DEGREE] = df[NODE_DEGREE].astype(int)

return df


def relationships_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the relationships dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
if EDGE_WEIGHT in df.columns:
df[EDGE_WEIGHT] = df[EDGE_WEIGHT].astype(float)
if EDGE_DEGREE in df.columns:
df[EDGE_DEGREE] = df[EDGE_DEGREE].astype(int)
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)

return df


def communities_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the communities dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
df[COMMUNITY_ID] = df[COMMUNITY_ID].astype(int)
df[COMMUNITY_LEVEL] = df[COMMUNITY_LEVEL].astype(int)
df[COMMUNITY_CHILDREN] = df[COMMUNITY_CHILDREN].apply(_split_list_column)
if ENTITY_IDS in df.columns:
df[ENTITY_IDS] = df[ENTITY_IDS].apply(_split_list_column)
if RELATIONSHIP_IDS in df.columns:
df[RELATIONSHIP_IDS] = df[RELATIONSHIP_IDS].apply(_split_list_column)
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)
df[PERIOD] = df[PERIOD].astype(str)
df[SIZE] = df[SIZE].astype(int)

return df


def community_reports_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the community reports dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
df[COMMUNITY_ID] = df[COMMUNITY_ID].astype(int)
df[COMMUNITY_LEVEL] = df[COMMUNITY_LEVEL].astype(int)
df[COMMUNITY_CHILDREN] = df[COMMUNITY_CHILDREN].apply(_split_list_column)
df[RATING] = df[RATING].astype(float)
df[FINDINGS] = df[FINDINGS].apply(_split_list_column)
df[SIZE] = df[SIZE].astype(int)

return df


def covariates_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the covariates dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)

return df


def text_units_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the text units dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
df[N_TOKENS] = df[N_TOKENS].astype(int)
if ENTITY_IDS in df.columns:
df[ENTITY_IDS] = df[ENTITY_IDS].apply(_split_list_column)
if RELATIONSHIP_IDS in df.columns:
df[RELATIONSHIP_IDS] = df[RELATIONSHIP_IDS].apply(_split_list_column)
if COVARIATE_IDS in df.columns:
df[COVARIATE_IDS] = df[COVARIATE_IDS].apply(_split_list_column)

return df


def documents_typed(df: pd.DataFrame) -> pd.DataFrame:
"""Return the documents dataframe with correct types, in case it was stored in a weakly-typed format."""
if SHORT_ID in df.columns:
df[SHORT_ID] = df[SHORT_ID].astype(int)
if TEXT_UNIT_IDS in df.columns:
df[TEXT_UNIT_IDS] = df[TEXT_UNIT_IDS].apply(_split_list_column)

return df
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
Expand All @@ -30,7 +31,8 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform base text_units."""
logger.info("Workflow started: create_base_text_units")
documents = await context.output_table_provider.read_dataframe("documents")
reader = DataReader(context.output_table_provider)
documents = await reader.documents()

tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
Expand All @@ -27,8 +28,9 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform final communities."""
logger.info("Workflow started: create_communities")
entities = await context.output_table_provider.read_dataframe("entities")
relationships = await context.output_table_provider.read_dataframe("relationships")
reader = DataReader(context.output_table_provider)
entities = await reader.entities()
relationships = await reader.relationships()

max_cluster_size = config.cluster_graph.max_cluster_size
use_lcc = config.cluster_graph.use_lcc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
Expand Down Expand Up @@ -43,15 +44,16 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: create_community_reports")
edges = await context.output_table_provider.read_dataframe("relationships")
entities = await context.output_table_provider.read_dataframe("entities")
communities = await context.output_table_provider.read_dataframe("communities")
reader = DataReader(context.output_table_provider)
relationships = await reader.relationships()
entities = await reader.entities()
communities = await reader.communities()

claims = None
if config.extract_claims.enabled and await context.output_table_provider.has(
"covariates"
):
claims = await context.output_table_provider.read_dataframe("covariates")
claims = await reader.covariates()

model_config = config.get_completion_model_config(
config.community_reports.completion_model_id
Expand All @@ -67,7 +69,7 @@ async def run_workflow(
tokenizer = model.tokenizer

output = await create_community_reports(
edges_input=edges,
relationships=relationships,
entities=entities,
communities=communities,
claims_input=claims,
Expand All @@ -88,7 +90,7 @@ async def run_workflow(


async def create_community_reports(
edges_input: pd.DataFrame,
relationships: pd.DataFrame,
entities: pd.DataFrame,
communities: pd.DataFrame,
claims_input: pd.DataFrame | None,
Expand All @@ -105,7 +107,7 @@ async def create_community_reports(
nodes = explode_communities(communities, entities)

nodes = _prep_nodes(nodes)
edges = _prep_edges(edges_input)
edges = _prep_edges(relationships)

claims = None
if claims_input is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
Expand Down Expand Up @@ -42,9 +43,10 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: create_community_reports_text")
entities = await context.output_table_provider.read_dataframe("entities")
communities = await context.output_table_provider.read_dataframe("communities")
text_units = await context.output_table_provider.read_dataframe("text_units")
reader = DataReader(context.output_table_provider)
entities = await reader.entities()
communities = await reader.communities()
text_units = await reader.text_units()

model_config = config.get_completion_model_config(
config.community_reports.completion_model_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.data_reader import DataReader
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
Expand All @@ -21,8 +22,9 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform final documents."""
logger.info("Workflow started: create_final_documents")
documents = await context.output_table_provider.read_dataframe("documents")
text_units = await context.output_table_provider.read_dataframe("text_units")
reader = DataReader(context.output_table_provider)
documents = await reader.documents()
text_units = await reader.text_units()

output = create_final_documents(documents, text_units)

Expand Down
Loading