Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateStoreColumnFamilySchemaUtils
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{SCHEMA_ID_PREFIX_BYTES, STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
Expand Down Expand Up @@ -846,6 +847,8 @@ class AvroStateEncoder(
}
}
StructType(remainingSchema)
case _ =>
throw unsupportedOperationForKeyStateEncoder("createAvroEnc")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we improve the passed arg/error message ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def unsupportedOperationForKeyStateEncoder(
      operation: String
  ): UnsupportedOperationException = {
    new UnsupportedOperationException(
      s"Method $operation not supported for encoder spec type " +
        s"${keyStateEncoderSpec.getClass.getSimpleName}")
  }

I feel like this should be sufficient as long as this isn't user-facing error?

}

// Handle suffix key schema for prefix scan case
Expand Down Expand Up @@ -1713,6 +1716,206 @@ class NoPrefixKeyStateEncoder(
}
}

/**
* The singleton instance to provide utility-like methods for key state encoders which include
* timestamp, specifically [[TimestampAsPrefixKeyStateEncoder]] and
* [[TimestampAsPostfixKeyStateEncoder]].
*/
object TimestampKeyStateEncoder {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add some high level comments here ?

private val INTERNAL_TIMESTAMP_COLUMN_NAME = "__event_time"

def keySchemaWithTimestamp(keySchema: StructType): StructType = {
StructType(keySchema.fields)
.add(name = INTERNAL_TIMESTAMP_COLUMN_NAME, dataType = LongType, nullable = false)
}

def getAttachTimestampProjection(keyWithoutTimestampSchema: StructType): UnsafeProjection = {
val refs = keyWithoutTimestampSchema.zipWithIndex.map(x =>
BoundReference(x._2, x._1.dataType, x._1.nullable))
UnsafeProjection.create(
refs :+ Literal(0L), // placeholder for timestamp
DataTypeUtils.toAttributes(StructType(keyWithoutTimestampSchema)))
}

def getDetachTimestampProjection(keyWithTimestampSchema: StructType): UnsafeProjection = {
val refs = keyWithTimestampSchema.zipWithIndex.dropRight(1).map(x =>
BoundReference(x._2, x._1.dataType, x._1.nullable))
UnsafeProjection.create(refs)
}

def attachTimestamp(
attachTimestampProjection: UnsafeProjection,
keyWithTimestampSchema: StructType,
key: UnsafeRow,
timestamp: Long): UnsafeRow = {
val rowWithTimestamp = attachTimestampProjection(key)
rowWithTimestamp.setLong(keyWithTimestampSchema.length - 1, timestamp)
rowWithTimestamp
}

def extractTimestamp(key: UnsafeRow): Long = {
key.getLong(key.numFields - 1)
}
}

/**
* The abstract base class for key state encoders which include timestamp, specifically
* [[TimestampAsPrefixKeyStateEncoder]] and [[TimestampAsPostfixKeyStateEncoder]].
*/
abstract class TimestampKeyStateEncoder(
dataEncoder: RocksDBDataEncoder,
keySchema: StructType)
extends RocksDBKeyStateEncoder with Logging {

protected val detachTimestampProjection: UnsafeProjection =
TimestampKeyStateEncoder.getDetachTimestampProjection(keySchema)

protected val attachTimestampProjection: UnsafeProjection =
TimestampKeyStateEncoder.getAttachTimestampProjection(
StructType(keySchema.fields.dropRight(1)))

protected def decodeKey(keyBytes: Array[Byte], startPos: Int): UnsafeRow = {
val rowBytesLength = keyBytes.length - 8
val rowBytes = new Array[Byte](rowBytesLength)
Platform.copyMemory(
keyBytes, Platform.BYTE_ARRAY_OFFSET + startPos,
rowBytes, Platform.BYTE_ARRAY_OFFSET,
rowBytesLength
)
// The encoded row does not include the timestamp (it's stored separately),
// so decode with keySchema.length - 1 fields.
dataEncoder.decodeToUnsafeRow(rowBytes, keySchema.length - 1)
}

// NOTE: We reuse the ByteBuffer to avoid allocating a new one for every encoding/decoding,
// which means the encoder is not thread-safe. Built-in operators do not access the encoder in
// multiple threads, but if we are concerned about thread-safety in the future, we can maintain
// the thread-local of ByteBuffer to retain the reusability of the instance while avoiding
// thread-safety issue. We do not use position - we always put/get at offset 0.
private val buffForBigEndianLong = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)

private val SIGN_MASK_FOR_LONG: Long = 0x8000000000000000L

protected def encodeTimestamp(timestamp: Long): Array[Byte] = {
// Flip the sign bit to ensure correct lexicographical ordering, even for negative timestamps.
// We should flip the sign bit back when decoding the timestamp.
val signFlippedTimestamp = timestamp ^ SIGN_MASK_FOR_LONG
buffForBigEndianLong.putLong(0, signFlippedTimestamp)
buffForBigEndianLong.array()
}

protected def decodeTimestamp(keyBytes: Array[Byte], startPos: Int): Long = {
buffForBigEndianLong.put(0, keyBytes, startPos, 8)
val signFlippedTimestamp = buffForBigEndianLong.getLong(0)
// Flip the sign bit back to get the original timestamp.
signFlippedTimestamp ^ SIGN_MASK_FOR_LONG
}

