Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions nemoguardrails/library/embedding_topic_detector/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Embedding Topic Detector

Embedding-based topic detection for NeMo Guardrails. Blocks off-topic queries using semantic similarity.

## Quick Start

```yaml
rails:
config:
embedding_topic_detector:
embedding_model: "sentence-transformers/all-MiniLM-L6-v2"
embedding_engine: "SentenceTransformers"
threshold: 0.5
top_k: 3
examples:
coffee:
- "how to brew the perfect cup of coffee"
- "best coffee beans for espresso"

input:
flows:
- embedding topic check
output:
flows:
- embedding topic check output
```

## How It Works

1. Pre-computes embeddings for your example queries (once at startup)
2. Embeds incoming user query
3. Compares against examples using cosine similarity
4. Returns `on_topic: true/false` based on threshold

**On-topic:** "How do I make espresso?" -> similarity 0.85
**Off-topic:** "Who won the Super Bowl?" -> similarity 0.04

## Configuration

| Parameter | Default | Description |
|-----------|---------|-------------|
| `embedding_model` | Required | Model name (e.g., `all-MiniLM-L6-v2`) |
| `embedding_engine` | Required | Engine (e.g., `SentenceTransformers`) |
| `threshold` | `0.75` | Min similarity to be on-topic (0-1) |
| `top_k` | `3` | Average top-K most similar examples |
| `examples` | Required | Dict of `{category: [example queries]}` |
16 changes: 16 additions & 0 deletions nemoguardrails/library/embedding_topic_detector/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import actions
119 changes: 119 additions & 0 deletions nemoguardrails/library/embedding_topic_detector/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import logging
from typing import Dict, List, Optional

import numpy as np

from nemoguardrails.actions import action
from nemoguardrails.embeddings.providers import init_embedding_model

log = logging.getLogger(__name__)

_detector_cache: Dict[str, "EmbeddingTopicDetector"] = {}


class EmbeddingTopicDetector:
def __init__(
self,
embedding_model: str,
embedding_engine: str,
examples: Dict[str, List[str]],
threshold: float,
top_k: int,
):
self.threshold = threshold
self.top_k = top_k
self.model = init_embedding_model(embedding_model, embedding_engine)
self.embeddings = {
cat: [np.array(e) for e in self.model.encode(queries)]
for cat, queries in examples.items()
if queries
}

async def detect(self, query: str) -> Dict:
query_emb = np.array((await self.model.encode_async([query]))[0])

sims = sorted(
[
(
cat,
np.dot(query_emb, emb)
/ ((np.linalg.norm(query_emb) * np.linalg.norm(emb)) or 1e-10),
)
for cat, embs in self.embeddings.items()
for emb in embs
],
key=lambda x: x[1],
reverse=True,
)[: self.top_k]

scores = {
cat: float(np.mean([s for c, s in sims if c == cat]) or 0.0)
for cat in self.embeddings
}
max_score = max(scores.values(), default=0.0)

return {
"on_topic": max_score >= self.threshold,
"confidence": max_score,
"top_category": max(scores, key=scores.get, default=None),
"category_scores": scores,
}


async def _check(context: Optional[dict], llm_task_manager, message_key: str) -> dict:
config = llm_task_manager.config.rails.config.embedding_topic_detector
examples_hash = hashlib.sha256(
json.dumps(config["examples"], sort_keys=True).encode("utf-8")
).hexdigest()
cache_key = f"{config['embedding_model']}_{config['embedding_engine']}_{config.get('threshold', 0.75)}_{config.get('top_k', 3)}_{examples_hash}"

if cache_key not in _detector_cache:
_detector_cache[cache_key] = EmbeddingTopicDetector(
config["embedding_model"],
config["embedding_engine"],
config["examples"],
config.get("threshold", 0.75),
config.get("top_k", 3),
)

query = context.get(message_key) if context else None
if not query:
return {
"on_topic": True,
"confidence": 0.0,
"top_category": None,
"category_scores": {},
}

return await _detector_cache[cache_key].detect(query)


@action(is_system_action=True)
async def embedding_topic_check(
context: Optional[dict] = None, llm_task_manager=None
) -> dict:
return await _check(context, llm_task_manager, "user_message")


@action(is_system_action=True)
async def embedding_topic_check_output(
context: Optional[dict] = None, llm_task_manager=None
) -> dict:
return await _check(context, llm_task_manager, "bot_message")
19 changes: 19 additions & 0 deletions nemoguardrails/library/embedding_topic_detector/flows.co
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
flow embedding topic check
$result = await EmbeddingTopicCheckAction

if not $result.on_topic
if $system.config.enable_rails_exceptions
send OffTopicRailException(message="Off-topic input blocked.")
else
bot refuse to respond
abort

