diff --git a/.gitignore b/.gitignore index 1a9c5653f..ece7e45ba 100644 --- a/.gitignore +++ b/.gitignore @@ -230,3 +230,5 @@ cython_debug/ outputs evaluation/data/temporal_locomo +test_add_pipeline.py +test_file_pipeline.py diff --git a/README.md b/README.md index db0c5cb44..a7b05d683 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Awesome AI Memory - +

@@ -55,7 +55,7 @@ -->

- + @@ -154,7 +154,7 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem -- **2025-08-07** · 🎉 **MemOS v1.0.0 (MemCube) Release** +- **2025-08-07** · 🎉 **MemOS v1.0.0 (MemCube) Release** First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch.
@@ -192,11 +192,11 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
-- **2025-07-07** · 🎉 **MemOS v1.0: Stellar (星河) Preview Release** +- **2025-07-07** · 🎉 **MemOS v1.0: Stellar (星河) Preview Release** A SOTA Memory OS for LLMs is now open-sourced. -- **2025-07-04** · 🎉 **MemOS Paper Release** +- **2025-07-04** · 🎉 **MemOS Paper Release** [MemOS: A Memory OS for AI System](https://arxiv.org/abs/2507.03724) is available on arXiv. -- **2024-07-04** · 🎉 **Memory3 Model Release at WAIC 2024** +- **2024-07-04** · 🎉 **Memory3 Model Release at WAIC 2024** The Memory3 model, featuring a memory-layered architecture, was unveiled at the 2024 World Artificial Intelligence Conference.
@@ -209,9 +209,9 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - Go to **API Keys** and copy your key #### Next Steps -- [MemOS Cloud Getting Started](https://memos-docs.openmem.net/memos_cloud/quick_start/) +- [MemOS Cloud Getting Started](https://memos-docs.openmem.net/memos_cloud/quick_start/) Connect to MemOS Cloud and enable memory in minutes. -- [MemOS Cloud Platform](https://memos.openmem.net/?from=/quickstart/) +- [MemOS Cloud Platform](https://memos.openmem.net/?from=/quickstart/) Explore the Cloud dashboard, features, and workflows. ### 🖥️ 2、Self-Hosted (Local/Private) @@ -249,7 +249,7 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ```python import requests import json - + data = { "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", "mem_cube_id": "b32d0977-435d-4828-a86f-4f47f8b55bca", @@ -265,7 +265,7 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem "Content-Type": "application/json" } url = "http://localhost:8000/product/add" - + res = requests.post(url=url, headers=headers, data=json.dumps(data)) print(f"result: {res.json()}") ``` @@ -273,7 +273,7 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ```python import requests import json - + data = { "query": "What do I like", "user_id": "8736b16e-1d20-4163-980b-a5063c3facdc", @@ -283,7 +283,7 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem "Content-Type": "application/json" } url = "http://localhost:8000/product/search" - + res = requests.post(url=url, headers=headers, data=json.dumps(data)) print(f"result: {res.json()}") ``` @@ -292,8 +292,8 @@ Full tutorial → [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ## 📚 Resources -- **Awesome-AI-Memory** - This is a curated repository dedicated to resources on memory and memory systems for large language models. It systematically collects relevant research papers, frameworks, tools, and practical insights. The repository aims to organize and present the rapidly evolving research landscape of LLM memory, bridging multiple research directions including natural language processing, information retrieval, agentic systems, and cognitive science. +- **Awesome-AI-Memory** + This is a curated repository dedicated to resources on memory and memory systems for large language models. It systematically collects relevant research papers, frameworks, tools, and practical insights. The repository aims to organize and present the rapidly evolving research landscape of LLM memory, bridging multiple research directions including natural language processing, information retrieval, agentic systems, and cognitive science. - **Get started** 👉 [IAAR-Shanghai/Awesome-AI-Memory](https://github.com/IAAR-Shanghai/Awesome-AI-Memory) - **MemOS Cloud OpenClaw Plugin** Official OpenClaw lifecycle plugin for MemOS Cloud. It automatically recalls context from MemOS before the agent starts and saves the conversation back to MemOS after the agent finishes. diff --git a/docker/Dockerfile.krolik b/docker/Dockerfile.krolik index c475a6d30..dcae7e0d9 100644 --- a/docker/Dockerfile.krolik +++ b/docker/Dockerfile.krolik @@ -1,5 +1,5 @@ # MemOS with Krolik Security Extensions -# +# # This Dockerfile builds MemOS with authentication, rate limiting, and admin API. # It uses the overlay pattern to keep customizations separate from base code. diff --git a/examples/core_memories/general_textual_memory.py b/examples/core_memories/general_textual_memory.py index d5c765b01..007736a6e 100644 --- a/examples/core_memories/general_textual_memory.py +++ b/examples/core_memories/general_textual_memory.py @@ -68,21 +68,9 @@ example_id = "a19b6caa-5d59-42ad-8c8a-e4f7118435b4" -print("===== Extract memories =====") -# Extract memories from a conversation -# The extractor LLM processes the conversation to identify relevant information. -memories = m.extract( - [ - {"role": "user", "content": "I love tomatoes."}, - {"role": "assistant", "content": "Great! Tomatoes are delicious."}, - ] -) -pprint.pprint(memories) -print() - print("==== Add memories ====") -# Add the extracted memories to the memory store -m.add(memories) +# Add example memories to the memory store +m.add(example_memories) # Add a manually created memory item m.add( [ diff --git a/examples/core_memories/naive_textual_memory.py b/examples/core_memories/naive_textual_memory.py index ab73060c7..1e7901e0f 100644 --- a/examples/core_memories/naive_textual_memory.py +++ b/examples/core_memories/naive_textual_memory.py @@ -1,20 +1,11 @@ -import json import os +import pprint import uuid from memos.configs.memory import MemoryConfigFactory from memos.memories.factory import MemoryFactory -def print_result(title, result): - """Helper function: Pretty print the result.""" - print(f"\n{'=' * 10} {title} {'=' * 10}") - if isinstance(result, list | dict): - print(json.dumps(result, indent=2, ensure_ascii=False, default=str)) - else: - print(result) - - # Configure memory backend with OpenAI extractor config = MemoryConfigFactory( backend="naive_text", @@ -38,39 +29,55 @@ def print_result(title, result): # Create memory instance m = MemoryFactory.from_config(config) +example_memories = [ + { + "memory": "I'm a RUCer, I'm happy.", + "metadata": { + "type": "event", + }, + }, + { + "memory": "MemOS is awesome!", + "metadata": { + "type": "opinion", + }, + }, +] + +example_id = str(uuid.uuid4()) -# Extract memories from a simulated conversation -memories = m.extract( +print("==== Add memories ====") +# Add example memories to the memory store +m.add(example_memories) +# Manually create a memory item and add it +m.add( [ - {"role": "user", "content": "I love tomatoes."}, - {"role": "assistant", "content": "Great! Tomatoes are delicious."}, + { + "id": example_id, + "memory": "User is Chinese.", + "metadata": {"type": "opinion"}, + } ] ) -print_result("Extract memories", memories) - - -# Add the extracted memories to storage -m.add(memories) - -# Manually create a memory item and add it -example_id = str(uuid.uuid4()) -manual_memory = [{"id": example_id, "memory": "User is Chinese.", "metadata": {"type": "opinion"}}] -m.add(manual_memory) - -# Print all current memories -print_result("Add memories (Check all after adding)", m.get_all()) - +print("All memories after addition:") +pprint.pprint(m.get_all()) +print() -# Search for relevant memories based on the query +print("==== Search memories ====") +# Search for memories related to a query search_results = m.search("Tell me more about the user", top_k=2) -print_result("Search memories", search_results) - +pprint.pprint(search_results) +print() +print("==== Get memories ====") # Get specific memory item by ID -memory_item = m.get(example_id) -print_result("Get memory", memory_item) - +print(f"Memory with ID {example_id}:") +pprint.pprint(m.get(example_id)) +print(f"Memories by IDs [{example_id}]:") +pprint.pprint(m.get_by_ids([example_id])) +print() +print("==== Update memories ====") # Update the memory content for the specified ID m.update( example_id, @@ -80,9 +87,9 @@ def print_result(title, result): "metadata": {"type": "opinion", "confidence": 85}, }, ) -updated_memory = m.get(example_id) -print_result("Update memory", updated_memory) - +print(f"Memory after update (ID {example_id}):") +pprint.pprint(m.get(example_id)) +print() print("==== Dump memory ====") # Dump the current state of memory to a file @@ -90,12 +97,16 @@ def print_result(title, result): print("Memory dumped to 'tmp/naive_mem'.") print() - +print("==== Delete memories ====") # Delete memory with the specified ID m.delete([example_id]) -print_result("Delete memory (Check all after deleting)", m.get_all()) - +print("All memories after deletion:") +pprint.pprint(m.get_all()) +print() +print("==== Delete all memories ====") # Delete all memories in storage m.delete_all() -print_result("Delete all", m.get_all()) +print("All memories after delete_all:") +pprint.pprint(m.get_all()) +print() diff --git a/examples/data/mem_cube_tree/textual_memory.json b/examples/data/mem_cube_tree/textual_memory.json index 91f426ca2..97a2b1dd0 100644 --- a/examples/data/mem_cube_tree/textual_memory.json +++ b/examples/data/mem_cube_tree/textual_memory.json @@ -4216,4 +4216,4 @@ "edges": [], "total_nodes": 4, "total_edges": 0 -} \ No newline at end of file +} diff --git a/examples/mem_feedback/example_feedback.py b/examples/mem_feedback/example_feedback.py index 8f4446863..794ddf111 100644 --- a/examples/mem_feedback/example_feedback.py +++ b/examples/mem_feedback/example_feedback.py @@ -144,7 +144,7 @@ def init_components(): mem_reader=mem_reader, searcher=searcher, reranker=mem_reranker, - pref_mem=None, + pref_feedback=True, ) return feedback_server, memory_manager, embedder diff --git a/pyproject.toml b/pyproject.toml index 4a9ea8852..9f17c0000 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.7" +version = "2.0.8" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index fefa3b2ab..36cc0b5b5 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.7" +__version__ = "2.0.8" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 65049b0c2..fa12bcf55 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -675,7 +675,17 @@ def get_polardb_config(user_id: str | None = None) -> dict[str, Any]: "user_name": user_name, "use_multi_db": use_multi_db, "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "1024")), + # .env: CONNECTION_WAIT_TIMEOUT, SKIP_CONNECTION_HEALTH_CHECK, WARM_UP_ON_STARTUP_BY_FULL, WARM_UP_ON_STARTUP_BY_ALL + "connection_wait_timeout": int(os.getenv("CONNECTION_WAIT_TIMEOUT", "60")), + "skip_connection_health_check": os.getenv( + "SKIP_CONNECTION_HEALTH_CHECK", "false" + ).lower() + == "true", + "warm_up_on_startup_by_full": os.getenv("WARM_UP_ON_STARTUP_BY_FULL", "false").lower() + == "true", + "warm_up_on_startup_by_all": os.getenv("WARM_UP_ON_STARTUP_BY_ALL", "false").lower() + == "true", } @staticmethod diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py index 90347768c..bd4c9f4b0 100644 --- a/src/memos/api/handlers/__init__.py +++ b/src/memos/api/handlers/__init__.py @@ -32,7 +32,6 @@ ) from memos.api.handlers.formatters_handler import ( format_memory_item, - post_process_pref_mem, to_iter, ) @@ -54,7 +53,6 @@ "formatters_handler", "init_server", "memory_handler", - "post_process_pref_mem", "scheduler_handler", "search_handler", "suggestion_handler", diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 3cdbedabf..e9ed4f955 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -22,7 +22,7 @@ class AddHandler(BaseHandler): """ Handler for memory addition operations. - Handles both text and preference memory additions with sync/async support. + Handles text memory additions with sync/async support. """ def __init__(self, dependencies: HandlerDependencies): @@ -41,7 +41,7 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. - Orchestrates the addition of both text and preference memories, + Orchestrates the addition of text memories, supporting concurrent processing. Args: diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index ba527d602..aa2525878 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -19,11 +19,7 @@ build_llm_config, build_mem_reader_config, build_nli_client_config, - build_pref_adder_config, - build_pref_extractor_config, - build_pref_retriever_config, build_reranker_config, - build_vec_db_config, ) from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.embedders.factory import EmbedderFactory @@ -36,12 +32,6 @@ from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager @@ -56,7 +46,6 @@ InternetRetrieverFactory, ) from memos.reranker.factory import RerankerFactory -from memos.vec_dbs.factory import VecDBFactory if TYPE_CHECKING: @@ -125,7 +114,7 @@ def init_server() -> dict[str, Any]: required by the MemOS server, including: - Database connections (graph DB, vector DB) - Language models and embedders - - Memory systems (text, preference) + - Memory systems (text) - Scheduler and related modules Returns: @@ -169,20 +158,11 @@ def init_server() -> dict[str, Any]: reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() - vector_db_config = build_vec_db_config() - pref_extractor_config = build_pref_extractor_config() - pref_adder_config = build_pref_adder_config() - pref_retriever_config = build_pref_retriever_config() logger.debug("Component configurations built successfully") # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = ( - VecDBFactory.from_config(vector_db_config) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) llm = LLMFactory.from_config(llm_config) chat_llms = ( _init_chat_llms(chat_llm_config) @@ -231,61 +211,6 @@ def init_server() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize preference memory components - pref_extractor = ( - ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_adder = ( - AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_retriever = ( - RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=feedback_reranker, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory components initialized") - - # Initialize preference memory - pref_mem = ( - SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=feedback_reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory initialized") - # Initialize MOS Server mos_server = MOSServer( mem_reader=mem_reader, @@ -298,7 +223,6 @@ def init_server() -> dict[str, Any]: # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=pref_mem, act_mem=None, para_mem=None, ) @@ -325,7 +249,7 @@ def init_server() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, - pref_mem=pref_mem, + pref_feedback=True, ) # Initialize Scheduler @@ -384,12 +308,7 @@ def init_server() -> dict[str, Any]: "naive_mem_cube": naive_mem_cube, "searcher": searcher, "api_module": api_module, - "vector_db": vector_db, - "pref_extractor": pref_extractor, - "pref_adder": pref_adder, - "pref_retriever": pref_retriever, "text_mem": text_mem, - "pref_mem": pref_mem, "online_bot": online_bot, "feedback_server": feedback_server, "redis_client": redis_client, diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 06c4fd223..ee88ae639 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -65,49 +65,14 @@ def format_memory_item( return memory -def post_process_pref_mem( - memories_result: dict[str, Any], - pref_formatted_mem: list[dict[str, Any]], - mem_cube_id: str, - include_preference: bool, -) -> dict[str, Any]: - """ - Post-process preference memory results. - - Adds formatted preference memories to the result dictionary and generates - instruction completion strings if preferences are included. - - Args: - memories_result: Result dictionary to update - pref_formatted_mem: List of formatted preference memories - mem_cube_id: Memory cube ID - include_preference: Whether to include preferences in result - - Returns: - Updated memories_result dictionary - """ - if include_preference: - memories_result["pref_mem"].append( - { - "cube_id": mem_cube_id, - "memories": pref_formatted_mem, - "total_nodes": len(pref_formatted_mem), - } - ) - pref_instruction, pref_note = instruct_completion(pref_formatted_mem) - memories_result["pref_string"] = pref_instruction - memories_result["pref_note"] = pref_note - - return memories_result - - def post_process_textual_mem( memories_result: dict[str, Any], text_formatted_mem: list[dict[str, Any]], mem_cube_id: str, ) -> dict[str, Any]: """ - Post-process text and tool memory results. + Post-process text, tool, skill and preference memory results. + Now automatically handles preference memories. """ fact_mem = [ mem @@ -124,6 +89,11 @@ def post_process_textual_mem( mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "SkillMemory" ] + # Extract preference memories + pref_mem = [ + mem for mem in text_formatted_mem if mem["metadata"]["memory_type"] == "PreferenceMemory" + ] + memories_result["text_mem"].append( { "cube_id": mem_cube_id, @@ -145,6 +115,19 @@ def post_process_textual_mem( "total_nodes": len(skill_mem), } ) + + memories_result["pref_mem"].append( + { + "cube_id": mem_cube_id, + "memories": pref_mem, + "total_nodes": len(pref_mem), + } + ) + if pref_mem: + pref_instruction, pref_note = instruct_completion(pref_mem) + memories_result["pref_string"] = pref_instruction + memories_result["pref_note"] = pref_note + return memories_result diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index ef56c7489..2ab8f50c7 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -4,12 +4,8 @@ This module handles retrieving all memories or specific subgraphs based on queries. """ -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal -from memos.api.handlers.formatters_handler import ( - format_memory_item, - post_process_pref_mem, -) from memos.api.product_models import ( DeleteMemoryRequest, DeleteMemoryResponse, @@ -29,10 +25,6 @@ ) -if TYPE_CHECKING: - from memos.memories.textual.preference import TextualMemoryItem - - logger = get_logger(__name__) @@ -171,8 +163,7 @@ def handle_get_subgraph( def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemoryResponse: """ Handler for getting a single memory by its ID. - - Tries to retrieve from text memory first, then preference memory if not found. + Now unified to retrieve from text_mem only (includes preferences). Args: memory_id: The ID of the memory to retrieve @@ -184,37 +175,12 @@ def handle_get_memory(memory_id: str, naive_mem_cube: NaiveMemCube) -> GetMemory try: memory = naive_mem_cube.text_mem.get(memory_id) - except Exception: + except Exception as e: + logger.error(f"Failed to get memory {memory_id}: {e}") memory = None - # If not found in text memory, try preference memory - pref = None - if memory is None and naive_mem_cube.pref_mem is not None: - collection_names = ["explicit_preference", "implicit_preference"] - for collection_name in collection_names: - try: - pref = naive_mem_cube.pref_mem.get_with_collection_name(collection_name, memory_id) - if pref is not None: - break - except Exception: - continue - - # Get the data from whichever memory source succeeded - data = (memory or pref).model_dump() if (memory or pref) else None - - if data is not None: - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - metadata = data.get("metadata", None) - if metadata is not None and isinstance(metadata, dict): - info = metadata.get("info", None) - if info is not None and isinstance(info, dict): - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value + # Get the data + data = memory.model_dump() if memory else None return GetMemoryResponse( message="Memory retrieved successfully" @@ -230,50 +196,20 @@ def handle_get_memory_by_ids( ) -> GetMemoryResponse: """ Handler for getting multiple memories by their IDs. + Now unified to retrieve from text_mem only (includes preferences). Retrieves multiple memories and formats them as a list of dictionaries. """ try: memories = naive_mem_cube.text_mem.get_by_ids(memory_ids=memory_ids) - except Exception: + except Exception as e: + logger.error(f"Failed to get memories: {e}") memories = [] # Ensure memories is not None if memories is None: memories = [] - if naive_mem_cube.pref_mem is not None: - collection_names = ["explicit_preference", "implicit_preference"] - for collection_name in collection_names: - try: - result = naive_mem_cube.pref_mem.get_by_ids_with_collection_name( - collection_name, memory_ids - ) - if result is not None: - result = [format_memory_item(item, save_sources=False) for item in result] - memories.extend(result) - except Exception: - continue - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in memories: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - return GetMemoryResponse( message="Memories retrieved successfully", code=200, data={"memories": memories} ) @@ -343,67 +279,31 @@ def handle_get_memories( "total_nodes": total_skill_nodes, } ] - preferences: list[TextualMemoryItem] = [] - total_preference_nodes = 0 - format_preferences = [] - if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - if get_mem_req.filter is not None: - # Check and remove user_id/mem_cube_id from filter if present - filter_copy = get_mem_req.filter.copy() - removed_fields = [] - - if "user_id" in filter_copy: - filter_copy.pop("user_id") - removed_fields.append("user_id") - if "mem_cube_id" in filter_copy: - filter_copy.pop("mem_cube_id") - removed_fields.append("mem_cube_id") - - if removed_fields: - logger.warning( - f"Fields {removed_fields} found in filter will be ignored. " - f"Use request-level user_id/mem_cube_id parameters instead." - ) - - filter_params.update(filter_copy) - - preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( - filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + # Get preference memories (same pattern as other memory types) + if get_mem_req.include_preference: + pref_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["PreferenceMemory"], ) - format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in format_preferences: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - - results = post_process_pref_mem( - results, format_preferences, get_mem_req.mem_cube_id, get_mem_req.include_preference - ) - if total_preference_nodes > 0 and results.get("pref_mem", []): - results["pref_mem"][0]["total_nodes"] = total_preference_nodes + pref_memories, total_pref_nodes = ( + pref_memories_info["nodes"], + pref_memories_info["total_nodes"], + ) + + results["pref_mem"] = [ + { + "cube_id": get_mem_req.mem_cube_id, + "memories": pref_memories, + "total_nodes": total_pref_nodes, + } + ] - # Filter to only keep text_mem, pref_mem, tool_mem + # Filter to only keep text_mem, pref_mem, tool_mem, skill_mem filtered_results = { "text_mem": results.get("text_mem", []), "pref_mem": results.get("pref_mem", []), @@ -415,6 +315,10 @@ def handle_get_memories( def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube): + """ + Handler for deleting memories. + Now unified to delete from text_mem only (includes preferences). + """ logger.info( f"[Delete memory request] writable_cube_ids: {delete_mem_req.writable_cube_ids}, memory_ids: {delete_mem_req.memory_ids}" ) @@ -432,17 +336,14 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: try: if delete_mem_req.memory_ids is not None: + # Unified deletion from text_mem (includes preferences) naive_mem_cube.text_mem.delete_by_memory_ids(delete_mem_req.memory_ids) - if naive_mem_cube.pref_mem is not None: - naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) elif delete_mem_req.file_ids is not None: naive_mem_cube.text_mem.delete_by_filter( writable_cube_ids=delete_mem_req.writable_cube_ids, file_ids=delete_mem_req.file_ids ) elif delete_mem_req.filter is not None: naive_mem_cube.text_mem.delete_by_filter(filter=delete_mem_req.filter) - if naive_mem_cube.pref_mem is not None: - naive_mem_cube.pref_mem.delete_by_filter(filter=delete_mem_req.filter) except Exception as e: logger.error(f"Failed to delete memories: {e}", exc_info=True) return DeleteMemoryResponse( @@ -572,49 +473,29 @@ def handle_get_memories_dashboard( for cube_id, memories in skill_mem_by_cube.items() ] - preferences: list[TextualMemoryItem] = [] - - format_preferences = [] - if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - if get_mem_req.filter is not None: - # Check and remove user_id/mem_cube_id from filter if present - filter_copy = get_mem_req.filter.copy() - removed_fields = [] - - if "user_id" in filter_copy: - filter_copy.pop("user_id") - removed_fields.append("user_id") - if "mem_cube_id" in filter_copy: - filter_copy.pop("mem_cube_id") - removed_fields.append("mem_cube_id") - - if removed_fields: - logger.warning( - f"Fields {removed_fields} found in filter will be ignored. " - f"Use request-level user_id/mem_cube_id parameters instead." - ) - - filter_params.update(filter_copy) - - preferences, total_preference_nodes = naive_mem_cube.pref_mem.get_memory_by_filter( - filter_params, page=get_mem_req.page, page_size=get_mem_req.page_size + if get_mem_req.include_preference: + pref_memories_info = naive_mem_cube.text_mem.get_all( + user_name=get_mem_req.mem_cube_id, + user_id=get_mem_req.user_id, + page=get_mem_req.page, + page_size=get_mem_req.page_size, + filter=get_mem_req.filter, + memory_type=["PreferenceMemory"], + ) + pref_memories, total_preference_nodes = ( + pref_memories_info["nodes"], + pref_memories_info["total_nodes"], ) - format_preferences = [format_memory_item(item, save_sources=False) for item in preferences] - # Group preferences by cube_id from metadata.mem_cube_id + # Group preference memories by cube_id from metadata.user_name pref_mem_by_cube: dict[str, list] = {} - for pref in format_preferences: - cube_id = pref.get("metadata", {}).get("mem_cube_id", get_mem_req.mem_cube_id) + for memory in pref_memories: + cube_id = memory.get("metadata", {}).get("user_name", get_mem_req.mem_cube_id) if cube_id not in pref_mem_by_cube: pref_mem_by_cube[cube_id] = [] - pref_mem_by_cube[cube_id].append(pref) + pref_mem_by_cube[cube_id].append(memory) - # If no preferences found, create a default entry with the requested cube_id + # If no memories found, create a default entry with the requested cube_id if not pref_mem_by_cube and get_mem_req.mem_cube_id: pref_mem_by_cube[get_mem_req.mem_cube_id] = [] diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 8e7785ad5..ba1c50b07 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -49,7 +49,7 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse Main handler for search memories endpoint. Orchestrates the search process based on the requested search mode, - supporting both text and preference memory searches. + supporting text memory searches. Args: search_req: Search request containing query and parameters @@ -120,10 +120,7 @@ def _apply_relativity_threshold(results: dict[str, Any], relativity: float) -> d if not isinstance(mem, dict): continue meta = mem.get("metadata", {}) - if key == "text_mem": - score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 - else: - score = meta.get("score", 1.0) if isinstance(meta, dict) else 1.0 + score = meta.get("relativity", 1.0) if isinstance(meta, dict) else 1.0 try: score_val = float(score) if score is not None else 1.0 except (TypeError, ValueError): diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 5bf27e985..6f112b9a7 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -434,7 +434,7 @@ class APISearchRequest(BaseRequest): # Internal field for search memory type search_memory_type: str = Field( "All", - description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory", + description="Type of memory to search: All, WorkingMemory, LongTermMemory, UserMemory, OuterMemory, ToolSchemaMemory, ToolTrajectoryMemory, RawFileMemory, AllSummaryMemory, SkillMemory, PreferenceMemory", ) # ==== Context ==== diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index af6ae4fe5..fa8a0b396 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -94,7 +94,6 @@ redis_client = components["redis_client"] status_tracker = TaskStatusTracker(redis_client=redis_client) graph_db = components["graph_db"] -vector_db = components["vector_db"] # ============================================================================= @@ -369,15 +368,9 @@ def feedback_memories(feedback_req: APIFeedbackRequest): response_model=GetUserNamesByMemoryIdsResponse, ) def get_user_names_by_memory_ids(request: GetUserNamesByMemoryIdsRequest): - """Get user names by memory ids.""" + """Get user names by memory ids. Now unified to query from graph_db only.""" result = graph_db.get_user_names_by_memory_ids(memory_ids=request.memory_ids) - if vector_db: - prefs = [] - for collection_name in ["explicit_preference", "implicit_preference"]: - prefs.extend( - vector_db.get_by_ids(collection_name=collection_name, ids=request.memory_ids) - ) - result.update({pref.id: pref.payload.get("mem_cube_id", None) for pref in prefs}) + return GetUserNamesByMemoryIdsResponse( code=200, message="Successfully", diff --git a/src/memos/chunkers/base.py b/src/memos/chunkers/base.py index c2a783baa..e858132e1 100644 --- a/src/memos/chunkers/base.py +++ b/src/memos/chunkers/base.py @@ -1,3 +1,5 @@ +import re + from abc import ABC, abstractmethod from memos.configs.chunker import BaseChunkerConfig @@ -22,3 +24,42 @@ def __init__(self, config: BaseChunkerConfig): @abstractmethod def chunk(self, text: str) -> list[Chunk]: """Chunk the given text into smaller chunks.""" + + def protect_urls(self, text: str) -> tuple[str, dict[str, str]]: + """ + Protect URLs in text from being split during chunking. + + Args: + text: Text to process + + Returns: + tuple: (Text with URLs replaced by placeholders, URL mapping dictionary) + """ + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + url_map = {} + + def replace_url(match): + url = match.group(0) + placeholder = f"__URL_{len(url_map)}__" + url_map[placeholder] = url + return placeholder + + protected_text = re.sub(url_pattern, replace_url, text) + return protected_text, url_map + + def restore_urls(self, text: str, url_map: dict[str, str]) -> str: + """ + Restore protected URLs in text back to their original form. + + Args: + text: Text with URL placeholders + url_map: URL mapping dictionary from protect_urls + + Returns: + str: Text with URLs restored + """ + restored_text = text + for placeholder, url in url_map.items(): + restored_text = restored_text.replace(placeholder, url) + + return restored_text diff --git a/src/memos/chunkers/charactertext_chunker.py b/src/memos/chunkers/charactertext_chunker.py index 15c0958ba..25739d96f 100644 --- a/src/memos/chunkers/charactertext_chunker.py +++ b/src/memos/chunkers/charactertext_chunker.py @@ -36,6 +36,8 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chunks = self.chunker.split_text(text) + protected_text, url_map = self.protect_urls(text) + chunks = self.chunker.split_text(protected_text) + chunks = [self.restore_urls(chunk, url_map) for chunk in chunks] logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks diff --git a/src/memos/chunkers/markdown_chunker.py b/src/memos/chunkers/markdown_chunker.py index b7771ac35..a37370200 100644 --- a/src/memos/chunkers/markdown_chunker.py +++ b/src/memos/chunkers/markdown_chunker.py @@ -1,3 +1,5 @@ +import re + from memos.configs.chunker import MarkdownChunkerConfig from memos.dependency import require_python_package from memos.log import get_logger @@ -22,6 +24,7 @@ def __init__( chunk_size: int = 1000, chunk_overlap: int = 200, recursive: bool = False, + auto_fix_headers: bool = True, ): from langchain_text_splitters import ( MarkdownHeaderTextSplitter, @@ -29,6 +32,7 @@ def __init__( ) self.config = config + self.auto_fix_headers = auto_fix_headers self.chunker = MarkdownHeaderTextSplitter( headers_to_split_on=config.headers_to_split_on if config @@ -46,17 +50,110 @@ def __init__( def chunk(self, text: str, **kwargs) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - md_header_splits = self.chunker.split_text(text) + # Protect URLs first + protected_text, url_map = self.protect_urls(text) + # Auto-detect and fix malformed header hierarchy if enabled + if self.auto_fix_headers and self._detect_malformed_headers(protected_text): + logger.info("[Chunker:] detected malformed header hierarchy, attempting to fix...") + protected_text = self._fix_header_hierarchy(protected_text) + logger.info("[Chunker:] Header hierarchy fix completed") + + md_header_splits = self.chunker.split_text(protected_text) chunks = [] if self.chunker_recursive: md_header_splits = self.chunker_recursive.split_documents(md_header_splits) for doc in md_header_splits: try: chunk = " ".join(list(doc.metadata.values())) + "\n" + doc.page_content + chunk = self.restore_urls(chunk, url_map) chunks.append(chunk) except Exception as e: logger.warning(f"warning chunking document: {e}") - chunks.append(doc.page_content) + restored_chunk = self.restore_urls(doc.page_content, url_map) + chunks.append(restored_chunk) logger.info(f"Generated chunks: {chunks[:5]}") logger.debug(f"Generated {len(chunks)} chunks from input text") return chunks + + def _detect_malformed_headers(self, text: str) -> bool: + """Detect if markdown has improper header hierarchy usage.""" + # Extract all valid markdown header lines + header_levels = [] + pattern = re.compile(r"^#{1,6}\s+.+") + for line in text.split("\n"): + stripped_line = line.strip() + if pattern.match(stripped_line): + hash_match = re.match(r"^(#+)", stripped_line) + if hash_match: + level = len(hash_match.group(1)) + header_levels.append(level) + + total_headers = len(header_levels) + if total_headers == 0: + logger.debug("No valid headers detected, skipping check") + return False + + # Calculate level-1 header ratio + level1_count = sum(1 for level in header_levels if level == 1) + + # Determine if malformed: >90% are level-1 when total > 5 + # OR all headers are level-1 when total ≤ 5 + if total_headers > 5: + level1_ratio = level1_count / total_headers + if level1_ratio > 0.9: + logger.warning( + f"Detected header hierarchy issue: {level1_count}/{total_headers} " + f"({level1_ratio:.1%}) of headers are level 1" + ) + return True + elif total_headers <= 5 and level1_count == total_headers: + logger.warning( + f"Detected header hierarchy issue: all {total_headers} headers are level 1" + ) + return True + return False + + def _fix_header_hierarchy(self, text: str) -> str: + """ + Fix markdown header hierarchy by adjusting levels. + + Strategy: + 1. Keep the first header unchanged as level-1 parent + 2. Increment all subsequent headers by 1 level (max level 6) + """ + header_pattern = re.compile(r"^(#{1,6})\s+(.+)$") + lines = text.split("\n") + fixed_lines = [] + first_valid_header = False + + for line in lines: + stripped_line = line.strip() + # Match valid header lines (invalid # lines kept as-is) + header_match = header_pattern.match(stripped_line) + if header_match: + current_hashes, title_content = header_match.groups() + current_level = len(current_hashes) + + if not first_valid_header: + # First valid header: keep original level unchanged + fixed_line = f"{current_hashes} {title_content}" + first_valid_header = True + logger.debug( + f"Keep first header at level {current_level}: {title_content[:50]}..." + ) + else: + # Subsequent headers: increment by 1, cap at level 6 + new_level = min(current_level + 1, 6) + new_hashes = "#" * new_level + fixed_line = f"{new_hashes} {title_content}" + logger.debug( + f"Adjust header level: {current_level} -> {new_level}: {title_content[:50]}..." + ) + fixed_lines.append(fixed_line) + else: + fixed_lines.append(line) + + # Join with newlines to preserve original formatting + fixed_text = "\n".join(fixed_lines) + logger.info(f"[Chunker:] Header hierarchy fix completed: {fixed_text[:50]}...") + return fixed_text diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index f39dfb8e2..e695d0d9a 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -43,11 +43,13 @@ def __init__(self, config: SentenceChunkerConfig): def chunk(self, text: str) -> list[str] | list[Chunk]: """Chunk the given text into smaller chunks based on sentences.""" - chonkie_chunks = self.chunker.chunk(text) + protected_text, url_map = self.protect_urls(text) + chonkie_chunks = self.chunker.chunk(protected_text) chunks = [] for c in chonkie_chunks: chunk = Chunk(text=c.text, token_count=c.token_count, sentences=c.sentences) + chunk = self.restore_urls(chunk.text, url_map) chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") diff --git a/src/memos/chunkers/simple_chunker.py b/src/memos/chunkers/simple_chunker.py index cc0dc40d0..58e12e2f1 100644 --- a/src/memos/chunkers/simple_chunker.py +++ b/src/memos/chunkers/simple_chunker.py @@ -20,12 +20,15 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> Returns: List of text chunks """ - if not text or len(text) <= chunk_size: - return [text] if text.strip() else [] + protected_text, url_map = self.protect_urls(text) + + if not protected_text or len(protected_text) <= chunk_size: + chunks = [protected_text] if protected_text.strip() else [] + return [self.restore_urls(chunk, url_map) for chunk in chunks] chunks = [] start = 0 - text_len = len(text) + text_len = len(protected_text) while start < text_len: # Calculate end position @@ -35,16 +38,16 @@ def _simple_split_text(self, text: str, chunk_size: int, chunk_overlap: int) -> if end < text_len: # Try to break at newline, sentence end, or space for separator in ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " "]: - last_sep = text.rfind(separator, start, end) + last_sep = protected_text.rfind(separator, start, end) if last_sep != -1: end = last_sep + len(separator) break - chunk = text[start:end].strip() + chunk = protected_text[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = max(start + 1, end - chunk_overlap) - return chunks + return [self.restore_urls(chunk, url_map) for chunk in chunks] diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 9b1ce7f9d..5900d2357 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -202,6 +202,33 @@ class PolarDBGraphDBConfig(BaseConfig): default=100, description="Maximum number of connections in the connection pool", ) + connection_wait_timeout: int = Field( + default=30, + ge=1, + le=3600, + description="Max seconds to wait for a connection slot before raising (0 = wait forever, not recommended)", + ) + skip_connection_health_check: bool = Field( + default=False, + description=( + "If True, skip SELECT 1 health check when getting connections (~1-2ms saved per request). " + "Use only when pool/network is reliable." + ), + ) + warm_up_on_startup_by_full: bool = Field( + default=True, + description=( + "If True, run search_by_fulltext warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) + warm_up_on_startup_by_all: bool = Field( + default=False, + description=( + "If True, run all connection warm-up on pool connections at init to reduce " + "first-query latency (~200ms planning). Requires user_name in config." + ), + ) @model_validator(mode="after") def validate_config(self): diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index a28f3bdce..9807f42c3 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -250,8 +250,12 @@ def validate_partial_initialization(self) -> "AuthConfig": "All configuration components are None. This may indicate missing environment variables or configuration files." ) elif failed_components: - logger.warning( - f"Failed to initialize components: {', '.join(failed_components)}. Successfully initialized: {', '.join(initialized_components)}" + # Use info level: individual from_local_env() methods already log + # warnings for actual initialization failures. Components that are + # simply not configured (no env vars) are not errors. + logger.info( + f"Components not configured: {', '.join(failed_components)}. " + f"Successfully initialized: {', '.join(initialized_components)}" ) return self diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 2b3bd0967..c71ed6b5a 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -14,6 +14,21 @@ logger = get_logger(__name__) +def _sanitize_unicode(text: str) -> str: + """ + Remove Unicode surrogates and other problematic characters. + Surrogates (U+D800-U+DFFF) cause UnicodeEncodeError with some APIs. + """ + try: + # Encode with 'surrogatepass' then decode, replacing invalid chars + cleaned = text.encode("utf-8", errors="surrogatepass").decode("utf-8", errors="replace") + # Replace replacement char with empty string for cleaner output + return cleaned.replace("\ufffd", "") + except Exception: + # Fallback: remove all non-BMP characters + return "".join(c for c in text if ord(c) < 0x10000) + + class UniversalAPIEmbedder(BaseEmbedder): def __init__(self, config: UniversalAPIEmbedderConfig): self.provider = config.provider @@ -54,6 +69,8 @@ def __init__(self, config: UniversalAPIEmbedderConfig): def embed(self, texts: list[str]) -> list[list[float]]: if isinstance(texts, str): texts = [texts] + # Sanitize Unicode to prevent encoding errors with emoji/surrogates + texts = [_sanitize_unicode(t) for t in texts] # Truncate texts if max_tokens is configured texts = self._truncate_texts(texts) logger.info(f"Embeddings request with input: {texts}") diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 09ad46c42..470d8cd8e 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -140,8 +140,6 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N metadata.setdefault("delete_record_id", "") embedding = metadata.pop("embedding", None) - if embedding is None: - raise ValueError(f"Missing 'embedding' in metadata for node {node_id}") vector_sync_status = "success" vec_items.append( diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 592f45a7f..ad75f4b65 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,9 +1,10 @@ import json import random import textwrap +import threading import time -from contextlib import suppress +from contextlib import contextmanager from datetime import datetime from typing import Any, Literal @@ -136,7 +137,11 @@ def __init__(self, config: PolarDBGraphDBConfig): port = config.get("port") user = config.get("user") password = config.get("password") - maxconn = config.get("maxconn", 100) # De + maxconn = config.get("maxconn", 100) + self._connection_wait_timeout = config.get("connection_wait_timeout", 60) + self._skip_connection_health_check = config.get("skip_connection_health_check", False) + self._warm_up_on_startup_by_full = config.get("warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = config.get("warm_up_on_startup_by_all", False) else: self.db_name = config.db_name self.user_name = config.user_name @@ -145,13 +150,19 @@ def __init__(self, config: PolarDBGraphDBConfig): user = config.user password = config.password maxconn = config.maxconn if hasattr(config, "maxconn") else 100 - """ - # Create connection - self.connection = psycopg2.connect( - host=host, port=port, user=user, password=password, dbname=self.db_name,minconn=10, maxconn=2000 + self._connection_wait_timeout = getattr(config, "connection_wait_timeout", 60) + self._skip_connection_health_check = getattr( + config, "skip_connection_health_check", False + ) + self._warm_up_on_startup_by_full = getattr(config, "warm_up_on_startup_by_full", False) + self._warm_up_on_startup_by_all = getattr(config, "warm_up_on_startup_by_all", False) + logger.info( + f"polardb init config connection_wait_timeout:{self._connection_wait_timeout},_skip_connection_health_check:{self._skip_connection_health_check},warm_up_on_startup_by_full:{self._warm_up_on_startup_by_full},warm_up_on_startup_by_all:{self._warm_up_on_startup_by_all}" + ) + + logger.info( + f" db_name: {self.db_name} maxconn: {maxconn} connection_wait_timeout: {self._connection_wait_timeout}s" ) - """ - logger.info(f" db_name: {self.db_name} current maxconn is:'{maxconn}'") # Create connection pool self.connection_pool = psycopg2.pool.ThreadedConnectionPool( @@ -162,14 +173,17 @@ def __init__(self, config: PolarDBGraphDBConfig): user=user, password=password, dbname=self.db_name, - connect_timeout=60, # Connection timeout in seconds - keepalives_idle=40, # Seconds of inactivity before sending keepalive (should be < server idle timeout) + connect_timeout=10, # Connection timeout in seconds + keepalives_idle=120, # Seconds of inactivity before sending keepalive (should be < server idle timeout) keepalives_interval=15, # Seconds between keepalive retries keepalives_count=5, # Number of keepalive retries before considering connection dead ) - # Keep a reference to the pool for cleanup - self._pool_closed = False + self._semaphore = threading.BoundedSemaphore(maxconn) + if self._warm_up_on_startup_by_full: + self._warm_up_search_connections_by_full() + if self._warm_up_on_startup_by_all: + self._warm_up_connections_by_all() """ # Handle auto_create @@ -194,194 +208,81 @@ def _get_config_value(self, key: str, default=None): else: return getattr(self.config, key, default) - def _get_connection_old(self): - """Get a connection from the pool.""" - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - conn = self.connection_pool.getconn() - # Set autocommit for PolarDB compatibility - conn.autocommit = True - return conn - - def _get_connection(self): - logger.info(f" db_name: {self.db_name} pool maxconn is:'{self.connection_pool.maxconn}'") - if self._pool_closed: - raise RuntimeError("Connection pool has been closed") - - max_retries = 500 - import psycopg2.pool - - for attempt in range(max_retries): - conn = None + def _warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + logger.info("--warm_up_search_connections_by_full--start-up----") + user_name = user_name or self.user_name + if not user_name: + logger.debug("[warm_up] Skipped: no user_name for warm-up") + return + warm_count = min(5, self.connection_pool.minconn) + for _ in range(warm_count): try: - conn = self.connection_pool.getconn() - - if conn.closed != 0: - logger.warning( - f"[_get_connection] Got closed connection, attempt {attempt + 1}/{max_retries}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as e: - logger.warning( - f"[_get_connection] Failed to return closed connection to pool: {e}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError("Pool returned a closed connection after all retries") - - # Set autocommit for PolarDB compatibility - conn.autocommit = True - - # Test connection health with SELECT 1 - try: - cursor = conn.cursor() - cursor.execute("SELECT 1") - cursor.fetchone() - cursor.close() - except Exception as health_check_error: - # Connection is not usable, return it to pool with close flag and try again - logger.warning( - f"[_get_connection] Connection health check failed (attempt {attempt + 1}/{max_retries}): {health_check_error}" - ) - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return unhealthy connection to pool: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - conn = None - if attempt < max_retries - 1: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" - ) from health_check_error - - # Connection is healthy, return it - return conn - - except psycopg2.pool.PoolError as pool_error: - error_msg = str(pool_error).lower() - if "exhausted" in error_msg or "pool" in error_msg: - # Log pool status for debugging - try: - # Try to get pool stats if available - pool_info = f"Pool config: minconn={self.connection_pool.minconn}, maxconn={self.connection_pool.maxconn}" - logger.info( - f" polardb get_connection Connection pool exhausted (attempt {attempt + 1}/{max_retries}). {pool_info}" - ) - except Exception: - logger.warning( - f"[_get_connection] Connection pool exhausted (attempt {attempt + 1}/{max_retries})" - ) - - # For pool exhaustion, wait longer before retry (connections may be returned) - if attempt < max_retries - 1: - # Longer backoff for pool exhaustion: 0.5s, 1.0s, 2.0s - wait_time = 0.5 * (2**attempt) - logger.info(f"[_get_connection] Waiting {wait_time}s before retry...") - """time.sleep(wait_time)""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Connection pool exhausted after {max_retries} attempts. " - f"This usually means connections are not being returned to the pool. " - ) from pool_error - else: - # Other pool errors - retry with normal backoff - if attempt < max_retries - 1: - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) - continue - else: - raise RuntimeError( - f"Failed to get connection from pool: {pool_error}" - ) from pool_error - + self.search_by_fulltext( + query_words=["warmup"], + top_k=1, + user_name=user_name, + ) except Exception as e: - if conn is not None: - try: - self.connection_pool.putconn(conn, close=True) - except Exception as putconn_error: - logger.warning( - f"[_get_connection] Failed to return connection after error: {putconn_error}" - ) - with suppress(Exception): - conn.close() - - if attempt >= max_retries - 1: - raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e - else: - # Exponential backoff: 0.1s, 0.2s, 0.4s - """time.sleep(0.1 * (2**attempt))""" - time.sleep(0.003) + logger.debug(f"[warm_up] Warm-up query failed (non-fatal): {e}") + break + logger.info(f"[warm_up] Pre-warmed {warm_count} connections for search_by_fulltext") + + def warm_up_search_connections_by_full(self, user_name: str | None = None) -> None: + self._warm_up_search_connections_by_full(user_name) + + def _warm_up_connections_by_all(self): + logger.info("--_warm_up_connections_by_all--start-up") + warm_count = self.connection_pool.minconn + preheated = 0 + logger.info(f"[warm_up] Pre-warming {warm_count} connections...") + for _ in range(warm_count): + try: + with self._get_connection() as conn, conn.cursor() as cur: + cur.execute("SELECT 1") + preheated += 1 + except Exception as e: + logger.warning(f"[warm_up] Failed to pre-warm connection: {e}") continue + logger.info(f"[warm_up] Pre-warmed {preheated}/{warm_count} connections") - # Should never reach here, but just in case - raise RuntimeError("Failed to get connection after all retries") - - def _return_connection(self, connection): - if self._pool_closed: - if connection: - try: - connection.close() - logger.debug("[_return_connection] Closed connection (pool is closed)") - except Exception as e: - logger.warning( - f"[_return_connection] Failed to close connection after pool closed: {e}" - ) - return - - if not connection: - return + @contextmanager + def _get_connection(self): + timeout = self._connection_wait_timeout + if timeout <= 0: + self._semaphore.acquire() + else: + if not self._semaphore.acquire(timeout=timeout): + logger.warning(f"Timeout waiting for connection slot ({timeout}s)") + raise RuntimeError( + f"Connection pool busy: acquire a slot within {timeout}s (all connections in use)." + ) + logger.info( + "Connection pool usage: %s/%s", + self.connection_pool.maxconn - self._semaphore._value, + self.connection_pool.maxconn, + ) + conn = None + broken = False try: - if hasattr(connection, "closed") and connection.closed != 0: - logger.debug( - "[_return_connection] Connection is closed, closing it instead of returning to pool" - ) + conn = self.connection_pool.getconn() + logger.debug(f"Acquired connection {id(conn)} from pool") + conn.autocommit = True + with conn.cursor() as cur: + cur.execute(f'SET search_path = {self.db_name}_graph, ag_catalog, "$user", public;') + yield conn + except Exception as e: + broken = True + logger.exception(f"Connection failed or broken: {e}") + raise + finally: + if conn: try: - connection.close() + self.connection_pool.putconn(conn, close=broken) + logger.debug(f"Returned connection {id(conn)} to pool (broken={broken})") except Exception as e: - logger.warning(f"[_return_connection] Failed to close closed connection: {e}") - return - - self.connection_pool.putconn(connection) - logger.debug("[_return_connection] Successfully returned connection to pool") - except Exception as e: - logger.error( - f"[_return_connection] Failed to return connection to pool: {e}", exc_info=True - ) - try: - connection.close() - logger.debug( - "[_return_connection] Closed connection as fallback after putconn failure" - ) - except Exception as close_error: - logger.warning( - f"[_return_connection] Failed to close connection after putconn error: {close_error}" - ) - - def _return_connection_old(self, connection): - """Return a connection to the pool.""" - if not self._pool_closed and connection: - self.connection_pool.putconn(connection) + logger.warning(f"Failed to return connection to pool: {e}") + self._semaphore.release() def _ensure_database_exists(self): """Create database if it doesn't exist.""" @@ -396,11 +297,8 @@ def _ensure_database_exists(self): @timed def _create_graph(self): """Create PostgreSQL schema and table for graph storage.""" - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Create schema if it doesn't exist cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.db_name}_graph";') logger.info(f"Schema '{self.db_name}_graph' ensured.") @@ -448,8 +346,6 @@ def _create_graph(self): except Exception as e: logger.error(f"Failed to create graph schema: {e}") raise e - finally: - self._return_connection(conn) def create_index( self, @@ -462,11 +358,8 @@ def create_index( Create indexes for embedding and other fields. Note: This creates PostgreSQL indexes on the underlying tables. """ - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" @@ -486,8 +379,6 @@ def create_index( logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") - finally: - self._return_connection(conn) def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: """Get count of memory nodes by type.""" @@ -500,19 +391,14 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [self.format_param_value(memory_type), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return result[0] if result else 0 except Exception as e: logger.error(f"[get_memory_count] Failed: {e}") return -1 - finally: - self._return_connection(conn) @timed def node_not_exist(self, scope: str, user_name: str | None = None) -> int: @@ -527,19 +413,14 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [self.format_param_value(scope), self.format_param_value(user_name)] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def remove_oldest_memory( @@ -569,10 +450,8 @@ def remove_oldest_memory( self.format_param_value(user_name), keep_latest, ] - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Execute query to get IDs to delete cursor.execute(select_query, select_params) ids_to_delete = [row[0] for row in cursor.fetchall()] @@ -584,9 +463,9 @@ def remove_oldest_memory( # Build delete query placeholders = ",".join(["%s"] * len(ids_to_delete)) delete_query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" - WHERE id IN ({placeholders}) - """ + DELETE FROM "{self.db_name}_graph"."Memory" + WHERE id IN ({placeholders}) + """ delete_params = ids_to_delete # Execute deletion @@ -600,8 +479,6 @@ def remove_oldest_memory( except Exception as e: logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: @@ -663,17 +540,12 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[update_node] Failed to update node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_node(self, id: str, user_name: str | None = None) -> None: @@ -694,26 +566,18 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(self.format_param_value(user_name)) - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) except Exception as e: logger.error(f"[delete_node] Failed to delete node '{id}': {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def create_extension(self): extensions = [("polar_age", "Graph engine"), ("vector", "Vector engine")] - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] @@ -736,20 +600,15 @@ def create_extension(self): except Exception as e: logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_graph(self): - # Get a connection from the pool - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph - WHERE name = '{self.db_name}_graph'; - """) + SELECT COUNT(*) FROM ag_catalog.ag_graph + WHERE name = '{self.db_name}_graph'; + """) graph_exists = cursor.fetchone()[0] > 0 if graph_exists: @@ -760,8 +619,6 @@ def create_graph(self): except Exception as e: logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def create_edge(self): @@ -770,11 +627,9 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - conn = None logger.info(f"Creating elabel: {label_name}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") logger.info(f"Successfully created elabel: {label_name}") except Exception as e: @@ -783,8 +638,6 @@ def create_edge(self): else: logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) - finally: - self._return_connection(conn) @timed def add_edge( @@ -825,10 +678,8 @@ def add_edge( ); """ logger.info(f"polardb [add_edge] query: {query}, properties: {json.dumps(properties)}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, (source_id, target_id, type, json.dumps(properties))) logger.info(f"Edge created: {source_id} -[{type}]-> {target_id}") @@ -837,8 +688,6 @@ def add_edge( except Exception as e: logger.error(f"Failed to insert edge: {e}", exc_info=True) raise - finally: - self._return_connection(conn) @timed def delete_edge(self, source_id: str, target_id: str, type: str) -> None: @@ -853,14 +702,9 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: DELETE FROM "{self.db_name}_graph"."Edges" WHERE source_id = %s AND target_id = %s AND edge_type = %s """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, (source_id, target_id, type)) - logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, (source_id, target_id, type)) + logger.info(f"Edge deleted: {source_id} -[{type}]-> {target_id}") @timed def edge_exists_old( @@ -915,15 +759,10 @@ def edge_exists_old( WHERE {where_clause} LIMIT 1 """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - result = cursor.fetchone() - return result is not None - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + result = cursor.fetchone() + return result is not None @timed def edge_exists( @@ -971,15 +810,10 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchone() - return result is not None and result[0] is not None - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchone() + return result is not None and result[0] is not None @timed def get_node( @@ -1015,10 +849,8 @@ def get_node( params.append(self.format_param_value(user_name)) logger.info(f"polardb [get_node] query: {query},params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() @@ -1067,8 +899,6 @@ def get_node( except Exception as e: logger.error(f"[get_node] Failed to retrieve node '{id}': {e}", exc_info=True) return None - finally: - self._return_connection(conn) @timed def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, Any]]: @@ -1105,50 +935,45 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, logger.info(f"get_nodes query:{query},params:{params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() - nodes = [] - for row in results: - node_id, properties_json, embedding_json = row - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse properties for node {node_id}") - properties = {} - else: - properties = properties_json if properties_json else {} + nodes = [] + for row in results: + node_id, properties_json, embedding_json = row + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse properties for node {node_id}") + properties = {} + else: + properties = properties_json if properties_json else {} - # Parse embedding from JSONB if it exists - if embedding_json is not None and kwargs.get("include_embedding"): - try: - # remove embedding - embedding = ( - json.loads(embedding_json) - if isinstance(embedding_json, str) - else embedding_json - ) - properties["embedding"] = embedding - except (json.JSONDecodeError, TypeError): - logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } + # Parse embedding from JSONB if it exists + if embedding_json is not None and kwargs.get("include_embedding"): + try: + # remove embedding + embedding = ( + json.loads(embedding_json) + if isinstance(embedding_json, str) + else embedding_json ) + properties["embedding"] = embedding + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse embedding for node {node_id}") + nodes.append( + self._parse_node( + { + "id": properties.get("id", node_id), + "memory": properties.get("memory", ""), + "metadata": properties, + } ) - return nodes - finally: - self._return_connection(conn) + ) + return nodes @timed def get_edges_old( @@ -1366,10 +1191,8 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1424,8 +1247,6 @@ def get_children_with_embeddings( except Exception as e: logger.error(f"[get_children_with_embeddings] Failed: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: """Get the path of nodes from source to target within a limited depth.""" @@ -1507,11 +1328,9 @@ def get_subgraph( RETURN collect(DISTINCT center), collect(DISTINCT neighbor), collect(DISTINCT r1) $$ ) as (centers agtype, neighbors agtype, rels agtype); """ - conn = None logger.info(f"[get_subgraph] Query: {query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -1636,8 +1455,6 @@ def get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) return {"core_node": None, "neighbors": [], "edges": []} - finally: - self._return_connection(conn) def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" @@ -1751,29 +1568,24 @@ def search_by_keywords_like( logger.info( f"[search_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - logger.info( - f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + logger.info( + f"[search_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_keywords_tfidf( @@ -1859,30 +1671,25 @@ def search_by_keywords_tfidf( logger.info( f"[search_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - item = {"id": id_val} - if return_fields: - properties = row[2] # properties column - item.update(self._extract_fields_from_properties(properties, return_fields)) - output.append(item) - - logger.info( - f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" - ) - return output - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + item = {"id": id_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + + logger.info( + f"[search_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) + return output @timed def search_by_fulltext( @@ -1901,32 +1708,19 @@ def search_by_fulltext( return_fields: list[str] | None = None, **kwargs, ) -> list[dict]: - """ - Full-text search functionality using PostgreSQL's full-text search capabilities. - - Args: - query_text: query text - top_k: maximum number of results to return - scope: memory type filter (memory_type) - status: status filter, defaults to "activated" - threshold: similarity threshold filter - search_filter: additional property filter conditions - user_name: username filter - knowledgebase_ids: knowledgebase ids filter - filter: filter conditions with 'and' or 'or' logic for search results. - tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 - tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) - return_fields: additional node fields to include in results - **kwargs: other parameters (e.g. cube_name) - - Returns: - list[dict]: result list containing id and score. - If return_fields is specified, each dict also includes the requested fields. - """ + start_time = time.perf_counter() logger.info( - f"[search_by_fulltext] query_words: {query_words},top_k:{top_k},scope:{scope},status:{status},threshold:{threshold},search_filter:{search_filter},user_name:{user_name},knowledgebase_ids:{knowledgebase_ids},filter:{filter}" + " search_by_fulltext query_words=%s top_k=%s scope=%s status=%s threshold=%s search_filter=%s user_name=%s knowledgebase_ids=%s filter=%s", + query_words, + top_k, + scope, + status, + threshold, + search_filter, + user_name, + knowledgebase_ids, + filter, ) - start_time = time.time() where_clauses = [] if scope: @@ -1942,22 +1736,18 @@ def search_by_fulltext( "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" ) - # Build user_name filter with knowledgebase_ids support (OR relationship) using common method user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, default_user_name=self.config.user_name, ) - logger.info(f"[search_by_fulltext] user_name_conditions: {user_name_conditions}") - # Add OR condition if we have any user_name conditions if user_name_conditions: if len(user_name_conditions) == 1: where_clauses.append(user_name_conditions[0]) else: where_clauses.append(f"({' OR '.join(user_name_conditions)})") - # Add search_filter conditions if search_filter: for key, value in search_filter.items(): if isinstance(value, str): @@ -1970,17 +1760,12 @@ def search_by_fulltext( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_fulltext] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) tsquery_string = " | ".join(query_words) where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - logger.info(f"[search_by_fulltext] where_clause: {where_clause}") - select_cols = f"""ag_catalog.agtype_access_operator(m.properties, '"id"'::agtype) AS old_id, ts_rank(m.{tsvector_field}, q.fq) AS rank""" if return_fields: @@ -1997,6 +1782,7 @@ def search_by_fulltext( ) where_clause_cte = f"WHERE {' AND '.join(where_with_q)}" if where_with_q else "" query = f""" + /*+ Set(max_parallel_workers_per_gather 0) */ WITH q AS (SELECT to_tsquery('{tsquery_config}', %s) AS fq) SELECT {select_cols} FROM "{self.db_name}_graph"."Memory" m @@ -2006,39 +1792,31 @@ def search_by_fulltext( LIMIT {top_k}; """ params = [tsquery_string] - logger.info(f"[search_by_fulltext] query: {query}, params: {params}") - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query, params) - results = cursor.fetchall() - output = [] - for row in results: - oldid = row[0] # old_id - rank = row[1] # rank score (no memory_text column) - - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(rank) - - # Apply threshold filter if specified - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[2] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_fulltext query completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] - finally: - self._return_connection(conn) + logger.info("search_by_fulltext query=%s params=%s", query, params) + + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[1] # rank score (no memory_text column) + + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[2] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed = (time.perf_counter() - start_time) * 1000 + logger.info("search_by_fulltext internal took %.1f ms", elapsed) + return output[:top_k] @timed def search_by_embedding( @@ -2056,9 +1834,18 @@ def search_by_embedding( **kwargs, ) -> list[dict]: logger.info( - f"search_by_embedding user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},scope:{scope},status:{status},search_filter:{search_filter},filter:{filter},knowledgebase_ids:{knowledgebase_ids},return_fields:{return_fields}" + "search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s", + user_name, + filter, + knowledgebase_ids, + scope, + status, + search_filter, + filter, + knowledgebase_ids, + return_fields, ) - start_time = time.time() + start_time = time.perf_counter() where_clauses = [] if scope: where_clauses.append( @@ -2097,7 +1884,6 @@ def search_by_embedding( ) filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[search_by_embedding] filter_conditions: {filter_conditions}") where_clauses.extend(filter_conditions) where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -2133,44 +1919,37 @@ def search_by_embedding( else: pass - logger.info(f"[search_by_embedding] query: {query}, params: {params}") + logger.info(" search_by_embedding query: %s", query) - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - if params: - cursor.execute(query, params) - else: - cursor.execute(query) - results = cursor.fetchall() - output = [] - for row in results: - if len(row) < 5: - logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") - continue - oldid = row[3] # old_id - score = row[4] # scope - id_val = str(oldid) - if id_val.startswith('"') and id_val.endswith('"'): - id_val = id_val[1:-1] - score_val = float(score) - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - item = {"id": id_val, "score": score_val} - if return_fields: - properties = row[1] # properties column - item.update( - self._extract_fields_from_properties(properties, return_fields) - ) - output.append(item) - elapsed_time = time.time() - start_time - logger.info( - f" polardb search_by_embedding query embedding completed time in {elapsed_time:.2f}s" - ) - return output[:top_k] - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + results = cursor.fetchall() + output = [] + for row in results: + if len(row) < 5: + logger.warning(f"Row has {len(row)} columns, expected 5. Row: {row}") + continue + oldid = row[3] # old_id + score = row[4] # scope + id_val = str(oldid) + if id_val.startswith('"') and id_val.endswith('"'): + id_val = id_val[1:-1] + score_val = float(score) + score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score + if threshold is None or score_val >= threshold: + item = {"id": id_val, "score": score_val} + if return_fields: + properties = row[1] # properties column + item.update(self._extract_fields_from_properties(properties, return_fields)) + output.append(item) + elapsed_time = time.perf_counter() - start_time + logger.info( + "search_by_embedding query embedding completed time took %.1f ms", elapsed_time + ) + return output[:top_k] @timed def get_by_metadata( @@ -2285,18 +2064,14 @@ def get_by_metadata( """ ids = [] - conn = None logger.info(f"[get_by_metadata] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() ids = [str(item[0]).strip('"') for item in results] except Exception as e: logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}") - finally: - self._return_connection(conn) return ids @@ -2448,10 +2223,8 @@ def get_grouped_counts( {where_clause} GROUP BY {", ".join(group_by_fields)} """ - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Handle parameterized query if params and isinstance(params, list): cursor.execute(query, params) @@ -2476,8 +2249,6 @@ def get_grouped_counts( except Exception as e: logger.error(f"Failed to get grouped counts: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def deduplicate_nodes(self) -> None: """Deduplicate redundant or semantically similar nodes.""" @@ -2509,14 +2280,9 @@ def clear(self, user_name: str | None = None) -> None: DETACH DELETE n $$) AS (result agtype) """ - conn = None - try: - conn = self._get_connection() - with conn.cursor() as cursor: - cursor.execute(query) - logger.info("Cleared all nodes from database.") - finally: - self._return_connection(conn) + with self._get_connection() as conn, conn.cursor() as cursor: + cursor.execute(query) + logger.info("Cleared all nodes from database.") except Exception as e: logger.error(f"[ERROR] Failed to clear database: {e}") @@ -2585,132 +2351,129 @@ def export_graph( else: offset = None - conn = None try: - conn = self._get_connection() - # Build WHERE conditions - where_conditions = [] - if user_name: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" - ) - if user_id: - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" - ) + with self._get_connection() as conn: + # Build WHERE conditions + where_conditions = [] + if user_name: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + if user_id: + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_id\"'::agtype) = '\"{user_id}\"'::agtype" + ) - # Add memory_type filter condition - if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: - # Escape memory_type values and build IN clause - memory_type_values = [] - for mt in memory_type: - # Escape single quotes in memory_type value - escaped_memory_type = str(mt).replace("'", "''") - memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") - memory_type_in_clause = ", ".join(memory_type_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" - ) + # Add memory_type filter condition + if memory_type and isinstance(memory_type, list) and len(memory_type) > 0: + # Escape memory_type values and build IN clause + memory_type_values = [] + for mt in memory_type: + # Escape single quotes in memory_type value + escaped_memory_type = str(mt).replace("'", "''") + memory_type_values.append(f"'\"{escaped_memory_type}\"'::agtype") + memory_type_in_clause = ", ".join(memory_type_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) IN ({memory_type_in_clause})" + ) - # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list - if status is None: - # Default behavior: exclude deleted entries - where_conditions.append( - "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" - ) - elif isinstance(status, list) and len(status) > 0: - # status IN (list) - status_values = [] - for st in status: - escaped_status = str(st).replace("'", "''") - status_values.append(f"'\"{escaped_status}\"'::agtype") - status_in_clause = ", ".join(status_values) - where_conditions.append( - f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" - ) + # Add status filter condition: if not passed, exclude deleted; otherwise filter by IN list + if status is None: + # Default behavior: exclude deleted entries + where_conditions.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) <> '\"deleted\"'::agtype" + ) + elif isinstance(status, list) and len(status) > 0: + # status IN (list) + status_values = [] + for st in status: + escaped_status = str(st).replace("'", "''") + status_values.append(f"'\"{escaped_status}\"'::agtype") + status_in_clause = ", ".join(status_values) + where_conditions.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) IN ({status_in_clause})" + ) - # Build filter conditions using common method - filter_conditions = self._build_filter_conditions_sql(filter) - logger.info(f"[export_graph] filter_conditions: {filter_conditions}") - if filter_conditions: - where_conditions.extend(filter_conditions) - - where_clause = "" - if where_conditions: - where_clause = f"WHERE {' AND '.join(where_conditions)}" - - # Get total count of nodes before pagination - count_node_query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" - {where_clause} - """ - logger.info(f"[export_graph nodes count] Query: {count_node_query}") - with conn.cursor() as cursor: - cursor.execute(count_node_query) - total_nodes = cursor.fetchone()[0] - - # Export nodes - # Build pagination clause if needed - pagination_clause = "" - if use_pagination: - pagination_clause = f"LIMIT {page_size} OFFSET {offset}" - - if include_embedding: - node_query = f""" - SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" - {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} - """ - else: - node_query = f""" - SELECT id, properties + # Build filter conditions using common method + filter_conditions = self._build_filter_conditions_sql(filter) + logger.info(f"[export_graph] filter_conditions: {filter_conditions}") + if filter_conditions: + where_conditions.extend(filter_conditions) + + where_clause = "" + if where_conditions: + where_clause = f"WHERE {' AND '.join(where_conditions)}" + + # Get total count of nodes before pagination + count_node_query = f""" + SELECT COUNT(*) FROM "{self.db_name}_graph"."Memory" {where_clause} - ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, - id DESC - {pagination_clause} """ - logger.info(f"[export_graph nodes] Query: {node_query}") - with conn.cursor() as cursor: - cursor.execute(node_query) - node_results = cursor.fetchall() - nodes = [] - - for row in node_results: - if include_embedding: - """row is (id, properties, embedding)""" - _, properties_json, embedding_json = row - else: - """row is (id, properties)""" - _, properties_json = row - embedding_json = None + logger.info(f"[export_graph nodes count] Query: {count_node_query}") + with conn.cursor() as cursor: + cursor.execute(count_node_query) + total_nodes = cursor.fetchone()[0] + + # Export nodes + # Build pagination clause if needed + pagination_clause = "" + if use_pagination: + pagination_clause = f"LIMIT {page_size} OFFSET {offset}" + + if include_embedding: + node_query = f""" + SELECT id, properties, embedding + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + else: + node_query = f""" + SELECT id, properties + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST, + id DESC + {pagination_clause} + """ + logger.info(f"[export_graph nodes] Query: {node_query}") + with conn.cursor() as cursor: + cursor.execute(node_query) + node_results = cursor.fetchall() + nodes = [] + + for row in node_results: + if include_embedding: + """row is (id, properties, embedding)""" + _, properties_json, embedding_json = row + else: + """row is (id, properties)""" + _, properties_json = row + embedding_json = None - # Parse properties from JSONB if it's a string - if isinstance(properties_json, str): - try: - properties = json.loads(properties_json) - except json.JSONDecodeError: - properties = {} - else: - properties = properties_json if properties_json else {} + # Parse properties from JSONB if it's a string + if isinstance(properties_json, str): + try: + properties = json.loads(properties_json) + except json.JSONDecodeError: + properties = {} + else: + properties = properties_json if properties_json else {} - # Remove embedding field if include_embedding is False - if not include_embedding: - properties.pop("embedding", None) - elif include_embedding and embedding_json is not None: - properties["embedding"] = embedding_json + # Remove embedding field if include_embedding is False + if not include_embedding: + properties.pop("embedding", None) + elif include_embedding and embedding_json is not None: + properties["embedding"] = embedding_json - nodes.append(self._parse_node(properties)) + nodes.append(self._parse_node(properties)) except Exception as e: logger.error(f"[EXPORT GRAPH - NODES] Exception: {e}", exc_info=True) raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - finally: - self._return_connection(conn) edges = [] return { @@ -2732,13 +2495,9 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: RETURN count(n) $$) AS (count agtype) """ - conn = None - try: - conn = self._get_connection() + with self._get_connection() as conn: result = self.execute_query(query, conn) return int(result.one_or_none()["count"].value) - finally: - self._return_connection(conn) @timed def get_all_memory_items( @@ -2825,18 +2584,16 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() for row in results: """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ if isinstance(row, list | tuple) and len(row) >= 2: embedding_val, node_val = row[0], row[1] else: @@ -2851,8 +2608,6 @@ def get_all_memory_items( except Exception as e: logger.warning(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes else: @@ -2879,29 +2634,25 @@ def get_all_memory_items( """ nodes = [] - conn = None logger.info(f"[get_all_memory_items] cypher_query: {cypher_query}") try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() for row in results: """ - if isinstance(row[0], str): - memory_data = json.loads(row[0]) - else: - memory_data = row[0] # 如果已经是字典,直接使用 - nodes.append(self._parse_node(memory_data)) - """ + if isinstance(row[0], str): + memory_data = json.loads(row[0]) + else: + memory_data = row[0] # 如果已经是字典,直接使用 + nodes.append(self._parse_node(memory_data)) + """ memory_data = json.loads(row[0]) if isinstance(row[0], str) else row[0] nodes.append(self._parse_node(memory_data)) except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) - finally: - self._return_connection(conn) return nodes @@ -3104,10 +2855,8 @@ def get_structure_optimization_candidates( candidates = [] node_ids = set() - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() logger.info(f"Found {len(results)} structure optimization candidates") @@ -3115,8 +2864,8 @@ def get_structure_optimization_candidates( if include_embedding: # When include_embedding=True, return full node object """ - if isinstance(row, (list, tuple)) and len(row) >= 2: - """ + if isinstance(row, (list, tuple)) and len(row) >= 2: + """ if isinstance(row, list | tuple) and len(row) >= 2: embedding_val, node_val = row[0], row[1] else: @@ -3184,8 +2933,6 @@ def get_structure_optimization_candidates( except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) - finally: - self._return_connection(conn) return candidates @@ -3355,60 +3102,59 @@ def add_node( elif len(embedding_vector) == 768: embedding_column = "embedding_768" - conn = None insert_query = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Delete existing record first (if any) - delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" - WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(delete_query, (id,)) - # - get_graph_id_query = f""" - SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) - """ - cursor.execute(get_graph_id_query, (id,)) - graph_id = cursor.fetchone()[0] - properties["graph_id"] = str(graph_id) - - # Then insert new record - if embedding_vector: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s, - %s - ) + with self._get_connection() as conn: + with conn.cursor() as cursor: + # Delete existing record first (if any) + delete_query = f""" + DELETE FROM {self.db_name}_graph."Memory" + WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ - cursor.execute( - insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) - ) - logger.info( - f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" - ) - else: - insert_query = f""" - INSERT INTO {self.db_name}_graph."Memory"(id, properties) - VALUES ( - ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), - %s + cursor.execute(delete_query, (id,)) + # + get_graph_id_query = f""" + SELECT ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) + """ + cursor.execute(get_graph_id_query, (id,)) + graph_id = cursor.fetchone()[0] + properties["graph_id"] = str(graph_id) + + # Then insert new record + if embedding_vector: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s, + %s + ) + """ + cursor.execute( + insert_query, (id, json.dumps(properties), json.dumps(embedding_vector)) ) - """ - cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-true] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + else: + insert_query = f""" + INSERT INTO {self.db_name}_graph."Memory"(id, properties) + VALUES ( + ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring), + %s + ) + """ + cursor.execute(insert_query, (id, json.dumps(properties))) + logger.info( + f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + ) + if insert_query: logger.info( - f"[add_node] [embedding_vector-false] insert_query: {insert_query}, properties: {json.dumps(properties)}" + f"In add node polardb: id-{id} memory-{memory} query-{insert_query}" ) except Exception as e: logger.error(f"[add_node] Failed to add node: {e}", exc_info=True) raise - finally: - if insert_query: - logger.info(f"In add node polardb: id-{id} memory-{memory} query-{insert_query}") - self._return_connection(conn) @timed def add_nodes_batch( @@ -3416,27 +3162,15 @@ def add_nodes_batch( nodes: list[dict[str, Any]], user_name: str | None = None, ) -> None: - """ - Batch add multiple memory nodes to the graph. + logger.info(f" add_nodes_batch Processing only first node (total nodes: {len(nodes)})") - Args: - nodes: List of node dictionaries, each containing: - - id: str - Node ID - - memory: str - Memory content - - metadata: dict[str, Any] - Node metadata - user_name: Optional user name (will use config default if not provided) - """ - batch_start_time = time.time() + batch_start_time = time.perf_counter() if not nodes: logger.warning("[add_nodes_batch] Empty nodes list, skipping") return - logger.info(f"[add_nodes_batch] Processing only first node (total nodes: {len(nodes)})") - - # user_name comes from parameter; fallback to config if missing effective_user_name = user_name if user_name else self.config.user_name - # Prepare all nodes prepared_nodes = [] for node_data in nodes: try: @@ -3446,16 +3180,13 @@ def add_nodes_batch( logger.debug(f"[add_nodes_batch] Processing node id: {id}") - # Set user_name in metadata metadata["user_name"] = effective_user_name metadata = _prepare_node_metadata(metadata) - # Merge node and set metadata created_at = metadata.pop("created_at", datetime.utcnow().isoformat()) updated_at = metadata.pop("updated_at", datetime.utcnow().isoformat()) - # Prepare properties properties = { "id": id, "memory": memory, @@ -3466,32 +3197,26 @@ def add_nodes_batch( **metadata, } - # Generate embedding if not provided if "embedding" not in properties or not properties["embedding"]: properties["embedding"] = generate_vector( self._get_config_value("embedding_dimension", 1024) ) - # Serialization - JSON-serialize sources and usage fields for field_name in ["sources", "usage"]: if properties.get(field_name): if isinstance(properties[field_name], list): for idx in range(len(properties[field_name])): - # Serialize only when element is not a string if not isinstance(properties[field_name][idx], str): properties[field_name][idx] = json.dumps( properties[field_name][idx] ) elif isinstance(properties[field_name], str): - # If already a string, leave as-is pass - # Extract embedding for separate column embedding_vector = properties.pop("embedding", []) if not isinstance(embedding_vector, list): embedding_vector = [] - # Select column name based on embedding dimension embedding_column = "embedding" # default column if len(embedding_vector) == 3072: embedding_column = "embedding_3072" @@ -3514,14 +3239,12 @@ def add_nodes_batch( f"[add_nodes_batch] Failed to prepare node {node_data.get('id', 'unknown')}: {e}", exc_info=True, ) - # Continue with other nodes continue if not prepared_nodes: logger.warning("[add_nodes_batch] No valid nodes to insert after preparation") return - # Group nodes by embedding column to optimize batch inserts nodes_by_embedding_column = {} for node in prepared_nodes: col = node["embedding_column"] @@ -3529,13 +3252,9 @@ def add_nodes_batch( nodes_by_embedding_column[col] = [] nodes_by_embedding_column[col].append(node) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: - # Process each group separately + with self._get_connection() as conn, conn.cursor() as cursor: for embedding_column, nodes_group in nodes_by_embedding_column.items(): - # Batch delete existing records using IN clause ids_to_delete = [node["id"] for node in nodes_group] if ids_to_delete: delete_query = f""" @@ -3546,7 +3265,6 @@ def add_nodes_batch( """ cursor.execute(delete_query, (ids_to_delete,)) - # Batch get graph_ids for all nodes get_graph_ids_query = f""" SELECT id_val, @@ -3556,21 +3274,16 @@ def add_nodes_batch( cursor.execute(get_graph_ids_query, (ids_to_delete,)) graph_id_map = {row[0]: row[1] for row in cursor.fetchall()} - # Add graph_id to properties for node in nodes_group: graph_id = graph_id_map.get(node["id"]) if graph_id: node["properties"]["graph_id"] = str(graph_id) - # Use PREPARE/EXECUTE for efficient batch insert - # Generate unique prepare statement name to avoid conflicts prepare_name = f"insert_mem_{embedding_column or 'no_embedding'}_{int(time.time() * 1000000)}" - try: if embedding_column and any( node["embedding_vector"] for node in nodes_group ): - # PREPARE statement for insert with embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties, {embedding_column}) @@ -3580,16 +3293,9 @@ def add_nodes_batch( $3::vector ) """ - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) embedding_json = ( @@ -3603,7 +3309,6 @@ def add_nodes_batch( (node["id"], properties_json, embedding_json), ) else: - # PREPARE statement for insert without embedding prepare_query = f""" PREPARE {prepare_name} AS INSERT INTO {self.db_name}_graph."Memory"(id, properties) @@ -3612,46 +3317,30 @@ def add_nodes_batch( $2::text::agtype ) """ - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_name: {prepare_name}" - ) - logger.info( - f"[add_nodes_batch] without embedding Preparing prepare_query: {prepare_query}" - ) cursor.execute(prepare_query) - # Execute prepared statement for each node for node in nodes_group: properties_json = json.dumps(node["properties"]) - cursor.execute( - f"EXECUTE {prepare_name}(%s, %s)", (node["id"], properties_json) + f"EXECUTE {prepare_name}(%s, %s)", + (node["id"], properties_json), ) finally: - # DEALLOCATE prepared statement (always execute, even on error) try: cursor.execute(f"DEALLOCATE {prepare_name}") - logger.info( - f"[add_nodes_batch] Deallocated prepared statement: {prepare_name}" - ) except Exception as dealloc_error: logger.warning( f"[add_nodes_batch] Failed to deallocate {prepare_name}: {dealloc_error}" ) - - logger.info( - f"[add_nodes_batch] Inserted {len(nodes_group)} nodes with embedding_column={embedding_column}" - ) - elapsed_time = time.time() - batch_start_time + elapsed_time = time.perf_counter() - batch_start_time logger.info( - f"[add_nodes_batch] PREPARE/EXECUTE batch insert completed successfully in {elapsed_time:.2f}s" + "add_nodes_batch batch insert completed successfully in took %.1f ms", + elapsed_time, ) except Exception as e: logger.error(f"[add_nodes_batch] Failed to add nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) def _build_node_from_agtype(self, node_agtype, embedding=None): """ @@ -3763,10 +3452,8 @@ def get_neighbors_by_tag( logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -3815,8 +3502,6 @@ def get_neighbors_by_tag( except Exception as e: logger.error(f"Failed to get neighbors by tag: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def get_neighbors_by_tag_ccl( self, @@ -4075,10 +3760,8 @@ def get_edges( $$) AS (from_id agtype, to_id agtype, edge_type agtype) """ logger.info(f"get_edges query:{query}") - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -4125,8 +3808,6 @@ def get_edges( except Exception as e: logger.error(f"Failed to get edges: {e}", exc_info=True) return [] - finally: - self._return_connection(conn) def _convert_graph_edges(self, core_node: dict) -> dict: import copy @@ -5132,11 +4813,9 @@ def delete_node_by_prams( ) return 0 - conn = None total_deleted_count = 0 try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: # Build WHERE conditions list where_conditions = [] @@ -5197,9 +4876,6 @@ def delete_node_by_prams( except Exception as e: logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True) raise - finally: - self._return_connection(conn) - logger.info(f"[delete_node_by_prams] Successfully deleted {total_deleted_count} nodes") return total_deleted_count @@ -5263,11 +4939,9 @@ def escape_memory_id(mid: str) -> str: """ logger.info(f"[get_user_names_by_memory_ids] query: {query}") - conn = None result_dict = {} try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() @@ -5307,8 +4981,6 @@ def escape_memory_id(mid: str) -> str: f"[get_user_names_by_memory_ids] Failed to get user names: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) def exist_user_name(self, user_name: str) -> dict[str, bool]: """Check if user name exists in the graph. @@ -5342,10 +5014,8 @@ def escape_user_name(un: str) -> str: """ logger.info(f"[exist_user_name] query: {query}") result_dict = {} - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: cursor.execute(query) count = cursor.fetchone()[0] result = count > 0 @@ -5356,8 +5026,6 @@ def escape_user_name(un: str) -> str: f"[exist_user_name] Failed to check user_name existence: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) @timed def delete_node_by_mem_cube_id( @@ -5381,10 +5049,8 @@ def delete_node_by_mem_cube_id( ) return 0 - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" user_name_param = self.format_param_value(mem_cube_id) @@ -5434,7 +5100,11 @@ def delete_node_by_mem_cube_id( logger.info( f"delete_node_by_mem_cube_id Soft delete update_query:{update_query},update_properties:{update_properties},deletetime:{current_time}" ) - update_params = [json.dumps(update_properties), current_time, user_name_param] + update_params = [ + json.dumps(update_properties), + current_time, + user_name_param, + ] cursor.execute(update_query, update_params) updated_count = cursor.rowcount @@ -5448,8 +5118,6 @@ def delete_node_by_mem_cube_id( f"[delete_node_by_mem_cube_id] Failed to delete/update nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) @timed def recover_memory_by_mem_cube_id( @@ -5476,10 +5144,8 @@ def recover_memory_by_mem_cube_id( f"delete_record_id={delete_record_id}" ) - conn = None try: - conn = self._get_connection() - with conn.cursor() as cursor: + with self._get_connection() as conn, conn.cursor() as cursor: user_name_condition = "ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" delete_record_id_condition = "ag_catalog.agtype_access_operator(properties, '\"delete_record_id\"'::agtype) = %s::agtype" where_clause = f"{user_name_condition} AND {delete_record_id_condition}" @@ -5523,5 +5189,3 @@ def recover_memory_by_mem_cube_id( f"[recover_memory_by_mem_cube_id] Failed to recover nodes: {e}", exc_info=True ) raise - finally: - self._return_connection(conn) diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 3afa78bab..b9395ea0d 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -20,7 +20,6 @@ class NaiveMemCube(BaseMemCube): def __init__( self, text_mem: BaseTextMemory | None = None, - pref_mem: BaseTextMemory | None = None, act_mem: BaseActMemory | None = None, para_mem: BaseParaMemory | None = None, ): @@ -28,19 +27,20 @@ def __init__( self._text_mem: BaseTextMemory = text_mem self._act_mem: BaseActMemory | None = act_mem self._para_mem: BaseParaMemory | None = para_mem - self._pref_mem: BaseTextMemory | None = pref_mem + # pref_mem removed - now handled by text_mem def load( self, dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] + Options: ["text_mem", "act_mem", "para_mem"] + Note: pref_mem is now integrated into text_mem """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: @@ -51,7 +51,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] + memory_types = ["text_mem", "act_mem", "para_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -66,23 +66,20 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") - if "pref_mem" in memory_types and self.pref_mem: - self.pref_mem.load(dir) - logger.info(f"Loaded pref_mem from {dir}") - logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( self, dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] + Options: ["text_mem", "act_mem", "para_mem"] + Note: pref_mem is now integrated into text_mem """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -94,7 +91,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] + memory_types = ["text_mem", "act_mem", "para_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -109,10 +106,6 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") - if "pref_mem" in memory_types and self.pref_mem: - self.pref_mem.dump(dir) - logger.info(f"Dumped pref_mem to {dir}") - logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @property @@ -157,16 +150,4 @@ def para_mem(self, value: BaseParaMemory) -> None: raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value - @property - def pref_mem(self) -> "BaseTextMemory | None": - """Get the preference memory.""" - if self._pref_mem is None: - logger.warning("Preference memory is not initialized. Returning None.") - return self._pref_mem - - @pref_mem.setter - def pref_mem(self, value: BaseTextMemory) -> None: - """Set the preference memory.""" - if not isinstance(value, BaseTextMemory): - raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") - self._pref_mem = value + # pref_mem property removed - preferences now handled by text_mem diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 6c6d1821f..b8019004d 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -2,7 +2,6 @@ import difflib import json import re -import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Literal @@ -36,7 +35,6 @@ if TYPE_CHECKING: - from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_feedback_prompts import ( FEEDBACK_ANSWER_PROMPT, @@ -95,7 +93,6 @@ def __init__(self, config: MemFeedbackConfig): self.stopword_manager = StopwordManager self.searcher: Searcher = None self.reranker = None - self.pref_mem: SimplePreferenceTextMemory = None self.pref_feedback: bool = False self.DB_IDX_READY = False @@ -239,6 +236,9 @@ def _single_add_operation( else: to_add_memory = new_memory_item.model_copy(deep=True) + if to_add_memory.metadata.memory_type == "PreferenceMemory": + to_add_memory.metadata.preference = new_memory_item.memory + to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( datetime.now().isoformat() ) @@ -274,13 +274,6 @@ def _single_update_operation( """ Individual update operations """ - if "preference" in old_memory_item.metadata.__dict__: - logger.info( - f"[0107 Feedback Core: _single_update_operation] pref_memory: {old_memory_item.id}" - ) - return self._single_update_pref( - old_memory_item, new_memory_item, user_id, user_name, operation - ) memory_type = old_memory_item.metadata.memory_type source_doc_id = ( @@ -329,68 +322,6 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } - def _single_update_pref( - self, - old_memory_item: TextualMemoryItem, - new_memory_item: TextualMemoryItem, - user_id: str, - user_name: str, - operation: dict, - ): - """update preference memory""" - - feedback_context = new_memory_item.memory - if operation and "text" in operation and operation["text"]: - new_memory_item.memory = operation["text"] - new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - - to_add_memory = old_memory_item.model_copy(deep=True) - to_add_memory.metadata.key = new_memory_item.metadata.key - to_add_memory.metadata.tags = new_memory_item.metadata.tags - to_add_memory.memory = new_memory_item.memory - to_add_memory.metadata.preference = new_memory_item.memory - to_add_memory.metadata.embedding = new_memory_item.metadata.embedding - - to_add_memory.metadata.user_id = new_memory_item.metadata.user_id - to_add_memory.metadata.original_text = old_memory_item.memory - to_add_memory.metadata.covered_history = old_memory_item.id - - to_add_memory.metadata.created_at = to_add_memory.metadata.updated_at = ( - datetime.now().isoformat() - ) - to_add_memory.metadata.context_summary = ( - old_memory_item.metadata.context_summary + " \n" + feedback_context - ) - - # add new memory - to_add_memory.id = str(uuid.uuid4()) - added_ids = self._retry_db_operation(lambda: self.pref_mem.add([to_add_memory])) - # delete - deleted_id = old_memory_item.id - collection_name = old_memory_item.metadata.preference_type - self._retry_db_operation( - lambda: self.pref_mem.delete_with_collection_name(collection_name, [deleted_id]) - ) - # add archived - old_memory_item.metadata.status = "archived" - old_memory_item.metadata.original_text = "archived" - old_memory_item.metadata.embedding = [0.0] * 1024 - - archived_ids = self._retry_db_operation(lambda: self.pref_mem.add([old_memory_item])) - - logger.info( - f"[Memory Feedback UPDATE Pref] New Add:{added_ids!s} | Set archived:{archived_ids!s}" - ) - - return { - "id": to_add_memory.id, - "text": new_memory_item.memory, - "source_doc_id": "", - "archived_id": old_memory_item.id, - "origin_memory": old_memory_item.memory, - "type": "preference", - } - def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -460,7 +391,7 @@ def semantics_feedback( for chunk in memory_chunks: chunk_list = [] for item in chunk: - if "preference" in item.metadata.__dict__: + if item.metadata.memory_type == "PreferenceMemory": chunk_list.append(f"{item.id}: {item.metadata.preference}") else: chunk_list.append(f"{item.id}: {item.memory}") @@ -628,16 +559,30 @@ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, boo edges = self.searcher.graph_store.get_edges(mem_item.id, user_name=user_name) return (mem_item, len(edges) == 0) + logger.info(f"[feedback _retrieve] query: {query}, user_name: {user_name}") text_mems = self.searcher.search( - query, + query=query, + top_k=top_k, info=info, memory_type="AllSummaryMemory", user_name=user_name, - top_k=top_k, full_recall=True, ) text_mems = [item[0] for item in text_mems if float(item[1]) > 0.01] + if self.pref_feedback: + pref_mems = self.searcher.search( + query=query, + top_k=top_k, + info=info, + memory_type="PreferenceMemory", + user_name=user_name, + include_preference_memory=True, + full_recall=True, + ) + pref_mems = [item[0] for item in pref_mems if float(item[1]) > 0.01] + text_mems.extend(pref_mems) + # Memory with edges is not modified by feedback retrieved_mems = [] with ContextThreadPoolExecutor(max_workers=10) as executor: @@ -656,14 +601,7 @@ def check_has_edges(mem_item: TextualMemoryItem) -> tuple[TextualMemoryItem, boo f"text memories are not modified by feedback due to edges." ) - if self.pref_feedback: - pref_info = {} - if "user_id" in info: - pref_info = {"user_id": info["user_id"]} - retrieved_prefs = self.pref_mem.search(query, top_k, pref_info) - return retrieved_mems + retrieved_prefs - else: - return retrieved_mems + return retrieved_mems def _vec_query(self, new_memories_embedding: list[float], user_name=None): """Vector retrieval query""" diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 2ac0a0a39..dfc9b9fdf 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -4,7 +4,6 @@ from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.mem_feedback.feedback import MemFeedback from memos.mem_reader.simple_struct import SimpleStructMemReader -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher @@ -24,7 +23,6 @@ def __init__( mem_reader: SimpleStructMemReader, searcher: Searcher, reranker: BaseReranker, - pref_mem: SimplePreferenceTextMemory, pref_feedback: bool = False, ): self.llm = llm @@ -34,7 +32,6 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager - self.pref_mem = pref_mem self.reranker = reranker self.DB_IDX_READY = False self.pref_feedback = pref_feedback diff --git a/src/memos/mem_os/utils/default_config.py b/src/memos/mem_os/utils/default_config.py index edb7875d4..de79d535d 100644 --- a/src/memos/mem_os/utils/default_config.py +++ b/src/memos/mem_os/utils/default_config.py @@ -3,6 +3,8 @@ Provides simplified configuration generation for users. """ +import logging + from typing import Literal from memos.configs.mem_cube import GeneralMemCubeConfig @@ -10,6 +12,9 @@ from memos.mem_cube.general import GeneralMemCube +logger = logging.getLogger(__name__) + + def get_default_config( openai_api_key: str, openai_api_base: str = "https://api.openai.com/v1", @@ -116,20 +121,9 @@ def get_default_config( }, } - # Add activation memory if enabled - if config_dict.get("enable_activation_memory", False): - config_dict["act_mem"] = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, - }, - }, - } + # Note: act_mem configuration belongs in MemCube config (get_default_cube_config), + # not in MOSConfig which doesn't have an act_mem field (extra="forbid"). + # The enable_activation_memory flag above is sufficient for MOSConfig. return MOSConfig(**config_dict) @@ -237,21 +231,33 @@ def get_default_cube_config( }, } - # Configure activation memory if enabled + # Configure activation memory if enabled. + # KV cache activation memory requires a local HuggingFace/vLLM model (it + # extracts internal attention KV tensors via build_kv_cache), so it cannot + # work with remote API backends like OpenAI. + # Only create act_mem when activation_memory_backend is explicitly provided. act_mem_config = {} if kwargs.get("enable_activation_memory", False): - act_mem_config = { - "backend": "kv_cache", - "config": { - "memory_filename": kwargs.get( - "activation_memory_filename", "activation_memory.pickle" - ), - "extractor_llm": { - "backend": "openai", - "config": openai_config, + extractor_backend = kwargs.get("activation_memory_backend") + if extractor_backend in ("huggingface", "huggingface_singleton", "vllm"): + act_mem_config = { + "backend": "kv_cache", + "config": { + "memory_filename": kwargs.get( + "activation_memory_filename", "activation_memory.pickle" + ), + "extractor_llm": { + "backend": extractor_backend, + "config": kwargs.get("activation_memory_llm_config", {}), + }, }, - }, - } + } + else: + logger.info( + "Activation memory (kv_cache) requires a local model backend " + "(huggingface/vllm) via activation_memory_backend kwarg. " + "Skipping act_mem in MemCube config." + ) # Create MemCube configuration cube_config_dict = { diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index e3d2bece9..0b3e19208 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -10,6 +10,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang from memos.mem_reader.read_multi_modal.base import _derive_key +from memos.mem_reader.read_pref_memory.process_preference_memory import process_preference_fine from memos.mem_reader.read_skill_memory.process_skill_memory import process_skill_memory_fine from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result @@ -189,8 +190,16 @@ def _concat_multi_modal_memories( else: processed_items.append(item) - # If only one item after processing, return as-is + # If only one item after processing, compute embedding and return if len(processed_items) == 1: + single_item = processed_items[0] + if single_item and single_item.memory: + try: + single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0] + except Exception as e: + logger.error( + f"[MultiModalStruct] Error computing embedding for single item: {e}" + ) return processed_items windows = [] @@ -288,7 +297,6 @@ def _build_window_from_items( # Collect all memory texts and sources memory_texts = [] all_sources = [] - seen_content = set() # Track seen source content to avoid duplicates roles = set() aggregated_file_ids: list[str] = [] @@ -302,18 +310,8 @@ def _build_window_from_items( item_sources = [item_sources] for source in item_sources: - # Get content from source for deduplication - source_content = None - if isinstance(source, dict): - source_content = source.get("content", "") - else: - source_content = getattr(source, "content", "") or "" - - # Only add if content is different (empty content is considered unique) - content_key = source_content if source_content else None - if content_key and content_key not in seen_content: - seen_content.add(content_key) - all_sources.append(source) + # Add source to all_sources + all_sources.append(source) # Extract role from source if hasattr(source, "role") and source.role: @@ -993,7 +991,7 @@ def _process_multi_modal_data( # Part A: call llm in parallel using thread pool fine_memory_items = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( self._process_string_fine, fast_memory_items, info, custom_tags, **kwargs ) @@ -1012,15 +1010,25 @@ def _process_multi_modal_data( skills_dir_config=self.skills_dir_config, **kwargs, ) + future_pref = executor.submit( + process_preference_fine, + fast_memory_items, + info, + self.llm, + self.embedder, + **kwargs, + ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() + fine_memory_items_pref_parser = future_pref.result() fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) fine_memory_items.extend(fine_memory_items_skill_memory_parser) + fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: get fine multimodal items for fast_item in fast_memory_items: @@ -1060,7 +1068,7 @@ def _process_transfer_multi_modal_data( fine_memory_items = [] # Part A: call llm in parallel using thread pool - with ContextThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=4) as executor: future_string = executor.submit( self._process_string_fine, raw_nodes, info, custom_tags, **kwargs ) @@ -1079,14 +1087,21 @@ def _process_transfer_multi_modal_data( skills_dir_config=self.skills_dir_config, **kwargs, ) + # Add preference memory extraction + future_pref = executor.submit( + process_preference_fine, raw_nodes, info, self.llm, self.embedder, **kwargs + ) # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() + fine_memory_items_pref_parser = future_pref.result() + fine_memory_items.extend(fine_memory_items_string_parser) fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) fine_memory_items.extend(fine_memory_items_skill_memory_parser) + fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: get fine multimodal items for raw_node in raw_nodes: diff --git a/src/memos/mem_reader/read_multi_modal/file_content_parser.py b/src/memos/mem_reader/read_multi_modal/file_content_parser.py index 1b4add398..00e02abda 100644 --- a/src/memos/mem_reader/read_multi_modal/file_content_parser.py +++ b/src/memos/mem_reader/read_multi_modal/file_content_parser.py @@ -51,8 +51,11 @@ class FileContentParser(BaseMessageParser): """Parser for file content parts.""" def _get_doc_llm_response( - self, chunk_text: str, custom_tags: list[str] | None = None - ) -> dict | list: + self, + chunk_text: str, + custom_tags: list[str] | None = None, + message_text_context: str | None = None, + ) -> dict: """ Call LLM to extract memory from document chunk. Uses doc prompts from DOC_PROMPT_DICT. @@ -60,6 +63,8 @@ def _get_doc_llm_response( Args: chunk_text: Text chunk to extract memory from custom_tags: Optional list of custom tags for LLM extraction + message_text_context: Optional text from the same message that + provides user intent / context for understanding this document Returns: Parsed JSON response from LLM (dict or list) or empty dict if failed @@ -79,6 +84,10 @@ def _get_doc_llm_response( ) prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) + # Inject sibling text context into prompt placeholder + context_text = message_text_context.strip() if message_text_context else "" + prompt = prompt.replace("{context}", context_text) + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -109,14 +118,25 @@ def _handle_url(self, url_str: str, filename: str) -> tuple[str, str | None, boo return response.text, None, True file_ext = os.path.splitext(filename)[1].lower() - if file_ext in [".md", ".markdown", ".txt"]: + if file_ext in [".md", ".markdown", ".txt"] or self._is_oss_md(url_str): return response.text, None, True with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_ext) as temp_file: temp_file.write(response.content) return "", temp_file.name, False except Exception as e: logger.error(f"[FileContentParser] URL processing error: {e}") - return f"[File URL download failed: {url_str}]", None + return f"[File URL download failed: {url_str}]", None, False + + def _is_oss_md(self, url: str) -> bool: + """Check if URL is an OSS markdown file based on pattern.""" + loose_pattern = re.compile(r"^https?://[^/]*\.aliyuncs\.com/.*/([^/?#]+)") + match = loose_pattern.search(url) + if not match: + return False + + file_name = match.group(1) + lower_name = file_name.lower() + return lower_name.endswith((".md", ".markdown", ".txt")) def _is_base64(self, data: str) -> bool: """Quick heuristic to check base64-like string.""" @@ -139,7 +159,12 @@ def _handle_local(self, data: str) -> str: return "" def _process_single_image( - self, image_url: str, original_ref: str, info: dict[str, Any], **kwargs + self, + image_url: str, + original_ref: str, + info: dict[str, Any], + header_context: list[str] | None = None, + **kwargs, ) -> tuple[str, str]: """ Process a single image and return (original_ref, replacement_text). @@ -148,6 +173,7 @@ def _process_single_image( image_url: URL of the image to process original_ref: Original markdown image reference to replace info: Dictionary containing user_id and session_id + header_context: Optional list of header titles providing context for the image **kwargs: Additional parameters for ImageParser Returns: @@ -173,20 +199,31 @@ def _process_single_image( if hasattr(item, "memory") and item.memory: extracted_texts.append(str(item.memory)) + # Prepare header context string if available + header_context_str = "" + if header_context: + # Join headers with " > " to show hierarchy + header_hierarchy = " > ".join(header_context) + header_context_str = f"[Section: {header_hierarchy}]\n\n" + if extracted_texts: # Combine all extracted texts extracted_content = "\n".join(extracted_texts) + # build final replacement text + replacement_text = ( + f"{header_context_str}[Image Content from {image_url}]:\n{extracted_content}\n" + ) # Replace image with extracted content return ( original_ref, - f"\n[Image Content from {image_url}]:\n{extracted_content}\n", + replacement_text, ) else: # If no content extracted, keep original with a note logger.warning(f"[FileContentParser] No content extracted from image: {image_url}") return ( original_ref, - f"\n[Image: {image_url} - No content extracted]\n", + f"{header_context_str}[Image: {image_url} - No content extracted]\n", ) except Exception as e: @@ -194,7 +231,9 @@ def _process_single_image( # On error, keep original image reference return (original_ref, original_ref) - def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) -> str: + def _extract_and_process_images( + self, text: str, info: dict[str, Any], headers: dict[int, dict] | None = None, **kwargs + ) -> str: """ Extract all images from markdown text and process them using ImageParser in parallel. Replaces image references with extracted text content. @@ -202,6 +241,7 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) Args: text: Markdown text containing image references info: Dictionary containing user_id and session_id + headers: Optional dictionary mapping line numbers to header info **kwargs: Additional parameters for ImageParser Returns: @@ -225,7 +265,13 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) for match in image_matches: image_url = match.group(2) original_ref = match.group(0) - tasks.append((image_url, original_ref)) + image_position = match.start() + + header_context = None + if headers: + header_context = self._get_header_context(text, image_position, headers) + + tasks.append((image_url, original_ref, header_context)) # Process images in parallel replacements = {} @@ -234,9 +280,14 @@ def _extract_and_process_images(self, text: str, info: dict[str, Any], **kwargs) with ContextThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( - self._process_single_image, image_url, original_ref, info, **kwargs + self._process_single_image, + image_url, + original_ref, + info, + header_context, + **kwargs, ): (image_url, original_ref) - for image_url, original_ref in tasks + for image_url, original_ref, header_context in tasks } # Collect results with progress tracking @@ -603,6 +654,18 @@ def parse_fine( # Extract custom_tags from kwargs (for LLM extraction) custom_tags = kwargs.get("custom_tags") + # Extract sibling text context . + message_text_context = None + context_items = kwargs.get("context_items") + if context_items: + sibling_texts = [] + for ctx_item in context_items: + for src in getattr(ctx_item.metadata, "sources", None) or []: + if src.type == "chat" and src.content: + sibling_texts.append(src.content.strip()) + if sibling_texts: + message_text_context = "\n".join(sibling_texts) + # Use parser from utils parser = self.parser or get_parser() if not parser: @@ -663,9 +726,20 @@ def parse_fine( ) if not parsed_text: return [] + + # Extract markdown headers if applicable + headers = {} + if is_markdown: + headers = self._extract_markdown_headers(parsed_text) + logger.info( + f"[Chunker: FileContentParser] Extracted {len(headers)} headers from markdown" + ) + # Extract and process images from parsed_text if is_markdown and parsed_text and self.image_parser: - parsed_text = self._extract_and_process_images(parsed_text, info, **kwargs) + parsed_text = self._extract_and_process_images( + parsed_text, info, headers=headers if headers else None, **kwargs + ) # Extract info fields if not info: @@ -782,7 +856,9 @@ def _make_fallback( def _process_chunk(chunk_idx: int, chunk_text: str) -> list[TextualMemoryItem]: """Process chunk with LLM, fallback to raw on failure. Returns list of memory items.""" try: - response_json = self._get_doc_llm_response(chunk_text, custom_tags) + response_json = self._get_doc_llm_response( + chunk_text, custom_tags, message_text_context=message_text_context + ) if response_json: # Handle list format response response_list = response_json.get("memory list", []) @@ -932,3 +1008,94 @@ def get_chunk_idx(item: TextualMemoryItem) -> int: chunk_idx=None, ) ] + + def _extract_markdown_headers(self, text: str) -> dict[int, dict]: + """ + Extract markdown headers and their positions. + + Args: + text: Markdown text to parse + """ + if not text: + return {} + + headers = {} + # Pattern to match markdown headers: # Title, ## Title, etc. + header_pattern = r"^(#{1,6})\s+(.+)$" + + lines = text.split("\n") + char_position = 0 + + for line_num, line in enumerate(lines): + # Match header pattern (must be at start of line) + match = re.match(header_pattern, line.strip()) + if match: + level = len(match.group(1)) # Number of # symbols (1-6) + title = match.group(2).strip() # Extract title text + + # Store header info with its position + headers[line_num] = {"level": level, "title": title, "position": char_position} + + logger.debug(f"[FileContentParser] Found H{level} at line {line_num}: {title}") + + # Update character position for next line (+1 for newline character) + char_position += len(line) + 1 + + logger.info(f"[Chunker: FileContentParser] Extracted {len(headers)} headers from markdown") + return headers + + def _get_header_context( + self, text: str, image_position: int, headers: dict[int, dict] + ) -> list[str]: + """ + Get all header levels above an image position in hierarchical order. + + Finds the image's line number, then identifies all preceding headers + and constructs the hierarchical path to the image location. + + Args: + text: Full markdown text + image_position: Character position of the image in text + headers: Dict of headers from _extract_markdown_headers + """ + if not headers: + return [] + + # Find the line number corresponding to the image position + lines = text.split("\n") + char_count = 0 + image_line = 0 + + for i, line in enumerate(lines): + if char_count >= image_position: + image_line = i + break + char_count += len(line) + 1 # +1 for newline + + # Filter headers that appear before the image + preceding_headers = { + line_num: info for line_num, info in headers.items() if line_num < image_line + } + + if not preceding_headers: + return [] + + # Build hierarchical header stack + header_stack = [] + + for line_num in sorted(preceding_headers.keys()): + header = preceding_headers[line_num] + level = header["level"] + title = header["title"] + + # Pop headers of same or lower level + while header_stack and header_stack[-1]["level"] >= level: + removed = header_stack.pop() + logger.debug(f"[FileContentParser] Popped H{removed['level']}: {removed['title']}") + + # Push current header onto stack + header_stack.append({"level": level, "title": title}) + + # Return titles in order + result = [h["title"] for h in header_stack] + return result diff --git a/src/memos/mem_reader/read_multi_modal/image_parser.py b/src/memos/mem_reader/read_multi_modal/image_parser.py index 97400ca26..0d5e8bcc2 100644 --- a/src/memos/mem_reader/read_multi_modal/image_parser.py +++ b/src/memos/mem_reader/read_multi_modal/image_parser.py @@ -137,13 +137,14 @@ def parse_fine( # Get context items if available context_items = kwargs.get("context_items") - # Determine language: prioritize lang from source (passed via kwargs), - # fallback to detecting from context_items if lang not provided + # Determine language: prioritize lang from context_items, + # fallback to kwargs lang = kwargs.get("lang") - if lang is None and context_items: + if context_items: for item in context_items: if hasattr(item, "memory") and item.memory: lang = detect_lang(item.memory) + source.lang = lang break if not lang: lang = "en" diff --git a/src/memos/mem_reader/read_multi_modal/utils.py b/src/memos/mem_reader/read_multi_modal/utils.py index a6d910e54..96918589b 100644 --- a/src/memos/mem_reader/read_multi_modal/utils.py +++ b/src/memos/mem_reader/read_multi_modal/utils.py @@ -341,12 +341,32 @@ def detect_lang(text): if not text or not isinstance(text, str): return "en" cleaned_text = text - # remove role and timestamp + # remove role and timestamp-like prefixes cleaned_text = re.sub( r"\b(user|assistant|query|answer)\s*:", "", cleaned_text, flags=re.IGNORECASE ) + # timestamps like [11:32 AM on 04 March, 2026] + cleaned_text = re.sub( + r"\[\s*\d{1,2}:\d{2}\s*(?:AM|PM)\s+on\s+\d{2}\s+[A-Za-z]+\s*,\s*\d{4}\s*\]", + "", + cleaned_text, + flags=re.IGNORECASE, + ) + # purely numeric timestamps like [2025-01-01 10:00] cleaned_text = re.sub(r"\[[\d\-:\s]+\]", "", cleaned_text) - + # remove URLs to prevent the dilution of Chinese characters + cleaned_text = re.sub(r'https?://[^\s<>"{}|\\^`\[\]]+', "", cleaned_text) + # remove MessageType schema keywords (multimodal JSON noise) + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url)\b", "", cleaned_text, flags=re.IGNORECASE + ) + # remove schema keywords like text / type / image_url / url + cleaned_text = re.sub( + r"\b(text|type|image_url|imageurl|url|file|file_id)\b", + "", + cleaned_text, + flags=re.IGNORECASE, + ) # extract chinese characters chinese_pattern = r"[\u4e00-\u9fff\u3400-\u4dbf\U00020000-\U0002a6df\U0002a700-\U0002b73f\U0002b740-\U0002b81f\U0002b820-\U0002ceaf\uf900-\ufaff]" chinese_chars = re.findall(chinese_pattern, cleaned_text) diff --git a/src/memos/mem_reader/read_pref_memory/process_preference_memory.py b/src/memos/mem_reader/read_pref_memory/process_preference_memory.py new file mode 100644 index 000000000..1ff1fba52 --- /dev/null +++ b/src/memos/mem_reader/read_pref_memory/process_preference_memory.py @@ -0,0 +1,296 @@ +"""Preference memory extractor.""" + +import json +import os +import uuid + +from concurrent.futures import as_completed +from typing import TYPE_CHECKING, Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_reader.read_multi_modal import detect_lang +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, +) + + +if TYPE_CHECKING: + from memos.types.general_types import UserContext + + +logger = get_logger(__name__) + + +def _extract_explicit_preference(qa_pair_str: str, llm) -> list[dict[str, Any]] | None: + """Extract explicit preference from a QA pair string.""" + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) + + try: + response = llm.generate([{"role": "user", "content": prompt}]) + if not response: + logger.info( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting explicit preference" + ) + return None + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + for d in result: + d["preference"] = d.pop("explicit_preference") + return result + except Exception as e: + logger.info(f"Error extracting explicit preference: {e}, return None") + return None + + +def _extract_implicit_preference(qa_pair_str: str, llm) -> list[dict[str, Any]] | None: + """Extract implicit preferences from a QA pair string.""" + if not qa_pair_str: + return None + + lang = detect_lang(qa_pair_str) + _map = { + "zh": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT_ZH, + "en": NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, + } + prompt = _map[lang].replace("{qa_pair}", qa_pair_str) + + try: + response = llm.generate([{"role": "user", "content": prompt}]) + if not response: + logger.info( + f"[prefer_extractor]: (Error) LLM response content is {response} when extracting implicit preference" + ) + return None + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + for d in result: + d["preference"] = d.pop("implicit_preference") + return result + except Exception as e: + logger.info(f"Error extracting implicit preferences: {e}, return None") + return None + + +def _create_preference_memory_item( + preference_data: dict[str, Any], + preference_type: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + embedder, + **kwargs, +) -> TextualMemoryItem: + """ + Create a preference memory item with proper metadata. + + Args: + preference_data: Dictionary containing preference, context_summary, reasoning, topic + preference_type: "explicit_preference" or "implicit_preference" + fast_item: Original fast memory item (for extracting sources and other metadata) + info: Dictionary containing user_id, session_id, etc. + embedder: Embedder instance + kwargs: Additional parameters including user_context + + Returns: + TextualMemoryItem with TreeNodeTextualMemoryMetadata + """ + # Make a copy of info to avoid modifying the original + info_ = info.copy() + + # Extract fields that should be at metadata level + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + # Extract manager_user_id, project_id, and operation from user_context + user_context: UserContext | None = kwargs.get("user_context") + manager_user_id = user_context.manager_user_id if user_context else None + project_id = user_context.project_id if user_context else None + + # Generate embedding for context_summary + context_summary = preference_data.get("context_summary", "") + embedding = embedder.embed([context_summary])[0] if embedder and context_summary else None + + # Extract sources from fast_item + sources = getattr(fast_item.metadata, "sources", []) if fast_item else [] + + # Create metadata + metadata = TreeNodeTextualMemoryMetadata( + memory_type="PreferenceMemory", + embedding=embedding, + user_id=user_id, + session_id=session_id, + status="activated", + tags=[], + type="chat", + info=info_, + sources=sources, + usage=[], + background="", + # Preference-specific fields + preference_type=preference_type, + preference=preference_data.get("preference", ""), + reasoning=preference_data.get("reasoning", ""), + topic=preference_data.get("topic", ""), + # User-specific fields + manager_user_id=manager_user_id, + project_id=project_id, + ) + + # Create and return memory item + return TextualMemoryItem(id=str(uuid.uuid4()), memory=context_summary, metadata=metadata) + + +def _process_single_chunk_explicit( + original_text: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + llm, + embedder, + **kwargs, +) -> list[TextualMemoryItem]: + """Process a single chunk for explicit preferences.""" + if not original_text.strip(): + return [] + + explicit_pref = _extract_explicit_preference(original_text, llm) + if not explicit_pref: + return [] + + memories = [] + for pref in explicit_pref: + memory = _create_preference_memory_item( + preference_data=pref, + preference_type="explicit_preference", + fast_item=fast_item, + info=info, + embedder=embedder, + **kwargs, + ) + memories.append(memory) + + return memories + + +def _process_single_chunk_implicit( + original_text: str, + fast_item: TextualMemoryItem | None, + info: dict[str, Any], + llm, + embedder, + **kwargs, +) -> list[TextualMemoryItem]: + """Process a single chunk for implicit preferences.""" + if not original_text.strip(): + return [] + + implicit_pref = _extract_implicit_preference(original_text, llm) + if not implicit_pref: + return [] + + memories = [] + for pref in implicit_pref: + memory = _create_preference_memory_item( + preference_data=pref, + preference_type="implicit_preference", + fast_item=fast_item, + info=info, + embedder=embedder, + **kwargs, + ) + memories.append(memory) + + return memories + + +def process_preference_fine( + fast_memory_items: list[TextualMemoryItem], + info: dict[str, Any], + llm=None, + embedder=None, + **kwargs, +) -> list[TextualMemoryItem]: + """ + Extract preference memories from fast_memory_items (for fine mode processing). + + Args: + fast_memory_items: List of TextualMemoryItem from fast parsing + info: Dictionary containing user_id and session_id + llm: LLM instance + embedder: Embedder instance + kwargs: Additional parameters (including user_context) + + Returns: + List of preference memory items + """ + + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + if not fast_memory_items or not llm: + return [] + + try: + # Convert fast_memory_items to messages format + chunks = [] + for fast_item in fast_memory_items: + mem_str = fast_item.memory or "" + if not mem_str.strip(): + continue + chunks.append((mem_str, fast_item)) + + if not chunks: + return [] + + # Process chunks in parallel + memories = [] + with ContextThreadPoolExecutor(max_workers=min(10, len(chunks))) as executor: + futures = {} + + # Submit explicit extraction tasks + for chunk, fast_item in chunks: + future = executor.submit( + _process_single_chunk_explicit, chunk, fast_item, info, llm, embedder, **kwargs + ) + futures[future] = ("explicit_preference", chunk) + + # Submit implicit extraction tasks + for chunk, fast_item in chunks: + future = executor.submit( + _process_single_chunk_implicit, chunk, fast_item, info, llm, embedder, **kwargs + ) + futures[future] = ("implicit_preference", chunk) + + # Collect results + for future in as_completed(futures): + try: + memory = future.result() + if memory: + if isinstance(memory, list): + memories.extend(memory) + else: + memories.append(memory) + except Exception as e: + task_type, chunk = futures[future] + logger.warning( + f"[process_preference_fine] Error processing {task_type} chunk, original text: {chunk}: {e}" + ) + continue + + if memories: + logger.info(f"[process_preference_fine] Extracted {len(memories)} preference memories") + + return memories + except Exception as e: + logger.warning( + f"[process_preference_fine] Failed to extract preferences: {e}", exc_info=True + ) + return [] diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index b103acf3a..8777b9f2e 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -18,17 +18,6 @@ from memos.mem_cube.navie import NaiveMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback from memos.mem_reader.factory import MemReaderFactory -from memos.memories.textual.prefer_text_memory.config import ( - AdderConfigFactory, - ExtractorConfigFactory, - RetrieverConfigFactory, -) -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( @@ -40,7 +29,6 @@ if TYPE_CHECKING: from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.factory import RerankerFactory -from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) @@ -182,36 +170,6 @@ def build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) -def build_pref_extractor_config() -> dict[str, Any]: - """ - Build preference memory extractor configuration. - - Returns: - Validated extractor configuration dictionary - """ - return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def build_pref_adder_config() -> dict[str, Any]: - """ - Build preference memory adder configuration. - - Returns: - Validated adder configuration dictionary - """ - return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def build_pref_retriever_config() -> dict[str, Any]: - """ - Build preference memory retriever configuration. - - Returns: - Validated retriever configuration dictionary - """ - return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) - - def _get_default_memory_size(cube_config: Any) -> dict[str, int]: """ Get default memory size configuration. @@ -291,20 +249,11 @@ def init_components() -> dict[str, Any]: reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() internet_retriever_config = build_internet_retriever_config() - vector_db_config = build_vec_db_config() - pref_extractor_config = build_pref_extractor_config() - pref_adder_config = build_pref_adder_config() - pref_retriever_config = build_pref_retriever_config() logger.debug("Component configurations built successfully") # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = ( - VecDBFactory.from_config(vector_db_config) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) @@ -345,63 +294,9 @@ def init_components() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize preference memory components - pref_extractor = ( - ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_adder = ( - AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - pref_retriever = ( - RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=feedback_reranker, - vector_db=vector_db, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - - logger.debug("Preference memory components initialized") - - # Initialize preference memory - pref_mem = ( - SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=feedback_reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" - else None - ) - # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, - pref_mem=pref_mem, act_mem=None, para_mem=None, ) @@ -421,7 +316,7 @@ def init_components() -> dict[str, Any]: mem_reader=mem_reader, searcher=searcher, reranker=feedback_reranker, - pref_mem=pref_mem, + pref_feedback=True, ) # Return all components as a dictionary for easy access and extension return {"naive_mem_cube": naive_mem_cube, "feedback_server": feedback_server} diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 7e40f1d50..60af67830 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -171,6 +171,7 @@ class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( default=None, description="Multiple origins of the memory (e.g., URLs, notes)." @@ -337,8 +338,6 @@ def _coerce_metadata(cls, v: Any): if v.get("relativity") is not None: return SearchedTreeNodeTextualMemoryMetadata(**v) - if v.get("preference_type") is not None: - return PreferenceTextualMemoryMetadata(**v) if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")): return TreeNodeTextualMemoryMetadata(**v) return TextualMemoryMetadata(**v) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index dba321f55..0cc6d1930 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -74,6 +74,7 @@ def get_memory( messages (list[MessageList]): The messages to get memory from. type (str): The type of memory to get. info (dict[str, Any]): The info to get memory. + **kwargs: Additional keyword arguments to pass to the extractor. """ return self.extractor.extract(messages, type, info, **kwargs) @@ -91,7 +92,6 @@ def search( if not isinstance(search_filter, dict): search_filter = {} search_filter.update({"status": "activated"}) - logger.info(f"search_filter for preference memory: {search_filter}") return self.retriever.retrieve(query, top_k, info, search_filter) def load(self, dir: str) -> None: diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index db7101744..51523d364 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -1,5 +1,3 @@ -from typing import Any - from memos.embedders.factory import ( ArkEmbedder, OllamaEmbedder, @@ -8,9 +6,7 @@ ) from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM from memos.log import get_logger -from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.preference import PreferenceTextMemory -from memos.types import MessageList from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB @@ -38,125 +34,3 @@ def __init__( self.extractor = extractor self.adder = adder self.retriever = retriever - - def get_memory( - self, messages: list[MessageList], type: str, info: dict[str, Any], **kwargs - ) -> list[TextualMemoryItem]: - """Get memory based on the messages. - Args: - messages (MessageList): The messages to get memory from. - type (str): The type of memory to get. - info (dict[str, Any]): The info to get memory. - **kwargs: Additional keyword arguments to pass to the extractor. - """ - return self.extractor.extract(messages, type, info, **kwargs) - - def search( - self, query: str, top_k: int, info=None, search_filter=None, **kwargs - ) -> list[TextualMemoryItem]: - """Search for memories based on a query. - Args: - query (str): The query to search for. - top_k (int): The number of top results to return. - info (dict): Leave a record of memory consumption. - Returns: - list[TextualMemoryItem]: List of matching memories. - """ - if not isinstance(search_filter, dict): - search_filter = {} - search_filter.update({"status": "activated"}) - return self.retriever.retrieve(query, top_k, info, search_filter) - - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: - """Add memories. - - Args: - memories: List of TextualMemoryItem objects or dictionaries to add. - """ - return self.adder.add(memories) - - def get_with_collection_name( - self, collection_name: str, memory_id: str - ) -> TextualMemoryItem | None: - """Get a memory by its ID and collection name. - Args: - memory_id (str): The ID of the memory to retrieve. - collection_name (str): The name of the collection to retrieve the memory from. - Returns: - TextualMemoryItem: The memory with the given ID and collection name. - """ - try: - res = self.vector_db.get_by_id(collection_name, memory_id) - if res is None: - return None - return TextualMemoryItem( - id=res.id, - memory=res.memory, - metadata=PreferenceTextualMemoryMetadata(**res.payload), - ) - except Exception as e: - # Convert any other exception to ValueError for consistent error handling - raise ValueError( - f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" - ) from e - - def get_by_ids_with_collection_name( - self, collection_name: str, memory_ids: list[str] - ) -> list[TextualMemoryItem]: - """Get memories by their IDs and collection name. - Args: - collection_name (str): The name of the collection to retrieve the memory from. - memory_ids (list[str]): List of memory IDs to retrieve. - Returns: - list[TextualMemoryItem]: List of memories with the specified IDs and collection name. - """ - try: - res = self.vector_db.get_by_ids(collection_name, memory_ids) - if not res: - return [] - return [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), - ) - for memo in res - ] - except Exception as e: - # Convert any other exception to ValueError for consistent error handling - raise ValueError( - f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" - ) from e - - def get_all(self) -> list[TextualMemoryItem]: - """Get all memories. - Returns: - list[TextualMemoryItem]: List of all memories. - """ - all_collections = ["explicit_preference", "implicit_preference"] - all_memories = {} - for collection_name in all_collections: - items = self.vector_db.get_all(collection_name) - all_memories[collection_name] = [ - TextualMemoryItem( - id=memo.id, - memory=memo.memory, - metadata=PreferenceTextualMemoryMetadata(**memo.payload), - ) - for memo in items - ] - return all_memories - - def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: - """Delete memories by their IDs and collection name. - Args: - collection_name (str): The name of the collection to delete the memory from. - memory_ids (list[str]): List of memory IDs to delete. - """ - self.vector_db.delete(collection_name, memory_ids) - - def delete_all(self) -> None: - """Delete all memories.""" - for collection_name in self.vector_db.config.collection_name: - self.vector_db.delete_collection(collection_name) - self.vector_db.create_collection() diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5b210ba61..8c896f538 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -169,6 +169,8 @@ def search( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, include_embedding: bool | None = None, **kwargs, @@ -222,6 +224,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, dedup=dedup, **kwargs, ) diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py index 132582a0d..98094877c 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py @@ -128,7 +128,7 @@ def resolve_history_via_nli( ) new_item.metadata.history.append(archived) logger.info( - f"[MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" + f"[Chunker: MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" ) # 3. Concat duplicate/conflict memories to new_item.memory diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 4ca30c7b8..df419f0c1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -185,6 +185,7 @@ def _add_memories_batch( "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ): graph_node_id = ( memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) @@ -341,6 +342,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ): f_graph = ex.submit( self._add_to_graph_memory, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index e5e96dd58..dd90b8932 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -69,6 +69,7 @@ def retrieve( "ToolTrajectoryMemory", "RawFileMemory", "SkillMemory", + "PreferenceMemory", ]: raise ValueError(f"Unsupported memory scope: {memory_scope}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index cc269e8c4..eb15b48ed 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -87,10 +87,12 @@ def retrieve( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: logger.info( - f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}, user_name={user_name}" ) parsed_goal, query_embedding, _context, query = self._parse_task( query, @@ -116,6 +118,8 @@ def retrieve( tool_mem_top_k, include_skill_memory, skill_mem_top_k, + include_preference_memory, + pref_mem_top_k, ) return results @@ -129,6 +133,8 @@ def post_retrieve( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, plugin=False, ): @@ -144,6 +150,8 @@ def post_retrieve( tool_mem_top_k, include_skill_memory, skill_mem_top_k, + include_preference_memory, + pref_mem_top_k, ) self._update_usage_history(final_results, info, user_name) return final_results @@ -163,6 +171,8 @@ def search( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, dedup: str | None = None, **kwargs, ) -> list[TextualMemoryItem]: @@ -212,6 +222,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, **kwargs, ) @@ -229,6 +241,8 @@ def search( tool_mem_top_k=tool_mem_top_k, include_skill_memory=include_skill_memory, skill_mem_top_k=skill_mem_top_k, + include_preference_memory=include_preference_memory, + pref_mem_top_k=pref_mem_top_k, dedup=dedup, ) @@ -329,8 +343,10 @@ def _retrieve_paths( tool_mem_top_k: int = 6, include_skill_memory: bool = False, skill_mem_top_k: int = 3, + include_preference_memory: bool = False, + pref_mem_top_k: int = 6, ): - """Run A/B/C/D/E retrieval paths in parallel""" + """Run A/B/C/D/E/F retrieval paths in parallel""" tasks = [] id_filter = { "user_id": info.get("user_id", None), @@ -428,6 +444,22 @@ def _retrieve_paths( mode=mode, ) ) + if include_preference_memory: + tasks.append( + executor.submit( + self._retrieve_from_preference_memory, + query, + parsed_goal, + query_embedding, + pref_mem_top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) results = [] for t in tasks: results.extend(t.result()) @@ -827,6 +859,7 @@ def _retrieve_from_skill_memory( mode: str = "fast", ): """Retrieve and rerank from SkillMemory""" + if memory_type not in ["All", "SkillMemory"]: logger.info(f"[PATH-E] '{query}' Skipped (memory_type does not match)") return [] @@ -863,6 +896,57 @@ def _retrieve_from_skill_memory( search_filter=search_filter, ) + @timed + def _retrieve_from_preference_memory( + self, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + id_filter: dict | None = None, + mode: str = "fast", + ): + """Retrieve and rerank from PreferenceMemory""" + if memory_type not in ["All", "PreferenceMemory"]: + logger.info(f"[PATH-F] '{query}' Skipped (memory_type does not match)") + return [] + + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + + items = self.graph_retriever.retrieve( + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="PreferenceMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + + return self.reranker.rerank( + query=query, + query_embedding=query_embedding[0], + graph_results=items, + top_k=top_k, + parsed_goal=parsed_goal, + search_filter=search_filter, + ) + @timed def _retrieve_simple( self, @@ -933,6 +1017,8 @@ def _sort_and_trim( tool_mem_top_k=6, include_skill_memory=False, skill_mem_top_k=3, + include_preference_memory=False, + pref_mem_top_k=6, ): """Sort results by score and trim to top_k""" final_items = [] @@ -1000,6 +1086,28 @@ def _sort_and_trim( ) ) + if include_preference_memory: + pref_results = [ + (item, score) + for item, score in results + if item.metadata.memory_type == "PreferenceMemory" + ] + sorted_pref_results = sorted(pref_results, key=lambda pair: pair[1], reverse=True)[ + :pref_mem_top_k + ] + for item, score in sorted_pref_results: + if plugin and round(score, 2) == 0.00: + continue + meta_data = item.metadata.model_dump() + meta_data["relativity"] = score + final_items.append( + TextualMemoryItem( + id=item.id, + memory=item.memory, + metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data), + ) + ) + # separate textual results results = [ (item, score) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index f4d6c4847..3b160a56e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -72,7 +72,7 @@ def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: else: return ParsedTaskGoal( memories=[task_description], - keys=[task_description], + keys=[], tags=[], goal_type="default", rephrased_query=task_description, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index d890c77bf..6df410c19 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os import time import traceback @@ -11,10 +10,8 @@ from memos.api.handlers.formatters_handler import ( format_memory_item, - post_process_pref_mem, post_process_textual_mem, ) -from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_reader.utils import parse_keep_filter_response from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -22,7 +19,6 @@ ADD_TASK_LABEL, MEM_FEEDBACK_TASK_LABEL, MEM_READ_TASK_LABEL, - PREF_ADD_TASK_LABEL, ) from memos.memories.textual.item import TextualMemoryItem from memos.multi_mem_cube.views import MemCubeView @@ -78,38 +74,23 @@ def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: ) target_session_id = add_req.session_id or "default_session" - sync_mode = add_req.async_mode or self._get_sync_mode() - self.logger.info( f"[SingleCubeView] cube={self.cube_id} " f"Processing add with mode={sync_mode}, session={target_session_id}" ) - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._process_text_mem, add_req, user_context, sync_mode) - pref_future = executor.submit(self._process_pref_mem, add_req, user_context, sync_mode) - - text_results = text_future.result() - pref_results = pref_future.result() - - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} text_results={len(text_results)}, " - f"pref_results={len(pref_results)}" - ) - - for item in text_results: - item["cube_id"] = self.cube_id - for item in pref_results: - item["cube_id"] = self.cube_id - - all_memories = text_results + pref_results + all_memories = self._process_text_mem(add_req, user_context, sync_mode) - # TODO: search existing memories and compare + self.logger.info(f"[SingleCubeView] cube={self.cube_id} total_results={len(all_memories)}") return all_memories @timed def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: + """ + Unified memory search handling (text + preference memories). + Preference memories are now searched through the same _search_text flow. + """ # Create UserContext object user_context = UserContext( user_id=search_req.user_id, @@ -131,28 +112,16 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: # Determine search mode search_mode = self._get_search_mode(search_req.mode) - # Execute search in parallel for text and preference memories - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(self._search_text, search_req, user_context, search_mode) - pref_future = executor.submit(self._search_pref, search_req, user_context) - - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() + # Unified search through _search_text (includes all memory types) + all_formatted_memories = self._search_text(search_req, user_context, search_mode) - # Build result + # Build result with unified processing memories_result = post_process_textual_mem( memories_result, - text_formatted_memories, + all_formatted_memories, self.cube_id, ) - memories_result = post_process_pref_mem( - memories_result, - pref_formatted_memories, - self.cube_id, - search_req.include_preference, - ) - self.logger.info(f"Search memories result: {memories_result}") self.logger.info(f"Search {len(memories_result)} memories.") return memories_result @@ -407,71 +376,6 @@ def _dedup_by_content(memories: list) -> list: return formatted_memories - @timed - def _search_pref( - self, - search_req: APISearchRequest, - user_context: UserContext, - ) -> list[dict[str, Any]]: - """ - Search preference memories. - - Args: - search_req: Search request - user_context: User context - - Returns: - List of formatted preference memory items - TODO: ADD CUBE ID IN PREFERENCE MEMORY - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - if not search_req.include_preference: - return [] - - logger.info(f"search_req.filter for preference memory: {search_req.filter}") - logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") - try: - results = self.naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "mem_cube_id": user_context.mem_cube_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - search_filter=search_req.filter, - ) - include_embedding = os.getenv("INCLUDE_EMBEDDING", "false") == "true" - formatted_results = self._postformat_memories( - results, user_context.mem_cube_id, include_embedding=include_embedding - ) - - # For each returned item, tackle with metadata.info project_id / - # operation / manager_user_id - for item in formatted_results: - if not isinstance(item, dict): - continue - metadata = item.get("metadata") - if not isinstance(metadata, dict): - continue - info = metadata.get("info") - if not isinstance(info, dict): - continue - - for key in ("project_id", "operation", "manager_user_id"): - if key not in info: - continue - value = info.pop(key) - if key not in metadata: - metadata[key] = value - - return formatted_results - except Exception as e: - self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - def _fast_search( self, search_req: APISearchRequest, @@ -645,89 +549,6 @@ def _schedule_memory_tasks( ) self.mem_scheduler.submit_messages(messages=[message_item_add]) - @timed - def _process_pref_mem( - self, - add_req: APIADDRequest, - user_context: UserContext, - sync_mode: str, - ) -> list[dict[str, Any]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - if add_req.messages is None or isinstance(add_req.messages, str): - return [] - - for message in add_req.messages: - if isinstance(message, dict) and message.get("role", None) is None: - return [] - - target_session_id = add_req.session_id or "default_session" - - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=user_context.mem_cube_id, - mem_cube=self.naive_mem_cube, - label=PREF_ADD_TASK_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - info=add_req.info, - user_name=self.cube_id, - task_id=add_req.task_id, - user_context=user_context, - ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) - self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") - except Exception as e: - self.logger.error( - f"[SingleCubeView] cube={self.cube_id} Failed to submit PREF_ADD: {e}", - exc_info=True, - ) - return [] - else: - pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - **(add_req.info or {}), - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": user_context.mem_cube_id, - }, - user_context=user_context, - ) - pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) - self.logger.info( - f"[SingleCubeView] cube={self.cube_id} " - f"added {len(pref_ids_local)} preferences for user {add_req.user_id}: {pref_ids_local}" - ) - - return [ - { - "memory": memory.metadata.preference, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - def add_before_search( self, messages: list[dict], @@ -834,7 +655,7 @@ def _process_text_mem( sync_mode: str, ) -> list[dict[str, Any]]: """ - Process and add text memories. + Process and add text memories (including preference memories). Extracts memories from messages and adds them to the text memory system. Handles both sync and async modes. @@ -959,13 +780,15 @@ def _process_text_mem( "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." ) + # Format results uniformly text_memories = [ { "memory": memory.memory, "memory_id": memory_id, "memory_type": memory.metadata.memory_type, + "cube_id": self.cube_id, } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + for memory_id, memory in zip(mem_ids_local, mem_group, strict=False) ] return text_memories diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index 6d57e3605..fa713a7d1 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -62,6 +62,8 @@ def search_text_memories( tool_mem_top_k=search_req.tool_mem_top_k, include_skill_memory=search_req.include_skill_memory, skill_mem_top_k=search_req.skill_mem_top_k, + include_preference_memory=search_req.include_preference, + pref_mem_top_k=search_req.pref_top_k, dedup=search_req.dedup, include_embedding=include_embedding, ) diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index f431bd041..63e4c1538 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -263,6 +263,10 @@ {custom_tags_prompt} +If given context, use it as a supplement to the document information extraction; if no context is given, directly process the document information. +Reference context: +{context} + Document chunk: {chunk_text} @@ -307,6 +311,10 @@ {custom_tags_prompt} +如果给定了上下文,就结合上下文信息作为文档信息提取的补充,如果没有给定上下文,请直接处理文档信息。 +参考的上下文: +{context} + 示例: 输入的文本片段: 在Kalamang语中,亲属名词在所有格构式中的行为并不一致。名词 esa“父亲”和 ema“母亲”只能在技术称谓(teknonym)中与第三人称所有格后缀共现,而在非技术称谓用法中,带有所有格后缀是不合语法的。相比之下,大多数其他亲属名词并不允许所有格构式,只有极少数例外。 diff --git a/tests/chunkers/test_sentence_chunker.py b/tests/chunkers/test_sentence_chunker.py index 28aaeabb9..7ff6b2ccd 100644 --- a/tests/chunkers/test_sentence_chunker.py +++ b/tests/chunkers/test_sentence_chunker.py @@ -47,6 +47,17 @@ def test_sentence_chunker(self): self.assertEqual(len(chunks), 2) # Validate the properties of the first chunk mock_chunker.chunk.assert_called_once_with(text) - self.assertEqual(chunks[0].text, "This is the first sentence.") - self.assertEqual(chunks[0].token_count, 6) - self.assertEqual(chunks[0].sentences, ["This is the first sentence."]) + + # Handle both return types: list[str] | list[Chunk] + if isinstance(chunks[0], str): + # If returns list[str], check the string value + self.assertEqual(chunks[0], "This is the first sentence.") + self.assertEqual(chunks[1], "This is the second sentence.") + else: + # If returns list[Chunk], check the Chunk properties + from memos.chunkers.base import Chunk + + self.assertIsInstance(chunks[0], Chunk) + self.assertEqual(chunks[0].text, "This is the first sentence.") + self.assertEqual(chunks[0].token_count, 6) + self.assertEqual(chunks[0].sentences, ["This is the first sentence."]) diff --git a/tests/utils.py b/tests/utils.py index 132cd7138..ec8a32799 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,8 @@ def check_module_base_class(cls: Any) -> None: General function to test the correctness of an abstract base class. - It should inherit from ABC. - It should define at least one method. - - All methods should be marked as @abstractmethod. + - It should have at least one abstract method. + - Abstract methods (those in __abstractmethods__) should be marked as @abstractmethod. - It should not be instantiable. - All methods should have docstrings. @@ -31,14 +32,25 @@ def check_module_base_class(cls: Any) -> None: assert all_class_methods, f"{cls.__name__} should define at least one method" # Check 3: Verify abstract methods + # Get the set of abstract methods from the class + abstract_methods = getattr(cls, "__abstractmethods__", set()) + + # Ensure there is at least one abstract method + assert len(abstract_methods) > 0, f"{cls.__name__} should have at least one abstract method" + + # Verify that all methods in __abstractmethods__ are actually marked as abstract for method_name in all_class_methods: method = getattr(cls, method_name) # Skip private methods (starting with _) as they are typically helper methods if method_name.startswith("_") and method_name != "__init__": continue - assert getattr(method, "__isabstractmethod__", False), ( - f"The method '{method_name}' in {cls.__name__} should be marked as @abstractmethod" - ) + + # If the method is in __abstractmethods__, it must be marked as abstract + if method_name in abstract_methods: + assert getattr(method, "__isabstractmethod__", False), ( + f"The method '{method_name}' in {cls.__name__} is in __abstractmethods__ " + f"but should be marked as @abstractmethod" + ) # Check 4: Test that the class cannot be instantiated directly with pytest.raises(TypeError) as excinfo: