Skip to content

Commit 413c3a0

Browse files
feat: retrying HTTP calls on errors (#440)
**Issue:** [ADDON-78610](https://splunk.atlassian.net/browse/ADDON-78610) Splunk REST Client now retries failed calls. Status codes that are retried are below. Default retries specified by urllib3: - 413 - 429 - 503 Additional codes specified by solnlib: - 500 - 502 - 504 The default number of retries is 5.
1 parent 790de5f commit 413c3a0

File tree

3 files changed

+170
-18
lines changed

3 files changed

+170
-18
lines changed

solnlib/splunk_rest_client.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def _request_handler(context):
7171
'cert_file': string
7272
'pool_connections', int,
7373
'pool_maxsize', int,
74+
'max_retries': int,
75+
'retry_status_codes': list,
7476
}
7577
:type content: dict
7678
"""
@@ -102,25 +104,32 @@ def _request_handler(context):
102104
else:
103105
cert = None
104106

105-
retries = Retry(
106-
total=MAX_REQUEST_RETRIES,
107-
backoff_factor=0.3,
108-
status_forcelist=[500, 502, 503, 504],
109-
allowed_methods=["GET", "POST", "PUT", "DELETE"],
110-
raise_on_status=False,
111-
)
112-
if context.get("pool_connections", 0):
113-
logging.info("Use HTTP connection pooling")
114-
session = requests.Session()
115-
adapter = requests.adapters.HTTPAdapter(
116-
max_retries=retries,
117-
pool_connections=context.get("pool_connections", 10),
118-
pool_maxsize=context.get("pool_maxsize", 10),
107+
def adapter():
108+
retries = Retry(
109+
total=context.get("max_retries", MAX_REQUEST_RETRIES),
110+
backoff_factor=0.3,
111+
status_forcelist=context.get("retry_status_codes", [500, 502, 503, 504]),
112+
allowed_methods=["GET", "POST", "PUT", "DELETE"],
113+
raise_on_status=False,
119114
)
120-
session.mount("https://", adapter)
121-
req_func = session.request
122-
else:
123-
req_func = requests.request
115+
116+
adapter_args = {
117+
"max_retries": retries,
118+
}
119+
120+
# By default, pool_connections and pool_maxsize are set to 10 in urllib3
121+
if "pool_connections" in context:
122+
adapter_args["pool_connections"] = context["pool_connections"]
123+
if "pool_maxsize" in context:
124+
adapter_args["pool_maxsize"] = context["pool_maxsize"]
125+
126+
return requests.adapters.HTTPAdapter(**adapter_args)
127+
128+
session = requests.Session()
129+
session.mount("http://", adapter())
130+
session.mount("https://", adapter())
131+
132+
req_func = session.request
124133

125134
def request(url, message, **kwargs):
126135
"""

tests/unit/conftest.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import json
2+
import socket
3+
from contextlib import closing
4+
from http.server import BaseHTTPRequestHandler, HTTPServer
5+
from threading import Thread
6+
7+
import pytest
8+
9+
10+
@pytest.fixture(scope="session")
11+
def http_mock_server():
12+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
13+
s.bind(("", 0))
14+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
15+
port = s.getsockname()[1]
16+
17+
class Mock:
18+
def __init__(self, host, port):
19+
self.host = host
20+
self.port = port
21+
self.get_func = None
22+
23+
def get(self, func):
24+
self.get_func = func
25+
return func
26+
27+
mock = Mock("localhost", port)
28+
29+
class RequestArg:
30+
def __init__(self):
31+
self.headers = {
32+
"Content-Type": "application/json",
33+
}
34+
self.response_code = 200
35+
36+
def send_header(self, key, value):
37+
self.headers[key] = value
38+
39+
def send_response(self, code):
40+
self.response_code = code
41+
42+
class Handler(BaseHTTPRequestHandler):
43+
def do_GET(self):
44+
if mock.get_func is None:
45+
self.send_response(404)
46+
self.send_header("Content-type", "application/json")
47+
self.end_headers()
48+
self.wfile.write(json.dumps({"error": "Not Found"}).encode("utf-8"))
49+
return
50+
51+
request = RequestArg()
52+
response = mock.get_func(request)
53+
54+
self.send_response(request.response_code)
55+
56+
for key, value in request.headers.items():
57+
self.send_header(key, value)
58+
59+
self.end_headers()
60+
61+
if isinstance(response, dict):
62+
response = json.dumps(response)
63+
64+
self.wfile.write(response.encode("utf-8"))
65+
66+
server_address = ("", mock.port)
67+
httpd = HTTPServer(server_address, Handler)
68+
69+
thread = Thread(target=httpd.serve_forever)
70+
thread.setDaemon(True)
71+
thread.start()
72+
73+
yield mock
74+
75+
httpd.shutdown()
76+
httpd.server_close()
77+
thread.join()

tests/unit/test_splunk_rest_client.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from unittest import mock
1818

1919
import pytest
20+
from splunklib.binding import HTTPError
21+
2022
from solnlib.splunk_rest_client import MAX_REQUEST_RETRIES
2123

2224
from requests.exceptions import ConnectionError
@@ -109,3 +111,67 @@ def test_request_retry(http_conn_pool, http_resp, mock_get_splunkd_access_info):
109111
http_conn_pool.side_effect = side_effects
110112
with pytest.raises(ConnectionError):
111113
rest_client.get("test")
114+
115+
116+
@pytest.mark.parametrize("error_code", [429, 500, 503])
117+
def test_request_throttling(http_mock_server, error_code):
118+
@http_mock_server.get
119+
def throttling(request):
120+
"""Mock endpoint to simulate request throttling.
121+
122+
The endpoint will return an error status code for the first 5
123+
requests, and a 200 status code for subsequent requests.
124+
"""
125+
number = getattr(throttling, "call_count", 0)
126+
throttling.call_count = number + 1
127+
128+
if number < 2:
129+
request.send_response(error_code)
130+
request.send_header("Retry-After", "1")
131+
return {"error": f"Error {number}"}
132+
133+
return {"content": "Success"}
134+
135+
rest_client = SplunkRestClient(
136+
"msg_name_1",
137+
"session_key",
138+
"_",
139+
scheme="http",
140+
host="localhost",
141+
port=http_mock_server.port,
142+
)
143+
144+
resp = rest_client.get("test")
145+
assert resp.status == 200
146+
assert resp.body.read().decode("utf-8") == '{"content": "Success"}'
147+
148+
149+
@pytest.mark.parametrize("error_code", [429, 500, 503])
150+
def test_request_throttling_exceeded(http_mock_server, error_code):
151+
@http_mock_server.get
152+
def throttling(request):
153+
"""Mock endpoint to simulate request throttling.
154+
155+
The endpoint will always return an error status code.
156+
"""
157+
number = getattr(throttling, "call_count", 0)
158+
throttling.call_count = number + 1
159+
160+
request.send_response(error_code)
161+
request.send_header("Retry-After", "1")
162+
return {"error": f"Error {number}"}
163+
164+
rest_client = SplunkRestClient(
165+
"msg_name_1",
166+
"session_key",
167+
"_",
168+
scheme="http",
169+
host="localhost",
170+
port=http_mock_server.port,
171+
)
172+
173+
with pytest.raises(HTTPError) as ex:
174+
rest_client.get("test")
175+
176+
assert ex.value.status == error_code
177+
assert ex.value.body.decode("utf-8") == '{"error": "Error 5"}'

0 commit comments

Comments
 (0)