88#
99# and added OCSP validator on the top.
1010import logging
11+ import os
12+ import ssl
1113import time
1214import weakref
1315from contextvars import ContextVar
1416from functools import wraps
15- from inspect import getfullargspec as get_args
17+ from inspect import signature as _sig
1618from socket import socket
1719from typing import Any
1820
3840log = logging .getLogger (__name__ )
3941
4042
43+ # Helper utilities (private)
44+ def _resolve_cafile (kwargs : dict [str , Any ]) -> str | None :
45+ """Resolve CA bundle path from kwargs or standard environment variables.
46+
47+ Precedence:
48+ 1) kwargs['ca_certs'] if provided by caller
49+ 2) REQUESTS_CA_BUNDLE
50+ 3) SSL_CERT_FILE
51+ """
52+ caf = kwargs .get ("ca_certs" )
53+ if caf :
54+ return caf
55+ return os .environ .get ("REQUESTS_CA_BUNDLE" ) or os .environ .get ("SSL_CERT_FILE" )
56+
57+
58+ def _ensure_partial_chain_on_context (ctx : PyOpenSSLContext , cafile : str | None ) -> None :
59+ """Load CA bundle (when provided) and enable OpenSSL partial-chain support on ctx."""
60+ if cafile :
61+ try :
62+ ctx .load_verify_locations (cafile = cafile , capath = None )
63+ except (ssl .SSLError , OSError , ValueError ):
64+ # Leave context unchanged; handshake/validation surfaces failures
65+ pass
66+ try :
67+ store = ctx ._ctx .get_cert_store ()
68+ from OpenSSL import crypto as _crypto
69+
70+ if hasattr (_crypto , "X509StoreFlags" ) and hasattr (
71+ _crypto .X509StoreFlags , "PARTIAL_CHAIN"
72+ ):
73+ store .set_flags (_crypto .X509StoreFlags .PARTIAL_CHAIN )
74+ except (AttributeError , ImportError , OpenSSL .SSL .Error , OSError , ValueError ):
75+ # Best-effort; if not available, default chain building applies
76+ pass
77+
78+
79+ def _build_context_with_partial_chain (cafile : str | None ) -> PyOpenSSLContext :
80+ """Create PyOpenSSL context configured for CERT_REQUIRED and partial-chain trust."""
81+ ctx = PyOpenSSLContext (ssl_ .PROTOCOL_TLS_CLIENT )
82+ try :
83+ ctx .verify_mode = ssl .CERT_REQUIRED
84+ except Exception :
85+ pass
86+ _ensure_partial_chain_on_context (ctx , cafile )
87+ return ctx
88+
89+
4190# Store a *weak* reference so that the context variable doesn’t prolong the
4291# lifetime of the SessionManager. Once all owning connections are GC-ed the
43- # weakref goes dead and OCSP will fall back to its local manager (but most likely won't be used ever again anyway).
92+ # weakref goes dead and OCSP will fall back to its local manager (but most
93+ # likely won't be used ever again anyway).
4494_CURRENT_SESSION_MANAGER : ContextVar [weakref .ref [SessionManager ] | None ] = ContextVar (
4595 "_CURRENT_SESSION_MANAGER" ,
4696 default = None ,
@@ -71,7 +121,10 @@ def set_current_session_manager(sm: SessionManager | None) -> Any:
71121 Called from SnowflakeConnection so that OCSP downloads
72122 use the same proxy / header configuration as the initiating connection.
73123
74- Alternative approach would be moving method inject_into_urllib3() inside connection initialization, but in case this delay (from module import time to connection initialization time) would cause some code to break we stayed with this approach, having in mind soon OCSP deprecation.
124+ Alternative approach would be moving method inject_into_urllib3() inside
125+ connection initialization, but in case this delay (from module import time
126+ to connection initialization time) would cause some code to break we stayed
127+ with this approach, having in mind soon OCSP deprecation.
75128 """
76129 return _CURRENT_SESSION_MANAGER .set (weakref .ref (sm ) if sm is not None else None )
77130
@@ -93,37 +146,29 @@ def inject_into_urllib3() -> None:
93146
94147@wraps (ssl_ .ssl_wrap_socket )
95148def ssl_wrap_socket_with_ocsp (* args : Any , ** kwargs : Any ) -> WrappedSocket :
96- # Extract host_name
97- hostname_index = get_args (ssl_ .ssl_wrap_socket ).args .index ("server_hostname" )
98- server_hostname = (
99- args [hostname_index ]
100- if len (args ) > hostname_index
101- else kwargs .get ("server_hostname" , None )
102- )
103- # Remove context if present
104- ssl_context_index = get_args (ssl_ .ssl_wrap_socket ).args .index ("ssl_context" )
105- context_in_args = len (args ) > ssl_context_index
106- ssl_context = (
107- args [hostname_index ] if context_in_args else kwargs .get ("ssl_context" , None )
108- )
109- if not isinstance (ssl_context , PyOpenSSLContext ):
110- # Create new default context
111- if context_in_args :
112- new_args = list (args )
113- new_args [ssl_context_index ] = None
114- args = tuple (new_args )
115- else :
116- del kwargs ["ssl_context" ]
117- # Fix ca certs location
118- ca_certs_index = get_args (ssl_ .ssl_wrap_socket ).args .index ("ca_certs" )
119- ca_certs_in_args = len (args ) > ca_certs_index
120- if not ca_certs_in_args and not kwargs .get ("ca_certs" ):
121- kwargs ["ca_certs" ] = certifi .where ()
122-
123- ret = ssl_ .ssl_wrap_socket (* args , ** kwargs )
149+ # Bind passed args/kwargs to the underlying signature to support both positional and keyword calls
150+ bound = _sig (ssl_ .ssl_wrap_socket ).bind_partial (* args , ** kwargs )
151+ params = bound .arguments
152+
153+ server_hostname = params .get ("server_hostname" )
154+
155+ # Ensure CA bundle default if not provided
156+ if not params .get ("ca_certs" ):
157+ params ["ca_certs" ] = certifi .where ()
158+
159+ # Ensure PyOpenSSL context with partial-chain is used if none or wrong type provided
160+ provided_ctx = params .get ("ssl_context" )
161+ if not isinstance (provided_ctx , PyOpenSSLContext ):
162+ cafile_for_ctx = _resolve_cafile (params )
163+ params ["ssl_context" ] = _build_context_with_partial_chain (cafile_for_ctx )
164+ else :
165+ # If a PyOpenSSLContext is provided, ensure it trusts the provided CA and partial-chain is enabled
166+ _ensure_partial_chain_on_context (provided_ctx , _resolve_cafile (params ))
167+
168+ ret = ssl_ .ssl_wrap_socket (** params )
124169
125170 log .debug (
126- "OCSP Mode: %s, " " OCSP response cache file name: %s" ,
171+ "OCSP Mode: %s, OCSP response cache file name: %s" ,
127172 FEATURE_OCSP_MODE .name ,
128173 FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME ,
129174 )
@@ -137,10 +182,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
137182 ).validate (server_hostname , ret .connection )
138183 if not v :
139184 raise OperationalError (
140- msg = (
141- "The certificate is revoked or "
142- "could not be validated: hostname={}" .format (server_hostname )
143- ),
185+ msg = f"The certificate is revoked or could not be validated: hostname={ server_hostname } " ,
144186 errno = ER_OCSP_RESPONSE_CERT_STATUS_REVOKED ,
145187 )
146188 else :
0 commit comments