@@ -31,6 +31,7 @@ import kotlinx.serialization.modules.SerializersModule
3131import org.bson.AbstractBsonReader
3232import org.bson.BsonInvalidOperationException
3333import org.bson.BsonReader
34+ import org.bson.BsonReaderMark
3435import org.bson.BsonType
3536import org.bson.BsonValue
3637import org.bson.codecs.BsonValueCodec
@@ -68,6 +69,20 @@ internal open class DefaultBsonDecoder(
6869 val validKeyKinds = setOf (PrimitiveKind .STRING , PrimitiveKind .CHAR , SerialKind .ENUM )
6970 val bsonValueCodec = BsonValueCodec ()
7071 const val UNKNOWN_INDEX = - 10
72+ fun validateCurrentBsonType (
73+ reader : AbstractBsonReader ,
74+ expectedType : BsonType ,
75+ descriptor : SerialDescriptor ,
76+ actualType : (descriptor: SerialDescriptor ) -> String = { it.kind.toString() }
77+ ) {
78+ reader.currentBsonType?.let {
79+ if (it != expectedType) {
80+ throw SerializationException (
81+ " Invalid data for `${actualType(descriptor)} ` expected a bson " +
82+ " ${expectedType.name.lowercase()} found: ${reader.currentBsonType} " )
83+ }
84+ }
85+ }
7186 }
7287
7388 private fun initElementMetadata (descriptor : SerialDescriptor ) {
@@ -119,29 +134,14 @@ internal open class DefaultBsonDecoder(
119134
120135 @Suppress(" ReturnCount" )
121136 override fun beginStructure (descriptor : SerialDescriptor ): CompositeDecoder {
122- when (descriptor.kind) {
123- is StructureKind .LIST -> {
124- reader.readStartArray()
125- return BsonArrayDecoder (reader, serializersModule, configuration)
126- }
127- is PolymorphicKind -> {
128- reader.readStartDocument()
129- return PolymorphicDecoder (reader, serializersModule, configuration)
130- }
137+ return when (descriptor.kind) {
138+ is StructureKind .LIST -> BsonArrayDecoder (descriptor, reader, serializersModule, configuration)
139+ is PolymorphicKind -> PolymorphicDecoder (descriptor, reader, serializersModule, configuration)
131140 is StructureKind .CLASS ,
132- StructureKind .OBJECT -> {
133- val current = reader.currentBsonType
134- if (current == null || current == BsonType .DOCUMENT ) {
135- reader.readStartDocument()
136- }
137- }
138- is StructureKind .MAP -> {
139- reader.readStartDocument()
140- return BsonDocumentDecoder (reader, serializersModule, configuration)
141- }
141+ StructureKind .OBJECT -> BsonDocumentDecoder (descriptor, reader, serializersModule, configuration)
142+ is StructureKind .MAP -> MapDecoder (descriptor, reader, serializersModule, configuration)
142143 else -> throw SerializationException (" Primitives are not supported at top-level" )
143144 }
144- return DefaultBsonDecoder (reader, serializersModule, configuration)
145145 }
146146
147147 override fun endStructure (descriptor : SerialDescriptor ) {
@@ -194,10 +194,17 @@ internal open class DefaultBsonDecoder(
194194
195195@OptIn(ExperimentalSerializationApi ::class )
196196private class BsonArrayDecoder (
197+ descriptor : SerialDescriptor ,
197198 reader : AbstractBsonReader ,
198199 serializersModule : SerializersModule ,
199200 configuration : BsonConfiguration
200201) : DefaultBsonDecoder(reader, serializersModule, configuration) {
202+
203+ init {
204+ validateCurrentBsonType(reader, BsonType .ARRAY , descriptor)
205+ reader.readStartArray()
206+ }
207+
201208 private var index = 0
202209 override fun decodeElementIndex (descriptor : SerialDescriptor ): Int {
203210 val nextType = reader.readBsonType()
@@ -208,18 +215,46 @@ private class BsonArrayDecoder(
208215
209216@OptIn(ExperimentalSerializationApi ::class )
210217private class PolymorphicDecoder (
218+ descriptor : SerialDescriptor ,
211219 reader : AbstractBsonReader ,
212220 serializersModule : SerializersModule ,
213221 configuration : BsonConfiguration
214222) : DefaultBsonDecoder(reader, serializersModule, configuration) {
215223 private var index = 0
224+ private var mark: BsonReaderMark ?
216225
217- override fun <T > decodeSerializableValue (deserializer : DeserializationStrategy <T >): T =
218- deserializer.deserialize(DefaultBsonDecoder (reader, serializersModule, configuration))
226+ init {
227+ mark = reader.mark
228+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor) { it.serialName }
229+ reader.readStartDocument()
230+ }
231+
232+ override fun <T > decodeSerializableValue (deserializer : DeserializationStrategy <T >): T {
233+ mark?.let {
234+ it.reset()
235+ mark = null
236+ }
237+ return deserializer.deserialize(DefaultBsonDecoder (reader, serializersModule, configuration))
238+ }
219239
220240 override fun decodeElementIndex (descriptor : SerialDescriptor ): Int {
241+ var found = false
221242 return when (index) {
222- 0 -> index++
243+ 0 -> {
244+ while (reader.readBsonType() != BsonType .END_OF_DOCUMENT ) {
245+ if (reader.readName() == configuration.classDiscriminator) {
246+ found = true
247+ break
248+ }
249+ reader.skipValue()
250+ }
251+ if (! found) {
252+ throw SerializationException (
253+ " Missing required discriminator field `${configuration.classDiscriminator} ` " +
254+ " for polymorphic class: `${descriptor.serialName} `." )
255+ }
256+ index++
257+ }
223258 1 -> index++
224259 else -> DECODE_DONE
225260 }
@@ -228,6 +263,20 @@ private class PolymorphicDecoder(
228263
229264@OptIn(ExperimentalSerializationApi ::class )
230265private class BsonDocumentDecoder (
266+ descriptor : SerialDescriptor ,
267+ reader : AbstractBsonReader ,
268+ serializersModule : SerializersModule ,
269+ configuration : BsonConfiguration
270+ ) : DefaultBsonDecoder(reader, serializersModule, configuration) {
271+ init {
272+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor) { it.serialName }
273+ reader.readStartDocument()
274+ }
275+ }
276+
277+ @OptIn(ExperimentalSerializationApi ::class )
278+ private class MapDecoder (
279+ descriptor : SerialDescriptor ,
231280 reader : AbstractBsonReader ,
232281 serializersModule : SerializersModule ,
233282 configuration : BsonConfiguration
@@ -236,6 +285,11 @@ private class BsonDocumentDecoder(
236285 private var index = 0
237286 private var isKey = false
238287
288+ init {
289+ validateCurrentBsonType(reader, BsonType .DOCUMENT , descriptor)
290+ reader.readStartDocument()
291+ }
292+
239293 override fun decodeString (): String {
240294 return if (isKey) {
241295 reader.readName()
0 commit comments