protected def attachTimestamp(key: UnsafeRow, timestamp: Long): UnsafeRow = {
TimestampKeyStateEncoder.attachTimestamp(attachTimestampProjection, keySchema, key, timestamp)
}

protected def detachTimestamp(key: UnsafeRow): UnsafeRow = {
detachTimestampProjection(key)
}

def extractTimestamp(key: UnsafeRow): Long = {
TimestampKeyStateEncoder.extractTimestamp(key)
}
}

/**
* Encodes row with timestamp as prefix of the key, so that they can be scanned based on
* timestamp ordering.
*
* The encoder expects the provided key schema to have [original key fields..., timestamp field].
* The key has to conform to this schema when putting/getting from the state store. The schema
* needs to be built via calling [[TimestampKeyStateEncoder.keySchemaWithTimestamp()]].
*/
class TimestampAsPrefixKeyStateEncoder(
dataEncoder: RocksDBDataEncoder,
keySchema: StructType,
useColumnFamilies: Boolean = false)
extends TimestampKeyStateEncoder(dataEncoder, keySchema) with Logging {

override def supportPrefixKeyScan: Boolean = false

override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
throw new IllegalStateException("This encoder doesn't support key without event time!")
}

override def encodeKey(row: UnsafeRow): Array[Byte] = {
val prefix = dataEncoder.encodeKey(detachTimestamp(row))
val timestamp = extractTimestamp(row)

val byteArray = new Array[Byte](prefix.length + 8)
Platform.copyMemory(
encodeTimestamp(timestamp), Platform.BYTE_ARRAY_OFFSET,
byteArray, Platform.BYTE_ARRAY_OFFSET, 8)
Platform.copyMemory(prefix, Platform.BYTE_ARRAY_OFFSET,
byteArray, Platform.BYTE_ARRAY_OFFSET + 8, prefix.length)

byteArray
}

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val timestamp = decodeTimestamp(keyBytes, 0)
val row = decodeKey(keyBytes, 8)
attachTimestamp(row, timestamp)
}

// TODO: [SPARK-55491] Revisit this to support delete range if needed.
override def supportsDeleteRange: Boolean = false
}

/**
* Encodes row with timestamp as postfix of the key, so that prefix scan with the keys
* having the same key but different timestamps is supported. In addition, timestamp is stored
* in sort order to support timestamp ordered iteration in the result of prefix scan.
*
* The encoder expects the provided key schema to have [original key fields..., timestamp field].
* The key has to be conformed to this schema when putting/getting from the state store. The schema
* needs to be built via calling [[TimestampKeyStateEncoder.keySchemaWithTimestamp()]].
*/
class TimestampAsPostfixKeyStateEncoder(
dataEncoder: RocksDBDataEncoder,
keySchema: StructType,
useColumnFamilies: Boolean = false)
extends TimestampKeyStateEncoder(dataEncoder, keySchema) with Logging {

override def supportPrefixKeyScan: Boolean = true

override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
dataEncoder.encodeKey(prefixKey)
}

override def encodeKey(row: UnsafeRow): Array[Byte] = {
val prefix = dataEncoder.encodeKey(detachTimestamp(row))
val timestamp = extractTimestamp(row)

val byteArray = new Array[Byte](prefix.length + 8)

Platform.copyMemory(prefix, Platform.BYTE_ARRAY_OFFSET,
byteArray, Platform.BYTE_ARRAY_OFFSET, prefix.length)
Platform.copyMemory(
encodeTimestamp(timestamp), Platform.BYTE_ARRAY_OFFSET,
byteArray, Platform.BYTE_ARRAY_OFFSET + prefix.length,
8
)

byteArray
}

override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
val row = decodeKey(keyBytes, 0)
val rowBytesLength = keyBytes.length - 8
val timestamp = decodeTimestamp(keyBytes, rowBytesLength)
attachTimestamp(row, timestamp)
}

override def supportsDeleteRange: Boolean = false
}

/**
* Supports encoding multiple values per key in RocksDB.
* A single value is encoded in the format below, where first value is number of bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,36 @@ case class RangeKeyScanStateEncoderSpec(
}
}

/** The encoder specification for [[TimestampAsPrefixKeyStateEncoder]]. */
case class TimestampAsPrefixKeyStateEncoderSpec(keySchema: StructType)
extends KeyStateEncoderSpec {

override def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean): RocksDBKeyStateEncoder = {
new TimestampAsPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies)
}

override def jsonValue: JValue = {
"keyStateEncoderType" -> JString("TimestampAsPrefixKeyStateEncoderSpec")
}
}

/** The encoder specification for [[TimestampAsPostfixKeyStateEncoder]]. */
case class TimestampAsPostfixKeyStateEncoderSpec(keySchema: StructType)
extends KeyStateEncoderSpec {

override def toEncoder(
dataEncoder: RocksDBDataEncoder,
useColumnFamilies: Boolean): RocksDBKeyStateEncoder = {
new TimestampAsPostfixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies)
}

override def jsonValue: JValue = {
"keyStateEncoderType" -> JString("TimestampAsPostfixKeyStateEncoderSpec")
}
}

/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
Expand Down Expand Up @@ -1081,7 +1111,6 @@ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
}
}


/**
* Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
* by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
Expand Down
Loading