@@ -14,12 +14,12 @@ class WebSocket:
1414 def __init__ (self , request , response ):
1515 self .request = request
1616 self .response = response
17+ self .max_payload_size = request .server .options ['ws_max_payload_size' ]
18+ self .ping_interval = request .server .options ['keepalive_timeout' ] / 2
19+
1720 self .fin = 1
1821 self .opcode = 0
1922
20- self ._max_payload_size = request .server .options ['ws_max_payload_size' ]
21- self ._receive_timeout = request .server .options ['keepalive_timeout' ] / 2
22-
2323 def __aiter__ (self ):
2424 return self
2525
@@ -44,18 +44,18 @@ async def accept(self):
4444 self .response .set_header (b'Sec-WebSocket-Accept' , accept_key )
4545 await self .response .write ()
4646
47- async def recv (self ):
47+ async def recv (self , timeout = None ):
4848 try :
49- first_byte , second_byte = await self .request .recv (2 )
49+ b1 , b2 = await self .request .recv (2 , timeout = timeout )
5050 except ValueError as exc :
5151 raise WebSocketClientClosed (
5252 'connection closed: recv failed'
5353 ) from exc
5454
55- fin = (first_byte & 0x80 ) >> 7
56- opcode = first_byte & 0x0f
57- is_masked = (second_byte & 0x80 ) >> 7
58- payload_length = second_byte & 0x7f
55+ fin = (b1 & 0x80 ) >> 7
56+ opcode = b1 & 0x0f
57+ is_masked = (b2 & 0x80 ) >> 7
58+ payload_length = b2 & 0x7f
5959
6060 if opcode != 0 :
6161 if self .fin != 1 :
@@ -80,10 +80,10 @@ async def recv(self):
8080 payload_length = int .from_bytes (await self .request .recv (8 ),
8181 byteorder = 'big' )
8282
83- if payload_length > self ._max_payload_size :
83+ if payload_length > self .max_payload_size :
8484 raise WebSocketServerClosed (
8585 '%d exceeds maximum payload size (%d)' %
86- (payload_length , self ._max_payload_size ),
86+ (payload_length , self .max_payload_size ),
8787 code = 1009
8888 )
8989
@@ -126,18 +126,18 @@ async def recv(self):
126126 code = 1008
127127 )
128128
129- async def receive (self ):
129+ async def receive (self , timeout = None ):
130130 payload = bytearray ()
131131
132- while True :
132+ while self . opcode != 8 :
133133 coro = self .ping ()
134134 timer = self .request .server .loop .call_at (
135- self .request .server .loop .time () + self ._receive_timeout ,
135+ self .request .server .loop .time () + self .ping_interval ,
136136 self .request .server .create_task , coro
137137 )
138138
139139 try :
140- frame = await self .recv ()
140+ frame = await self .recv (timeout )
141141 except TimeoutError as exc :
142142 raise WebSocketServerClosed ('receive timeout' ,
143143 code = 1000 ) from exc
@@ -151,12 +151,14 @@ async def receive(self):
151151
152152 payload .extend (frame )
153153
154- if len (payload ) > self ._max_payload_size :
154+ if len (payload ) > self .max_payload_size :
155155 raise WebSocketServerClosed ('maximum payload size exceeded' ,
156156 code = 1009 )
157157
158158 if self .fin == 1 :
159159 break
160+ else :
161+ raise WebSocketClientClosed ('connection already closed' )
160162
161163 if self .opcode == 1 :
162164 return payload .decode ('utf-8' )
@@ -174,24 +176,24 @@ def create_frame(payload_data, fin=1, opcode=None, mask=False):
174176 if opcode == 1 :
175177 payload_data = payload_data .encode ('utf-8' )
176178
177- first_byte = (fin << 7 ) | opcode
179+ b1 = (fin << 7 ) | opcode
178180 payload_length = len (payload_data )
179181
180182 if payload_length < 126 :
181- second_byte = payload_length
183+ b2 = payload_length
182184 payload_length_data = b''
183185 elif payload_length < 65536 :
184- second_byte = 126
186+ b2 = 126
185187 payload_length_data = payload_length .to_bytes (2 , byteorder = 'big' )
186188 else :
187- second_byte = 127
189+ b2 = 127
188190 payload_length_data = payload_length .to_bytes (8 , byteorder = 'big' )
189191
190192 if mask :
191- second_byte |= (1 << 7 )
193+ b2 |= (1 << 7 )
192194 masking_key = os .urandom (4 )
193195
194- frame_header = bytes ([first_byte , second_byte ]) + payload_length_data
196+ frame_header = bytes ([b1 , b2 ]) + payload_length_data
195197
196198 if mask :
197199 masked_payload_data = bytes (
@@ -216,9 +218,15 @@ async def ping(self, data=b''):
216218 async def pong (self , data = b'' ):
217219 await self .send (data , opcode = 10 )
218220
219- async def close (self , code = 1000 ):
221+ async def close (self , code = 1000 , * , timeout = None ):
220222 try :
221223 await self .send (code .to_bytes (2 , byteorder = 'big' ), opcode = 8 )
222- self .response .close (keepalive = True )
223224 except RuntimeError :
225+ return
226+
227+ try :
228+ await self .receive (timeout )
229+ except (WebSocketClientClosed , WebSocketServerClosed ):
224230 pass
231+ finally :
232+ self .response .close ()
0 commit comments