Skip to content

Commit c160551

Browse files
committed
Stream dmypy output instead of dumping everything at the end
This does 2 things: 1. It changes the IPC code to work with multiple messages. 2. It changes the dmypy client/server communication so that it streams stdout/stderr instead of dumping everything at the end. For 1, we have to provide a way to separate out different messages. We can frame messages as bytes separated by whitespace character. That means we have to encode the message in a scheme that escapes whitespace. The urllib.parse quote/unquote seems reasonable. It encodes more than needed but the application is not IPC IO limited so it should be fine. With this convention in place, all we have to do is read from the socket stream until we have a whitespace character. For 2, since we communicate with JSONs, it's easy to add a "finished" key that tells us it's the final response from dmypy. Anything else is just stdout/stderr output. Note: dmypy server also returns out/err which is the output of actual mypy type checking. Right now this change does not stream that output. We can stream that in a followup change. We just have to decide on how to differenciate the 4 text streams (stdout/stderr/out/err) that will now be interleaved. Played around with it on Linux quite a bit. Will also test it some more on Windows. The WriteToConn class could use more love. I just put a bare minimum to test the rest.
1 parent 8b6d213 commit c160551

File tree

5 files changed

+118
-49
lines changed

5 files changed

+118
-49
lines changed

mypy/dmypy/client.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Callable, Mapping, NoReturn
1818

1919
from mypy.dmypy_os import alive, kill
20-
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive
20+
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
2121
from mypy.ipc import IPCClient, IPCException
2222
from mypy.util import check_python_version, get_terminal_width, should_force_color
2323
from mypy.version import __version__
@@ -659,28 +659,29 @@ def request(
659659
# so that it can format the type checking output accordingly.
660660
args["is_tty"] = sys.stdout.isatty() or should_force_color()
661661
args["terminal_width"] = get_terminal_width()
662-
bdata = json.dumps(args).encode("utf8")
663662
_, name = get_status(status_file)
664663
try:
665664
with IPCClient(name, timeout) as client:
666-
client.write(bdata)
667-
response = receive(client)
665+
send(client, args)
666+
667+
final = False
668+
while not final:
669+
response = receive(client)
670+
final = bool(response.pop("final", False))
671+
# Display debugging output written to stdout/stderr in the server process for convenience.
672+
# This should not be confused with "out" and "err" fields in the response.
673+
# Those fields hold the output of the "check" command, and are handled in check_output().
674+
stdout = response.pop("stdout", None)
675+
if stdout:
676+
sys.stdout.write(stdout)
677+
stderr = response.pop("stderr", None)
678+
if stderr:
679+
sys.stderr.write(stderr)
668680
except (OSError, IPCException) as err:
669681
return {"error": str(err)}
670682
# TODO: Other errors, e.g. ValueError, UnicodeError
671-
else:
672-
# Display debugging output written to stdout/stderr in the server process for convenience.
673-
# This should not be confused with "out" and "err" fields in the response.
674-
# Those fields hold the output of the "check" command, and are handled in check_output().
675-
stdout = response.get("stdout")
676-
if stdout:
677-
sys.stdout.write(stdout)
678-
stderr = response.get("stderr")
679-
if stderr:
680-
print("-" * 79)
681-
print("stderr:")
682-
sys.stdout.write(stderr)
683-
return response
683+
684+
return response
684685

685686

686687
def get_status(status_file: str) -> tuple[int, str]:

mypy/dmypy_server.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717
import time
1818
import traceback
1919
from contextlib import redirect_stderr, redirect_stdout
20-
from typing import AbstractSet, Any, Callable, Final, List, Sequence, Tuple
20+
from typing import AbstractSet, Any, Callable, Final, Iterable, List, Sequence, Tuple
2121
from typing_extensions import TypeAlias as _TypeAlias
2222

2323
import mypy.build
2424
import mypy.errors
2525
import mypy.main
26-
from mypy.dmypy_util import receive
26+
from mypy.dmypy_util import receive, send
2727
from mypy.find_sources import InvalidSourceList, create_source_list
2828
from mypy.fscache import FileSystemCache
2929
from mypy.fswatcher import FileData, FileSystemWatcher
3030
from mypy.inspections import InspectionEngine
31-
from mypy.ipc import IPCServer
31+
from mypy.ipc import IPCBase, IPCServer
3232
from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths, compute_search_paths
3333
from mypy.options import Options
3434
from mypy.server.update import FineGrainedBuildManager, refresh_suppressed_submodules
@@ -208,19 +208,34 @@ def _response_metadata(self) -> dict[str, str]:
208208

209209
def serve(self) -> None:
210210
"""Serve requests, synchronously (no thread or fork)."""
211+
212+
class WriteToConn(object):
213+
def __init__(self, server: IPCBase, output_key: str = "stdout"):
214+
self.server = server
215+
self.output_key = output_key
216+
217+
def write(self, output: str) -> int:
218+
resp: dict[str, Any] = {}
219+
resp[self.output_key] = output
220+
send(server, resp)
221+
return len(output)
222+
223+
def writelines(self, lines: Iterable[str]) -> None:
224+
for s in lines:
225+
self.write(s)
226+
211227
command = None
212228
server = IPCServer(CONNECTION_NAME, self.timeout)
229+
213230
try:
214231
with open(self.status_file, "w") as f:
215232
json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f)
216233
f.write("\n") # I like my JSON with a trailing newline
217234
while True:
218235
with server:
219236
data = receive(server)
220-
debug_stdout = io.StringIO()
221-
debug_stderr = io.StringIO()
222-
sys.stdout = debug_stdout
223-
sys.stderr = debug_stderr
237+
sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment]
238+
sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment]
224239
resp: dict[str, Any] = {}
225240
if "command" not in data:
226241
resp = {"error": "No command found in request"}
@@ -237,15 +252,13 @@ def serve(self) -> None:
237252
tb = traceback.format_exception(*sys.exc_info())
238253
resp = {"error": "Daemon crashed!\n" + "".join(tb)}
239254
resp.update(self._response_metadata())
240-
resp["stdout"] = debug_stdout.getvalue()
241-
resp["stderr"] = debug_stderr.getvalue()
242-
server.write(json.dumps(resp).encode("utf8"))
255+
resp["final"] = True
256+
send(server, resp)
243257
raise
244-
resp["stdout"] = debug_stdout.getvalue()
245-
resp["stderr"] = debug_stderr.getvalue()
258+
resp["final"] = True
246259
try:
247260
resp.update(self._response_metadata())
248-
server.write(json.dumps(resp).encode("utf8"))
261+
send(server, resp)
249262
except OSError:
250263
pass # Maybe the client hung up
251264
if command == "stop":

