Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT;
import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES;
import static datadog.trace.util.Strings.isBlank;
import static java.util.Collections.singletonMap;

import com.squareup.moshi.JsonAdapter;
import com.squareup.moshi.JsonReader;
Expand Down Expand Up @@ -69,7 +68,8 @@ public BadConfigurationException(final String message) {
static final String REASON_TAG = "ai_guard.reason";
static final String BLOCKED_TAG = "ai_guard.blocked";
static final String META_STRUCT_TAG = "ai_guard";
static final String META_STRUCT_KEY = "messages";
static final String META_STRUCT_MESSAGES = "messages";
static final String META_STRUCT_CATEGORIES = "attack_categories";

public static void install() {
final Config config = Config.get();
Expand Down Expand Up @@ -208,8 +208,8 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
} else {
span.setTag(TARGET_TAG, "prompt");
}
final Map<String, Object> metaStruct =
singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages));
final Map<String, Object> metaStruct = new HashMap<>(2);
metaStruct.put(META_STRUCT_MESSAGES, messagesForMetaStruct(messages));
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
final Request.Builder request =
new Request.Builder()
Expand All @@ -224,8 +224,15 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
}
final Action action = Action.valueOf(actionStr);
final String reason = (String) result.get("reason");
@SuppressWarnings("unchecked")
final List<String> tags = (List<String>) result.get("tags");
span.setTag(ACTION_TAG, action);
span.setTag(REASON_TAG, reason);
if (reason != null) {
span.setTag(REASON_TAG, reason);
}
if (tags != null && !tags.isEmpty()) {
metaStruct.put(META_STRUCT_CATEGORIES, tags);
}
final boolean shouldBlock =
isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW;
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ class AIGuardInternalTests extends DDSpecification {
Request request = null
Throwable error = null
AIGuard.Evaluation eval = null
Map<String, Object> receivedMeta = null
final throwAbortError = suite.blocking && suite.action != ALLOW
final call = Mock(Call) {
execute() >> {
return mockResponse(
request,
200,
[data: [attributes: [action: suite.action, reason: suite.reason, is_blocking_enabled: suite.blocking]]]
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]]
)
}
}
Expand All @@ -189,11 +190,18 @@ class AIGuardInternalTests extends DDSpecification {
}
1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action)
1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason)
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, [messages: suite.messages])
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> {
receivedMeta = it[1] as Map<String, Object>
return span
}
if (throwAbortError) {
1 * span.addThrowable(_ as AIGuard.AIGuardAbortError)
}

receivedMeta.messages == suite.messages
if (suite.tags) {
receivedMeta.attack_categories == suite.tags
}
assertRequest(request, suite.messages)
if (throwAbortError) {
error instanceof AIGuard.AIGuardAbortError
Expand Down Expand Up @@ -497,22 +505,28 @@ class AIGuardInternalTests extends DDSpecification {
private static class TestSuite {
private final AIGuard.Action action
private final String reason
private final List<String> tags
private final boolean blocking
private final String description
private final String target
private final List<AIGuard.Message> messages

TestSuite(AIGuard.Action action, String reason, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
TestSuite(AIGuard.Action action, String reason, List<String> tags, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
this.action = action
this.reason = reason
this.tags = tags
this.blocking = blocking
this.description = description
this.target = target
this.messages = messages
}

static List<TestSuite> build() {
def actionValues = [[ALLOW, 'Go ahead'], [DENY, 'Nope'], [ABORT, 'Kill it with fire']]
def actionValues = [
[ALLOW, 'Go ahead', []],
[DENY, 'Nope', ['deny_everything', 'test_deny']],
[ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']]
]
def blockingValues = [true, false]
def suiteValues = [
['tool call', 'tool', TOOL_CALL],
Expand All @@ -521,7 +535,7 @@ class AIGuardInternalTests extends DDSpecification {
]
return combinations([actionValues, blockingValues, suiteValues] as Iterable)
.collect { action, blocking, suite ->
new TestSuite(action[0], action[1], blocking, suite[0], suite[1], suite[2])
new TestSuite(action[0], action[1], action[2], blocking, suite[0], suite[1], suite[2])
}
}

Expand Down