Skip to content

Commit 2bcec24

Browse files
Stream dmypy output instead of dumping everything at the end (#16252)
This does 2 things: 1. It changes the IPC code to work with multiple messages. 2. It changes the dmypy client/server communication so that it streams stdout/stderr instead of dumping everything at the end. For 1, we have to provide a way to separate out different messages. I chose to frame messages as bytes separated by whitespace character. That means we have to encode the message in a scheme that escapes whitespace. The `codecs.encode(<bytes_data>, 'base64')` seems reasonable. It encodes more than needed but the application is not IPC IO limited so it should be fine. With this convention in place, all we have to do is read from the socket stream until we have a whitespace character. The framing logic can be easily changed. For 2, since we communicate with JSONs, it's easy to add a "finished" key that tells us it's the final response from dmypy. Anything else is just stdout/stderr output. Note: dmypy server also returns out/err which is the output of actual mypy type checking. Right now this change does not stream that output. We can stream that in a followup change. We just have to decide on how to differenciate the 4 text streams (stdout/stderr/out/err) that will now be interleaved. The WriteToConn class could use more love. I just put a bare minimum. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e435594 commit 2bcec24

File tree

5 files changed

+155
-51
lines changed

5 files changed

+155
-51
lines changed

mypy/dmypy/client.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Callable, Mapping, NoReturn
1818

1919
from mypy.dmypy_os import alive, kill
20-
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive
20+
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
2121
from mypy.ipc import IPCClient, IPCException
2222
from mypy.util import check_python_version, get_terminal_width, should_force_color
2323
from mypy.version import __version__
@@ -659,28 +659,29 @@ def request(
659659
# so that it can format the type checking output accordingly.
660660
args["is_tty"] = sys.stdout.isatty() or should_force_color()
661661
args["terminal_width"] = get_terminal_width()
662-
bdata = json.dumps(args).encode("utf8")
663662
_, name = get_status(status_file)
664663
try:
665664
with IPCClient(name, timeout) as client:
666-
client.write(bdata)
667-
response = receive(client)
665+
send(client, args)
666+
667+
final = False
668+
while not final:
669+
response = receive(client)
670+
final = bool(response.pop("final", False))
671+
# Display debugging output written to stdout/stderr in the server process for convenience.
672+
# This should not be confused with "out" and "err" fields in the response.
673+
# Those fields hold the output of the "check" command, and are handled in check_output().
674+
stdout = response.pop("stdout", None)
675+
if stdout:
676+
sys.stdout.write(stdout)
677+
stderr = response.pop("stderr", None)
678+
if stderr:
679+
sys.stderr.write(stderr)
668680
except (OSError, IPCException) as err:
669681
return {"error": str(err)}
670682
# TODO: Other errors, e.g. ValueError, UnicodeError
671-
else:
672-
# Display debugging output written to stdout/stderr in the server process for convenience.
673-
# This should not be confused with "out" and "err" fields in the response.
674-
# Those fields hold the output of the "check" command, and are handled in check_output().
675-
stdout = response.get("stdout")
676-
if stdout:
677-
sys.stdout.write(stdout)
678-
stderr = response.get("stderr")
679-
if stderr:
680-
print("-" * 79)
681-
print("stderr:")
682-
sys.stdout.write(stderr)
683-
return response
683+
684+
return response
684685

685686

686687
def get_status(status_file: str) -> tuple[int, str]:

mypy/dmypy_server.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import mypy.build
2424
import mypy.errors
2525
import mypy.main
26-
from mypy.dmypy_util import receive
26+
from mypy.dmypy_util import WriteToConn, receive, send
2727
from mypy.find_sources import InvalidSourceList, create_source_list
2828
from mypy.fscache import FileSystemCache
2929
from mypy.fswatcher import FileData, FileSystemWatcher
@@ -208,21 +208,21 @@ def _response_metadata(self) -> dict[str, str]:
208208

209209
def serve(self) -> None:
210210
"""Serve requests, synchronously (no thread or fork)."""
211+
211212
command = None
212213
server = IPCServer(CONNECTION_NAME, self.timeout)
213214
orig_stdout = sys.stdout
214215
orig_stderr = sys.stderr
216+
215217
try:
216218
with open(self.status_file, "w") as f:
217219
json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f)
218220
f.write("\n") # I like my JSON with a trailing newline
219221
while True:
220222
with server:
221223
data = receive(server)
222-
debug_stdout = io.StringIO()
223-
debug_stderr = io.StringIO()
224-
sys.stdout = debug_stdout
225-
sys.stderr = debug_stderr
224+
sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment]
225+
sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment]
226226
resp: dict[str, Any] = {}
227227
if "command" not in data:
228228
resp = {"error": "No command found in request"}
@@ -239,15 +239,13 @@ def serve(self) -> None:
239239
tb = traceback.format_exception(*sys.exc_info())
240240
resp = {"error": "Daemon crashed!\n" + "".join(tb)}
241241
resp.update(self._response_metadata())
242-
resp["stdout"] = debug_stdout.getvalue()
243-
resp["stderr"] = debug_stderr.getvalue()
244-
server.write(json.dumps(resp).encode("utf8"))
242+
resp["final"] = True
243+
send(server, resp)
245244
raise
246-
resp["stdout"] = debug_stdout.getvalue()
247-
resp["stderr"] = debug_stderr.getvalue()
245+
resp["final"] = True
248246
try:
249247
resp.update(self._response_metadata())
250-
server.write(json.dumps(resp).encode("utf8"))
248+
send(server, resp)
251249
except OSError:
252250
pass # Maybe the client hung up
253251
if command == "stop":

