Skip to content

Commit 310891f

Browse files
harshmotw-dbhuangxiaopingRD
authored andcommitted
[SPARK-54410][SQL] Fix read support for the variant logical type annotation
### What changes were proposed in this pull request? [This PR](apache#53005) introduced a fix where the Spark parquet writer would annotate variant columns with the parquet variant logical type. The PR had an ad-hoc fix on the reader side for validation. This PR formally allows Spark to read parquet files with the Variant logical type. The PR also introduces an unrelated fix in ParquetRowConverter to allow Spark to read variant columns regardless of which order the value and metadata fields are stored in. ### Why are the changes needed? The variant logical type annotation has formally been adopted as part of the parquet spec in is part of the parquet-java 1.16.0 library. Therefore, Spark should be able to read files containing data annotated as such. ### Does this PR introduce _any_ user-facing change? Yes, it allows users to read parquet files with the variant logical type annotation. ### How was this patch tested? Existing test from [this PR](apache#53005) where we wrote data of the variant logical type and tested read using an ad-hoc solution. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#53120 from harshmotw-db/harshmotw-db/variant_annotation_write. Authored-by: Harsh Motwani <harsh.motwani@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 019984b commit 310891f

File tree

5 files changed

+182
-84
lines changed

5 files changed

+182
-84
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,15 @@ object SQLConf {
16001600
.booleanConf
16011601
.createWithDefault(false)
16021602

1603+
val PARQUET_IGNORE_VARIANT_ANNOTATION =
1604+
buildConf("spark.sql.parquet.ignoreVariantAnnotation")
1605+
.internal()
1606+
.doc("When true, ignore the variant logical type annotation and treat the Parquet " +
1607+
"column in the same way as the underlying struct type")
1608+
.version("4.1.0")
1609+
.booleanConf
1610+
.createWithDefault(false)
1611+
16031612
val PARQUET_FIELD_ID_READ_ENABLED =
16041613
buildConf("spark.sql.parquet.fieldId.read.enabled")
16051614
.doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " +
@@ -5592,7 +5601,7 @@ object SQLConf {
55925601
"When false, it only reads unshredded variant.")
55935602
.version("4.0.0")
55945603
.booleanConf
5595-
.createWithDefault(false)
5604+
.createWithDefault(true)
55965605

55975606
val PUSH_VARIANT_INTO_SCAN =
55985607
buildConf("spark.sql.variant.pushVariantIntoScan")
@@ -7811,6 +7820,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
78117820

78127821
def parquetAnnotateVariantLogicalType: Boolean = getConf(PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE)
78137822

7823+
def parquetIgnoreVariantAnnotation: Boolean = getConf(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION)
7824+
78147825
def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID)
78157826

78167827
def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,11 @@ private[parquet] class ParquetRowConverter(
876876
}
877877
}
878878

879-
/** Parquet converter for unshredded Variant */
879+
/**
880+
* Parquet converter for unshredded Variant. We use this converter when the
881+
* `spark.sql.variant.allowReadingShredded` config is set to false. This option just exists to
882+
* fall back to legacy logic which will eventually be removed.
883+
*/
880884
private final class ParquetUnshreddedVariantConverter(
881885
parquetType: GroupType,
882886
updater: ParentContainerUpdater)
@@ -890,29 +894,27 @@ private[parquet] class ParquetRowConverter(
890894
// We may allow more than two children in the future, so consider this unsupported.
891895
throw QueryCompilationErrors.invalidVariantWrongNumFieldsError()
892896
}
893-
val valueAndMetadata = Seq("value", "metadata").map { colName =>
897+
val Seq(value, metadata) = Seq("value", "metadata").map { colName =>
894898
val idx = (0 until parquetType.getFieldCount())
895-
.find(parquetType.getFieldName(_) == colName)
896-
if (idx.isEmpty) {
897-
throw QueryCompilationErrors.invalidVariantMissingFieldError(colName)
898-
}
899-
val child = parquetType.getType(idx.get)
899+
.find(parquetType.getFieldName(_) == colName)
900+
.getOrElse(throw QueryCompilationErrors.invalidVariantMissingFieldError(colName))
901+
val child = parquetType.getType(idx)
900902
if (!child.isPrimitive || child.getRepetition != Type.Repetition.REQUIRED ||
901-
child.asPrimitiveType().getPrimitiveTypeName != BINARY) {
903+
child.asPrimitiveType().getPrimitiveTypeName != BINARY) {
902904
throw QueryCompilationErrors.invalidVariantNullableOrNotBinaryFieldError(colName)
903905
}
904-
child
906+
idx
905907
}
906-
Array(
907-
// Converter for value
908-
newConverter(valueAndMetadata(0), BinaryType, new ParentContainerUpdater {
908+
val result = new Array[Converter with HasParentContainerUpdater](2)
909+
result(value) =
910+
newConverter(parquetType.getType(value), BinaryType, new ParentContainerUpdater {
909911
override def set(value: Any): Unit = currentValue = value
910-
}),
911-
912-
// Converter for metadata
913-
newConverter(valueAndMetadata(1), BinaryType, new ParentContainerUpdater {
912+
})
913+
result(metadata) =
914+
newConverter(parquetType.getType(metadata), BinaryType, new ParentContainerUpdater {
914915
override def set(value: Any): Unit = currentMetadata = value
915-
}))
916+
})
917+
result
916918
}
917919

