diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 33eb39692..83d227da8 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -72,6 +72,27 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]: return metadata +def _sanitize_neo4j_value(value: Any) -> Any: + """Convert values unsupported by Neo4j properties into safe serializations.""" + if value is None or isinstance(value, (str, int, float, bool)): + return value + + if isinstance(value, list): + if all(item is None or isinstance(item, (str, int, float, bool)) for item in value): + return value + return [json.dumps(item, ensure_ascii=False) if isinstance(item, (dict, list)) else str(item) for item in value] + + if isinstance(value, dict): + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + return str(value) + + +def _sanitize_neo4j_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure all metadata values are valid Neo4j property types.""" + return {key: _sanitize_neo4j_value(value) for key, value in metadata.items()} + + class Neo4jGraphDB(BaseGraphDB): """Neo4j-based implementation of a graph memory store.""" @@ -209,6 +230,9 @@ def add_node( # Flatten info fields to top level (for Neo4j flat structure) metadata = _flatten_info_fields(metadata) + # Ensure Neo4j property compatibility (no nested map/list-of-map values) + metadata = _sanitize_neo4j_metadata(metadata) + # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") metadata.setdefault("delete_record_id", "") diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 09ad46c42..a5bffd123 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -5,7 +5,12 @@ from typing import Any from memos.configs.graph_db import Neo4jGraphDBConfig -from memos.graph_dbs.neo4j import Neo4jGraphDB, _flatten_info_fields, _prepare_node_metadata +from memos.graph_dbs.neo4j import ( + Neo4jGraphDB, + _flatten_info_fields, + _prepare_node_metadata, + _sanitize_neo4j_metadata, +) from memos.log import get_logger from memos.vec_dbs.factory import VecDBFactory from memos.vec_dbs.item import VecDBItem @@ -55,6 +60,8 @@ def add_node( # Safely process metadata metadata = _prepare_node_metadata(metadata) + metadata = _flatten_info_fields(metadata) + metadata = _sanitize_neo4j_metadata(metadata) # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") @@ -134,6 +141,7 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N metadata = _prepare_node_metadata(metadata) metadata = _flatten_info_fields(metadata) + metadata = _sanitize_neo4j_metadata(metadata) # Initialize delete_time and delete_record_id fields metadata.setdefault("delete_time", "") diff --git a/tests/graph_dbs/graph_dbs.py b/tests/graph_dbs/graph_dbs.py index 2cc35a0ad..c834f65a9 100644 --- a/tests/graph_dbs/graph_dbs.py +++ b/tests/graph_dbs/graph_dbs.py @@ -105,3 +105,26 @@ def test_get_memory_count(graph_db): session_mock.run.return_value.single.return_value = {"count": 42} count = graph_db.get_memory_count("WorkingMemory") assert count == 42 + + +def test_add_node_sanitizes_nested_metadata(graph_db): + session_mock = graph_db.driver.session.return_value.__enter__.return_value + node_id = str(uuid.uuid4()) + memory = "skill memory" + metadata = { + "memory_type": "SkillMemory", + "embedding": [0.1, 0.2, 0.3], + "tags": ["skill"], + "scripts": {"run.py": "print(1)"}, + "others": {"README.md": "# demo"}, + "info": {"nested": {"x": 1}, "arr_obj": [{"a": 1}]}, + } + + graph_db.add_node(node_id, memory, metadata) + + _, kwargs = session_mock.run.call_args + sanitized = kwargs["metadata"] + assert isinstance(sanitized["scripts"], str) + assert isinstance(sanitized["others"], str) + assert isinstance(sanitized["nested"], str) + assert sanitized["arr_obj"] == ['{"a": 1}']