mypy/dmypy_util.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from __future__ import annotations
77

88
import json
9-
from typing import Any, Final
9+
from typing import Any, Final, Iterable
1010

1111
from mypy.ipc import IPCBase
1212

1313
DEFAULT_STATUS_FILE: Final = ".dmypy.json"
1414

1515

1616
def receive(connection: IPCBase) -> Any:
17-
"""Receive JSON data from a connection until EOF.
17+
"""Receive single JSON data frame from a connection.
1818
1919
Raise OSError if the data received is not valid JSON or if it is
2020
not a dict.
@@ -23,9 +23,36 @@ def receive(connection: IPCBase) -> Any:
2323
if not bdata:
2424
raise OSError("No data received")
2525
try:
26-
data = json.loads(bdata.decode("utf8"))
26+
data = json.loads(bdata)
2727
except Exception as e:
2828
raise OSError("Data received is not valid JSON") from e
2929
if not isinstance(data, dict):
3030
raise OSError(f"Data received is not a dict ({type(data)})")
3131
return data
32+
33+
34+
def send(connection: IPCBase, data: Any) -> None:
35+
"""Send data to a connection encoded and framed.
36+
37+
The data must be JSON-serializable. We assume that a single send call is a
38+
single frame to be sent on the connect.
39+
"""
40+
connection.write(json.dumps(data))
41+
42+
43+
class WriteToConn:
44+
"""Helper class to write to a connection instead of standard output."""
45+
46+
def __init__(self, server: IPCBase, output_key: str = "stdout"):
47+
self.server = server
48+
self.output_key = output_key
49+
50+
def write(self, output: str) -> int:
51+
resp: dict[str, Any] = {}
52+
resp[self.output_key] = output
53+
send(self.server, resp)
54+
return len(output)
55+
56+
def writelines(self, lines: Iterable[str]) -> None:
57+
for s in lines:
58+
self.write(s)

