Skip to content

Commit 2c0fc70

Browse files
committed
[SPARK-51891][SS] Squeeze the protocol of ListState GET / PUT / APPENDLIST for transformWithState in PySpark
### What changes were proposed in this pull request? This PR proposes to squeeze the protocol of ListState GET / PUT / APPENDLIST for transformWithState in PySpark, which will help a lot on dealing with small list on ListState. Here are the changes: * ListState.get() no longer requires additional request to notice there is no further data to read. * We inline the data into proto message, to ease of determine whether the iterator has fully consumed or not. * ListState.put() / ListState.appendList() do not require additional request to send the data separately. * We inline the data into propo message if the length of list we pass is small enough (now it's "magically" set to 100 elements - need to look further) * If the length of list is over 100, we fall back to "old" Arrow send (rather than custom protocol). This is because of the fact pickled Python Row contains the schema information as string, which is larger than we anticipated. So in some point, Arrow would be more efficient. NOTE: 100 is a sort of "magic number", and we will need to improve this with more benchmarking. ### Why are the changes needed? To optimize further on ListState operations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50689 from HeartSaVioR/SPARK-51891. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 83db398 commit 2c0fc70

File tree

9 files changed

+553
-115
lines changed

9 files changed

+553
-115
lines changed

python/pyspark/sql/streaming/list_state_client.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self.schema = schema
3838
# A dictionary to store the mapping between list state name and a tuple of data batch
3939
# and the index of the last row that was read.
40-
self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}
40+
self.data_batch_dict: Dict[str, Tuple[Any, int, bool]] = {}
4141

4242
def exists(self, state_name: str) -> bool:
4343
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -61,12 +61,12 @@ def exists(self, state_name: str) -> bool:
6161
f"Error checking value state exists: " f"{response_message[1]}"
6262
)
6363

64-
def get(self, state_name: str, iterator_id: str) -> Tuple:
64+
def get(self, state_name: str, iterator_id: str) -> Tuple[Tuple, bool]:
6565
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
6666

6767
if iterator_id in self.data_batch_dict:
6868
# If the state is already in the dictionary, return the next row.
69-
data_batch, index = self.data_batch_dict[iterator_id]
69+
data_batch, index, require_next_fetch = self.data_batch_dict[iterator_id]
7070
else:
7171
# If the state is not in the dictionary, fetch the state from the server.
7272
get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -79,23 +79,35 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
7979
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
8080

8181
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
82-
response_message = self._stateful_processor_api_client._receive_proto_message()
82+
response_message = (
83+
self._stateful_processor_api_client._receive_proto_message_with_list_get()
84+
)
8385
status = response_message[0]
8486
if status == 0:
85-
data_batch = self._stateful_processor_api_client._read_list_state()
87+
data_batch = list(
88+
map(
89+
lambda x: self._stateful_processor_api_client._deserialize_from_bytes(x),
90+
response_message[2],
91+
)
92+
)
93+
require_next_fetch = response_message[3]
8694
index = 0
8795
else:
8896
raise StopIteration()
8997

98+
is_last_row = False
9099
new_index = index + 1
91100
if new_index < len(data_batch):
92101
# Update the index in the dictionary.
93-
self.data_batch_dict[iterator_id] = (data_batch, new_index)
102+
self.data_batch_dict[iterator_id] = (data_batch, new_index, require_next_fetch)
94103
else:
95104
# If the index is at the end of the data batch, remove the state from the dictionary.
96105
self.data_batch_dict.pop(iterator_id, None)
106+
is_last_row = True
107+
108+
is_last_row_from_iterator = is_last_row and not require_next_fetch
97109
row = data_batch[index]
98-
return tuple(row)
110+
return (tuple(row), is_last_row_from_iterator)
99111

100112
def append_value(self, state_name: str, value: Tuple) -> None:
101113
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -118,7 +130,24 @@ def append_value(self, state_name: str, value: Tuple) -> None:
118130
def append_list(self, state_name: str, values: List[Tuple]) -> None:
119131
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
120132

