Skip to content

Commit 3115669

Browse files
gaogaotiantianzhengruifeng
authored andcommitted
[SPARK-54925][PYTHON] Add the capability to dump threads for pyspark
### What changes were proposed in this pull request? Add an optional capability to dump thread info of *all* pyspark processes. It is intentionally hidden now because it's not fully polished. It can be used as `python -m pyspark.threaddump -p <pid>`. It requires `pystack` and `psutil`. Without these libraries the command will fail. For now it was only used when test hangs. The result would be like: ``` Thread dump: Dumping threads for process 1904 Traceback for thread 2175 (python3.12) [] (most recent call last): (Python) File "/usr/lib/python3.12/threading.py", line 1032, in _bootstrap self._bootstrap_inner() (Python) File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner self.run() (Python) File "/usr/lib/python3.12/threading.py", line 1012, in run self._target(*self._args, **self._kwargs) (Python) File "/usr/lib/python3.12/socketserver.py", line 240, in serve_forever self._handle_request_noblock() (Python) File "/usr/lib/python3.12/socketserver.py", line 318, in _handle_request_noblock self.process_request(request, client_address) (Python) File "/usr/lib/python3.12/socketserver.py", line 349, in process_request self.finish_request(request, client_address) (Python) File "/usr/lib/python3.12/socketserver.py", line 362, in finish_request self.RequestHandlerClass(request, client_address, self) (Python) File "/usr/lib/python3.12/socketserver.py", line 766, in __init__ self.handle() (Python) File "/workspaces/spark/python/pyspark/accumulators.py", line 327, in handle poll(accum_updates) (Python) File "/workspaces/spark/python/pyspark/accumulators.py", line 281, in poll for fd, event in poller.poll(1000): Traceback for thread 2034 (python3.12) [] (most recent call last): (Python) File "/usr/lib/python3.12/threading.py", line 1032, in _bootstrap self._bootstrap_inner() (Python) File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner self.run() (Python) File "/workspaces/spark/python/lib/py4j-0.10.9.9-src.zip/py4j/clientserver.py", line 58, in run Traceback for thread 1904 (python3.12) [] (most recent call last): (Python) File "<frozen runpy>", line 198, in _run_module_as_main (Python) File "<frozen runpy>", line 88, in _run_code (Python) File "/workspaces/spark/python/pyspark/sql/tests/test_udf.py", line 1790, in <module> unittest.main(testRunner=testRunner, verbosity=2) (Python) File "/workspaces/spark/python/pyspark/testing/__init__.py", line 30, in unittest_main res = _unittest_main(*args, **kwargs) (Python) File "/usr/lib/python3.12/unittest/main.py", line 105, in __init__ self.runTests() (Python) File "/usr/lib/python3.12/unittest/main.py", line 281, in runTests self.result = testRunner.run(self.test) (Python) File "/usr/local/lib/python3.12/dist-packages/xmlrunner/runner.py", line 67, in run test(result) (Python) File "/usr/lib/python3.12/unittest/suite.py", line 84, in __call__ return self.run(*args, **kwds) (Python) File "/usr/lib/python3.12/unittest/suite.py", line 122, in run test(result) (Python) File "/usr/lib/python3.12/unittest/suite.py", line 84, in __call__ return self.run(*args, **kwds) (Python) File "/usr/lib/python3.12/unittest/suite.py", line 122, in run test(result) (Python) File "/usr/lib/python3.12/unittest/case.py", line 690, in __call__ return self.run(*args, **kwds) (Python) File "/usr/lib/python3.12/unittest/case.py", line 634, in run self._callTestMethod(testMethod) (Python) File "/usr/lib/python3.12/unittest/case.py", line 589, in _callTestMethod if method() is not None: (Python) File "/workspaces/spark/python/pyspark/sql/tests/test_udf.py", line 212, in test_chained_udf [row] = self.spark.sql("SELECT double_int(double_int(1) + 1)").collect() (Python) File "/workspaces/spark/python/pyspark/sql/classic/dataframe.py", line 469, in collect sock_info = self._jdf.collectToPython() (Python) File "/workspaces/spark/python/lib/py4j-0.10.9.9-src.zip/py4j/java_gateway.py", line 1361, in __call__ (Python) File "/workspaces/spark/python/lib/py4j-0.10.9.9-src.zip/py4j/java_gateway.py", line 1038, in send_command (Python) File "/workspaces/spark/python/lib/py4j-0.10.9.9-src.zip/py4j/clientserver.py", line 535, in send_command (Python) File "/usr/lib/python3.12/socket.py", line 720, in readinto return self._sock.recv_into(b) Dumping threads for process 2191 Traceback for thread 2191 (python3.12) [] (most recent call last): (Python) File "<frozen runpy>", line 198, in _run_module_as_main (Python) File "<frozen runpy>", line 88, in _run_code (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 287, in <module> (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 180, in manager Dumping threads for process 2198 Traceback for thread 2198 (python3.12) [] (most recent call last): (Python) File "<frozen runpy>", line 198, in _run_module_as_main (Python) File "<frozen runpy>", line 88, in _run_code (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 287, in <module> (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 259, in manager (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 88, in worker (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/util.py", line 981, in wrapper (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/worker.py", line 3439, in main (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 583, in read_int Dumping threads for process 2257 Traceback for thread 2257 (python3.12) [] (most recent call last): (Python) File "<frozen runpy>", line 198, in _run_module_as_main (Python) File "<frozen runpy>", line 88, in _run_code (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 287, in <module> (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 259, in manager (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/daemon.py", line 88, in worker (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/util.py", line 981, in wrapper (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/worker.py", line 3439, in main (Python) File "/workspaces/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 583, in read_int ``` Notice that it has not only the driver process, but also daemon and worker process. The plan is to incorporate this into our existing debug framework so `threaddump` button will return both JVM executor threads and python worker threads. ### Why are the changes needed? We need insights into python worker/daemon. We have some thread dump capability in our test, but that's not stable. `SIGTERM` sometimes is hooked and `faulthandler` can't work properly. Also it can't dump the subprocesses. ### Does this PR introduce _any_ user-facing change? Yes, but it's hidden for now. A new command entry is introduced. ### How was this patch tested? Locally it works. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53705 from gaogaotiantian/python-threaddump. Authored-by: Tian Gao <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 3dc641b commit 3115669

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

