1 /*
2 * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5 @file:OptIn(ExperimentalSerializationApi::class)
6 @file:Suppress("UNCHECKED_CAST")
7
8 package kotlinx.serialization.protobuf.internal
9
10 import kotlinx.serialization.*
11 import kotlinx.serialization.builtins.*
12 import kotlinx.serialization.descriptors.*
13 import kotlinx.serialization.encoding.*
14 import kotlinx.serialization.internal.*
15 import kotlinx.serialization.modules.*
16 import kotlinx.serialization.protobuf.*
17 import kotlin.jvm.*
18
19 internal open class ProtobufDecoder(
20 @JvmField protected val proto: ProtoBuf,
21 @JvmField protected val reader: ProtobufReader,
22 @JvmField protected val descriptor: SerialDescriptor
23 ) : ProtobufTaggedDecoder() {
24 override val serializersModule: SerializersModule
25 get() = proto.serializersModule
26
27 // Proto id -> index in serial descriptor cache
28 private var indexCache: IntArray? = null
29 private var sparseIndexCache: MutableMap<Int, Int>? = null
30
31 // Index -> proto id for oneof element. An oneof element of certain index may refer to different proto id in runtime.
32 private var index2IdMap: MutableMap<Int, Int>? = null
33
34 private var nullValue: Boolean = false
35 private val elementMarker = ElementMarker(descriptor, ::readIfAbsent)
36
37 init {
38 populateCache(descriptor)
39 }
40
populateCachenull41 public fun populateCache(descriptor: SerialDescriptor) {
42 val elements = descriptor.elementsCount
43 if (elements < 32) {
44 /*
45 * If we have reasonably small count of elements, try to build sequential
46 * array for the fast-path. Fast-path implies that elements are not marked with @ProtoId
47 * explicitly or are monotonic and incremental (maybe, 1-indexed)
48 *
49 * Initialize all elements, because there will always be one extra element as arrays are numbered from 0
50 * but in protobuf field number starts from 1.
51 */
52 val cache = IntArray(elements + 1) { -1 }
53 for (i in 0 until elements) {
54 val protoId = extractProtoId(descriptor, i, false)
55 // If any element is marked as ProtoOneOf,
56 // the fast path is not applicable
57 // because it will contain more id than elements
58 if (protoId <= elements && protoId != ID_HOLDER_ONE_OF) {
59 cache[protoId] = i
60 } else {
61 return populateCacheMap(descriptor, elements)
62 }
63 }
64 indexCache = cache
65 } else {
66 populateCacheMap(descriptor, elements)
67 }
68 }
69
populateCacheMapnull70 private fun populateCacheMap(descriptor: SerialDescriptor, elements: Int) {
71 val map = HashMap<Int, Int>(elements, 1f)
72 var oneOfCount = 0
73 for (i in 0 until elements) {
74 val id = extractProtoId(descriptor, i, false)
75 if (id == ID_HOLDER_ONE_OF) {
76 descriptor.getElementDescriptor(i)
77 .getAllOneOfSerializerOfField(serializersModule)
78 .map { it.extractParameters(0).protoId }
79 .forEach { map.putProtoId(it, i) }
80 oneOfCount ++
81 } else {
82 map.putProtoId(extractProtoId(descriptor, i, false), i)
83 }
84 }
85 if (oneOfCount > 0) {
86 index2IdMap = HashMap(oneOfCount, 1f)
87 }
88 sparseIndexCache = map
89 }
90
MutableMapnull91 private fun MutableMap<Int, Int>.putProtoId(protoId: Int, index: Int) {
92 put(protoId, index)
93 }
94
getIndexByNumnull95 private fun getIndexByNum(protoNum: Int): Int {
96 val array = indexCache
97 if (array != null) {
98 return array.getOrElse(protoNum) { -1 }
99 }
100 return getIndexByNumSlowPath(protoNum)
101 }
102
getIndexByNumSlowPathnull103 private fun getIndexByNumSlowPath(
104 protoTag: Int
105 ): Int = sparseIndexCache!!.getOrElse(protoTag) { -1 }
106
findIndexByTagnull107 private fun findIndexByTag(descriptor: SerialDescriptor, protoTag: Int): Int {
108 // Fast-path: tags are incremental, 1-based
109 if (protoTag < descriptor.elementsCount && protoTag >= 0) {
110 val protoId = extractProtoId(descriptor, protoTag, true)
111 if (protoId == protoTag) return protoTag
112 }
113 return findIndexByTagSlowPath(descriptor, protoTag)
114 }
115
findIndexByTagSlowPathnull116 private fun findIndexByTagSlowPath(desc: SerialDescriptor, protoTag: Int): Int {
117 for (i in 0 until desc.elementsCount) {
118 val protoId = extractProtoId(desc, i, true)
119 if (protoId == protoTag) return i
120 }
121
122 throw ProtobufDecodingException(
123 "$protoTag is not among valid ${descriptor.serialName} enum proto numbers"
124 )
125 }
126
beginStructurenull127 override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
128 return try {
129 when (descriptor.kind) {
130 StructureKind.LIST -> {
131 val tag = currentTagOrDefault
132 return if (this.descriptor.kind == StructureKind.LIST && tag != MISSING_TAG && this.descriptor != descriptor) {
133 val reader = makeDelimited(reader, tag)
134 // repeated decoder expects the first tag to be read already
135 reader.readTag()
136 // all elements always have id = 1
137 RepeatedDecoder(proto, reader, ProtoDesc(1, ProtoIntegerType.DEFAULT), descriptor)
138
139 } else if (reader.currentType == ProtoWireType.SIZE_DELIMITED && descriptor.getElementDescriptor(0).isPackable) {
140 val sliceReader = ProtobufReader(reader.objectInput())
141 PackedArrayDecoder(proto, sliceReader, descriptor)
142
143 } else {
144 RepeatedDecoder(proto, reader, tag, descriptor)
145 }
146 }
147
148 StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
149 val tag = currentTagOrDefault
150 // Do not create redundant copy
151 if (tag == MISSING_TAG && this.descriptor == descriptor) return this
152 if (tag.isOneOf) {
153 // If a tag is annotated as oneof
154 // [tag.protoId] here is overwritten with index-based default id in
155 // [kotlinx.serialization.protobuf.internal.HelpersKt.extractParameters]
156 // and restored the real id from index2IdMap, set by [decodeElementIndex]
157 val rawIndex = tag.protoId - 1
158 val restoredTag = index2IdMap?.get(rawIndex)?.let { tag.overrideId(it) } ?: tag
159 return OneOfPolymorphicReader(proto, reader, restoredTag, descriptor)
160 }
161 return ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
162 }
163
164 StructureKind.MAP -> MapEntryReader(
165 proto,
166 makeDelimitedForced(reader, currentTagOrDefault),
167 currentTagOrDefault,
168 descriptor
169 )
170
171 else -> throw SerializationException("Primitives are not supported at top-level")
172 }
173 } catch (e: ProtobufDecodingException) {
174 throw ProtobufDecodingException("Fail to begin structure for ${descriptor.serialName} in ${this.descriptor.serialName} at proto number ${currentTagOrDefault.protoId}", e)
175 }
176 }
177
endStructurenull178 override fun endStructure(descriptor: SerialDescriptor) {
179 // Nothing
180 }
181
decodeTaggedBooleannull182 override fun decodeTaggedBoolean(tag: ProtoDesc): Boolean = when(val value = decodeTaggedInt(tag)) {
183 0 -> false
184 1 -> true
185 else -> throw SerializationException("Unexpected boolean value: $value")
186 }
187
decodeTaggedBytenull188 override fun decodeTaggedByte(tag: ProtoDesc): Byte = decodeTaggedInt(tag).toByte()
189 override fun decodeTaggedShort(tag: ProtoDesc): Short = decodeTaggedInt(tag).toShort()
190 override fun decodeTaggedInt(tag: ProtoDesc): Int {
191 return decodeOrThrow(tag) {
192 if (tag == MISSING_TAG) {
193 reader.readInt32NoTag()
194 } else {
195 reader.readInt(tag.integerType)
196 }
197 }
198 }
decodeTaggedLongnull199 override fun decodeTaggedLong(tag: ProtoDesc): Long {
200 return decodeOrThrow(tag) {
201 if (tag == MISSING_TAG) {
202 reader.readLongNoTag()
203 } else {
204 reader.readLong(tag.integerType)
205 }
206 }
207 }
208
decodeTaggedFloatnull209 override fun decodeTaggedFloat(tag: ProtoDesc): Float {
210 return decodeOrThrow(tag) {
211 if (tag == MISSING_TAG) {
212 reader.readFloatNoTag()
213 } else {
214 reader.readFloat()
215 }
216 }
217 }
decodeTaggedDoublenull218 override fun decodeTaggedDouble(tag: ProtoDesc): Double {
219 return decodeOrThrow(tag) {
220 if (tag == MISSING_TAG) {
221 reader.readDoubleNoTag()
222 } else {
223 reader.readDouble()
224 }
225 }
226 }
decodeTaggedCharnull227 override fun decodeTaggedChar(tag: ProtoDesc): Char = decodeTaggedInt(tag).toChar()
228
229 override fun decodeTaggedString(tag: ProtoDesc): String {
230 return decodeOrThrow(tag) {
231 if (tag == MISSING_TAG) {
232 reader.readStringNoTag()
233 } else {
234 reader.readString()
235 }
236 }
237 }
238
decodeTaggedEnumnull239 override fun decodeTaggedEnum(tag: ProtoDesc, enumDescription: SerialDescriptor): Int {
240 return findIndexByTag(enumDescription, decodeTaggedInt(tag))
241 }
242
decodeSerializableValuenull243 override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T = decodeSerializableValue(deserializer, null)
244
245 @Suppress("UNCHECKED_CAST")
246 override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>, previousValue: T?): T = try {
247 when {
248 deserializer is MapLikeSerializer<*, *, *, *> -> {
249 deserializeMap(deserializer as DeserializationStrategy<T>, previousValue)
250 }
251
252 deserializer.descriptor == ByteArraySerializer().descriptor -> deserializeByteArray(previousValue as ByteArray?) as T
253 deserializer is AbstractCollectionSerializer<*, *, *> ->
254 (deserializer as AbstractCollectionSerializer<*, T, *>).merge(this, previousValue)
255
256 else -> deserializer.deserialize(this)
257 }
258 } catch (e: ProtobufDecodingException) {
259 val currentTag = currentTagOrDefault
260 val msg = if (descriptor != deserializer.descriptor) {
261 // Decoding child element
262 if (descriptor.kind == StructureKind.LIST && deserializer.descriptor.kind != StructureKind.MAP) {
263 // Decoding repeated field
264 "Error while decoding index ${currentTag.protoId - 1} in repeated field of ${deserializer.descriptor.serialName}"
265 } else if (descriptor.kind == StructureKind.MAP) {
266 // Decoding map field
267 val index = (currentTag.protoId - 1) / 2
268 val field = if ((currentTag.protoId - 1) % 2 == 0) { "key" } else "value"
269 "Error while decoding $field of index $index in map field of ${deserializer.descriptor.serialName}"
270 } else {
271 // Decoding common class
272 "Error while decoding ${deserializer.descriptor.serialName} at proto number ${currentTag.protoId} of ${descriptor.serialName}"
273 }
274 } else {
275 // Decoding self
276 "Error while decoding ${descriptor.serialName}"
277 }
278 throw ProtobufDecodingException(msg, e)
279 }
280
deserializeByteArraynull281 private fun deserializeByteArray(previousValue: ByteArray?): ByteArray {
282 val tag = currentTagOrDefault
283 val array = decodeOrThrow(tag) {
284 if (tag == MISSING_TAG) {
285 reader.readByteArrayNoTag()
286 } else {
287 reader.readByteArray()
288 }
289 }
290 return if (previousValue == null) array else previousValue + array
291 }
292
293 @Suppress("UNCHECKED_CAST")
deserializeMapnull294 private fun <T> deserializeMap(deserializer: DeserializationStrategy<T>, previousValue: T?): T {
295 val serializer = (deserializer as MapLikeSerializer<Any?, Any?, T, *>)
296 // Yeah thanks different resolution algorithms
297 val mapEntrySerial =
298 kotlinx.serialization.builtins.MapEntrySerializer(serializer.keySerializer, serializer.valueSerializer)
299 val oldSet = (previousValue as? Map<Any?, Any?>)?.entries
300 val setOfEntries = (SetSerializer(mapEntrySerial) as AbstractCollectionSerializer<Map.Entry<Any?, Any?>, Set<Map.Entry<Any?, Any?>>, *>).merge(this, oldSet)
301 return setOfEntries.associateBy({ it.key }, { it.value }) as T
302 }
303
getTagnull304 override fun SerialDescriptor.getTag(index: Int) = extractParameters(index)
305
306 override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
307 try {
308 while (true) {
309 val protoId = reader.readTag()
310 if (protoId == -1) { // EOF
311 return elementMarker.nextUnmarkedIndex()
312 }
313 if (protoId == 0) {
314 throw SerializationException("0 is not allowed as the protobuf field number in ${descriptor.serialName}, the input bytes may have been corrupted")
315 }
316 val index = getIndexByNum(protoId)
317 if (index == -1) { // not found
318 reader.skipElement()
319 } else {
320 if (descriptor.extractParameters(index).isOneOf) {
321 /**
322 * While decoding message with one-of field,
323 * the proto id read from wire data cannot be easily found
324 * in the properties of this type,
325 * So the index of this one-of property and the id read from the wire
326 * are saved in this map, then restored in [beginStructure]
327 * and passed to [OneOfPolymorphicReader] to get the actual deserializer.
328 */
329 index2IdMap?.put(index, protoId)
330 }
331 elementMarker.mark(index)
332 return index
333 }
334 }
335 } catch (e: ProtobufDecodingException) {
336 throw ProtobufDecodingException("Fail to get element index for ${descriptor.serialName} in ${this.descriptor.serialName}", e)
337 }
338 }
339
decodeNotNullMarknull340 override fun decodeNotNullMark(): Boolean {
341 return !nullValue
342 }
343
readIfAbsentnull344 private fun readIfAbsent(descriptor: SerialDescriptor, index: Int): Boolean {
345 if (!descriptor.isElementOptional(index)) {
346 val elementDescriptor = descriptor.getElementDescriptor(index)
347 val kind = elementDescriptor.kind
348 if (kind == StructureKind.MAP || kind == StructureKind.LIST) {
349 nullValue = false
350 return true
351 } else if (elementDescriptor.isNullable) {
352 nullValue = true
353 return true
354 }
355 }
356 return false
357 }
358
decodeOrThrownull359 private inline fun <T> decodeOrThrow(tag: ProtoDesc, action: (tag: ProtoDesc) -> T): T {
360 try {
361 return action(tag)
362 } catch (e: ProtobufDecodingException) {
363 rethrowException(tag, e)
364 }
365 }
366
367 @Suppress("NOTHING_TO_INLINE")
rethrowExceptionnull368 private inline fun rethrowException(tag: ProtoDesc, e: ProtobufDecodingException): Nothing {
369 throw ProtobufDecodingException("Error while decoding proto number ${tag.protoId} of ${descriptor.serialName}", e)
370 }
371 }
372
373 private class RepeatedDecoder(
374 proto: ProtoBuf,
375 decoder: ProtobufReader,
376 currentTag: ProtoDesc,
377 descriptor: SerialDescriptor
378 ) : ProtobufDecoder(proto, decoder, descriptor) {
379 // Current index
380 private var index = -1
381
382 /*
383 * For regular messages, it is always a tag.
384 * For out-of-spec top-level lists (and maps) the very first varint
385 * represents this list size. It is stored in a single variable
386 * as negative value and branched based on that fact.
387 */
388 private val tagOrSize: Long
389
390 init {
391 tagOrSize = if (currentTag == MISSING_TAG) {
392 val length = reader.readInt32NoTag()
<lambda>null393 require(length >= 0) { "Expected positive length for $descriptor, but got $length" }
394 -length.toLong()
395 } else {
396 currentTag
397 }
398 }
399
decodeElementIndexnull400 override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
401 if (tagOrSize > 0) {
402 return decodeTaggedListIndex()
403 }
404 return decodeListIndexNoTag()
405 }
406
decodeListIndexNoTagnull407 private fun decodeListIndexNoTag(): Int {
408 val size = -tagOrSize
409 val idx = ++index
410 // Check for eof is here for the case that it is an out-of-spec packed array where size is bytesize not list length.
411 if (idx.toLong() == size || reader.eof) return CompositeDecoder.DECODE_DONE
412 return idx
413 }
414
decodeTaggedListIndexnull415 private fun decodeTaggedListIndex(): Int {
416 val protoId = if (index == -1) {
417 // For the very first element tag is already read by the parent
418 reader.currentId
419 } else {
420 reader.readTag()
421 }
422
423 return if (protoId == tagOrSize.protoId) {
424 ++index
425 } else {
426 // If we read tag of a different message, push it back to the reader and bail out
427 reader.pushBackTag()
428 CompositeDecoder.DECODE_DONE
429 }
430 }
431
getTagnull432 override fun SerialDescriptor.getTag(index: Int): ProtoDesc {
433 if (tagOrSize > 0) return tagOrSize
434 return MISSING_TAG
435 }
436 }
437
438 private class MapEntryReader(
439 proto: ProtoBuf,
440 decoder: ProtobufReader,
441 @JvmField val parentTag: ProtoDesc,
442 descriptor: SerialDescriptor
443 ) : ProtobufDecoder(proto, decoder, descriptor) {
getTagnull444 override fun SerialDescriptor.getTag(index: Int): ProtoDesc =
445 if (index % 2 == 0) ProtoDesc(1, (parentTag.integerType))
446 else ProtoDesc(2, (parentTag.integerType))
447 }
448
449 private class OneOfPolymorphicReader(
450 proto: ProtoBuf,
451 decoder: ProtobufReader,
452 private val parentTag: ProtoDesc,
453 descriptor: SerialDescriptor
454 ) : ProtobufDecoder(proto, decoder, descriptor) {
455 private var serialNameDecoded = false
456 private var contentDecoded = false
457 override fun SerialDescriptor.getTag(index: Int): ProtoDesc = if (index == 0) {
458 POLYMORPHIC_NAME_TAG
459 } else {
460 extractParameters(0)
461 }
462
463 override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
464 return if (descriptor == this.descriptor) {
465 this
466 } else {
467 OneOfElementReader(proto, reader, descriptor)
468 }
469 }
470
471 override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
472 if (!serialNameDecoded) {
473 serialNameDecoded = true
474 return 0
475 } else if (!contentDecoded) {
476 contentDecoded = true
477 return 1
478 } else {
479 return CompositeDecoder.DECODE_DONE
480 }
481 }
482
483 override fun decodeTaggedString(tag: ProtoDesc): String = if (tag == POLYMORPHIC_NAME_TAG) {
484 // This exception will neven be thrown
485 // Subclass of oneof-field without matching ProtoNum annotated will be skipped in outer [decodeElementIndex]
486 // and raise a [MissingFieldException]
487 descriptor.getActualOneOfSerializer(serializersModule, parentTag.protoId)?.serialName ?: throw SerializationException(
488 "Cannot find a subclass of ${descriptor.serialName} annotated with @ProtoNumber(${parentTag.protoId})."
489 )
490 } else {
491 super.decodeTaggedString(tag)
492 }
493 }
494
495 private class OneOfElementReader(
496 proto: ProtoBuf,
497 decoder: ProtobufReader,
498 descriptor: SerialDescriptor
499 ) : ProtobufDecoder(proto, decoder, descriptor) {
500 private val classId: Int
501 init {
<lambda>null502 require(descriptor.elementsCount == 1) {
503 "Implementation of oneOf type ${descriptor.serialName} should contain only 1 element, but get ${descriptor.elementsCount}"
504 }
505 val protoNumber = descriptor.getElementAnnotations(0).filterIsInstance<ProtoNumber>().singleOrNull()
<lambda>null506 require(protoNumber != null) {
507 "Implementation of oneOf type ${descriptor.serialName} should have @ProtoNumber annotation"
508 }
509 classId = protoNumber.number
510 }
511
512 private var contentDecoded: Boolean = false
513
beginStructurenull514 override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
515 return when(descriptor.kind) {
516 StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
517 val tag = currentTagOrDefault
518 // Do not create redundant copy
519 if (tag == MISSING_TAG && this.descriptor == descriptor) return this
520 if (tag.isOneOf) throw SerializationException("An oneof element cannot be directly child of another oneof element")
521 ProtobufDecoder(proto, makeDelimited(reader, tag), descriptor)
522 }
523 else -> {
524 throw SerializationException("Type ${descriptor.kind} cannot be directly child of oneof element")
525 }
526 }
527 }
528
decodeElementIndexnull529 override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
530 return if (contentDecoded) {
531 -1
532 }
533 else {
534 contentDecoded = true
535 0
536 }
537 }
538 }
539
makeDelimitednull540 private fun makeDelimited(decoder: ProtobufReader, parentTag: ProtoDesc): ProtobufReader {
541 val tagless = parentTag == MISSING_TAG
542 val input = if (tagless) decoder.objectTaglessInput() else decoder.objectInput()
543 return ProtobufReader(input)
544 }
545
makeDelimitedForcednull546 private fun makeDelimitedForced(decoder: ProtobufReader, parentTag: ProtoDesc): ProtobufReader {
547 val tagless = parentTag == MISSING_TAG
548 val input = if (tagless) decoder.objectTaglessInput() else decoder.objectInput()
549 return ProtobufReader(input)
550 }
551