Skip to content

Commit 83e43f5

Browse files
committed
Harden AsyncSSH state machine against message injection during handshake
This commit puts additional restrictions on when messages are accepted during the SSH handshake to avoid message injection attacks from a rogue client or server. More detailed information will be available in CVE-2023-46445 and CVE-2023-46446, to be published shortly. Thanks go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for identifying and reporting these vulnerabilities and providing detailed analysis and suggestions for how to protect against them, as well as review comments on the proposed fix.
1 parent f67234f commit 83e43f5

2 files changed

Lines changed: 207 additions & 76 deletions

File tree

asyncssh/connection.py

Lines changed: 83 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
899899
self._can_send_ext_info = False
900900
self._extensions_to_send: 'OrderedDict[bytes, bytes]' = OrderedDict()
901901

902+
self._can_recv_ext_info = False
903+
902904
self._server_sig_algs: Set[bytes] = set()
903905

904906
self._next_service: Optional[bytes] = None
@@ -908,6 +910,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
908910
self._auth: Optional[Auth] = None
909911
self._auth_in_progress = False
910912
self._auth_complete = False
913+
self._auth_final = False
911914
self._auth_methods = [b'none']
912915
self._auth_was_trivial = True
913916
self._username = ''
@@ -1538,15 +1541,25 @@ def _recv_packet(self) -> bool:
15381541
skip_reason = ''
15391542
exc_reason = ''
15401543

1541-
if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
1542-
if self._ignore_first_kex: # pragma: no cover
1543-
skip_reason = 'ignored first kex'
1544-
self._ignore_first_kex = False
1544+
if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
1545+
if self._kex:
1546+
if self._ignore_first_kex: # pragma: no cover
1547+
skip_reason = 'ignored first kex'
1548+
self._ignore_first_kex = False
1549+
else:
1550+
handler = self._kex
15451551
else:
1546-
handler = self._kex
1547-
elif (self._auth and
1548-
MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST):
1549-
handler = self._auth
1552+
skip_reason = 'kex not in progress'
1553+
exc_reason = 'Key exchange not in progress'
1554+
elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST:
1555+
if self._auth:
1556+
handler = self._auth
1557+
else:
1558+
skip_reason = 'auth not in progress'
1559+
exc_reason = 'Authentication not in progress'
1560+
elif pkttype > MSG_KEX_LAST and not self._recv_encryption:
1561+
skip_reason = 'invalid request before kex complete'
1562+
exc_reason = 'Invalid request before key exchange was complete'
15501563
elif pkttype > MSG_USERAUTH_LAST and not self._auth_complete:
15511564
skip_reason = 'invalid request before auth complete'
15521565
exc_reason = 'Invalid request before authentication was complete'
@@ -1579,6 +1592,9 @@ def _recv_packet(self) -> bool:
15791592
if exc_reason:
15801593
raise ProtocolError(exc_reason)
15811594

