• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 @file:OptIn(ExperimentalSerializationApi::class)
5 
6 package kotlinx.serialization.protobuf.internal
7 
8 import kotlinx.serialization.*
9 import kotlinx.serialization.protobuf.*
10 import kotlin.jvm.*
11 
12 internal class ProtobufReader(private val input: ByteArrayInput) {
13     @JvmField
14     public var currentId = -1
15     @JvmField
16     public var currentType = ProtoWireType.INVALID
17     private var pushBack = false
18     private var pushBackHeader = 0
19 
20     public val eof
21         get() = !pushBack && input.availableBytes == 0
22 
readTagnull23     public fun readTag(): Int {
24         if (pushBack) {
25             pushBack = false
26             val previousHeader = (currentId shl 3) or currentType.typeId
27             return updateIdAndType(pushBackHeader).also {
28                 pushBackHeader = previousHeader
29             }
30         }
31         // Header to use when pushed back is the old id/type
32         pushBackHeader = (currentId shl 3) or currentType.typeId
33 
34         val header = input.readVarint64(true).toInt()
35         return updateIdAndType(header)
36     }
37 
updateIdAndTypenull38     private fun updateIdAndType(header: Int): Int {
39         return if (header == -1) {
40             currentId = -1
41             currentType = ProtoWireType.INVALID
42             -1
43         } else {
44             currentId = header ushr 3
45             currentType = ProtoWireType.from(header and 0b111)
46             currentId
47         }
48     }
49 
pushBackTagnull50     public fun pushBackTag() {
51         pushBack = true
52 
53         val nextHeader = (currentId shl 3) or currentType.typeId
54         updateIdAndType(pushBackHeader)
55         pushBackHeader = nextHeader
56     }
57 
skipElementnull58     fun skipElement() {
59         when (currentType) {
60             ProtoWireType.VARINT -> readInt(ProtoIntegerType.DEFAULT)
61             ProtoWireType.i64 -> readLong(ProtoIntegerType.FIXED)
62             ProtoWireType.SIZE_DELIMITED -> skipSizeDelimited()
63             ProtoWireType.i32 -> readInt(ProtoIntegerType.FIXED)
64             else -> throw ProtobufDecodingException("Unsupported start group or end group wire type: $currentType")
65         }
66     }
67 
68     @Suppress("NOTHING_TO_INLINE")
assertWireTypenull69     private inline fun assertWireType(expected: ProtoWireType) {
70         if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found $currentType")
71     }
72 
readByteArraynull73     fun readByteArray(): ByteArray {
74         assertWireType(ProtoWireType.SIZE_DELIMITED)
75         return readByteArrayNoTag()
76     }
77 
skipSizeDelimitednull78     fun skipSizeDelimited() {
79         assertWireType(ProtoWireType.SIZE_DELIMITED)
80         val length = decode32()
81         checkLength(length)
82         input.skipExactNBytes(length)
83     }
84 
readByteArrayNoTagnull85     fun readByteArrayNoTag(): ByteArray {
86         val length = decode32()
87         checkLength(length)
88         return input.readExactNBytes(length)
89     }
90 
objectInputnull91     fun objectInput(): ByteArrayInput {
92         assertWireType(ProtoWireType.SIZE_DELIMITED)
93         return objectTaglessInput()
94     }
95 
objectTaglessInputnull96     fun objectTaglessInput(): ByteArrayInput {
97         val length = decode32()
98         checkLength(length)
99         return input.slice(length)
100     }
101 
readIntnull102     fun readInt(format: ProtoIntegerType): Int {
103         val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i32 else ProtoWireType.VARINT
104         assertWireType(wireType)
105         return decode32(format)
106     }
107 
readInt32NoTagnull108     fun readInt32NoTag(): Int = decode32()
109 
110     fun readLong(format: ProtoIntegerType): Long {
111         val wireType = if (format == ProtoIntegerType.FIXED) ProtoWireType.i64 else ProtoWireType.VARINT
112         assertWireType(wireType)
113         return decode64(format)
114     }
115 
readLongNoTagnull116     fun readLongNoTag(): Long = decode64(ProtoIntegerType.DEFAULT)
117 
118     fun readFloat(): Float {
119         assertWireType(ProtoWireType.i32)
120         return Float.fromBits(readIntLittleEndian())
121     }
122 
readFloatNoTagnull123     fun readFloatNoTag(): Float = Float.fromBits(readIntLittleEndian())
124 
125     private fun readIntLittleEndian(): Int {
126         // TODO this could be optimized by extracting method to the IS
127         var result = 0
128         for (i in 0..3) {
129             val byte = input.read() and 0x000000FF
130             result = result or (byte shl (i * 8))
131         }
132         return result
133     }
134 
readLongLittleEndiannull135     private fun readLongLittleEndian(): Long {
136         // TODO this could be optimized by extracting method to the IS
137         var result = 0L
138         for (i in 0..7) {
139             val byte = (input.read() and 0x000000FF).toLong()
140             result = result or (byte shl (i * 8))
141         }
142         return result
143     }
144 
readDoublenull145     fun readDouble(): Double {
146         assertWireType(ProtoWireType.i64)
147         return Double.fromBits(readLongLittleEndian())
148     }
149 
readDoubleNoTagnull150     fun readDoubleNoTag(): Double {
151         return Double.fromBits(readLongLittleEndian())
152     }
153 
readStringnull154     fun readString(): String {
155         assertWireType(ProtoWireType.SIZE_DELIMITED)
156         val length = decode32()
157         checkLength(length)
158         return input.readString(length)
159     }
160 
readStringNoTagnull161     fun readStringNoTag(): String {
162         val length = decode32()
163         checkLength(length)
164         return input.readString(length)
165     }
166 
checkLengthnull167     private fun checkLength(length: Int) {
168         if (length < 0) {
169             throw ProtobufDecodingException("Unexpected negative length: $length")
170         }
171     }
172 
decode32null173     private fun decode32(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Int = when (format) {
174         ProtoIntegerType.DEFAULT -> input.readVarint64(false).toInt()
175         ProtoIntegerType.SIGNED -> decodeSignedVarintInt(
176             input
177         )
178         ProtoIntegerType.FIXED -> readIntLittleEndian()
179     }
180 
decode64null181     private fun decode64(format: ProtoIntegerType = ProtoIntegerType.DEFAULT): Long = when (format) {
182         ProtoIntegerType.DEFAULT -> input.readVarint64(false)
183         ProtoIntegerType.SIGNED -> decodeSignedVarintLong(
184             input
185         )
186         ProtoIntegerType.FIXED -> readLongLittleEndian()
187     }
188 
189     /**
190      *  Source for all varint operations:
191      *  https://github.com/addthis/stream-lib/blob/master/src/main/java/com/clearspring/analytics/util/Varint.java
192      */
decodeSignedVarintIntnull193     private fun decodeSignedVarintInt(input: ByteArrayInput): Int {
194         val raw = input.readVarint32()
195         val temp = raw shl 31 shr 31 xor raw shr 1
196         // This extra step lets us deal with the largest signed values by treating
197         // negative results from read unsigned methods as like unsigned values.
198         // Must re-flip the top bit if the original read value had it set.
199         return temp xor (raw and (1 shl 31))
200     }
201 
decodeSignedVarintLongnull202     private fun decodeSignedVarintLong(input: ByteArrayInput): Long {
203         val raw = input.readVarint64(false)
204         val temp = raw shl 63 shr 63 xor raw shr 1
205         // This extra step lets us deal with the largest signed values by treating
206         // negative results from read unsigned methods as like unsigned values
207         // Must re-flip the top bit if the original read value had it set.
208         return temp xor (raw and (1L shl 63))
209 
210     }
211 }
212