Skip to content

Commit 61a024c

Browse files
committed
[SPARK-53916][PYTHON] Deduplicate the variables in PythonArrowInput
### What changes were proposed in this pull request? Deduplicate the variables in PythonArrowInput ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52621 from zhengruifeng/unify_var_name. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 032dcf8 commit 61a024c

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
115115
self: BasePythonRunner[Iterator[InternalRow], _] =>
116116
protected val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
117117

118+
protected val maxRecordsPerBatch: Int = {
119+
val v = SQLConf.get.arrowMaxRecordsPerBatch
120+
if (v > 0) v else Int.MaxValue
121+
}
122+
123+
protected val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
124+
118125
protected def writeNextBatchToArrowStream(
119126
root: VectorSchemaRoot,
120127
writer: ArrowStreamWriter,
@@ -145,13 +152,6 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
145152

146153
private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
147154
self: BasePythonRunner[Iterator[InternalRow], _] =>
148-
private val arrowMaxRecordsPerBatch = {
149-
val v = SQLConf.get.arrowMaxRecordsPerBatch
150-
if (v > 0) v else Int.MaxValue
151-
}
152-
153-
private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
154-
155155
// Marker inside the input iterator to indicate the start of the next batch.
156156
private var nextBatchStart: Iterator[InternalRow] = Iterator.empty
157157

@@ -169,7 +169,7 @@ private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
169169
val startData = dataOut.size()
170170

171171
val numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
172-
arrowWriter, writer, nextBatchStart, maxBytesPerBatch, arrowMaxRecordsPerBatch)
172+
arrowWriter, writer, nextBatchStart, maxBytesPerBatch, maxRecordsPerBatch)
173173

174174
val deltaData = dataOut.size() - startData
175175
pythonMetrics("pythonDataSent") += deltaData
@@ -234,14 +234,6 @@ private[python] object BatchedPythonArrowInput {
234234
* Enables an optimization that splits each group into the sized batches.
235235
*/
236236
private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner =>
237-
238-
val maxRecordsPerBatch: Int = {
239-
val v = SQLConf.get.arrowMaxRecordsPerBatch
240-
if (v > 0) v else Int.MaxValue
241-
}
242-
243-
val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
244-
245237
protected override def newWriter(
246238
env: SparkEnv,
247239
worker: PythonWorker,

0 commit comments

Comments
 (0)