Skip to content

Commit 51b1f52

Browse files
committed
✨ added a topic detector rail based on embeddings models
Signed-off-by: mmisiura <mmisiura@redhat.com>
1 parent afb1bf0 commit 51b1f52

File tree

7 files changed

+374
-0
lines changed

7 files changed

+374
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Embedding Topic Detector
2+
3+
Embedding-based topic detection for NeMo Guardrails. Blocks off-topic queries using semantic similarity.
4+
5+
## Quick Start
6+
7+
```yaml
8+
rails:
9+
config:
10+
embedding_topic_detector:
11+
embedding_model: "sentence-transformers/all-MiniLM-L6-v2"
12+
embedding_engine: "SentenceTransformers"
13+
threshold: 0.5
14+
top_k: 3
15+
examples:
16+
coffee:
17+
- "how to brew the perfect cup of coffee"
18+
- "best coffee beans for espresso"
19+
20+
input:
21+
flows:
22+
- embedding topic check
23+
output:
24+
flows:
25+
- embedding topic check output
26+
```
27+
28+
## How It Works
29+
30+
1. Pre-computes embeddings for your example queries (once at startup)
31+
2. Embeds incoming user query
32+
3. Compares against examples using cosine similarity
33+
4. Returns `on_topic: true/false` based on threshold
34+
35+
**On-topic:** "How do I make espresso?" -> similarity 0.85
36+
**Off-topic:** "Who won the Super Bowl?" -> similarity 0.04
37+
38+
## Configuration
39+
40+
| Parameter | Default | Description |
41+
|-----------|---------|-------------|
42+
| `embedding_model` | Required | Model name (e.g., `all-MiniLM-L6-v2`) |
43+
| `embedding_engine` | Required | Engine (e.g., `SentenceTransformers`) |
44+
| `threshold` | `0.75` | Min similarity to be on-topic (0-1) |
45+
| `top_k` | `3` | Average top-K most similar examples |
46+
| `examples` | Required | Dict of `{category: [example queries]}` |
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from . import actions
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from typing import Dict, List, Optional
18+
19+
import numpy as np
20+
21+
from nemoguardrails.actions import action
22+
from nemoguardrails.embeddings.providers import init_embedding_model
23+
24+
log = logging.getLogger(__name__)
25+
26+
_detector_cache: Dict[str, "EmbeddingTopicDetector"] = {}
27+
28+
29+
class EmbeddingTopicDetector:
30+
def __init__(
31+
self,
32+
embedding_model: str,
33+
embedding_engine: str,
34+
examples: Dict[str, List[str]],
35+
threshold: float,
36+
top_k: int,
37+
):
38+
self.threshold = threshold
39+
self.top_k = top_k
40+
self.model = init_embedding_model(embedding_model, embedding_engine)
41+
self.embeddings = {
42+
cat: [np.array(e) for e in self.model.encode(queries)]
43+
for cat, queries in examples.items()
44+
if queries
45+
}
46+
47+
async def detect(self, query: str) -> Dict:
48+
query_emb = np.array((await self.model.encode_async([query]))[0])
49+
50+
sims = sorted(
51+
[
52+
(
53+
cat,
54+
float(
55+
np.dot(query_emb, emb)
56+
/ (np.linalg.norm(query_emb) * np.linalg.norm(emb) or 1)
57+
),
58+
)
59+
for cat, embs in self.embeddings.items()
60+
for emb in embs
61+
],
62+
key=lambda x: x[1],
63+
reverse=True,
64+
)[: self.top_k]
65+
66+
scores = {
67+
cat: float(np.mean([s for c, s in sims if c == cat]) or 0.0)
68+
for cat in self.embeddings
69+
}
70+
max_score = max(scores.values(), default=0.0)
71+
72+
return {
73+
"on_topic": max_score >= self.threshold,
74+
"confidence": max_score,
75+
"top_category": max(scores, key=scores.get, default=None),
76+
"category_scores": scores,
77+
}
78+
79+
80+
async def _check(context: Optional[dict], llm_task_manager, message_key: str) -> dict:
81+
config = llm_task_manager.config.rails.config.embedding_topic_detector
82+
cache_key = f"{config['embedding_model']}_{config['embedding_engine']}_{config.get('threshold', 0.75)}"
83+
84+
if cache_key not in _detector_cache:
85+
_detector_cache[cache_key] = EmbeddingTopicDetector(
86+
config["embedding_model"],
87+
config["embedding_engine"],
88+
config["examples"],
89+
config.get("threshold", 0.75),
90+
config.get("top_k", 3),
91+
)
92+
93+
query = context.get(message_key) if context else None
94+
if not query:
95+
return {
96+
"on_topic": True,
97+
"confidence": 0.0,
98+
"top_category": None,
99+
"category_scores": {},
100+
}
101+
102+
return await _detector_cache[cache_key].detect(query)
103+
104+
105+
@action(is_system_action=True)
106+
async def embedding_topic_check(
107+
context: Optional[dict] = None, llm_task_manager=None
108+
) -> dict:
109+
return await _check(context, llm_task_manager, "user_message")
110+
111+
112+
@action(is_system_action=True)
113+
async def embedding_topic_check_output(
114+
context: Optional[dict] = None, llm_task_manager=None
115+
) -> dict:
116+
return await _check(context, llm_task_manager, "bot_message")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
flow embedding topic check
2+
$result = await EmbeddingTopicCheckAction
3+
4+
if not $result.on_topic
5+
if $system.config.enable_rails_exceptions
6+
send OffTopicRailException(message="Off-topic input blocked.")
7+
else
8+
bot refuse to respond
9+
abort
10+
11+
flow embedding topic check output
12+
$result = await EmbeddingTopicCheckOutputAction
13+
14+
if not $result.on_topic
15+
if $system.config.enable_rails_exceptions
16+
send OffTopicOutputRailException(message="Off-topic output blocked.")
17+
else
18+
bot refuse to respond
19+
abort
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
define bot refuse to respond
2+
"I'm sorry, I can't respond to that."
3+
4+
define flow embedding topic check
5+
$result = execute embedding_topic_check
6+
7+
if not $result["on_topic"]
8+
if $config.enable_rails_exceptions
9+
create event OffTopicRailException(message="Off-topic input blocked.")
10+
else
11+
bot refuse to respond
12+
stop
13+
14+
define flow embedding topic check output
15+
$result = execute embedding_topic_check_output
16+
17+
if not $result["on_topic"]
18+
if $config.enable_rails_exceptions
19+
create event OffTopicOutputRailException(message="Off-topic output blocked.")
20+
else
21+
bot refuse to respond
22+
stop