918920
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ class ParquetToSparkSchemaConverter(
5858
caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get,
5959
inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
6060
nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get,
61-
useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) {
61+
useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get,
62+
val ignoreVariantAnnotation: Boolean =
63+
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get) {
6264

6365
def this(conf: SQLConf) = this(
6466
assumeBinaryIsString = conf.isParquetBinaryAsString,
6567
assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp,
6668
caseSensitive = conf.caseSensitiveAnalysis,
6769
inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
6870
nanosAsLong = conf.legacyParquetNanosAsLong,
69-
useFieldId = conf.parquetFieldIdReadEnabled)
71+
useFieldId = conf.parquetFieldIdReadEnabled,
72+
ignoreVariantAnnotation = conf.parquetIgnoreVariantAnnotation)
7073

7174
def this(conf: Configuration) = this(
7275
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
@@ -75,7 +78,9 @@ class ParquetToSparkSchemaConverter(
7578
inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
7679
nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean,
7780
useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key,
78-
SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get))
81+
SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get),
82+
ignoreVariantAnnotation = conf.getBoolean(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key,
83+
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get))
7984

8085
/**
8186
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
@@ -202,15 +207,17 @@ class ParquetToSparkSchemaConverter(
202207
case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType)
203208
case groupColumn: GroupColumnIO if targetType.contains(VariantType) =>
204209
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
205-
val col = convertGroupField(groupColumn)
210+
// We need the underlying file type regardless of the config.
211+
val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true)
206212
col.copy(sparkType = VariantType, variantFileType = Some(col))
207213
} else {
208214
convertVariantField(groupColumn)
209215
}
210216
case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) =>
211-
val col = convertGroupField(groupColumn)
217+
val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true)
212218
col.copy(sparkType = targetType.get, variantFileType = Some(col))
213-
case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType)
219+
case groupColumn: GroupColumnIO =>
220+
convertGroupField(groupColumn, ignoreVariantAnnotation, targetType)
214221
}
215222
}
216223

@@ -349,6 +356,7 @@ class ParquetToSparkSchemaConverter(
349356

350357
private def convertGroupField(
351358
groupColumn: GroupColumnIO,
359+
ignoreVariantAnnotation: Boolean,
352360
sparkReadType: Option[DataType] = None): ParquetColumn = {
353361
val field = groupColumn.getType.asGroupType()
354362

@@ -373,9 +381,21 @@ class ParquetToSparkSchemaConverter(
373381

374382
Option(field.getLogicalTypeAnnotation).fold(
375383
convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) {
376-
// Temporary workaround to read Shredded variant data
377-
case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 && sparkReadType.isEmpty =>
378-
convertInternal(groupColumn, None)
384+
case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 =>
385+
if (ignoreVariantAnnotation) {
386+
convertInternal(groupColumn)
387+
} else {
388+
ParquetSchemaConverter.checkConversionRequirement(
389+
sparkReadType.forall(_.isInstanceOf[VariantType]),
390+
s"Invalid Spark read type: expected $field to be variant type but found " +
391+
s"${if (sparkReadType.isEmpty) { "None" } else {sparkReadType.get.sql} }")
392+
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
393+
val col = convertInternal(groupColumn)
394+
col.copy(sparkType = VariantType, variantFileType = Some(col))
395+
} else {
396+
convertVariantField(groupColumn)
397+
}
398+
}
379399

380400
// A Parquet list is represented as a 3-level structure:
381401
//

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,9 @@ case object SparkShreddingUtils {
646646
def parquetTypeToSparkType(parquetType: ParquetType): DataType = {
647647
val messageType = ParquetTypes.buildMessage().addField(parquetType).named("foo")
648648
val column = new ColumnIOFactory().getColumnIO(messageType)
649-
new ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType
649+
// We need the underlying file type regardless of the ignoreVariantAnnotation config.
650+
val converter = new ParquetToSparkSchemaConverter(ignoreVariantAnnotation = true)
651+
converter.convertField(column.getChild(0)).sparkType
650652
}
651653

652654
class SparkShreddedResult(schema: VariantSchema) extends VariantShreddingWriter.ShreddedResult {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala

Lines changed: 119 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.parquet.hadoop.util.HadoopInputFile
2828
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Type}
2929
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
3030

31-
import org.apache.spark.sql.{QueryTest, Row}
31+
import org.apache.spark.SparkException
32+
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
3233
import org.apache.spark.sql.internal.SQLConf
3334
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3435
import org.apache.spark.sql.test.SharedSparkSession
@@ -160,64 +161,126 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
160161
Seq(false, true).foreach { annotateVariantLogicalType =>
161162
Seq(false, true).foreach { shredVariant =>
162163
Seq(false, true).foreach { allowReadingShredded =>
163-
withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> shredVariant.toString,
164-
SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> shredVariant.toString,
165-
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key ->
166-
(allowReadingShredded || shredVariant).toString,
167-
SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key ->
168-
annotateVariantLogicalType.toString) {
169-
def validateAnnotation(g: Type): Unit = {
170-
if (annotateVariantLogicalType) {
171-
assert(g.getLogicalTypeAnnotation == LogicalTypeAnnotation.variantType(1))
172-
} else {
173-
assert(g.getLogicalTypeAnnotation == null)
164+
Seq(false, true).foreach { ignoreVariantAnnotation =>
165+
withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> shredVariant.toString,
166+
SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> shredVariant.toString,
167+
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key ->
168+
(allowReadingShredded || shredVariant).toString,
169+
SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key ->
170+
annotateVariantLogicalType.toString,
171+
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString) {
172+
def validateAnnotation(g: Type): Unit = {
173+
if (annotateVariantLogicalType) {
174+
assert(g.getLogicalTypeAnnotation == LogicalTypeAnnotation.variantType(1))
175+
} else {
176+
assert(g.getLogicalTypeAnnotation == null)
177+
}
178+
}
179+
withTempDir { dir =>
180+
// write parquet file
181+
val df = spark.sql(
182+
"""
183+
| select
184+
| id * 2 i,
185+
| to_variant_object(named_struct('id', id)) v,
186+
| named_struct('i', (id * 2)::string,
187+
| 'nv', to_variant_object(named_struct('id', 30 + id))) ns,
188+
| array(to_variant_object(named_struct('id', 10 + id))) av,
189+
| map('v2', to_variant_object(named_struct('id', 20 + id))) mv
190+
| from range(0,3,1,1)""".stripMargin)
191+
df.write.mode("overwrite").parquet(dir.getAbsolutePath)
192+
val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get
193+
val parquetFilePath = file.getAbsolutePath
194+
val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath),
195+
new Configuration())
196+
val reader = ParquetFileReader.open(inputFile)
197+
val footer = reader.getFooter
198+
val schema = footer.getFileMetaData.getSchema
199+
val vGroup = schema.getType(schema.getFieldIndex("v"))
200+
validateAnnotation(vGroup)
201+
assert(vGroup.asGroupType().getFields.asScala.toSeq
202+
.exists(_.getName == "typed_value") == shredVariant)
203+
val nsGroup = schema.getType(schema.getFieldIndex("ns")).asGroupType()
204+
val nvGroup = nsGroup.getType(nsGroup.getFieldIndex("nv"))
205+
validateAnnotation(nvGroup)
206+
val avGroup = schema.getType(schema.getFieldIndex("av")).asGroupType()
207+
val avList = avGroup.getType(avGroup.getFieldIndex("list")).asGroupType()
208+
val avElement = avList.getType(avList.getFieldIndex("element"))
209+
validateAnnotation(avElement)
210+
val mvGroup = schema.getType(schema.getFieldIndex("mv")).asGroupType()
211+
val mvList = mvGroup.getType(mvGroup.getFieldIndex("key_value")).asGroupType()
212+
val mvValue = mvList.getType(mvList.getFieldIndex("value"))
213+
validateAnnotation(mvValue)
214+
// verify result
215+
val result = spark.read.format("parquet")
216+
.schema("v variant, ns struct<nv variant>, av array<variant>, " +
217+
"mv map<string, variant>")
218+
.load(dir.getAbsolutePath)
219+
.selectExpr("v:id::int i1", "ns.nv:id::int i2", "av[0]:id::int i3",
220+
"mv['v2']:id::int i4")
221+
checkAnswer(result, Array(Row(0, 30, 10, 20), Row(1, 31, 11, 21),
222+
Row(2, 32, 12, 22)))
223+
reader.close()
174224
}
175225
}
176-
withTempDir { dir =>
177-
// write parquet file
178-
val df = spark.sql(
179-
"""
180-
| select
181-
| id * 2 i,
182-
| to_variant_object(named_struct('id', id)) v,
183-
| named_struct('i', (id * 2)::string,
184-
| 'nv', to_variant_object(named_struct('id', 30 + id))) ns,
185-
| array(to_variant_object(named_struct('id', 10 + id))) av,
186-
| map('v2', to_variant_object(named_struct('id', 20 + id))) mv
187-
| from range(0,3,1,1)""".stripMargin)
188-
df.write.mode("overwrite").parquet(dir.getAbsolutePath)
189-
val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get
190-
val parquetFilePath = file.getAbsolutePath
191-
val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath),
192-
new Configuration())
193-
val reader = ParquetFileReader.open(inputFile)
194-
val footer = reader.getFooter
195-
val schema = footer.getFileMetaData.getSchema
196-
val vGroup = schema.getType(schema.getFieldIndex("v"))
197-
validateAnnotation(vGroup)
198-
assert(vGroup.asGroupType().getFields.asScala.toSeq
199-
.exists(_.getName == "typed_value") == shredVariant)
200-
val nsGroup = schema.getType(schema.getFieldIndex("ns")).asGroupType()
201-
val nvGroup = nsGroup.getType(nsGroup.getFieldIndex("nv"))
202-
validateAnnotation(nvGroup)
203-
val avGroup = schema.getType(schema.getFieldIndex("av")).asGroupType()
204-
val avList = avGroup.getType(avGroup.getFieldIndex("list")).asGroupType()
205-
val avElement = avList.getType(avList.getFieldIndex("element"))
206-
validateAnnotation(avElement)
207-
val mvGroup = schema.getType(schema.getFieldIndex("mv")).asGroupType()
208-
val mvList = mvGroup.getType(mvGroup.getFieldIndex("key_value")).asGroupType()
209-
val mvValue = mvList.getType(mvList.getFieldIndex("value"))
210-
validateAnnotation(mvValue)
211-
// verify result
212-
val result = spark.read.format("parquet")
213-
.schema("v variant, ns struct<nv variant>, av array<variant>, " +
214-
"mv map<string, variant>")
215-
.load(dir.getAbsolutePath)
216-
.selectExpr("v:id::int i1", "ns.nv:id::int i2", "av[0]:id::int i3",
217-
"mv['v2']:id::int i4")
218-
checkAnswer(result, Array(Row(0, 30, 10, 20), Row(1, 31, 11, 21), Row(2, 32, 12, 22)))
219-
reader.close()
226+
}
227+
}
228+
}
229+
}
230+
}
231+
232+
test("variant logical type annotation - ignore variant annotation") {
233+
Seq(true, false).foreach { ignoreVariantAnnotation =>
234+
withSQLConf(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "true",
235+
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString
236+
) {
237+
withTempDir { dir =>
238+
// write parquet file
239+
val df = spark.sql(
240+
"""
241+
| select
242+
| id * 2 i,
243+
| 1::variant v,
244+
| named_struct('i', (id * 2)::string, 'nv', 1::variant) ns,
245+
| array(1::variant) av,
246+
| map('v2', 1::variant) mv
247+
| from range(0,1,1,1)""".stripMargin)
248+
df.write.mode("overwrite").parquet(dir.getAbsolutePath)
249+
// verify result
250+
val normal_result = spark.read.format("parquet")
251+
.schema("v variant, ns struct<nv variant>, av array<variant>, " +
252+
"mv map<string, variant>")
253+
.load(dir.getAbsolutePath)
254+
.selectExpr("v::int i1", "ns.nv::int i2", "av[0]::int i3",
255+
"mv['v2']::int i4")
256+
checkAnswer(normal_result, Array(Row(1, 1, 1, 1)))
257+
val struct_result = spark.read.format("parquet")
258+
.schema("v struct<value binary, metadata binary>, " +
259+
"ns struct<nv struct<value binary, metadata binary>>, " +
260+
"av array<struct<value binary, metadata binary>>, " +
261+
"mv map<string, struct<value binary, metadata binary>>")
262+
.load(dir.getAbsolutePath)
263+
.selectExpr("v", "ns.nv", "av[0]", "mv['v2']")
264+
if (ignoreVariantAnnotation) {
265+
checkAnswer(
266+
struct_result,
267+
Seq(Row(
268+
Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)),
269+
Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)),
270+
Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)),
271+
Row(Array[Byte](12, 1), Array[Byte](1, 0, 0))
272+
))
273+
)
274+
} else {
275+
val exception = intercept[SparkException]{
276+
struct_result.collect()
220277
}
278+
checkError(
279+
exception = exception.getCause.asInstanceOf[AnalysisException],
280+
condition = "_LEGACY_ERROR_TEMP_3071",
281+
parameters = Map("msg" -> "Invalid Spark read type[\\s\\S]*"),
282+
matchPVals = true
283+
)
221284
}
222285
}
223286
}

0 commit comments

Comments
 (0)