mypy/dmypy_util.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def receive(connection: IPCBase) -> Any:
17-
"""Receive JSON data from a connection until EOF.
17+
"""Receive single JSON data frame from a connection.
1818
1919
Raise OSError if the data received is not valid JSON or if it is
2020
not a dict.
@@ -23,9 +23,19 @@ def receive(connection: IPCBase) -> Any:
2323
if not bdata:
2424
raise OSError("No data received")
2525
try:
26-
data = json.loads(bdata.decode("utf8"))
26+
data = json.loads(bdata)
2727
except Exception as e:
2828
raise OSError("Data received is not valid JSON") from e
2929
if not isinstance(data, dict):
3030
raise OSError(f"Data received is not a dict ({type(data)})")
3131
return data
32+
33+
def send(connection: IPCBase, data: Any) -> None:
34+
"""Send data to a connection encoded and framed.
35+
36+
The data must be JSON-serializable. We assume that a single send call is a
37+
single frame to be sent on the connect.
38+
As an easy way to separate frames, we urlencode them and separate by space.
39+
Last frame also has a trailing space.
40+
"""
41+
connection.write(json.dumps(data))

mypy/ipc.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import tempfile
1414
from types import TracebackType
1515
from typing import Callable, Final
16+
from urllib.parse import quote,unquote
1617

1718
if sys.platform == "win32":
1819
# This may be private, but it is needed for IPC on Windows, and is basically stable
@@ -47,8 +48,9 @@ class IPCBase:
4748
def __init__(self, name: str, timeout: float | None) -> None:
4849
self.name = name
4950
self.timeout = timeout
51+
self.buffer = bytearray()
5052

