1 /* 2 * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. 3 */ 4 @file:OptIn(ExperimentalSerializationApi::class) 5 6 package kotlinx.serialization.protobuf.internal 7 8 import kotlinx.serialization.* 9 import kotlinx.serialization.protobuf.* 10 import kotlin.jvm.* 11 12 internal class ProtobufReader(private val input: ByteArrayInput) { 13 @JvmField 14 public var currentId = -1 15 @JvmField 16 public var currentType = ProtoWireType.INVALID 17 private var pushBack = false 18 private var pushBackHeader = 0 19 20 public val eof 21 get() = !pushBack && input.availableBytes == 0 22 readTagnull23 public fun readTag(): Int { 24 if (pushBack) { 25 pushBack = false 26 val previousHeader = (currentId shl 3) or currentType.typeId 27 return updateIdAndType(pushBackHeader).also { 28 pushBackHeader = previousHeader 29 } 30 } 31 // Header to use when pushed back is the old id/type 32 pushBackHeader = (currentId shl 3) or currentType.typeId 33 34 val header = input.readVarint64(true).toInt() 35 return updateIdAndType(header) 36 } 37 updateIdAndTypenull38 private fun updateIdAndType(header: Int): Int { 39 return if (header == -1) { 40 currentId = -1 41 currentType = ProtoWireType.INVALID 42 -1 43 } else { 44 currentId = header ushr 3 45 currentType = ProtoWireType.from(header and 0b111) 46 currentId 47 } 48 } 49 pushBackTagnull50 public fun pushBackTag() { 51 pushBack = true 52 53 val nextHeader = (currentId shl 3) or currentType.typeId 54 updateIdAndType(pushBackHeader) 55 pushBackHeader = nextHeader 56 } 57 skipElementnull58 fun skipElement() { 59 when (currentType) { 60 ProtoWireType.VARINT -> readInt(ProtoIntegerType.DEFAULT) 61 ProtoWireType.i64 -> readLong(ProtoIntegerType.FIXED) 62 ProtoWireType.SIZE_DELIMITED -> skipSizeDelimited() 63 ProtoWireType.i32 -> readInt(ProtoIntegerType.FIXED) 64 else -> throw ProtobufDecodingException("Unsupported start group or end group wire type: $currentType") 65 } 66 } 67 68 @Suppress("NOTHING_TO_INLINE") assertWireTypenull69 private inline fun assertWireType(expected: ProtoWireType) { 70 if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found $currentType") 71 } 72 readByteArraynull73 fun readByteArray(): ByteArray { 74 assertWireType(ProtoWireType.SIZE_DELIMITED) 75 return readByteArrayNoTag() 76 } 77 skipSizeDelimitednull78 fun skipSizeDelimited() { 79 assertWireType(ProtoWireType.SIZE_DELIMITED) 80 val length = decode32() 81 checkLength(length) 82 input.skipExactNBytes(length) 83 } 84 readByteArrayNoTagnull85 fun readByteArrayNoTag(): ByteArray { 86 val length = decode32() 87 checkLength(length) 88 return input.readExactNBytes(length) 89 } 90 objectInputnull91 fun objectInput(): ByteArrayInput { 92 assertWireType(ProtoWireType.SIZE_DELIMITED) 93 return objectTaglessInput() 94 } 95 objectTaglessInputnull96 fun objectTaglessInput(): ByteArrayInput { 97 val length = decode32() 98 checkLength(length) 99 return input.slice(length) 100 } 101 readIntnull102 fun readInt(format: ProtoIntegerType): Int { 103 val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i32 else ProtoWireType.VARINT 104 assertWireType(wireType) 105 return decode32(format) 106 } 107 readInt32NoTagnull108 fun readInt32NoTag(): Int = decode32() 109 110 fun readLong(format: ProtoIntegerType): Long { 111 val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i64 else ProtoWireType.VARINT 112 assertWireType(wireType) 113 return decode64(format) 114 } 115 readLongNoTagnull116 fun readLongNoTag(): Long = decode64(ProtoIntegerType.DEFAULT) 117 118 fun readFloat(): Float { 119 assertWireType(ProtoWireType.i32) 120 return Float.fromBits(readIntLittleEndian()) 121 } 122 readFloatNoTagnull123 fun readFloatNoTag(): Float = Float.fromBits(readIntLittleEndian()) 124 125 private fun readIntLittleEndian(): Int { 126 // TODO this could be optimized by extracting method to the IS 127 var result = 0 128 for (i in 0..3) { 129 val byte = input.read() and 0x000000FF 130 result = result or (byte shl (i * 8)) 131 } 132 return result 133 } 134 readLongLittleEndiannull135 private fun readLongLittleEndian(): Long { 136 // TODO this could be optimized by extracting method to the IS 137 var result = 0L 138 for (i in 0..7) { 139 val byte = (input.read() and 0x000000FF).toLong() 140 result = result or (byte shl (i * 8)) 141 } 142 return result 143 } 144 readDoublenull145 fun readDouble(): Double { 146 assertWireType(ProtoWireType.i64) 147 return Double.fromBits(readLongLittleEndian()) 148 } 149 readDoubleNoTagnull150 fun readDoubleNoTag(): Double { 151 return Double.fromBits(readLongLittleEndian()) 152 } 153 readStringnull154 fun readString(): String { 155 assertWireType(ProtoWireType.SIZE_DELIMITED) 156 val length = decode32() 157 checkLength(length) 158 return input.readString(length) 159 } 160 readStringNoTagnull161 fun readStringNoTag(): String { 162 val length = decode32() 163 checkLength(length) 164 return input.readString(length) 165 } 166 checkLengthnull167 private fun checkLength(length: Int) { 168 if (length < 0) { 169 throw ProtobufDecodingException("Unexpected negative length: $length") 170 } 171 } 172 decode32null173 private fun decode32(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Int = when (format) { 174 ProtoIntegerType.DEFAULT -> input.readVarint64(false).toInt() 175 ProtoIntegerType.SIGNED -> decodeSignedVarintInt( 176 input 177 ) 178 ProtoIntegerType.FIXED -> readIntLittleEndian() 179 } 180 decode64null181 private fun decode64(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Long = when (format) { 182 ProtoIntegerType.DEFAULT -> input.readVarint64(false) 183 ProtoIntegerType.SIGNED -> decodeSignedVarintLong( 184 input 185 ) 186 ProtoIntegerType.FIXED -> readLongLittleEndian() 187 } 188 189 /** 190 * Source for all varint operations: 191 * https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java 192 */ decodeSignedVarintIntnull193 private fun decodeSignedVarintInt(input: ByteArrayInput): Int { 194 val raw = input.readVarint32() 195 val temp = raw shl 31 shr 31 xor raw shr 1 196 // This extra step lets us deal with the largest signed values by treating 197 // negative results from read unsigned methods as like unsigned values. 198 // Must re-flip the top bit if the original read value had it set. 199 return temp xor (raw and (1 shl 31)) 200 } 201 decodeSignedVarintLongnull202 private fun decodeSignedVarintLong(input: ByteArrayInput): Long { 203 val raw = input.readVarint64(false) 204 val temp = raw shl 63 shr 63 xor raw shr 1 205 // This extra step lets us deal with the largest signed values by treating 206 // negative results from read unsigned methods as like unsigned values 207 // Must re-flip the top bit if the original read value had it set. 208 return temp xor (raw and (1L shl 63)) 209 210 } 211 } 212