3030
3131if SSL_AVAILABLE :
3232 import ssl
33- from ssl import SSLContext , TLSVersion
33+ from ssl import SSLContext , TLSVersion , VerifyFlags
3434else :
3535 ssl = None
3636 TLSVersion = None
3737 SSLContext = None
38+ VerifyFlags = None
3839
3940from ..auth .token import TokenInterface
4041from ..event import AsyncAfterConnectionReleasedEvent , EventDispatcher
@@ -793,6 +794,8 @@ def __init__(
793794 ssl_keyfile : Optional [str ] = None ,
794795 ssl_certfile : Optional [str ] = None ,
795796 ssl_cert_reqs : Union [str , ssl .VerifyMode ] = "required" ,
797+ ssl_include_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
798+ ssl_exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
796799 ssl_ca_certs : Optional [str ] = None ,
797800 ssl_ca_data : Optional [str ] = None ,
798801 ssl_check_hostname : bool = True ,
@@ -807,6 +810,8 @@ def __init__(
807810 keyfile = ssl_keyfile ,
808811 certfile = ssl_certfile ,
809812 cert_reqs = ssl_cert_reqs ,
813+ include_verify_flags = ssl_include_verify_flags ,
814+ exclude_verify_flags = ssl_exclude_verify_flags ,
810815 ca_certs = ssl_ca_certs ,
811816 ca_data = ssl_ca_data ,
812817 check_hostname = ssl_check_hostname ,
@@ -832,6 +837,14 @@ def certfile(self):
832837 def cert_reqs (self ):
833838 return self .ssl_context .cert_reqs
834839
840+ @property
841+ def include_verify_flags (self ):
842+ return self .ssl_context .include_verify_flags
843+
844+ @property
845+ def exclude_verify_flags (self ):
846+ return self .ssl_context .exclude_verify_flags
847+
835848 @property
836849 def ca_certs (self ):
837850 return self .ssl_context .ca_certs
@@ -854,6 +867,8 @@ class RedisSSLContext:
854867 "keyfile" ,
855868 "certfile" ,
856869 "cert_reqs" ,
870+ "include_verify_flags" ,
871+ "exclude_verify_flags" ,
857872 "ca_certs" ,
858873 "ca_data" ,
859874 "context" ,
@@ -867,6 +882,8 @@ def __init__(
867882 keyfile : Optional [str ] = None ,
868883 certfile : Optional [str ] = None ,
869884 cert_reqs : Optional [Union [str , ssl .VerifyMode ]] = None ,
885+ include_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
886+ exclude_verify_flags : Optional [List ["ssl.VerifyFlags" ]] = None ,
870887 ca_certs : Optional [str ] = None ,
871888 ca_data : Optional [str ] = None ,
872889 check_hostname : bool = False ,
@@ -892,6 +909,8 @@ def __init__(
892909 )
893910 cert_reqs = CERT_REQS [cert_reqs ]
894911 self .cert_reqs = cert_reqs
912+ self .include_verify_flags = include_verify_flags
913+ self .exclude_verify_flags = exclude_verify_flags
895914 self .ca_certs = ca_certs
896915 self .ca_data = ca_data
897916 self .check_hostname = (
@@ -906,6 +925,12 @@ def get(self) -> SSLContext:
906925 context = ssl .create_default_context ()
907926 context .check_hostname = self .check_hostname
908927 context .verify_mode = self .cert_reqs
928+ if self .include_verify_flags :
929+ for flag in self .include_verify_flags :
930+ context .verify_flags |= flag
931+ if self .exclude_verify_flags :
932+ for flag in self .exclude_verify_flags :
933+ context .verify_flags &= ~ flag
909934 if self .certfile and self .keyfile :
910935 context .load_cert_chain (certfile = self .certfile , keyfile = self .keyfile )
911936 if self .ca_certs or self .ca_data :
@@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
953978 return bool (value )
954979
955980
981+ def parse_ssl_verify_flags (value ):
982+ # flags are passed in as a string representation of a list,
983+ # e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
984+ verify_flags_str = value .replace ("[" , "" ).replace ("]" , "" )
985+
986+ verify_flags = []
987+ for flag in verify_flags_str .split ("," ):
988+ flag = flag .strip ()
989+ if not hasattr (VerifyFlags , flag ):
990+ raise ValueError (f"Invalid ssl verify flag: { flag } " )
991+ verify_flags .append (getattr (VerifyFlags , flag ))
992+ return verify_flags
993+
994+
956995URL_QUERY_ARGUMENT_PARSERS : Mapping [str , Callable [..., object ]] = MappingProxyType (
957996 {
958997 "db" : int ,
@@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
9631002 "max_connections" : int ,
9641003 "health_check_interval" : int ,
9651004 "ssl_check_hostname" : to_bool ,
1005+ "ssl_include_verify_flags" : parse_ssl_verify_flags ,
1006+ "ssl_exclude_verify_flags" : parse_ssl_verify_flags ,
9661007 "timeout" : float ,
9671008 }
9681009)
@@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:
10211062
10221063 if parsed .scheme == "rediss" :
10231064 kwargs ["connection_class" ] = SSLConnection
1065+
10241066 else :
10251067 valid_schemes = "redis://, rediss://, unix://"
10261068 raise ValueError (
0 commit comments