51-
def read(self, size: int = 100000) -> bytes:
53+
def read(self, size: int = 100000) -> str:
5254
"""Read bytes from an IPC connection until its empty."""
5355
bdata = bytearray()
5456
if sys.platform == "win32":
@@ -66,7 +68,13 @@ def read(self, size: int = 100000) -> bytes:
6668
_, err = ov.GetOverlappedResult(True)
6769
more = ov.getbuffer()
6870
if more:
69-
bdata.extend(more)
71+
self.buffer.extend(more)
72+
space_pos = self.buffer.find(b" ")
73+
if space_pos != -1:
74+
# We have a full frame
75+
bdata = self.buffer[: space_pos]
76+
self.buffer = self.buffer[space_pos + 1 :]
77+
break
7078
if err == 0:
7179
# we are done!
7280
break
@@ -79,15 +87,26 @@ def read(self, size: int = 100000) -> bytes:
7987
while True:
8088
more = self.connection.recv(size)
8189
if not more:
90+
# Connection closed
8291
break
83-
bdata.extend(more)
84-
return bytes(bdata)
92+
self.buffer.extend(more)
93+
space_pos = self.buffer.find(b" ")
94+
if space_pos != -1:
95+
# We have a full frame
96+
bdata = self.buffer[: space_pos]
97+
self.buffer = self.buffer[space_pos + 1 :]
98+
break
99+
return unquote(bytes(bdata).decode("utf8"))
85100

86-
def write(self, data: bytes) -> None:
101+
def write(self, data: str) -> None:
87102
"""Write bytes to an IPC connection."""
103+
104+
# Frame the data by urlencoding it and separating by space.
105+
encoded_data = (quote(data) + " ").encode("utf8")
106+
88107
if sys.platform == "win32":
89108
try:
90-
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
109+
ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True)
91110
try:
92111
if err == _winapi.ERROR_IO_PENDING:
93112
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
@@ -101,12 +120,11 @@ def write(self, data: bytes) -> None:
101120
raise
102121
bytes_written, err = ov.GetOverlappedResult(True)
103122
assert err == 0, err
104-
assert bytes_written == len(data)
123+
assert bytes_written == len(encoded_data)
105124
except OSError as e:
106125
raise IPCException(f"Failed to write with error: {e.winerror}") from e
107126
else:
108-
self.connection.sendall(data)
109-
self.connection.shutdown(socket.SHUT_WR)
127+
self.connection.sendall(encoded_data)
110128

111129
def close(self) -> None:
112130
if sys.platform == "win32":

mypy/test/testipc.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,23 @@
1515
def server(msg: str, q: Queue[str]) -> None:
1616
server = IPCServer(CONNECTION_NAME)
1717
q.put(server.connection_name)
18-
data = b""
18+
data = ""
1919
while not data:
2020
with server:
21-
server.write(msg.encode())
21+
server.write(msg)
2222
data = server.read()
2323
server.cleanup()
2424

25+
def server_multi_message_echo(q: Queue[str]) -> None:
26+
server = IPCServer(CONNECTION_NAME)
27+
q.put(server.connection_name)
28+
data = ""
29+
with server:
30+
while data != "quit":
31+
data = server.read()
32+
server.write(data)
33+
server.cleanup()
34+
2535

2636
class IPCTests(TestCase):
2737
def test_transaction_large(self) -> None:
@@ -31,8 +41,8 @@ def test_transaction_large(self) -> None:
3141
p.start()
3242
connection_name = queue.get()
3343
with IPCClient(connection_name, timeout=1) as client:
34-
assert client.read() == msg.encode()
35-
client.write(b"test")
44+
assert client.read() == msg
45+
client.write("test")
3646
queue.close()
3747
queue.join_thread()
3848
p.join()
@@ -44,12 +54,29 @@ def test_connect_twice(self) -> None:
4454
p.start()
4555
connection_name = queue.get()
4656
with IPCClient(connection_name, timeout=1) as client:
47-
assert client.read() == msg.encode()
48-
client.write(b"") # don't let the server hang up yet, we want to connect again.
57+
assert client.read() == msg
58+
client.write("") # don't let the server hang up yet, we want to connect again.
59+
60+
with IPCClient(connection_name, timeout=1) as client:
61+
assert client.read() == msg
62+
client.write("test")
63+
queue.close()
64+
queue.join_thread()
65+
p.join()
66+
assert p.exitcode == 0
4967

68+
def test_multiple_messages(self) -> None:
69+
queue: Queue[str] = Queue()
70+
p = Process(target=server_multi_message_echo, args=(queue,), daemon=True)
71+
p.start()
72+
connection_name = queue.get()
5073
with IPCClient(connection_name, timeout=1) as client:
51-
assert client.read() == msg.encode()
52-
client.write(b"test")
74+
client.write("test1")
75+
assert client.read() == "test1"
76+
client.write("test2")
77+
assert client.read() == "test2"
78+
client.write("quit")
79+
assert client.read() == "quit"
5380
queue.close()
5481
queue.join_thread()
5582
p.join()

0 commit comments

Comments
 (0)