11package com .exasol .common .avro
22
33import java .nio .ByteBuffer
4+ import java .util .{Map => JMap }
5+ import java .util .Collection
46
57import com .exasol .common .data .Row
8+ import com .exasol .common .json .JsonMapper
69
710import org .apache .avro .Schema
811import org .apache .avro .generic .GenericFixed
912import org .apache .avro .generic .GenericRecord
13+ import org .apache .avro .generic .IndexedRecord
1014import org .apache .avro .util .Utf8
1115
1216/**
@@ -15,40 +19,63 @@ import org.apache.avro.util.Utf8
1519 */
1620object AvroRow {
1721
22+ /**
23+ * Converts an Avro record into an internal [[com.exasol.common.data.Row ]].
24+ *
25+ * @param avroRecord a generic Avro record
26+ * @return a Row representation of the given Avro record
27+ */
1828 def apply (avroRecord : GenericRecord ): Row = {
19- val size = avroRecord.getSchema.getFields.size
20- val values = Array .ofDim[Any ](size)
2129 val fields = avroRecord.getSchema().getFields()
22- for { index <- 0 until fields.size } {
23- values.update(index, getAvroRecordValue(avroRecord.get(index), fields.get(index).schema))
30+ val size = fields.size()
31+ val values = Array .ofDim[Any ](size)
32+ for { i <- 0 until size } {
33+ values.update(i, getAvroFieldValue(fields.get(i).schema(), avroRecord.get(i)))
2434 }
2535 Row (values.toSeq)
2636 }
2737
38+ private [this ] def getAvroFieldValue (schema : Schema , value : Any ): Any = {
39+ val fieldValue = getAvroValue(value, schema)
40+ if (isPrimitiveAvroType(schema.getType())) {
41+ fieldValue
42+ } else {
43+ JsonMapper .toJson(fieldValue)
44+ }
45+ }
46+
47+ private [this ] def isPrimitiveAvroType (avroType : Schema .Type ): Boolean =
48+ avroType match {
49+ case Schema .Type .ARRAY => false
50+ case Schema .Type .MAP => false
51+ case Schema .Type .RECORD => false
52+ case _ => true
53+ }
54+
2855 @ SuppressWarnings (Array (" org.wartremover.warts.Return" , " org.wartremover.warts.ToString" ))
29- private [this ] def getAvroRecordValue (value : Any , field : Schema ): Any = {
56+ private [this ] def getAvroValue (value : Any , field : Schema ): Any = {
3057 if (value == null ) {
3158 return null // scalastyle:ignore return
32-
3359 }
34- field.getType match {
60+ field.getType() match {
3561 case Schema .Type .NULL => value
3662 case Schema .Type .BOOLEAN => value
3763 case Schema .Type .INT => value
3864 case Schema .Type .LONG => value
3965 case Schema .Type .FLOAT => value
4066 case Schema .Type .DOUBLE => value
41- case Schema .Type .STRING => getAvroValueAsString (value, field)
42- case Schema .Type .FIXED => getAvroValueAsString (value, field)
43- case Schema .Type .BYTES => getAvroValueAsString (value, field)
67+ case Schema .Type .STRING => getStringValue (value, field)
68+ case Schema .Type .FIXED => getStringValue (value, field)
69+ case Schema .Type .BYTES => getStringValue (value, field)
4470 case Schema .Type .ENUM => value.toString
45- case Schema .Type .UNION => getAvroUnionValue(value, field)
46- case field =>
47- throw new IllegalArgumentException (s " Avro ${field.getName} type is not supported! " )
71+ case Schema .Type .UNION => getUnionValue(value, field)
72+ case Schema .Type .ARRAY => getArrayValue(value, field)
73+ case Schema .Type .MAP => getMapValue(value, field)
74+ case Schema .Type .RECORD => getRecordValue(value)
4875 }
4976 }
5077
51- private [this ] def getAvroValueAsString (value : Any , field : Schema ): String =
78+ private [this ] def getStringValue (value : Any , field : Schema ): String =
5279 value match {
5380 case str : String => str
5481 case utf : Utf8 => utf.toString
@@ -61,16 +88,16 @@ object AvroRow {
6188 )
6289 }
6390
64- private [this ] def getAvroUnionValue (value : Any , field : Schema ): Any = {
91+ private [this ] def getUnionValue (value : Any , field : Schema ): Any = {
6592 val types = field.getTypes()
6693 val typesSize = types.size()
6794 typesSize match {
68- case 1 => getAvroRecordValue (value, types.get(0 ))
95+ case 1 => getAvroValue (value, types.get(0 ))
6996 case 2 =>
7097 if (types.get(0 ).getType() == Schema .Type .NULL ) {
71- getAvroRecordValue (value, types.get(1 ))
98+ getAvroValue (value, types.get(1 ))
7299 } else if (types.get(1 ).getType() == Schema .Type .NULL ) {
73- getAvroRecordValue (value, types.get(0 ))
100+ getAvroValue (value, types.get(0 ))
74101 } else {
75102 throw new IllegalArgumentException (
76103 " Avro Union type should contain a primitive and null!"
@@ -81,4 +108,46 @@ object AvroRow {
81108 }
82109 }
83110
111+ private [this ] def getArrayValue (value : Any , field : Schema ): Array [Any ] = value match {
112+ case array : Array [_] => array.map(getAvroValue(_, field.getElementType()))
113+ case list : Collection [_] =>
114+ val result = new Array [Any ](list.size)
115+ var i = 0
116+ list.stream().forEach { element =>
117+ val _ = result.update(i, getAvroValue(element, field.getElementType()))
118+ i += 1
119+ }
120+ result
121+ case other =>
122+ throw new IllegalArgumentException (
123+ s " Unsupported Avro Array type ' ${other.getClass.getName()}'. "
124+ )
125+ }
126+
127+ private [this ] def getMapValue (map : Any , field : Schema ): JMap [String , Any ] = {
128+ val result = new java.util.HashMap [String , Any ]()
129+ map.asInstanceOf [JMap [String , _]].forEach { (key, value) =>
130+ val _ = result.put(key, getAvroValue(value, field.getValueType()))
131+ }
132+ result
133+ }
134+
135+ private [this ] def getRecordValue (value : Any ): JMap [String , Any ] = value match {
136+ case record : IndexedRecord =>
137+ val size = record.getSchema().getFields().size
138+ val fields = record.getSchema().getFields()
139+ val result = new java.util.HashMap [String , Any ]()
140+ var i = 0
141+ while (i < size) {
142+ val _ =
143+ result.put(fields.get(i).name, getAvroValue(record.get(i), fields.get(i).schema))
144+ i += 1
145+ }
146+ result
147+ case other =>
148+ throw new IllegalArgumentException (
149+ s " Unsupported Avro Record type ' ${other.getClass.getName()}'. "
150+ )
151+ }
152+
84153}
0 commit comments