121-
append_list_call = stateMessage.AppendList()
133+
send_data_via_arrow = False
134+
135+
# To workaround mypy type assignment check.
136+
values_as_bytes: Any = []
137+
if len(values) == 100:
138+
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
139+
# value backed by various benchmarks.
140+
# Arrow codepath
141+
send_data_via_arrow = True
142+
else:
143+
values_as_bytes = map(
144+
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
145+
values,
146+
)
147+
148+
append_list_call = stateMessage.AppendList(
149+
value=values_as_bytes, fetchWithArrow=send_data_via_arrow
150+
)
122151
list_state_call = stateMessage.ListStateCall(
123152
stateName=state_name, appendList=append_list_call
124153
)
@@ -127,7 +156,9 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
127156

128157
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
129158

130-
self._stateful_processor_api_client._send_list_state(self.schema, values)
159+
if send_data_via_arrow:
160+
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
161+
131162
response_message = self._stateful_processor_api_client._receive_proto_message()
132163
status = response_message[0]
133164
if status != 0:
@@ -137,14 +168,32 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
137168
def put(self, state_name: str, values: List[Tuple]) -> None:
138169
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
139170

140-
put_call = stateMessage.ListStatePut()
171+
send_data_via_arrow = False
172+
# To workaround mypy type assignment check.
173+
values_as_bytes: Any = []
174+
if len(values) == 100:
175+
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
176+
# value backed by various benchmarks.
177+
send_data_via_arrow = True
178+
else:
179+
values_as_bytes = map(
180+
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
181+
values,
182+
)
183+
184+
put_call = stateMessage.ListStatePut(
185+
value=values_as_bytes, fetchWithArrow=send_data_via_arrow
186+
)
187+
141188
list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call)
142189
state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call)
143190
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
144191

145192
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
146193

147-
self._stateful_processor_api_client._send_list_state(self.schema, values)
194+
if send_data_via_arrow:
195+
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
196+
148197
response_message = self._stateful_processor_api_client._receive_proto_message()
149198
status = response_message[0]
150199
if status != 0:
@@ -174,9 +223,17 @@ def __init__(self, list_state_client: ListStateClient, state_name: str):
174223
# Generate a unique identifier for the iterator to make sure iterators from the same
175224
# list state do not interfere with each other.
176225
self.iterator_id = str(uuid.uuid4())
226+
self.iterator_fully_consumed = False
177227

178228
def __iter__(self) -> Iterator[Tuple]:
179229
return self
180230

181231
def __next__(self) -> Tuple:
182-
return self.list_state_client.get(self.state_name, self.iterator_id)
232+
if self.iterator_fully_consumed:
233+
raise StopIteration()
234+
235+
row, is_last_row = self.list_state_client.get(self.state_name, self.iterator_id)
236+
if is_last_row:
237+
self.iterator_fully_consumed = True
238+
239+
return row

python/pyspark/sql/streaming/proto/StateMessage_pb2.py

Lines changed: 79 additions & 77 deletions
Large diffs are not rendered by default.

