Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50654][SS] CommitMetadata should set stateUniqueIds to None in V1 #49278

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CommitLog(sparkSession: SparkSession, path: String)

import CommitLog._

private val VERSION = SQLConf.get.stateStoreCheckpointFormatVersion

override protected[sql] def deserialize(in: InputStream): CommitMetadata = {
// called inside a try-finally where the underlying stream is closed in the caller
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
Expand All @@ -74,7 +76,6 @@ class CommitLog(sparkSession: SparkSession, path: String)
}

object CommitLog {
private val VERSION = SQLConf.get.stateStoreCheckpointFormatVersion
private val EMPTY_JSON = "{}"
}

Expand Down Expand Up @@ -104,7 +105,7 @@ object CommitLog {

case class CommitMetadata(
nextBatchWatermarkMs: Long = 0,
stateUniqueIds: Map[Long, Array[Array[String]]] = Map.empty) {
stateUniqueIds: Option[Map[Long, Array[Array[String]]]] = None) {
def json: String = Serialization.write(this)(CommitMetadata.format)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class IncrementalExecution(
val watermarkPropagator: WatermarkPropagator,
val isFirstBatch: Boolean,
val currentStateStoreCkptId:
MutableMap[Long, Array[Array[String]]] = MutableMap[Long, Array[Array[String]]]())
Option[MutableMap[Long, Array[Array[String]]]] =
Option(MutableMap[Long, Array[Array[String]]]()))
extends QueryExecution(sparkSession, logicalPlan) with Logging {

// Modified planner with stateful operations.
Expand Down Expand Up @@ -142,7 +143,7 @@ class IncrementalExecution(
operatorId,
currentBatchId,
numStateStores,
currentStateStoreCkptId.get(operatorId))
currentStateStoreCkptId.flatMap(_.get(operatorId)))
ret
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class MicroBatchExecution(
// Store checkpointIDs for state store checkpoints to be committed or have been committed to
// the commit log.
// operatorID -> (partitionID -> array of uniqueID)
private val currentStateStoreCkptId = MutableMap[Long, Array[Array[String]]]()
private var currentStateStoreCkptId: Option[MutableMap[Long, Array[Array[String]]]] = None

override lazy val logicalPlan: LogicalPlan = {
assert(queryExecutionThread eq Thread.currentThread,
Expand Down Expand Up @@ -513,7 +513,14 @@ class MicroBatchExecution(
execCtx.startOffsets ++= execCtx.endOffsets
watermarkTracker.setWatermark(
math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs))
currentStateStoreCkptId ++= commitMetadata.stateUniqueIds
if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(
sparkSessionForStream.sessionState.conf)) {
if (currentStateStoreCkptId.isEmpty) {
currentStateStoreCkptId = Some(MutableMap[Long, Array[Array[String]]]())
}
currentStateStoreCkptId.get.addAll(
commitMetadata.stateUniqueIds.getOrElse(Map.empty))
}
} else if (latestCommittedBatchId == latestBatchId - 1) {
execCtx.endOffsets.foreach {
case (source: Source, end: Offset) =>
Expand Down Expand Up @@ -920,6 +927,8 @@ class MicroBatchExecution(
execCtx: MicroBatchExecutionContext,
opId: Long,
checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
assert(StatefulOperatorStateInfo.enableStateStoreCheckpointIds(
sparkSessionForStream.sessionState.conf))
// TODO validate baseStateStoreCkptId
checkpointInfo.map(_.batchVersion).foreach { v =>
assert(
Expand All @@ -931,7 +940,10 @@ class MicroBatchExecution(
assert(info.stateStoreCkptId.isDefined)
info.stateStoreCkptId.get
}
currentStateStoreCkptId.put(opId, ckptIds)
if (currentStateStoreCkptId.isEmpty) {
currentStateStoreCkptId = Some(MutableMap[Long, Array[Array[String]]]())
}
currentStateStoreCkptId.get.put(opId, ckptIds)
}

/**
Expand Down Expand Up @@ -967,7 +979,7 @@ class MicroBatchExecution(
}
execCtx.reportTimeTaken("commitOffsets") {
if (!commitLog.add(execCtx.batchId,
CommitMetadata(watermarkTracker.currentWatermark, currentStateStoreCkptId.toMap))) {
CommitMetadata(watermarkTracker.currentWatermark, currentStateStoreCkptId.map(_.toMap)))) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
v1
{"nextBatchWatermarkMs":1,"stateUniqueIds":{}}
{"nextBatchWatermarkMs":1}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
v2
{"nextBatchWatermarkMs":0,"stateUniqueIds":{}}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
v1
v2
{"nextBatchWatermarkMs":0,"stateUniqueIds":{"0":[["unique_id1","unique_id2"],["unique_id3","unique_id4"]],"1":[["unique_id5","unique_id6"],["unique_id7","unique_id8"]]}}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,12 @@

package org.apache.spark.sql.streaming

import java.io.{ByteArrayInputStream, FileInputStream, FileOutputStream, InputStream, OutputStream}
import java.nio.charset.StandardCharsets.UTF_8
import java.io.{ByteArrayInputStream, FileInputStream, FileOutputStream}
import java.nio.file.Path

import scala.io.{Source => IOSource}

import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.{CommitLog, CommitMetadata, HDFSMetadataLog}
import org.apache.spark.sql.execution.streaming.{CommitLog, CommitMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

class CommitLogSuite extends SparkFunSuite with SharedSparkSession {
Expand All @@ -45,6 +39,18 @@ class CommitLogSuite extends SparkFunSuite with SharedSparkSession {
)
}

private def testCommitLogV2FilePathEmptyUniqueId: Path = {
getWorkspaceFilePath(
"sql",
"core",
"src",
"test",
"resources",
"structured-streaming",
"testCommitLogV2-empty-unique-id"
)
}

private def testCommitLogV1FilePath: Path = {
getWorkspaceFilePath(
"sql",
Expand All @@ -68,30 +74,45 @@ class CommitLogSuite extends SparkFunSuite with SharedSparkSession {
val metadata = commitLog.deserialize(inputStream)
// Array comparison are reference based, so we need to compare the elements
assert(metadata.nextBatchWatermarkMs == commitMetadata.nextBatchWatermarkMs)
assert(metadata.stateUniqueIds.size == commitMetadata.stateUniqueIds.size)
commitMetadata.stateUniqueIds.foreach { case (operatorId, uniqueIds) =>
assert(metadata.stateUniqueIds.contains(operatorId))
assert(metadata.stateUniqueIds(operatorId).length == uniqueIds.length)
assert(metadata.stateUniqueIds(operatorId).zip(uniqueIds).forall {
case (a, b) => a.sameElements(b)
})
if (metadata.stateUniqueIds.isEmpty) {
assert(commitMetadata.stateUniqueIds.isEmpty)
} else {
assert(metadata.stateUniqueIds.get.size == commitMetadata.stateUniqueIds.get.size)
commitMetadata.stateUniqueIds.get.foreach { case (operatorId, uniqueIds) =>
assert(metadata.stateUniqueIds.get.contains(operatorId))
assert(metadata.stateUniqueIds.get(operatorId).length == uniqueIds.length)
assert(metadata.stateUniqueIds.get(operatorId).zip(uniqueIds).forall {
case (a, b) => a.sameElements(b)
})
}
}
}
}

test("Basic Commit Log V1 SerDe") {
val testMetadataV1 = CommitMetadata(1)
testSerde(testMetadataV1, testCommitLogV1FilePath)
withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1") {
val testMetadataV1 = CommitMetadata(1)
testSerde(testMetadataV1, testCommitLogV1FilePath)
}
}

test("Basic Commit Log V2 SerDe") {
val testStateUniqueIds: Map[Long, Array[Array[String]]] =
Map(
0L -> Array(Array("unique_id1", "unique_id2"), Array("unique_id3", "unique_id4")),
1L -> Array(Array("unique_id5", "unique_id6"), Array("unique_id7", "unique_id8"))
)
val testMetadataV2 = CommitMetadata(0, testStateUniqueIds)
testSerde(testMetadataV2, testCommitLogV2FilePath)
test("Basic Commit Log V2 SerDe - nonempty stateUniqueIds") {
withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
val testStateUniqueIds: Map[Long, Array[Array[String]]] =
Map(
0L -> Array(Array("unique_id1", "unique_id2"), Array("unique_id3", "unique_id4")),
1L -> Array(Array("unique_id5", "unique_id6"), Array("unique_id7", "unique_id8"))
)
val testMetadataV2 = CommitMetadata(0, Some(testStateUniqueIds))
testSerde(testMetadataV2, testCommitLogV2FilePath)
}
}

test("Basic Commit Log V2 SerDe - empty stateUniqueIds") {
withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
val testMetadataV2 = CommitMetadata(0, Some(Map[Long, Array[Array[String]]]()))
testSerde(testMetadataV2, testCommitLogV2FilePathEmptyUniqueId)
}
}

// Old metadata structure with no state unique ids should not affect the deserialization
Expand All @@ -103,59 +124,6 @@ class CommitLogSuite extends SparkFunSuite with SharedSparkSession {
val commitMetadata: CommitMetadata = new CommitLog(
spark, testCommitLogV1FilePath.toString).deserialize(inputStream)
assert(commitMetadata.nextBatchWatermarkMs === 233)
assert(commitMetadata.stateUniqueIds === Map.empty)
}

// Test an old version of Spark can ser-de the new version of commit log,
// but running under V1 (i.e. no stateUniqueIds)
test("v1 Serde backward compatibility") {
// This is the json created by a V1 commit log
val commitLogV1WithStateUniqueId = """v1
|{"nextBatchWatermarkMs":1,"stateUniqueIds":{}}""".stripMargin
val inputStream: ByteArrayInputStream =
new ByteArrayInputStream(commitLogV1WithStateUniqueId.getBytes("UTF-8"))
val commitMetadata: CommitMetadataLegacy = new CommitLogLegacy(
spark, testCommitLogV1FilePath.toString).deserialize(inputStream)
assert(commitMetadata.nextBatchWatermarkMs === 1)
}
}

// DO-NOT-MODIFY-THE-CODE-BELOW
// Below are the legacy commit log code carbon copied from Spark branch-3.5, except
// adding a "Legacy" to the class names.
case class CommitMetadataLegacy(nextBatchWatermarkMs: Long = 0) {
def json: String = Serialization.write(this)(CommitMetadataLegacy.format)
}

object CommitMetadataLegacy {
implicit val format: Formats = Serialization.formats(NoTypeHints)

def apply(json: String): CommitMetadataLegacy = Serialization.read[CommitMetadataLegacy](json)
}

class CommitLogLegacy(sparkSession: SparkSession, path: String)
extends HDFSMetadataLog[CommitMetadataLegacy](sparkSession, path) {

private val VERSION = 1
private val EMPTY_JSON = "{}"

override def deserialize(in: InputStream): CommitMetadataLegacy = {
// called inside a try-finally where the underlying stream is closed in the caller
val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines()
if (!lines.hasNext) {
throw new IllegalStateException("Incomplete log file in the offset commit log")
}
validateVersion(lines.next().trim, VERSION)
val metadataJson = if (lines.hasNext) lines.next() else EMPTY_JSON
CommitMetadataLegacy(metadataJson)
}

override def serialize(metadata: CommitMetadataLegacy, out: OutputStream): Unit = {
// called inside a try-finally where the underlying stream is closed in the caller
out.write(s"v${VERSION}".getBytes(UTF_8))
out.write('\n')

// write metadata
out.write(metadata.json.getBytes(UTF_8))
assert(commitMetadata.stateUniqueIds === None)
}
}
Loading