Skip to content

Commit 61fb41f

Browse files
nicholaschew11HeartSaVioR
authored andcommitted
[SPARK-55147][SS] Scope timestamp range for time-interval join retrieval in V4 state format
### What changes were proposed in this pull request? This PR improves the retrieval operation in the V4 stream-stream join state manager to scope the timestamp range for time-interval joins. Instead of scanning all timestamps for a given key during prefix scan, V4 now extracts constant interval offsets from the join condition and computes a `(minTs, maxTs)` range per input row, enabling the prefix scan to skip entries before `minTs` and terminate early past `maxTs`. - Add `scanRangeOffsets` and `computeTimestampRange` to `OneSideHashJoiner`, using `StreamingJoinHelper.getStateValueWatermark(eventWatermark=0)` to extract interval bounds from the join condition - Add `timestampRange` parameter to `getJoinedRows` in the state manager trait, V4 implementation, and V1-V3 base class (ignored by V1-V3) - Add `getValuesInRange` to `KeyWithTsToValuesStore` that filters by range and stops early past the upper bound - `getValues` now delegates to `getValuesInRange(Long.MinValue, Long.MaxValue)` ### Why are the changes needed? For time-interval joins, the V4 state format stores values indexed by `(key, timestamp)`. Without range scoping, retrieving matches requires scanning all timestamps for a key via prefix scan, even though the join condition constrains matching to a specific time window. With this change, the scan is bounded to only the relevant timestamp range, reducing I/O proportionally to the ratio of the interval width to the total timestamp span in state. ### Does this PR introduce _any_ user-facing change? No. V4 state format is experimental and gated behind `spark.sql.streaming.join.stateFormatV4.enabled`. ### How was this patch tested? New unit tests in `SymmetricHashJoinStateManagerEventTimeInValueSuite`: - `getJoinedRows with timestampRange`: boundary conditions, exact matches, empty ranges, full range - `timestampRange with multiple values per timestamp`: multiple values at the same timestamp Existing V4 join suites (Inner, Outer, FullOuter, LeftSemi) all pass. ### Was this patch authored or co-authored using generative AI tooling? Yes. (Claude Opus 4.6) Closes #54879 from nicholaschew11/SPARK-55147-range-scan-v4. Authored-by: Nicholas Chew <chew.nicky@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent e7a9976 commit 61fb41f

File tree

3 files changed

+151
-50
lines changed

3 files changed

+151
-50
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ import org.apache.hadoop.conf.Configuration
2323

2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
26-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
26+
import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
27+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
2728
import org.apache.spark.sql.catalyst.plans._
2829
import org.apache.spark.sql.catalyst.plans.physical._
2930
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -682,6 +683,50 @@ case class StreamingSymmetricHashJoinExec(
682683
private[this] val allowMultipleStatefulOperators: Boolean =
683684
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)
684685

686+
// V4 range scan for time-interval joins (SPARK-55147). Extracts constant interval
687+
// offsets from the join condition using getStateValueWatermark(eventWatermark=0).
688+
// The -1 eviction adjustment widens range by ~1ms/side; postJoinFilter handles exact bounds.
689+
private[this] val scanRangeOffsets: Option[(Long, Long)] = {
690+
val isV4TimeIntervalJoin = stateFormatVersion >= 4 && (stateWatermarkPredicate match {
691+
case Some(_: JoinStateValueWatermarkPredicate) => true
692+
case _ => false
693+
})
694+
695+
if (!isV4TimeIntervalJoin) {
696+
None
697+
} else {
698+
val (thisSideAttrs, otherSideAttrs) = joinSide match {
699+
case LeftSide => (left.output, right.output)
700+
case RightSide => (right.output, left.output)
701+
}
702+
703+
val lowerBoundMs = StreamingJoinHelper.getStateValueWatermark(
704+
AttributeSet(otherSideAttrs), AttributeSet(thisSideAttrs), condition.full, Some(0L))
705+
val upperBoundMs = StreamingJoinHelper.getStateValueWatermark(
706+
AttributeSet(thisSideAttrs), AttributeSet(otherSideAttrs), condition.full, Some(0L))
707+
708+
(lowerBoundMs, upperBoundMs) match {
709+
case (Some(lower), Some(upper)) =>
710+
Some((lower * 1000L, -upper * 1000L)) // ms -> us
711+
case _ => None
712+
}
713+
}
714+
}
715+
716+
private[this] val eventTimeIdxForRangeScan: Int = scanRangeOffsets.map { _ =>
717+
WatermarkSupport.findEventTimeColumnIndex(
718+
inputAttributes, !allowMultipleStatefulOperators).getOrElse(-1)
719+
}.getOrElse(-1)
720+
721+
private def computeTimestampRange(thisRow: UnsafeRow): Option[(Long, Long)] = {
722+
scanRangeOffsets match {
723+
case Some((lowerOffset, upperOffset)) if eventTimeIdxForRangeScan >= 0 =>
724+
val eventTimeUs = thisRow.getLong(eventTimeIdxForRangeScan)
725+
Some((eventTimeUs + lowerOffset, eventTimeUs + upperOffset))
726+
case _ => None
727+
}
728+
}
729+
685730
/**
686731
* Generate joined rows by consuming input from this side, and matching it with the buffered
687732
* rows (i.e. state) of the other side.
@@ -758,7 +803,8 @@ case class StreamingSymmetricHashJoinExec(
758803
otherSideJoiner.joinStateManager.getJoinedRows(
759804
key,
760805
thatRow => generateJoinedRow(thisRow, thatRow),
761-
postJoinFilter)
806+
postJoinFilter,
807+
timestampRange = computeTimestampRange(thisRow))
762808
}
763809
val outputIter = generateOutputIter(thisRow, joinedRowIter)
764810
new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,16 @@ trait SymmetricHashJoinStateManager {
6767
* required to do so.
6868
*
6969
* It is caller's responsibility to consume the whole iterator.
70+
*
71+
* @param timestampRange Optional optimization hint as (minTimestamp, maxTimestamp), both
72+
* inclusive. Derived classes may use it to reduce scan scope but are free to ignore it.
73+
* The predicate must produce correct output regardless of whether this hint is leveraged.
7074
*/
7175
def getJoinedRows(
7276
key: UnsafeRow,
7377
generateJoinedRow: InternalRow => JoinedRow,
74-
predicate: JoinedRow => Boolean): Iterator[JoinedRow]
78+
predicate: JoinedRow => Boolean,
79+
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow]
7580

