Skip to content

Commit 4192135

Browse files
gaogaotiantianHyukjinKwon
authored andcommitted
[SPARK-55020][PYTHON][FOLLOW-UP] Disable gc only when we communicate through gRPC for ExecutePlan
### What changes were proposed in this pull request? Instead of disabling gc for the whole function (we did it wrong for generators), we precisely disable it when we do communications through gRPC with `ExecutePlan`. ### Why are the changes needed? The previous implementation [SPARK-55020](https://issues.apache.org/jira/browse/SPARK-55020) (#53783) was wrong - the generator was only protected when it's built, but it will trigger communication when it's being drained. The context `disable_gc` provides a more precise way to disable gc during some operation. Also we should make generator work in a way that gc is not always disabled when generator is not drained. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54248 from gaogaotiantian/fix-disable-gc. Authored-by: Tian Gao <gaogaotiantian@hotmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 93acd1f commit 4192135

File tree

5 files changed

+52
-28
lines changed

5 files changed

+52
-28
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,7 +1462,6 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
14621462
except Exception as error:
14631463
self._handle_error(error)
14641464

1465-
@disable_gc
14661465
def _execute(self, req: pb2.ExecutePlanRequest) -> None:
14671466
"""
14681467
Execute the passed request `req` and drop all results.
@@ -1496,12 +1495,12 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
14961495
else:
14971496
for attempt in self._retrying():
14981497
with attempt:
1499-
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
1500-
handle_response(b)
1498+
with disable_gc():
1499+
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
1500+
handle_response(b)
15011501
except Exception as error:
15021502
self._handle_error(error)
15031503

1504-
@disable_gc
15051504
def _execute_and_fetch_as_iterator(
15061505
self,
15071506
req: pb2.ExecutePlanRequest,
@@ -1697,8 +1696,15 @@ def handle_response(
16971696
else:
16981697
for attempt in self._retrying():
16991698
with attempt:
1700-
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
1701-
yield from handle_response(b)
1699+
with disable_gc():
1700+
gen = self._stub.ExecutePlan(req, metadata=self._builder.metadata())
1701+
while True:
1702+
try:
1703+
with disable_gc():
1704+
b = next(gen)
1705+
yield from handle_response(b)
1706+
except StopIteration:
1707+
break
17021708
except KeyboardInterrupt as kb:
17031709
logger.debug(f"Interrupt request received for operation={req.operation_id}")
17041710
if progress is not None:

python/pyspark/sql/connect/client/reattach.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import pyspark.sql.connect.proto as pb2
3535
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
3636
from pyspark.errors import PySparkRuntimeError
37+
from pyspark.util import disable_gc
3738

3839

3940
class ExecutePlanResponseReattachableIterator(Generator):
@@ -108,9 +109,10 @@ def __init__(
108109
# Note: This is not retried, because no error would ever be thrown here, and GRPC will only
109110
# throw error on first self._has_next().
110111
self._metadata = metadata
111-
self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter(
112-
self._stub.ExecutePlan(self._initial_request, metadata=metadata)
113-
)
112+
with disable_gc():
113+
self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter(
114+
self._stub.ExecutePlan(self._initial_request, metadata=metadata)
115+
)
114116

115117
# Current item from this iterator.
116118
self._current: Optional[pb2.ExecutePlanResponse] = None
@@ -142,8 +144,9 @@ def shutdown_threadpool_if_idle(cls) -> None:
142144

143145
def send(self, value: Any) -> pb2.ExecutePlanResponse:
144146
# will trigger reattach in case the stream completed without result_complete
145-
if not self._has_next():
146-
raise StopIteration()
147+
with disable_gc():
148+
if not self._has_next():
149+
raise StopIteration()
147150

148151
ret = self._current
149152
assert ret is not None

python/pyspark/sql/tests/connect/client/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
158158
buf = sink.getvalue()
159159
resp.arrow_batch.data = buf.to_pybytes()
160160
resp.arrow_batch.row_count = 2
161-
return [resp]
161+
return iter([resp])
162162

163163
def Interrupt(self, req: proto.InterruptRequest, metadata):
164164
self.req = req

python/pyspark/tests/test_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import gc
1718
import os
1819
import time
1920
import unittest
@@ -22,7 +23,7 @@
2223
from py4j.protocol import Py4JJavaError
2324

2425
from pyspark import keyword_only
25-
from pyspark.util import _parse_memory
26+
from pyspark.util import _parse_memory, disable_gc
2627
from pyspark.loose_version import LooseVersion
2728
from pyspark.testing.utils import PySparkTestCase, eventually, timeout
2829
from pyspark.find_spark_home import _find_spark_home
@@ -148,6 +149,12 @@ def test_parse_memory(self):
148149
with self.assertRaisesRegex(ValueError, "invalid format"):
149150
_parse_memory("2gs")
150151

152+
def test_disable_gc(self):
153+
self.assertTrue(gc.isenabled())
154+
with disable_gc():
155+
self.assertFalse(gc.isenabled())
156+
self.assertTrue(gc.isenabled())
157+
151158
@eventually(timeout=180, catch_timeout=True)
152159
@timeout(timeout=1)
153160
def test_retry_timeout_test(self):

python/pyspark/util.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717
#
1818

19+
import contextlib
1920
import copy
2021
import functools
2122
import faulthandler
@@ -32,7 +33,19 @@
3233
import warnings
3334
from contextlib import contextmanager
3435
from types import TracebackType
35-
from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, TypeVar, Union, cast
36+
from typing import (
37+
Any,
38+
Callable,
39+
Generator,
40+
IO,
41+
Iterator,
42+
List,
43+
Optional,
44+
TextIO,
45+
Tuple,
46+
TypeVar,
47+
Union,
48+
)
3649

3750
from pyspark.errors import PySparkRuntimeError
3851
from pyspark.serializers import (
@@ -860,21 +873,16 @@ def _do_server_auth(conn: "io.IOBase", auth_secret: str) -> None:
860873
)
861874

862875

863-
def disable_gc(f: FuncT) -> FuncT:
864-
"""Mark the function that should disable gc during execution"""
865-
866-
@functools.wraps(f)
867-
def wrapped(*args: Any, **kwargs: Any) -> Any:
868-
gc_enabled_originally = gc.isenabled()
876+
@contextlib.contextmanager
877+
def disable_gc() -> Generator[None, None, None]:
878+
gc_enabled_originally = gc.isenabled()
879+
if gc_enabled_originally:
880+
gc.disable()
881+
try:
882+
yield
883+
finally:
869884
if gc_enabled_originally:
870-
gc.disable()
871-
try:
872-
return f(*args, **kwargs)
873-
finally:
874-
if gc_enabled_originally:
875-
gc.enable()
876-
877-
return cast(FuncT, wrapped)
885+
gc.enable()
878886

879887

880888
_is_remote_only = None

0 commit comments

Comments
 (0)