Skip to content

Commit 13c4d70

Browse files
committed
Switch to using ConnectionManager
1 parent 4754bf3 commit 13c4d70

File tree

8 files changed

+88
-154
lines changed

8 files changed

+88
-154
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,10 @@ _build
4646
.idea
4747
.vscode
4848
*~
49+
50+
# tox-specific files
51+
.tox
52+
build
53+
54+
# coverage-specific files
55+
.coverage

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ repos:
3939
types: [python]
4040
files: "^tests/"
4141
args:
42-
- --disable=missing-docstring,consider-using-f-string,duplicate-code
42+
- --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code

adafruit_requests.py

Lines changed: 29 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
* Adafruit CircuitPython firmware for the supported boards:
3232
https://github.com/adafruit/circuitpython/releases
3333
34+
* Adafruit's Connection Manager library:
35+
https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager
36+
3437
"""
3538

3639
__version__ = "0.0.0+auto.0"
@@ -41,6 +44,9 @@
4144

4245
import json as json_module
4346

47+
from adafruit_connectionmanager import get_connection_manager
48+
49+
4450
if not sys.implementation.name == "circuitpython":
4551
from ssl import SSLContext
4652
from types import ModuleType, TracebackType
@@ -176,7 +182,7 @@ def __init__(self, sock: SocketType, session: Optional["Session"] = None) -> Non
176182
http = self._readto(b" ")
177183
if not http:
178184
if session:
179-
session._close_socket(self.socket)
185+
session._connection_manager.close_socket(self.socket)
180186
else:
181187
self.socket.close()
182188
raise RuntimeError("Unable to read HTTP response.")
@@ -320,7 +326,8 @@ def close(self) -> None:
320326
self._throw_away(chunk_size + 2)
321327
self._parse_headers()
322328
if self._session:
323-
self._session._free_socket(self.socket) # pylint: disable=protected-access
329+
# pylint: disable=protected-access
330+
self._session._connection_manager.free_socket(self.socket)
324331
else:
325332
self.socket.close()
326333
self.socket = None
@@ -429,105 +436,26 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt
429436
self.close()
430437

431438

439+
_global_session = None # pylint: disable=invalid-name
440+
441+
432442
class Session:
433443
"""HTTP session that shares sockets and ssl context."""
434444

435445
def __init__(
436446
self,
437447
socket_pool: SocketpoolModuleType,
438448
ssl_context: Optional[SSLContextType] = None,
449+
set_global_session: bool = True,
439450
) -> None:
440-
self._socket_pool = socket_pool
451+
self._connection_manager = get_connection_manager(socket_pool)
441452
self._ssl_context = ssl_context
442-
# Hang onto open sockets so that we can reuse them.
443-
self._open_sockets = {}
444-
self._socket_free = {}
445453
self._last_response = None
446454

447-
def _free_socket(self, socket: SocketType) -> None:
448-
if socket not in self._open_sockets.values():
449-
raise RuntimeError("Socket not from session")
450-
self._socket_free[socket] = True
451-
452-
def _close_socket(self, sock: SocketType) -> None:
453-
sock.close()
454-
del self._socket_free[sock]
455-
key = None
456-
for k in self._open_sockets: # pylint: disable=consider-using-dict-items
457-
if self._open_sockets[k] == sock:
458-
key = k
459-
break
460-
if key:
461-
del self._open_sockets[key]
462-
463-
def _free_sockets(self) -> None:
464-
free_sockets = []
465-
for sock, val in self._socket_free.items():
466-
if val:
467-
free_sockets.append(sock)
468-
for sock in free_sockets:
469-
self._close_socket(sock)
470-
471-
def _get_socket(
472-
self, host: str, port: int, proto: str, *, timeout: float = 1
473-
) -> CircuitPythonSocketType:
474-
# pylint: disable=too-many-branches
475-
key = (host, port, proto)
476-
if key in self._open_sockets:
477-
sock = self._open_sockets[key]
478-
if self._socket_free[sock]:
479-
self._socket_free[sock] = False
480-
return sock
481-
if proto == "https:" and not self._ssl_context:
482-
raise RuntimeError(
483-
"ssl_context must be set before using adafruit_requests for https"
484-
)
485-
addr_info = self._socket_pool.getaddrinfo(
486-
host, port, 0, self._socket_pool.SOCK_STREAM
487-
)[0]
488-
retry_count = 0
489-
sock = None
490-
last_exc = None
491-
while retry_count < 5 and sock is None:
492-
if retry_count > 0:
493-
if any(self._socket_free.items()):
494-
self._free_sockets()
495-
else:
496-
raise RuntimeError("Sending request failed") from last_exc
497-
retry_count += 1
498-
499-
try:
500-
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
501-
except OSError as exc:
502-
last_exc = exc
503-
continue
504-
except RuntimeError as exc:
505-
last_exc = exc
506-
continue
507-
508-
connect_host = addr_info[-1][0]
509-
if proto == "https:":
510-
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
511-
connect_host = host
512-
sock.settimeout(timeout) # socket read timeout
513-
514-
try:
515-
sock.connect((connect_host, port))
516-
except MemoryError as exc:
517-
last_exc = exc
518-
sock.close()
519-
sock = None
520-
except OSError as exc:
521-
last_exc = exc
522-
sock.close()
523-
sock = None
524-
525-
if sock is None:
526-
raise RuntimeError("Repeated socket failures") from last_exc
527-
528-
self._open_sockets[key] = sock
529-
self._socket_free[sock] = False
530-
return sock
455+
if set_global_session:
456+
# pylint: disable=global-statement
457+
global _global_session
458+
_global_session = self
531459

532460
@staticmethod
533461
def _send(socket: SocketType, data: bytes):
@@ -647,7 +575,9 @@ def request(
647575
last_exc = None
648576
while retry_count < 2:
649577
retry_count += 1
650-
socket = self._get_socket(host, port, proto, timeout=timeout)
578+
socket = self._connection_manager.get_socket(
579+
host, port, proto, timeout=timeout, ssl_context=self._ssl_context
580+
)
651581
ok = True
652582
try:
653583
self._send_request(socket, host, method, path, headers, data, json)
@@ -668,7 +598,7 @@ def request(
668598
if result == b"H":
669599
# Things seem to be ok so break with socket set.
670600
break
671-
self._close_socket(socket)
601+
self._connection_manager.close_socket(socket)
672602
socket = None
673603

674604
if not socket:
@@ -727,54 +657,6 @@ def delete(self, url: str, **kw) -> Response:
727657
return self.request("DELETE", url, **kw)
728658

729659

730-
# Backwards compatible API:
731-
732-
_default_session = None # pylint: disable=invalid-name
733-
734-
735-
class _FakeSSLSocket:
736-
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
737-
self._socket = socket
738-
self._mode = tls_mode
739-
self.settimeout = socket.settimeout
740-
self.send = socket.send
741-
self.recv = socket.recv
742-
self.close = socket.close
743-
self.recv_into = socket.recv_into
744-
745-
def connect(self, address: Tuple[str, int]) -> None:
746-
"""connect wrapper to add non-standard mode parameter"""
747-
try:
748-
return self._socket.connect(address, self._mode)
749-
except RuntimeError as error:
750-
raise OSError(errno.ENOMEM) from error
751-
752-
753-
class _FakeSSLContext:
754-
def __init__(self, iface: InterfaceType) -> None:
755-
self._iface = iface
756-
757-
def wrap_socket(
758-
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
759-
) -> _FakeSSLSocket:
760-
"""Return the same socket"""
761-
# pylint: disable=unused-argument
762-
return _FakeSSLSocket(socket, self._iface.TLS_MODE)
763-
764-
765-
def set_socket(
766-
sock: SocketpoolModuleType, iface: Optional[InterfaceType] = None
767-
) -> None:
768-
"""Legacy API for setting the socket and network interface. Use a `Session` instead."""
769-
global _default_session # pylint: disable=global-statement,invalid-name
770-
if not iface:
771-
# pylint: disable=protected-access
772-
_default_session = Session(sock, _FakeSSLContext(sock._the_interface))
773-
else:
774-
_default_session = Session(sock, _FakeSSLContext(iface))
775-
sock.set_interface(iface)
776-
777-
778660
def request(
779661
method: str,
780662
url: str,
@@ -786,7 +668,7 @@ def request(
786668
) -> None:
787669
"""Send HTTP request"""
788670
# pylint: disable=too-many-arguments
789-
_default_session.request(
671+
_global_session.request(
790672
method,
791673
url,
792674
data=data,
@@ -799,29 +681,29 @@ def request(
799681

800682
def head(url: str, **kw):
801683
"""Send HTTP HEAD request"""
802-
return _default_session.request("HEAD", url, **kw)
684+
return _global_session.request("HEAD", url, **kw)
803685

804686

805687
def get(url: str, **kw):
806688
"""Send HTTP GET request"""
807-
return _default_session.request("GET", url, **kw)
689+
return _global_session.request("GET", url, **kw)
808690

809691

810692
def post(url: str, **kw):
811693
"""Send HTTP POST request"""
812-
return _default_session.request("POST", url, **kw)
694+
return _global_session.request("POST", url, **kw)
813695

814696

815697
def put(url: str, **kw):
816698
"""Send HTTP PUT request"""
817-
return _default_session.request("PUT", url, **kw)
699+
return _global_session.request("PUT", url, **kw)
818700

819701

820702
def patch(url: str, **kw):
821703
"""Send HTTP PATCH request"""
822-
return _default_session.request("PATCH", url, **kw)
704+
return _global_session.request("PATCH", url, **kw)
823705

824706

825707
def delete(url: str, **kw):
826708
"""Send HTTP DELETE request"""
827-
return _default_session.request("DELETE", url, **kw)
709+
return _global_session.request("DELETE", url, **kw)

conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-FileCopyrightText: 2023 Justin Myers for Adafruit Industries
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
""" PyTest Setup """
6+
7+
import pytest
8+
import adafruit_connectionmanager
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def reset_connection_manager(monkeypatch):
13+
"""Reset the ConnectionManager, since it's a singlton and will hold data"""
14+
monkeypatch.setattr(
15+
"adafruit_requests.get_connection_manager",
16+
adafruit_connectionmanager.ConnectionManager,
17+
)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
# SPDX-License-Identifier: Unlicense
44

