From e48416cea0dbd202a0d9316baabf6e92ae21d6d2 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Thu, 6 Nov 2025 17:37:53 +0000 Subject: [PATCH] :sparkles: added a topic detector rail based on embeddings models Signed-off-by: mmisiura --- .../embedding_topic_detector/README.md | 46 ++++++ .../embedding_topic_detector/__init__.py | 16 ++ .../embedding_topic_detector/actions.py | 119 ++++++++++++++ .../library/embedding_topic_detector/flows.co | 19 +++ .../embedding_topic_detector/flows.v1.co | 22 +++ nemoguardrails/rails/llm/config.py | 2 + tests/test_embedding_topic_detector.py | 153 ++++++++++++++++++ 7 files changed, 377 insertions(+) create mode 100644 nemoguardrails/library/embedding_topic_detector/README.md create mode 100644 nemoguardrails/library/embedding_topic_detector/__init__.py create mode 100644 nemoguardrails/library/embedding_topic_detector/actions.py create mode 100644 nemoguardrails/library/embedding_topic_detector/flows.co create mode 100644 nemoguardrails/library/embedding_topic_detector/flows.v1.co create mode 100644 tests/test_embedding_topic_detector.py diff --git a/nemoguardrails/library/embedding_topic_detector/README.md b/nemoguardrails/library/embedding_topic_detector/README.md new file mode 100644 index 000000000..00ec0836a --- /dev/null +++ b/nemoguardrails/library/embedding_topic_detector/README.md @@ -0,0 +1,46 @@ +# Embedding Topic Detector + +Embedding-based topic detection for NeMo Guardrails. Blocks off-topic queries using semantic similarity. + +## Quick Start + +```yaml +rails: + config: + embedding_topic_detector: + embedding_model: "sentence-transformers/all-MiniLM-L6-v2" + embedding_engine: "SentenceTransformers" + threshold: 0.5 + top_k: 3 + examples: + coffee: + - "how to brew the perfect cup of coffee" + - "best coffee beans for espresso" + + input: + flows: + - embedding topic check + output: + flows: + - embedding topic check output +``` + +## How It Works + +1. Pre-computes embeddings for your example queries (once at startup) +2. Embeds incoming user query +3. Compares against examples using cosine similarity +4. Returns `on_topic: true/false` based on threshold + +**On-topic:** "How do I make espresso?" -> similarity 0.85 +**Off-topic:** "Who won the Super Bowl?" -> similarity 0.04 + +## Configuration + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `embedding_model` | Required | Model name (e.g., `all-MiniLM-L6-v2`) | +| `embedding_engine` | Required | Engine (e.g., `SentenceTransformers`) | +| `threshold` | `0.75` | Min similarity to be on-topic (0-1) | +| `top_k` | `3` | Average top-K most similar examples | +| `examples` | Required | Dict of `{category: [example queries]}` | diff --git a/nemoguardrails/library/embedding_topic_detector/__init__.py b/nemoguardrails/library/embedding_topic_detector/__init__.py new file mode 100644 index 000000000..1c0ccb995 --- /dev/null +++ b/nemoguardrails/library/embedding_topic_detector/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import actions diff --git a/nemoguardrails/library/embedding_topic_detector/actions.py b/nemoguardrails/library/embedding_topic_detector/actions.py new file mode 100644 index 000000000..4b18ee866 --- /dev/null +++ b/nemoguardrails/library/embedding_topic_detector/actions.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import json +import logging +from typing import Dict, List, Optional + +import numpy as np + +from nemoguardrails.actions import action +from nemoguardrails.embeddings.providers import init_embedding_model + +log = logging.getLogger(__name__) + +_detector_cache: Dict[str, "EmbeddingTopicDetector"] = {} + + +class EmbeddingTopicDetector: + def __init__( + self, + embedding_model: str, + embedding_engine: str, + examples: Dict[str, List[str]], + threshold: float, + top_k: int, + ): + self.threshold = threshold + self.top_k = top_k + self.model = init_embedding_model(embedding_model, embedding_engine) + self.embeddings = { + cat: [np.array(e) for e in self.model.encode(queries)] + for cat, queries in examples.items() + if queries + } + + async def detect(self, query: str) -> Dict: + query_emb = np.array((await self.model.encode_async([query]))[0]) + + sims = sorted( + [ + ( + cat, + np.dot(query_emb, emb) + / ((np.linalg.norm(query_emb) * np.linalg.norm(emb)) or 1e-10), + ) + for cat, embs in self.embeddings.items() + for emb in embs + ], + key=lambda x: x[1], + reverse=True, + )[: self.top_k] + + scores = { + cat: float(np.mean([s for c, s in sims if c == cat]) or 0.0) + for cat in self.embeddings + } + max_score = max(scores.values(), default=0.0) + + return { + "on_topic": max_score >= self.threshold, + "confidence": max_score, + "top_category": max(scores, key=scores.get, default=None), + "category_scores": scores, + } + + +async def _check(context: Optional[dict], llm_task_manager, message_key: str) -> dict: + config = llm_task_manager.config.rails.config.embedding_topic_detector + examples_hash = hashlib.sha256( + json.dumps(config["examples"], sort_keys=True).encode("utf-8") + ).hexdigest() + cache_key = f"{config['embedding_model']}_{config['embedding_engine']}_{config.get('threshold', 0.75)}_{config.get('top_k', 3)}_{examples_hash}" + + if cache_key not in _detector_cache: + _detector_cache[cache_key] = EmbeddingTopicDetector( + config["embedding_model"], + config["embedding_engine"], + config["examples"], + config.get("threshold", 0.75), + config.get("top_k", 3), + ) + + query = context.get(message_key) if context else None + if not query: + return { + "on_topic": True, + "confidence": 0.0, + "top_category": None, + "category_scores": {}, + } + + return await _detector_cache[cache_key].detect(query) + + +@action(is_system_action=True) +async def embedding_topic_check( + context: Optional[dict] = None, llm_task_manager=None +) -> dict: + return await _check(context, llm_task_manager, "user_message") + + +@action(is_system_action=True) +async def embedding_topic_check_output( + context: Optional[dict] = None, llm_task_manager=None +) -> dict: + return await _check(context, llm_task_manager, "bot_message") diff --git a/nemoguardrails/library/embedding_topic_detector/flows.co b/nemoguardrails/library/embedding_topic_detector/flows.co new file mode 100644 index 000000000..f6985c0f0 --- /dev/null +++ b/nemoguardrails/library/embedding_topic_detector/flows.co @@ -0,0 +1,19 @@ +flow embedding topic check + $result = await EmbeddingTopicCheckAction + + if not $result.on_topic + if $system.config.enable_rails_exceptions + send OffTopicRailException(message="Off-topic input blocked.") + else + bot refuse to respond + abort + +flow embedding topic check output + $result = await EmbeddingTopicCheckOutputAction + + if not $result.on_topic + if $system.config.enable_rails_exceptions + send OffTopicOutputRailException(message="Off-topic output blocked.") + else + bot refuse to respond + abort diff --git a/nemoguardrails/library/embedding_topic_detector/flows.v1.co b/nemoguardrails/library/embedding_topic_detector/flows.v1.co new file mode 100644 index 000000000..e86621dd0 --- /dev/null +++ b/nemoguardrails/library/embedding_topic_detector/flows.v1.co @@ -0,0 +1,22 @@ +define bot refuse to respond + "I'm sorry, I can't respond to that." + +define flow embedding topic check + $result = execute embedding_topic_check + + if not $result["on_topic"] + if $config.enable_rails_exceptions + create event OffTopicRailException(message="Off-topic input blocked.") + else + bot refuse to respond + stop + +define flow embedding topic check output + $result = execute embedding_topic_check_output + + if not $result["on_topic"] + if $config.enable_rails_exceptions + create event OffTopicOutputRailException(message="Off-topic output blocked.") + else + bot refuse to respond + stop diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index 90d24bdc7..8a59a83d0 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -974,6 +974,8 @@ class RailsConfigData(BaseModel): description="Configuration for Cisco AI Defense.", ) + model_config = ConfigDict(extra="allow") + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/tests/test_embedding_topic_detector.py b/tests/test_embedding_topic_detector.py new file mode 100644 index 000000000..3fb34b475 --- /dev/null +++ b/tests/test_embedding_topic_detector.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from types import SimpleNamespace + +from nemoguardrails import RailsConfig +from tests.utils import TestChat + +YAML_CONFIG = """ +models: + - type: main + engine: fake + model: test + +rails: + config: + embedding_topic_detector: + embedding_model: "BAAI/bge-small-en-v1.5" + embedding_engine: "FastEmbed" + threshold: 0.5 + top_k: 3 + examples: + coffee: + - "how to brew v60" + - "best light-roast espresso beans" + - "is soup an espresso type?" + + input: + flows: + - embedding topic check +""" + +COLANG_CONFIG = """ +define bot refuse to respond + "I'm sorry, I can't respond to that." +""" + + +def test_off_topic_blocked(): + """Test that off-topic queries are blocked by the embedding detector.""" + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG + ) + chat = TestChat(config, llm_completions=[]) + + chat >> "Who won the Super Bowl?" + chat << "I'm sorry, I can't respond to that." + + +def test_detector_logic(): + """Test the core embedding similarity detection logic.""" + from nemoguardrails.library.embedding_topic_detector.actions import ( + EmbeddingTopicDetector, + ) + + detector = EmbeddingTopicDetector( + embedding_model="BAAI/bge-small-en-v1.5", + embedding_engine="FastEmbed", + examples={"coffee": ["how to brew coffee", "best espresso beans"]}, + threshold=0.5, + top_k=3, + ) + + on_topic = asyncio.run(detector.detect("How do I make espresso?")) + assert on_topic["on_topic"] is True + assert on_topic["confidence"] > 0.5 + assert on_topic["top_category"] == "coffee" + + off_topic = asyncio.run(detector.detect("Who won the Super Bowl?")) + assert off_topic["on_topic"] is False + assert off_topic["confidence"] < 0.5 + + +def test_empty_query_handling(): + """Test that empty queries are handled gracefully.""" + from nemoguardrails.library.embedding_topic_detector.actions import _check + + llm_task_manager = SimpleNamespace( + config=SimpleNamespace( + rails=SimpleNamespace( + config=SimpleNamespace( + embedding_topic_detector={ + "embedding_model": "BAAI/bge-small-en-v1.5", + "embedding_engine": "FastEmbed", + "examples": {"coffee": ["espresso"]}, + "threshold": 0.5, + "top_k": 3, + } + ) + ) + ) + ) + + # Test with None context + result = asyncio.run(_check(None, llm_task_manager, "user_message")) + assert result == { + "on_topic": True, + "confidence": 0.0, + "top_category": None, + "category_scores": {}, + } + + # Test with empty message in context + result = asyncio.run(_check({}, llm_task_manager, "user_message")) + assert result["on_topic"] is True + assert result["confidence"] == 0.0 + + +def test_output_rail(): + """Test that output rail (bot message checking) works.""" + yaml_with_output = """ +models: + - type: main + engine: fake + model: test + +rails: + config: + embedding_topic_detector: + embedding_model: "BAAI/bge-small-en-v1.5" + embedding_engine: "FastEmbed" + threshold: 0.5 + top_k: 3 + examples: + coffee: + - "how to brew coffee" + - "espresso tips" + + output: + flows: + - embedding topic check output +""" + + config = RailsConfig.from_content( + colang_content=COLANG_CONFIG, yaml_content=yaml_with_output + ) + chat = TestChat(config, llm_completions=["Who won the Super Bowl yesterday?"]) + + chat >> "Hello" + chat << "I'm sorry, I can't respond to that."