@@ -899,6 +899,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
899899 self ._can_send_ext_info = False
900900 self ._extensions_to_send : 'OrderedDict[bytes, bytes]' = OrderedDict ()
901901
902+ self ._can_recv_ext_info = False
903+
902904 self ._server_sig_algs : Set [bytes ] = set ()
903905
904906 self ._next_service : Optional [bytes ] = None
@@ -908,6 +910,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
908910 self ._auth : Optional [Auth ] = None
909911 self ._auth_in_progress = False
910912 self ._auth_complete = False
913+ self ._auth_final = False
911914 self ._auth_methods = [b'none' ]
912915 self ._auth_was_trivial = True
913916 self ._username = ''
@@ -1538,15 +1541,25 @@ def _recv_packet(self) -> bool:
15381541 skip_reason = ''
15391542 exc_reason = ''
15401543
1541- if self ._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST :
1542- if self ._ignore_first_kex : # pragma: no cover
1543- skip_reason = 'ignored first kex'
1544- self ._ignore_first_kex = False
1544+ if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST :
1545+ if self ._kex :
1546+ if self ._ignore_first_kex : # pragma: no cover
1547+ skip_reason = 'ignored first kex'
1548+ self ._ignore_first_kex = False
1549+ else :
1550+ handler = self ._kex
15451551 else :
1546- handler = self ._kex
1547- elif (self ._auth and
1548- MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST ):
1549- handler = self ._auth
1552+ skip_reason = 'kex not in progress'
1553+ exc_reason = 'Key exchange not in progress'
1554+ elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST :
1555+ if self ._auth :
1556+ handler = self ._auth
1557+ else :
1558+ skip_reason = 'auth not in progress'
1559+ exc_reason = 'Authentication not in progress'
1560+ elif pkttype > MSG_KEX_LAST and not self ._recv_encryption :
1561+ skip_reason = 'invalid request before kex complete'
1562+ exc_reason = 'Invalid request before key exchange was complete'
15501563 elif pkttype > MSG_USERAUTH_LAST and not self ._auth_complete :
15511564 skip_reason = 'invalid request before auth complete'
15521565 exc_reason = 'Invalid request before authentication was complete'
@@ -1579,6 +1592,9 @@ def _recv_packet(self) -> bool:
15791592 if exc_reason :
15801593 raise ProtocolError (exc_reason )
15811594
1595+ if pkttype > MSG_USERAUTH_LAST :
1596+ self ._auth_final = True
1597+
15821598 if self ._transport :
15831599 self ._recv_seq = (seq + 1 ) & 0xffffffff
15841600 self ._recv_handler = self ._recv_pkthdr
@@ -1596,9 +1612,7 @@ def send_packet(self, pkttype: int, *args: bytes,
15961612 self ._send_kexinit ()
15971613 self ._kexinit_sent = True
15981614
1599- if (((pkttype in {MSG_SERVICE_REQUEST , MSG_SERVICE_ACCEPT } or
1600- pkttype > MSG_KEX_LAST ) and not self ._kex_complete ) or
1601- (pkttype == MSG_USERAUTH_BANNER and
1615+ if ((pkttype == MSG_USERAUTH_BANNER and
16021616 not (self ._auth_in_progress or self ._auth_complete )) or
16031617 (pkttype > MSG_USERAUTH_LAST and not self ._auth_complete )):
16041618 self ._deferred_packets .append ((pkttype , args ))
@@ -1810,9 +1824,11 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
18101824 not self ._waiter .cancelled ():
18111825 self ._waiter .set_result (None )
18121826 self ._wait = None
1813- else :
1814- self .send_service_request (_USERAUTH_SERVICE )
1827+ return
18151828 else :
1829+ self ._extensions_to_send [b'server-sig-algs' ] = \
1830+ b',' .join (self ._sig_algs )
1831+
18161832 self ._send_encryption = next_enc_sc
18171833 self ._send_enchdrlen = 1 if etm_sc else 5
18181834 self ._send_blocksize = max (8 , enc_blocksize_sc )
@@ -1833,17 +1849,18 @@ def send_newkeys(self, k: bytes, h: bytes) -> None:
18331849 recv_mac = self ._mac_alg_cs .decode ('ascii' ),
18341850 recv_compression = self ._cmp_alg_cs .decode ('ascii' ))
18351851
1836- if first_kex :
1837- self ._next_service = _USERAUTH_SERVICE
1838-
1839- self ._extensions_to_send [b'server-sig-algs' ] = \
1840- b',' .join (self ._sig_algs )
1841-
18421852 if self ._can_send_ext_info :
18431853 self ._send_ext_info ()
18441854 self ._can_send_ext_info = False
18451855
18461856 self ._kex_complete = True
1857+
1858+ if first_kex :
1859+ if self .is_client ():
1860+ self .send_service_request (_USERAUTH_SERVICE )
1861+ else :
1862+ self ._next_service = _USERAUTH_SERVICE
1863+
18471864 self ._send_deferred_packets ()
18481865
18491866 def send_service_request (self , service : bytes ) -> None :
@@ -2080,18 +2097,25 @@ def _process_service_request(self, _pkttype: int, _pktid: int,
20802097 service = packet .get_string ()
20812098 packet .check_end ()
20822099
2083- if service == self ._next_service :
2084- self . logger . debug2 ( 'Accepting request for service %s' , service )
2100+ if self .is_client () :
2101+ raise ProtocolError ( 'Unexpected service request received' )
20852102
2086- self .send_packet (MSG_SERVICE_ACCEPT , String (service ))
2103+ if not self ._recv_encryption :
2104+ raise ProtocolError ('Service request received before kex complete' )
20872105
2088- if (self .is_server () and # pragma: no branch
2089- not self ._auth_in_progress and
2090- service == _USERAUTH_SERVICE ):
2091- self ._auth_in_progress = True
2092- self ._send_deferred_packets ()
2093- else :
2094- raise ServiceNotAvailable ('Unexpected service request received' )
2106+ if service != self ._next_service :
2107+ raise ServiceNotAvailable ('Unexpected service in service request' )
2108+
2109+ self .logger .debug2 ('Accepting request for service %s' , service )
2110+
2111+ self .send_packet (MSG_SERVICE_ACCEPT , String (service ))
2112+
2113+ self ._next_service = None
2114+
2115+ if service == _USERAUTH_SERVICE : # pragma: no branch
2116+ self ._auth_in_progress = True
2117+ self ._can_recv_ext_info = False
2118+ self ._send_deferred_packets ()
20952119
20962120 def _process_service_accept (self , _pkttype : int , _pktid : int ,
20972121 packet : SSHPacket ) -> None :
@@ -2100,27 +2124,35 @@ def _process_service_accept(self, _pkttype: int, _pktid: int,
21002124 service = packet .get_string ()
21012125 packet .check_end ()
21022126
2103- if service == self ._next_service :
2104- self . logger . debug2 ( 'Request for service %s accepted' , service )
2127+ if self .is_server () :
2128+ raise ProtocolError ( 'Unexpected service accept received' )
21052129
2106- self ._next_service = None
2130+ if not self ._recv_encryption :
2131+ raise ProtocolError ('Service accept received before kex complete' )
21072132
2108- if (self .is_client () and # pragma: no branch
2109- service == _USERAUTH_SERVICE ):
2110- self .logger .info ('Beginning auth for user %s' , self ._username )
2133+ if service != self ._next_service :
2134+ raise ServiceNotAvailable ('Unexpected service in service accept' )
21112135
2112- self . _auth_in_progress = True
2136+ self . logger . debug2 ( 'Request for service %s accepted' , service )
21132137
2114- # This method is only in SSHClientConnection
2115- # pylint: disable=no-member
2116- cast ('SSHClientConnection' , self ).try_next_auth ()
2117- else :
2118- raise ServiceNotAvailable ('Unexpected service accept received' )
2138+ self ._next_service = None
2139+
2140+ if service == _USERAUTH_SERVICE : # pragma: no branch
2141+ self .logger .info ('Beginning auth for user %s' , self ._username )
2142+
2143+ self ._auth_in_progress = True
2144+
2145+ # This method is only in SSHClientConnection
2146+ # pylint: disable=no-member
2147+ cast ('SSHClientConnection' , self ).try_next_auth ()
21192148
21202149 def _process_ext_info (self , _pkttype : int , _pktid : int ,
21212150 packet : SSHPacket ) -> None :
21222151 """Process extension information"""
21232152
2153+ if not self ._can_recv_ext_info :
2154+ raise ProtocolError ('Unexpected ext_info received' )
2155+
21242156 extensions : Dict [bytes , bytes ] = {}
21252157
21262158 self .logger .debug2 ('Received extension info' )
@@ -2246,6 +2278,7 @@ def _process_newkeys(self, _pkttype: int, _pktid: int,
22462278 self ._decompress_after_auth = self ._next_decompress_after_auth
22472279
22482280 self ._next_recv_encryption = None
2281+ self ._can_recv_ext_info = True
22492282 else :
22502283 raise ProtocolError ('New keys not negotiated' )
22512284
@@ -2273,8 +2306,10 @@ def _process_userauth_request(self, _pkttype: int, _pktid: int,
22732306 if self .is_client ():
22742307 raise ProtocolError ('Unexpected userauth request' )
22752308 elif self ._auth_complete :
2276- # Silently ignore requests if we're already authenticated
2277- pass
2309+ # Silently ignore additional auth requests after auth succeeds,
2310+ # until the client sends a non-auth message
2311+ if self ._auth_final :
2312+ raise ProtocolError ('Unexpected userauth request' )
22782313 else :
22792314 if username != self ._username :
22802315 self .logger .info ('Beginning auth for user %s' , username )
@@ -2316,7 +2351,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes,
23162351 self ._auth = lookup_server_auth (cast (SSHServerConnection , self ),
23172352 self ._username , method , packet )
23182353
2319- def _process_userauth_failure (self , _pkttype : int , pktid : int ,
2354+ def _process_userauth_failure (self , _pkttype : int , _pktid : int ,
23202355 packet : SSHPacket ) -> None :
23212356 """Process a user authentication failure response"""
23222357
@@ -2356,10 +2391,9 @@ def _process_userauth_failure(self, _pkttype: int, pktid: int,
23562391 # pylint: disable=no-member
23572392 cast (SSHClientConnection , self ).try_next_auth ()
23582393 else :
2359- self .logger .debug2 ('Unexpected userauth failure response' )
2360- self .send_packet (MSG_UNIMPLEMENTED , UInt32 (pktid ))
2394+ raise ProtocolError ('Unexpected userauth failure response' )
23612395
2362- def _process_userauth_success (self , _pkttype : int , pktid : int ,
2396+ def _process_userauth_success (self , _pkttype : int , _pktid : int ,
23632397 packet : SSHPacket ) -> None :
23642398 """Process a user authentication success response"""
23652399
@@ -2385,6 +2419,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
23852419 self ._auth = None
23862420 self ._auth_in_progress = False
23872421 self ._auth_complete = True
2422+ self ._can_recv_ext_info = False
23882423
23892424 if self ._agent :
23902425 self ._agent .close ()
@@ -2412,8 +2447,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int,
24122447 self ._waiter .set_result (None )
24132448 self ._wait = None
24142449 else :
2415- self .logger .debug2 ('Unexpected userauth success response' )
2416- self .send_packet (MSG_UNIMPLEMENTED , UInt32 (pktid ))
2450+ raise ProtocolError ('Unexpected userauth success response' )
24172451
24182452 def _process_userauth_banner (self , _pkttype : int , _pktid : int ,
24192453 packet : SSHPacket ) -> None :
0 commit comments