diff --git a/README.md b/README.md index 70562fec7..45a372ed1 100644 --- a/README.md +++ b/README.md @@ -345,10 +345,10 @@ url = {https://global-sci.com/article/91443/memory3-language-modeling-with-expli ## 🙌 Contributing -We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/contribution/overview) to get started. +We welcome contributions from the community! Please read our [contribution guidelines](https://memos-docs.openmem.net/open_source/contribution/overview/) to get started.
## 📄 License -MemOS is licensed under the [Apache 2.0 License](./LICENSE). \ No newline at end of file +MemOS is licensed under the [Apache 2.0 License](./LICENSE). diff --git a/pyproject.toml b/pyproject.toml index b4b01e0e1..4a9ea8852 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.6" +version = "2.0.7" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index b568ae0c2..fefa3b2ab 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.6" +__version__ = "2.0.7" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 267d1bb28..8e7785ad5 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -64,7 +64,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse # Expand top_k for deduplication (5x to ensure enough candidates) if search_req_local.dedup in ("sim", "mmr"): - search_req_local.top_k = search_req_local.top_k * 5 + search_req_local.top_k = search_req_local.top_k * 3 # Search and deduplicate cube_view = self._build_cube_view(search_req_local) @@ -152,9 +152,6 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di return results embeddings = self._extract_embeddings([mem for _, mem, _ in flat]) - if embeddings is None: - documents = [mem.get("memory", "") for _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(embeddings) @@ -235,12 +232,39 @@ def _mmr_dedup_text_memories( if len(flat) <= 1: return results + total_by_type: dict[str, int] = {"text": 0, "preference": 0} + existing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_by_type: dict[str, int] = {"text": 0, "preference": 0} + missing_indices: list[int] = [] + for idx, (mem_type, _, mem, _) in enumerate(flat): + if mem_type not in total_by_type: + total_by_type[mem_type] = 0 + existing_by_type[mem_type] = 0 + missing_by_type[mem_type] = 0 + total_by_type[mem_type] += 1 + + embedding = mem.get("metadata", {}).get("embedding") + if embedding: + existing_by_type[mem_type] += 1 + else: + missing_by_type[mem_type] += 1 + missing_indices.append(idx) + + self.logger.info( + "[SearchHandler] MMR embedding metadata scan: total=%s total_by_type=%s existing_by_type=%s missing_by_type=%s", + len(flat), + total_by_type, + existing_by_type, + missing_by_type, + ) + if missing_indices: + self.logger.warning( + "[SearchHandler] MMR embedding metadata missing; will compute missing embeddings: missing_total=%s", + len(missing_indices), + ) + # Get or compute embeddings embeddings = self._extract_embeddings([mem for _, _, mem, _ in flat]) - if embeddings is None: - self.logger.warning("[SearchHandler] Embedding is missing; recomputing embeddings") - documents = [mem.get("memory", "") for _, _, mem, _ in flat] - embeddings = self.searcher.embedder.embed(documents) # Compute similarity matrix using NumPy-optimized method # Returns numpy array but compatible with list[i][j] indexing @@ -404,14 +428,32 @@ def _max_similarity( return 0.0 return max(similarity_matrix[index][j] for j in selected_indices) - @staticmethod - def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None: + def _extract_embeddings(self, memories: list[dict[str, Any]]) -> list[list[float]]: embeddings: list[list[float]] = [] - for mem in memories: - embedding = mem.get("metadata", {}).get("embedding") - if not embedding: - return None - embeddings.append(embedding) + missing_indices: list[int] = [] + missing_documents: list[str] = [] + + for idx, mem in enumerate(memories): + metadata = mem.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + mem["metadata"] = metadata + + embedding = metadata.get("embedding") + if embedding: + embeddings.append(embedding) + continue + + embeddings.append([]) + missing_indices.append(idx) + missing_documents.append(mem.get("memory", "")) + + if missing_indices: + computed = self.searcher.embedder.embed(missing_documents) + for idx, embedding in zip(missing_indices, computed, strict=False): + embeddings[idx] = embedding + memories[idx]["metadata"]["embedding"] = embedding + return embeddings @staticmethod diff --git a/src/memos/api/middleware/__init__.py b/src/memos/api/middleware/__init__.py index 64cbc5c60..fd39252f5 100644 --- a/src/memos/api/middleware/__init__.py +++ b/src/memos/api/middleware/__init__.py @@ -1,13 +1,14 @@ """Krolik middleware extensions for MemOS.""" -from .auth import verify_api_key, require_scope, require_admin, require_read, require_write +from .auth import require_admin, require_read, require_scope, require_write, verify_api_key from .rate_limit import RateLimitMiddleware + __all__ = [ - "verify_api_key", - "require_scope", + "RateLimitMiddleware", "require_admin", "require_read", + "require_scope", "require_write", - "RateLimitMiddleware", + "verify_api_key", ] diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 6fc03e735..5bf27e985 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -99,12 +99,12 @@ class ChatRequest(BaseRequest): manager_user_id: str | None = Field(None, description="Manager User ID") project_id: str | None = Field(None, description="Project ID") relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) @@ -339,12 +339,12 @@ class APISearchRequest(BaseRequest): ) relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) @@ -785,12 +785,12 @@ class APIChatCompleteRequest(BaseRequest): manager_user_id: str | None = Field(None, description="Manager User ID") project_id: str | None = Field(None, description="Project ID") relativity: float = Field( - 0.0, + 0.45, ge=0, description=( "Relevance threshold for recalled memories. " "Only memories with metadata.relativity >= relativity will be returned. " - "Use 0 to disable threshold filtering. Default: 0.3." + "Use 0 to disable threshold filtering. Default: 0.45." ), ) diff --git a/src/memos/api/utils/api_keys.py b/src/memos/api/utils/api_keys.py index 559ddd355..29b493fd0 100644 --- a/src/memos/api/utils/api_keys.py +++ b/src/memos/api/utils/api_keys.py @@ -5,8 +5,8 @@ """ import hashlib -import os import secrets + from dataclasses import dataclass from datetime import datetime, timedelta diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 538d913ea..2b3bd0967 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -73,7 +73,6 @@ async def _create_embeddings(): ) ) logger.info(f"Embeddings request succeeded with {time.time() - init_time} seconds") - logger.info(f"Embeddings request response: {response}") return [r.embedding for r in response.data] except Exception as e: if self.use_backup_client: diff --git a/src/memos/graph_dbs/base.py b/src/memos/graph_dbs/base.py index 130b66a3d..0bc4a54f8 100644 --- a/src/memos/graph_dbs/base.py +++ b/src/memos/graph_dbs/base.py @@ -1,12 +1,35 @@ +import re + from abc import ABC, abstractmethod from typing import Any, Literal +# Pattern for valid field names: alphanumeric and underscores, must start with letter or underscore +_VALID_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + class BaseGraphDB(ABC): """ Abstract base class for a graph database interface used in a memory-augmented RAG system. """ + @staticmethod + def _validate_return_fields(return_fields: list[str] | None) -> list[str]: + """Validate and sanitize return_fields to prevent query injection. + + Only allows alphanumeric characters and underscores in field names. + Silently drops invalid field names. + + Args: + return_fields: List of field names to validate. + + Returns: + List of valid field names. + """ + if not return_fields: + return [] + return [f for f in return_fields if _VALID_FIELD_NAME_RE.match(f)] + # Node (Memory) Management @abstractmethod def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @@ -144,16 +167,23 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: # Search / recall operations @abstractmethod - def search_by_embedding(self, vector: list[float], top_k: int = 5, **kwargs) -> list[dict]: + def search_by_embedding( + self, vector: list[float], top_k: int = 5, return_fields: list[str] | None = None, **kwargs + ) -> list[dict]: """ Retrieve node IDs based on vector similarity. Args: vector (list[float]): The embedding vector representing query semantics. top_k (int): Number of top similar nodes to retrieve. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method may internally call a VecDB (e.g., Qdrant) or store embeddings in the graph DB itself. diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 746051187..33eb39692 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -818,6 +818,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -832,9 +833,14 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result + dict will contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses Neo4j native vector indexing to search for similar nodes. @@ -886,11 +892,20 @@ def search_by_embedding( if where_clauses: where_clause = "WHERE " + " AND ".join(where_clauses) + return_clause = "RETURN node.id AS id, score" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"node.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN node.id AS id, score, {extra_fields}" + query = f""" CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) YIELD node, score {where_clause} - RETURN node.id AS id, score + {return_clause} """ parameters = {"embedding": vector, "k": top_k} @@ -920,7 +935,15 @@ def search_by_embedding( print(f"[search_by_embedding] query: {query},parameters: {parameters}") with self.driver.session(database=self.db_name) as session: result = session.run(query, parameters) - records = [{"id": record["id"], "score": record["score"]} for record in result] + records = [] + for record in result: + item = {"id": record["id"], "score": record["score"]} + if return_fields: + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + records.append(item) # Threshold filtering after retrieval if threshold is not None: @@ -943,8 +966,8 @@ def search_by_fulltext( **kwargs, ) -> list[dict]: """ - TODO: 实现 Neo4j 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径. - 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错. + TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path. + Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j. """ return [] diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index cae7d6ca5..09ad46c42 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -246,6 +246,39 @@ def get_children_with_embeddings( return child_nodes + def _fetch_return_fields( + self, + ids: list[str], + score_map: dict[str, float], + return_fields: list[str], + ) -> list[dict]: + """Fetch additional fields from Neo4j for given node IDs.""" + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + return_clause = "RETURN n.id AS id" + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + + query = f""" + MATCH (n:Memory) + WHERE n.id IN $ids + {return_clause} + """ + with self.driver.session(database=self.db_name) as session: + neo4j_results = session.run(query, {"ids": ids}) + results = [] + for record in neo4j_results: + node_id = record["id"] + item = {"id": node_id, "score": score_map.get(node_id)} + record_keys = record.keys() + for field in return_fields: + if field != "id" and field in record_keys: + item[field] = record[field] + results.append(item) + return results + # Search / recall operations def search_by_embedding( self, @@ -258,6 +291,7 @@ def search_by_embedding( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -273,9 +307,14 @@ def search_by_embedding( filter (dict, optional): Filter conditions with 'and' or 'or' logic for search results. Example: {"and": [{"id": "xxx"}, {"A": "yyy"}]} or {"or": [{"id": "xxx"}, {"A": "yyy"}]} knowledgebase_ids (list[str], optional): List of knowledgebase IDs to filter by. + return_fields (list[str], optional): Additional node fields to include in results + (e.g., ["memory", "status", "tags"]). When provided, each result dict will + contain these fields in addition to 'id' and 'score'. + Defaults to None (only 'id' and 'score' are returned). Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. + If return_fields is specified, each dict also includes the requested fields. Notes: - This method uses an external vector database (not Neo4j) to perform the search. @@ -320,7 +359,14 @@ def search_by_embedding( # If no filter or knowledgebase_ids provided, return vector search results directly if not filter and not knowledgebase_ids: - return [{"id": r.id, "score": r.score} for r in vec_results] + if not return_fields: + return [{"id": r.id, "score": r.score} for r in vec_results] + # Need to fetch additional fields from Neo4j + vec_ids = [r.id for r in vec_results] + if not vec_ids: + return [] + score_map = {r.id: r.score for r in vec_results} + return self._fetch_return_fields(vec_ids, score_map, return_fields) # Extract IDs from vector search results vec_ids = [r.id for r in vec_results] @@ -363,22 +409,49 @@ def search_by_embedding( if filter_params: params.update(filter_params) + # Build RETURN clause with optional extra fields + return_clause = "RETURN n.id AS id" + if return_fields: + validated_fields = self._validate_return_fields(return_fields) + extra_fields = ", ".join( + f"n.{field} AS {field}" for field in validated_fields if field != "id" + ) + if extra_fields: + return_clause = f"RETURN n.id AS id, {extra_fields}" + # Query Neo4j to filter results query = f""" MATCH (n:Memory) {where_clause} - RETURN n.id AS id + {return_clause} """ logger.info(f"[search_by_embedding] query: {query}, params: {params}") with self.driver.session(database=self.db_name) as session: neo4j_results = session.run(query, params) - filtered_ids = {record["id"] for record in neo4j_results} + if return_fields: + # Build a map of id -> extra fields from Neo4j results + neo4j_data = {} + for record in neo4j_results: + node_id = record["id"] + record_keys = record.keys() + neo4j_data[node_id] = { + field: record[field] + for field in return_fields + if field != "id" and field in record_keys + } + filtered_ids = set(neo4j_data.keys()) + else: + filtered_ids = {record["id"] for record in neo4j_results} # Filter vector results by Neo4j filtered IDs and return with scores - filtered_results = [ - {"id": r.id, "score": r.score} for r in vec_results if r.id in filtered_ids - ] + filtered_results = [] + for r in vec_results: + if r.id in filtered_ids: + item = {"id": r.id, "score": r.score} + if return_fields and r.id in neo4j_data: + item.update(neo4j_data[r.id]) + filtered_results.append(item) return filtered_results @@ -397,8 +470,8 @@ def search_by_fulltext( **kwargs, ) -> list[dict]: """ - TODO: 实现 Neo4j Community 的关键词检索, 以兼容 TreeTextMemory 的 keyword/fulltext 召回路径. - 目前先返回空列表, 避免切换到 Neo4j 后因缺失方法导致运行时报错. + TODO: Implement fulltext search for Neo4j to be compatible with TreeTextMemory's keyword/fulltext recall path. + Currently, return an empty list to avoid runtime errors due to missing methods when switching to Neo4j. """ return [] @@ -1122,7 +1195,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge embeddings into parsed nodes for parsed_node in parsed_nodes: node_id = parsed_node["id"] - parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id, None) + parsed_node["metadata"]["embedding"] = vec_items_map.get(node_id) return parsed_nodes diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index f0a23e39b..592f45a7f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -204,21 +204,6 @@ def _get_connection_old(self): return conn def _get_connection(self): - """ - Get a connection from the pool. - - This function: - 1. Gets a connection from ThreadedConnectionPool - 2. Checks if connection is closed or unhealthy - 3. Returns healthy connection or retries (max 3 times) - 4. Handles connection pool exhaustion gracefully - - Returns: - psycopg2 connection object - - Raises: - RuntimeError: If connection pool is closed or exhausted after retries - """ logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") if self._pool_closed: raise RuntimeError("Connection pool has been closed") @@ -229,13 +214,9 @@ def _get_connection(self): for attempt in range(max_retries): conn = None try: - # Try to get connection from pool - # This may raise PoolError if pool is exhausted conn = self.connection_pool.getconn() - # Check if connection is closed if conn.closed != 0: - # Connection is closed, return it to pool with close flag and try again logger.warning( f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" ) @@ -295,19 +276,17 @@ def _get_connection(self): return conn except psycopg2.pool.PoolError as pool_error: - # Pool exhausted or other pool-related error - # Don't retry immediately for pool exhaustion - it's unlikely to resolve quickly error_msg = str(pool_error).lower() if "exhausted" in error_msg or "pool" in error_msg: # Log pool status for debugging try: # Try to get pool stats if available pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.error( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" + logger.info( + f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" ) except Exception: - logger.error( + logger.warning( f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" ) @@ -323,7 +302,6 @@ def _get_connection(self): raise RuntimeError( f"Connection pool exhausted after {max_retries} attempts. " f"This usually means connections are not being returned to the pool. " - f"Check for connection leaks in your code." ) from pool_error else: # Other pool errors - retry with normal backoff @@ -337,12 +315,8 @@ def _get_connection(self): ) from pool_error except Exception as e: - # Other exceptions (not pool-related) - # Only try to return connection if we actually got one - # If getconn() failed (e.g., pool exhausted), conn will be None if conn is not None: try: - # Return connection to pool if it's valid self.connection_pool.putconn(conn, close=True) except Exception as putconn_error: logger.warning( @@ -363,20 +337,7 @@ def _get_connection(self): raise RuntimeError("Failed to get connection after all retries") def _return_connection(self, connection): - """ - Return a connection to the pool. - - This function safely returns a connection to the pool, handling: - - Closed connections (close them instead of returning) - - Pool closed state (close connection directly) - - None connections (no-op) - - putconn() failures (close connection as fallback) - - Args: - connection: psycopg2 connection object or None - """ if self._pool_closed: - # Pool is closed, just close the connection if it exists if connection: try: connection.close() @@ -388,13 +349,10 @@ def _return_connection(self, connection): return if not connection: - # No connection to return - this is normal if _get_connection() failed return try: - # Check if connection is closed if hasattr(connection, "closed") and connection.closed != 0: - # Connection is closed, just close it explicitly and don't return to pool logger.debug( "[_return_connection] Connection is closed, closing it instead of returning to pool" ) @@ -404,12 +362,9 @@ def _return_connection(self, connection): logger.warning(f"[_return_connection] Failed to close closed connection: {e}") return - # Connection is valid, return to pool self.connection_pool.putconn(connection) logger.debug("[_return_connection] Successfully returned connection to pool") except Exception as e: - # If putconn fails, try to close the connection - # This prevents connection leaks if putconn() fails logger.error( f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True ) @@ -841,8 +796,8 @@ def add_edge( start_time = time.time() if not source_id or not target_id: - logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") - raise ValueError("[add_edge] source_id and target_id must be provided") + logger.error(f"Edge '{source_id}' and '{target_id}' are both None") + return source_exists = self.get_node(source_id) is not None target_exists = self.get_node(target_id) is not None @@ -851,7 +806,7 @@ def add_edge( logger.warning( "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists ) - raise ValueError("[add_edge] source_id and target_id must be provided") + return properties = {} if user_name is not None: @@ -1116,9 +1071,7 @@ def get_node( self._return_connection(conn) @timed - def get_nodes( - self, ids: list[str], user_name: str | None = None, **kwargs - ) -> list[dict[str, Any]]: + def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: @@ -1690,6 +1643,36 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + def _extract_fields_from_properties( + self, properties: Any, return_fields: list[str] + ) -> dict[str, Any]: + """Extract requested fields from a PolarDB properties agtype/JSON value. + + Args: + properties: The raw properties value from a PolarDB row (agtype or JSON string). + return_fields: List of field names to extract. + + Returns: + dict with field_name -> value for each requested field found in properties. + """ + result = {} + return_fields = self._validate_return_fields(return_fields) + if not properties or not return_fields: + return result + try: + if isinstance(properties, str): + props = json.loads(properties) + elif isinstance(properties, dict): + props = properties + else: + props = json.loads(str(properties)) + except (json.JSONDecodeError, TypeError, ValueError): + return result + for field in return_fields: + if field != "id" and field in props: + result[field] = props[field] + return result + @timed def search_by_keywords_like( self, @@ -1700,6 +1683,7 @@ def search_by_keywords_like( user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1751,10 +1735,14 @@ def search_by_keywords_like( where_clauses.append("""(properties -> '"memory"')::text LIKE %s""") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1775,7 +1763,11 @@ def search_by_keywords_like( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) @@ -1795,6 +1787,7 @@ def search_by_keywords_tfidf( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebaqry", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: where_clauses = [] @@ -1850,10 +1843,14 @@ def search_by_keywords_tfidf( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" # Build fulltext search query - query = f""" - SELECT + select_clause = """SELECT ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text + agtype_object_field_text(properties, 'memory') as memory_text""" + if return_fields: + select_clause += ", properties" + + query = f""" + {select_clause} FROM "{self.db_name}_graph"."Memory" {where_clause} """ @@ -1874,7 +1871,11 @@ def search_by_keywords_tfidf( id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): id_val = id_val[1:-1] - output.append({"id": id_val}) + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) logger.info( f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" @@ -1897,6 +1898,7 @@ def search_by_fulltext( knowledgebase_ids: list[str] | None = None, tsvector_field: str = "properties_tsvector_zh", tsquery_config: str = "jiebacfg", + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: """ @@ -1914,15 +1916,16 @@ def search_by_fulltext( filter: filter conditions with 'and' or 'or' logic for search results. tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + return_fields: additional node fields to include in results **kwargs: other parameters (e.g. cube_name) Returns: - list[dict]: result list containing id and score + list[dict]: result list containing id and score. + If return_fields is specified, each dict also includes the requested fields. """ logger.info( f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" ) - # Build WHERE clause dynamically, same as search_by_embedding start_time = time.time() where_clauses = [] @@ -1966,13 +1969,10 @@ def search_by_fulltext( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) - # Add fulltext search condition - # Convert query_text to OR query format: "word1 | word2 | word3" tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") @@ -1981,19 +1981,31 @@ def search_by_fulltext( logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - # Build fulltext search query + select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, + ts_rank(m.{tsvector_field}, q.fq) AS rank""" + if return_fields: + select_cols += ", m.properties" + where_with_q = [] + for w in where_clauses: + if f"{tsvector_field} @@ to_tsquery(" in w: + where_with_q.append(f"m.{tsvector_field} @@ q.fq") + else: + where_with_q.append( + w.replace("(properties,", "(m.properties,") + .replace("(properties)", "(m.properties)") + .replace("ARRAY[properties,", "ARRAY[m.properties,") + ) + where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" - SELECT - ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, - agtype_object_field_text(properties, 'memory') as memory_text, - ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank - FROM "{self.db_name}_graph"."Memory" - {where_clause} + WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) + SELECT {select_cols} + FROM "{self.db_name}_graph"."Memory" m + CROSS JOIN q + {where_clause_cte} ORDER BY rank DESC LIMIT {top_k}; """ - - params = [tsquery_string, tsquery_string] + params = [tsquery_string] logger.info(f"[search_by_fulltext] query: {query}, params: {params}") conn = None try: @@ -2004,7 +2016,7 @@ def search_by_fulltext( output = [] for row in results: oldid = row[0] # old_id - rank = row[2] # rank score + rank = row[1] # rank score (no memory_text column) id_val = str(oldid) if id_val.startswith('"') and id_val.endswith('"'): @@ -2013,10 +2025,16 @@ def search_by_fulltext( # Apply threshold filter if specified if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[2] # properties column + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) elapsed_time = time.time() - start_time logger.info( - f" polardb [search_by_fulltext] query completed time in {elapsed_time:.2f}s" + f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" ) return output[:top_k] finally: @@ -2026,23 +2044,21 @@ def search_by_fulltext( def search_by_embedding( self, vector: list[float], + user_name: str, top_k: int = 5, scope: str | None = None, status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list[str] | None = None, + return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity using PostgreSQL vector operations. - """ - # Build WHERE clause dynamically like nebular.py logger.info( - f"[search_by_embedding] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}" + f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}" ) + start_time = time.time() where_clauses = [] if scope: where_clauses.append( @@ -2057,31 +2073,18 @@ def search_by_embedding( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) where_clauses.append("embedding is not null") - # Add user_name filter like nebular.py - - """ - # user_name = self._get_config_value("user_name") - # if not self.config.use_multi_db and user_name: - # if kwargs.get("cube_name"): - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{kwargs['cube_name']}\"'::agtype") - # else: - # where_clauses.append(f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype") - """ - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions like nebular.py if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -2093,14 +2096,12 @@ def search_by_embedding( f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" ) - # Build filter conditions using common method filter_conditions = self._build_filter_conditions_sql(filter) logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - # Keep original simple query structure but add dynamic WHERE clause query = f""" WITH t AS ( SELECT id, @@ -2117,19 +2118,12 @@ def search_by_embedding( FROM t WHERE scope > 0.1; """ - # Convert vector to string format for PostgreSQL vector type - # PostgreSQL vector type expects a string format like '[1,2,3]' vector_str = convert_to_vector(vector) - # Use string format directly in query instead of parameterized query - # Replace %s with the vector string, but need to quote it properly - # PostgreSQL vector type needs the string to be quoted query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)") params = [] - # Split query by lines and wrap long lines to prevent terminal truncation query_lines = query.strip().split("\n") for line in query_lines: - # Wrap lines longer than 200 characters to prevent terminal truncation if len(line) > 200: wrapped_lines = textwrap.wrap( line, width=200, break_long_words=False, break_on_hyphens=False @@ -2145,28 +2139,13 @@ def search_by_embedding( try: conn = self._get_connection() with conn.cursor() as cursor: - try: - # If params is empty, execute query directly without parameters - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - except Exception as e: - logger.error(f"[search_by_embedding] Error executing query: {e}") - logger.error(f"[search_by_embedding] Query length: {len(query)}") - logger.error( - f"[search_by_embedding] Params type: {type(params)}, length: {len(params)}" - ) - logger.error(f"[search_by_embedding] Query contains %s: {'%s' in query}") - raise + if params: + cursor.execute(query, params) + else: + cursor.execute(query) results = cursor.fetchall() output = [] for row in results: - """ - polarId = row[0] # id - properties = row[1] # properties - # embedding = row[3] # embedding - """ if len(row) < 5: logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") continue @@ -2178,7 +2157,17 @@ def search_by_embedding( score_val = float(score) score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update( + self._extract_fields_from_properties(properties, return_fields) + ) + output.append(item) + elapsed_time = time.time() - start_time + logger.info( + f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" + ) return output[:top_k] finally: self._return_connection(conn) @@ -2187,7 +2176,7 @@ def search_by_embedding( def get_by_metadata( self, filters: list[dict[str, Any]], - user_name: str | None = None, + user_name: str, filter: dict | None = None, knowledgebase_ids: list | None = None, user_name_flag: bool = True, @@ -2209,7 +2198,9 @@ def get_by_metadata( Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). """ - logger.info(f"[get_by_metadata] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}") + logger.info( + f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}" + ) user_name = user_name if user_name else self._get_config_value("user_name") @@ -2264,9 +2255,6 @@ def get_by_metadata( else: raise ValueError(f"Unsupported operator: {op}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, @@ -2306,7 +2294,7 @@ def get_by_metadata( results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: - logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") + logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") finally: self._return_connection(conn) @@ -2536,8 +2524,8 @@ def clear(self, user_name: str | None = None) -> None: @timed def export_graph( self, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, user_id: str | None = None, page: int | None = None, page_size: int | None = None, @@ -2576,7 +2564,7 @@ def export_graph( } """ logger.info( - f"[export_graph] include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" + f" export_graph include_embedding: {include_embedding}, user_name: {user_name}, user_id: {user_id}, page: {page}, page_size: {page_size}, filter: {filter}, memory_type: {memory_type}, status: {status}" ) user_id = user_id if user_id else self._get_config_value("user_id") @@ -2724,159 +2712,7 @@ def export_graph( finally: self._return_connection(conn) - conn = None - try: - conn = self._get_connection() - # Build Cypher WHERE conditions for edges - cypher_where_conditions = [] - if user_name: - cypher_where_conditions.append(f"a.user_name = '{user_name}'") - cypher_where_conditions.append(f"b.user_name = '{user_name}'") - if user_id: - cypher_where_conditions.append(f"a.user_id = '{user_id}'") - cypher_where_conditions.append(f"b.user_id = '{user_id}'") - - # Add memory_type filter condition for edges (apply to both source and target nodes) - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape single quotes in memory_type values for Cypher - escaped_memory_types = [mt.replace("'", "\\'") for mt in memory_type] - memory_type_list_str = ", ".join([f"'{mt}'" for mt in escaped_memory_types]) - # Cypher IN syntax: a.memory_type IN ['LongTermMemory', 'WorkingMemory'] - cypher_where_conditions.append(f"a.memory_type IN [{memory_type_list_str}]") - cypher_where_conditions.append(f"b.memory_type IN [{memory_type_list_str}]") - - # Add status filter for edges: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - cypher_where_conditions.append("a.status <> 'deleted' AND b.status <> 'deleted'") - elif isinstance(status, list) and len(status) > 0: - escaped_statuses = [st.replace("'", "\\'") for st in status] - status_list_str = ", ".join([f"'{st}'" for st in escaped_statuses]) - cypher_where_conditions.append(f"a.status IN [{status_list_str}]") - cypher_where_conditions.append(f"b.status IN [{status_list_str}]") - - # Build filter conditions for edges (apply to both source and target nodes) - filter_where_clause = self._build_filter_conditions_cypher(filter) - logger.info(f"[export_graph edges] filter_where_clause: {filter_where_clause}") - if filter_where_clause: - # _build_filter_conditions_cypher returns a string that starts with " AND " if filter exists - # Remove the leading " AND " and replace n. with a. for source node and b. for target node - filter_clause = filter_where_clause.strip() - if filter_clause.startswith("AND "): - filter_clause = filter_clause[4:].strip() - # Replace n. with a. for source node and create a copy for target node - source_filter = filter_clause.replace("n.", "a.") - target_filter = filter_clause.replace("n.", "b.") - # Combine source and target filters with AND - combined_filter = f"({source_filter}) AND ({target_filter})" - cypher_where_conditions.append(combined_filter) - - cypher_where_clause = "" - if cypher_where_conditions: - cypher_where_clause = f"WHERE {' AND '.join(cypher_where_conditions)}" - - # Get total count of edges before pagination - count_edge_query = f""" - SELECT COUNT(*) - FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - """ - logger.info(f"[export_graph edges count] Query: {count_edge_query}") - with conn.cursor() as cursor: - cursor.execute(count_edge_query) - total_edges = cursor.fetchone()[0] - - # Export edges using cypher query - # Note: Apache AGE Cypher may not support SKIP, so we use SQL LIMIT/OFFSET on the subquery - # Build pagination clause if needed - edge_pagination_clause = "" - if use_pagination: - edge_pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - edge_query = f""" - SELECT source, target, edge FROM ( - SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (a:Memory)-[r]->(b:Memory) - {cypher_where_clause} - RETURN a.id AS source, b.id AS target, type(r) as edge - ORDER BY COALESCE(a.created_at, '1970-01-01T00:00:00') DESC, - COALESCE(b.created_at, '1970-01-01T00:00:00') DESC, - a.id DESC, b.id DESC - $$) AS (source agtype, target agtype, edge agtype) - ) AS edges - {edge_pagination_clause} - """ - logger.info(f"[export_graph edges] Query: {edge_query}") - with conn.cursor() as cursor: - cursor.execute(edge_query) - edge_results = cursor.fetchall() - edges = [] - - for row in edge_results: - source_agtype, target_agtype, edge_agtype = row - - # Extract and clean source - source_raw = ( - source_agtype.value - if hasattr(source_agtype, "value") - else str(source_agtype) - ) - if ( - isinstance(source_raw, str) - and source_raw.startswith('"') - and source_raw.endswith('"') - ): - source = source_raw[1:-1] - else: - source = str(source_raw) - - # Extract and clean target - target_raw = ( - target_agtype.value - if hasattr(target_agtype, "value") - else str(target_agtype) - ) - if ( - isinstance(target_raw, str) - and target_raw.startswith('"') - and target_raw.endswith('"') - ): - target = target_raw[1:-1] - else: - target = str(target_raw) - - # Extract and clean edge type - type_raw = ( - edge_agtype.value if hasattr(edge_agtype, "value") else str(edge_agtype) - ) - if ( - isinstance(type_raw, str) - and type_raw.startswith('"') - and type_raw.endswith('"') - ): - edge_type = type_raw[1:-1] - else: - edge_type = str(type_raw) - - edges.append( - { - "source": source, - "target": target, - "type": edge_type, - } - ) - - except Exception as e: - logger.error(f"[EXPORT GRAPH - EDGES] Exception: {e}", exc_info=True) - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - finally: - self._return_connection(conn) - + edges = [] return { "nodes": nodes, "edges": edges, @@ -2908,8 +2744,8 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: def get_all_memory_items( self, scope: str, + user_name: str, include_embedding: bool = False, - user_name: str | None = None, filter: dict | None = None, knowledgebase_ids: list | None = None, status: str | None = None, @@ -2930,14 +2766,13 @@ def get_all_memory_items( list[dict]: Full list of memory items under this scope. """ logger.info( - f"[get_all_memory_items] filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status}" + f"[get_all_memory_items] user_name: {user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids}, status: {status},scope:{scope}" ) user_name = user_name if user_name else self._get_config_value("user_name") if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_cypher( user_name=user_name, knowledgebase_ids=knowledgebase_ids, @@ -3015,7 +2850,7 @@ def get_all_memory_items( node_ids.add(node_id) except Exception as e: - logger.error(f"Failed to get memories: {e}", exc_info=True) + logger.warning(f"Failed to get memories: {e}", exc_info=True) finally: self._return_connection(conn) @@ -4199,34 +4034,47 @@ def get_edges( ... ] """ + start_time = time.time() + logger.info(f" get_edges id:{id},type:{type},direction:{direction},user_name:{user_name}") user_name = user_name if user_name else self._get_config_value("user_name") - - if direction == "OUTGOING": - pattern = "(a:Memory)-[r]->(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = "(a:Memory)<-[r]-(b:Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = "(a:Memory)-[r]-(b:Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: + if direction not in ("OUTGOING", "INCOMING", "ANY"): raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - # Add type filter - if type != "ANY": - where_clause += f" AND type(r) = '{type}'" - - # Add user filter - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + # Escape single quotes for safe embedding in Cypher string + id_esc = (id or "").replace("'", "''") + user_esc = (user_name or "").replace("'", "''") + type_esc = (type or "").replace("'", "''") + type_filter = f" AND type(r) = '{type_esc}'" if type != "ANY" else "" + logger.info(f"type_filter:{type_filter}") + if direction == "OUTGOING": + cypher_body = f""" + MATCH (a:Memory)-[r:{type}]->(b:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}' + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ + elif direction == "INCOMING": + cypher_body = f""" + MATCH (b:Memory)<-[r:{type}]-(a:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}' + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ + else: # ANY: union of OUTGOING and INCOMING + cypher_body = f""" + MATCH (a:Memory)-[r]->(b:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + UNION ALL + MATCH (b:Memory)<-[r]-(a:Memory) + WHERE a.id = '{id_esc}' AND a.user_name = '{user_esc}'{type_filter} + RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + """ query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type + {cypher_body.strip()} $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ + logger.info(f"get_edges query:{query}") conn = None try: conn = self._get_connection() @@ -4270,6 +4118,8 @@ def get_edges( edge_type = str(edge_type_raw) edges.append({"from": from_id, "to": to_id, "type": edge_type}) + elapsed_time = time.time() - start_time + logger.info(f"polardb get_edges query completed time in {elapsed_time:.2f}s") return edges except Exception as e: diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index a27d64758..e3d2bece9 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -721,6 +721,7 @@ def _process_one_item( m_maybe_merged.get("memory_type", "LongTermMemory") .replace("长期记忆", "LongTermMemory") .replace("用户记忆", "UserMemory") + .replace("pref", "UserMemory") ) node = self._make_memory_item( value=m_maybe_merged.get("value", ""), diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 2b49d63ba..1b4add398 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -50,7 +50,9 @@ class FileContentParser(BaseMessageParser): """Parser for file content parts.""" - def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = None) -> dict: + def _get_doc_llm_response( + self, chunk_text: str, custom_tags: list[str] | None = None + ) -> dict | list: """ Call LLM to extract memory from document chunk. Uses doc prompts from DOC_PROMPT_DICT. @@ -60,7 +62,7 @@ def _get_doc_llm_response(self, chunk_text: str, custom_tags: list[str] | None = custom_tags: Optional list of custom tags for LLM extraction Returns: - Parsed JSON response from LLM or empty dict if failed + Parsed JSON response from LLM (dict or list) or empty dict if failed """ if not self.llm: logger.warning("[FileContentParser] LLM not available for fine mode") @@ -777,35 +779,49 @@ def _make_fallback( return [_make_fallback(idx, text, "no_llm") for idx, text in valid_chunks] # Process single chunk with LLM extraction (worker function) - def _process_chunk(chunk_idx: int, chunk_text: str) -> TextualMemoryItem: - """Process chunk with LLM, fallback to raw on failure.""" + def _process_chunk(chunk_idx: int, chunk_text: str) -> list[TextualMemoryItem]: + """Process chunk with LLM, fallback to raw on failure. Returns list of memory items.""" try: response_json = self._get_doc_llm_response(chunk_text, custom_tags) if response_json: - value = response_json.get("value", "").strip() - if value: - tags = response_json.get("tags", []) - tags = tags if isinstance(tags, list) else [] - tags.extend(["mode:fine", "multimodal:file"]) - - llm_mem_type = response_json.get("memory_type", memory_type) - if llm_mem_type not in ["LongTermMemory", "UserMemory"]: - llm_mem_type = memory_type - - return _make_memory_item( - value=value, - mem_type=llm_mem_type, - tags=tags, - key=response_json.get("key"), - chunk_idx=chunk_idx, - chunk_content=chunk_text, - ) + # Handle list format response + response_list = response_json.get("memory list", []) + memory_items = [] + for item_data in response_list: + if not isinstance(item_data, dict): + continue + + value = item_data.get("value", "").strip() + if value: + tags = item_data.get("tags", []) + tags = tags if isinstance(tags, list) else [] + tags.extend(["mode:fine", "multimodal:file"]) + key_str = item_data.get("key", "") + + llm_mem_type = item_data.get("memory_type", memory_type) + if llm_mem_type not in ["LongTermMemory", "UserMemory"]: + llm_mem_type = memory_type + + memory_item = _make_memory_item( + value=value, + mem_type=llm_mem_type, + tags=tags, + key=key_str, + chunk_idx=chunk_idx, + chunk_content=chunk_text, + ) + memory_items.append(memory_item) + + if memory_items: + return memory_items + else: + return [_make_fallback(chunk_idx, chunk_text)] except Exception as e: logger.error(f"[FileContentParser] LLM error for chunk {chunk_idx}: {e}") # Fallback to raw chunk logger.warning(f"[FileContentParser] Fallback to raw for chunk {chunk_idx}") - return _make_fallback(chunk_idx, chunk_text) + return [_make_fallback(chunk_idx, chunk_text)] def _relate_chunks(items: list[TextualMemoryItem]) -> None: """ @@ -853,30 +869,37 @@ def get_chunk_idx(item: TextualMemoryItem) -> int: ): chunk_idx = futures[future] try: - node = future.result() - memory_items.append(node) - - # Check if this node is a fallback by checking tags - is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags) - if is_fallback: - fallback_count += 1 - - # save raw file - node_id = node.id - if node.memory != node.metadata.sources[0].content: - chunk_node = _make_memory_item( - value=node.metadata.sources[0].content, - mem_type="RawFileMemory", - tags=[ - "mode:fine", - "multimodal:file", - f"chunk:{chunk_idx + 1}/{total_chunks}", - ], - chunk_idx=chunk_idx, - chunk_content="", - ) - chunk_node.metadata.summary_ids = [node_id] - memory_items.append(chunk_node) + nodes = future.result() + memory_items.extend(nodes) + + # Check if any node is a fallback by checking tags + has_fallback = False + for node in nodes: + is_fallback = any(tag.startswith("fallback:") for tag in node.metadata.tags) + if is_fallback: + fallback_count += 1 + has_fallback = True + + # save raw file only if no fallback (all nodes are LLM-extracted) + if not has_fallback and nodes: + # Use first node's source info for raw file + first_node = nodes[0] + if first_node.metadata.sources and len(first_node.metadata.sources) > 0: + # Collect all node IDs for summary_ids + node_ids = [node.id for node in nodes] + chunk_node = _make_memory_item( + value=first_node.metadata.sources[0].content, + mem_type="RawFileMemory", + tags=[ + "mode:fine", + "multimodal:file", + f"chunk:{chunk_idx + 1}/{total_chunks}", + ], + chunk_idx=chunk_idx, + chunk_content="", + ) + chunk_node.metadata.summary_ids = node_ids + memory_items.append(chunk_node) except Exception as e: tqdm.write(f"[ERROR] Chunk {chunk_idx} failed: {e}") diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index d39955ac2..a9a727b08 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -1019,7 +1019,9 @@ def process_skill_memory_fine( **kwargs, ) -> list[TextualMemoryItem]: skills_repo_backend = _get_skill_file_storage_location() - oss_client, missing_keys, flag = _skill_init(skills_repo_backend, oss_config, skills_dir_config) + oss_client, _missing_keys, flag = _skill_init( + skills_repo_backend, oss_config, skills_dir_config + ) if not flag: return [] diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py index 63718fd92..e4a88a635 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/add_handler.py @@ -68,33 +68,35 @@ def log_add_messages(self, msg: ScheduleMessageItem): mem_item = mem_cube.text_mem.get(memory_id=memory_id, user_name=msg.mem_cube_id) if mem_item is None: raise ValueError(f"Memory {memory_id} not found after retries") - key = getattr(mem_item.metadata, "key", None) or transform_name_to_key( - name=mem_item.memory - ) - exists = False original_content = None original_item_id = None - if key and hasattr(mem_cube.text_mem, "graph_store"): - candidates = mem_cube.text_mem.graph_store.get_by_metadata( - [ - {"field": "key", "op": "=", "value": key}, - { - "field": "memory_type", - "op": "=", - "value": mem_item.metadata.memory_type, - }, - ] + # Determine add vs update from the merged_from field set by the upstream + # mem_reader during fine extraction. When the LLM merges a new memory with + # existing ones it writes their IDs into metadata.info["merged_from"]. + # This avoids an extra graph DB query and the self-match / cross-user + # matching bugs that came with the old get_by_metadata approach. + merged_from = (getattr(mem_item.metadata, "info", None) or {}).get("merged_from") + if merged_from: + merged_ids = ( + merged_from + if isinstance(merged_from, list | tuple | set) + else [merged_from] ) - if candidates: - exists = True - original_item_id = candidates[0] + original_item_id = merged_ids[0] + try: original_mem_item = mem_cube.text_mem.get( memory_id=original_item_id, user_name=msg.mem_cube_id ) - original_content = original_mem_item.memory + original_content = original_mem_item.memory if original_mem_item else None + except Exception as e: + logger.warning( + "Failed to fetch original memory %s for update log: %s", + original_item_id, + e, + ) - if exists: + if merged_from: prepared_update_items_with_original.append( { "new_item": mem_item, diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 5d86c5589..20dbb63b2 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -259,13 +259,19 @@ def _process_memories_with_reader( source_doc_id = ( file_ids[0] if isinstance(file_ids, list) and file_ids else None ) + # Use merged_from to determine ADD vs UPDATE. + # The upstream mem_reader sets this during fine extraction when + # the new memory was merged with an existing one. + item_merged_from = (getattr(item.metadata, "info", None) or {}).get( + "merged_from" + ) kb_log_content.append( { "log_source": "KNOWLEDGE_BASE_LOG", "trigger_source": info.get("trigger_source", "Messages") if info else "Messages", - "operation": "ADD", + "operation": "UPDATE" if item_merged_from else "ADD", "memory_id": item.id, "content": item.memory, "original_content": None, @@ -302,29 +308,39 @@ def _process_memories_with_reader( else: add_content_legacy: list[dict] = [] add_meta_legacy: list[dict] = [] + update_content_legacy: list[dict] = [] + update_meta_legacy: list[dict] = [] for item_id, item in zip( enhanced_mem_ids, flattened_memories, strict=False ): key = getattr(item.metadata, "key", None) or transform_name_to_key( name=item.memory ) - add_content_legacy.append( - {"content": f"{key}: {item.memory}", "ref_id": item_id} - ) - add_meta_legacy.append( - { - "ref_id": item_id, - "id": item_id, - "key": item.metadata.key, - "memory": item.memory, - "memory_type": item.metadata.memory_type, - "status": item.metadata.status, - "confidence": item.metadata.confidence, - "tags": item.metadata.tags, - "updated_at": getattr(item.metadata, "updated_at", None) - or getattr(item.metadata, "update_at", None), - } + item_merged_from = (getattr(item.metadata, "info", None) or {}).get( + "merged_from" ) + meta_entry = { + "ref_id": item_id, + "id": item_id, + "key": item.metadata.key, + "memory": item.memory, + "memory_type": item.metadata.memory_type, + "status": item.metadata.status, + "confidence": item.metadata.confidence, + "tags": item.metadata.tags, + "updated_at": getattr(item.metadata, "updated_at", None) + or getattr(item.metadata, "update_at", None), + } + if item_merged_from: + update_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + update_meta_legacy.append(meta_entry) + else: + add_content_legacy.append( + {"content": f"{key}: {item.memory}", "ref_id": item_id} + ) + add_meta_legacy.append(meta_entry) if add_content_legacy: event = self.scheduler_context.services.create_event_log( label="addMemory", @@ -342,6 +358,23 @@ def _process_memories_with_reader( ) event.task_id = task_id self.scheduler_context.services.submit_web_logs([event]) + if update_content_legacy: + event = self.scheduler_context.services.create_event_log( + label="updateMemory", + from_memory_type=USER_INPUT_TYPE, + to_memory_type=LONG_TERM_MEMORY_TYPE, + user_id=user_id, + mem_cube_id=mem_cube_id, + mem_cube=self.scheduler_context.get_mem_cube(), + memcube_log_content=update_content_legacy, + metadata=update_meta_legacy, + memory_len=len(update_content_legacy), + memcube_name=self.scheduler_context.services.map_memcube_name( + mem_cube_id + ), + ) + event.task_id = task_id + self.scheduler_context.services.submit_web_logs([event]) else: logger.info("No enhanced memories generated by mem_reader") else: diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 6352d5840..8483a5151 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -124,25 +124,45 @@ def retrieve( explicit_prefs.sort(key=lambda x: x.score, reverse=True) implicit_prefs.sort(key=lambda x: x.score, reverse=True) - explicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + explicit_prefs_mem = [] + for pref in explicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + explicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in explicit_prefs - if pref.payload.get("preference", None) - ] - implicit_prefs_mem = [ - TextualMemoryItem( - id=pref.id, - memory=pref.memory, - metadata=PreferenceTextualMemoryMetadata(**pref.payload), + implicit_prefs_mem = [] + for pref in implicit_prefs: + if not pref.payload.get("preference", None): + continue + if "embedding" in pref.payload: + payload = pref.payload + else: + pref_vector = getattr(pref, "vector", None) + if pref_vector is None: + payload = pref.payload + else: + payload = {**pref.payload, "embedding": pref_vector} + implicit_prefs_mem.append( + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**payload), + ) ) - for pref in implicit_prefs - if pref.payload.get("preference", None) - ] reranker_map = { "naive": self._naive_reranker, diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5faf8aa09..5b210ba61 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -404,10 +404,10 @@ def delete_by_memory_ids(self, memory_ids: list[str]) -> None: except Exception as e: logger.error(f"An error occurred while deleting memories by memory_ids: {e}") - def delete_all(self) -> None: + def delete_all(self, user_name: str | None = None) -> None: """Delete all memories and their relationships from the graph store.""" try: - self.graph_store.clear() + self.graph_store.clear(user_name=user_name) logger.info("All memories and edges have been deleted from the graph.") except Exception as e: logger.error(f"An error occurred while deleting all memories: {e}") @@ -424,7 +424,7 @@ def delete_by_filter( writable_cube_ids=writable_cube_ids, file_ids=file_ids, filter=filter ) - def load(self, dir: str) -> None: + def load(self, dir: str, user_name: str | None = None) -> None: try: memory_file = os.path.join(dir, self.config.memory_filename) @@ -435,7 +435,7 @@ def load(self, dir: str) -> None: with open(memory_file, encoding="utf-8") as f: memories = json.load(f) - self.graph_store.import_graph(memories) + self.graph_store.import_graph(memories, user_name=user_name) logger.info(f"Loaded {len(memories)} memories from {memory_file}") except FileNotFoundError: @@ -445,10 +445,12 @@ def load(self, dir: str) -> None: except Exception as e: logger.error(f"An error occurred while loading memories: {e}") - def dump(self, dir: str, include_embedding: bool = False) -> None: + def dump(self, dir: str, include_embedding: bool = False, user_name: str | None = None) -> None: """Dump memories to os.path.join(dir, self.config.memory_filename)""" try: - json_memories = self.graph_store.export_graph(include_embedding=include_embedding) + json_memories = self.graph_store.export_graph( + include_embedding=include_embedding, user_name=user_name + ) os.makedirs(dir, exist_ok=True) memory_file = os.path.join(dir, self.config.memory_filename) diff --git a/src/memos/memories/textual/tree_text_memory/organize/handler.py b/src/memos/memories/textual/tree_text_memory/organize/handler.py index 595cf099c..2d776912b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/handler.py +++ b/src/memos/memories/textual/tree_text_memory/organize/handler.py @@ -27,18 +27,24 @@ def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedd self.llm = llm self.embedder = embedder - def detect(self, memory, top_k: int = 5, scope=None): + def detect(self, memory, top_k: int = 5, scope=None, user_name: str | None = None): # 1. Search for similar memories based on embedding embedding = memory.metadata.embedding embedding_candidates_info = self.graph_store.search_by_embedding( - embedding, top_k=top_k, scope=scope, threshold=self.EMBEDDING_THRESHOLD + embedding, + top_k=top_k, + scope=scope, + threshold=self.EMBEDDING_THRESHOLD, + user_name=user_name, ) # 2. Filter based on similarity threshold embedding_candidates_ids = [ info["id"] for info in embedding_candidates_info if info["id"] != memory.id ] # 3. Judge conflicts using LLM - embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids) + embedding_candidates = self.graph_store.get_nodes( + embedding_candidates_ids, user_name=user_name + ) detected_relationships = [] for embedding_candidate in embedding_candidates: embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate) @@ -67,13 +73,20 @@ def detect(self, memory, top_k: int = 5, scope=None): pass return detected_relationships - def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, relation) -> None: + def resolve( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + relation, + user_name: str | None = None, + ) -> None: """ Resolve detected conflicts between two memory items using LLM fusion. Args: memory_a: The first conflicting memory item. memory_b: The second conflicting memory item. relation: relation + user_name: Optional user name for multi-tenant isolation. Returns: A fused TextualMemoryItem representing the resolved memory. """ @@ -105,17 +118,22 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, rela logger.warning( f"{relation} between {memory_a.id} and {memory_b.id} could not be resolved. " ) - self._hard_update(memory_a, memory_b) + self._hard_update(memory_a, memory_b, user_name=user_name) # —————— 2.2 Conflict resolved, update metadata and memory ———— else: fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata) merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata) logger.info(f"Resolved result: {merged_memory}") - self._resolve_in_graph(memory_a, memory_b, merged_memory) + self._resolve_in_graph(memory_a, memory_b, merged_memory, user_name=user_name) except json.decoder.JSONDecodeError: logger.error(f"Failed to parse LLM response: {response}") - def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem): + def _hard_update( + self, + memory_a: TextualMemoryItem, + memory_b: TextualMemoryItem, + user_name: str | None = None, + ): """ Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata. """ @@ -125,7 +143,7 @@ def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) newer_mem = memory_a if time_a >= time_b else memory_b older_mem = memory_b if time_a >= time_b else memory_a - self.graph_store.delete_node(older_mem.id) + self.graph_store.delete_node(older_mem.id, user_name=user_name) logger.warning( f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>" ) @@ -135,13 +153,21 @@ def _resolve_in_graph( conflict_a: TextualMemoryItem, conflict_b: TextualMemoryItem, merged: TextualMemoryItem, + user_name: str | None = None, ): - edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY") - edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY") + edges_a = self.graph_store.get_edges( + conflict_a.id, type="ANY", direction="ANY", user_name=user_name + ) + edges_b = self.graph_store.get_edges( + conflict_b.id, type="ANY", direction="ANY", user_name=user_name + ) all_edges = edges_a + edges_b self.graph_store.add_node( - merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True) + merged.id, + merged.memory, + merged.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for edge in all_edges: @@ -150,13 +176,15 @@ def _resolve_in_graph( if new_from == new_to: continue # Check if the edge already exists before adding - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) - - self.graph_store.update_node(conflict_a.id, {"status": "archived"}) - self.graph_store.update_node(conflict_b.id, {"status": "archived"}) - self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO") - self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO") + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) + + self.graph_store.update_node(conflict_a.id, {"status": "archived"}, user_name=user_name) + self.graph_store.update_node(conflict_b.id, {"status": "archived"}, user_name=user_name) + self.graph_store.add_edge(conflict_a.id, merged.id, type="MERGED_TO", user_name=user_name) + self.graph_store.add_edge(conflict_b.id, merged.id, type="MERGED_TO", user_name=user_name) logger.debug( f"Archive {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}." ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 1afdc9281..132582a0d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -141,6 +141,7 @@ def mark_memory_status( self, memory_items: list[TextualMemoryItem], status: Literal["activated", "resolving", "archived", "deleted"], + user_name: str | None = None, ) -> None: """ Support status marking operations during history management. Common usages are: @@ -157,6 +158,7 @@ def mark_memory_status( self.graph_db.update_node, id=mem.id, fields={"status": status}, + user_name=user_name, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index cbc349d67..4ca30c7b8 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -238,7 +238,9 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: - self.reorganizer.add_message(QueueMessage(op="add", after_node=graph_node_ids)) + self.reorganizer.add_message( + QueueMessage(op="add", after_node=graph_node_ids, user_name=user_name) + ) return added_ids @@ -411,16 +413,19 @@ def _add_to_graph_memory( QueueMessage( op="add", after_node=[node_id], + user_name=user_name, ) ) return node_id - def _inherit_edges(self, from_id: str, to_id: str) -> None: + def _inherit_edges(self, from_id: str, to_id: str, user_name: str | None = None) -> None: """ Migrate all non-lineage edges from `from_id` to `to_id`, and remove them from `from_id` after copying. """ - edges = self.graph_store.get_edges(from_id, type="ANY", direction="ANY") + edges = self.graph_store.get_edges( + from_id, type="ANY", direction="ANY", user_name=user_name + ) for edge in edges: if edge["type"] == "MERGED_TO": @@ -433,20 +438,29 @@ def _inherit_edges(self, from_id: str, to_id: str) -> None: continue # Add edge to merged node if it doesn't already exist - if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): - self.graph_store.add_edge(new_from, new_to, edge["type"]) + if not self.graph_store.edge_exists( + new_from, new_to, edge["type"], direction="ANY", user_name=user_name + ): + self.graph_store.add_edge(new_from, new_to, edge["type"], user_name=user_name) # Remove original edge if it involved the archived node - self.graph_store.delete_edge(edge["from"], edge["to"], edge["type"]) + self.graph_store.delete_edge( + edge["from"], edge["to"], edge["type"], user_name=user_name + ) def _ensure_structure_path( - self, memory_type: str, metadata: TreeNodeTextualMemoryMetadata + self, + memory_type: str, + metadata: TreeNodeTextualMemoryMetadata, + user_name: str | None = None, ) -> str: """ Ensure structural path exists (ROOT → ... → final node), return last node ID. Args: - path: like ["hobby", "photography"] + memory_type: Memory type for the structure node. + metadata: Metadata containing key and other fields. + user_name: Optional user name for multi-tenant isolation. Returns: Final node ID of the structure path. @@ -456,7 +470,8 @@ def _ensure_structure_path( [ {"field": "memory", "op": "=", "value": metadata.key}, {"field": "memory_type", "op": "=", "value": memory_type}, - ] + ], + user_name=user_name, ) if existing: node_id = existing[0] # Use the first match @@ -479,14 +494,16 @@ def _ensure_structure_path( ), ) self.graph_store.add_node( - id=new_node.id, - memory=new_node.memory, - metadata=new_node.metadata.model_dump(exclude_none=True), + new_node.id, + new_node.memory, + new_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( op="add", after_node=[new_node.id], + user_name=user_name, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py index ea06a7c60..b7fb6b1a0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py +++ b/src/memos/memories/textual/tree_text_memory/organize/reorganizer.py @@ -52,12 +52,14 @@ def __init__( before_edge: list[str] | list[GraphDBEdge] | None = None, after_node: list[str] | list[GraphDBNode] | None = None, after_edge: list[str] | list[GraphDBEdge] | None = None, + user_name: str | None = None, ): self.op = op self.before_node = before_node self.before_edge = before_edge self.after_node = after_node self.after_edge = after_edge + self.user_name = user_name def __str__(self) -> str: return f"QueueMessage(op={self.op}, before_node={self.before_node if self.before_node is None else len(self.before_node)}, after_node={self.after_node if self.after_node is None else len(self.after_node)})" @@ -191,11 +193,15 @@ def handle_add(self, message: QueueMessage): logger.debug(f"Handling add operation: {str(message)[:500]}") added_node = message.after_node[0] detected_relationships = self.resolver.detect( - added_node, scope=added_node.metadata.memory_type + added_node, + scope=added_node.metadata.memory_type, + user_name=message.user_name, ) if detected_relationships: for added_node, existing_node, relation in detected_relationships: - self.resolver.resolve(added_node, existing_node, relation) + self.resolver.resolve( + added_node, existing_node, relation, user_name=message.user_name + ) self._reorganize_needed = True @@ -209,6 +215,7 @@ def optimize_structure( min_cluster_size: int = 4, min_group_size: int = 20, max_duration_sec: int = 600, + user_name: str | None = None, ): """ Periodically reorganize the graph: @@ -232,7 +239,7 @@ def _check_deadline(where: str): logger.info(f"[GraphStructureReorganize] Already optimizing for {scope}. Skipping.") return - if self.graph_store.node_not_exist(scope): + if self.graph_store.node_not_exist(scope, user_name=user_name): logger.debug(f"[GraphStructureReorganize] No nodes for scope={scope}. Skip.") return @@ -244,12 +251,14 @@ def _check_deadline(where: str): logger.debug( f"[GraphStructureReorganize] Num of scope in self.graph_store is" - f" {self.graph_store.get_memory_count(scope)}" + f" {self.graph_store.get_memory_count(scope, user_name=user_name)}" ) # Load candidate nodes if _check_deadline("[GraphStructureReorganize] Before loading candidates"): return - raw_nodes = self.graph_store.get_structure_optimization_candidates(scope) + raw_nodes = self.graph_store.get_structure_optimization_candidates( + scope, user_name=user_name + ) nodes = [GraphDBNode(**n) for n in raw_nodes] if not nodes: @@ -281,6 +290,7 @@ def _check_deadline(where: str): scope, local_tree_threshold, min_cluster_size, + user_name, ) ) @@ -307,6 +317,7 @@ def _process_cluster_and_write( scope: str, local_tree_threshold: int, min_cluster_size: int, + user_name: str | None = None, ): if len(cluster_nodes) <= min_cluster_size: return @@ -319,15 +330,17 @@ def _process_cluster_and_write( if len(sub_nodes) < min_cluster_size: continue # Skip tiny noise sub_parent_node = self._summarize_cluster(sub_nodes, scope) - self._create_parent_node(sub_parent_node) - self._link_cluster_nodes(sub_parent_node, sub_nodes) + self._create_parent_node(sub_parent_node, user_name=user_name) + self._link_cluster_nodes(sub_parent_node, sub_nodes, user_name=user_name) sub_parents.append(sub_parent_node) if sub_parents and len(sub_parents) >= min_cluster_size: cluster_parent_node = self._summarize_cluster(cluster_nodes, scope) - self._create_parent_node(cluster_parent_node) + self._create_parent_node(cluster_parent_node, user_name=user_name) for sub_parent in sub_parents: - self.graph_store.add_edge(cluster_parent_node.id, sub_parent.id, "PARENT") + self.graph_store.add_edge( + cluster_parent_node.id, sub_parent.id, "PARENT", user_name=user_name + ) logger.info("Adding relations/reasons") nodes_to_check = cluster_nodes @@ -351,10 +364,16 @@ def _process_cluster_and_write( # 1) Add pairwise relations for rel in results["relations"]: if not self.graph_store.edge_exists( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ): self.graph_store.add_edge( - rel["source_id"], rel["target_id"], rel["relation_type"] + rel["source_id"], + rel["target_id"], + rel["relation_type"], + user_name=user_name, ) # 2) Add inferred nodes and link to sources @@ -363,14 +382,21 @@ def _process_cluster_and_write( inf_node.id, inf_node.memory, inf_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for src_id in inf_node.metadata.sources: - self.graph_store.add_edge(src_id, inf_node.id, "INFERS") + self.graph_store.add_edge( + src_id, inf_node.id, "INFERS", user_name=user_name + ) # 3) Add sequence links for seq in results["sequence_links"]: - if not self.graph_store.edge_exists(seq["from_id"], seq["to_id"], "FOLLOWS"): - self.graph_store.add_edge(seq["from_id"], seq["to_id"], "FOLLOWS") + if not self.graph_store.edge_exists( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ): + self.graph_store.add_edge( + seq["from_id"], seq["to_id"], "FOLLOWS", user_name=user_name + ) # 4) Add aggregate concept nodes for agg_node in results["aggregate_nodes"]: @@ -378,9 +404,12 @@ def _process_cluster_and_write( agg_node.id, agg_node.memory, agg_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) for child_id in agg_node.metadata.sources: - self.graph_store.add_edge(agg_node.id, child_id, "AGGREGATE_TO") + self.graph_store.add_edge( + agg_node.id, child_id, "AGGREGATE_TO", user_name=user_name + ) logger.info("[Reorganizer] Cluster relation/reasoning done.") @@ -577,7 +606,7 @@ def _parse_json_result(self, response_text): ) return {} - def _create_parent_node(self, parent_node: GraphDBNode) -> None: + def _create_parent_node(self, parent_node: GraphDBNode, user_name: str | None = None) -> None: """ Create a new parent node for the cluster. """ @@ -585,17 +614,23 @@ def _create_parent_node(self, parent_node: GraphDBNode) -> None: parent_node.id, parent_node.memory, parent_node.metadata.model_dump(exclude_none=True), + user_name=user_name, ) - def _link_cluster_nodes(self, parent_node: GraphDBNode, child_nodes: list[GraphDBNode]): + def _link_cluster_nodes( + self, + parent_node: GraphDBNode, + child_nodes: list[GraphDBNode], + user_name: str | None = None, + ): """ Add PARENT edges from the parent node to all nodes in the cluster. """ for child in child_nodes: if not self.graph_store.edge_exists( - parent_node.id, child.id, "PARENT", direction="OUTGOING" + parent_node.id, child.id, "PARENT", direction="OUTGOING", user_name=user_name ): - self.graph_store.add_edge(parent_node.id, child.id, "PARENT") + self.graph_store.add_edge(parent_node.id, child.id, "PARENT", user_name=user_name) def _preprocess_message(self, message: QueueMessage) -> bool: message = self._convert_id_to_node(message) @@ -613,7 +648,9 @@ def _convert_id_to_node(self, message: QueueMessage) -> QueueMessage: for i, node in enumerate(message.after_node or []): if not isinstance(node, str): continue - raw_node = self.graph_store.get_node(node, include_embedding=True) + raw_node = self.graph_store.get_node( + node, include_embedding=True, user_name=message.user_name + ) if raw_node is None: logger.debug(f"Node with ID {node} not found in the graph store.") message.after_node[i] = None diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 9dcbe8c56..cc269e8c4 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -524,7 +524,7 @@ def _retrieve_from_keyword( user_name=user_name, tsquery_config="jiebaqry", ) - except Exception as e: + except Exception: logger.warning( f"[PATH-KEYWORD] search_by_fulltext failed, scope={scope}, user_name={user_name}" ) diff --git a/src/memos/memos_tools/thread_safe_dict_segment.py b/src/memos/memos_tools/thread_safe_dict_segment.py index c1c10e3e1..bf918889f 100644 --- a/src/memos/memos_tools/thread_safe_dict_segment.py +++ b/src/memos/memos_tools/thread_safe_dict_segment.py @@ -71,7 +71,7 @@ def acquire_write(self) -> bool: self._waiting_writers -= 1 self._last_write_time = time.time() return True - except: + except Exception: self._waiting_writers -= 1 raise diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1678d9d15..d890c77bf 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -443,7 +443,10 @@ def _search_pref( }, search_filter=search_req.filter, ) - formatted_results = self._postformat_memories(results, user_context.mem_cube_id) + include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" + formatted_results = self._postformat_memories( + results, user_context.mem_cube_id, include_embedding=include_embedding + ) # For each returned item, tackle with metadata.info project_id / # operation / manager_user_id diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index e4f1ca334..f431bd041 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -244,12 +244,17 @@ Return a single valid JSON object with the following structure: -Return valid JSON: { - "key": , - "memory_type": "LongTermMemory", - "value": , - "tags": + "memory list": [ + { + "key": , + "memory_type": "LongTermMemory", + "value": , + "tags": + } + ... + ], + "summary": } Language rules: @@ -264,7 +269,7 @@ Your Output:""" SIMPLE_STRUCT_DOC_READER_PROMPT_ZH = """您是搜索与检索系统的文本分析专家。 -您的任务是处理文档片段,并生成一个结构化的 JSON 对象。 +您的任务是处理文档片段,并生成一个结构化的 JSON 列表对象。 请执行以下操作: 1. 识别反映文档中事实内容、见解、决策或含义的关键信息——包括任何显著的主题、结论或数据点,使读者无需阅读原文即可充分理解该片段的核心内容。 @@ -281,14 +286,19 @@ - 优先考虑完整性和保真度,而非简洁性。 - 不要泛化或跳过可能具有上下文意义的细节。 -返回一个有效的 JSON 对象,结构如下: +返回有效的 JSON 对象: -返回有效的 JSON: { - "key": <字符串,`value` 字段的简洁标题>, - "memory_type": "LongTermMemory", - "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>, - "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + "memory list": [ + { + "key": <字符串,`value` 字段的简洁标题>, + "memory_type": "LongTermMemory", + "value": <一段清晰准确的段落,全面总结文档片段中的主要观点、论据和信息——若输入摘要为英文,则用英文;若为中文,则用中文>, + "tags": <相关主题关键词列表(例如,["截止日期", "团队", "计划"])> + } + ... + ], + "summary": <简洁总结原文内容,与输入语言一致> } 语言规则: diff --git a/tests/graph_dbs/test_search_return_fields.py b/tests/graph_dbs/test_search_return_fields.py new file mode 100644 index 000000000..82a50308b --- /dev/null +++ b/tests/graph_dbs/test_search_return_fields.py @@ -0,0 +1,306 @@ +""" +Regression tests for issue #955: search methods support specifying return fields. + +Tests that search_by_embedding (and other search methods) accept a `return_fields` +parameter and include the requested fields in the result dicts, eliminating the +need for N+1 get_node() calls. +""" + +import uuid + +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.graph_db import Neo4jGraphDBConfig + + +@pytest.fixture +def neo4j_config(): + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_memory_db", + auto_create=False, + embedding_dimension=3, + ) + + +@pytest.fixture +def neo4j_db(neo4j_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(neo4j_config) + db.driver = mock_driver + yield db + + +class TestNeo4jSearchReturnFields: + """Tests for Neo4jGraphDB.search_by_embedding with return_fields.""" + + def test_return_fields_included_in_results(self, neo4j_db): + """return_fields values are present in each result dict.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + node_id = str(uuid.uuid4()) + session_mock.run.return_value = [ + {"id": node_id, "score": 0.95, "memory": "hello", "status": "activated"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == node_id + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + def test_backward_compatible_without_return_fields(self, neo4j_db): + """Without return_fields, only id and score are returned (old behavior).""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + assert len(results) == 1 + assert set(results[0].keys()) == {"id", "score"} + + def test_cypher_return_clause_includes_fields(self, neo4j_db): + """Cypher RETURN clause contains the requested fields.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory", "tags"], + ) + + query = session_mock.run.call_args[0][0] + assert "node.memory AS memory" in query + assert "node.tags AS tags" in query + + def test_cypher_return_clause_default(self, neo4j_db): + """Without return_fields, RETURN clause only has id and score.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + ) + + query = session_mock.run.call_args[0][0] + assert "RETURN node.id AS id, score" in query + assert "node.memory" not in query + + def test_return_fields_skips_id_field(self, neo4j_db): + """Passing 'id' in return_fields does not duplicate it in RETURN clause.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["id", "memory"], + ) + + query = session_mock.run.call_args[0][0] + # 'id' should appear only once (as node.id AS id), not duplicated + assert query.count("node.id AS id") == 1 + assert "node.memory AS memory" in query + + def test_threshold_filtering_still_works_with_return_fields(self, neo4j_db): + """Threshold filtering works correctly when return_fields is specified.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": str(uuid.uuid4()), "score": 0.9, "memory": "high score"}, + {"id": str(uuid.uuid4()), "score": 0.3, "memory": "low score"}, + ] + + results = neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + threshold=0.5, + return_fields=["memory"], + ) + + assert len(results) == 1 + assert results[0]["memory"] == "high score" + + +class TestPolarDBExtractFieldsFromProperties: + """Tests for PolarDBGraphDB._extract_fields_from_properties helper.""" + + @pytest.fixture + def polardb_instance(self): + """Create a minimal PolarDB instance for testing the helper method.""" + with patch("memos.graph_dbs.polardb.PolarDBGraphDB.__init__", return_value=None): + from memos.graph_dbs.polardb import PolarDBGraphDB + + db = PolarDBGraphDB.__new__(PolarDBGraphDB) + yield db + + def test_extract_from_json_string(self, polardb_instance): + """Extract fields from a JSON string properties value.""" + props = '{"id": "abc", "memory": "hello", "status": "activated", "tags": ["a"]}' + result = polardb_instance._extract_fields_from_properties( + props, ["memory", "status", "tags"] + ) + assert result == {"memory": "hello", "status": "activated", "tags": ["a"]} + + def test_extract_from_dict(self, polardb_instance): + """Extract fields from a dict properties value.""" + props = {"id": "abc", "memory": "hello", "status": "activated"} + result = polardb_instance._extract_fields_from_properties(props, ["memory", "status"]) + assert result == {"memory": "hello", "status": "activated"} + + def test_extract_skips_id(self, polardb_instance): + """'id' field is skipped even if requested.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["id", "memory"]) + assert result == {"memory": "hello"} + + def test_extract_missing_fields(self, polardb_instance): + """Missing fields are silently skipped.""" + props = '{"id": "abc", "memory": "hello"}' + result = polardb_instance._extract_fields_from_properties(props, ["memory", "nonexistent"]) + assert result == {"memory": "hello"} + + def test_extract_empty_properties(self, polardb_instance): + """Empty/None properties return empty dict.""" + assert polardb_instance._extract_fields_from_properties(None, ["memory"]) == {} + assert polardb_instance._extract_fields_from_properties("", ["memory"]) == {} + + def test_extract_invalid_json(self, polardb_instance): + """Invalid JSON returns empty dict without raising.""" + result = polardb_instance._extract_fields_from_properties("not-json", ["memory"]) + assert result == {} + + +class TestFieldNameValidation: + """Tests for _validate_return_fields injection prevention.""" + + def test_valid_field_names_pass(self): + from memos.graph_dbs.base import BaseGraphDB + + result = BaseGraphDB._validate_return_fields(["memory", "status", "tags", "user_name"]) + assert result == ["memory", "status", "tags", "user_name"] + + def test_invalid_field_names_rejected(self): + from memos.graph_dbs.base import BaseGraphDB + + # Cypher injection attempts + result = BaseGraphDB._validate_return_fields( + [ + "memory} RETURN n //", + "status; DROP", + "valid_field", + "a.b", + "field name", + "", + ] + ) + assert result == ["valid_field"] + + def test_none_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields(None) == [] + + def test_empty_list_returns_empty(self): + from memos.graph_dbs.base import BaseGraphDB + + assert BaseGraphDB._validate_return_fields([]) == [] + + def test_injection_in_cypher_query_prevented(self, neo4j_db): + """Malicious field names should not appear in the Cypher query.""" + session_mock = neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + user_name="test_user", + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + # Injection attempt should NOT appear in query + assert "memory}" not in query + assert "RETURN n //" not in query + # Valid field should appear + assert "node.valid_field AS valid_field" in query + + +class TestNeo4jCommunitySearchReturnFields: + """Tests for Neo4jCommunityGraphDB._fetch_return_fields with return_fields.""" + + @pytest.fixture + def neo4j_community_db(self): + """Create a minimal Neo4jCommunityGraphDB instance by patching __init__.""" + with patch( + "memos.graph_dbs.neo4j_community.Neo4jCommunityGraphDB.__init__", return_value=None + ): + from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB + + db = Neo4jCommunityGraphDB.__new__(Neo4jCommunityGraphDB) + db.driver = MagicMock() + db.db_name = "test_memory_db" + yield db + + def test_fetch_return_fields_queries_neo4j(self, neo4j_community_db): + """_fetch_return_fields builds correct Cypher and returns fields.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [ + {"id": "node-1", "memory": "hello", "status": "activated"}, + ] + + results = neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory", "status"], + ) + + assert len(results) == 1 + assert results[0]["id"] == "node-1" + assert results[0]["score"] == 0.95 + assert results[0]["memory"] == "hello" + assert results[0]["status"] == "activated" + + query = session_mock.run.call_args[0][0] + assert "n.memory AS memory" in query + assert "n.status AS status" in query + + def test_fetch_return_fields_validates_names(self, neo4j_community_db): + """_fetch_return_fields rejects invalid field names.""" + session_mock = neo4j_community_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + neo4j_community_db._fetch_return_fields( + ids=["node-1"], + score_map={"node-1": 0.95}, + return_fields=["memory} RETURN n //", "valid_field"], + ) + + query = session_mock.run.call_args[0][0] + assert "memory}" not in query + assert "n.valid_field AS valid_field" in query diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py index 46cf3a1f6..a6ac186b7 100644 --- a/tests/memories/textual/test_history_manager.py +++ b/tests/memories/textual/test_history_manager.py @@ -131,7 +131,7 @@ def test_mark_memory_status(history_manager, mock_graph_db): # Assert assert mock_graph_db.update_node.call_count == 3 - # Verify we called it correctly - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}) + # Verify we called it correctly (user_name=None is passed by mark_memory_status) + mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) + mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None)