mypy/ipc.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import base64
10+
import codecs
1011
import os
1112
import shutil
1213
import sys
@@ -40,19 +41,41 @@ class IPCBase:
4041
4142
This contains logic shared between the client and server, such as reading
4243
and writing.
44+
We want to be able to send multiple "messages" over a single connection and
45+
to be able to separate the messages. We do this by encoding the messages
46+
in an alphabet that does not contain spaces, then adding a space for
47+
separation. The last framed message is also followed by a space.
4348
"""
4449

4550
connection: _IPCHandle
4651

4752
def __init__(self, name: str, timeout: float | None) -> None:
4853
self.name = name
4954
self.timeout = timeout
55+
self.buffer = bytearray()
5056

51-
def read(self, size: int = 100000) -> bytes:
52-
"""Read bytes from an IPC connection until its empty."""
53-
bdata = bytearray()
57+
def frame_from_buffer(self) -> bytearray | None:
58+
"""Return a full frame from the bytes we have in the buffer."""
59+
space_pos = self.buffer.find(b" ")
60+
if space_pos == -1:
61+
return None
62+
# We have a full frame
63+
bdata = self.buffer[:space_pos]
64+
self.buffer = self.buffer[space_pos + 1 :]
65+
return bdata
66+
67+
def read(self, size: int = 100000) -> str:
68+
"""Read bytes from an IPC connection until we have a full frame."""
69+
bdata: bytearray | None = bytearray()
5470
if sys.platform == "win32":
5571
while True:
72+
# Check if we already have a message in the buffer before
73+
# receiving any more data from the socket.
74+
bdata = self.frame_from_buffer()
75+
if bdata is not None:
76+
break
77+
78+
# Receive more data into the buffer.
5679
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
5780
try:
5881
if err == _winapi.ERROR_IO_PENDING:
@@ -66,7 +89,10 @@ def read(self, size: int = 100000) -> bytes:
6689
_, err = ov.GetOverlappedResult(True)
6790
more = ov.getbuffer()
6891
if more:
69-
bdata.extend(more)
92+
self.buffer.extend(more)
93+
bdata = self.frame_from_buffer()
94+
if bdata is not None:
95+
break
7096
if err == 0:
7197
# we are done!
7298
break
@@ -77,17 +103,34 @@ def read(self, size: int = 100000) -> bytes:
77103
raise IPCException("ReadFile operation aborted.")
78104
else:
79105
while True:
106+
# Check if we already have a message in the buffer before
107+
# receiving any more data from the socket.
108+
bdata = self.frame_from_buffer()
109+
if bdata is not None:
110+
break
111+
112+
# Receive more data into the buffer.
80113
more = self.connection.recv(size)
81114
if not more:
115+
# Connection closed
82116
break
83-
bdata.extend(more)
84-
return bytes(bdata)
117+
self.buffer.extend(more)
118+
119+
if not bdata:
120+
# Socket was empty and we didn't get any frame.
121+
# This should only happen if the socket was closed.
122+
return ""
123+
return codecs.decode(bdata, "base64").decode("utf8")
124+
125+
def write(self, data: str) -> None:
126+
"""Write to an IPC connection."""
127+
128+
# Frame the data by urlencoding it and separating by space.
129+
encoded_data = codecs.encode(data.encode("utf8"), "base64") + b" "
85130

86-
def write(self, data: bytes) -> None:
87-
"""Write bytes to an IPC connection."""
88131
if sys.platform == "win32":
89132
try:
90-
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
133+
ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True)
91134
try:
92135
if err == _winapi.ERROR_IO_PENDING:
93136
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
@@ -101,12 +144,11 @@ def write(self, data: bytes) -> None:
101144
raise
102145
bytes_written, err = ov.GetOverlappedResult(True)
103146
assert err == 0, err
104-
assert bytes_written == len(data)
147+
assert bytes_written == len(encoded_data)
105148
except OSError as e:
106149
raise IPCException(f"Failed to write with error: {e.winerror}") from e
107150
else:
108-
self.connection.sendall(data)
109-
self.connection.shutdown(socket.SHUT_WR)
151+
self.connection.sendall(encoded_data)
110152

111153
def close(self) -> None:
112154
if sys.platform == "win32":

mypy/test/testipc.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,25 @@
1515
def server(msg: str, q: Queue[str]) -> None:
1616
server = IPCServer(CONNECTION_NAME)
1717
q.put(server.connection_name)
18-
data = b""
18+
data = ""
1919
while not data:
2020
with server:
21-
server.write(msg.encode())
21+
server.write(msg)
2222
data = server.read()
2323
server.cleanup()
2424

2525

26+
def server_multi_message_echo(q: Queue[str]) -> None:
27+
server = IPCServer(CONNECTION_NAME)
28+
q.put(server.connection_name)
29+
data = ""
30+
with server:
31+
while data != "quit":
32+
data = server.read()
33+
server.write(data)
34+
server.cleanup()
35+
36+
2637
class IPCTests(TestCase):
2738
def test_transaction_large(self) -> None:
2839
queue: Queue[str] = Queue()
@@ -31,8 +42,8 @@ def test_transaction_large(self) -> None:
3142
p.start()
3243
connection_name = queue.get()
3344
with IPCClient(connection_name, timeout=1) as client:
34-
assert client.read() == msg.encode()
35-
client.write(b"test")
45+
assert client.read() == msg
46+
client.write("test")
3647
queue.close()
3748
queue.join_thread()
3849
p.join()
@@ -44,12 +55,37 @@ def test_connect_twice(self) -> None:
4455
p.start()
4556
connection_name = queue.get()
4657
with IPCClient(connection_name, timeout=1) as client:
47-
assert client.read() == msg.encode()
48-
client.write(b"") # don't let the server hang up yet, we want to connect again.
58+
assert client.read() == msg
59+
client.write("") # don't let the server hang up yet, we want to connect again.
4960

5061
with IPCClient(connection_name, timeout=1) as client:
51-
assert client.read() == msg.encode()
52-
client.write(b"test")
62+
assert client.read() == msg
63+
client.write("test")
64+
queue.close()
65+
queue.join_thread()
66+
p.join()
67+
assert p.exitcode == 0
68+
69+
def test_multiple_messages(self) -> None:
70+
queue: Queue[str] = Queue()
71+
p = Process(target=server_multi_message_echo, args=(queue,), daemon=True)
72+
p.start()
73+
connection_name = queue.get()
74+
with IPCClient(connection_name, timeout=1) as client:
75+
# "foo bar" with extra accents on letters.
76+
# In UTF-8 encoding so we don't confuse editors opening this file.
77+
fancy_text = b"f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c"
78+
client.write(fancy_text.decode("utf-8"))
79+
assert client.read() == fancy_text.decode("utf-8")
80+
81+
client.write("Test with spaces")
82+
client.write("Test write before reading previous")
83+
time.sleep(0) # yield to the server to force reading of all messages by server.
84+
assert client.read() == "Test with spaces"
85+
assert client.read() == "Test write before reading previous"
86+
87+
client.write("quit")
88+
assert client.read() == "quit"
5389
queue.close()
5490
queue.join_thread()
5591
p.join()

0 commit comments

Comments
 (0)