Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]:
return metadata


def _sanitize_neo4j_value(value: Any) -> Any:
"""Convert values unsupported by Neo4j properties into safe serializations."""
if value is None or isinstance(value, (str, int, float, bool)):
return value

if isinstance(value, list):
if all(item is None or isinstance(item, (str, int, float, bool)) for item in value):
return value
return [json.dumps(item, ensure_ascii=False) if isinstance(item, (dict, list)) else str(item) for item in value]

if isinstance(value, dict):
return json.dumps(value, ensure_ascii=False, sort_keys=True)

return str(value)


def _sanitize_neo4j_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
"""Ensure all metadata values are valid Neo4j property types."""
return {key: _sanitize_neo4j_value(value) for key, value in metadata.items()}


class Neo4jGraphDB(BaseGraphDB):
"""Neo4j-based implementation of a graph memory store."""

Expand Down Expand Up @@ -209,6 +230,9 @@ def add_node(
# Flatten info fields to top level (for Neo4j flat structure)
metadata = _flatten_info_fields(metadata)

# Ensure Neo4j property compatibility (no nested map/list-of-map values)
metadata = _sanitize_neo4j_metadata(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
metadata.setdefault("delete_record_id", "")
Expand Down
10 changes: 9 additions & 1 deletion src/memos/graph_dbs/neo4j_community.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing import Any

from memos.configs.graph_db import Neo4jGraphDBConfig
from memos.graph_dbs.neo4j import Neo4jGraphDB, _flatten_info_fields, _prepare_node_metadata
from memos.graph_dbs.neo4j import (
Neo4jGraphDB,
_flatten_info_fields,
_prepare_node_metadata,
_sanitize_neo4j_metadata,
)
from memos.log import get_logger
from memos.vec_dbs.factory import VecDBFactory
from memos.vec_dbs.item import VecDBItem
Expand Down Expand Up @@ -55,6 +60,8 @@ def add_node(

# Safely process metadata
metadata = _prepare_node_metadata(metadata)
metadata = _flatten_info_fields(metadata)
metadata = _sanitize_neo4j_metadata(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
Expand Down Expand Up @@ -134,6 +141,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N

metadata = _prepare_node_metadata(metadata)
metadata = _flatten_info_fields(metadata)
metadata = _sanitize_neo4j_metadata(metadata)

# Initialize delete_time and delete_record_id fields
metadata.setdefault("delete_time", "")
Expand Down
23 changes: 23 additions & 0 deletions tests/graph_dbs/graph_dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,26 @@ def test_get_memory_count(graph_db):
session_mock.run.return_value.single.return_value = {"count": 42}
count = graph_db.get_memory_count("WorkingMemory")
assert count == 42


def test_add_node_sanitizes_nested_metadata(graph_db):
session_mock = graph_db.driver.session.return_value.__enter__.return_value
node_id = str(uuid.uuid4())
memory = "skill memory"
metadata = {
"memory_type": "SkillMemory",
"embedding": [0.1, 0.2, 0.3],
"tags": ["skill"],
"scripts": {"run.py": "print(1)"},
"others": {"README.md": "# demo"},
"info": {"nested": {"x": 1}, "arr_obj": [{"a": 1}]},
}

graph_db.add_node(node_id, memory, metadata)

_, kwargs = session_mock.run.call_args
sanitized = kwargs["metadata"]
assert isinstance(sanitized["scripts"], str)
assert isinstance(sanitized["others"], str)
assert isinstance(sanitized["nested"], str)
assert sanitized["arr_obj"] == ['{"a": 1}']