diff --git a/README.md b/README.md
index 70562fec7..45a372ed1 100644
--- a/README.md
+++ b/README.md
@@ -345,10 +345,10 @@ url = {https://global-sci.com/article/91443/memory3-language-modeling-with-expli
## 🙌 Contributing
-We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/contribution/overview) to get started.
+We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/open_source/contribution/overview/) to get started.
## 📄 License
-MemOS is licensed under the [Apache 2.0 License](./LICENSE).
\ No newline at end of file
+MemOS is licensed under the [Apache 2.0 License](./LICENSE).
diff --git a/pyproject.toml b/pyproject.toml
index b4b01e0e1..4a9ea8852 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@
##############################################################################
name = "MemoryOS"
-version = "2.0.6"
+version = "2.0.7"
description = "Intelligence Begins with Memory"
license = {text = "Apache-2.0"}
readme = "README.md"
diff --git a/src/memos/__init__.py b/src/memos/__init__.py
index b568ae0c2..fefa3b2ab 100644
--- a/src/memos/__init__.py
+++ b/src/memos/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "2.0.6"
+__version__ = "2.0.7"
from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py
index 267d1bb28..8e7785ad5 100644
--- a/src/memos/api/handlers/search_handler.py
+++ b/src/memos/api/handlers/search_handler.py
@@ -64,7 +64,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
# Expand top_k for deduplication (5x to ensure enough candidates)
if search_req_local.dedup in ("sim", "mmr"):
- search_req_local.top_k = search_req_local.top_k * 5
+ search_req_local.top_k = search_req_local.top_k * 3
# Search and deduplicate
cube_view = self._build_cube_view(search_req_local)
@@ -152,9 +152,6 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di
return results
embeddings = self._extract_embeddings([mem for _, mem, _ in flat])
- if embeddings is None:
- documents = [mem.get("memory", "") for _, mem, _ in flat]
- embeddings = self.searcher.embedder.embed(documents)
similarity_matrix = cosine_similarity_matrix(embeddings)
@@ -235,12 +232,39 @@ def _mmr_dedup_text_memories(
if len(flat) <= 1:
return results
+ total_by_type: dict[str, int] = {"text": 0, "preference": 0}
+ existing_by_type: dict[str, int] = {"text": 0, "preference": 0}
+ missing_by_type: dict[str, int] = {"text": 0, "preference": 0}
+ missing_indices: list[int] = []
+ for idx, (mem_type, _, mem, _) in enumerate(flat):
+ if mem_type not in total_by_type:
+ total_by_type[mem_type] = 0
+ existing_by_type[mem_type] = 0
+ missing_by_type[mem_type] = 0
+ total_by_type[mem_type] += 1
+
+ embedding = mem.get("metadata", {}).get("embedding")
+ if embedding:
+ existing_by_type[mem_type] += 1
+ else:
+ missing_by_type[mem_type] += 1
+ missing_indices.append(idx)
+
+ self.logger.info(
+ "[SearchHandler] MMR embedding metadata scan: total=%s total_by_type=%s existing_by_type=%s missing_by_type=%s",
+ len(flat),
+ total_by_type,
+ existing_by_type,
+ missing_by_type,
+ )
+ if missing_indices:
+ self.logger.warning(
+ "[SearchHandler] MMR embedding metadata missing; will compute missing embeddings: missing_total=%s",
+ len(missing_indices),
+ )
+
# Get or compute embeddings
embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat])
- if embeddings is None:
- self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings")
- documents = [mem.get("memory", "") for _, _, mem, _ in flat]
- embeddings = self.searcher.embedder.embed(documents)
# Compute similarity matrix using NumPy-optimized method
# Returns numpy array but compatible with list[i][j] indexing
@@ -404,14 +428,32 @@ def _max_similarity(
return 0.0
return max(similarity_matrix[index][j] for j in selected_indices)
- @staticmethod
- def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
+ def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float]]:
embeddings: list[list[float]] = []
- for mem in memories:
- embedding = mem.get("metadata", {}).get("embedding")
- if not embedding:
- return None
- embeddings.append(embedding)
+ missing_indices: list[int] = []
+ missing_documents: list[str] = []
+
+ for idx, mem in enumerate(memories):
+ metadata = mem.get("metadata")
+ if not isinstance(metadata, dict):
+ metadata = {}
+ mem["metadata"] = metadata
+
+ embedding = metadata.get("embedding")
+ if embedding:
+ embeddings.append(embedding)
+ continue
+
+ embeddings.append([])
+ missing_indices.append(idx)
+ missing_documents.append(mem.get("memory", ""))
+
+ if missing_indices:
+ computed = self.searcher.embedder.embed(missing_documents)
+ for idx, embedding in zip(missing_indices, computed, strict=False):
+ embeddings[idx] = embedding
+ memories[idx]["metadata"]["embedding"] = embedding
+
return embeddings
@staticmethod
diff --git a/src/memos/api/middleware/__init__.py b/src/memos/api/middleware/__init__.py
index 64cbc5c60..fd39252f5 100644
--- a/src/memos/api/middleware/__init__.py
+++ b/src/memos/api/middleware/__init__.py
@@ -1,13 +1,14 @@
"""Krolik middleware extensions for MemOS."""
-from .auth import verify_api_key, require_scope, require_admin, require_read, require_write
+from .auth import require_admin, require_read, require_scope, require_write, verify_api_key
from .rate_limit import RateLimitMiddleware
+
__all__ = [
- "verify_api_key",
- "require_scope",
+ "RateLimitMiddleware",
"require_admin",
"require_read",
+ "require_scope",
"require_write",
- "RateLimitMiddleware",
+ "verify_api_key",
]
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index 6fc03e735..5bf27e985 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -99,12 +99,12 @@ class ChatRequest(BaseRequest):
manager_user_id: str | None = Field(None, description="Manager User ID")
project_id: str | None = Field(None, description="Project ID")
relativity: float = Field(
- 0.0,
+ 0.45,
ge=0,
description=(
"Relevance threshold for recalled memories. "
"Only memories with metadata.relativity >= relativity will be returned. "
- "Use 0 to disable threshold filtering. Default: 0.3."
+ "Use 0 to disable threshold filtering. Default: 0.45."
),
)
@@ -339,12 +339,12 @@ class APISearchRequest(BaseRequest):
)
relativity: float = Field(
- 0.0,
+ 0.45,
ge=0,
description=(
"Relevance threshold for recalled memories. "
"Only memories with metadata.relativity >= relativity will be returned. "
- "Use 0 to disable threshold filtering. Default: 0.3."
+ "Use 0 to disable threshold filtering. Default: 0.45."
),
)
@@ -785,12 +785,12 @@ class APIChatCompleteRequest(BaseRequest):
manager_user_id: str | None = Field(None, description="Manager User ID")
project_id: str | None = Field(None, description="Project ID")
relativity: float = Field(
- 0.0,
+ 0.45,
ge=0,
description=(
"Relevance threshold for recalled memories. "
"Only memories with metadata.relativity >= relativity will be returned. "
- "Use 0 to disable threshold filtering. Default: 0.3."
+ "Use 0 to disable threshold filtering. Default: 0.45."
),
)
diff --git a/src/memos/api/utils/api_keys.py b/src/memos/api/utils/api_keys.py
index 559ddd355..29b493fd0 100644
--- a/src/memos/api/utils/api_keys.py
+++ b/src/memos/api/utils/api_keys.py
@@ -5,8 +5,8 @@
"""
import hashlib
-import os
import secrets
+
from dataclasses import dataclass
from datetime import datetime, timedelta
diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py
index 538d913ea..2b3bd0967 100644
--- a/src/memos/embedders/universal_api.py
+++ b/src/memos/embedders/universal_api.py
@@ -73,7 +73,6 @@ async def _create_embeddings():
)
)
logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds")
- logger.info(f"Embeddings request response: {response}")
return [r.embedding for r in response.data]
except Exception as e:
if self.use_backup_client:
diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py
index 130b66a3d..0bc4a54f8 100644
--- a/src/memos/graph_dbs/base.py
+++ b/src/memos/graph_dbs/base.py
@@ -1,12 +1,35 @@
+import re
+
from abc import ABC, abstractmethod
from typing import Any, Literal
+# Pattern for valid field names: alphanumeric and underscores, must start with letter or underscore
+_VALID_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
+
+
class BaseGraphDB(ABC):
"""
Abstract base class for a graph database interface used in a memory-augmented RAG system.
"""
+ @staticmethod
+ def _validate_return_fields(return_fields: list[str] | None) -> list[str]:
+ """Validate and sanitize return_fields to prevent query injection.
+
+ Only allows alphanumeric characters and underscores in field names.
+ Silently drops invalid field names.
+
+ Args:
+ return_fields: List of field names to validate.
+
+ Returns:
+ List of valid field names.
+ """
+ if not return_fields:
+ return []
+ return [f for f in return_fields if _VALID_FIELD_NAME_RE.match(f)]
+
# Node (Memory) Management
@abstractmethod
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
@@ -144,16 +167,23 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
# Search / recall operations
@abstractmethod
- def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]:
+ def search_by_embedding(
+ self, vector: list[float], top_k: int = 5, return_fields: list[str] | None = None, **kwargs
+ ) -> list[dict]:
"""
Retrieve node IDs based on vector similarity.
Args:
vector (list[float]): The embedding vector representing query semantics.
top_k (int): Number of top similar nodes to retrieve.
+ return_fields (list[str], optional): Additional node fields to include in results
+ (e.g., ["memory", "status", "tags"]). When provided, each result dict will
+ contain these fields in addition to 'id' and 'score'.
+ Defaults to None (only 'id' and 'score' are returned).
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
+ If return_fields is specified, each dict also includes the requested fields.
Notes:
- This method may internally call a VecDB (e.g., Qdrant) or store embeddings in the graph DB itself.
diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py
index 746051187..33eb39692 100644
--- a/src/memos/graph_dbs/neo4j.py
+++ b/src/memos/graph_dbs/neo4j.py
@@ -818,6 +818,7 @@ def search_by_embedding(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -832,9 +833,14 @@ def search_by_embedding(
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
search_filter (dict, optional): Additional metadata filters for search results.
Keys should match node properties, values are the expected values.
+ return_fields (list[str], optional): Additional node fields to include in results
+ (e.g., ["memory", "status", "tags"]). When provided, each result
+ dict will contain these fields in addition to 'id' and 'score'.
+ Defaults to None (only 'id' and 'score' are returned).
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
+ If return_fields is specified, each dict also includes the requested fields.
Notes:
- This method uses Neo4j native vector indexing to search for similar nodes.
@@ -886,11 +892,20 @@ def search_by_embedding(
if where_clauses:
where_clause = "WHERE " + " AND ".join(where_clauses)
+ return_clause = "RETURN node.id AS id, score"
+ if return_fields:
+ validated_fields = self._validate_return_fields(return_fields)
+ extra_fields = ", ".join(
+ f"node.{field} AS {field}" for field in validated_fields if field != "id"
+ )
+ if extra_fields:
+ return_clause = f"RETURN node.id AS id, score, {extra_fields}"
+
query = f"""
CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding)
YIELD node, score
{where_clause}
- RETURN node.id AS id, score
+ {return_clause}
"""
parameters = {"embedding": vector, "k": top_k}
@@ -920,7 +935,15 @@ def search_by_embedding(
print(f"[search_by_embedding] query: {query},parameters: {parameters}")
with self.driver.session(database=self.db_name) as session:
result = session.run(query, parameters)
- records = [{"id": record["id"], "score": record["score"]} for record in result]
+ records = []
+ for record in result:
+ item = {"id": record["id"], "score": record["score"]}
+ if return_fields:
+ record_keys = record.keys()
+ for field in return_fields:
+ if field != "id" and field in record_keys:
+ item[field] = record[field]
+ records.append(item)
# Threshold filtering after retrieval
if threshold is not None:
@@ -943,8 +966,8 @@ def search_by_fulltext(
**kwargs,
) -> list[dict]:
"""
- TODO: 实现 Neo4j 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径.
- 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错.
+ TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path.
+ Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j.
"""
return []
diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py
index cae7d6ca5..09ad46c42 100644
--- a/src/memos/graph_dbs/neo4j_community.py
+++ b/src/memos/graph_dbs/neo4j_community.py
@@ -246,6 +246,39 @@ def get_children_with_embeddings(
return child_nodes
+ def _fetch_return_fields(
+ self,
+ ids: list[str],
+ score_map: dict[str, float],
+ return_fields: list[str],
+ ) -> list[dict]:
+ """Fetch additional fields from Neo4j for given node IDs."""
+ validated_fields = self._validate_return_fields(return_fields)
+ extra_fields = ", ".join(
+ f"n.{field} AS {field}" for field in validated_fields if field != "id"
+ )
+ return_clause = "RETURN n.id AS id"
+ if extra_fields:
+ return_clause = f"RETURN n.id AS id, {extra_fields}"
+
+ query = f"""
+ MATCH (n:Memory)
+ WHERE n.id IN $ids
+ {return_clause}
+ """
+ with self.driver.session(database=self.db_name) as session:
+ neo4j_results = session.run(query, {"ids": ids})
+ results = []
+ for record in neo4j_results:
+ node_id = record["id"]
+ item = {"id": node_id, "score": score_map.get(node_id)}
+ record_keys = record.keys()
+ for field in return_fields:
+ if field != "id" and field in record_keys:
+ item[field] = record[field]
+ results.append(item)
+ return results
+
# Search / recall operations
def search_by_embedding(
self,
@@ -258,6 +291,7 @@ def search_by_embedding(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -273,9 +307,14 @@ def search_by_embedding(
filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results.
Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]}
knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by.
+ return_fields (list[str], optional): Additional node fields to include in results
+ (e.g., ["memory", "status", "tags"]). When provided, each result dict will
+ contain these fields in addition to 'id' and 'score'.
+ Defaults to None (only 'id' and 'score' are returned).
Returns:
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
+ If return_fields is specified, each dict also includes the requested fields.
Notes:
- This method uses an external vector database (not Neo4j) to perform the search.
@@ -320,7 +359,14 @@ def search_by_embedding(
# If no filter or knowledgebase_ids provided, return vector search results directly
if not filter and not knowledgebase_ids:
- return [{"id": r.id, "score": r.score} for r in vec_results]
+ if not return_fields:
+ return [{"id": r.id, "score": r.score} for r in vec_results]
+ # Need to fetch additional fields from Neo4j
+ vec_ids = [r.id for r in vec_results]
+ if not vec_ids:
+ return []
+ score_map = {r.id: r.score for r in vec_results}
+ return self._fetch_return_fields(vec_ids, score_map, return_fields)
# Extract IDs from vector search results
vec_ids = [r.id for r in vec_results]
@@ -363,22 +409,49 @@ def search_by_embedding(
if filter_params:
params.update(filter_params)
+ # Build RETURN clause with optional extra fields
+ return_clause = "RETURN n.id AS id"
+ if return_fields:
+ validated_fields = self._validate_return_fields(return_fields)
+ extra_fields = ", ".join(
+ f"n.{field} AS {field}" for field in validated_fields if field != "id"
+ )
+ if extra_fields:
+ return_clause = f"RETURN n.id AS id, {extra_fields}"
+
# Query Neo4j to filter results
query = f"""
MATCH (n:Memory)
{where_clause}
- RETURN n.id AS id
+ {return_clause}
"""
logger.info(f"[search_by_embedding] query: {query}, params: {params}")
with self.driver.session(database=self.db_name) as session:
neo4j_results = session.run(query, params)
- filtered_ids = {record["id"] for record in neo4j_results}
+ if return_fields:
+ # Build a map of id -> extra fields from Neo4j results
+ neo4j_data = {}
+ for record in neo4j_results:
+ node_id = record["id"]
+ record_keys = record.keys()
+ neo4j_data[node_id] = {
+ field: record[field]
+ for field in return_fields
+ if field != "id" and field in record_keys
+ }
+ filtered_ids = set(neo4j_data.keys())
+ else:
+ filtered_ids = {record["id"] for record in neo4j_results}
# Filter vector results by Neo4j filtered IDs and return with scores
- filtered_results = [
- {"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids
- ]
+ filtered_results = []
+ for r in vec_results:
+ if r.id in filtered_ids:
+ item = {"id": r.id, "score": r.score}
+ if return_fields and r.id in neo4j_data:
+ item.update(neo4j_data[r.id])
+ filtered_results.append(item)
return filtered_results
@@ -397,8 +470,8 @@ def search_by_fulltext(
**kwargs,
) -> list[dict]:
"""
- TODO: 实现 Neo4j Community 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径.
- 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错.
+ TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path.
+ Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j.
"""
return []
@@ -1122,7 +1195,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]]
# Merge embeddings into parsed nodes
for parsed_node in parsed_nodes:
node_id = parsed_node["id"]
- parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None)
+ parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id)
return parsed_nodes
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index f0a23e39b..592f45a7f 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -204,21 +204,6 @@ def _get_connection_old(self):
return conn
def _get_connection(self):
- """
- Get a connection from the pool.
-
- This function:
- 1. Gets a connection from ThreadedConnectionPool
- 2. Checks if connection is closed or unhealthy
- 3. Returns healthy connection or retries (max 3 times)
- 4. Handles connection pool exhaustion gracefully
-
- Returns:
- psycopg2 connection object
-
- Raises:
- RuntimeError: If connection pool is closed or exhausted after retries
- """
logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'")
if self._pool_closed:
raise RuntimeError("Connection pool has been closed")
@@ -229,13 +214,9 @@ def _get_connection(self):
for attempt in range(max_retries):
conn = None
try:
- # Try to get connection from pool
- # This may raise PoolError if pool is exhausted
conn = self.connection_pool.getconn()
- # Check if connection is closed
if conn.closed != 0:
- # Connection is closed, return it to pool with close flag and try again
logger.warning(
f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}"
)
@@ -295,19 +276,17 @@ def _get_connection(self):
return conn
except psycopg2.pool.PoolError as pool_error:
- # Pool exhausted or other pool-related error
- # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly
error_msg = str(pool_error).lower()
if "exhausted" in error_msg or "pool" in error_msg:
# Log pool status for debugging
try:
# Try to get pool stats if available
pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}"
- logger.error(
- f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}"
+ logger.info(
+ f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}"
)
except Exception:
- logger.error(
+ logger.warning(
f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})"
)
@@ -323,7 +302,6 @@ def _get_connection(self):
raise RuntimeError(
f"Connection pool exhausted after {max_retries} attempts. "
f"This usually means connections are not being returned to the pool. "
- f"Check for connection leaks in your code."
) from pool_error
else:
# Other pool errors - retry with normal backoff
@@ -337,12 +315,8 @@ def _get_connection(self):
) from pool_error
except Exception as e:
- # Other exceptions (not pool-related)
- # Only try to return connection if we actually got one
- # If getconn() failed (e.g., pool exhausted), conn will be None
if conn is not None:
try:
- # Return connection to pool if it's valid
self.connection_pool.putconn(conn, close=True)
except Exception as putconn_error:
logger.warning(
@@ -363,20 +337,7 @@ def _get_connection(self):
raise RuntimeError("Failed to get connection after all retries")
def _return_connection(self, connection):
- """
- Return a connection to the pool.
-
- This function safely returns a connection to the pool, handling:
- - Closed connections (close them instead of returning)
- - Pool closed state (close connection directly)
- - None connections (no-op)
- - putconn() failures (close connection as fallback)
-
- Args:
- connection: psycopg2 connection object or None
- """
if self._pool_closed:
- # Pool is closed, just close the connection if it exists
if connection:
try:
connection.close()
@@ -388,13 +349,10 @@ def _return_connection(self, connection):
return
if not connection:
- # No connection to return - this is normal if _get_connection() failed
return
try:
- # Check if connection is closed
if hasattr(connection, "closed") and connection.closed != 0:
- # Connection is closed, just close it explicitly and don't return to pool
logger.debug(
"[_return_connection] Connection is closed, closing it instead of returning to pool"
)
@@ -404,12 +362,9 @@ def _return_connection(self, connection):
logger.warning(f"[_return_connection] Failed to close closed connection: {e}")
return
- # Connection is valid, return to pool
self.connection_pool.putconn(connection)
logger.debug("[_return_connection] Successfully returned connection to pool")
except Exception as e:
- # If putconn fails, try to close the connection
- # This prevents connection leaks if putconn() fails
logger.error(
f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True
)
@@ -841,8 +796,8 @@ def add_edge(
start_time = time.time()
if not source_id or not target_id:
- logger.warning(f"Edge '{source_id}' and '{target_id}' are both None")
- raise ValueError("[add_edge] source_id and target_id must be provided")
+ logger.error(f"Edge '{source_id}' and '{target_id}' are both None")
+ return
source_exists = self.get_node(source_id) is not None
target_exists = self.get_node(target_id) is not None
@@ -851,7 +806,7 @@ def add_edge(
logger.warning(
"[add_edge] Source %s or target %s does not exist.", source_exists, target_exists
)
- raise ValueError("[add_edge] source_id and target_id must be provided")
+ return
properties = {}
if user_name is not None:
@@ -1116,9 +1071,7 @@ def get_node(
self._return_connection(conn)
@timed
- def get_nodes(
- self, ids: list[str], user_name: str | None = None, **kwargs
- ) -> list[dict[str, Any]]:
+ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]:
"""
Retrieve the metadata and memory of a list of nodes.
Args:
@@ -1690,6 +1643,36 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
"""Get the ordered context chain starting from a node."""
raise NotImplementedError
+ def _extract_fields_from_properties(
+ self, properties: Any, return_fields: list[str]
+ ) -> dict[str, Any]:
+ """Extract requested fields from a PolarDB properties agtype/JSON value.
+
+ Args:
+ properties: The raw properties value from a PolarDB row (agtype or JSON string).
+ return_fields: List of field names to extract.
+
+ Returns:
+ dict with field_name -> value for each requested field found in properties.
+ """
+ result = {}
+ return_fields = self._validate_return_fields(return_fields)
+ if not properties or not return_fields:
+ return result
+ try:
+ if isinstance(properties, str):
+ props = json.loads(properties)
+ elif isinstance(properties, dict):
+ props = properties
+ else:
+ props = json.loads(str(properties))
+ except (json.JSONDecodeError, TypeError, ValueError):
+ return result
+ for field in return_fields:
+ if field != "id" and field in props:
+ result[field] = props[field]
+ return result
+
@timed
def search_by_keywords_like(
self,
@@ -1700,6 +1683,7 @@ def search_by_keywords_like(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
where_clauses = []
@@ -1751,10 +1735,14 @@ def search_by_keywords_like(
where_clauses.append("""(properties -> '"memory"')::text LIKE %s""")
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
- query = f"""
- SELECT
+ select_clause = """SELECT
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
- agtype_object_field_text(properties, 'memory') as memory_text
+ agtype_object_field_text(properties, 'memory') as memory_text"""
+ if return_fields:
+ select_clause += ", properties"
+
+ query = f"""
+ {select_clause}
FROM "{self.db_name}_graph"."Memory"
{where_clause}
"""
@@ -1775,7 +1763,11 @@ def search_by_keywords_like(
id_val = str(oldid)
if id_val.startswith('"') and id_val.endswith('"'):
id_val = id_val[1:-1]
- output.append({"id": id_val})
+ item = {"id": id_val}
+ if return_fields:
+ properties = row[2] # properties column
+ item.update(self._extract_fields_from_properties(properties, return_fields))
+ output.append(item)
logger.info(
f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
)
@@ -1795,6 +1787,7 @@ def search_by_keywords_tfidf(
knowledgebase_ids: list[str] | None = None,
tsvector_field: str = "properties_tsvector_zh",
tsquery_config: str = "jiebaqry",
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
where_clauses = []
@@ -1850,10 +1843,14 @@ def search_by_keywords_tfidf(
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
# Build fulltext search query
- query = f"""
- SELECT
+ select_clause = """SELECT
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
- agtype_object_field_text(properties, 'memory') as memory_text
+ agtype_object_field_text(properties, 'memory') as memory_text"""
+ if return_fields:
+ select_clause += ", properties"
+
+ query = f"""
+ {select_clause}
FROM "{self.db_name}_graph"."Memory"
{where_clause}
"""
@@ -1874,7 +1871,11 @@ def search_by_keywords_tfidf(
id_val = str(oldid)
if id_val.startswith('"') and id_val.endswith('"'):
id_val = id_val[1:-1]
- output.append({"id": id_val})
+ item = {"id": id_val}
+ if return_fields:
+ properties = row[2] # properties column
+ item.update(self._extract_fields_from_properties(properties, return_fields))
+ output.append(item)
logger.info(
f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}"
@@ -1897,6 +1898,7 @@ def search_by_fulltext(
knowledgebase_ids: list[str] | None = None,
tsvector_field: str = "properties_tsvector_zh",
tsquery_config: str = "jiebacfg",
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
"""
@@ -1914,15 +1916,16 @@ def search_by_fulltext(
filter: filter conditions with 'and' or 'or' logic for search results.
tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1
tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation)
+ return_fields: additional node fields to include in results
**kwargs: other parameters (e.g. cube_name)
Returns:
- list[dict]: result list containing id and score
+ list[dict]: result list containing id and score.
+ If return_fields is specified, each dict also includes the requested fields.
"""
logger.info(
f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}"
)
- # Build WHERE clause dynamically, same as search_by_embedding
start_time = time.time()
where_clauses = []
@@ -1966,13 +1969,10 @@ def search_by_fulltext(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
)
- # Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}")
where_clauses.extend(filter_conditions)
- # Add fulltext search condition
- # Convert query_text to OR query format: "word1 | word2 | word3"
tsquery_string = " | ".join(query_words)
where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")
@@ -1981,19 +1981,31 @@ def search_by_fulltext(
logger.info(f"[search_by_fulltext] where_clause: {where_clause}")
- # Build fulltext search query
+ select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id,
+ ts_rank(m.{tsvector_field}, q.fq) AS rank"""
+ if return_fields:
+ select_cols += ", m.properties"
+ where_with_q = []
+ for w in where_clauses:
+ if f"{tsvector_field} @@ to_tsquery(" in w:
+ where_with_q.append(f"m.{tsvector_field} @@ q.fq")
+ else:
+ where_with_q.append(
+ w.replace("(properties,", "(m.properties,")
+ .replace("(properties)", "(m.properties)")
+ .replace("ARRAY[properties,", "ARRAY[m.properties,")
+ )
+ where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else ""
query = f"""
- SELECT
- ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
- agtype_object_field_text(properties, 'memory') as memory_text,
- ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank
- FROM "{self.db_name}_graph"."Memory"
- {where_clause}
+ WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq)
+ SELECT {select_cols}
+ FROM "{self.db_name}_graph"."Memory" m
+ CROSS JOIN q
+ {where_clause_cte}
ORDER BY rank DESC
LIMIT {top_k};
"""
-
- params = [tsquery_string, tsquery_string]
+ params = [tsquery_string]
logger.info(f"[search_by_fulltext] query: {query}, params: {params}")
conn = None
try:
@@ -2004,7 +2016,7 @@ def search_by_fulltext(
output = []
for row in results:
oldid = row[0] # old_id
- rank = row[2] # rank score
+ rank = row[1] # rank score (no memory_text column)
id_val = str(oldid)
if id_val.startswith('"') and id_val.endswith('"'):
@@ -2013,10 +2025,16 @@ def search_by_fulltext(
# Apply threshold filter if specified
if threshold is None or score_val >= threshold:
- output.append({"id": id_val, "score": score_val})
+ item = {"id": id_val, "score": score_val}
+ if return_fields:
+ properties = row[2] # properties column
+ item.update(
+ self._extract_fields_from_properties(properties, return_fields)
+ )
+ output.append(item)
elapsed_time = time.time() - start_time
logger.info(
- f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s"
+ f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s"
)
return output[:top_k]
finally:
@@ -2026,23 +2044,21 @@ def search_by_fulltext(
def search_by_embedding(
self,
vector: list[float],
+ user_name: str,
top_k: int = 5,
scope: str | None = None,
status: str | None = None,
threshold: float | None = None,
search_filter: dict | None = None,
- user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
+ return_fields: list[str] | None = None,
**kwargs,
) -> list[dict]:
- """
- Retrieve node IDs based on vector similarity using PostgreSQL vector operations.
- """
- # Build WHERE clause dynamically like nebular.py
logger.info(
- f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}"
+ f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}"
)
+ start_time = time.time()
where_clauses = []
if scope:
where_clauses.append(
@@ -2057,31 +2073,18 @@ def search_by_embedding(
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
)
where_clauses.append("embedding is not null")
- # Add user_name filter like nebular.py
-
- """
- # user_name = self._get_config_value("user_name")
- # if not self.config.use_multi_db and user_name:
- # if kwargs.get("cube_name"):
- # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype")
- # else:
- # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype")
- """
- # Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self.config.user_name,
)
- # Add OR condition if we have any user_name conditions
if user_name_conditions:
if len(user_name_conditions) == 1:
where_clauses.append(user_name_conditions[0])
else:
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
- # Add search_filter conditions like nebular.py
if search_filter:
for key, value in search_filter.items():
if isinstance(value, str):
@@ -2093,14 +2096,12 @@ def search_by_embedding(
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
)
- # Build filter conditions using common method
filter_conditions = self._build_filter_conditions_sql(filter)
logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}")
where_clauses.extend(filter_conditions)
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
- # Keep original simple query structure but add dynamic WHERE clause
query = f"""
WITH t AS (
SELECT id,
@@ -2117,19 +2118,12 @@ def search_by_embedding(
FROM t
WHERE scope > 0.1;
"""
- # Convert vector to string format for PostgreSQL vector type
- # PostgreSQL vector type expects a string format like '[1,2,3]'
vector_str = convert_to_vector(vector)
- # Use string format directly in query instead of parameterized query
- # Replace %s with the vector string, but need to quote it properly
- # PostgreSQL vector type needs the string to be quoted
query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)")
params = []
- # Split query by lines and wrap long lines to prevent terminal truncation
query_lines = query.strip().split("\n")
for line in query_lines:
- # Wrap lines longer than 200 characters to prevent terminal truncation
if len(line) > 200:
wrapped_lines = textwrap.wrap(
line, width=200, break_long_words=False, break_on_hyphens=False
@@ -2145,28 +2139,13 @@ def search_by_embedding(
try:
conn = self._get_connection()
with conn.cursor() as cursor:
- try:
- # If params is empty, execute query directly without parameters
- if params:
- cursor.execute(query, params)
- else:
- cursor.execute(query)
- except Exception as e:
- logger.error(f"[search_by_embedding] Error executing query: {e}")
- logger.error(f"[search_by_embedding] Query length: {len(query)}")
- logger.error(
- f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}"
- )
- logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}")
- raise
+ if params:
+ cursor.execute(query, params)
+ else:
+ cursor.execute(query)
results = cursor.fetchall()
output = []
for row in results:
- """
- polarId = row[0] # id
- properties = row[1] # properties
- # embedding = row[3] # embedding
- """
if len(row) < 5:
logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}")
continue
@@ -2178,7 +2157,17 @@ def search_by_embedding(
score_val = float(score)
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
if threshold is None or score_val >= threshold:
- output.append({"id": id_val, "score": score_val})
+ item = {"id": id_val, "score": score_val}
+ if return_fields:
+ properties = row[1] # properties column
+ item.update(
+ self._extract_fields_from_properties(properties, return_fields)
+ )
+ output.append(item)
+ elapsed_time = time.time() - start_time
+ logger.info(
+ f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s"
+ )
return output[:top_k]
finally:
self._return_connection(conn)
@@ -2187,7 +2176,7 @@ def search_by_embedding(
def get_by_metadata(
self,
filters: list[dict[str, Any]],
- user_name: str | None = None,
+ user_name: str,
filter: dict | None = None,
knowledgebase_ids: list | None = None,
user_name_flag: bool = True,
@@ -2209,7 +2198,9 @@ def get_by_metadata(
Returns:
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
"""
- logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}")
+ logger.info(
+ f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}"
+ )
user_name = user_name if user_name else self._get_config_value("user_name")
@@ -2264,9 +2255,6 @@ def get_by_metadata(
else:
raise ValueError(f"Unsupported operator: {op}")
- # Build user_name filter with knowledgebase_ids support (OR relationship) using common method
- # Build user_name filter with knowledgebase_ids support (OR relationship) using common method
- # Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
@@ -2306,7 +2294,7 @@ def get_by_metadata(
results = cursor.fetchall()
ids = [str(item[0]).strip('"') for item in results]
except Exception as e:
- logger.error(f"Failed to get metadata: {e}, query is {cypher_query}")
+ logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}")
finally:
self._return_connection(conn)
@@ -2536,8 +2524,8 @@ def clear(self, user_name: str | None = None) -> None:
@timed
def export_graph(
self,
+ user_name: str,
include_embedding: bool = False,
- user_name: str | None = None,
user_id: str | None = None,
page: int | None = None,
page_size: int | None = None,
@@ -2576,7 +2564,7 @@ def export_graph(
}
"""
logger.info(
- f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}"
+ f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}"
)
user_id = user_id if user_id else self._get_config_value("user_id")
@@ -2724,159 +2712,7 @@ def export_graph(
finally:
self._return_connection(conn)
- conn = None
- try:
- conn = self._get_connection()
- # Build Cypher WHERE conditions for edges
- cypher_where_conditions = []
- if user_name:
- cypher_where_conditions.append(f"a.user_name = '{user_name}'")
- cypher_where_conditions.append(f"b.user_name = '{user_name}'")
- if user_id:
- cypher_where_conditions.append(f"a.user_id = '{user_id}'")
- cypher_where_conditions.append(f"b.user_id = '{user_id}'")
-
- # Add memory_type filter condition for edges (apply to both source and target nodes)
- if memory_type and isinstance(memory_type, list) and len(memory_type) > 0:
- # Escape single quotes in memory_type values for Cypher
- escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type]
- memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types])
- # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory']
- cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]")
- cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]")
-
- # Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list
- if status is None:
- # Default behavior: exclude deleted entries
- cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'")
- elif isinstance(status, list) and len(status) > 0:
- escaped_statuses = [st.replace("'", "\\'") for st in status]
- status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses])
- cypher_where_conditions.append(f"a.status IN [{status_list_str}]")
- cypher_where_conditions.append(f"b.status IN [{status_list_str}]")
-
- # Build filter conditions for edges (apply to both source and target nodes)
- filter_where_clause = self._build_filter_conditions_cypher(filter)
- logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}")
- if filter_where_clause:
- # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists
- # Remove the leading " AND " and replace n. with a. for source node and b. for target node
- filter_clause = filter_where_clause.strip()
- if filter_clause.startswith("AND "):
- filter_clause = filter_clause[4:].strip()
- # Replace n. with a. for source node and create a copy for target node
- source_filter = filter_clause.replace("n.", "a.")
- target_filter = filter_clause.replace("n.", "b.")
- # Combine source and target filters with AND
- combined_filter = f"({source_filter}) AND ({target_filter})"
- cypher_where_conditions.append(combined_filter)
-
- cypher_where_clause = ""
- if cypher_where_conditions:
- cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}"
-
- # Get total count of edges before pagination
- count_edge_query = f"""
- SELECT COUNT(*)
- FROM (
- SELECT * FROM cypher('{self.db_name}_graph', $$
- MATCH (a:Memory)-[r]->(b:Memory)
- {cypher_where_clause}
- RETURN a.id AS source, b.id AS target, type(r) as edge
- $$) AS (source agtype, target agtype, edge agtype)
- ) AS edges
- """
- logger.info(f"[export_graph edges count] Query: {count_edge_query}")
- with conn.cursor() as cursor:
- cursor.execute(count_edge_query)
- total_edges = cursor.fetchone()[0]
-
- # Export edges using cypher query
- # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery
- # Build pagination clause if needed
- edge_pagination_clause = ""
- if use_pagination:
- edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}"
-
- edge_query = f"""
- SELECT source, target, edge FROM (
- SELECT * FROM cypher('{self.db_name}_graph', $$
- MATCH (a:Memory)-[r]->(b:Memory)
- {cypher_where_clause}
- RETURN a.id AS source, b.id AS target, type(r) as edge
- ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC,
- COALESCE(b.created_at, '1970-01-01T00:00:00') DESC,
- a.id DESC, b.id DESC
- $$) AS (source agtype, target agtype, edge agtype)
- ) AS edges
- {edge_pagination_clause}
- """
- logger.info(f"[export_graph edges] Query: {edge_query}")
- with conn.cursor() as cursor:
- cursor.execute(edge_query)
- edge_results = cursor.fetchall()
- edges = []
-
- for row in edge_results:
- source_agtype, target_agtype, edge_agtype = row
-
- # Extract and clean source
- source_raw = (
- source_agtype.value
- if hasattr(source_agtype, "value")
- else str(source_agtype)
- )
- if (
- isinstance(source_raw, str)
- and source_raw.startswith('"')
- and source_raw.endswith('"')
- ):
- source = source_raw[1:-1]
- else:
- source = str(source_raw)
-
- # Extract and clean target
- target_raw = (
- target_agtype.value
- if hasattr(target_agtype, "value")
- else str(target_agtype)
- )
- if (
- isinstance(target_raw, str)
- and target_raw.startswith('"')
- and target_raw.endswith('"')
- ):
- target = target_raw[1:-1]
- else:
- target = str(target_raw)
-
- # Extract and clean edge type
- type_raw = (
- edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype)
- )
- if (
- isinstance(type_raw, str)
- and type_raw.startswith('"')
- and type_raw.endswith('"')
- ):
- edge_type = type_raw[1:-1]
- else:
- edge_type = str(type_raw)
-
- edges.append(
- {
- "source": source,
- "target": target,
- "type": edge_type,
- }
- )
-
- except Exception as e:
- logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True)
- raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
- finally:
- self._return_connection(conn)
-
+ edges = []
return {
"nodes": nodes,
"edges": edges,
@@ -2908,8 +2744,8 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int:
def get_all_memory_items(
self,
scope: str,
+ user_name: str,
include_embedding: bool = False,
- user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list | None = None,
status: str | None = None,
@@ -2930,14 +2766,13 @@ def get_all_memory_items(
list[dict]: Full list of memory items under this scope.
"""
logger.info(
- f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}"
+ f"[get_all_memory_items] user_name: {user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status},scope:{scope}"
)
user_name = user_name if user_name else self._get_config_value("user_name")
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
raise ValueError(f"Unsupported memory type scope: {scope}")
- # Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
@@ -3015,7 +2850,7 @@ def get_all_memory_items(
node_ids.add(node_id)
except Exception as e:
- logger.error(f"Failed to get memories: {e}", exc_info=True)
+ logger.warning(f"Failed to get memories: {e}", exc_info=True)
finally:
self._return_connection(conn)
@@ -4199,34 +4034,47 @@ def get_edges(
...
]
"""
+ start_time = time.time()
+ logger.info(f" get_edges id:{id},type:{type},direction:{direction},user_name:{user_name}")
user_name = user_name if user_name else self._get_config_value("user_name")
-
- if direction == "OUTGOING":
- pattern = "(a:Memory)-[r]->(b:Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "INCOMING":
- pattern = "(a:Memory)<-[r]-(b:Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "ANY":
- pattern = "(a:Memory)-[r]-(b:Memory)"
- where_clause = f"a.id = '{id}' OR b.id = '{id}'"
- else:
+ if direction not in ("OUTGOING", "INCOMING", "ANY"):
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
- # Add type filter
- if type != "ANY":
- where_clause += f" AND type(r) = '{type}'"
-
- # Add user filter
- where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
+ # Escape single quotes for safe embedding in Cypher string
+ id_esc = (id or "").replace("'", "''")
+ user_esc = (user_name or "").replace("'", "''")
+ type_esc = (type or "").replace("'", "''")
+ type_filter = f" AND type(r) = '{type_esc}'" if type != "ANY" else ""
+ logger.info(f"type_filter:{type_filter}")
+ if direction == "OUTGOING":
+ cypher_body = f"""
+ MATCH (a:Memory)-[r:{type}]->(b:Memory)
+ WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
+ """
+ elif direction == "INCOMING":
+ cypher_body = f"""
+ MATCH (b:Memory)<-[r:{type}]-(a:Memory)
+ WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
+ """
+ else: # ANY: union of OUTGOING and INCOMING
+ cypher_body = f"""
+ MATCH (a:Memory)-[r]->(b:Memory)
+ WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
+ UNION ALL
+ MATCH (b:Memory)<-[r]-(a:Memory)
+ WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter}
+ RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
+ """
query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
- MATCH {pattern}
- WHERE {where_clause}
- RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
+ {cypher_body.strip()}
$$) AS (from_id agtype, to_id agtype, edge_type agtype)
"""
+ logger.info(f"get_edges query:{query}")
conn = None
try:
conn = self._get_connection()
@@ -4270,6 +4118,8 @@ def get_edges(
edge_type = str(edge_type_raw)
edges.append({"from": from_id, "to": to_id, "type": edge_type})
+ elapsed_time = time.time() - start_time
+ logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s")
return edges
except Exception as e:
diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py
index a27d64758..e3d2bece9 100644
--- a/src/memos/mem_reader/multi_modal_struct.py
+++ b/src/memos/mem_reader/multi_modal_struct.py
@@ -721,6 +721,7 @@ def _process_one_item(
m_maybe_merged.get("memory_type", "LongTermMemory")
.replace("长期记忆", "LongTermMemory")
.replace("用户记忆", "UserMemory")
+ .replace("pref", "UserMemory")
)
node = self._make_memory_item(
value=m_maybe_merged.get("value", ""),
diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
index 2b49d63ba..1b4add398 100644
--- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py
@@ -50,7 +50,9 @@
class FileContentParser(BaseMessageParser):
"""Parser for file content parts."""
- def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = None) -> dict:
+ def _get_doc_llm_response(
+ self, chunk_text: str, custom_tags: list[str] | None = None
+ ) -> dict | list:
"""
Call LLM to extract memory from document chunk.
Uses doc prompts from DOC_PROMPT_DICT.
@@ -60,7 +62,7 @@ def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None =
custom_tags: Optional list of custom tags for LLM extraction
Returns:
- Parsed JSON response from LLM or empty dict if failed
+ Parsed JSON response from LLM (dict or list) or empty dict if failed
"""
if not self.llm:
logger.warning("[FileContentParser] LLM not available for fine mode")
@@ -777,35 +779,49 @@ def _make_fallback(
return [_make_fallback(idx, text, "no_llm") for idx, text in valid_chunks]
# Process single chunk with LLM extraction (worker function)
- def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem:
- """Process chunk with LLM, fallback to raw on failure."""
+ def _process_chunk(chunk_idx: int, chunk_text: str) -> list[TextualMemoryItem]:
+ """Process chunk with LLM, fallback to raw on failure. Returns list of memory items."""
try:
response_json = self._get_doc_llm_response(chunk_text, custom_tags)
if response_json:
- value = response_json.get("value", "").strip()
- if value:
- tags = response_json.get("tags", [])
- tags = tags if isinstance(tags, list) else []
- tags.extend(["mode:fine", "multimodal:file"])
-
- llm_mem_type = response_json.get("memory_type", memory_type)
- if llm_mem_type not in ["LongTermMemory", "UserMemory"]:
- llm_mem_type = memory_type
-
- return _make_memory_item(
- value=value,
- mem_type=llm_mem_type,
- tags=tags,
- key=response_json.get("key"),
- chunk_idx=chunk_idx,
- chunk_content=chunk_text,
- )
+ # Handle list format response
+ response_list = response_json.get("memory list", [])
+ memory_items = []
+ for item_data in response_list:
+ if not isinstance(item_data, dict):
+ continue
+
+ value = item_data.get("value", "").strip()
+ if value:
+ tags = item_data.get("tags", [])
+ tags = tags if isinstance(tags, list) else []
+ tags.extend(["mode:fine", "multimodal:file"])
+ key_str = item_data.get("key", "")
+
+ llm_mem_type = item_data.get("memory_type", memory_type)
+ if llm_mem_type not in ["LongTermMemory", "UserMemory"]:
+ llm_mem_type = memory_type
+
+ memory_item = _make_memory_item(
+ value=value,
+ mem_type=llm_mem_type,
+ tags=tags,
+ key=key_str,
+ chunk_idx=chunk_idx,
+ chunk_content=chunk_text,
+ )
+ memory_items.append(memory_item)
+
+ if memory_items:
+ return memory_items
+ else:
+ return [_make_fallback(chunk_idx, chunk_text)]
except Exception as e:
logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}")
# Fallback to raw chunk
logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}")
- return _make_fallback(chunk_idx, chunk_text)
+ return [_make_fallback(chunk_idx, chunk_text)]
def _relate_chunks(items: list[TextualMemoryItem]) -> None:
"""
@@ -853,30 +869,37 @@ def get_chunk_idx(item: TextualMemoryItem) -> int:
):
chunk_idx = futures[future]
try:
- node = future.result()
- memory_items.append(node)
-
- # Check if this node is a fallback by checking tags
- is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags)
- if is_fallback:
- fallback_count += 1
-
- # save raw file
- node_id = node.id
- if node.memory != node.metadata.sources[0].content:
- chunk_node = _make_memory_item(
- value=node.metadata.sources[0].content,
- mem_type="RawFileMemory",
- tags=[
- "mode:fine",
- "multimodal:file",
- f"chunk:{chunk_idx + 1}/{total_chunks}",
- ],
- chunk_idx=chunk_idx,
- chunk_content="",
- )
- chunk_node.metadata.summary_ids = [node_id]
- memory_items.append(chunk_node)
+ nodes = future.result()
+ memory_items.extend(nodes)
+
+ # Check if any node is a fallback by checking tags
+ has_fallback = False
+ for node in nodes:
+ is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags)
+ if is_fallback:
+ fallback_count += 1
+ has_fallback = True
+
+ # save raw file only if no fallback (all nodes are LLM-extracted)
+ if not has_fallback and nodes:
+ # Use first node's source info for raw file
+ first_node = nodes[0]
+ if first_node.metadata.sources and len(first_node.metadata.sources) > 0:
+ # Collect all node IDs for summary_ids
+ node_ids = [node.id for node in nodes]
+ chunk_node = _make_memory_item(
+ value=first_node.metadata.sources[0].content,
+ mem_type="RawFileMemory",
+ tags=[
+ "mode:fine",
+ "multimodal:file",
+ f"chunk:{chunk_idx + 1}/{total_chunks}",
+ ],
+ chunk_idx=chunk_idx,
+ chunk_content="",
+ )
+ chunk_node.metadata.summary_ids = node_ids
+ memory_items.append(chunk_node)
except Exception as e:
tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}")
diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
index d39955ac2..a9a727b08 100644
--- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
+++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py
@@ -1019,7 +1019,9 @@ def process_skill_memory_fine(
**kwargs,
) -> list[TextualMemoryItem]:
skills_repo_backend = _get_skill_file_storage_location()
- oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config)
+ oss_client, _missing_keys, flag = _skill_init(
+ skills_repo_backend, oss_config, skills_dir_config
+ )
if not flag:
return []
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
index 63718fd92..e4a88a635 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py
@@ -68,33 +68,35 @@ def log_add_messages(self, msg: ScheduleMessageItem):
mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id)
if mem_item is None:
raise ValueError(f"Memory {memory_id} not found after retries")
- key = getattr(mem_item.metadata, "key", None) or transform_name_to_key(
- name=mem_item.memory
- )
- exists = False
original_content = None
original_item_id = None
- if key and hasattr(mem_cube.text_mem, "graph_store"):
- candidates = mem_cube.text_mem.graph_store.get_by_metadata(
- [
- {"field": "key", "op": "=", "value": key},
- {
- "field": "memory_type",
- "op": "=",
- "value": mem_item.metadata.memory_type,
- },
- ]
+ # Determine add vs update from the merged_from field set by the upstream
+ # mem_reader during fine extraction. When the LLM merges a new memory with
+ # existing ones it writes their IDs into metadata.info["merged_from"].
+ # This avoids an extra graph DB query and the self-match / cross-user
+ # matching bugs that came with the old get_by_metadata approach.
+ merged_from = (getattr(mem_item.metadata, "info", None) or {}).get("merged_from")
+ if merged_from:
+ merged_ids = (
+ merged_from
+ if isinstance(merged_from, list | tuple | set)
+ else [merged_from]
)
- if candidates:
- exists = True
- original_item_id = candidates[0]
+ original_item_id = merged_ids[0]
+ try:
original_mem_item = mem_cube.text_mem.get(
memory_id=original_item_id, user_name=msg.mem_cube_id
)
- original_content = original_mem_item.memory
+ original_content = original_mem_item.memory if original_mem_item else None
+ except Exception as e:
+ logger.warning(
+ "Failed to fetch original memory %s for update log: %s",
+ original_item_id,
+ e,
+ )
- if exists:
+ if merged_from:
prepared_update_items_with_original.append(
{
"new_item": mem_item,
diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
index 5d86c5589..20dbb63b2 100644
--- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py
@@ -259,13 +259,19 @@ def _process_memories_with_reader(
source_doc_id = (
file_ids[0] if isinstance(file_ids, list) and file_ids else None
)
+ # Use merged_from to determine ADD vs UPDATE.
+ # The upstream mem_reader sets this during fine extraction when
+ # the new memory was merged with an existing one.
+ item_merged_from = (getattr(item.metadata, "info", None) or {}).get(
+ "merged_from"
+ )
kb_log_content.append(
{
"log_source": "KNOWLEDGE_BASE_LOG",
"trigger_source": info.get("trigger_source", "Messages")
if info
else "Messages",
- "operation": "ADD",
+ "operation": "UPDATE" if item_merged_from else "ADD",
"memory_id": item.id,
"content": item.memory,
"original_content": None,
@@ -302,29 +308,39 @@ def _process_memories_with_reader(
else:
add_content_legacy: list[dict] = []
add_meta_legacy: list[dict] = []
+ update_content_legacy: list[dict] = []
+ update_meta_legacy: list[dict] = []
for item_id, item in zip(
enhanced_mem_ids, flattened_memories, strict=False
):
key = getattr(item.metadata, "key", None) or transform_name_to_key(
name=item.memory
)
- add_content_legacy.append(
- {"content": f"{key}: {item.memory}", "ref_id": item_id}
- )
- add_meta_legacy.append(
- {
- "ref_id": item_id,
- "id": item_id,
- "key": item.metadata.key,
- "memory": item.memory,
- "memory_type": item.metadata.memory_type,
- "status": item.metadata.status,
- "confidence": item.metadata.confidence,
- "tags": item.metadata.tags,
- "updated_at": getattr(item.metadata, "updated_at", None)
- or getattr(item.metadata, "update_at", None),
- }
+ item_merged_from = (getattr(item.metadata, "info", None) or {}).get(
+ "merged_from"
)
+ meta_entry = {
+ "ref_id": item_id,
+ "id": item_id,
+ "key": item.metadata.key,
+ "memory": item.memory,
+ "memory_type": item.metadata.memory_type,
+ "status": item.metadata.status,
+ "confidence": item.metadata.confidence,
+ "tags": item.metadata.tags,
+ "updated_at": getattr(item.metadata, "updated_at", None)
+ or getattr(item.metadata, "update_at", None),
+ }
+ if item_merged_from:
+ update_content_legacy.append(
+ {"content": f"{key}: {item.memory}", "ref_id": item_id}
+ )
+ update_meta_legacy.append(meta_entry)
+ else:
+ add_content_legacy.append(
+ {"content": f"{key}: {item.memory}", "ref_id": item_id}
+ )
+ add_meta_legacy.append(meta_entry)
if add_content_legacy:
event = self.scheduler_context.services.create_event_log(
label="addMemory",
@@ -342,6 +358,23 @@ def _process_memories_with_reader(
)
event.task_id = task_id
self.scheduler_context.services.submit_web_logs([event])
+ if update_content_legacy:
+ event = self.scheduler_context.services.create_event_log(
+ label="updateMemory",
+ from_memory_type=USER_INPUT_TYPE,
+ to_memory_type=LONG_TERM_MEMORY_TYPE,
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ mem_cube=self.scheduler_context.get_mem_cube(),
+ memcube_log_content=update_content_legacy,
+ metadata=update_meta_legacy,
+ memory_len=len(update_content_legacy),
+ memcube_name=self.scheduler_context.services.map_memcube_name(
+ mem_cube_id
+ ),
+ )
+ event.task_id = task_id
+ self.scheduler_context.services.submit_web_logs([event])
else:
logger.info("No enhanced memories generated by mem_reader")
else:
diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py
index 6352d5840..8483a5151 100644
--- a/src/memos/memories/textual/prefer_text_memory/retrievers.py
+++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py
@@ -124,25 +124,45 @@ def retrieve(
explicit_prefs.sort(key=lambda x: x.score, reverse=True)
implicit_prefs.sort(key=lambda x: x.score, reverse=True)
- explicit_prefs_mem = [
- TextualMemoryItem(
- id=pref.id,
- memory=pref.memory,
- metadata=PreferenceTextualMemoryMetadata(**pref.payload),
+ explicit_prefs_mem = []
+ for pref in explicit_prefs:
+ if not pref.payload.get("preference", None):
+ continue
+ if "embedding" in pref.payload:
+ payload = pref.payload
+ else:
+ pref_vector = getattr(pref, "vector", None)
+ if pref_vector is None:
+ payload = pref.payload
+ else:
+ payload = {**pref.payload, "embedding": pref_vector}
+ explicit_prefs_mem.append(
+ TextualMemoryItem(
+ id=pref.id,
+ memory=pref.memory,
+ metadata=PreferenceTextualMemoryMetadata(**payload),
+ )
)
- for pref in explicit_prefs
- if pref.payload.get("preference", None)
- ]
- implicit_prefs_mem = [
- TextualMemoryItem(
- id=pref.id,
- memory=pref.memory,
- metadata=PreferenceTextualMemoryMetadata(**pref.payload),
+ implicit_prefs_mem = []
+ for pref in implicit_prefs:
+ if not pref.payload.get("preference", None):
+ continue
+ if "embedding" in pref.payload:
+ payload = pref.payload
+ else:
+ pref_vector = getattr(pref, "vector", None)
+ if pref_vector is None:
+ payload = pref.payload
+ else:
+ payload = {**pref.payload, "embedding": pref_vector}
+ implicit_prefs_mem.append(
+ TextualMemoryItem(
+ id=pref.id,
+ memory=pref.memory,
+ metadata=PreferenceTextualMemoryMetadata(**payload),
+ )
)
- for pref in implicit_prefs
- if pref.payload.get("preference", None)
- ]
reranker_map = {
"naive": self._naive_reranker,
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index 5faf8aa09..5b210ba61 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -404,10 +404,10 @@ def delete_by_memory_ids(self, memory_ids: list[str]) -> None:
except Exception as e:
logger.error(f"An error occurred while deleting memories by memory_ids: {e}")
- def delete_all(self) -> None:
+ def delete_all(self, user_name: str | None = None) -> None:
"""Delete all memories and their relationships from the graph store."""
try:
- self.graph_store.clear()
+ self.graph_store.clear(user_name=user_name)
logger.info("All memories and edges have been deleted from the graph.")
except Exception as e:
logger.error(f"An error occurred while deleting all memories: {e}")
@@ -424,7 +424,7 @@ def delete_by_filter(
writable_cube_ids=writable_cube_ids, file_ids=file_ids, filter=filter
)
- def load(self, dir: str) -> None:
+ def load(self, dir: str, user_name: str | None = None) -> None:
try:
memory_file = os.path.join(dir, self.config.memory_filename)
@@ -435,7 +435,7 @@ def load(self, dir: str) -> None:
with open(memory_file, encoding="utf-8") as f:
memories = json.load(f)
- self.graph_store.import_graph(memories)
+ self.graph_store.import_graph(memories, user_name=user_name)
logger.info(f"Loaded {len(memories)} memories from {memory_file}")
except FileNotFoundError:
@@ -445,10 +445,12 @@ def load(self, dir: str) -> None:
except Exception as e:
logger.error(f"An error occurred while loading memories: {e}")
- def dump(self, dir: str, include_embedding: bool = False) -> None:
+ def dump(self, dir: str, include_embedding: bool = False, user_name: str | None = None) -> None:
"""Dump memories to os.path.join(dir, self.config.memory_filename)"""
try:
- json_memories = self.graph_store.export_graph(include_embedding=include_embedding)
+ json_memories = self.graph_store.export_graph(
+ include_embedding=include_embedding, user_name=user_name
+ )
os.makedirs(dir, exist_ok=True)
memory_file = os.path.join(dir, self.config.memory_filename)
diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py
index 595cf099c..2d776912b 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/handler.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py
@@ -27,18 +27,24 @@ def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedd
self.llm = llm
self.embedder = embedder
- def detect(self, memory, top_k: int = 5, scope=None):
+ def detect(self, memory, top_k: int = 5, scope=None, user_name: str | None = None):
# 1. Search for similar memories based on embedding
embedding = memory.metadata.embedding
embedding_candidates_info = self.graph_store.search_by_embedding(
- embedding, top_k=top_k, scope=scope, threshold=self.EMBEDDING_THRESHOLD
+ embedding,
+ top_k=top_k,
+ scope=scope,
+ threshold=self.EMBEDDING_THRESHOLD,
+ user_name=user_name,
)
# 2. Filter based on similarity threshold
embedding_candidates_ids = [
info["id"] for info in embedding_candidates_info if info["id"] != memory.id
]
# 3. Judge conflicts using LLM
- embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids)
+ embedding_candidates = self.graph_store.get_nodes(
+ embedding_candidates_ids, user_name=user_name
+ )
detected_relationships = []
for embedding_candidate in embedding_candidates:
embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate)
@@ -67,13 +73,20 @@ def detect(self, memory, top_k: int = 5, scope=None):
pass
return detected_relationships
- def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, relation) -> None:
+ def resolve(
+ self,
+ memory_a: TextualMemoryItem,
+ memory_b: TextualMemoryItem,
+ relation,
+ user_name: str | None = None,
+ ) -> None:
"""
Resolve detected conflicts between two memory items using LLM fusion.
Args:
memory_a: The first conflicting memory item.
memory_b: The second conflicting memory item.
relation: relation
+ user_name: Optional user name for multi-tenant isolation.
Returns:
A fused TextualMemoryItem representing the resolved memory.
"""
@@ -105,17 +118,22 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, rela
logger.warning(
f"{relation} between {memory_a.id} and {memory_b.id} could not be resolved. "
)
- self._hard_update(memory_a, memory_b)
+ self._hard_update(memory_a, memory_b, user_name=user_name)
# —————— 2.2 Conflict resolved, update metadata and memory ————
else:
fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata)
merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata)
logger.info(f"Resolved result: {merged_memory}")
- self._resolve_in_graph(memory_a, memory_b, merged_memory)
+ self._resolve_in_graph(memory_a, memory_b, merged_memory, user_name=user_name)
except json.decoder.JSONDecodeError:
logger.error(f"Failed to parse LLM response: {response}")
- def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem):
+ def _hard_update(
+ self,
+ memory_a: TextualMemoryItem,
+ memory_b: TextualMemoryItem,
+ user_name: str | None = None,
+ ):
"""
Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata.
"""
@@ -125,7 +143,7 @@ def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem)
newer_mem = memory_a if time_a >= time_b else memory_b
older_mem = memory_b if time_a >= time_b else memory_a
- self.graph_store.delete_node(older_mem.id)
+ self.graph_store.delete_node(older_mem.id, user_name=user_name)
logger.warning(
f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>"
)
@@ -135,13 +153,21 @@ def _resolve_in_graph(
conflict_a: TextualMemoryItem,
conflict_b: TextualMemoryItem,
merged: TextualMemoryItem,
+ user_name: str | None = None,
):
- edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY")
- edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY")
+ edges_a = self.graph_store.get_edges(
+ conflict_a.id, type="ANY", direction="ANY", user_name=user_name
+ )
+ edges_b = self.graph_store.get_edges(
+ conflict_b.id, type="ANY", direction="ANY", user_name=user_name
+ )
all_edges = edges_a + edges_b
self.graph_store.add_node(
- merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True)
+ merged.id,
+ merged.memory,
+ merged.metadata.model_dump(exclude_none=True),
+ user_name=user_name,
)
for edge in all_edges:
@@ -150,13 +176,15 @@ def _resolve_in_graph(
if new_from == new_to:
continue
# Check if the edge already exists before adding
- if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"):
- self.graph_store.add_edge(new_from, new_to, edge["type"])
-
- self.graph_store.update_node(conflict_a.id, {"status": "archived"})
- self.graph_store.update_node(conflict_b.id, {"status": "archived"})
- self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO")
- self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO")
+ if not self.graph_store.edge_exists(
+ new_from, new_to, edge["type"], direction="ANY", user_name=user_name
+ ):
+ self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name)
+
+ self.graph_store.update_node(conflict_a.id, {"status": "archived"}, user_name=user_name)
+ self.graph_store.update_node(conflict_b.id, {"status": "archived"}, user_name=user_name)
+ self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO", user_name=user_name)
+ self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO", user_name=user_name)
logger.debug(
f"Archive {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}."
)
diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py
index 1afdc9281..132582a0d 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py
@@ -141,6 +141,7 @@ def mark_memory_status(
self,
memory_items: list[TextualMemoryItem],
status: Literal["activated", "resolving", "archived", "deleted"],
+ user_name: str | None = None,
) -> None:
"""
Support status marking operations during history management. Common usages are:
@@ -157,6 +158,7 @@ def mark_memory_status(
self.graph_db.update_node,
id=mem.id,
fields={"status": status},
+ user_name=user_name,
)
)
diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py
index cbc349d67..4ca30c7b8 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py
@@ -238,7 +238,9 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None:
_submit_batches(graph_nodes, "graph memory")
if graph_node_ids and self.is_reorganize:
- self.reorganizer.add_message(QueueMessage(op="add", after_node=graph_node_ids))
+ self.reorganizer.add_message(
+ QueueMessage(op="add", after_node=graph_node_ids, user_name=user_name)
+ )
return added_ids
@@ -411,16 +413,19 @@ def _add_to_graph_memory(
QueueMessage(
op="add",
after_node=[node_id],
+ user_name=user_name,
)
)
return node_id
- def _inherit_edges(self, from_id: str, to_id: str) -> None:
+ def _inherit_edges(self, from_id: str, to_id: str, user_name: str | None = None) -> None:
"""
Migrate all non-lineage edges from `from_id` to `to_id`,
and remove them from `from_id` after copying.
"""
- edges = self.graph_store.get_edges(from_id, type="ANY", direction="ANY")
+ edges = self.graph_store.get_edges(
+ from_id, type="ANY", direction="ANY", user_name=user_name
+ )
for edge in edges:
if edge["type"] == "MERGED_TO":
@@ -433,20 +438,29 @@ def _inherit_edges(self, from_id: str, to_id: str) -> None:
continue
# Add edge to merged node if it doesn't already exist
- if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"):
- self.graph_store.add_edge(new_from, new_to, edge["type"])
+ if not self.graph_store.edge_exists(
+ new_from, new_to, edge["type"], direction="ANY", user_name=user_name
+ ):
+ self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name)
# Remove original edge if it involved the archived node
- self.graph_store.delete_edge(edge["from"], edge["to"], edge["type"])
+ self.graph_store.delete_edge(
+ edge["from"], edge["to"], edge["type"], user_name=user_name
+ )
def _ensure_structure_path(
- self, memory_type: str, metadata: TreeNodeTextualMemoryMetadata
+ self,
+ memory_type: str,
+ metadata: TreeNodeTextualMemoryMetadata,
+ user_name: str | None = None,
) -> str:
"""
Ensure structural path exists (ROOT → ... → final node), return last node ID.
Args:
- path: like ["hobby", "photography"]
+ memory_type: Memory type for the structure node.
+ metadata: Metadata containing key and other fields.
+ user_name: Optional user name for multi-tenant isolation.
Returns:
Final node ID of the structure path.
@@ -456,7 +470,8 @@ def _ensure_structure_path(
[
{"field": "memory", "op": "=", "value": metadata.key},
{"field": "memory_type", "op": "=", "value": memory_type},
- ]
+ ],
+ user_name=user_name,
)
if existing:
node_id = existing[0] # Use the first match
@@ -479,14 +494,16 @@ def _ensure_structure_path(
),
)
self.graph_store.add_node(
- id=new_node.id,
- memory=new_node.memory,
- metadata=new_node.metadata.model_dump(exclude_none=True),
+ new_node.id,
+ new_node.memory,
+ new_node.metadata.model_dump(exclude_none=True),
+ user_name=user_name,
)
self.reorganizer.add_message(
QueueMessage(
op="add",
after_node=[new_node.id],
+ user_name=user_name,
)
)
diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
index ea06a7c60..b7fb6b1a0 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py
@@ -52,12 +52,14 @@ def __init__(
before_edge: list[str] | list[GraphDBEdge] | None = None,
after_node: list[str] | list[GraphDBNode] | None = None,
after_edge: list[str] | list[GraphDBEdge] | None = None,
+ user_name: str | None = None,
):
self.op = op
self.before_node = before_node
self.before_edge = before_edge
self.after_node = after_node
self.after_edge = after_edge
+ self.user_name = user_name
def __str__(self) -> str:
return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})"
@@ -191,11 +193,15 @@ def handle_add(self, message: QueueMessage):
logger.debug(f"Handling add operation: {str(message)[:500]}")
added_node = message.after_node[0]
detected_relationships = self.resolver.detect(
- added_node, scope=added_node.metadata.memory_type
+ added_node,
+ scope=added_node.metadata.memory_type,
+ user_name=message.user_name,
)
if detected_relationships:
for added_node, existing_node, relation in detected_relationships:
- self.resolver.resolve(added_node, existing_node, relation)
+ self.resolver.resolve(
+ added_node, existing_node, relation, user_name=message.user_name
+ )
self._reorganize_needed = True
@@ -209,6 +215,7 @@ def optimize_structure(
min_cluster_size: int = 4,
min_group_size: int = 20,
max_duration_sec: int = 600,
+ user_name: str | None = None,
):
"""
Periodically reorganize the graph:
@@ -232,7 +239,7 @@ def _check_deadline(where: str):
logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.")
return
- if self.graph_store.node_not_exist(scope):
+ if self.graph_store.node_not_exist(scope, user_name=user_name):
logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.")
return
@@ -244,12 +251,14 @@ def _check_deadline(where: str):
logger.debug(
f"[GraphStructureReorganize] Num of scope in self.graph_store is"
- f" {self.graph_store.get_memory_count(scope)}"
+ f" {self.graph_store.get_memory_count(scope, user_name=user_name)}"
)
# Load candidate nodes
if _check_deadline("[GraphStructureReorganize] Before loading candidates"):
return
- raw_nodes = self.graph_store.get_structure_optimization_candidates(scope)
+ raw_nodes = self.graph_store.get_structure_optimization_candidates(
+ scope, user_name=user_name
+ )
nodes = [GraphDBNode(**n) for n in raw_nodes]
if not nodes:
@@ -281,6 +290,7 @@ def _check_deadline(where: str):
scope,
local_tree_threshold,
min_cluster_size,
+ user_name,
)
)
@@ -307,6 +317,7 @@ def _process_cluster_and_write(
scope: str,
local_tree_threshold: int,
min_cluster_size: int,
+ user_name: str | None = None,
):
if len(cluster_nodes) <= min_cluster_size:
return
@@ -319,15 +330,17 @@ def _process_cluster_and_write(
if len(sub_nodes) < min_cluster_size:
continue # Skip tiny noise
sub_parent_node = self._summarize_cluster(sub_nodes, scope)
- self._create_parent_node(sub_parent_node)
- self._link_cluster_nodes(sub_parent_node, sub_nodes)
+ self._create_parent_node(sub_parent_node, user_name=user_name)
+ self._link_cluster_nodes(sub_parent_node, sub_nodes, user_name=user_name)
sub_parents.append(sub_parent_node)
if sub_parents and len(sub_parents) >= min_cluster_size:
cluster_parent_node = self._summarize_cluster(cluster_nodes, scope)
- self._create_parent_node(cluster_parent_node)
+ self._create_parent_node(cluster_parent_node, user_name=user_name)
for sub_parent in sub_parents:
- self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT")
+ self.graph_store.add_edge(
+ cluster_parent_node.id, sub_parent.id, "PARENT", user_name=user_name
+ )
logger.info("Adding relations/reasons")
nodes_to_check = cluster_nodes
@@ -351,10 +364,16 @@ def _process_cluster_and_write(
# 1) Add pairwise relations
for rel in results["relations"]:
if not self.graph_store.edge_exists(
- rel["source_id"], rel["target_id"], rel["relation_type"]
+ rel["source_id"],
+ rel["target_id"],
+ rel["relation_type"],
+ user_name=user_name,
):
self.graph_store.add_edge(
- rel["source_id"], rel["target_id"], rel["relation_type"]
+ rel["source_id"],
+ rel["target_id"],
+ rel["relation_type"],
+ user_name=user_name,
)
# 2) Add inferred nodes and link to sources
@@ -363,14 +382,21 @@ def _process_cluster_and_write(
inf_node.id,
inf_node.memory,
inf_node.metadata.model_dump(exclude_none=True),
+ user_name=user_name,
)
for src_id in inf_node.metadata.sources:
- self.graph_store.add_edge(src_id, inf_node.id, "INFERS")
+ self.graph_store.add_edge(
+ src_id, inf_node.id, "INFERS", user_name=user_name
+ )
# 3) Add sequence links
for seq in results["sequence_links"]:
- if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"):
- self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS")
+ if not self.graph_store.edge_exists(
+ seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name
+ ):
+ self.graph_store.add_edge(
+ seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name
+ )
# 4) Add aggregate concept nodes
for agg_node in results["aggregate_nodes"]:
@@ -378,9 +404,12 @@ def _process_cluster_and_write(
agg_node.id,
agg_node.memory,
agg_node.metadata.model_dump(exclude_none=True),
+ user_name=user_name,
)
for child_id in agg_node.metadata.sources:
- self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO")
+ self.graph_store.add_edge(
+ agg_node.id, child_id, "AGGREGATE_TO", user_name=user_name
+ )
logger.info("[Reorganizer] Cluster relation/reasoning done.")
@@ -577,7 +606,7 @@ def _parse_json_result(self, response_text):
)
return {}
- def _create_parent_node(self, parent_node: GraphDBNode) -> None:
+ def _create_parent_node(self, parent_node: GraphDBNode, user_name: str | None = None) -> None:
"""
Create a new parent node for the cluster.
"""
@@ -585,17 +614,23 @@ def _create_parent_node(self, parent_node: GraphDBNode) -> None:
parent_node.id,
parent_node.memory,
parent_node.metadata.model_dump(exclude_none=True),
+ user_name=user_name,
)
- def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]):
+ def _link_cluster_nodes(
+ self,
+ parent_node: GraphDBNode,
+ child_nodes: list[GraphDBNode],
+ user_name: str | None = None,
+ ):
"""
Add PARENT edges from the parent node to all nodes in the cluster.
"""
for child in child_nodes:
if not self.graph_store.edge_exists(
- parent_node.id, child.id, "PARENT", direction="OUTGOING"
+ parent_node.id, child.id, "PARENT", direction="OUTGOING", user_name=user_name
):
- self.graph_store.add_edge(parent_node.id, child.id, "PARENT")
+ self.graph_store.add_edge(parent_node.id, child.id, "PARENT", user_name=user_name)
def _preprocess_message(self, message: QueueMessage) -> bool:
message = self._convert_id_to_node(message)
@@ -613,7 +648,9 @@ def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage:
for i, node in enumerate(message.after_node or []):
if not isinstance(node, str):
continue
- raw_node = self.graph_store.get_node(node, include_embedding=True)
+ raw_node = self.graph_store.get_node(
+ node, include_embedding=True, user_name=message.user_name
+ )
if raw_node is None:
logger.debug(f"Node with ID {node} not found in the graph store.")
message.after_node[i] = None
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index 9dcbe8c56..cc269e8c4 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -524,7 +524,7 @@ def _retrieve_from_keyword(
user_name=user_name,
tsquery_config="jiebaqry",
)
- except Exception as e:
+ except Exception:
logger.warning(
f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}"
)
diff --git a/src/memos/memos_tools/thread_safe_dict_segment.py b/src/memos/memos_tools/thread_safe_dict_segment.py
index c1c10e3e1..bf918889f 100644
--- a/src/memos/memos_tools/thread_safe_dict_segment.py
+++ b/src/memos/memos_tools/thread_safe_dict_segment.py
@@ -71,7 +71,7 @@ def acquire_write(self) -> bool:
self._waiting_writers -= 1
self._last_write_time = time.time()
return True
- except:
+ except Exception:
self._waiting_writers -= 1
raise
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index 1678d9d15..d890c77bf 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -443,7 +443,10 @@ def _search_pref(
},
search_filter=search_req.filter,
)
- formatted_results = self._postformat_memories(results, user_context.mem_cube_id)
+ include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true"
+ formatted_results = self._postformat_memories(
+ results, user_context.mem_cube_id, include_embedding=include_embedding
+ )
# For each returned item, tackle with metadata.info project_id /
# operation / manager_user_id
diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py
index e4f1ca334..f431bd041 100644
--- a/src/memos/templates/mem_reader_prompts.py
+++ b/src/memos/templates/mem_reader_prompts.py
@@ -244,12 +244,17 @@
Return a single valid JSON object with the following structure:
-Return valid JSON:
{
- "key": ,
- "memory_type": "LongTermMemory",
- "value": ,
- "tags":
+ "memory list": [
+ {
+ "key": ,
+ "memory_type": "LongTermMemory",
+ "value": ,
+ "tags":
+ }
+ ...
+ ],
+ "summary":
}
Language rules:
@@ -264,7 +269,7 @@
Your Output:"""
SIMPLE_STRUCT_DOC_READER_PROMPT_ZH = """您是搜索与检索系统的文本分析专家。
-您的任务是处理文档片段,并生成一个结构化的 JSON 对象。
+您的任务是处理文档片段,并生成一个结构化的 JSON 列表对象。
请执行以下操作:
1. 识别反映文档中事实内容、见解、决策或含义的关键信息——包括任何显著的主题、结论或数据点,使读者无需阅读原文即可充分理解该片段的核心内容。
@@ -281,14 +286,19 @@
- 优先考虑完整性和保真度,而非简洁性。
- 不要泛化或跳过可能具有上下文意义的细节。
-返回一个有效的 JSON 对象,结构如下:
+返回有效的 JSON 对象:
-返回有效的 JSON:
{
- "key": <字符串,`value` 字段的简洁标题>,
- "memory_type": "LongTermMemory",
- "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>,
- "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])>
+ "memory list": [
+ {
+ "key": <字符串,`value` 字段的简洁标题>,
+ "memory_type": "LongTermMemory",
+ "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>,
+ "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])>
+ }
+ ...
+ ],
+ "summary": <简洁总结原文内容,与输入语言一致>
}
语言规则:
diff --git a/tests/graph_dbs/test_search_return_fields.py b/tests/graph_dbs/test_search_return_fields.py
new file mode 100644
index 000000000..82a50308b
--- /dev/null
+++ b/tests/graph_dbs/test_search_return_fields.py
@@ -0,0 +1,306 @@
+"""
+Regression tests for issue #955: search methods support specifying return fields.
+
+Tests that search_by_embedding (and other search methods) accept a `return_fields`
+parameter and include the requested fields in the result dicts, eliminating the
+need for N+1 get_node() calls.
+"""
+
+import uuid
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from memos.configs.graph_db import Neo4jGraphDBConfig
+
+
+@pytest.fixture
+def neo4j_config():
+ return Neo4jGraphDBConfig(
+ uri="bolt://localhost:7687",
+ user="neo4j",
+ password="test",
+ db_name="test_memory_db",
+ auto_create=False,
+ embedding_dimension=3,
+ )
+
+
+@pytest.fixture
+def neo4j_db(neo4j_config):
+ with patch("neo4j.GraphDatabase") as mock_gd:
+ mock_driver = MagicMock()
+ mock_gd.driver.return_value = mock_driver
+ from memos.graph_dbs.neo4j import Neo4jGraphDB
+
+ db = Neo4jGraphDB(neo4j_config)
+ db.driver = mock_driver
+ yield db
+
+
+class TestNeo4jSearchReturnFields:
+ """Tests for Neo4jGraphDB.search_by_embedding with return_fields."""
+
+ def test_return_fields_included_in_results(self, neo4j_db):
+ """return_fields values are present in each result dict."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ node_id = str(uuid.uuid4())
+ session_mock.run.return_value = [
+ {"id": node_id, "score": 0.95, "memory": "hello", "status": "activated"},
+ ]
+
+ results = neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ return_fields=["memory", "status"],
+ )
+
+ assert len(results) == 1
+ assert results[0]["id"] == node_id
+ assert results[0]["score"] == 0.95
+ assert results[0]["memory"] == "hello"
+ assert results[0]["status"] == "activated"
+
+ def test_backward_compatible_without_return_fields(self, neo4j_db):
+ """Without return_fields, only id and score are returned (old behavior)."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = [
+ {"id": str(uuid.uuid4()), "score": 0.9},
+ ]
+
+ results = neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ )
+
+ assert len(results) == 1
+ assert set(results[0].keys()) == {"id", "score"}
+
+ def test_cypher_return_clause_includes_fields(self, neo4j_db):
+ """Cypher RETURN clause contains the requested fields."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ return_fields=["memory", "tags"],
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "node.memory AS memory" in query
+ assert "node.tags AS tags" in query
+
+ def test_cypher_return_clause_default(self, neo4j_db):
+ """Without return_fields, RETURN clause only has id and score."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "RETURN node.id AS id, score" in query
+ assert "node.memory" not in query
+
+ def test_return_fields_skips_id_field(self, neo4j_db):
+ """Passing 'id' in return_fields does not duplicate it in RETURN clause."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ return_fields=["id", "memory"],
+ )
+
+ query = session_mock.run.call_args[0][0]
+ # 'id' should appear only once (as node.id AS id), not duplicated
+ assert query.count("node.id AS id") == 1
+ assert "node.memory AS memory" in query
+
+ def test_threshold_filtering_still_works_with_return_fields(self, neo4j_db):
+ """Threshold filtering works correctly when return_fields is specified."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = [
+ {"id": str(uuid.uuid4()), "score": 0.9, "memory": "high score"},
+ {"id": str(uuid.uuid4()), "score": 0.3, "memory": "low score"},
+ ]
+
+ results = neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ threshold=0.5,
+ return_fields=["memory"],
+ )
+
+ assert len(results) == 1
+ assert results[0]["memory"] == "high score"
+
+
+class TestPolarDBExtractFieldsFromProperties:
+ """Tests for PolarDBGraphDB._extract_fields_from_properties helper."""
+
+ @pytest.fixture
+ def polardb_instance(self):
+ """Create a minimal PolarDB instance for testing the helper method."""
+ with patch("memos.graph_dbs.polardb.PolarDBGraphDB.__init__", return_value=None):
+ from memos.graph_dbs.polardb import PolarDBGraphDB
+
+ db = PolarDBGraphDB.__new__(PolarDBGraphDB)
+ yield db
+
+ def test_extract_from_json_string(self, polardb_instance):
+ """Extract fields from a JSON string properties value."""
+ props = '{"id": "abc", "memory": "hello", "status": "activated", "tags": ["a"]}'
+ result = polardb_instance._extract_fields_from_properties(
+ props, ["memory", "status", "tags"]
+ )
+ assert result == {"memory": "hello", "status": "activated", "tags": ["a"]}
+
+ def test_extract_from_dict(self, polardb_instance):
+ """Extract fields from a dict properties value."""
+ props = {"id": "abc", "memory": "hello", "status": "activated"}
+ result = polardb_instance._extract_fields_from_properties(props, ["memory", "status"])
+ assert result == {"memory": "hello", "status": "activated"}
+
+ def test_extract_skips_id(self, polardb_instance):
+ """'id' field is skipped even if requested."""
+ props = '{"id": "abc", "memory": "hello"}'
+ result = polardb_instance._extract_fields_from_properties(props, ["id", "memory"])
+ assert result == {"memory": "hello"}
+
+ def test_extract_missing_fields(self, polardb_instance):
+ """Missing fields are silently skipped."""
+ props = '{"id": "abc", "memory": "hello"}'
+ result = polardb_instance._extract_fields_from_properties(props, ["memory", "nonexistent"])
+ assert result == {"memory": "hello"}
+
+ def test_extract_empty_properties(self, polardb_instance):
+ """Empty/None properties return empty dict."""
+ assert polardb_instance._extract_fields_from_properties(None, ["memory"]) == {}
+ assert polardb_instance._extract_fields_from_properties("", ["memory"]) == {}
+
+ def test_extract_invalid_json(self, polardb_instance):
+ """Invalid JSON returns empty dict without raising."""
+ result = polardb_instance._extract_fields_from_properties("not-json", ["memory"])
+ assert result == {}
+
+
+class TestFieldNameValidation:
+ """Tests for _validate_return_fields injection prevention."""
+
+ def test_valid_field_names_pass(self):
+ from memos.graph_dbs.base import BaseGraphDB
+
+ result = BaseGraphDB._validate_return_fields(["memory", "status", "tags", "user_name"])
+ assert result == ["memory", "status", "tags", "user_name"]
+
+ def test_invalid_field_names_rejected(self):
+ from memos.graph_dbs.base import BaseGraphDB
+
+ # Cypher injection attempts
+ result = BaseGraphDB._validate_return_fields(
+ [
+ "memory} RETURN n //",
+ "status; DROP",
+ "valid_field",
+ "a.b",
+ "field name",
+ "",
+ ]
+ )
+ assert result == ["valid_field"]
+
+ def test_none_returns_empty(self):
+ from memos.graph_dbs.base import BaseGraphDB
+
+ assert BaseGraphDB._validate_return_fields(None) == []
+
+ def test_empty_list_returns_empty(self):
+ from memos.graph_dbs.base import BaseGraphDB
+
+ assert BaseGraphDB._validate_return_fields([]) == []
+
+ def test_injection_in_cypher_query_prevented(self, neo4j_db):
+ """Malicious field names should not appear in the Cypher query."""
+ session_mock = neo4j_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ neo4j_db.search_by_embedding(
+ vector=[0.1, 0.2, 0.3],
+ top_k=5,
+ user_name="test_user",
+ return_fields=["memory} RETURN n //", "valid_field"],
+ )
+
+ query = session_mock.run.call_args[0][0]
+ # Injection attempt should NOT appear in query
+ assert "memory}" not in query
+ assert "RETURN n //" not in query
+ # Valid field should appear
+ assert "node.valid_field AS valid_field" in query
+
+
+class TestNeo4jCommunitySearchReturnFields:
+ """Tests for Neo4jCommunityGraphDB._fetch_return_fields with return_fields."""
+
+ @pytest.fixture
+ def neo4j_community_db(self):
+ """Create a minimal Neo4jCommunityGraphDB instance by patching __init__."""
+ with patch(
+ "memos.graph_dbs.neo4j_community.Neo4jCommunityGraphDB.__init__", return_value=None
+ ):
+ from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
+
+ db = Neo4jCommunityGraphDB.__new__(Neo4jCommunityGraphDB)
+ db.driver = MagicMock()
+ db.db_name = "test_memory_db"
+ yield db
+
+ def test_fetch_return_fields_queries_neo4j(self, neo4j_community_db):
+ """_fetch_return_fields builds correct Cypher and returns fields."""
+ session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = [
+ {"id": "node-1", "memory": "hello", "status": "activated"},
+ ]
+
+ results = neo4j_community_db._fetch_return_fields(
+ ids=["node-1"],
+ score_map={"node-1": 0.95},
+ return_fields=["memory", "status"],
+ )
+
+ assert len(results) == 1
+ assert results[0]["id"] == "node-1"
+ assert results[0]["score"] == 0.95
+ assert results[0]["memory"] == "hello"
+ assert results[0]["status"] == "activated"
+
+ query = session_mock.run.call_args[0][0]
+ assert "n.memory AS memory" in query
+ assert "n.status AS status" in query
+
+ def test_fetch_return_fields_validates_names(self, neo4j_community_db):
+ """_fetch_return_fields rejects invalid field names."""
+ session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value
+ session_mock.run.return_value = []
+
+ neo4j_community_db._fetch_return_fields(
+ ids=["node-1"],
+ score_map={"node-1": 0.95},
+ return_fields=["memory} RETURN n //", "valid_field"],
+ )
+
+ query = session_mock.run.call_args[0][0]
+ assert "memory}" not in query
+ assert "n.valid_field AS valid_field" in query
diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py
index 46cf3a1f6..a6ac186b7 100644
--- a/tests/memories/textual/test_history_manager.py
+++ b/tests/memories/textual/test_history_manager.py
@@ -131,7 +131,7 @@ def test_mark_memory_status(history_manager, mock_graph_db):
# Assert
assert mock_graph_db.update_node.call_count == 3
- # Verify we called it correctly
- mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status})
- mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status})
- mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status})
+ # Verify we called it correctly (user_name=None is passed by mark_memory_status)
+ mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None)
+ mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None)
+ mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None)