• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 @file:OptIn(ExperimentalSerializationApi::class)
6 @file:Suppress("UNCHECKED_CAST")
7 
8 package kotlinx.serialization.protobuf.internal
9 
10 import kotlinx.serialization.*
11 import kotlinx.serialization.builtins.*
12 import kotlinx.serialization.descriptors.*
13 import kotlinx.serialization.encoding.*
14 import kotlinx.serialization.internal.*
15 import kotlinx.serialization.modules.*
16 import kotlinx.serialization.protobuf.*
17 import kotlin.jvm.*
18 
19 internal open class ProtobufDecoder(
20     @JvmField protected val proto: ProtoBuf,
21     @JvmField protected val reader: ProtobufReader,
22     @JvmField protected val descriptor: SerialDescriptor
23 ) : ProtobufTaggedDecoder() {
24     override val serializersModule: SerializersModule
25         get() = proto.serializersModule
26 
27     // Proto id -> index in serial descriptor cache
28     private var indexCache: IntArray? = null
29     private var sparseIndexCache: MutableMap<Int, Int>? = null
30 
31     // Index -> proto id for oneof element. An oneof element of certain index may refer to different proto id in runtime.
32     private var index2IdMap: MutableMap<Int, Int>? = null
33 
34     private var nullValue: Boolean = false
35     private val elementMarker = ElementMarker(descriptor, ::readIfAbsent)
36 
37     init {
38         populateCache(descriptor)
39     }
40 
populateCachenull41     public fun populateCache(descriptor: SerialDescriptor) {
42         val elements = descriptor.elementsCount
43         if (elements < 32) {
44             /*
45              * If we have reasonably small count of elements, try to build sequential
46              * array for the fast-path. Fast-path implies that elements are not marked with @ProtoId
47              * explicitly or are monotonic and incremental (maybe, 1-indexed)
48              *
49              * Initialize all elements, because there will always be one extra element as arrays are numbered from 0
50              * but in protobuf field number starts from 1.
51              */
52             val cache = IntArray(elements + 1) { -1 }
53             for (i in 0 until elements) {
54                 val protoId = extractProtoId(descriptor, i, false)
55                 // If any element is marked as ProtoOneOf,
56                 // the fast path is not applicable
57                 // because it will contain more id than elements
58                 if (protoId <= elements && protoId != ID_HOLDER_ONE_OF) {
59                     cache[protoId] = i
60                 } else {
61                     return populateCacheMap(descriptor, elements)
62                 }
63             }
64             indexCache = cache
65         } else {
66             populateCacheMap(descriptor, elements)
67         }
68     }
69 
populateCacheMapnull70     private fun populateCacheMap(descriptor: SerialDescriptor, elements: Int) {
71         val map = HashMap<Int, Int>(elements, 1f)
72         var oneOfCount = 0
73         for (i in 0 until elements) {
74             val id = extractProtoId(descriptor, i, false)
75             if (id == ID_HOLDER_ONE_OF) {
76                 descriptor.getElementDescriptor(i)
77                     .getAllOneOfSerializerOfField(serializersModule)
78                     .map { it.extractParameters(0).protoId }
79                     .forEach { map.putProtoId(it, i) }
80                 oneOfCount ++
81             } else {
82                 map.putProtoId(extractProtoId(descriptor, i, false),  i)
83             }
84         }
85         if (oneOfCount > 0) {
86             index2IdMap = HashMap(oneOfCount, 1f)
87         }
88         sparseIndexCache = map
89     }
90 
MutableMapnull91     private fun MutableMap<Int, Int>.putProtoId(protoId: Int, index: Int) {
92         put(protoId, index)
93     }
94 
getIndexByNumnull95     private fun getIndexByNum(protoNum: Int): Int {
96         val array = indexCache
97         if (array != null) {
98             return array.getOrElse(protoNum) { -1 }
99         }
100         return getIndexByNumSlowPath(protoNum)
101     }
102 
getIndexByNumSlowPathnull103     private fun getIndexByNumSlowPath(
104         protoTag: Int
105     ): Int = sparseIndexCache!!.getOrElse(protoTag) { -1 }
106 
findIndexByTagnull107     private fun findIndexByTag(descriptor: SerialDescriptor, protoTag: Int): Int {
108         // Fast-path: tags are incremental, 1-based
109         if (protoTag < descriptor.elementsCount && protoTag >= 0) {
110             val protoId = extractProtoId(descriptor, protoTag, true)
111             if (protoId == protoTag) return protoTag
112         }
113         return findIndexByTagSlowPath(descriptor, protoTag)
114     }
115 
findIndexByTagSlowPathnull116     private fun findIndexByTagSlowPath(desc: SerialDescriptor, protoTag: Int): Int {
117         for (i in 0 until desc.elementsCount) {
118             val protoId = extractProtoId(desc, i, true)
119             if (protoId == protoTag) return i
120         }
121 
122         throw ProtobufDecodingException(
123             "$protoTag is not among valid ${descriptor.serialName} enum proto numbers"
124         )
125     }
126 
beginStructurenull127     override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
128         return try {
129             when (descriptor.kind) {
130                 StructureKind.LIST -> {
131                     val tag = currentTagOrDefault
132                     return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
133                         val reader = makeDelimited(reader, tag)
134                         // repeated decoder expects the first tag to be read already
135                         reader.readTag()
136                         // all elements always have id = 1
137                         RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)
138 
139                     } else if (reader.currentType == ProtoWireType.SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
140                         val sliceReader = ProtobufReader(reader.objectInput())
141                         PackedArrayDecoder(proto, sliceReader, descriptor)
142 
143                     } else {
144                         RepeatedDecoder(proto, reader, tag, descriptor)
145                     }
146                 }
147 
148                 StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
149                     val tag = currentTagOrDefault
150                     // Do not create redundant copy
151                     if (tag == MISSING_TAG && this.descriptor == descriptor) return this
152                     if (tag.isOneOf) {
153                         // If a tag is annotated as oneof
154                         // [tag.protoId] here is overwritten with index-based default id in
155                         // [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
156                         // and restored the real id from index2IdMap, set by [decodeElementIndex]
157                         val rawIndex = tag.protoId - 1
158                         val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
159                         return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)
160                     }
161                     return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
162                 }
163 
164                 StructureKind.MAP -> MapEntryReader(
165                     proto,
166                     makeDelimitedForced(reader, currentTagOrDefault),
167                     currentTagOrDefault,
168                     descriptor
169                 )
170 
171                 else -> throw SerializationException("Primitives are not supported at top-level")
172             }
173         } catch (e: ProtobufDecodingException) {
174             throw ProtobufDecodingException("Fail to begin structure for ${descriptor.serialName} in ${this.descriptor.serialName} at proto number ${currentTagOrDefault.protoId}", e)
175         }
176     }
177 
endStructurenull178     override fun endStructure(descriptor: SerialDescriptor) {
179         // Nothing
180     }
181 
decodeTaggedBooleannull182     override fun decodeTaggedBoolean(tag: ProtoDesc): Boolean = when(val value = decodeTaggedInt(tag)) {
183         0 -> false
184         1 -> true
185         else -> throw SerializationException("Unexpected boolean value: $value")
186     }
187 
decodeTaggedBytenull188     override fun decodeTaggedByte(tag: ProtoDesc): Byte = decodeTaggedInt(tag).toByte()
189     override fun decodeTaggedShort(tag: ProtoDesc): Short = decodeTaggedInt(tag).toShort()
190     override fun decodeTaggedInt(tag: ProtoDesc): Int {
191         return decodeOrThrow(tag) {
192             if (tag == MISSING_TAG) {
193                 reader.readInt32NoTag()
194             } else {
195                 reader.readInt(tag.integerType)
196             }
197         }
198     }
decodeTaggedLongnull199     override fun decodeTaggedLong(tag: ProtoDesc): Long {
200         return decodeOrThrow(tag) {
201             if (tag == MISSING_TAG) {
202                 reader.readLongNoTag()
203             } else {
204                 reader.readLong(tag.integerType)
205             }
206         }
207     }
208 
decodeTaggedFloatnull209     override fun decodeTaggedFloat(tag: ProtoDesc): Float {
210         return decodeOrThrow(tag) {
211             if (tag == MISSING_TAG) {
212                 reader.readFloatNoTag()
213             } else {
214                 reader.readFloat()
215             }
216         }
217     }
decodeTaggedDoublenull218     override fun decodeTaggedDouble(tag: ProtoDesc): Double {
219         return decodeOrThrow(tag) {
220             if (tag == MISSING_TAG) {
221                 reader.readDoubleNoTag()
222             } else {
223                 reader.readDouble()
224             }
225         }
226     }
decodeTaggedCharnull227     override fun decodeTaggedChar(tag: ProtoDesc): Char = decodeTaggedInt(tag).toChar()
228 
229     override fun decodeTaggedString(tag: ProtoDesc): String {
230         return decodeOrThrow(tag) {
231             if (tag == MISSING_TAG) {
232                 reader.readStringNoTag()
233             } else {
234                 reader.readString()
235             }
236         }
237     }
238 
decodeTaggedEnumnull239     override fun decodeTaggedEnum(tag: ProtoDesc, enumDescription: SerialDescriptor): Int {
240         return findIndexByTag(enumDescription, decodeTaggedInt(tag))
241     }
242 
decodeSerializableValuenull243     override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T = decodeSerializableValue(deserializer, null)
244 
245     @Suppress("UNCHECKED_CAST")
246     override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = try {
247         when {
248             deserializer is MapLikeSerializer<*, *, *, *> -> {
249                 deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
250             }
251 
252             deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
253             deserializer is AbstractCollectionSerializer<*, *, *> ->
254                 (deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)
255 
256             else -> deserializer.deserialize(this)
257         }
258     } catch (e: ProtobufDecodingException) {
259         val currentTag = currentTagOrDefault
260         val msg = if (descriptor != deserializer.descriptor) {
261             // Decoding child element
262             if (descriptor.kind == StructureKind.LIST && deserializer.descriptor.kind != StructureKind.MAP) {
263                 // Decoding repeated field
264                 "Error while decoding index ${currentTag.protoId - 1} in repeated field of ${deserializer.descriptor.serialName}"
265             } else if (descriptor.kind == StructureKind.MAP) {
266                 // Decoding map field
267                 val index = (currentTag.protoId - 1) / 2
268                 val field = if ((currentTag.protoId - 1) % 2 == 0) { "key" } else "value"
269                 "Error while decoding $field of index $index in map field of ${deserializer.descriptor.serialName}"
270             } else {
271                 // Decoding common class
272                 "Error while decoding ${deserializer.descriptor.serialName} at proto number ${currentTag.protoId} of ${descriptor.serialName}"
273             }
274         } else {
275             // Decoding self
276             "Error while decoding ${descriptor.serialName}"
277         }
278         throw ProtobufDecodingException(msg, e)
279     }
280 
deserializeByteArraynull281     private fun deserializeByteArray(previousValue: ByteArray?): ByteArray {
282         val tag = currentTagOrDefault
283         val array = decodeOrThrow(tag) {
284             if (tag == MISSING_TAG) {
285                 reader.readByteArrayNoTag()
286             } else {
287                 reader.readByteArray()
288             }
289         }
290         return if (previousValue == null) array else previousValue + array
291     }
292 
293     @Suppress("UNCHECKED_CAST")
deserializeMapnull294     private fun <T> deserializeMap(deserializer: DeserializationStrategy<T>, previousValue: T?): T {
295         val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
296         // Yeah thanks different resolution algorithms
297         val mapEntrySerial =
298             kotlinx.serialization.builtins.MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
299         val oldSet = (previousValue as? Map<Any?, Any?>)?.entries
300         val setOfEntries = (SetSerializer(mapEntrySerial) as AbstractCollectionSerializer<Map.Entry<Any?, Any?>, Set<Map.Entry<Any?, Any?>>, *>).merge(this, oldSet)
301         return setOfEntries.associateBy({ it.key }, { it.value }) as T
302     }
303 
getTagnull304     override fun SerialDescriptor.getTag(index: Int) = extractParameters(index)
305 
306     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
307         try {
308             while (true) {
309                 val protoId = reader.readTag()
310                 if (protoId == -1) { // EOF
311                     return elementMarker.nextUnmarkedIndex()
312                 }
313                 if (protoId == 0) {
314                     throw SerializationException("0 is not allowed as the protobuf field number in ${descriptor.serialName}, the input bytes may have been corrupted")
315                 }
316                 val index = getIndexByNum(protoId)
317                 if (index == -1) { // not found
318                     reader.skipElement()
319                 } else {
320                     if (descriptor.extractParameters(index).isOneOf) {
321                         /**
322                          * While decoding message with one-of field,
323                          * the proto id read from wire data cannot be easily found
324                          * in the properties of this type,
325                          * So the index of this one-of property and the id read from the wire
326                          * are saved in this map, then restored in [beginStructure]
327                          * and passed to [OneOfPolymorphicReader] to get the actual deserializer.
328                          */
329                         index2IdMap?.put(index, protoId)
330                     }
331                     elementMarker.mark(index)
332                     return index
333                 }
334             }
335         } catch (e: ProtobufDecodingException) {
336             throw ProtobufDecodingException("Fail to get element index for ${descriptor.serialName} in ${this.descriptor.serialName}", e)
337         }
338     }
339 
decodeNotNullMarknull340     override fun decodeNotNullMark(): Boolean {
341         return !nullValue
342     }
343 
readIfAbsentnull344     private fun readIfAbsent(descriptor: SerialDescriptor, index: Int): Boolean {
345         if (!descriptor.isElementOptional(index)) {
346             val elementDescriptor = descriptor.getElementDescriptor(index)
347             val kind = elementDescriptor.kind
348             if (kind == StructureKind.MAP || kind == StructureKind.LIST) {
349                 nullValue = false
350                 return true
351             } else if (elementDescriptor.isNullable) {
352                 nullValue = true
353                 return true
354             }
355         }
356         return false
357     }
358 
decodeOrThrownull359     private inline fun <T> decodeOrThrow(tag: ProtoDesc, action: (tag: ProtoDesc) -> T): T {
360         try {
361             return action(tag)
362         } catch (e: ProtobufDecodingException) {
363             rethrowException(tag, e)
364         }
365     }
366 
367     @Suppress("NOTHING_TO_INLINE")
rethrowExceptionnull368     private inline fun rethrowException(tag: ProtoDesc, e: ProtobufDecodingException): Nothing {
369         throw ProtobufDecodingException("Error while decoding proto number ${tag.protoId} of ${descriptor.serialName}", e)
370     }
371 }
372 
373 private class RepeatedDecoder(
374     proto: ProtoBuf,
375     decoder: ProtobufReader,
376     currentTag: ProtoDesc,
377     descriptor: SerialDescriptor
378 ) : ProtobufDecoder(proto, decoder, descriptor) {
379     // Current index
380     private var index = -1
381 
382     /*
383      * For regular messages, it is always a tag.
384      * For out-of-spec top-level lists (and maps) the very first varint
385      * represents this list size. It is stored in a single variable
386      * as negative value and branched based on that fact.
387      */
388     private val tagOrSize: Long
389 
390     init {
391         tagOrSize = if (currentTag == MISSING_TAG) {
392             val length = reader.readInt32NoTag()
<lambda>null393             require(length >= 0) { "Expected positive length for $descriptor, but got $length" }
394             -length.toLong()
395         } else {
396             currentTag
397         }
398     }
399 
decodeElementIndexnull400     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
401         if (tagOrSize > 0) {
402             return decodeTaggedListIndex()
403         }
404         return decodeListIndexNoTag()
405     }
406 
decodeListIndexNoTagnull407     private fun decodeListIndexNoTag(): Int {
408         val size = -tagOrSize
409         val idx = ++index
410         // Check for eof is here for the case that it is an out-of-spec packed array where size is bytesize not list length.
411         if (idx.toLong() == size || reader.eof) return CompositeDecoder.DECODE_DONE
412         return idx
413     }
414 
decodeTaggedListIndexnull415     private fun decodeTaggedListIndex(): Int {
416         val protoId = if (index == -1) {
417             // For the very first element tag is already read by the parent
418             reader.currentId
419         } else {
420             reader.readTag()
421         }
422 
423         return if (protoId == tagOrSize.protoId) {
424             ++index
425         } else {
426             // If we read tag of a different message, push it back to the reader and bail out
427             reader.pushBackTag()
428             CompositeDecoder.DECODE_DONE
429         }
430     }
431 
getTagnull432     override fun SerialDescriptor.getTag(index: Int): ProtoDesc {
433         if (tagOrSize > 0) return tagOrSize
434         return MISSING_TAG
435     }
436 }
437 
438 private class MapEntryReader(
439     proto: ProtoBuf,
440     decoder: ProtobufReader,
441     @JvmField val parentTag: ProtoDesc,
442     descriptor: SerialDescriptor
443 ) : ProtobufDecoder(proto, decoder, descriptor) {
getTagnull444     override fun SerialDescriptor.getTag(index: Int): ProtoDesc =
445         if (index % 2 == 0) ProtoDesc(1, (parentTag.integerType))
446         else ProtoDesc(2, (parentTag.integerType))
447 }
448 
449 private class OneOfPolymorphicReader(
450     proto: ProtoBuf,
451     decoder: ProtobufReader,
452     private val parentTag: ProtoDesc,
453     descriptor: SerialDescriptor
454 ) : ProtobufDecoder(proto, decoder, descriptor) {
455     private var serialNameDecoded = false
456     private var contentDecoded = false
457     override fun SerialDescriptor.getTag(index: Int): ProtoDesc = if (index == 0) {
458         POLYMORPHIC_NAME_TAG
459     } else {
460         extractParameters(0)
461     }
462 
463     override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
464         return if (descriptor == this.descriptor) {
465             this
466         } else {
467             OneOfElementReader(proto, reader, descriptor)
468         }
469     }
470 
471     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
472         if (!serialNameDecoded) {
473             serialNameDecoded = true
474             return 0
475         } else if (!contentDecoded) {
476             contentDecoded = true
477             return 1
478         } else {
479             return CompositeDecoder.DECODE_DONE
480         }
481     }
482 
483     override fun decodeTaggedString(tag: ProtoDesc): String = if (tag == POLYMORPHIC_NAME_TAG) {
484         // This exception will neven be thrown
485         // Subclass of oneof-field without matching ProtoNum annotated will be skipped in outer [decodeElementIndex]
486         // and raise a [MissingFieldException]
487         descriptor.getActualOneOfSerializer(serializersModule, parentTag.protoId)?.serialName ?: throw SerializationException(
488             "Cannot find a subclass of ${descriptor.serialName} annotated with @ProtoNumber(${parentTag.protoId})."
489         )
490     } else {
491         super.decodeTaggedString(tag)
492     }
493 }
494 
495 private class OneOfElementReader(
496     proto: ProtoBuf,
497     decoder: ProtobufReader,
498     descriptor: SerialDescriptor
499 ) : ProtobufDecoder(proto, decoder, descriptor) {
500     private val classId: Int
501     init {
<lambda>null502         require(descriptor.elementsCount == 1) {
503             "Implementation of oneOf type ${descriptor.serialName} should contain only 1 element, but get ${descriptor.elementsCount}"
504         }
505         val protoNumber = descriptor.getElementAnnotations(0).filterIsInstance<ProtoNumber>().singleOrNull()
<lambda>null506         require(protoNumber != null) {
507             "Implementation of oneOf type ${descriptor.serialName} should have @ProtoNumber annotation"
508         }
509         classId = protoNumber.number
510     }
511 
512     private var contentDecoded: Boolean = false
513 
beginStructurenull514     override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
515         return when(descriptor.kind) {
516             StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
517                 val tag = currentTagOrDefault
518                 // Do not create redundant copy
519                 if (tag == MISSING_TAG && this.descriptor == descriptor) return this
520                 if (tag.isOneOf) throw SerializationException("An oneof element cannot be directly child of another oneof element")
521                 ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
522             }
523             else -> {
524                 throw SerializationException("Type ${descriptor.kind} cannot be directly child of oneof element")
525             }
526         }
527     }
528 
decodeElementIndexnull529     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
530         return if (contentDecoded) {
531             -1
532         }
533         else {
534             contentDecoded = true
535             0
536         }
537     }
538 }
539 
makeDelimitednull540 private fun makeDelimited(decoder: ProtobufReader, parentTag: ProtoDesc): ProtobufReader {
541     val tagless = parentTag == MISSING_TAG
542     val input = if (tagless) decoder.objectTaglessInput() else decoder.objectInput()
543     return ProtobufReader(input)
544 }
545 
makeDelimitedForcednull546 private fun makeDelimitedForced(decoder: ProtobufReader, parentTag: ProtoDesc): ProtobufReader {
547     val tagless = parentTag == MISSING_TAG
548     val input = if (tagless) decoder.objectTaglessInput() else decoder.objectInput()
549     return ProtobufReader(input)
550 }
551