• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // http://code.google.com/p/protobuf/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 package com.google.protobuf;
32 
33 import java.io.IOException;
34 import java.io.InputStream;
35 import java.util.ArrayList;
36 import java.util.List;
37 
38 /**
39  * Reads and decodes protocol message fields.
40  *
41  * This class contains two kinds of methods:  methods that read specific
42  * protocol message constructs and field types (e.g. {@link #readTag()} and
43  * {@link #readInt32()}) and methods that read low-level values (e.g.
44  * {@link #readRawVarint32()} and {@link #readRawBytes}).  If you are reading
45  * encoded protocol messages, you should use the former methods, but if you are
46  * reading some other format of your own design, use the latter.
47  *
48  * @author kenton@google.com Kenton Varda
49  */
50 public final class CodedInputStream {
51   /**
52    * Create a new CodedInputStream wrapping the given InputStream.
53    */
newInstance(final InputStream input)54   public static CodedInputStream newInstance(final InputStream input) {
55     return new CodedInputStream(input);
56   }
57 
58   /**
59    * Create a new CodedInputStream wrapping the given byte array.
60    */
newInstance(final byte[] buf)61   public static CodedInputStream newInstance(final byte[] buf) {
62     return newInstance(buf, 0, buf.length);
63   }
64 
65   /**
66    * Create a new CodedInputStream wrapping the given byte array slice.
67    */
newInstance(final byte[] buf, final int off, final int len)68   public static CodedInputStream newInstance(final byte[] buf, final int off,
69                                              final int len) {
70     return new CodedInputStream(buf, off, len);
71   }
72 
73   // -----------------------------------------------------------------
74 
75   /**
76    * Attempt to read a field tag, returning zero if we have reached EOF.
77    * Protocol message parsers use this to read tags, since a protocol message
78    * may legally end wherever a tag occurs, and zero is not a valid tag number.
79    */
readTag()80   public int readTag() throws IOException {
81     if (isAtEnd()) {
82       lastTag = 0;
83       return 0;
84     }
85 
86     lastTag = readRawVarint32();
87     if (WireFormat.getTagFieldNumber(lastTag) == 0) {
88       // If we actually read zero (or any tag number corresponding to field
89       // number zero), that's not a valid tag.
90       throw InvalidProtocolBufferException.invalidTag();
91     }
92     return lastTag;
93   }
94 
95   /**
96    * Verifies that the last call to readTag() returned the given tag value.
97    * This is used to verify that a nested group ended with the correct
98    * end tag.
99    *
100    * @throws InvalidProtocolBufferException {@code value} does not match the
101    *                                        last tag.
102    */
checkLastTagWas(final int value)103   public void checkLastTagWas(final int value)
104                               throws InvalidProtocolBufferException {
105     if (lastTag != value) {
106       throw InvalidProtocolBufferException.invalidEndTag();
107     }
108   }
109 
110   /**
111    * Reads and discards a single field, given its tag value.
112    *
113    * @return {@code false} if the tag is an endgroup tag, in which case
114    *         nothing is skipped.  Otherwise, returns {@code true}.
115    */
skipField(final int tag)116   public boolean skipField(final int tag) throws IOException {
117     switch (WireFormat.getTagWireType(tag)) {
118       case WireFormat.WIRETYPE_VARINT:
119         readInt32();
120         return true;
121       case WireFormat.WIRETYPE_FIXED64:
122         readRawLittleEndian64();
123         return true;
124       case WireFormat.WIRETYPE_LENGTH_DELIMITED:
125         skipRawBytes(readRawVarint32());
126         return true;
127       case WireFormat.WIRETYPE_START_GROUP:
128         skipMessage();
129         checkLastTagWas(
130           WireFormat.makeTag(WireFormat.getTagFieldNumber(tag),
131                              WireFormat.WIRETYPE_END_GROUP));
132         return true;
133       case WireFormat.WIRETYPE_END_GROUP:
134         return false;
135       case WireFormat.WIRETYPE_FIXED32:
136         readRawLittleEndian32();
137         return true;
138       default:
139         throw InvalidProtocolBufferException.invalidWireType();
140     }
141   }
142 
143   /**
144    * Reads and discards an entire message.  This will read either until EOF
145    * or until an endgroup tag, whichever comes first.
146    */
skipMessage()147   public void skipMessage() throws IOException {
148     while (true) {
149       final int tag = readTag();
150       if (tag == 0 || !skipField(tag)) {
151         return;
152       }
153     }
154   }
155 
156   // -----------------------------------------------------------------
157 
158   /** Read a {@code double} field value from the stream. */
readDouble()159   public double readDouble() throws IOException {
160     return Double.longBitsToDouble(readRawLittleEndian64());
161   }
162 
163   /** Read a {@code float} field value from the stream. */
readFloat()164   public float readFloat() throws IOException {
165     return Float.intBitsToFloat(readRawLittleEndian32());
166   }
167 
168   /** Read a {@code uint64} field value from the stream. */
readUInt64()169   public long readUInt64() throws IOException {
170     return readRawVarint64();
171   }
172 
173   /** Read an {@code int64} field value from the stream. */
readInt64()174   public long readInt64() throws IOException {
175     return readRawVarint64();
176   }
177 
178   /** Read an {@code int32} field value from the stream. */
readInt32()179   public int readInt32() throws IOException {
180     return readRawVarint32();
181   }
182 
183   /** Read a {@code fixed64} field value from the stream. */
readFixed64()184   public long readFixed64() throws IOException {
185     return readRawLittleEndian64();
186   }
187 
188   /** Read a {@code fixed32} field value from the stream. */
readFixed32()189   public int readFixed32() throws IOException {
190     return readRawLittleEndian32();
191   }
192 
193   /** Read a {@code bool} field value from the stream. */
readBool()194   public boolean readBool() throws IOException {
195     return readRawVarint32() != 0;
196   }
197 
198   /** Read a {@code string} field value from the stream. */
readString()199   public String readString() throws IOException {
200     final int size = readRawVarint32();
201     if (size <= (bufferSize - bufferPos) && size > 0) {
202       // Fast path:  We already have the bytes in a contiguous buffer, so
203       //   just copy directly from it.
204       final String result = new String(buffer, bufferPos, size, "UTF-8");
205       bufferPos += size;
206       return result;
207     } else {
208       // Slow path:  Build a byte array first then copy it.
209       return new String(readRawBytes(size), "UTF-8");
210     }
211   }
212 
213   /** Read a {@code group} field value from the stream. */
readGroup(final int fieldNumber, final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry)214   public void readGroup(final int fieldNumber,
215                         final MessageLite.Builder builder,
216                         final ExtensionRegistryLite extensionRegistry)
217       throws IOException {
218     if (recursionDepth >= recursionLimit) {
219       throw InvalidProtocolBufferException.recursionLimitExceeded();
220     }
221     ++recursionDepth;
222     builder.mergeFrom(this, extensionRegistry);
223     checkLastTagWas(
224       WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP));
225     --recursionDepth;
226   }
227 
228   /**
229    * Reads a {@code group} field value from the stream and merges it into the
230    * given {@link UnknownFieldSet}.
231    *
232    * @deprecated UnknownFieldSet.Builder now implements MessageLite.Builder, so
233    *             you can just call {@link #readGroup}.
234    */
235   @Deprecated
readUnknownGroup(final int fieldNumber, final MessageLite.Builder builder)236   public void readUnknownGroup(final int fieldNumber,
237                                final MessageLite.Builder builder)
238       throws IOException {
239     // We know that UnknownFieldSet will ignore any ExtensionRegistry so it
240     // is safe to pass null here.  (We can't call
241     // ExtensionRegistry.getEmptyRegistry() because that would make this
242     // class depend on ExtensionRegistry, which is not part of the lite
243     // library.)
244     readGroup(fieldNumber, builder, null);
245   }
246 
247   /** Read an embedded message field value from the stream. */
readMessage(final MessageLite.Builder builder, final ExtensionRegistryLite extensionRegistry)248   public void readMessage(final MessageLite.Builder builder,
249                           final ExtensionRegistryLite extensionRegistry)
250       throws IOException {
251     final int length = readRawVarint32();
252     if (recursionDepth >= recursionLimit) {
253       throw InvalidProtocolBufferException.recursionLimitExceeded();
254     }
255     final int oldLimit = pushLimit(length);
256     ++recursionDepth;
257     builder.mergeFrom(this, extensionRegistry);
258     checkLastTagWas(0);
259     --recursionDepth;
260     popLimit(oldLimit);
261   }
262 
263   /** Read a {@code bytes} field value from the stream. */
readBytes()264   public ByteString readBytes() throws IOException {
265     final int size = readRawVarint32();
266     if (size <= (bufferSize - bufferPos) && size > 0) {
267       // Fast path:  We already have the bytes in a contiguous buffer, so
268       //   just copy directly from it.
269       final ByteString result = ByteString.copyFrom(buffer, bufferPos, size);
270       bufferPos += size;
271       return result;
272     } else {
273       // Slow path:  Build a byte array first then copy it.
274       return ByteString.copyFrom(readRawBytes(size));
275     }
276   }
277 
278   /** Read a {@code uint32} field value from the stream. */
readUInt32()279   public int readUInt32() throws IOException {
280     return readRawVarint32();
281   }
282 
283   /**
284    * Read an enum field value from the stream.  Caller is responsible
285    * for converting the numeric value to an actual enum.
286    */
readEnum()287   public int readEnum() throws IOException {
288     return readRawVarint32();
289   }
290 
291   /** Read an {@code sfixed32} field value from the stream. */
readSFixed32()292   public int readSFixed32() throws IOException {
293     return readRawLittleEndian32();
294   }
295 
296   /** Read an {@code sfixed64} field value from the stream. */
readSFixed64()297   public long readSFixed64() throws IOException {
298     return readRawLittleEndian64();
299   }
300 
301   /** Read an {@code sint32} field value from the stream. */
readSInt32()302   public int readSInt32() throws IOException {
303     return decodeZigZag32(readRawVarint32());
304   }
305 
306   /** Read an {@code sint64} field value from the stream. */
readSInt64()307   public long readSInt64() throws IOException {
308     return decodeZigZag64(readRawVarint64());
309   }
310 
311   // =================================================================
312 
313   /**
314    * Read a raw Varint from the stream.  If larger than 32 bits, discard the
315    * upper bits.
316    */
readRawVarint32()317   public int readRawVarint32() throws IOException {
318     byte tmp = readRawByte();
319     if (tmp >= 0) {
320       return tmp;
321     }
322     int result = tmp & 0x7f;
323     if ((tmp = readRawByte()) >= 0) {
324       result |= tmp << 7;
325     } else {
326       result |= (tmp & 0x7f) << 7;
327       if ((tmp = readRawByte()) >= 0) {
328         result |= tmp << 14;
329       } else {
330         result |= (tmp & 0x7f) << 14;
331         if ((tmp = readRawByte()) >= 0) {
332           result |= tmp << 21;
333         } else {
334           result |= (tmp & 0x7f) << 21;
335           result |= (tmp = readRawByte()) << 28;
336           if (tmp < 0) {
337             // Discard upper 32 bits.
338             for (int i = 0; i < 5; i++) {
339               if (readRawByte() >= 0) {
340                 return result;
341               }
342             }
343             throw InvalidProtocolBufferException.malformedVarint();
344           }
345         }
346       }
347     }
348     return result;
349   }
350 
351   /**
352    * Reads a varint from the input one byte at a time, so that it does not
353    * read any bytes after the end of the varint.  If you simply wrapped the
354    * stream in a CodedInputStream and used {@link #readRawVarint32(InputStream)}
355    * then you would probably end up reading past the end of the varint since
356    * CodedInputStream buffers its input.
357    */
readRawVarint32(final InputStream input)358   static int readRawVarint32(final InputStream input) throws IOException {
359     final int firstByte = input.read();
360     if (firstByte == -1) {
361       throw InvalidProtocolBufferException.truncatedMessage();
362     }
363     return readRawVarint32(firstByte, input);
364   }
365 
366   /**
367    * Like {@link #readRawVarint32(InputStream)}, but expects that the caller
368    * has already read one byte.  This allows the caller to determine if EOF
369    * has been reached before attempting to read.
370    */
readRawVarint32(final int firstByte, final InputStream input)371   static int readRawVarint32(final int firstByte,
372                              final InputStream input) throws IOException {
373     if ((firstByte & 0x80) == 0) {
374       return firstByte;
375     }
376 
377     int result = firstByte & 0x7f;
378     int offset = 7;
379     for (; offset < 32; offset += 7) {
380       final int b = input.read();
381       if (b == -1) {
382         throw InvalidProtocolBufferException.truncatedMessage();
383       }
384       result |= (b & 0x7f) << offset;
385       if ((b & 0x80) == 0) {
386         return result;
387       }
388     }
389     // Keep reading up to 64 bits.
390     for (; offset < 64; offset += 7) {
391       final int b = input.read();
392       if (b == -1) {
393         throw InvalidProtocolBufferException.truncatedMessage();
394       }
395       if ((b & 0x80) == 0) {
396         return result;
397       }
398     }
399     throw InvalidProtocolBufferException.malformedVarint();
400   }
401 
402   /** Read a raw Varint from the stream. */
readRawVarint64()403   public long readRawVarint64() throws IOException {
404     int shift = 0;
405     long result = 0;
406     while (shift < 64) {
407       final byte b = readRawByte();
408       result |= (long)(b & 0x7F) << shift;
409       if ((b & 0x80) == 0) {
410         return result;
411       }
412       shift += 7;
413     }
414     throw InvalidProtocolBufferException.malformedVarint();
415   }
416 
417   /** Read a 32-bit little-endian integer from the stream. */
readRawLittleEndian32()418   public int readRawLittleEndian32() throws IOException {
419     final byte b1 = readRawByte();
420     final byte b2 = readRawByte();
421     final byte b3 = readRawByte();
422     final byte b4 = readRawByte();
423     return (((int)b1 & 0xff)      ) |
424            (((int)b2 & 0xff) <<  8) |
425            (((int)b3 & 0xff) << 16) |
426            (((int)b4 & 0xff) << 24);
427   }
428 
429   /** Read a 64-bit little-endian integer from the stream. */
readRawLittleEndian64()430   public long readRawLittleEndian64() throws IOException {
431     final byte b1 = readRawByte();
432     final byte b2 = readRawByte();
433     final byte b3 = readRawByte();
434     final byte b4 = readRawByte();
435     final byte b5 = readRawByte();
436     final byte b6 = readRawByte();
437     final byte b7 = readRawByte();
438     final byte b8 = readRawByte();
439     return (((long)b1 & 0xff)      ) |
440            (((long)b2 & 0xff) <<  8) |
441            (((long)b3 & 0xff) << 16) |
442            (((long)b4 & 0xff) << 24) |
443            (((long)b5 & 0xff) << 32) |
444            (((long)b6 & 0xff) << 40) |
445            (((long)b7 & 0xff) << 48) |
446            (((long)b8 & 0xff) << 56);
447   }
448 
449   /**
450    * Decode a ZigZag-encoded 32-bit value.  ZigZag encodes signed integers
451    * into values that can be efficiently encoded with varint.  (Otherwise,
452    * negative values must be sign-extended to 64 bits to be varint encoded,
453    * thus always taking 10 bytes on the wire.)
454    *
455    * @param n An unsigned 32-bit integer, stored in a signed int because
456    *          Java has no explicit unsigned support.
457    * @return A signed 32-bit integer.
458    */
decodeZigZag32(final int n)459   public static int decodeZigZag32(final int n) {
460     return (n >>> 1) ^ -(n & 1);
461   }
462 
463   /**
464    * Decode a ZigZag-encoded 64-bit value.  ZigZag encodes signed integers
465    * into values that can be efficiently encoded with varint.  (Otherwise,
466    * negative values must be sign-extended to 64 bits to be varint encoded,
467    * thus always taking 10 bytes on the wire.)
468    *
469    * @param n An unsigned 64-bit integer, stored in a signed int because
470    *          Java has no explicit unsigned support.
471    * @return A signed 64-bit integer.
472    */
decodeZigZag64(final long n)473   public static long decodeZigZag64(final long n) {
474     return (n >>> 1) ^ -(n & 1);
475   }
476 
477   // -----------------------------------------------------------------
478 
479   private final byte[] buffer;
480   private int bufferSize;
481   private int bufferSizeAfterLimit;
482   private int bufferPos;
483   private final InputStream input;
484   private int lastTag;
485 
486   /**
487    * The total number of bytes read before the current buffer.  The total
488    * bytes read up to the current position can be computed as
489    * {@code totalBytesRetired + bufferPos}.  This value may be negative if
490    * reading started in the middle of the current buffer (e.g. if the
491    * constructor that takes a byte array and an offset was used).
492    */
493   private int totalBytesRetired;
494 
495   /** The absolute position of the end of the current message. */
496   private int currentLimit = Integer.MAX_VALUE;
497 
498   /** See setRecursionLimit() */
499   private int recursionDepth;
500   private int recursionLimit = DEFAULT_RECURSION_LIMIT;
501 
502   /** See setSizeLimit() */
503   private int sizeLimit = DEFAULT_SIZE_LIMIT;
504 
505   private static final int DEFAULT_RECURSION_LIMIT = 64;
506   private static final int DEFAULT_SIZE_LIMIT = 64 << 20;  // 64MB
507   private static final int BUFFER_SIZE = 4096;
508 
CodedInputStream(final byte[] buffer, final int off, final int len)509   private CodedInputStream(final byte[] buffer, final int off, final int len) {
510     this.buffer = buffer;
511     bufferSize = off + len;
512     bufferPos = off;
513     totalBytesRetired = -off;
514     input = null;
515   }
516 
CodedInputStream(final InputStream input)517   private CodedInputStream(final InputStream input) {
518     buffer = new byte[BUFFER_SIZE];
519     bufferSize = 0;
520     bufferPos = 0;
521     totalBytesRetired = 0;
522     this.input = input;
523   }
524 
525   /**
526    * Set the maximum message recursion depth.  In order to prevent malicious
527    * messages from causing stack overflows, {@code CodedInputStream} limits
528    * how deeply messages may be nested.  The default limit is 64.
529    *
530    * @return the old limit.
531    */
setRecursionLimit(final int limit)532   public int setRecursionLimit(final int limit) {
533     if (limit < 0) {
534       throw new IllegalArgumentException(
535         "Recursion limit cannot be negative: " + limit);
536     }
537     final int oldLimit = recursionLimit;
538     recursionLimit = limit;
539     return oldLimit;
540   }
541 
542   /**
543    * Set the maximum message size.  In order to prevent malicious
544    * messages from exhausting memory or causing integer overflows,
545    * {@code CodedInputStream} limits how large a message may be.
546    * The default limit is 64MB.  You should set this limit as small
547    * as you can without harming your app's functionality.  Note that
548    * size limits only apply when reading from an {@code InputStream}, not
549    * when constructed around a raw byte array (nor with
550    * {@link ByteString#newCodedInput}).
551    * <p>
552    * If you want to read several messages from a single CodedInputStream, you
553    * could call {@link #resetSizeCounter()} after each one to avoid hitting the
554    * size limit.
555    *
556    * @return the old limit.
557    */
setSizeLimit(final int limit)558   public int setSizeLimit(final int limit) {
559     if (limit < 0) {
560       throw new IllegalArgumentException(
561         "Size limit cannot be negative: " + limit);
562     }
563     final int oldLimit = sizeLimit;
564     sizeLimit = limit;
565     return oldLimit;
566   }
567 
568   /**
569    * Resets the current size counter to zero (see {@link #setSizeLimit(int)}).
570    */
resetSizeCounter()571   public void resetSizeCounter() {
572     totalBytesRetired = -bufferPos;
573   }
574 
575   /**
576    * Sets {@code currentLimit} to (current position) + {@code byteLimit}.  This
577    * is called when descending into a length-delimited embedded message.
578    *
579    * <p>Note that {@code pushLimit()} does NOT affect how many bytes the
580    * {@code CodedInputStream} reads from an underlying {@code InputStream} when
581    * refreshing its buffer.  If you need to prevent reading past a certain
582    * point in the underlying {@code InputStream} (e.g. because you expect it to
583    * contain more data after the end of the message which you need to handle
584    * differently) then you must place a wrapper around you {@code InputStream}
585    * which limits the amount of data that can be read from it.
586    *
587    * @return the old limit.
588    */
pushLimit(int byteLimit)589   public int pushLimit(int byteLimit) throws InvalidProtocolBufferException {
590     if (byteLimit < 0) {
591       throw InvalidProtocolBufferException.negativeSize();
592     }
593     byteLimit += totalBytesRetired + bufferPos;
594     final int oldLimit = currentLimit;
595     if (byteLimit > oldLimit) {
596       throw InvalidProtocolBufferException.truncatedMessage();
597     }
598     currentLimit = byteLimit;
599 
600     recomputeBufferSizeAfterLimit();
601 
602     return oldLimit;
603   }
604 
recomputeBufferSizeAfterLimit()605   private void recomputeBufferSizeAfterLimit() {
606     bufferSize += bufferSizeAfterLimit;
607     final int bufferEnd = totalBytesRetired + bufferSize;
608     if (bufferEnd > currentLimit) {
609       // Limit is in current buffer.
610       bufferSizeAfterLimit = bufferEnd - currentLimit;
611       bufferSize -= bufferSizeAfterLimit;
612     } else {
613       bufferSizeAfterLimit = 0;
614     }
615   }
616 
617   /**
618    * Discards the current limit, returning to the previous limit.
619    *
620    * @param oldLimit The old limit, as returned by {@code pushLimit}.
621    */
popLimit(final int oldLimit)622   public void popLimit(final int oldLimit) {
623     currentLimit = oldLimit;
624     recomputeBufferSizeAfterLimit();
625   }
626 
627   /**
628    * Returns the number of bytes to be read before the current limit.
629    * If no limit is set, returns -1.
630    */
getBytesUntilLimit()631   public int getBytesUntilLimit() {
632     if (currentLimit == Integer.MAX_VALUE) {
633       return -1;
634     }
635 
636     final int currentAbsolutePosition = totalBytesRetired + bufferPos;
637     return currentLimit - currentAbsolutePosition;
638   }
639 
640   /**
641    * Returns true if the stream has reached the end of the input.  This is the
642    * case if either the end of the underlying input source has been reached or
643    * if the stream has reached a limit created using {@link #pushLimit(int)}.
644    */
isAtEnd()645   public boolean isAtEnd() throws IOException {
646     return bufferPos == bufferSize && !refillBuffer(false);
647   }
648 
649   /**
650    * The total bytes read up to the current position. Calling
651    * {@link #resetSizeCounter()} resets this value to zero.
652    */
getTotalBytesRead()653   public int getTotalBytesRead() {
654       return totalBytesRetired + bufferPos;
655   }
656 
657   /**
658    * Called with {@code this.buffer} is empty to read more bytes from the
659    * input.  If {@code mustSucceed} is true, refillBuffer() gurantees that
660    * either there will be at least one byte in the buffer when it returns
661    * or it will throw an exception.  If {@code mustSucceed} is false,
662    * refillBuffer() returns false if no more bytes were available.
663    */
refillBuffer(final boolean mustSucceed)664   private boolean refillBuffer(final boolean mustSucceed) throws IOException {
665     if (bufferPos < bufferSize) {
666       throw new IllegalStateException(
667         "refillBuffer() called when buffer wasn't empty.");
668     }
669 
670     if (totalBytesRetired + bufferSize == currentLimit) {
671       // Oops, we hit a limit.
672       if (mustSucceed) {
673         throw InvalidProtocolBufferException.truncatedMessage();
674       } else {
675         return false;
676       }
677     }
678 
679     totalBytesRetired += bufferSize;
680 
681     bufferPos = 0;
682     bufferSize = (input == null) ? -1 : input.read(buffer);
683     if (bufferSize == 0 || bufferSize < -1) {
684       throw new IllegalStateException(
685           "InputStream#read(byte[]) returned invalid result: " + bufferSize +
686           "\nThe InputStream implementation is buggy.");
687     }
688     if (bufferSize == -1) {
689       bufferSize = 0;
690       if (mustSucceed) {
691         throw InvalidProtocolBufferException.truncatedMessage();
692       } else {
693         return false;
694       }
695     } else {
696       recomputeBufferSizeAfterLimit();
697       final int totalBytesRead =
698         totalBytesRetired + bufferSize + bufferSizeAfterLimit;
699       if (totalBytesRead > sizeLimit || totalBytesRead < 0) {
700         throw InvalidProtocolBufferException.sizeLimitExceeded();
701       }
702       return true;
703     }
704   }
705 
706   /**
707    * Read one byte from the input.
708    *
709    * @throws InvalidProtocolBufferException The end of the stream or the current
710    *                                        limit was reached.
711    */
readRawByte()712   public byte readRawByte() throws IOException {
713     if (bufferPos == bufferSize) {
714       refillBuffer(true);
715     }
716     return buffer[bufferPos++];
717   }
718 
719   /**
720    * Read a fixed size of bytes from the input.
721    *
722    * @throws InvalidProtocolBufferException The end of the stream or the current
723    *                                        limit was reached.
724    */
readRawBytes(final int size)725   public byte[] readRawBytes(final int size) throws IOException {
726     if (size < 0) {
727       throw InvalidProtocolBufferException.negativeSize();
728     }
729 
730     if (totalBytesRetired + bufferPos + size > currentLimit) {
731       // Read to the end of the stream anyway.
732       skipRawBytes(currentLimit - totalBytesRetired - bufferPos);
733       // Then fail.
734       throw InvalidProtocolBufferException.truncatedMessage();
735     }
736 
737     if (size <= bufferSize - bufferPos) {
738       // We have all the bytes we need already.
739       final byte[] bytes = new byte[size];
740       System.arraycopy(buffer, bufferPos, bytes, 0, size);
741       bufferPos += size;
742       return bytes;
743     } else if (size < BUFFER_SIZE) {
744       // Reading more bytes than are in the buffer, but not an excessive number
745       // of bytes.  We can safely allocate the resulting array ahead of time.
746 
747       // First copy what we have.
748       final byte[] bytes = new byte[size];
749       int pos = bufferSize - bufferPos;
750       System.arraycopy(buffer, bufferPos, bytes, 0, pos);
751       bufferPos = bufferSize;
752 
753       // We want to use refillBuffer() and then copy from the buffer into our
754       // byte array rather than reading directly into our byte array because
755       // the input may be unbuffered.
756       refillBuffer(true);
757 
758       while (size - pos > bufferSize) {
759         System.arraycopy(buffer, 0, bytes, pos, bufferSize);
760         pos += bufferSize;
761         bufferPos = bufferSize;
762         refillBuffer(true);
763       }
764 
765       System.arraycopy(buffer, 0, bytes, pos, size - pos);
766       bufferPos = size - pos;
767 
768       return bytes;
769     } else {
770       // The size is very large.  For security reasons, we can't allocate the
771       // entire byte array yet.  The size comes directly from the input, so a
772       // maliciously-crafted message could provide a bogus very large size in
773       // order to trick the app into allocating a lot of memory.  We avoid this
774       // by allocating and reading only a small chunk at a time, so that the
775       // malicious message must actually *be* extremely large to cause
776       // problems.  Meanwhile, we limit the allowed size of a message elsewhere.
777 
778       // Remember the buffer markers since we'll have to copy the bytes out of
779       // it later.
780       final int originalBufferPos = bufferPos;
781       final int originalBufferSize = bufferSize;
782 
783       // Mark the current buffer consumed.
784       totalBytesRetired += bufferSize;
785       bufferPos = 0;
786       bufferSize = 0;
787 
788       // Read all the rest of the bytes we need.
789       int sizeLeft = size - (originalBufferSize - originalBufferPos);
790       final List<byte[]> chunks = new ArrayList<byte[]>();
791 
792       while (sizeLeft > 0) {
793         final byte[] chunk = new byte[Math.min(sizeLeft, BUFFER_SIZE)];
794         int pos = 0;
795         while (pos < chunk.length) {
796           final int n = (input == null) ? -1 :
797             input.read(chunk, pos, chunk.length - pos);
798           if (n == -1) {
799             throw InvalidProtocolBufferException.truncatedMessage();
800           }
801           totalBytesRetired += n;
802           pos += n;
803         }
804         sizeLeft -= chunk.length;
805         chunks.add(chunk);
806       }
807 
808       // OK, got everything.  Now concatenate it all into one buffer.
809       final byte[] bytes = new byte[size];
810 
811       // Start by copying the leftover bytes from this.buffer.
812       int pos = originalBufferSize - originalBufferPos;
813       System.arraycopy(buffer, originalBufferPos, bytes, 0, pos);
814 
815       // And now all the chunks.
816       for (final byte[] chunk : chunks) {
817         System.arraycopy(chunk, 0, bytes, pos, chunk.length);
818         pos += chunk.length;
819       }
820 
821       // Done.
822       return bytes;
823     }
824   }
825 
826   /**
827    * Reads and discards {@code size} bytes.
828    *
829    * @throws InvalidProtocolBufferException The end of the stream or the current
830    *                                        limit was reached.
831    */
skipRawBytes(final int size)832   public void skipRawBytes(final int size) throws IOException {
833     if (size < 0) {
834       throw InvalidProtocolBufferException.negativeSize();
835     }
836 
837     if (totalBytesRetired + bufferPos + size > currentLimit) {
838       // Read to the end of the stream anyway.
839       skipRawBytes(currentLimit - totalBytesRetired - bufferPos);
840       // Then fail.
841       throw InvalidProtocolBufferException.truncatedMessage();
842     }
843 
844     if (size <= bufferSize - bufferPos) {
845       // We have all the bytes we need already.
846       bufferPos += size;
847     } else {
848       // Skipping more bytes than are in the buffer.  First skip what we have.
849       int pos = bufferSize - bufferPos;
850       totalBytesRetired += bufferSize;
851       bufferPos = 0;
852       bufferSize = 0;
853 
854       // Then skip directly from the InputStream for the rest.
855       while (pos < size) {
856         final int n = (input == null) ? -1 : (int) input.skip(size - pos);
857         if (n <= 0) {
858           throw InvalidProtocolBufferException.truncatedMessage();
859         }
860         pos += n;
861         totalBytesRetired += n;
862       }
863     }
864   }
865 }
866