Skip to content

Commit 60bfae8

Browse files
Add test
1 parent 4483e3f commit 60bfae8

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

test/test_blocks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,47 @@ async def test_every_does_not_block_queue(self):
979979
else:
980980
break
981981

982+
@pytest.mark.asyncio
983+
async def test_generating_event_cancelled_if_ws_closed(self, capsys):
984+
def generation():
985+
for i in range(10):
986+
time.sleep(0.1)
987+
print(f"At step {i}")
988+
yield i
989+
return "Hello!"
990+
991+
with gr.Blocks() as demo:
992+
greeting = gr.Textbox()
993+
button = gr.Button(value="Greet")
994+
button.click(generation, None, greeting)
995+
996+
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
997+
998+
async with websockets.connect(
999+
f"{demo.local_url.replace('http', 'ws')}queue/join"
1000+
) as ws:
1001+
completed = False
1002+
n_steps = 0
1003+
while not completed:
1004+
msg = json.loads(await ws.recv())
1005+
if msg["msg"] == "send_data":
1006+
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
1007+
elif msg["msg"] == "send_hash":
1008+
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
1009+
elif msg["msg"] == "process_generating":
1010+
if n_steps == 2:
1011+
# Close the websocket
1012+
break
1013+
n_steps += 1
1014+
else:
1015+
continue
1016+
await asyncio.sleep(1)
1017+
# If the generation function did not get cancelled
1018+
# it would have finished running and `At step 9` would
1019+
# have been printed
1020+
captured = capsys.readouterr()
1021+
assert "At step 9" not in captured.out
1022+
9821023

9831024
class TestAddRequests:
9841025
def test_no_type_hints(self):

0 commit comments

Comments
 (0)