python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ See the License for the specific language governing permissions and
3434
limitations under the License.
3535
"""
3636
import builtins
37+
import collections.abc
3738
import google.protobuf.descriptor
39+
import google.protobuf.internal.containers
3840
import google.protobuf.internal.enum_type_wrapper
3941
import google.protobuf.message
4042
import sys
@@ -229,6 +231,44 @@ class StateResponseWithStringTypeVal(google.protobuf.message.Message):
229231

230232
global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal
231233

234+
class StateResponseWithListGet(google.protobuf.message.Message):
235+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
236+
237+
STATUSCODE_FIELD_NUMBER: builtins.int
238+
ERRORMESSAGE_FIELD_NUMBER: builtins.int
239+
VALUE_FIELD_NUMBER: builtins.int
240+
REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int
241+
statusCode: builtins.int
242+
errorMessage: builtins.str
243+
@property
244+
def value(
245+
self,
246+
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
247+
requireNextFetch: builtins.bool
248+
def __init__(
249+
self,
250+
*,
251+
statusCode: builtins.int = ...,
252+
errorMessage: builtins.str = ...,
253+
value: collections.abc.Iterable[builtins.bytes] | None = ...,
254+
requireNextFetch: builtins.bool = ...,
255+
) -> None: ...
256+
def ClearField(
257+
self,
258+
field_name: typing_extensions.Literal[
259+
"errorMessage",
260+
b"errorMessage",
261+
"requireNextFetch",
262+
b"requireNextFetch",
263+
"statusCode",
264+
b"statusCode",
265+
"value",
266+
b"value",
267+
],
268+
) -> None: ...
269+
270+
global___StateResponseWithListGet = StateResponseWithListGet
271+
232272
class StatefulProcessorCall(google.protobuf.message.Message):
233273
DESCRIPTOR: google.protobuf.descriptor.Descriptor
234274

@@ -1042,8 +1082,24 @@ global___ListStateGet = ListStateGet
10421082
class ListStatePut(google.protobuf.message.Message):
10431083
DESCRIPTOR: google.protobuf.descriptor.Descriptor
10441084

1085+
VALUE_FIELD_NUMBER: builtins.int
1086+
FETCHWITHARROW_FIELD_NUMBER: builtins.int
1087+
@property
1088+
def value(
1089+
self,
1090+
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
1091+
fetchWithArrow: builtins.bool
10451092
def __init__(
10461093
self,
1094+
*,
1095+
value: collections.abc.Iterable[builtins.bytes] | None = ...,
1096+
fetchWithArrow: builtins.bool = ...,
1097+
) -> None: ...
1098+
def ClearField(
1099+
self,
1100+
field_name: typing_extensions.Literal[
1101+
"fetchWithArrow", b"fetchWithArrow", "value", b"value"
1102+
],
10471103
) -> None: ...
10481104

10491105
global___ListStatePut = ListStatePut
@@ -1065,8 +1121,24 @@ global___AppendValue = AppendValue
10651121
class AppendList(google.protobuf.message.Message):
10661122
DESCRIPTOR: google.protobuf.descriptor.Descriptor
10671123

1124+
VALUE_FIELD_NUMBER: builtins.int
1125+
FETCHWITHARROW_FIELD_NUMBER: builtins.int
1126+
@property
1127+
def value(
1128+
self,
1129+
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
1130+
fetchWithArrow: builtins.bool
10681131
def __init__(
10691132
self,
1133+
*,
1134+
value: collections.abc.Iterable[builtins.bytes] | None = ...,
1135+
fetchWithArrow: builtins.bool = ...,
1136+
) -> None: ...
1137+
def ClearField(
1138+
self,
1139+
field_name: typing_extensions.Literal[
1140+
"fetchWithArrow", b"fetchWithArrow", "value", b"value"
1141+
],
10701142
) -> None: ...
10711143

10721144
global___AppendList = AppendList

python/pyspark/sql/streaming/stateful_processor_api_client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,18 @@ def _receive_proto_message_with_string_value(self) -> Tuple[int, str, str]:
425425
message.ParseFromString(bytes)
426426
return message.statusCode, message.errorMessage, message.value
427427

428+
# The third return type is RepeatedScalarFieldContainer[bytes], which is protobuf's container
429+
# type. We simplify it to Any here to avoid unnecessary complexity.
430+
def _receive_proto_message_with_list_get(self) -> Tuple[int, str, Any, bool]:
431+
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
432+
433+
length = read_int(self.sockfile)
434+
bytes = self.sockfile.read(length)
435+
message = stateMessage.StateResponseWithListGet()
436+
message.ParseFromString(bytes)
437+
438+
return message.statusCode, message.errorMessage, message.value, message.requireNextFetch
439+
428440
def _receive_str(self) -> str:
429441
return self.utf8_deserializer.loads(self.sockfile)
430442

0 commit comments

Comments
 (0)