diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 9d9e8f7017..086b5ff7fc 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -275,7 +275,7 @@ - [x] map_contains_key - [ ] map_entries - [ ] map_from_arrays -- [ ] map_from_entries +- [x] map_from_entries - [x] map_keys - [ ] map_values - [ ] str_to_map diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index ceafc157c4..637c6eeb73 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -20,6 +20,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} @@ -136,6 +137,8 @@ object CometMapContainsKey extends CometExpressionSerde[MapContainsKey] { object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" + val lastWinUnsupportedReason = + "spark.sql.mapKeyDedupPolicy=LAST_WIN is not yet supported natively for map_from_entries" private def containsBinary(dataType: DataType): Boolean = { dataType match { @@ -153,6 +156,11 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from if (containsBinary(expr.dataType.valueType)) { return Incompatible(Some(valueUnsupportedReason)) } + // Only the default EXCEPTION policy is supported natively; fall back otherwise. + if (!SQLConf.get.getConfString("spark.sql.mapKeyDedupPolicy", "EXCEPTION") + .equalsIgnoreCase("EXCEPTION")) { + return Incompatible(Some(lastWinUnsupportedReason)) + } Compatible(None) } } diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 03db26e566..4fd2829ecf 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -22,6 +22,7 @@ package org.apache.comet import scala.util.Random import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -236,4 +237,14 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("map_from_entries - duplicate keys throw under default EXCEPTION policy") { + val df = sql("select map_from_entries(array(struct(1, 'a'), struct(1, 'b')))") + val ex = intercept[SparkException] { + df.collect() + } + assert( + ex.getMessage.contains("DUPLICATED_MAP_KEY"), + s"expected DUPLICATED_MAP_KEY error class, got: ${ex.getMessage}") + } + }