Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions optimizely/cmab/cmab_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_MAX_BACKOFF = 10 # in seconds
DEFAULT_BACKOFF_MULTIPLIER = 2.0
MAX_WAIT_TIME = 10.0
DEFAULT_PREDICTION_ENDPOINT = "https://prediction.cmab.optimizely.com/predict/{}"


class CmabRetryConfig:
Expand Down Expand Up @@ -52,17 +53,21 @@ class DefaultCmabClient:
"""
def __init__(self, http_client: Optional[requests.Session] = None,
retry_config: Optional[CmabRetryConfig] = None,
logger: Optional[_logging.Logger] = None):
logger: Optional[_logging.Logger] = None,
prediction_endpoint: Optional[str] = None):
"""Initialize the CMAB client.

Args:
http_client (Optional[requests.Session]): HTTP client for making requests.
retry_config (Optional[CmabRetryConfig]): Configuration for retry logic.
logger (Optional[_logging.Logger]): Logger for logging messages.
prediction_endpoint (Optional[str]): Custom prediction endpoint URL template.
Use {} as placeholder for rule_id.
"""
self.http_client = http_client or requests.Session()
self.retry_config = retry_config
self.logger = _logging.adapt_logger(logger or _logging.NoOpLogger())
self.prediction_endpoint = prediction_endpoint or DEFAULT_PREDICTION_ENDPOINT

def fetch_decision(
self,
Expand All @@ -84,7 +89,7 @@ def fetch_decision(
Returns:
str: The variation ID.
"""
url = f"https://prediction.cmab.optimizely.com/predict/{rule_id}"
url = self.prediction_endpoint.format(rule_id)
cmab_attributes = [
{"id": key, "value": value, "type": "custom_attribute"}
for key, value in attributes.items()
Expand Down
6 changes: 5 additions & 1 deletion optimizely/helpers/sdk_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
odp_event_manager: Optional[OdpEventManager] = None,
odp_segment_request_timeout: Optional[int] = None,
odp_event_request_timeout: Optional[int] = None,
odp_event_flush_interval: Optional[int] = None
odp_event_flush_interval: Optional[int] = None,
cmab_prediction_endpoint: Optional[str] = None
) -> None:
"""
Args:
Expand All @@ -52,6 +53,8 @@ def __init__(
send successfully (optional).
odp_event_request_timeout: Time to wait in seconds for send_odp_events request to send successfully.
odp_event_flush_interval: Time to wait for events to accumulate before sending a batch in seconds (optional).
cmab_prediction_endpoint: Custom CMAB prediction endpoint URL template (optional).
Use {} as placeholder for rule_id. Defaults to production endpoint if not provided.
"""

self.odp_disabled = odp_disabled
Expand All @@ -63,3 +66,4 @@ def __init__(
self.fetch_segments_timeout = odp_segment_request_timeout
self.odp_event_timeout = odp_event_request_timeout
self.odp_flush_interval = odp_event_flush_interval
self.cmab_prediction_endpoint = cmab_prediction_endpoint
7 changes: 6 additions & 1 deletion optimizely/optimizely.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,14 @@ def __init__(
self.event_builder = event_builder.EventBuilder()

# Initialize CMAB components
cmab_prediction_endpoint = None
if self.sdk_settings and self.sdk_settings.cmab_prediction_endpoint:
cmab_prediction_endpoint = self.sdk_settings.cmab_prediction_endpoint

self.cmab_client = DefaultCmabClient(
retry_config=CmabRetryConfig(),
logger=self.logger
logger=self.logger,
prediction_endpoint=cmab_prediction_endpoint
)
self.cmab_cache: LRUCache[str, CmabCacheValue] = LRUCache(DEFAULT_CMAB_CACHE_SIZE, DEFAULT_CMAB_CACHE_TIMEOUT)
self.cmab_service = DefaultCmabService(
Expand Down
1 change: 1 addition & 0 deletions requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ jsonschema>=3.2.0
pyrsistent>=0.16.0
requests>=2.21
idna>=2.10
rpds-py<0.20.0; python_version < '3.11'
3 changes: 2 additions & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mypy
types-jsonschema
types-requests
types-Flask
types-Flask
rpds-py<0.20.0; python_version < '3.11'
78 changes: 78 additions & 0 deletions tests/test_cmab_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,81 @@ def test_fetch_decision_exhausts_all_retry_attempts(self, mock_sleep):
self.mock_logger.error.assert_called_with(
Errors.CMAB_FETCH_FAILED.format('Exhausted all retries for CMAB request.')
)

def test_custom_prediction_endpoint(self):
"""Test that custom prediction endpoint is used correctly."""
custom_endpoint = "https://custom.endpoint.com/predict/{}"
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
prediction_endpoint=custom_endpoint
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'abc123'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'abc123')
expected_custom_url = custom_endpoint.format(self.rule_id)
self.mock_http_client.post.assert_called_once_with(
expected_custom_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)

def test_default_prediction_endpoint(self):
"""Test that default prediction endpoint is used when none is provided."""
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'def456'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'def456')
# Should use the default production endpoint
self.mock_http_client.post.assert_called_once_with(
self.expected_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)

def test_empty_prediction_endpoint_uses_default(self):
"""Test that empty string prediction endpoint falls back to default."""
client = DefaultCmabClient(
http_client=self.mock_http_client,
logger=self.mock_logger,
prediction_endpoint=""
)

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
'predictions': [{'variation_id': 'ghi789'}]
}
self.mock_http_client.post.return_value = mock_response

result = client.fetch_decision(self.rule_id, self.user_id, self.attributes, self.cmab_uuid)

self.assertEqual(result, 'ghi789')
# Should use the default production endpoint when empty string is provided
self.mock_http_client.post.assert_called_once_with(
self.expected_url,
data=json.dumps(self.expected_body),
headers=self.expected_headers,
timeout=10.0
)
6 changes: 4 additions & 2 deletions tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,10 @@ def test_fetch_datafile__exception_polling_thread_failed(self, _):
log_messages = [args[0] for args, _ in mock_logger.error.call_args_list]
for message in log_messages:
print(message)
if "Thread for background datafile polling failed. " \
"Error: timestamp too large to convert to C PyTime_t" not in message:
# Check for key parts of the error message (version-agnostic for Python 3.11+)
if not ("Thread for background datafile polling failed" in message and
"timestamp too large to convert to C" in message and
"PyTime_t" in message):
assert False

def test_is_running(self, _):
Expand Down