Skip to content

Commit bdc3ebc

Browse files
committed
feat: filter all detected categories when contentChecks is empty
When no specific content checks are configured in the AI Prompt Guard Rails policy, the policy now blocks/logs requests for all detected categories from the model, instead of allowing them to pass through. This implements the default behavior requested in APIM-11208 where an empty contentChecks configuration means "all tags". Changes: - Update policy logic to filter all categories when contentChecks is empty - Add documentation clarifying "Keep empty for all" behavior - Mark promptLocation and requestPolicy as required fields - Update default sensitivity threshold from 0.5 to 0.8 with proper constraints - Add integration test for empty contentChecks scenario - Simplify configuration code (use primitive double instead of Double) https://gravitee.atlassian.net/browse/APIM-11208
1 parent 3c4a251 commit bdc3ebc

File tree

6 files changed

+166
-26
lines changed

6 files changed

+166
-26
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ Strikethrough text indicates that a version is deprecated.
8686
####
8787
| Name <br>`json name` | Type <br>`constraint` | Mandatory | Default | Description |
8888
|:----------------------|:-----------------------|:----------:|:---------|:-------------|
89-
| Content Checks<br>`contentChecks`| string| | | Comma-separated list of model labels (e.g., TOXIC,OBSCENE)|
90-
| Prompt Location<br>`promptLocation`| string| | | Prompt Location|
91-
| Request Policy<br>`requestPolicy`| enum (string)| | `LOG_REQUEST`| Request Policy<br>Values: `BLOCK_REQUEST` `LOG_REQUEST`|
89+
| Content Checks<br>`contentChecks`| string| | | Comma-separated list of model labels (e.g., TOXIC,OBSCENE). Keep empty for all|
90+
| Prompt Location<br>`promptLocation`| string| | | Prompt Location|
91+
| Request Policy<br>`requestPolicy`| enum (string)| | `LOG_REQUEST`| Request Policy<br>Values: `BLOCK_REQUEST` `LOG_REQUEST`|
9292
| Resource Name<br>`resourceName`| string| | | The resource name loading the Text Classification model|
93-
| Sensitivity threshold<br>`sensitivityThreshold`| number| | `0.5`| |
93+
| Sensitivity threshold<br>`sensitivityThreshold`| number<br>`[0.1, 1)`| | `0.8`| |
9494

9595

9696

