diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index c6561510c0359..28498f18cb084 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3419,8 +3419,12 @@ class PlanGenerationTestSuite // Handle parameterized scala types e.g.: List, Seq and Map. fn.typedLit(Some(1)), fn.typedLit(Array(1, 2, 3)), + fn.typedLit[Array[Integer]](Array(null, null)), + fn.typedLit[Array[(Int, String)]](Array(null, null, (1, "a"), (2, null))), + fn.typedLit[Array[Option[(Int, String)]]](Array(None, None, Some((1, "a")))), fn.typedLit(Seq(1, 2, 3)), fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)), + fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> null)), fn.typedLit(("a", 2, 1.0)), fn.typedLit[Option[Int]](None), fn.typedLit[Array[Option[Int]]](Array(Some(1))), diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index e2213003656e5..db165c03ad359 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -1785,6 +1785,13 @@ class ClientE2ETestSuite assert(observation.get.contains("map")) assert(observation.get("map") === Map("count" -> 10)) } + + test("SPARK-53553: null value handling in literals") { + val df = spark.sql("select 1").select(typedlit(Array[Integer](1, null)).as("arr_col")) + val result = df.collect() + assert(result.length === 1) + assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null)) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 286b83d4eae92..16bbeb99557b4 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -40,6 +40,19 @@ import org.apache.spark.unsafe.types.CalendarInterval object LiteralValueProtoConverter { + private def setNullValue( + builder: proto.Expression.Literal.Builder, + dataType: DataType, + needDataType: Boolean): proto.Expression.Literal.Builder = { + if (needDataType) { + builder.setNull(toConnectProtoType(dataType)) + } else { + // No need data type but still set the null type to indicate that + // the value is null. + builder.setNull(ProtoDataTypes.NullType) + } + } + private def setArrayTypeAfterAddingElements( ab: proto.Expression.Literal.Array.Builder, elementType: DataType, @@ -275,6 +288,14 @@ object LiteralValueProtoConverter { } (literal, dataType) match { + case (v: Option[_], _) => + if (v.isDefined) { + toLiteralProtoBuilderInternal(v.get, dataType, options, needDataType) + } else { + setNullValue(builder, dataType, needDataType) + } + case (null, _) => + setNullValue(builder, dataType, needDataType) case (v: mutable.ArraySeq[_], ArrayType(_, _)) => toLiteralProtoBuilderInternal(v.array, dataType, options, needDataType) case (v: immutable.ArraySeq[_], ArrayType(_, _)) => @@ -287,12 +308,6 @@ object LiteralValueProtoConverter { builder.setMap(mapBuilder(v, keyType, valueType, valueContainsNull)) case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) - case (v: Option[_], _: DataType) => - if (v.isDefined) { - toLiteralProtoBuilderInternal(v.get, options, needDataType) - } else { - builder.setNull(toConnectProtoType(dataType)) - } case (v: LocalTime, timeType: TimeType) => builder.setTime( builder.getTimeBuilder @@ -477,7 +492,7 @@ object LiteralValueProtoConverter { } private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - dataType.getKindCase match { + val converter: proto.Expression.Literal => Any = dataType.getKindCase match { case proto.DataType.KindCase.SHORT => v => v.getShort.toShort case proto.DataType.KindCase.INTEGER => v => v.getInteger case proto.DataType.KindCase.LONG => v => v.getLong @@ -513,6 +528,7 @@ object LiteralValueProtoConverter { case _ => throw InvalidPlanInput(s"Unsupported Literal Type: ${dataType.getKindCase}") } + v => if (v.hasNull) null else converter(v) } private def getInferredDataType( diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 817b923202c58..5daa50bfe38a0 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, [18545,1677155519808000,12345000,1677184560000000,19411,200000000,0,86399999999999,2 months 20 days 0.0001 seconds] AS NAMED_STRUCT('_1', DATE '2020-10-10', '_2', TIMESTAMP '2023-02-23 04:31:59.808', '_3', TIMESTAMP '1969-12-31 16:00:12.345', '_4', TIMESTAMP_NTZ '2023-02-23 20:36:00', '_5', DATE '2023-02-23', '_6', INTERVAL '0 00:03:20' DAY TO SECOND, '_7', INTERVAL '0-0' YEAR TO MONTH, '_8', TIME '23:59:59.999999999', '_9', INTERVAL '2 months 20 days 0.0001 seconds')#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, map(keys: [1], values: [0]) AS MAP(1, 0)#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, [keys: [a,b], values: [[1,2],[3,4]],keys: [a,b], values: [[5,6],[7,8]],keys: [a,b], values: [[],[]]] AS ARRAY(MAP('a', ARRAY('1', '2'), 'b', ARRAY('3', '4')), MAP('a', ARRAY('5', '6'), 'b', ARRAY('7', '8')), MAP('a', ARRAY(), 'b', ARRAY()))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, INTERVAL '0-0' YEAR TO MONTH AS INTERVAL '0-0' YEAR TO MONTH#0, 23:59:59.999999999 AS TIME '23:59:59.999999999'#0, 2 months 20 days 0.0001 seconds AS INTERVAL '2 months 20 days 0.0001 seconds'#0, [18545,1677155519808000,12345000,1677184560000000,19411,200000000,0,86399999999999,2 months 20 days 0.0001 seconds] AS NAMED_STRUCT('_1', DATE '2020-10-10', '_2', TIMESTAMP '2023-02-23 04:31:59.808', '_3', TIMESTAMP '1969-12-31 16:00:12.345', '_4', TIMESTAMP_NTZ '2023-02-23 20:36:00', '_5', DATE '2023-02-23', '_6', INTERVAL '0 00:03:20' DAY TO SECOND, '_7', INTERVAL '0-0' YEAR TO MONTH, '_8', TIME '23:59:59.999999999', '_9', INTERVAL '2 months 20 days 0.0001 seconds')#0, 1 AS 1#0, [1,2,3] AS ARRAY(1, 2, 3)#0, [null,null] AS ARRAY(CAST(NULL AS INT), CAST(NULL AS INT))#0, [null,null,[1,a],[2,null]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'), NAMED_STRUCT('_1', 2, '_2', CAST(NULL AS STRING)))#0, [null,null,[1,a]] AS ARRAY(NULL, NULL, NAMED_STRUCT('_1', 1, '_2', 'a'))#0, [1,2,3] AS ARRAY(1, 2, 3)#0, map(keys: [a,b], values: [1,2]) AS MAP('a', 1, 'b', 2)#0, map(keys: [a,b], values: [null,null]) AS MAP('a', CAST(NULL AS INT), 'b', CAST(NULL AS INT))#0, [a,2,1.0] AS NAMED_STRUCT('_1', 'a', '_2', 2, '_3', 1.0D)#0, null AS NULL#0, [1] AS ARRAY(1)#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, map(keys: [1], values: [null]) AS MAP(1, CAST(NULL AS INT))#0, [[1,2,3],[4,5,6],[7,8,9]] AS ARRAY(ARRAY(1, 2, 3), ARRAY(4, 5, 6), ARRAY(7, 8, 9))#0, [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4],keys: [a,b], values: [5,6]] AS ARRAY(MAP('a', 1, 'b', 2), MAP('a', 3, 'b', 4), MAP('a', 5, 'b', 6))#0, [keys: [a,b], values: [[1,2],[3,4]],keys: [a,b], values: [[5,6],[7,8]],keys: [a,b], values: [[],[]]] AS ARRAY(MAP('a', ARRAY('1', '2'), 'b', ARRAY('3', '4')), MAP('a', ARRAY('5', '6'), 'b', ARRAY('7', '8')), MAP('a', ARRAY(), 'b', ARRAY()))#0, map(keys: [1,2], values: [keys: [a,b], values: [1,2],keys: [a,b], values: [3,4]]) AS MAP(1, MAP('a', 1, 'b', 2), 2, MAP('a', 3, 'b', 4))#0, [[1,2,3],keys: [a,b], values: [1,2],[a,keys: [1,2], values: [a,b]]] AS NAMED_STRUCT('_1', ARRAY(1, 2, 3), '_2', MAP('a', 1, 'b', 2), '_3', NAMED_STRUCT('_1', 'a', '_2', MAP(1, 'a', 2, 'b')))#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 5869ec44789da..db7b2a992e945 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -77,7 +77,8 @@ }, { "literal": { "null": { - "null": { + "string": { + "collation": "UTF8_BINARY" } } }, @@ -821,6 +822,206 @@ } } } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "null": { + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "struct": { + "fields": [{ + "name": "_1", + "dataType": { + "integer": { + } + } + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "null": { + "null": { + } + } + }, { + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "a" + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "struct": { + "elements": [{ + "integer": 2 + }, { + "null": { + "string": { + "collation": "UTF8_BINARY" + } + } + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "struct": { + "fields": [{ + "name": "_1", + "dataType": { + "integer": { + } + } + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "null": { + "null": { + } + } + }, { + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "a" + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "array": { @@ -891,6 +1092,52 @@ } } } + }, { + "literal": { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "null": { + } + } + }], + "dataType": { + "keyType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "valueContainsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "struct": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index 00f80df0e2292..6c5ea53d05a9c 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 80c185ee8b3c3..9a2827cf8b55f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -53,7 +53,11 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i } } + // The goal of this test is to check that converting a Scala value -> Proto -> Catalyst value + // is equivalent to converting a Scala value directly to a Catalyst value. Seq[(Any, DataType)]( + (Array[String](null, "a", null), ArrayType(StringType)), + (Map[String, String]("a" -> null, "b" -> null), MapType(StringType, StringType)), ( (1, "string", true), StructType(