1595+
if pkttype > MSG_USERAUTH_LAST:
1596+
self._auth_final = True
1597+
15821598
if self._transport:
15831599
self._recv_seq = (seq + 1) & 0xffffffff
15841600
self._recv_handler = self._recv_pkthdr
@@ -1596,9 +1612,7 @@ def send_packet(self, pkttype: int, *args: bytes,
15961612
self._send_kexinit()
15971613
self._kexinit_sent = True
15981614

1599-
if (((pkttype in {MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or
1600-
pkttype > MSG_KEX_LAST) and not self._kex_complete) or
1601-
(pkttype == MSG_USERAUTH_BANNER and
1615+
if ((pkttype == MSG_USERAUTH_BANNER and
16021616
not (self._auth_in_progress or self._auth_complete)) or
16031617
(pkttype > MSG_USERAUTH_LAST and not self._auth_complete)):
16041618
self._deferred_packets.append((pkttype, args))
@@ -1810,9 +1824,11 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
18101824
not self._waiter.cancelled():
18111825
self._waiter.set_result(None)
18121826
self._wait = None
1813-
else:
1814-
self.send_service_request(_USERAUTH_SERVICE)
1827+
return
18151828
else:
1829+
self._extensions_to_send[b'server-sig-algs'] = \
1830+
b','.join(self._sig_algs)
1831+
18161832
self._send_encryption = next_enc_sc
18171833
self._send_enchdrlen = 1 if etm_sc else 5
18181834
self._send_blocksize = max(8, enc_blocksize_sc)
@@ -1833,17 +1849,18 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
18331849
recv_mac=self._mac_alg_cs.decode('ascii'),
18341850
recv_compression=self._cmp_alg_cs.decode('ascii'))
18351851

1836-
if first_kex:
1837-
self._next_service = _USERAUTH_SERVICE
1838-
1839-
self._extensions_to_send[b'server-sig-algs'] = \
1840-
b','.join(self._sig_algs)
1841-
18421852
if self._can_send_ext_info:
18431853
self._send_ext_info()
18441854
self._can_send_ext_info = False
18451855

18461856
self._kex_complete = True
1857+
1858+
if first_kex:
1859+
if self.is_client():
1860+
self.send_service_request(_USERAUTH_SERVICE)
1861+
else:
1862+
self._next_service = _USERAUTH_SERVICE
1863+
18471864
self._send_deferred_packets()
18481865

18491866
def send_service_request(self, service: bytes) -> None:
@@ -2080,18 +2097,25 @@ def _process_service_request(self, _pkttype: int, _pktid: int,
20802097
service = packet.get_string()
20812098
packet.check_end()
20822099

2083-
if service == self._next_service:
2084-
self.logger.debug2('Accepting request for service %s', service)
2100+
if self.is_client():
2101+
raise ProtocolError('Unexpected service request received')
20852102

2086-
self.send_packet(MSG_SERVICE_ACCEPT, String(service))
2103+
if not self._recv_encryption:
2104+
raise ProtocolError('Service request received before kex complete')
20872105

2088-
if (self.is_server() and # pragma: no branch
2089-
not self._auth_in_progress and
2090-
service == _USERAUTH_SERVICE):
2091-
self._auth_in_progress = True
2092-
self._send_deferred_packets()
2093-
else:
2094-
raise ServiceNotAvailable('Unexpected service request received')
2106+
if service != self._next_service:
2107+
raise ServiceNotAvailable('Unexpected service in service request')
2108+
2109+
self.logger.debug2('Accepting request for service %s', service)
2110+
2111+
self.send_packet(MSG_SERVICE_ACCEPT, String(service))
2112+
2113+
self._next_service = None
2114+
2115+
if service == _USERAUTH_SERVICE: # pragma: no branch
2116+
self._auth_in_progress = True
2117+
self._can_recv_ext_info = False
2118+
self._send_deferred_packets()
20952119

20962120
def _process_service_accept(self, _pkttype: int, _pktid: int,
20972121
packet: SSHPacket) -> None:
@@ -2100,27 +2124,35 @@ def _process_service_accept(self, _pkttype: int, _pktid: int,
21002124
service = packet.get_string()
21012125
packet.check_end()
21022126

2103-
if service == self._next_service:
2104-
self.logger.debug2('Request for service %s accepted', service)
2127+
if self.is_server():
2128+
raise ProtocolError('Unexpected service accept received')
21052129

2106-
self._next_service = None
2130+
if not self._recv_encryption:
2131+
raise ProtocolError('Service accept received before kex complete')
21072132

2108-
if (self.is_client() and # pragma: no branch
2109-
service == _USERAUTH_SERVICE):
2110-
self.logger.info('Beginning auth for user %s', self._username)
2133+
if service != self._next_service:
2134+
raise ServiceNotAvailable('Unexpected service in service accept')
21112135

2112-
self._auth_in_progress = True
2136+
self.logger.debug2('Request for service %s accepted', service)
21132137

2114-
# This method is only in SSHClientConnection
2115-
# pylint: disable=no-member
2116-
cast('SSHClientConnection', self).try_next_auth()
2117-
else:
2118-
raise ServiceNotAvailable('Unexpected service accept received')
2138+
self._next_service = None
2139+
2140+
if service == _USERAUTH_SERVICE: # pragma: no branch
2141+
self.logger.info('Beginning auth for user %s', self._username)
2142+
2143+
self._auth_in_progress = True
2144+
2145+
# This method is only in SSHClientConnection
2146+
# pylint: disable=no-member
2147+
cast('SSHClientConnection', self).try_next_auth()
21192148

21202149
def _process_ext_info(self, _pkttype: int, _pktid: int,
21212150
packet: SSHPacket) -> None:
21222151
"""Process extension information"""
21232152

2153+
if not self._can_recv_ext_info:
2154+
raise ProtocolError('Unexpected ext_info received')
2155+
21242156
extensions: Dict[bytes, bytes] = {}
21252157

21262158
self.logger.debug2('Received extension info')
@@ -2246,6 +2278,7 @@ def _process_newkeys(self, _pkttype: int, _pktid: int,
22462278
self._decompress_after_auth = self._next_decompress_after_auth
22472279

22482280
self._next_recv_encryption = None
2281+
self._can_recv_ext_info = True
22492282
else:
22502283
raise ProtocolError('New keys not negotiated')
22512284

@@ -2273,8 +2306,10 @@ def _process_userauth_request(self, _pkttype: int, _pktid: int,
22732306
if self.is_client():
22742307
raise ProtocolError('Unexpected userauth request')
22752308
elif self._auth_complete:
2276-
# Silently ignore requests if we're already authenticated
2277-
pass
2309+
# Silently ignore additional auth requests after auth succeeds,
2310+
# until the client sends a non-auth message
2311+
if self._auth_final:
2312+
raise ProtocolError('Unexpected userauth request')
22782313
else:
22792314
if username != self._username:
22802315
self.logger.info('Beginning auth for user %s', username)
@@ -2316,7 +2351,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes,
23162351
self._auth = lookup_server_auth(cast(SSHServerConnection, self),
23172352
self._username, method, packet)
23182353

2319-
def _process_userauth_failure(self, _pkttype: int, pktid: int,
2354+
def _process_userauth_failure(self, _pkttype: int, _pktid: int,
23202355
packet: SSHPacket) -> None:
23212356
"""Process a user authentication failure response"""
23222357

@@ -2356,10 +2391,9 @@ def _process_userauth_failure(self, _pkttype: int, pktid: int,
23562391
# pylint: disable=no-member
23572392
cast(SSHClientConnection, self).try_next_auth()
23582393
else:
2359-
self.logger.debug2('Unexpected userauth failure response')
2360-
self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid))
2394+
raise ProtocolError('Unexpected userauth failure response')
23612395

2362-
def _process_userauth_success(self, _pkttype: int, pktid: int,
2396+
def _process_userauth_success(self, _pkttype: int, _pktid: int,
23632397
packet: SSHPacket) -> None:
23642398
"""Process a user authentication success response"""
23652399

@@ -2385,6 +2419,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
23852419
self._auth = None
23862420
self._auth_in_progress = False
23872421
self._auth_complete = True
2422+
self._can_recv_ext_info = False
23882423

23892424
if self._agent:
23902425
self._agent.close()
@@ -2412,8 +2447,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
24122447
self._waiter.set_result(None)
24132448
self._wait = None
24142449
else:
2415-
self.logger.debug2('Unexpected userauth success response')
2416-
self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid))
2450+
raise ProtocolError('Unexpected userauth success response')
24172451

24182452
def _process_userauth_banner(self, _pkttype: int, _pktid: int,
24192453
packet: SSHPacket) -> None:

0 commit comments

Comments
 (0)