src/main/java/io/gravitee/policy/ai/prompt/guard/rails/AiPromptGuardRailsPolicy.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ private CompletableSource checkContent(HttpPlainExecutionContext ctx) {
7777
.invokeModel(new PromptInput(prompt))
7878
.flatMapCompletable(classifierResults -> {
7979
var detectedContentTypes = detectClassifierResultContentTypes(classifierResults, sensitivityThreshold);
80-
if (configuration.parseContentChecks().stream().anyMatch(detectedContentTypes::contains)) {
80+
if (
81+
configuration.parseContentChecks().isEmpty() ||
82+
configuration.parseContentChecks().stream().anyMatch(detectedContentTypes::contains)
83+
) {
8184
logMetrics(detectedContentTypes, ctx, configuration.requestPolicy().getAction());
8285
if (RequestPolicy.BLOCK_REQUEST.equals(configuration.requestPolicy())) {
8386
return Completable.error(new BlockQueryException(detectedContentTypes));

src/main/java/io/gravitee/policy/ai/prompt/guard/rails/configuration/AiPromptGuardRailsConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public List<String> parseContentChecks() {
3939
return Arrays.stream(contentChecks.split(",")).map(String::trim).filter(s -> !s.isEmpty()).toList();
4040
}
4141

42-
public Double getSensitivityThreshold() {
42+
public double getSensitivityThreshold() {
4343
return sensitivityThreshold != null ? sensitivityThreshold : DEFAULT_SENSITIVITY_THRESHOLD;
4444
}
4545
}

src/main/resources/schemas/schema-form.json

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
"resourceName": {
66
"title": "Resource Name",
77
"description": "The resource name loading the Text Classification model",
8-
"type": "string"
8+
"type": "string",
9+
"x-schema-form": {
10+
"event": {
11+
"name": "fetch-resources",
12+
"regexTypes": "^ai-model-text-classification"
13+
}
14+
},
15+
"gioConfig": {
16+
"uiType": "resource-type",
17+
"uiTypeProps": { "resourceType": "ai-model-text-classification" }
18+
}
919
},
1020
"promptLocation": {
1121
"title": "Prompt Location",
@@ -14,13 +24,16 @@
1424
},
1525
"contentChecks": {
1626
"title": "Content Checks",
17-
"description": "Comma-separated list of model labels (e.g., TOXIC,OBSCENE)",
27+
"description": "Comma-separated list of model labels (e.g., TOXIC,OBSCENE). Keep empty for all",
1828
"type": "string"
1929
},
2030
"sensitivityThreshold": {
2131
"title": "Sensitivity threshold",
2232
"type": "number",
23-
"default": 0.5
33+
"default": 0.8,
34+
"minimum": 0.1,
35+
"exclusiveMaximum": 1,
36+
"multipleOf": 0.01
2437
},
2538
"requestPolicy": {
2639
"title": "Request Policy",
@@ -29,5 +42,6 @@
2942
"default": "LOG_REQUEST",
3043
"enum": ["BLOCK_REQUEST", "LOG_REQUEST"]
3144
}
32-
}
45+
},
46+
"required": ["promptLocation", "requestPolicy"]
3347
}

src/test/java/io/gravitee/policy/ai/prompt/guard/rails/AiPromptGuardRailsPolicyIntegrationTest.java

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,21 @@
2525
import io.gravitee.apim.gateway.tests.sdk.annotations.GatewayTest;
2626
import io.gravitee.apim.gateway.tests.sdk.resource.ResourceBuilder;
2727
import io.gravitee.definition.model.ExecutionMode;
28-
import io.gravitee.gateway.core.component.ComponentProvider;
2928
import io.gravitee.plugin.resource.ResourcePlugin;
3029
import io.gravitee.policy.ai.prompt.guard.rails.configuration.AiPromptGuardRailsConfiguration;
31-
import io.gravitee.policy.ai.prompt.guard.rails.model.AiModelResourceProvider;
3230
import io.gravitee.reporter.api.v4.metric.AdditionalMetric;
3331
import io.gravitee.reporter.api.v4.metric.Metrics;
3432
import io.gravitee.resource.ai_model.TextClassificationAiModelResource;
35-
import io.gravitee.resource.ai_model.api.AiTextModelResource;
3633
import io.gravitee.resource.ai_model.configuration.TextClassificationAiModelConfiguration;
37-
import io.gravitee.resource.api.ResourceManager;
3834
import io.reactivex.rxjava3.core.Completable;
3935
import io.reactivex.rxjava3.core.Observable;
40-
import io.reactivex.rxjava3.core.Single;
4136
import io.vertx.core.http.HttpMethod;
4237
import io.vertx.junit5.VertxTestContext;
4338
import io.vertx.rxjava3.core.http.HttpClient;
4439
import java.util.Map;
4540
import lombok.extern.slf4j.Slf4j;
4641
import org.assertj.core.api.InstanceOfAssertFactories;
47-
import org.junit.jupiter.api.DisplayNameGeneration;
48-
import org.junit.jupiter.api.DisplayNameGenerator;
49-
import org.junit.jupiter.api.Nested;
50-
import org.junit.jupiter.api.Test;
51-
import reactor.util.function.Tuple2;
42+
import org.junit.jupiter.api.*;
5243
import reactor.util.function.Tuples;
5344

5445
@Slf4j
@@ -304,8 +295,13 @@ void should_return_an_error_when_inference_fail(HttpClient client) {
304295
}
305296

306297
@Nested
298+
@DeployApi(
299+
{ "/apis/block_request_policy_empty_contentChecks.json", "/apis/block_request_policy.json", "/apis/log_request_policy.json" }
300+
)
307301
class WithRealAiResource extends AbstractAiPromptGuardRailsPolicyIntegrationTest {
308302

303+
Observable<Long> timer;
304+
309305
@Override
310306
public void configureResources(Map<String, ResourcePlugin> resources) {
311307
super.configureResources(resources);
@@ -320,13 +316,17 @@ public void configureResources(Map<String, ResourcePlugin> resources) {
320316
);
321317
}
322318

319+
@BeforeAll
320+
public void setup() {
321+
// add delay because the model is load asynchronously
322+
timer = Observable.timer(DELAY_BEFORE_REQUEST, SECONDS);
323+
}
324+
323325
@Test
324-
@DeployApi({ "/apis/log_request_policy.json" })
325326
void should_flag_request_if_prompt_violation_detected(HttpClient client, VertxTestContext context) {
326327
wiremock.stubFor(get("/endpoint").willReturn(aResponse().withStatus(200)));
327328

328-
Observable
329-
.timer(DELAY_BEFORE_REQUEST, SECONDS)
329+
timer
330330
.flatMapSingle(v -> client.rxRequest(HttpMethod.GET, "/log-request"))
331331
.firstOrError()
332332
.flatMap(request ->
@@ -361,7 +361,6 @@ void should_flag_request_if_prompt_violation_detected(HttpClient client, VertxTe
361361
}
362362

363363
@Test
364-
@DeployApi({ "/apis/block_request_policy.json" })
365364
void should_block_request_if_prompt_violation_detected(HttpClient client, VertxTestContext context) {
366365
wiremock.stubFor(get("/endpoint").willReturn(aResponse().withStatus(200)));
367366

@@ -377,8 +376,8 @@ void should_block_request_if_prompt_violation_detected(HttpClient client, VertxT
377376
)
378377
.ignoreElements();
379378

380-
var clientAsserts = Completable
381-
.fromObservable(Observable.timer(DELAY_BEFORE_REQUEST, SECONDS))
379+
var clientAsserts = timer
380+
.ignoreElements()
382381
.andThen(
383382
client
384383
.rxRequest(HttpMethod.GET, "/block-request")
@@ -402,6 +401,48 @@ void should_block_request_if_prompt_violation_detected(HttpClient client, VertxT
402401

403402
finalAssert(context, metricsAsserts, clientAsserts);
404403
}
404+
405+
@Test
406+
void should_block_request_if_prompt_violation_detected_and_empty_contentChecks(HttpClient client, VertxTestContext context) {
407+
wiremock.stubFor(get("/endpoint").willReturn(aResponse().withStatus(200)));
408+
409+
var metricsAsserts = metricsSubject
410+
.doOnNext(metrics ->
411+
assertThat(metrics)
412+
.extracting(Metrics::getAdditionalMetrics)
413+
.asInstanceOf(InstanceOfAssertFactories.SET)
414+
.containsExactlyInAnyOrder(
415+
new AdditionalMetric.KeywordMetric("keyword_action", "request-blocked"),
416+
new AdditionalMetric.KeywordMetric("keyword_content_violations", "toxic")
417+
)
418+
)
419+
.ignoreElements();
420+
421+
var clientAsserts = timer
422+
.ignoreElements()
423+
.andThen(
424+
client
425+
.rxRequest(HttpMethod.GET, "/block-request-empty-contentChecks")
426+
.flatMap(request ->
427+
request.rxSend(
428+
"""
429+
{
430+
"model": "GPT-2000",
431+
"date": "01-01-2025",
432+
"prompt": "Nobody asked for your bullsh*t response."
433+
}"""
434+
)
435+
)
436+
.flatMapPublisher(response -> {
437+
assertThat(response.statusCode()).isEqualTo(400);
438+
return response.toFlowable();
439+
})
440+
)
441+
.map(responseBody -> assertThat(responseBody).hasToString("AI prompt validation detected. Reason: [toxic]"))
442+
.ignoreElements();
443+
444+
finalAssert(context, metricsAsserts, clientAsserts);
445+
}
405446
}
406447

407448
private static void finalAssert(VertxTestContext context, Completable metricsAsserts, Completable clientAsserts) {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
{
2+
"id": "v4-ai-prompt-guard-rails-block-request-empty-contentChecks",
3+
"name": "v4-ai-prompt-guard-rails-block-request-empty-contentChecks",
4+
"apiVersion": "1.0",
5+
"definitionVersion": "4.0.0",
6+
"type": "proxy",
7+
"analytics": {},
8+
"listeners": [
9+
{
10+
"type": "http",
11+
"paths": [
12+
{
13+
"path": "/block-request-empty-contentChecks"
14+
}
15+
],
16+
"entrypoints": [
17+
{
18+
"type": "http-proxy"
19+
}
20+
]
21+
}
22+
],
23+
"endpointGroups": [
24+
{
25+
"name": "default",
26+
"type": "http-proxy",
27+
"endpoints": [
28+
{
29+
"name": "default",
30+
"type": "http-proxy",
31+
"weight": 1,
32+
"inheritConfiguration": false,
33+
"configuration": {
34+
"target": "http://localhost:8080/endpoint"
35+
}
36+
}
37+
]
38+
}
39+
],
40+
"resources": [
41+
{
42+
"name": "ai-model-text-classification-resource",
43+
"type": "ai-model-text-classification",
44+
"configuration": {
45+
"model": {
46+
"type": "MINILMV2_TOXIC_JIGSAW_MODEL"
47+
}
48+
},
49+
"enabled": true
50+
}
51+
],
52+
"flows": [
53+
{
54+
"name": "flow-1",
55+
"enabled": true,
56+
"selectors": [
57+
{
58+
"type": "http",
59+
"path": "/",
60+
"pathOperator": "STARTS_WITH"
61+
}
62+
],
63+
"request": [
64+
{
65+
"name": "AI Prompt Guard Rails",
66+
"description": "",
67+
"enabled": true,
68+
"policy": "ai-prompt-guard-rails",
69+
"configuration": {
70+
"resourceName": "ai-model-text-classification-resource",
71+
"promptLocation": "{#request.jsonContent.prompt}",
72+
"contentChecks": "",
73+
"requestPolicy": "BLOCK_REQUEST"
74+
}
75+
}
76+
],
77+
"response": [],
78+
"subscribe": [],
79+
"publish": []
80+
}
81+
]
82+
}

0 commit comments

Comments
 (0)