Skip to content

Commit afac83d

Browse files
Added support for must_compute flag (#67)
* Version bump to 10.5 * Added yml file to automatically run tests when a PR is opened, synchronized (updated), or reopened that targets the branch * Updated pip install packages * pip install . * Uncommenting API Key env var * Ignoring obsolete and low level tests * Added descriptive logging to tests * Debugging yml file and requirements modified for streamlit chatbot example * Adjusted streamlit chatbot to be compatible with IA2 * Added `must_compute` flag to Detect * Replaced PII with completeness * Added must compute tests with the actual service
1 parent 4a0ac3e commit afac83d

File tree

3 files changed

+171
-4
lines changed

3 files changed

+171
-4
lines changed

aimon/decorators/detect.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class Detect:
9191
The name of the application to use when publish is True.
9292
model_name : str, optional
9393
The name of the model to use when publish is True.
94+
must_compute : str, optional
95+
Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'.
9496
9597
Example:
9698
--------
@@ -133,7 +135,7 @@ class Detect:
133135
"""
134136
DEFAULT_CONFIG = {'hallucination': {'detector_name': 'default'}}
135137

136-
def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None):
138+
def __init__(self, values_returned, api_key=None, config=None, async_mode=False, publish=False, application_name=None, model_name=None, must_compute='all_or_none'):
137139
"""
138140
:param values_returned: A list of values in the order returned by the decorated function
139141
Acceptable values are 'generated_text', 'context', 'user_query', 'instructions'
@@ -144,6 +146,7 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
144146
:param publish: Boolean, if True, the payload will be published to AIMon and can be viewed on the AIMon UI. Default is False.
145147
:param application_name: The name of the application to use when publish is True
146148
:param model_name: The name of the model to use when publish is True
149+
:param must_compute: String, indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'. Default is 'all_or_none'.
147150
"""
148151
api_key = os.getenv('AIMON_API_KEY') if not api_key else api_key
149152
if api_key is None:
@@ -163,8 +166,15 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
163166
if model_name is None:
164167
raise ValueError("Model name must be provided if publish is True")
165168

169+
# Validate must_compute parameter
170+
if not isinstance(must_compute, str):
171+
raise ValueError("`must_compute` must be a string value")
172+
if must_compute not in ['all_or_none', 'ignore_failures']:
173+
raise ValueError("`must_compute` must be either 'all_or_none' or 'ignore_failures'")
174+
self.must_compute = must_compute
175+
166176
self.application_name = application_name
167-
self.model_name = model_name
177+
self.model_name = model_name
168178

169179
def __call__(self, func):
170180
@wraps(func)
@@ -181,6 +191,7 @@ def wrapper(*args, **kwargs):
181191
aimon_payload['config'] = self.config
182192
aimon_payload['publish'] = self.publish
183193
aimon_payload['async_mode'] = self.async_mode
194+
aimon_payload['must_compute'] = self.must_compute
184195

185196
# Include application_name and model_name if publishing
186197
if self.publish:

aimon/types/inference_detect_params.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class BodyConfigInstructionAdherence(TypedDict, total=False):
4747
class BodyConfigToxicity(TypedDict, total=False):
4848
detector_name: Literal["default"]
4949

50-
5150
class BodyConfig(TypedDict, total=False):
5251
completeness: BodyConfigCompleteness
5352

@@ -61,7 +60,6 @@ class BodyConfig(TypedDict, total=False):
6160

6261
toxicity: BodyConfigToxicity
6362

64-
6563
class Body(TypedDict, total=False):
6664
context: Required[Union[List[str], str]]
6765
"""Context as an array of strings or a single string"""
@@ -81,6 +79,9 @@ class Body(TypedDict, total=False):
8179
model_name: str
8280
"""The model name for publishing metrics for an application."""
8381

82+
must_compute: str
83+
"""Indicates the computation strategy. Must be either 'all_or_none' or 'ignore_failures'."""
84+
8485
publish: bool
8586
"""Indicates whether to publish metrics."""
8687

tests/test_detect.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,3 +824,158 @@ def test_evaluate_with_new_model(self):
824824
import os
825825
if os.path.exists(dataset_path):
826826
os.remove(dataset_path)
827+
828+
def test_must_compute_validation(self):
829+
"""Test that the must_compute parameter is properly validated."""
830+
print("\n=== Testing must_compute validation ===")
831+
832+
# Test config with both hallucination and completeness
833+
test_config = {
834+
"hallucination": {
835+
"detector_name": "default"
836+
},
837+
"completeness": {
838+
"detector_name": "default"
839+
}
840+
}
841+
print(f"Test Config: {test_config}")
842+
843+
# Test valid values
844+
valid_values = ['all_or_none', 'ignore_failures']
845+
print(f"Testing valid must_compute values: {valid_values}")
846+
847+
for value in valid_values:
848+
print(f"Testing valid must_compute value: {value}")
849+
detect = Detect(
850+
values_returned=["context", "generated_text"],
851+
api_key=self.api_key,
852+
config=test_config,
853+
must_compute=value
854+
)
855+
assert detect.must_compute == value
856+
print(f"✅ Successfully validated must_compute value: {value}")
857+
858+
# Test invalid string value
859+
invalid_string_value = "invalid_value"
860+
print(f"Testing invalid must_compute string value: {invalid_string_value}")
861+
try:
862+
Detect(
863+
values_returned=["context", "generated_text"],
864+
api_key=self.api_key,
865+
config=test_config,
866+
must_compute=invalid_string_value
867+
)
868+
print("❌ ERROR: Expected ValueError but none was raised - This should not happen")
869+
assert False, "Expected ValueError for invalid string value"
870+
except ValueError as e:
871+
print(f"✅ Successfully caught ValueError for invalid string: {str(e)}")
872+
assert "`must_compute` must be either 'all_or_none' or 'ignore_failures'" in str(e)
873+
874+
# Test non-string value
875+
non_string_value = 123
876+
print(f"Testing non-string must_compute value: {non_string_value}")
877+
try:
878+
Detect(
879+
values_returned=["context", "generated_text"],
880+
api_key=self.api_key,
881+
config=test_config,
882+
must_compute=non_string_value
883+
)
884+
print("❌ ERROR: Expected ValueError but none was raised - This should not happen")
885+
assert False, "Expected ValueError for non-string value"
886+
except ValueError as e:
887+
print(f"✅ Successfully caught ValueError for non-string: {str(e)}")
888+
assert "`must_compute` must be a string value" in str(e)
889+
890+
# Test default value
891+
print("Testing default must_compute value: default")
892+
detect_default = Detect(
893+
values_returned=["context", "generated_text"],
894+
api_key=self.api_key,
895+
config=test_config
896+
)
897+
assert detect_default.must_compute == 'all_or_none'
898+
print(f"✅ Successfully validated default must_compute value: {detect_default.must_compute}")
899+
900+
print("🎉 Result: must_compute validation working correctly")
901+
902+
def test_must_compute_with_actual_service(self):
903+
"""Test must_compute functionality with actual service calls."""
904+
print("\n=== Testing must_compute with actual service ===")
905+
906+
# Test config with both hallucination and completeness
907+
test_config = {
908+
"hallucination": {
909+
"detector_name": "default"
910+
},
911+
"completeness": {
912+
"detector_name": "default"
913+
}
914+
}
915+
print(f"Test Config: {test_config}")
916+
917+
# Test both must_compute values
918+
for must_compute_value in ['all_or_none', 'ignore_failures']:
919+
print(f"\n--- Testing must_compute: {must_compute_value} ---")
920+
921+
detect = Detect(
922+
values_returned=["context", "generated_text", "user_query"],
923+
api_key=self.api_key,
924+
config=test_config,
925+
must_compute=must_compute_value
926+
)
927+
928+
@detect
929+
def generate_summary(context, query):
930+
generated_text = f"Summary of {context} based on query: {query}"
931+
return context, generated_text, query
932+
933+
# Test data
934+
context = "Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed."
935+
query = "What is machine learning?"
936+
937+
print(f"Input Context: {context}")
938+
print(f"Input Query: {query}")
939+
print(f"Must Compute: {must_compute_value}")
940+
941+
try:
942+
# Call the decorated function
943+
context_ret, generated_text, query_ret, result = generate_summary(context, query)
944+
945+
print(f"✅ API Call Successful!")
946+
print(f"Status Code: {result.status}")
947+
print(f"Generated Text: {generated_text}")
948+
949+
# Display response details
950+
if hasattr(result.detect_response, 'hallucination'):
951+
hallucination = result.detect_response.hallucination
952+
print(f"Hallucination Score: {hallucination.get('score', 'N/A')}")
953+
print(f"Is Hallucinated: {hallucination.get('is_hallucinated', 'N/A')}")
954+
955+
if hasattr(result.detect_response, 'completeness'):
956+
completeness = result.detect_response.completeness
957+
print(f"Completeness Score: {completeness.get('score', 'N/A')}")
958+
959+
# Show the full response structure
960+
print(f"Response Object Type: {type(result.detect_response)}")
961+
if hasattr(result.detect_response, '__dict__'):
962+
print(f"Response Attributes: {list(result.detect_response.__dict__.keys())}")
963+
964+
except Exception as e:
965+
error_message = str(e)
966+
print(f"API Call Result: {error_message}")
967+
print(f"Error Type: {type(e).__name__}")
968+
969+
# For all_or_none, 503 is expected when services are unavailable
970+
if must_compute_value == 'all_or_none' and '503' in error_message:
971+
print("✅ Expected behavior: all_or_none returns 503 when services unavailable")
972+
# For ignore_failures, we expect success or different error handling
973+
elif must_compute_value == 'ignore_failures':
974+
if '503' in error_message:
975+
print("❌ Unexpected: ignore_failures should handle service unavailability")
976+
else:
977+
print("✅ Expected behavior: ignore_failures handled the error appropriately")
978+
else:
979+
print(f"❌ Unexpected error for {must_compute_value}: {error_message}")
980+
981+
print("\n🎉 All must_compute service tests completed!")

0 commit comments

Comments
 (0)