diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java index fc2d2dafb64..4c7604444b7 100644 --- a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -4,7 +4,6 @@ import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.CONTENT; import static datadog.trace.api.telemetry.WafMetricCollector.AIGuardTruncationType.MESSAGES; import static datadog.trace.util.Strings.isBlank; -import static java.util.Collections.singletonMap; import com.squareup.moshi.JsonAdapter; import com.squareup.moshi.JsonReader; @@ -69,7 +68,8 @@ public BadConfigurationException(final String message) { static final String REASON_TAG = "ai_guard.reason"; static final String BLOCKED_TAG = "ai_guard.blocked"; static final String META_STRUCT_TAG = "ai_guard"; - static final String META_STRUCT_KEY = "messages"; + static final String META_STRUCT_MESSAGES = "messages"; + static final String META_STRUCT_CATEGORIES = "attack_categories"; public static void install() { final Config config = Config.get(); @@ -208,8 +208,8 @@ public Evaluation evaluate(final List messages, final Options options) } else { span.setTag(TARGET_TAG, "prompt"); } - final Map metaStruct = - singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages)); + final Map metaStruct = new HashMap<>(2); + metaStruct.put(META_STRUCT_MESSAGES, messagesForMetaStruct(messages)); span.setMetaStruct(META_STRUCT_TAG, metaStruct); final Request.Builder request = new Request.Builder() @@ -224,8 +224,15 @@ public Evaluation evaluate(final List messages, final Options options) } final Action action = Action.valueOf(actionStr); final String reason = (String) result.get("reason"); + @SuppressWarnings("unchecked") + final List tags = (List) result.get("tags"); span.setTag(ACTION_TAG, action); - span.setTag(REASON_TAG, reason); + if (reason != null) { + span.setTag(REASON_TAG, reason); + } + if (tags != null && !tags.isEmpty()) { + metaStruct.put(META_STRUCT_CATEGORIES, tags); + } final boolean shouldBlock = isBlockingEnabled(options, result.get("is_blocking_enabled")) && action != Action.ALLOW; WafMetricCollector.get().aiGuardRequest(action, shouldBlock); diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy index 366b977a7a3..77a27dbc61f 100644 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -157,13 +157,14 @@ class AIGuardInternalTests extends DDSpecification { Request request = null Throwable error = null AIGuard.Evaluation eval = null + Map receivedMeta = null final throwAbortError = suite.blocking && suite.action != ALLOW final call = Mock(Call) { execute() >> { return mockResponse( request, 200, - [data: [attributes: [action: suite.action, reason: suite.reason, is_blocking_enabled: suite.blocking]]] + [data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]] ) } } @@ -189,11 +190,18 @@ class AIGuardInternalTests extends DDSpecification { } 1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action) 1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason) - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, [messages: suite.messages]) + 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> { + receivedMeta = it[1] as Map + return span + } if (throwAbortError) { 1 * span.addThrowable(_ as AIGuard.AIGuardAbortError) } + receivedMeta.messages == suite.messages + if (suite.tags) { + receivedMeta.attack_categories == suite.tags + } assertRequest(request, suite.messages) if (throwAbortError) { error instanceof AIGuard.AIGuardAbortError @@ -497,14 +505,16 @@ class AIGuardInternalTests extends DDSpecification { private static class TestSuite { private final AIGuard.Action action private final String reason + private final List tags private final boolean blocking private final String description private final String target private final List messages - TestSuite(AIGuard.Action action, String reason, boolean blocking, String description, String target, List messages) { + TestSuite(AIGuard.Action action, String reason, List tags, boolean blocking, String description, String target, List messages) { this.action = action this.reason = reason + this.tags = tags this.blocking = blocking this.description = description this.target = target @@ -512,7 +522,11 @@ class AIGuardInternalTests extends DDSpecification { } static List build() { - def actionValues = [[ALLOW, 'Go ahead'], [DENY, 'Nope'], [ABORT, 'Kill it with fire']] + def actionValues = [ + [ALLOW, 'Go ahead', []], + [DENY, 'Nope', ['deny_everything', 'test_deny']], + [ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']] + ] def blockingValues = [true, false] def suiteValues = [ ['tool call', 'tool', TOOL_CALL], @@ -521,7 +535,7 @@ class AIGuardInternalTests extends DDSpecification { ] return combinations([actionValues, blockingValues, suiteValues] as Iterable) .collect { action, blocking, suite -> - new TestSuite(action[0], action[1], blocking, suite[0], suite[1], suite[2]) + new TestSuite(action[0], action[1], action[2], blocking, suite[0], suite[1], suite[2]) } }