1- from unittest .mock import MagicMock
1+ from typing import Dict , List
2+ from unittest .mock import AsyncMock , MagicMock , patch
23
34import pytest
45
@@ -151,6 +152,100 @@ async def test_request_file_matcher(
151152 )
152153
153154
155+ # We mock PersonaManager because it's tested in /tests/persona/test_manager.py
156+ MOCK_PERSONA_MANAGER = AsyncMock ()
157+ MOCK_PERSONA_MANAGER .check_persona_match .return_value = True
158+
159+
160+ @pytest .mark .asyncio
161+ @pytest .mark .parametrize (
162+ "body, expected_queries" ,
163+ [
164+ ({"messages" : [{"role" : "system" , "content" : "Youre helpful" }]}, []),
165+ ({"messages" : [{"role" : "user" , "content" : "hello" }]}, ["hello" ]),
166+ (
167+ {"messages" : [{"role" : "user" , "content" : [{"type" : "text" , "text" : "hello_dict" }]}]},
168+ ["hello_dict" ],
169+ ),
170+ ],
171+ )
172+ async def test_user_msgs_persona_desc_matcher (body : Dict , expected_queries : List [str ]):
173+ mux_rule = mux_models .MuxRule (
174+ provider_id = "1" ,
175+ model = "fake-gpt" ,
176+ matcher_type = "user_messages_persona_desc" ,
177+ matcher = "foo_persona" ,
178+ )
179+ muxing_rule_matcher = rulematcher .UserMsgsPersonaDescMuxMatcher (mocked_route_openai , mux_rule )
180+
181+ mocked_thing_to_match = mux_models .ThingToMatchMux (
182+ body = body ,
183+ url_request_path = "/chat/completions" ,
184+ is_fim_request = False ,
185+ client_type = "generic" ,
186+ )
187+
188+ resulting_queries = muxing_rule_matcher ._get_queries_for_persona_match (body )
189+ assert set (resulting_queries ) == set (expected_queries )
190+
191+ with patch ("codegate.muxing.rulematcher.PersonaManager" , return_value = MOCK_PERSONA_MANAGER ):
192+ result = await muxing_rule_matcher .match (mocked_thing_to_match )
193+
194+ if expected_queries :
195+ assert result is True
196+ else :
197+ assert result is False
198+
199+
200+ @pytest .mark .asyncio
201+ @pytest .mark .parametrize (
202+ "body, expected_queries" ,
203+ [
204+ ({"messages" : [{"role" : "system" , "content" : "Youre helpful" }]}, ["Youre helpful" ]),
205+ ({"messages" : [{"role" : "user" , "content" : "hello" }]}, []),
206+ (
207+ {
208+ "messages" : [
209+ {"role" : "system" , "content" : "Youre helpful" },
210+ {"role" : "user" , "content" : "hello" },
211+ ]
212+ },
213+ ["Youre helpful" ],
214+ ),
215+ (
216+ {"messages" : [{"role" : "user" , "content" : "hello" }], "system" : "Anthropic system" },
217+ ["Anthropic system" ],
218+ ),
219+ ],
220+ )
221+ async def test_sys_prompt_persona_desc_matcher (body : Dict , expected_queries : List [str ]):
222+ mux_rule = mux_models .MuxRule (
223+ provider_id = "1" ,
224+ model = "fake-gpt" ,
225+ matcher_type = "sys_prompt_persona_desc" ,
226+ matcher = "foo_persona" ,
227+ )
228+ muxing_rule_matcher = rulematcher .SysPromptPersonaDescMuxMatcher (mocked_route_openai , mux_rule )
229+
230+ mocked_thing_to_match = mux_models .ThingToMatchMux (
231+ body = body ,
232+ url_request_path = "/chat/completions" ,
233+ is_fim_request = False ,
234+ client_type = "generic" ,
235+ )
236+
237+ resulting_queries = muxing_rule_matcher ._get_queries_for_persona_match (body )
238+ assert set (resulting_queries ) == set (expected_queries )
239+
240+ with patch ("codegate.muxing.rulematcher.PersonaManager" , return_value = MOCK_PERSONA_MANAGER ):
241+ result = await muxing_rule_matcher .match (mocked_thing_to_match )
242+
243+ if expected_queries :
244+ assert result is True
245+ else :
246+ assert result is False
247+
248+
154249@pytest .mark .parametrize (
155250 "matcher_type, expected_class" ,
156251 [
@@ -159,8 +254,12 @@ async def test_request_file_matcher(
159254 (mux_models .MuxMatcherType .fim_filename , rulematcher .RequestTypeAndFileMuxingRuleMatcher ),
160255 (mux_models .MuxMatcherType .chat_filename , rulematcher .RequestTypeAndFileMuxingRuleMatcher ),
161256 (
162- mux_models .MuxMatcherType .persona_description ,
163- rulematcher .PersonaDescriptionMuxingRuleMatcher ,
257+ mux_models .MuxMatcherType .user_messages_persona_desc ,
258+ rulematcher .UserMsgsPersonaDescMuxMatcher ,
259+ ),
260+ (
261+ mux_models .MuxMatcherType .sys_prompt_persona_desc ,
262+ rulematcher .SysPromptPersonaDescMuxMatcher ,
164263 ),
165264 ("invalid_matcher" , None ),
166265 ],
0 commit comments