7681
/**
7782
* Retrieve all joined rows for the given key and remove the matched rows from state. The joined
@@ -343,9 +348,8 @@ class SymmetricHashJoinStateManagerV4(
343348
override def getJoinedRows(
344349
key: UnsafeRow,
345350
generateJoinedRow: InternalRow => JoinedRow,
346-
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
347-
// TODO: [SPARK-55147] We could improve this method to get the scope of timestamp and scan keys
348-
// more efficiently. For now, we just get all values for the key.
351+
predicate: JoinedRow => Boolean,
352+
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = {
349353
def getJoinedRowsFromTsAndValues(
350354
ts: Long,
351355
valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
@@ -399,7 +403,8 @@ class SymmetricHashJoinStateManagerV4(
399403
getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)
400404

401405
case _ =>
402-
keyWithTsToValues.getValues(key).flatMap { result =>
406+
val (minTs, maxTs) = timestampRange.getOrElse((Long.MinValue, Long.MaxValue))
407+
keyWithTsToValues.getValuesInRange(key, minTs, maxTs).flatMap { result =>
403408
val ts = result.timestamp
404409
val valuesAndMatched = result.values.toArray
405410
getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
@@ -626,66 +631,64 @@ class SymmetricHashJoinStateManagerV4(
626631

627632
// NOTE: This assumes we consume the whole iterator to trigger completion.
628633
def getValues(key: UnsafeRow): Iterator[GetValuesResult] = {
634+
getValuesInRange(key, Long.MinValue, Long.MaxValue)
635+
}
636+
637+
/**
638+
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
639+
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
640+
*/
641+
def getValuesInRange(
642+
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
629643
val reusableGetValuesResult = new GetValuesResult()
630644

631645
new NextIterator[GetValuesResult] {
632646
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
633647

634648
private var currentTs = -1L
649+
private var pastUpperBound = false
635650
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
636651

652+
private def flushAccumulated(): GetValuesResult = {
653+
if (valueAndMatchPairs.nonEmpty) {
654+
val result = reusableGetValuesResult.withNew(
655+
currentTs, valueAndMatchPairs.toList)
656+
currentTs = -1L
657+
valueAndMatchPairs.clear()
658+
result
659+
} else {
660+
finished = true
661+
null
662+
}
663+
}
664+
637665
@tailrec
638666
override protected def getNext(): GetValuesResult = {
639-
if (iter.hasNext) {
667+
if (pastUpperBound || !iter.hasNext) {
668+
flushAccumulated()
669+
} else {
640670
val unsafeRowPair = iter.next()
641-
642671
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
643672

644-
if (currentTs == -1L) {
645-
// First time
673+
if (ts > maxTs) {
674+
pastUpperBound = true
675+
getNext()
676+
} else if (ts < minTs) {
677+
getNext()
678+
} else if (currentTs == -1L || currentTs == ts) {
646679
currentTs = ts
647-
}
648-
649-
if (currentTs != ts) {
650-
assert(valueAndMatchPairs.nonEmpty,
651-
"timestamp has changed but no values collected from previous timestamp! " +
652-
s"This should not happen. currentTs: $currentTs, new ts: $ts")
653-
654-
// Return previous batch
655-
val result = reusableGetValuesResult.withNew(
656-
currentTs, valueAndMatchPairs.toSeq)
680+
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
681+
getNext()
682+
} else {
683+
// Timestamp changed -- flush previous group before starting new one
684+
val prevTs = currentTs
685+
val prevValues = valueAndMatchPairs.toList
657686

658-
// Reset for new timestamp
659687
currentTs = ts
660688
valueAndMatchPairs.clear()
689+
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
661690

662-
// Add current value
663-
val value = valueRowConverter.convertValue(unsafeRowPair.value)
664-
valueAndMatchPairs += value
665-
result
666-
} else {
667-
// Same timestamp, accumulate values
668-
val value = valueRowConverter.convertValue(unsafeRowPair.value)
669-
valueAndMatchPairs += value
670-
671-
// Continue to next
672-
getNext()
673-
}
674-
} else {
675-
if (currentTs != -1L) {
676-
assert(valueAndMatchPairs.nonEmpty)
677-
678-
// Return last batch
679-
val result = reusableGetValuesResult.withNew(
680-
currentTs, valueAndMatchPairs.toSeq)
681-
682-
// Mark as finished
683-
currentTs = -1L
684-
valueAndMatchPairs.clear()
685-
result
686-
} else {
687-
finished = true
688-
null
691+
reusableGetValuesResult.withNew(prevTs, prevValues)
689692
}
690693
}
691694
}
@@ -1051,7 +1054,8 @@ abstract class SymmetricHashJoinStateManagerBase(
10511054
def getJoinedRows(
10521055
key: UnsafeRow,
10531056
generateJoinedRow: InternalRow => JoinedRow,
1054-
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
1057+
predicate: JoinedRow => Boolean,
1058+
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = {
10551059
val numValues = keyToNumValues.get(key)
10561060
keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
10571061
val joinedRow = generateJoinedRow(keyIdxToValue.value)

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,4 +1009,55 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite
10091009
}
10101010
}
10111011
}
1012+
1013+
// NOTE: In practice, the predicate should contain the condition matching timestampRange.
1014+
// Here we intentionally use a pass-all predicate to test timestampRange filtering directly.
1015+
private def getJoinedRowTimestamps(
1016+
key: Int,
1017+
range: Option[(Long, Long)])(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = {
1018+
val dummyRow = new GenericInternalRow(0)
1019+
manager.getJoinedRows(
1020+
toJoinKeyRow(key),
1021+
row => new JoinedRow(row, dummyRow),
1022+
_ => true,
1023+
timestampRange = range
1024+
).map(_.getInt(1)).toSeq.sorted
1025+
}
1026+
1027+
test("StreamingJoinStateManager V4 - getJoinedRows with timestampRange") {
1028+
withJoinStateManager(
1029+
inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager =>
1030+
implicit val mgr = manager
1031+
1032+
Seq(10, 20, 30, 40, 50).foreach(append(40, _))
1033+
1034+
assert(getJoinedRowTimestamps(40, Some((20L, 40L))) === Seq(20, 30, 40))
1035+
assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20))
1036+
assert(getJoinedRowTimestamps(40, Some((25L, 35L))) === Seq(30))
1037+
assert(getJoinedRowTimestamps(40, Some((0L, 100L))) === Seq(10, 20, 30, 40, 50))
1038+
assert(getJoinedRowTimestamps(40, Some((10L, 30L))) === Seq(10, 20, 30))
1039+
assert(getJoinedRowTimestamps(40, Some((50L, 100L))) === Seq(50))
1040+
assert(getJoinedRowTimestamps(40, Some((60L, 100L))) === Seq.empty)
1041+
assert(getJoinedRowTimestamps(40, Some((0L, 5L))) === Seq.empty)
1042+
assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50))
1043+
}
1044+
}
1045+
1046+
test("StreamingJoinStateManager V4 - timestampRange with multiple values per timestamp") {
1047+
withJoinStateManager(
1048+
inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager =>
1049+
implicit val mgr = manager
1050+
1051+
append(40, 10)
1052+
append(40, 10)
1053+
append(40, 20)
1054+
append(40, 20)
1055+
append(40, 20)
1056+
append(40, 30)
1057+
1058+
assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20, 20, 20))
1059+
assert(getJoinedRowTimestamps(40, Some((10L, 20L))) === Seq(10, 10, 20, 20, 20))
1060+
assert(getJoinedRowTimestamps(40, Some((10L, 30L))) === Seq(10, 10, 20, 20, 20, 30))
1061+
}
1062+
}
10121063
}

0 commit comments

Comments
 (0)