Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ spark/benchmarks
comet-event-trace.json
__pycache__
output
.claude/
141 changes: 140 additions & 1 deletion spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
Expand All @@ -56,6 +58,14 @@ import org.apache.comet.serde.operator._

object CometExecRule {

/**
* Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because
* the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have
* incompatible intermediate buffer formats between Spark and Comet.
*/
val COMET_UNSAFE_PARTIAL: TreeNodeTag[String] =
TreeNodeTag[String]("comet.unsafePartialAgg")

/**
* Fully native operators.
*/
Expand Down Expand Up @@ -388,6 +398,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
normalizedPlan
}

// Tag Partial aggregates that must not be converted to Comet because the
// corresponding Final aggregate cannot be converted and the intermediate buffer
// formats are incompatible. This runs before transform() so the tags are checked
// during the bottom-up conversion. Tags persist through AQE stage creation.
tagUnsafePartialAggregates(planWithJoinRewritten)

var newPlan = transform(planWithJoinRewritten)

// if the plan cannot be run fully natively then explain why (when appropriate
Expand Down Expand Up @@ -601,4 +617,127 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
}
}

/**
* Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such
* Final, if the aggregate functions have incompatible intermediate buffer formats, tag the
* corresponding Partial-mode aggregate so it will also be skipped during conversion.
*
* This prevents the crash described in issue #1389 where a Comet Partial produces intermediate
* data in a format that the Spark Final cannot interpret.
*/
private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = {
plan.foreach {
case agg: BaseAggregateExec =>
// Only consider single-mode Final aggregates. Multi-mode Finals come from Spark's
// distinct-aggregate rewrite, where the Comet partial (if any) feeds into a Spark
// PartialMerge rather than directly into a Final, which is a different code path
// than the Comet-Partial → Spark-Final crash scenario from issue #1389.
val modes = agg.aggregateExpressions.map(_.mode).distinct
if (modes == Seq(Final) &&
!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions) &&
!canAggregateBeConverted(agg, Final)) {
findPartialAggInPlan(agg.child).foreach { partial =>
// Only tag if the Partial would otherwise have been converted. If the Partial
// itself cannot be converted (e.g. the aggregate function is incompatible for the
// input type), there is no buffer-format mismatch to guard against, and tagging
// would mask the natural, more specific fallback reason.
if (canAggregateBeConverted(partial, Partial)) {
partial.setTagValue(
CometExecRule.COMET_UNSAFE_PARTIAL,
"Partial aggregate disabled: corresponding final aggregate " +
"cannot be converted to Comet and intermediate buffer formats are incompatible")
}
}
}
case _ =>
}
}

/**
* Conservative check for whether an aggregate could be converted to Comet. Checks operator
* enablement, grouping expressions, aggregate expressions, and result expressions.
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
* transformation state.
*
* WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert`
* (operators.scala). Any change to the convertibility rules there must be reflected here or
* this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
* shared predicate helper would be preferable.
*/
private def canAggregateBeConverted(
agg: BaseAggregateExec,
expectedMode: AggregateMode): Boolean = {
val handler = allExecs.get(agg.getClass)
if (handler.isEmpty) return false
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false

// ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
agg match {
case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false
case _ =>
}

val aggregateExpressions = agg.aggregateExpressions
val groupingExpressions = agg.groupingExpressions

if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false

if (groupingExpressions.exists(e => QueryPlanSerde.containsMapType(e.dataType))) return false

if (!groupingExpressions.forall(e =>
QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
return false
}

if (aggregateExpressions.isEmpty) {
// Result expressions always checked when there are no aggregate expressions
val attributes =
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
return agg.resultExpressions.forall(e =>
QueryPlanSerde.exprToProto(e, attributes).isDefined)
}

val modes = aggregateExpressions.map(_.mode).distinct
if (modes.size != 1 || modes.head != expectedMode) return false

// In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
// it must bind to input attributes. This mirrors the `binding` calculation in
// `CometBaseAggregate.doConvert`.
val binding = expectedMode != Final
if (!aggregateExpressions.forall(e =>
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
return false
}

// doConvert only checks resultExpressions in Final mode when aggregate expressions exist
// (Partial emits the buffer directly). Mirror that here to avoid false negatives.
if (expectedMode == Final) {
val attributes =
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
} else {
true
}
}

/**
* Look for a Partial-mode aggregate that feeds directly into the given plan (the child of a
* Final). Walks through exchanges and AQE stages only, stopping at anything else including
* other aggregate stages. This avoids tagging unrelated Partials found deeper in the plan (e.g.
* the non-distinct Partial in a distinct-aggregate rewrite, which is separated from the Final
* by intermediate PartialMerge stages). Requires `aggregateExpressions.nonEmpty` so that
* group-by-only dedup stages are not mistaken for the partial we want to tag.
*/
private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = plan match {
case agg: BaseAggregateExec
if agg.aggregateExpressions.nonEmpty &&
agg.aggregateExpressions.forall(e => e.mode == Partial) =>
Some(agg)
case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
case e: ShuffleExchangeExec => findPartialAggInPlan(e.child)
case _ => None
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
*/
def getSupportLevel(expr: T): SupportLevel = Compatible(None)

/**
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
* simple single-value buffers (MIN, MAX, COUNT, bitwise) are safe; those with complex or
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not.
*/
def supportsMixedPartialFinal: Boolean = false

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
Expand Down
31 changes: 31 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,24 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[VariancePop] -> CometVariancePop,
classOf[VarianceSamp] -> CometVarianceSamp)

/**
* Returns true if all aggregate expressions in the list have intermediate buffer formats that
* are compatible between Spark and Comet, making it safe to run Partial in one engine and Final
* in the other.
*/
def allAggsSupportMixedExecution(aggExprs: Seq[AggregateExpression]): Boolean = {
aggExprs.forall { aggExpr =>
val fn = aggExpr.aggregateFunction
aggrSerdeMap.get(fn.getClass) match {
case Some(handler) =>
handler
.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]]
.supportsMixedPartialFinal
case None => false
}
}
}