flow embedding topic check output
$result = await EmbeddingTopicCheckOutputAction

if not $result.on_topic
if $system.config.enable_rails_exceptions
send OffTopicOutputRailException(message="Off-topic output blocked.")
else
bot refuse to respond
abort
22 changes: 22 additions & 0 deletions nemoguardrails/library/embedding_topic_detector/flows.v1.co
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
define bot refuse to respond
"I'm sorry, I can't respond to that."

define flow embedding topic check
$result = execute embedding_topic_check

if not $result["on_topic"]
if $config.enable_rails_exceptions
create event OffTopicRailException(message="Off-topic input blocked.")
else
bot refuse to respond
stop

define flow embedding topic check output
$result = execute embedding_topic_check_output

if not $result["on_topic"]
if $config.enable_rails_exceptions
create event OffTopicOutputRailException(message="Off-topic output blocked.")
else
bot refuse to respond
stop
2 changes: 2 additions & 0 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ class RailsConfigData(BaseModel):
description="Configuration for Cisco AI Defense.",
)

model_config = ConfigDict(extra="allow")


class Rails(BaseModel):
"""Configuration of specific rails."""
Expand Down
153 changes: 153 additions & 0 deletions tests/test_embedding_topic_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from types import SimpleNamespace

from nemoguardrails import RailsConfig
from tests.utils import TestChat

YAML_CONFIG = """
models:
- type: main
engine: fake
model: test

rails:
config:
embedding_topic_detector:
embedding_model: "BAAI/bge-small-en-v1.5"
embedding_engine: "FastEmbed"
threshold: 0.5
top_k: 3
examples:
coffee:
- "how to brew v60"
- "best light-roast espresso beans"
- "is soup an espresso type?"

input:
flows:
- embedding topic check
"""

COLANG_CONFIG = """
define bot refuse to respond
"I'm sorry, I can't respond to that."
"""


def test_off_topic_blocked():
"""Test that off-topic queries are blocked by the embedding detector."""
config = RailsConfig.from_content(
colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG
)
chat = TestChat(config, llm_completions=[])

chat >> "Who won the Super Bowl?"
chat << "I'm sorry, I can't respond to that."


def test_detector_logic():
"""Test the core embedding similarity detection logic."""
from nemoguardrails.library.embedding_topic_detector.actions import (
EmbeddingTopicDetector,
)

detector = EmbeddingTopicDetector(
embedding_model="BAAI/bge-small-en-v1.5",
embedding_engine="FastEmbed",
examples={"coffee": ["how to brew coffee", "best espresso beans"]},
threshold=0.5,
top_k=3,
)

on_topic = asyncio.run(detector.detect("How do I make espresso?"))
assert on_topic["on_topic"] is True
assert on_topic["confidence"] > 0.5
assert on_topic["top_category"] == "coffee"

off_topic = asyncio.run(detector.detect("Who won the Super Bowl?"))
assert off_topic["on_topic"] is False
assert off_topic["confidence"] < 0.5


def test_empty_query_handling():
"""Test that empty queries are handled gracefully."""
from nemoguardrails.library.embedding_topic_detector.actions import _check

llm_task_manager = SimpleNamespace(
config=SimpleNamespace(
rails=SimpleNamespace(
config=SimpleNamespace(
embedding_topic_detector={
"embedding_model": "BAAI/bge-small-en-v1.5",
"embedding_engine": "FastEmbed",
"examples": {"coffee": ["espresso"]},
"threshold": 0.5,
"top_k": 3,
}
)
)
)
)

# Test with None context
result = asyncio.run(_check(None, llm_task_manager, "user_message"))
assert result == {
"on_topic": True,
"confidence": 0.0,
"top_category": None,
"category_scores": {},
}

# Test with empty message in context
result = asyncio.run(_check({}, llm_task_manager, "user_message"))
assert result["on_topic"] is True
assert result["confidence"] == 0.0


def test_output_rail():
"""Test that output rail (bot message checking) works."""
yaml_with_output = """
models:
- type: main
engine: fake
model: test

rails:
config:
embedding_topic_detector:
embedding_model: "BAAI/bge-small-en-v1.5"
embedding_engine: "FastEmbed"
threshold: 0.5
top_k: 3
examples:
coffee:
- "how to brew coffee"
- "espresso tips"

output:
flows:
- embedding topic check output
"""

config = RailsConfig.from_content(
colang_content=COLANG_CONFIG, yaml_content=yaml_with_output
)
chat = TestChat(config, llm_completions=["Who won the Super Bowl yesterday?"])

chat >> "Hello"
chat << "I'm sorry, I can't respond to that."