diff --git a/mem0/memory/utils.py b/mem0/memory/utils.py index 8c11705c87..e6c29b8007 100644 --- a/mem0/memory/utils.py +++ b/mem0/memory/utils.py @@ -7,14 +7,63 @@ AGENT_MEMORY_EXTRACTION_PROMPT, ) +_char_map = { + "...": "_ellipsis_", + "…": "_ellipsis_", + "。": "_period_", + ",": "_comma_", + ";": "_semicolon_", + ":": "_colon_", + "!": "_exclamation_", + "?": "_question_", + "(": "_lparen_", + ")": "_rparen_", + "【": "_lbracket_", + "】": "_rbracket_", + "《": "_langle_", + "》": "_rangle_", + "'": "_apostrophe_", + '"': "_quote_", + "\\": "_backslash_", + "/": "_slash_", + "|": "_pipe_", + "&": "_ampersand_", + "=": "_equals_", + "+": "_plus_", + "*": "_asterisk_", + "^": "_caret_", + "%": "_percent_", + "$": "_dollar_", + "#": "_hash_", + "@": "_at_", + "!": "_bang_", + "?": "_question_", + "(": "_lparen_", + ")": "_rparen_", + "[": "_lbracket_", + "]": "_rbracket_", + "{": "_lbrace_", + "}": "_rbrace_", + "<": "_langle_", + ">": "_rangle_", +} + +_multi_keys = [k for k in _char_map if len(k) > 1] + +_single_keys = [k for k in _char_map if len(k) == 1] + +_translation_table = str.maketrans({k: v for k, v in _char_map.items() if len(k) == 1}) + +_re_sub_underscores = re.compile(r"_+") + def get_fact_retrieval_messages(message, is_agent_memory=False): """Get fact retrieval messages based on the memory type. - + Args: message: The message content to extract facts from is_agent_memory: If True, use agent memory extraction prompt, else use user memory extraction prompt - + Returns: tuple: (system_prompt, user_prompt) """ @@ -64,11 +113,10 @@ def remove_code_blocks(content: str) -> str: """ pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$" match = re.match(pattern, content.strip()) - match_res=match.group(1).strip() if match else content.strip() + match_res = match.group(1).strip() if match else content.strip() return re.sub(r".*?", "", match_res, flags=re.DOTALL).strip() - def extract_json(text): """ Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present. @@ -158,51 +206,15 @@ def process_telemetry_filters(filters): def sanitize_relationship_for_cypher(relationship) -> str: """Sanitize relationship text for Cypher queries by replacing problematic characters.""" - char_map = { - "...": "_ellipsis_", - "…": "_ellipsis_", - "。": "_period_", - ",": "_comma_", - ";": "_semicolon_", - ":": "_colon_", - "!": "_exclamation_", - "?": "_question_", - "(": "_lparen_", - ")": "_rparen_", - "【": "_lbracket_", - "】": "_rbracket_", - "《": "_langle_", - "》": "_rangle_", - "'": "_apostrophe_", - '"': "_quote_", - "\\": "_backslash_", - "/": "_slash_", - "|": "_pipe_", - "&": "_ampersand_", - "=": "_equals_", - "+": "_plus_", - "*": "_asterisk_", - "^": "_caret_", - "%": "_percent_", - "$": "_dollar_", - "#": "_hash_", - "@": "_at_", - "!": "_bang_", - "?": "_question_", - "(": "_lparen_", - ")": "_rparen_", - "[": "_lbracket_", - "]": "_rbracket_", - "{": "_lbrace_", - "}": "_rbrace_", - "<": "_langle_", - ">": "_rangle_", - } # Apply replacements and clean up sanitized = relationship - for old, new in char_map.items(): - sanitized = sanitized.replace(old, new) - return re.sub(r"_+", "_", sanitized).strip("_") + # First handle all multi-character replacements + for old in _multi_keys: + sanitized = sanitized.replace(old, _char_map[old]) + + # Next, handle all single-character replacements in one pass + sanitized = sanitized.translate(_translation_table) + return _re_sub_underscores.sub("_", sanitized).strip("_")