|
3 | 3 | from .common import AimonClientSingleton |
4 | 4 |
|
5 | 5 |
|
6 | | -class DetectWithQueryFuncReturningContext(object): |
| 6 | +class Detect: |
7 | 7 | DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}} |
8 | 8 |
|
9 | | - def __init__(self, api_key=None, config=None): |
| 9 | + def __init__(self, values_returned, api_key=None, config=None): |
| 10 | + """ |
| 11 | + :param values_returned: A list of values in the order returned by the decorated function |
| 12 | + Acceptable values are 'generated_text', 'context', 'user_query', 'instructions' |
| 13 | + """ |
10 | 14 | self.client = AimonClientSingleton.get_instance(api_key) |
11 | 15 | self.config = config if config else self.DEFAULT_CONFIG |
| 16 | + self.values_returned = values_returned |
| 17 | + if self.values_returned is None or len(self.values_returned) == 0: |
| 18 | + raise ValueError("Values returned by the decorated function must be specified") |
12 | 19 |
|
13 | 20 | def __call__(self, func): |
14 | 21 | @wraps(func) |
15 | | - def wrapper(user_query, *args, **kwargs): |
16 | | - result, context = func(user_query, *args, **kwargs) |
| 22 | + def wrapper(*args, **kwargs): |
| 23 | + result = func(*args, **kwargs) |
17 | 24 |
|
18 | | - if result is None or context is None: |
19 | | - raise ValueError("Result and context must be returned by the decorated function") |
| 25 | + # Handle the case where the result is a single value |
| 26 | + if not isinstance(result, tuple): |
| 27 | + result = (result,) |
20 | 28 |
|
21 | | - data_to_send = [{ |
22 | | - "user_query": user_query, |
23 | | - "context": context, |
24 | | - "generated_text": result, |
25 | | - "config": self.config |
26 | | - }] |
| 29 | + # Create a dictionary mapping output names to results |
| 30 | + result_dict = {name: value for name, value in zip(self.values_returned, result)} |
27 | 31 |
|
28 | | - aimon_response = self.client.inference.detect(body=data_to_send)[0] |
29 | | - return result, context, aimon_response |
30 | | - |
31 | | - return wrapper |
32 | | - |
33 | | - |
34 | | -class DetectWithQueryInstructionsFuncReturningContext(DetectWithQueryFuncReturningContext): |
35 | | - def __call__(self, func): |
36 | | - @wraps(func) |
37 | | - def wrapper(user_query, instructions, *args, **kwargs): |
38 | | - result, context = func(user_query, instructions, *args, **kwargs) |
39 | | - |
40 | | - if result is None or context is None: |
41 | | - raise ValueError("Result and context must be returned by the decorated function") |
42 | | - |
43 | | - data_to_send = [{ |
44 | | - "user_query": user_query, |
45 | | - "context": context, |
46 | | - "generated_text": result, |
47 | | - "instructions": instructions, |
48 | | - "config": self.config |
49 | | - }] |
50 | | - |
51 | | - aimon_response = self.client.inference.detect(body=data_to_send)[0] |
52 | | - return result, context, aimon_response |
53 | | - |
54 | | - return wrapper |
55 | | - |
56 | | - |
57 | | -# Another class but does not include instructions in the wrapper call |
58 | | -class DetectWithContextQuery(object): |
59 | | - DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}} |
60 | | - |
61 | | - def __init__(self, api_key=None, config=None): |
62 | | - self.client = AimonClientSingleton.get_instance(api_key) |
63 | | - self.config = config if config else self.DEFAULT_CONFIG |
64 | | - |
65 | | - def __call__(self, func): |
66 | | - @wraps(func) |
67 | | - def wrapper(context, user_query, *args, **kwargs): |
68 | | - result = func(context, user_query, *args, **kwargs) |
| 32 | + aimon_payload = {} |
| 33 | + if 'generated_text' in result_dict: |
| 34 | + aimon_payload['generated_text'] = result_dict['generated_text'] |
| 35 | + if 'context' in result_dict: |
| 36 | + aimon_payload['context'] = result_dict['context'] |
| 37 | + if 'user_query' in result_dict: |
| 38 | + aimon_payload['user_query'] = result_dict['user_query'] |
| 39 | + if 'instructions' in result_dict: |
| 40 | + aimon_payload['instructions'] = result_dict['instructions'] |
69 | 41 |
|
70 | | - if result is None: |
71 | | - raise ValueError("Result must be returned by the decorated function") |
72 | | - |
73 | | - data_to_send = [{ |
74 | | - "context": context, |
75 | | - "user_query": user_query, |
76 | | - "generated_text": result, |
77 | | - "config": self.config |
78 | | - }] |
| 42 | + data_to_send = [aimon_payload] |
79 | 43 |
|
80 | 44 | aimon_response = self.client.inference.detect(body=data_to_send)[0] |
81 | | - return result, aimon_response |
| 45 | + return result + (aimon_response,) |
82 | 46 |
|
83 | 47 | return wrapper |
84 | | - |
85 | | - |
86 | | -class DetectWithContextQueryInstructions(DetectWithContextQuery): |
87 | | - def __call__(self, func): |
88 | | - @wraps(func) |
89 | | - def wrapper(context, user_query, instructions, *args, **kwargs): |
90 | | - result = func(context, user_query, instructions, *args, **kwargs) |
91 | | - |
92 | | - if result is None: |
93 | | - raise ValueError("Result must be returned by the decorated function") |
94 | | - |
95 | | - data_to_send = [{ |
96 | | - "context": context, |
97 | | - "user_query": user_query, |
98 | | - "generated_text": result, |
99 | | - "instructions": instructions, |
100 | | - "config": self.config |
101 | | - }] |
102 | | - |
103 | | - aimon_response = self.client.inference.detect(body=data_to_send)[0] |
104 | | - return result, aimon_response |
105 | | - |
106 | | - return wrapper |
107 | | - |
0 commit comments