dev/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ graphviz==0.20.3
7777
flameprof==0.4
7878
viztracer
7979
debugpy
80+
pystack>=1.5.1; python_version!='3.13' and sys_platform=='linux' # no 3.13t wheels
81+
psutil
8082

8183
# TorchDistributor dependencies
8284
torch

dev/spark-test-image/python-311/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ RUN apt-get update && apt-get install -y \
6868
&& rm -rf /var/lib/apt/lists/*
6969

7070

71-
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
71+
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 pystack psutil"
7272
# Python deps for Spark Connect
7373
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
7474

python/pyspark/threaddump.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
import argparse
19+
import sys
20+
21+
22+
def build_parser() -> argparse.ArgumentParser:
23+
parser = argparse.ArgumentParser(description="Dump threads of a process and its children")
24+
parser.add_argument("-p", "--pid", type=int, required=True, help="The PID to dump")
25+
return parser
26+
27+
28+
def main() -> int:
29+
try:
30+
import psutil
31+
from pystack.__main__ import main as pystack_main # type: ignore
32+
except ImportError:
33+
print("pystack and psutil are not installed")
34+
return 1
35+
36+
parser = build_parser()
37+
args = parser.parse_args()
38+
39+
try:
40+
pids = [args.pid] + [
41+
child.pid
42+
for child in psutil.Process(args.pid).children(recursive=True)
43+
if "python" in child.exe()
44+
]
45+
except Exception as e:
46+
print(f"Error getting children of process {args.pid}: {e}")
47+
return 2
48+
49+
for pid in pids:
50+
sys.argv = ["pystack", "remote", str(pid)]
51+
try:
52+
print(f"Dumping threads for process {pid}")
53+
pystack_main()
54+
except Exception:
55+
# We might tried to dump a process that is not a Python process
56+
pass
57+
58+
return 0
59+
60+
61+
if __name__ == "__main__":
62+
sys.exit(main())

python/run-tests.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def run(self):
113113
try:
114114
asyncio.run(self.handle_inout())
115115
except subprocess.TimeoutExpired:
116-
LOGGER.error(f"Test {self.test_name} timed out")
116+
LOGGER.error(f"Test {self.test_name} timed out after {self.timeout} seconds")
117117
try:
118118
return self.p.wait(timeout=30)
119119
except subprocess.TimeoutExpired:
@@ -204,9 +204,22 @@ async def check_timeout(self):
204204
# We don't want to kill the process if it's in pdb mode
205205
return
206206
if self.p.poll() is None:
207+
if sys.platform == "linux":
208+
self.thread_dump(self.p.pid)
207209
self.p.terminate()
208210
raise subprocess.TimeoutExpired(self.cmd, self.timeout)
209211

212+
def thread_dump(self, pid):
213+
pyspark_python = self.env['PYSPARK_PYTHON']
214+
p = subprocess.run(
215+
[pyspark_python, "-m", "pyspark.threaddump", "-p", str(pid)],
216+
env={**self.env, "PYTHONPATH": f"{os.path.join(SPARK_HOME, 'python')}:{os.environ.get('PYTHONPATH', '')}"},
217+
stdout=subprocess.PIPE,
218+
stderr=subprocess.STDOUT,
219+
)
220+
if p.returncode == 0:
221+
LOGGER.error(f"Thread dump:\n{p.stdout.decode('utf-8')}")
222+
210223

211224
def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_output):
212225
"""

0 commit comments

Comments
 (0)