@@ -446,6 +446,54 @@ async def test_handle_rejected_prompt(self):
446446
447447 self .assertDictEqual (expected_response_metadata , response_metadata )
448448
449+ async def test_handle_rejected_prompt_with_result (self ):
450+ prompt = TEST_REQ_INPUT .model_dump_json ()
451+ metadata = {"key" : "value" , "step_id" : "12345" , "request_id" : "09876" }
452+ request = fake_multipart_request (
453+ prompt = prompt ,
454+ metadata = metadata ,
455+ parameters = {"reject" : True , "annotate" : True },
456+ )
457+ result = Reject (
458+ metadata = metadata ,
459+ code = RejectCode .POLICY_VIOLATION ,
460+ detail = "dangerous question asked" ,
461+ tags = FAKE_TAGS ,
462+ processor_result = {"confidence" : 0.99 },
463+ )
464+ processor = fake_processor (result = result )
465+
466+ response = await processor .handle_request (request )
467+
468+ self .assertStatusCodeEqual (response , HTTP_200_OK )
469+
470+ content = await self .buffer_response (response )
471+ multipart = MultipartDecoderHelper (
472+ content = content , content_type = response .headers ["Content-Type" ]
473+ )
474+
475+ self .assertFalse (
476+ multipart .has_prompt (), "the rejected prompt should not be in the response"
477+ )
478+
479+ multipart_metadata = multipart .metadata
480+ self .assertEqual (
481+ MultipartResponse .JSON_CONTENT_TYPE , multipart_metadata .content_type ()
482+ )
483+ response_metadata = multipart_metadata .as_json ()
484+
485+ expected_response_metadata = dict (
486+ app_details = APP_DETAILS ,
487+ processor_id = processor .id (),
488+ processor_result = {"confidence" : 0.99 },
489+ processor_version = processor .version ,
490+ tags = {"test1" : ["a" , "b" ]},
491+ )
492+ for k , v in metadata .items ():
493+ expected_response_metadata [k ] = v
494+
495+ self .assertDictEqual (expected_response_metadata , response_metadata )
496+
449497 async def test_handle_modified_prompt (self ):
450498 prompt = TEST_REQ_INPUT .model_dump_json ()
451499 metadata = {"key" : "value" }
0 commit comments