nemoguardrails/rails/llm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@ class RailsConfigData(BaseModel):
974974
description="Configuration for Cisco AI Defense.",
975975
)
976976

977+
model_config = ConfigDict(extra="allow")
978+
977979

978980
class Rails(BaseModel):
979981
"""Configuration of specific rails."""
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
from types import SimpleNamespace
18+
19+
from nemoguardrails import RailsConfig
20+
from tests.utils import TestChat
21+
22+
YAML_CONFIG = """
23+
models:
24+
- type: main
25+
engine: fake
26+
model: test
27+
28+
rails:
29+
config:
30+
embedding_topic_detector:
31+
embedding_model: "BAAI/bge-small-en-v1.5"
32+
embedding_engine: "FastEmbed"
33+
threshold: 0.5
34+
top_k: 3
35+
examples:
36+
coffee:
37+
- "how to brew v60"
38+
- "best light-roast espresso beans"
39+
- "is soup an espresso type?"
40+
41+
input:
42+
flows:
43+
- embedding topic check
44+
"""
45+
46+
COLANG_CONFIG = """
47+
define bot refuse to respond
48+
"I'm sorry, I can't respond to that."
49+
"""
50+
51+
52+
def test_off_topic_blocked():
53+
"""Test that off-topic queries are blocked by the embedding detector."""
54+
config = RailsConfig.from_content(
55+
colang_content=COLANG_CONFIG, yaml_content=YAML_CONFIG
56+
)
57+
chat = TestChat(config, llm_completions=[])
58+
59+
chat >> "Who won the Super Bowl?"
60+
chat << "I'm sorry, I can't respond to that."
61+
62+
63+
def test_detector_logic():
64+
"""Test the core embedding similarity detection logic."""
65+
from nemoguardrails.library.embedding_topic_detector.actions import (
66+
EmbeddingTopicDetector,
67+
)
68+
69+
detector = EmbeddingTopicDetector(
70+
embedding_model="BAAI/bge-small-en-v1.5",
71+
embedding_engine="FastEmbed",
72+
examples={"coffee": ["how to brew coffee", "best espresso beans"]},
73+
threshold=0.5,
74+
top_k=3,
75+
)
76+
77+
on_topic = asyncio.run(detector.detect("How do I make espresso?"))
78+
assert on_topic["on_topic"] is True
79+
assert on_topic["confidence"] > 0.5
80+
assert on_topic["top_category"] == "coffee"
81+
82+
off_topic = asyncio.run(detector.detect("Who won the Super Bowl?"))
83+
assert off_topic["on_topic"] is False
84+
assert off_topic["confidence"] < 0.5
85+
86+
87+
def test_empty_query_handling():
88+
"""Test that empty queries are handled gracefully."""
89+
from nemoguardrails.library.embedding_topic_detector.actions import _check
90+
91+
llm_task_manager = SimpleNamespace(
92+
config=SimpleNamespace(
93+
rails=SimpleNamespace(
94+
config=SimpleNamespace(
95+
embedding_topic_detector={
96+
"embedding_model": "BAAI/bge-small-en-v1.5",
97+
"embedding_engine": "FastEmbed",
98+
"examples": {"coffee": ["espresso"]},
99+
"threshold": 0.5,
100+
"top_k": 3,
101+
}
102+
)
103+
)
104+
)
105+
)
106+
107+
# Test with None context
108+
result = asyncio.run(_check(None, llm_task_manager, "user_message"))
109+
assert result == {
110+
"on_topic": True,
111+
"confidence": 0.0,
112+
"top_category": None,
113+
"category_scores": {},
114+
}
115+
116+
# Test with empty message in context
117+
result = asyncio.run(_check({}, llm_task_manager, "user_message"))
118+
assert result["on_topic"] is True
119+
assert result["confidence"] == 0.0
120+
121+
122+
def test_output_rail():
123+
"""Test that output rail (bot message checking) works."""
124+
yaml_with_output = """
125+
models:
126+
- type: main
127+
engine: fake
128+
model: test
129+
130+
rails:
131+
config:
132+
embedding_topic_detector:
133+
embedding_model: "BAAI/bge-small-en-v1.5"
134+
embedding_engine: "FastEmbed"
135+
threshold: 0.5
136+
top_k: 3
137+
examples:
138+
coffee:
139+
- "how to brew coffee"
140+
- "espresso tips"
141+
142+
output:
143+
flows:
144+
- embedding topic check output
145+
"""
146+
147+
config = RailsConfig.from_content(
148+
colang_content=COLANG_CONFIG, yaml_content=yaml_with_output
149+
)
150+
chat = TestChat(config, llm_completions=["Who won the Super Bowl yesterday?"])
151+
152+
chat >> "Hello"
153+
chat << "I'm sorry, I can't respond to that."

0 commit comments

Comments
 (0)