Skip to content

Commit 7f6bb71

Browse files
authored
[Azure AMQP] Remove Deprecated SSL.Wrap_Socket (#31524)
* remove deprecated code * some re-ordering * bring changes over to SB * fix order * update async side * sb async transport * comment for ordering * changes to sni wrap code * sync code * enable auto ssl handshake * fix type of ssl opts * bring change to SB * minor fix * refactor out server_side * remove comment * sync changes * remove whitespace * get rid of whitespace * remove whitespace * fix pylint
1 parent f8f8742 commit 7f6bb71

4 files changed

Lines changed: 116 additions & 112 deletions

File tree

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/_transport.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def __init__(
501501
self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs
502502
):
503503
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {}
504+
self.sslopts['server_hostname'] = host
504505
self._read_buffer = BytesIO()
505506
super(SSLTransport, self).__init__(
506507
host, port=port, socket_timeout=socket_timeout, **kwargs
@@ -509,7 +510,6 @@ def __init__(
509510
def _setup_transport(self):
510511
"""Wrap the socket in an SSL object."""
511512
self.sock = self._wrap_socket(self.sock, **self.sslopts)
512-
self.sock.do_handshake()
513513
self._quick_recv = self.sock.recv
514514

515515
def _wrap_socket(self, sock, context=None, **sslopts):
@@ -531,10 +531,9 @@ def _wrap_socket_sni(
531531
sock,
532532
keyfile=None,
533533
certfile=None,
534-
server_side=False,
535534
cert_reqs=ssl.CERT_REQUIRED,
536535
ca_certs=None,
537-
do_handshake_on_connect=False,
536+
do_handshake_on_connect=True,
538537
suppress_ragged_eofs=True,
539538
server_hostname=None,
540539
ciphers=None,
@@ -548,7 +547,6 @@ def _wrap_socket_sni(
548547
:param socket.socket sock: socket to wrap
549548
:param str or None keyfile: key file path
550549
:param str or None certfile: cert file path
551-
:param bool or None server_side: server side socket
552550
:param int cert_reqs: cert requirements
553551
:param str or None ca_certs: ca certs file path
554552
:param bool do_handshake_on_connect: do handshake on connect
@@ -562,44 +560,39 @@ def _wrap_socket_sni(
562560
# Setup the right SSL version; default to optimal versions across
563561
# ssl implementations
564562
if ssl_version is None:
565-
ssl_version = ssl.PROTOCOL_TLS
563+
ssl_version = ssl.PROTOCOL_TLS_CLIENT
564+
purpose = ssl.Purpose.SERVER_AUTH
566565

567566
opts = {
568567
"sock": sock,
569-
"keyfile": keyfile,
570-
"certfile": certfile,
571-
"server_side": server_side,
572-
"cert_reqs": cert_reqs,
573-
"ca_certs": ca_certs,
574568
"do_handshake_on_connect": do_handshake_on_connect,
575569
"suppress_ragged_eofs": suppress_ragged_eofs,
576-
"ciphers": ciphers,
577-
#'ssl_version': ssl_version
570+
"server_hostname": server_hostname,
578571
}
579572

580-
# TODO: We need to refactor this.
581-
try:
582-
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
583-
except FileNotFoundError as exc:
584-
# FileNotFoundError does not have missing filename info, so adding it below.
585-
# Assuming that this must be ca_certs, since this is the only file path that
586-
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
587-
# For uamqp exception parity. Remove later when resolving issue #27128.
588-
exc.filename = {"ca_certs": ca_certs}
589-
raise exc
590-
# Set SNI headers if supported
591-
if (
592-
(server_hostname is not None)
593-
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
594-
and (hasattr(ssl, "SSLContext"))
595-
):
596-
context = ssl.SSLContext(opts["ssl_version"])
573+
context = ssl.SSLContext(ssl_version)
574+
575+
if ca_certs is not None:
576+
try:
577+
context.load_verify_locations(ca_certs)
578+
except FileNotFoundError as exc:
579+
exc.filename = {"ca_certs": ca_certs}
580+
raise exc from None
581+
elif context.verify_mode != ssl.CERT_NONE:
582+
# load the default system root CA certs.
583+
context.load_default_certs(purpose=purpose)
584+
585+
if certfile is not None:
586+
context.load_cert_chain(certfile, keyfile)
587+
588+
if ciphers is not None:
589+
context.set_ciphers(ciphers)
590+
591+
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
592+
context.check_hostname = False
597593
context.verify_mode = cert_reqs
598-
if cert_reqs != ssl.CERT_NONE:
599-
context.check_hostname = True
600-
if (certfile is not None) and (keyfile is not None):
601-
context.load_cert_chain(certfile, keyfile)
602-
sock = context.wrap_socket(sock, server_hostname=server_hostname)
594+
595+
sock = context.wrap_socket(**opts)
603596
return sock
604597

605598
def _shutdown_transport(self):

sdk/eventhub/azure-eventhub/azure/eventhub/_pyamqp/aio/_transport_async.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts):
185185
return self._build_ssl_context(**sslopts.pop("context"))
186186
ssl_version = sslopts.get("ssl_version")
187187
if ssl_version is None:
188-
ssl_version = ssl.PROTOCOL_TLS
188+
ssl_version = ssl.PROTOCOL_TLS_CLIENT
189+
190+
context = ssl.SSLContext(ssl_version)
191+
192+
purpose = ssl.Purpose.SERVER_AUTH
193+
194+
ca_certs = sslopts.get("ca_certs")
195+
196+
if ca_certs is not None:
197+
try:
198+
context.load_verify_locations(ca_certs)
199+
except FileNotFoundError as exc:
200+
# FileNotFoundError does not have missing filename info, so adding it below.
201+
# since this is the only file path that users can pass in
202+
# (`connection_verify` in the EH/SB clients) through opts above.
203+
exc.filename = {"ca_certs": ca_certs}
204+
raise exc from None
205+
elif context.verify_mode != ssl.CERT_NONE:
206+
# load the default system root CA certs.
207+
context.load_default_certs(purpose=purpose)
208+
209+
certfile = sslopts.get("certfile")
210+
keyfile = sslopts.get("keyfile")
211+
if certfile is not None:
212+
context.load_cert_chain(certfile, keyfile)
213+
189214

190-
# Set SNI headers if supported
191215
server_hostname = sslopts.get("server_hostname")
192-
if (
193-
(server_hostname is not None)
194-
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
195-
and (hasattr(ssl, "SSLContext"))
196-
):
197-
context = ssl.SSLContext(ssl_version)
198-
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
199-
certfile = sslopts.get("certfile")
200-
keyfile = sslopts.get("keyfile")
216+
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
217+
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
218+
context.check_hostname = False
201219
context.verify_mode = cert_reqs
202-
if cert_reqs != ssl.CERT_NONE:
203-
context.check_hostname = True
204-
if (certfile is not None) and (keyfile is not None):
205-
context.load_cert_chain(certfile, keyfile)
206-
return context
207-
ca_certs = sslopts.get("ca_certs")
208-
if ca_certs:
209-
context = ssl.SSLContext(ssl_version)
210-
context.load_verify_locations(ca_certs)
211-
return context
212-
return True
220+
221+
return context
213222
except TypeError:
214223
raise TypeError(
215224
"SSL configuration must be a dictionary, or the value True."

sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/_transport.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def __init__(
501501
self, host, *, port=AMQPS_PORT, socket_timeout=None, ssl_opts=None, **kwargs
502502
):
503503
self.sslopts = ssl_opts if isinstance(ssl_opts, dict) else {}
504+
self.sslopts['server_hostname'] = host
504505
self._read_buffer = BytesIO()
505506
super(SSLTransport, self).__init__(
506507
host, port=port, socket_timeout=socket_timeout, **kwargs
@@ -509,7 +510,6 @@ def __init__(
509510
def _setup_transport(self):
510511
"""Wrap the socket in an SSL object."""
511512
self.sock = self._wrap_socket(self.sock, **self.sslopts)
512-
self.sock.do_handshake()
513513
self._quick_recv = self.sock.recv
514514

515515
def _wrap_socket(self, sock, context=None, **sslopts):
@@ -531,10 +531,9 @@ def _wrap_socket_sni(
531531
sock,
532532
keyfile=None,
533533
certfile=None,
534-
server_side=False,
535534
cert_reqs=ssl.CERT_REQUIRED,
536535
ca_certs=None,
537-
do_handshake_on_connect=False,
536+
do_handshake_on_connect=True,
538537
suppress_ragged_eofs=True,
539538
server_hostname=None,
540539
ciphers=None,
@@ -548,7 +547,6 @@ def _wrap_socket_sni(
548547
:param socket.socket sock: socket to wrap
549548
:param str or None keyfile: key file path
550549
:param str or None certfile: cert file path
551-
:param bool or None server_side: server side socket
552550
:param int cert_reqs: cert requirements
553551
:param str or None ca_certs: ca certs file path
554552
:param bool do_handshake_on_connect: do handshake on connect
@@ -562,44 +560,39 @@ def _wrap_socket_sni(
562560
# Setup the right SSL version; default to optimal versions across
563561
# ssl implementations
564562
if ssl_version is None:
565-
ssl_version = ssl.PROTOCOL_TLS
563+
ssl_version = ssl.PROTOCOL_TLS_CLIENT
564+
purpose = ssl.Purpose.SERVER_AUTH
566565

567566
opts = {
568567
"sock": sock,
569-
"keyfile": keyfile,
570-
"certfile": certfile,
571-
"server_side": server_side,
572-
"cert_reqs": cert_reqs,
573-
"ca_certs": ca_certs,
574568
"do_handshake_on_connect": do_handshake_on_connect,
575569
"suppress_ragged_eofs": suppress_ragged_eofs,
576-
"ciphers": ciphers,
577-
#'ssl_version': ssl_version
570+
"server_hostname": server_hostname,
578571
}
579572

580-
# TODO: We need to refactor this.
581-
try:
582-
sock = ssl.wrap_socket(**opts) # pylint: disable=deprecated-method
583-
except FileNotFoundError as exc:
584-
# FileNotFoundError does not have missing filename info, so adding it below.
585-
# Assuming that this must be ca_certs, since this is the only file path that
586-
# users can pass in (`connection_verify` in the EH/SB clients) through opts above.
587-
# For uamqp exception parity. Remove later when resolving issue #27128.
588-
exc.filename = {"ca_certs": ca_certs}
589-
raise exc
590-
# Set SNI headers if supported
591-
if (
592-
(server_hostname is not None)
593-
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
594-
and (hasattr(ssl, "SSLContext"))
595-
):
596-
context = ssl.SSLContext(opts["ssl_version"])
573+
context = ssl.SSLContext(ssl_version)
574+
575+
if ca_certs is not None:
576+
try:
577+
context.load_verify_locations(ca_certs)
578+
except FileNotFoundError as exc:
579+
exc.filename = {"ca_certs": ca_certs}
580+
raise exc from None
581+
elif context.verify_mode != ssl.CERT_NONE:
582+
# load the default system root CA certs.
583+
context.load_default_certs(purpose=purpose)
584+
585+
if certfile is not None:
586+
context.load_cert_chain(certfile, keyfile)
587+
588+
if ciphers is not None:
589+
context.set_ciphers(ciphers)
590+
591+
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
592+
context.check_hostname = False
597593
context.verify_mode = cert_reqs
598-
if cert_reqs != ssl.CERT_NONE:
599-
context.check_hostname = True
600-
if (certfile is not None) and (keyfile is not None):
601-
context.load_cert_chain(certfile, keyfile)
602-
sock = context.wrap_socket(sock, server_hostname=server_hostname)
594+
595+
sock = context.wrap_socket(**opts)
603596
return sock
604597

605598
def _shutdown_transport(self):

sdk/servicebus/azure-servicebus/azure/servicebus/_pyamqp/aio/_transport_async.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -185,31 +185,40 @@ def _build_ssl_opts(self, sslopts):
185185
return self._build_ssl_context(**sslopts.pop("context"))
186186
ssl_version = sslopts.get("ssl_version")
187187
if ssl_version is None:
188-
ssl_version = ssl.PROTOCOL_TLS
188+
ssl_version = ssl.PROTOCOL_TLS_CLIENT
189+
190+
context = ssl.SSLContext(ssl_version)
191+
192+
purpose = ssl.Purpose.SERVER_AUTH
193+
194+
ca_certs = sslopts.get("ca_certs")
195+
196+
if ca_certs is not None:
197+
try:
198+
context.load_verify_locations(ca_certs)
199+
except FileNotFoundError as exc:
200+
# FileNotFoundError does not have missing filename info, so adding it below.
201+
# since this is the only file path that users can pass in
202+
# (`connection_verify` in the EH/SB clients) through opts above.
203+
exc.filename = {"ca_certs": ca_certs}
204+
raise exc from None
205+
elif context.verify_mode != ssl.CERT_NONE:
206+
# load the default system root CA certs.
207+
context.load_default_certs(purpose=purpose)
208+
209+
certfile = sslopts.get("certfile")
210+
keyfile = sslopts.get("keyfile")
211+
if certfile is not None:
212+
context.load_cert_chain(certfile, keyfile)
213+
189214

190-
# Set SNI headers if supported
191215
server_hostname = sslopts.get("server_hostname")
192-
if (
193-
(server_hostname is not None)
194-
and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI)
195-
and (hasattr(ssl, "SSLContext"))
196-
):
197-
context = ssl.SSLContext(ssl_version)
198-
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
199-
certfile = sslopts.get("certfile")
200-
keyfile = sslopts.get("keyfile")
216+
cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED)
217+
if cert_reqs == ssl.CERT_NONE and server_hostname is None:
218+
context.check_hostname = False
201219
context.verify_mode = cert_reqs
202-
if cert_reqs != ssl.CERT_NONE:
203-
context.check_hostname = True
204-
if (certfile is not None) and (keyfile is not None):
205-
context.load_cert_chain(certfile, keyfile)
206-
return context
207-
ca_certs = sslopts.get("ca_certs")
208-
if ca_certs:
209-
context = ssl.SSLContext(ssl_version)
210-
context.load_verify_locations(ca_certs)
211-
return context
212-
return True
220+
221+
return context
213222
except TypeError:
214223
raise TypeError(
215224
"SSL configuration must be a dictionary, or the value True."

0 commit comments

Comments
 (0)