55
Adafruit-Blinka
6+
Adafruit-Circuitpython-ConnectionManager@git+https://github.com/justmobilize/Adafruit_CircuitPython_ConnectionManager

tests/concurrent_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RESPONSE = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + TEXT
1818

1919

20-
def test_second_connect_fails_memoryerror(): # pylint: disable=invalid-name
20+
def test_second_connect_fails_memoryerror():
2121
pool = mocket.MocketPool()
2222
pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),)
2323
sock = mocket.Mocket(RESPONSE)
@@ -59,7 +59,7 @@ def test_second_connect_fails_memoryerror(): # pylint: disable=invalid-name
5959
assert pool.socket.call_count == 3
6060

6161

62-
def test_second_connect_fails_oserror(): # pylint: disable=invalid-name
62+
def test_second_connect_fails_oserror():
6363
pool = mocket.MocketPool()
6464
pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),)
6565
sock = mocket.Mocket(RESPONSE)

tests/reuse_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_second_send_fails():
202202
assert pool.socket.call_count == 2
203203

204204

205-
def test_second_send_lies_recv_fails(): # pylint: disable=invalid-name
205+
def test_second_send_lies_recv_fails():
206206
pool = mocket.MocketPool()
207207
pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),)
208208
sock = mocket.Mocket(RESPONSE)

tox.ini

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,36 @@
33
# SPDX-License-Identifier: MIT
44

55
[tox]
6-
envlist = py38
6+
envlist = py311
77

88
[testenv]
9-
changedir = {toxinidir}/tests
10-
deps = pytest==6.2.5
9+
description = run tests
10+
deps =
11+
pytest==7.4.3
1112
commands = pytest
13+
14+
[testenv:coverage]
15+
description = run coverage
16+
deps =
17+
pytest==7.4.3
18+
pytest-cov==4.1.0
19+
package = editable
20+
commands =
21+
coverage run --source=. --omit=tests/* --branch {posargs} -m pytest
22+
coverage report
23+
coverage html
24+
25+
[testenv:lint]
26+
description = run linters
27+
deps =
28+
pre-commit==3.6.0
29+
skip_install = true
30+
commands = pre-commit run {posargs}
31+
32+
[testenv:docs]
33+
description = build docs
34+
deps =
35+
-r requirements.txt
36+
-r docs/requirements.txt
37+
skip_install = true
38+
commands = sphinx-build -E -W -b html docs/. _build/html

0 commit comments

Comments
 (0)