// A unique id for each expression. ~used to look up QueryContext during error creation.
private val exprIdCounter = new AtomicLong(0)

Expand Down Expand Up @@ -354,6 +372,19 @@ object QueryPlanSerde extends Logging with CometExprShim {
false
}

/**
* Returns true if the given data type is or contains a `MapType` at any nesting level. Arrow's
* row format (used by DataFusion's grouped hash aggregate for composite group keys) does not
* support `Map`, so grouping on any type that transitively contains a map would crash in native
* execution.
*/
def containsMapType(dt: DataType): Boolean = dt match {
case _: MapType => true
case a: ArrayType => containsMapType(a.elementType)
case s: StructType => s.fields.exists(f => containsMapType(f.dataType))
case _ => false
}

/**
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
Expand Down
12 changes: 12 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.comet.shims.CometEvalModeUtil

object CometMin extends CometAggregateExpressionSerde[Min] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Min,
Expand Down Expand Up @@ -81,6 +83,8 @@ object CometMin extends CometAggregateExpressionSerde[Min] {

object CometMax extends CometAggregateExpressionSerde[Max] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Max,
Expand Down Expand Up @@ -127,6 +131,8 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
}

object CometCount extends CometAggregateExpressionSerde[Count] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Count,
Expand Down Expand Up @@ -306,6 +312,8 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
}

object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitAnd: BitAndAgg,
Expand Down Expand Up @@ -340,6 +348,8 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
}

object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitOr: BitOrAgg,
Expand Down Expand Up @@ -374,6 +384,8 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
}

object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitXor: BitXorAgg,
Expand Down
26 changes: 19 additions & 7 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ import com.google.protobuf.CodedOutputStream
import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo}
import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.rules.CometExecRule
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
import org.apache.comet.serde.QueryPlanSerde
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType}
import org.apache.comet.serde.operator.CometSink

Expand Down Expand Up @@ -1359,10 +1361,24 @@ trait CometBaseAggregate {
// In distinct aggregates there can be a combination of modes
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
// if there is Comet partial aggregation, unless all aggregates have compatible
// intermediate buffer formats (safe for mixed Spark/Comet execution).
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty

if (multiMode || sparkFinalMode) {
if (multiMode) {
return None
}

if (sparkFinalMode &&
!QueryPlanSerde.allAggsSupportMixedExecution(aggregate.aggregateExpressions)) {
return None
}

// Check if this aggregate has been tagged as unsafe for mixed execution
// (Comet partial + Spark final with incompatible intermediate buffers)
val unsafeReason = aggregate.getTagValue(CometExecRule.COMET_UNSAFE_PARTIAL)
if (unsafeReason.isDefined) {
withInfo(aggregate, unsafeReason.get)
return None
}

Expand All @@ -1377,11 +1393,7 @@ trait CometBaseAggregate {
return None
}

if (groupingExpressions.exists(expr =>
expr.dataType match {
case _: MapType => true
case _ => false
})) {
if (groupingExpressions.exists(expr => QueryPlanSerde.containsMapType(expr.dataType))) {
withInfo(aggregate, "Grouping on map types is not supported")
return None
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject
+- HashAggregate
+- CometNativeColumnarToRow
CometNativeColumnarToRow
+- CometTakeOrderedAndProject
+- CometHashAggregate
+- CometColumnarExchange
+- HashAggregate
+- Project
Expand Down Expand Up @@ -64,4 +64,4 @@ TakeOrderedAndProject
+- CometFilter
+- CometNativeScan parquet spark_catalog.default.customer_demographics

Comet accelerated 21 out of 54 eligible operators (38%). Final plan contains 11 transitions between Spark and Comet.
Comet accelerated 23 out of 54 eligible operators (42%). Final plan contains 11 transitions between Spark and Comet.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
TakeOrderedAndProject
+- HashAggregate
+- CometNativeColumnarToRow
CometNativeColumnarToRow
+- CometTakeOrderedAndProject
+- CometHashAggregate
+- CometColumnarExchange
+- HashAggregate
+- Project
Expand Down Expand Up @@ -60,4 +60,4 @@ TakeOrderedAndProject
+- CometFilter
+- CometScan [native_iceberg_compat] parquet spark_catalog.default.customer_demographics

Comet accelerated 35 out of 54 eligible operators (64%). Final plan contains 7 transitions between Spark and Comet.
Comet accelerated 37 out of 54 eligible operators (68%). Final plan contains 7 transitions between Spark and Comet.
Loading
Loading