• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #region Copyright notice and license
2 // Protocol Buffers - Google's data interchange format
3 // Copyright 2008 Google Inc.  All rights reserved.
4 // https://developers.google.com/protocol-buffers/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 #endregion
32 
33 using System;
34 using System.Buffers;
35 using System.IO;
36 using Google.Protobuf.TestProtos;
37 using Proto2 = Google.Protobuf.TestProtos.Proto2;
38 using NUnit.Framework;
39 
40 namespace Google.Protobuf
41 {
42     public class CodedInputStreamTest
43     {
44         /// <summary>
45         /// Helper to construct a byte array from a bunch of bytes.  The inputs are
46         /// actually ints so that I can use hex notation and not get stupid errors
47         /// about precision.
48         /// </summary>
Bytes(params int[] bytesAsInts)49         private static byte[] Bytes(params int[] bytesAsInts)
50         {
51             byte[] bytes = new byte[bytesAsInts.Length];
52             for (int i = 0; i < bytesAsInts.Length; i++)
53             {
54                 bytes[i] = (byte) bytesAsInts[i];
55             }
56             return bytes;
57         }
58 
59         /// <summary>
60         /// Parses the given bytes using ReadRawVarint32() and ReadRawVarint64()
61         /// </summary>
AssertReadVarint(byte[] data, ulong value)62         private static void AssertReadVarint(byte[] data, ulong value)
63         {
64             CodedInputStream input = new CodedInputStream(data);
65             Assert.AreEqual((uint) value, input.ReadRawVarint32());
66             Assert.IsTrue(input.IsAtEnd);
67 
68             input = new CodedInputStream(data);
69             Assert.AreEqual(value, input.ReadRawVarint64());
70             Assert.IsTrue(input.IsAtEnd);
71 
72             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
73             {
74                 Assert.AreEqual((uint) value, ctx.ReadUInt32());
75             }, true);
76 
77             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
78             {
79                 Assert.AreEqual(value, ctx.ReadUInt64());
80             }, true);
81 
82             // Try different block sizes.
83             for (int bufferSize = 1; bufferSize <= 16; bufferSize *= 2)
84             {
85                 input = new CodedInputStream(new SmallBlockInputStream(data, bufferSize));
86                 Assert.AreEqual((uint) value, input.ReadRawVarint32());
87 
88                 input = new CodedInputStream(new SmallBlockInputStream(data, bufferSize));
89                 Assert.AreEqual(value, input.ReadRawVarint64());
90                 Assert.IsTrue(input.IsAtEnd);
91 
92                 AssertReadFromParseContext(ReadOnlySequenceFactory.CreateWithContent(data, bufferSize), (ref ParseContext ctx) =>
93                 {
94                     Assert.AreEqual((uint) value, ctx.ReadUInt32());
95                 }, true);
96 
97                 AssertReadFromParseContext(ReadOnlySequenceFactory.CreateWithContent(data, bufferSize), (ref ParseContext ctx) =>
98                 {
99                     Assert.AreEqual(value, ctx.ReadUInt64());
100                 }, true);
101             }
102 
103             // Try reading directly from a MemoryStream. We want to verify that it
104             // doesn't read past the end of the input, so write an extra byte - this
105             // lets us test the position at the end.
106             MemoryStream memoryStream = new MemoryStream();
107             memoryStream.Write(data, 0, data.Length);
108             memoryStream.WriteByte(0);
109             memoryStream.Position = 0;
110             Assert.AreEqual((uint) value, CodedInputStream.ReadRawVarint32(memoryStream));
111             Assert.AreEqual(data.Length, memoryStream.Position);
112         }
113 
114         /// <summary>
115         /// Parses the given bytes using ReadRawVarint32() and ReadRawVarint64() and
116         /// expects them to fail with an InvalidProtocolBufferException whose
117         /// description matches the given one.
118         /// </summary>
AssertReadVarintFailure(InvalidProtocolBufferException expected, byte[] data)119         private static void AssertReadVarintFailure(InvalidProtocolBufferException expected, byte[] data)
120         {
121             CodedInputStream input = new CodedInputStream(data);
122             var exception = Assert.Throws<InvalidProtocolBufferException>(() => input.ReadRawVarint32());
123             Assert.AreEqual(expected.Message, exception.Message);
124 
125             input = new CodedInputStream(data);
126             exception = Assert.Throws<InvalidProtocolBufferException>(() => input.ReadRawVarint64());
127             Assert.AreEqual(expected.Message, exception.Message);
128 
129             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
130             {
131                 try
132                 {
133                     ctx.ReadUInt32();
134                     Assert.Fail();
135                 }
136                 catch (InvalidProtocolBufferException ex)
137                 {
138                     Assert.AreEqual(expected.Message, ex.Message);
139                 }
140             }, false);
141 
142             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
143             {
144                 try
145                 {
146                     ctx.ReadUInt64();
147                     Assert.Fail();
148                 }
149                 catch (InvalidProtocolBufferException ex)
150                 {
151                     Assert.AreEqual(expected.Message, ex.Message);
152                 }
153             }, false);
154 
155             // Make sure we get the same error when reading directly from a Stream.
156             exception = Assert.Throws<InvalidProtocolBufferException>(() => CodedInputStream.ReadRawVarint32(new MemoryStream(data)));
157             Assert.AreEqual(expected.Message, exception.Message);
158         }
159 
ParseContextAssertAction(ref ParseContext ctx)160         private delegate void ParseContextAssertAction(ref ParseContext ctx);
161 
AssertReadFromParseContext(ReadOnlySequence<byte> input, ParseContextAssertAction assertAction, bool assertIsAtEnd)162         private static void AssertReadFromParseContext(ReadOnlySequence<byte> input, ParseContextAssertAction assertAction, bool assertIsAtEnd)
163         {
164             // Check as ReadOnlySequence<byte>
165             ParseContext.Initialize(input, out ParseContext parseCtx);
166             assertAction(ref parseCtx);
167             if (assertIsAtEnd)
168             {
169                 Assert.IsTrue(SegmentedBufferHelper.IsAtEnd(ref parseCtx.buffer, ref parseCtx.state));
170             }
171 
172             // Check as ReadOnlySpan<byte>
173             ParseContext.Initialize(input.ToArray().AsSpan(), out ParseContext spanParseContext);
174             assertAction(ref spanParseContext);
175             if (assertIsAtEnd)
176             {
177                 Assert.IsTrue(SegmentedBufferHelper.IsAtEnd(ref spanParseContext.buffer, ref spanParseContext.state));
178             }
179         }
180 
181         [Test]
ReadVarint()182         public void ReadVarint()
183         {
184             AssertReadVarint(Bytes(0x00), 0);
185             AssertReadVarint(Bytes(0x01), 1);
186             AssertReadVarint(Bytes(0x7f), 127);
187             // 14882
188             AssertReadVarint(Bytes(0xa2, 0x74), (0x22 << 0) | (0x74 << 7));
189             // 2961488830
190             AssertReadVarint(Bytes(0xbe, 0xf7, 0x92, 0x84, 0x0b),
191                              (0x3e << 0) | (0x77 << 7) | (0x12 << 14) | (0x04 << 21) |
192                              (0x0bL << 28));
193 
194             // 64-bit
195             // 7256456126
196             AssertReadVarint(Bytes(0xbe, 0xf7, 0x92, 0x84, 0x1b),
197                              (0x3e << 0) | (0x77 << 7) | (0x12 << 14) | (0x04 << 21) |
198                              (0x1bL << 28));
199             // 41256202580718336
200             AssertReadVarint(Bytes(0x80, 0xe6, 0xeb, 0x9c, 0xc3, 0xc9, 0xa4, 0x49),
201                              (0x00 << 0) | (0x66 << 7) | (0x6b << 14) | (0x1c << 21) |
202                              (0x43L << 28) | (0x49L << 35) | (0x24L << 42) | (0x49L << 49));
203             // 11964378330978735131
204             AssertReadVarint(Bytes(0x9b, 0xa8, 0xf9, 0xc2, 0xbb, 0xd6, 0x80, 0x85, 0xa6, 0x01),
205                              (0x1b << 0) | (0x28 << 7) | (0x79 << 14) | (0x42 << 21) |
206                              (0x3bUL << 28) | (0x56UL << 35) | (0x00UL << 42) |
207                              (0x05UL << 49) | (0x26UL << 56) | (0x01UL << 63));
208 
209             // Failures
210             AssertReadVarintFailure(
211                 InvalidProtocolBufferException.MalformedVarint(),
212                 Bytes(0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
213                       0x00));
214             AssertReadVarintFailure(
215                 InvalidProtocolBufferException.TruncatedMessage(),
216                 Bytes(0x80));
217         }
218 
219         /// <summary>
220         /// Parses the given bytes using ReadRawLittleEndian32() and checks
221         /// that the result matches the given value.
222         /// </summary>
AssertReadLittleEndian32(byte[] data, uint value)223         private static void AssertReadLittleEndian32(byte[] data, uint value)
224         {
225             CodedInputStream input = new CodedInputStream(data);
226             Assert.AreEqual(value, input.ReadRawLittleEndian32());
227             Assert.IsTrue(input.IsAtEnd);
228 
229             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
230             {
231                 Assert.AreEqual(value, ctx.ReadFixed32());
232             }, true);
233 
234             // Try different block sizes.
235             for (int blockSize = 1; blockSize <= 16; blockSize *= 2)
236             {
237                 input = new CodedInputStream(
238                     new SmallBlockInputStream(data, blockSize));
239                 Assert.AreEqual(value, input.ReadRawLittleEndian32());
240                 Assert.IsTrue(input.IsAtEnd);
241 
242                 AssertReadFromParseContext(ReadOnlySequenceFactory.CreateWithContent(data, blockSize), (ref ParseContext ctx) =>
243                 {
244                     Assert.AreEqual(value, ctx.ReadFixed32());
245                 }, true);
246             }
247         }
248 
249         /// <summary>
250         /// Parses the given bytes using ReadRawLittleEndian64() and checks
251         /// that the result matches the given value.
252         /// </summary>
AssertReadLittleEndian64(byte[] data, ulong value)253         private static void AssertReadLittleEndian64(byte[] data, ulong value)
254         {
255             CodedInputStream input = new CodedInputStream(data);
256             Assert.AreEqual(value, input.ReadRawLittleEndian64());
257             Assert.IsTrue(input.IsAtEnd);
258 
259             AssertReadFromParseContext(new ReadOnlySequence<byte>(data), (ref ParseContext ctx) =>
260             {
261                 Assert.AreEqual(value, ctx.ReadFixed64());
262             }, true);
263 
264             // Try different block sizes.
265             for (int blockSize = 1; blockSize <= 16; blockSize *= 2)
266             {
267                 input = new CodedInputStream(
268                     new SmallBlockInputStream(data, blockSize));
269                 Assert.AreEqual(value, input.ReadRawLittleEndian64());
270                 Assert.IsTrue(input.IsAtEnd);
271 
272                 AssertReadFromParseContext(ReadOnlySequenceFactory.CreateWithContent(data, blockSize), (ref ParseContext ctx) =>
273                 {
274                     Assert.AreEqual(value, ctx.ReadFixed64());
275                 }, true);
276             }
277         }
278 
279         [Test]
ReadLittleEndian()280         public void ReadLittleEndian()
281         {
282             AssertReadLittleEndian32(Bytes(0x78, 0x56, 0x34, 0x12), 0x12345678);
283             AssertReadLittleEndian32(Bytes(0xf0, 0xde, 0xbc, 0x9a), 0x9abcdef0);
284 
285             AssertReadLittleEndian64(Bytes(0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12),
286                                      0x123456789abcdef0L);
287             AssertReadLittleEndian64(
288                 Bytes(0x78, 0x56, 0x34, 0x12, 0xf0, 0xde, 0xbc, 0x9a), 0x9abcdef012345678UL);
289         }
290 
291         [Test]
DecodeZigZag32()292         public void DecodeZigZag32()
293         {
294             Assert.AreEqual(0, ParsingPrimitives.DecodeZigZag32(0));
295             Assert.AreEqual(-1, ParsingPrimitives.DecodeZigZag32(1));
296             Assert.AreEqual(1, ParsingPrimitives.DecodeZigZag32(2));
297             Assert.AreEqual(-2, ParsingPrimitives.DecodeZigZag32(3));
298             Assert.AreEqual(0x3FFFFFFF, ParsingPrimitives.DecodeZigZag32(0x7FFFFFFE));
299             Assert.AreEqual(unchecked((int) 0xC0000000), ParsingPrimitives.DecodeZigZag32(0x7FFFFFFF));
300             Assert.AreEqual(0x7FFFFFFF, ParsingPrimitives.DecodeZigZag32(0xFFFFFFFE));
301             Assert.AreEqual(unchecked((int) 0x80000000), ParsingPrimitives.DecodeZigZag32(0xFFFFFFFF));
302         }
303 
304         [Test]
DecodeZigZag64()305         public void DecodeZigZag64()
306         {
307             Assert.AreEqual(0, ParsingPrimitives.DecodeZigZag64(0));
308             Assert.AreEqual(-1, ParsingPrimitives.DecodeZigZag64(1));
309             Assert.AreEqual(1, ParsingPrimitives.DecodeZigZag64(2));
310             Assert.AreEqual(-2, ParsingPrimitives.DecodeZigZag64(3));
311             Assert.AreEqual(0x000000003FFFFFFFL, ParsingPrimitives.DecodeZigZag64(0x000000007FFFFFFEL));
312             Assert.AreEqual(unchecked((long) 0xFFFFFFFFC0000000L), ParsingPrimitives.DecodeZigZag64(0x000000007FFFFFFFL));
313             Assert.AreEqual(0x000000007FFFFFFFL, ParsingPrimitives.DecodeZigZag64(0x00000000FFFFFFFEL));
314             Assert.AreEqual(unchecked((long) 0xFFFFFFFF80000000L), ParsingPrimitives.DecodeZigZag64(0x00000000FFFFFFFFL));
315             Assert.AreEqual(0x7FFFFFFFFFFFFFFFL, ParsingPrimitives.DecodeZigZag64(0xFFFFFFFFFFFFFFFEL));
316             Assert.AreEqual(unchecked((long) 0x8000000000000000L), ParsingPrimitives.DecodeZigZag64(0xFFFFFFFFFFFFFFFFL));
317         }
318 
319         [Test]
ReadWholeMessage_VaryingBlockSizes()320         public void ReadWholeMessage_VaryingBlockSizes()
321         {
322             TestAllTypes message = SampleMessages.CreateFullTestAllTypes();
323 
324             byte[] rawBytes = message.ToByteArray();
325             Assert.AreEqual(rawBytes.Length, message.CalculateSize());
326             TestAllTypes message2 = TestAllTypes.Parser.ParseFrom(rawBytes);
327             Assert.AreEqual(message, message2);
328 
329             // Try different block sizes.
330             for (int blockSize = 1; blockSize < 256; blockSize *= 2)
331             {
332                 message2 = TestAllTypes.Parser.ParseFrom(new SmallBlockInputStream(rawBytes, blockSize));
333                 Assert.AreEqual(message, message2);
334             }
335         }
336 
337         [Test]
ReadWholeMessage_VaryingBlockSizes_FromSequence()338         public void ReadWholeMessage_VaryingBlockSizes_FromSequence()
339         {
340             TestAllTypes message = SampleMessages.CreateFullTestAllTypes();
341 
342             byte[] rawBytes = message.ToByteArray();
343             Assert.AreEqual(rawBytes.Length, message.CalculateSize());
344             TestAllTypes message2 = TestAllTypes.Parser.ParseFrom(rawBytes);
345             Assert.AreEqual(message, message2);
346 
347             // Try different block sizes.
348             for (int blockSize = 1; blockSize < 256; blockSize *= 2)
349             {
350                 message2 = TestAllTypes.Parser.ParseFrom(ReadOnlySequenceFactory.CreateWithContent(rawBytes, blockSize));
351                 Assert.AreEqual(message, message2);
352             }
353         }
354 
355         [Test]
ReadInt32Wrapper_VariableBlockSizes()356         public void ReadInt32Wrapper_VariableBlockSizes()
357         {
358             byte[] rawBytes = new byte[] { 202, 1, 11, 8, 254, 255, 255, 255, 255, 255, 255, 255, 255, 1 };
359 
360             for (int blockSize = 1; blockSize <= rawBytes.Length; blockSize++)
361             {
362                 ReadOnlySequence<byte> data = ReadOnlySequenceFactory.CreateWithContent(rawBytes, blockSize);
363                 AssertReadFromParseContext(data, (ref ParseContext ctx) =>
364                 {
365                     ctx.ReadTag();
366 
367                     var value = ParsingPrimitivesWrappers.ReadInt32Wrapper(ref ctx);
368 
369                     Assert.AreEqual(-2, value);
370                 }, true);
371             }
372         }
373 
374         [Test]
ReadHugeBlob()375         public void ReadHugeBlob()
376         {
377             // Allocate and initialize a 1MB blob.
378             byte[] blob = new byte[1 << 20];
379             for (int i = 0; i < blob.Length; i++)
380             {
381                 blob[i] = (byte) i;
382             }
383 
384             // Make a message containing it.
385             var message = new TestAllTypes { SingleBytes = ByteString.CopyFrom(blob) };
386 
387             // Serialize and parse it.  Make sure to parse from an InputStream, not
388             // directly from a ByteString, so that CodedInputStream uses buffered
389             // reading.
390             TestAllTypes message2 = TestAllTypes.Parser.ParseFrom(message.ToByteString());
391 
392             Assert.AreEqual(message, message2);
393         }
394 
395         [Test]
ReadMaliciouslyLargeBlob()396         public void ReadMaliciouslyLargeBlob()
397         {
398             MemoryStream ms = new MemoryStream();
399             CodedOutputStream output = new CodedOutputStream(ms);
400 
401             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
402             output.WriteRawVarint32(tag);
403             output.WriteRawVarint32(0x7FFFFFFF);
404             output.WriteRawBytes(new byte[32]); // Pad with a few random bytes.
405             output.Flush();
406             ms.Position = 0;
407 
408             CodedInputStream input = new CodedInputStream(ms);
409             Assert.AreEqual(tag, input.ReadTag());
410 
411             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadBytes());
412         }
413 
414         [Test]
ReadBlobGreaterThanCurrentLimit()415         public void ReadBlobGreaterThanCurrentLimit()
416         {
417             MemoryStream ms = new MemoryStream();
418             CodedOutputStream output = new CodedOutputStream(ms);
419             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
420             output.WriteRawVarint32(tag);
421             output.WriteRawVarint32(4);
422             output.WriteRawBytes(new byte[4]); // Pad with a few random bytes.
423             output.Flush();
424             ms.Position = 0;
425 
426             CodedInputStream input = new CodedInputStream(ms);
427             Assert.AreEqual(tag, input.ReadTag());
428 
429             // Specify limit smaller than data length
430             input.PushLimit(3);
431             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadBytes());
432 
433             AssertReadFromParseContext(new ReadOnlySequence<byte>(ms.ToArray()), (ref ParseContext ctx) =>
434             {
435                 Assert.AreEqual(tag, ctx.ReadTag());
436                 SegmentedBufferHelper.PushLimit(ref ctx.state, 3);
437                 try
438                 {
439                     ctx.ReadBytes();
440                     Assert.Fail();
441                 }
442                 catch (InvalidProtocolBufferException) {}
443             }, true);
444         }
445 
446         [Test]
ReadStringGreaterThanCurrentLimit()447         public void ReadStringGreaterThanCurrentLimit()
448         {
449             MemoryStream ms = new MemoryStream();
450             CodedOutputStream output = new CodedOutputStream(ms);
451             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
452             output.WriteRawVarint32(tag);
453             output.WriteRawVarint32(4);
454             output.WriteRawBytes(new byte[4]); // Pad with a few random bytes.
455             output.Flush();
456             ms.Position = 0;
457 
458             CodedInputStream input = new CodedInputStream(ms.ToArray());
459             Assert.AreEqual(tag, input.ReadTag());
460 
461             // Specify limit smaller than data length
462             input.PushLimit(3);
463             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadString());
464 
465             AssertReadFromParseContext(new ReadOnlySequence<byte>(ms.ToArray()), (ref ParseContext ctx) =>
466             {
467                 Assert.AreEqual(tag, ctx.ReadTag());
468                 SegmentedBufferHelper.PushLimit(ref ctx.state, 3);
469                 try
470                 {
471                     ctx.ReadString();
472                     Assert.Fail();
473                 }
474                 catch (InvalidProtocolBufferException) { }
475             }, true);
476         }
477 
478         // Representations of a tag for field 0 with various wire types
479         [Test]
480         [TestCase(0)]
481         [TestCase(1)]
482         [TestCase(2)]
483         [TestCase(3)]
484         [TestCase(4)]
485         [TestCase(5)]
ReadTag_ZeroFieldRejected(byte tag)486         public void ReadTag_ZeroFieldRejected(byte tag)
487         {
488             CodedInputStream cis = new CodedInputStream(new byte[] { tag });
489             Assert.Throws<InvalidProtocolBufferException>(() => cis.ReadTag());
490         }
491 
MakeRecursiveMessage(int depth)492         internal static TestRecursiveMessage MakeRecursiveMessage(int depth)
493         {
494             if (depth == 0)
495             {
496                 return new TestRecursiveMessage { I = 5 };
497             }
498             else
499             {
500                 return new TestRecursiveMessage { A = MakeRecursiveMessage(depth - 1) };
501             }
502         }
503 
AssertMessageDepth(TestRecursiveMessage message, int depth)504         internal static void AssertMessageDepth(TestRecursiveMessage message, int depth)
505         {
506             if (depth == 0)
507             {
508                 Assert.IsNull(message.A);
509                 Assert.AreEqual(5, message.I);
510             }
511             else
512             {
513                 Assert.IsNotNull(message.A);
514                 AssertMessageDepth(message.A, depth - 1);
515             }
516         }
517 
518         [Test]
MaliciousRecursion()519         public void MaliciousRecursion()
520         {
521             ByteString atRecursiveLimit = MakeRecursiveMessage(CodedInputStream.DefaultRecursionLimit).ToByteString();
522             ByteString beyondRecursiveLimit = MakeRecursiveMessage(CodedInputStream.DefaultRecursionLimit + 1).ToByteString();
523 
524             AssertMessageDepth(TestRecursiveMessage.Parser.ParseFrom(atRecursiveLimit), CodedInputStream.DefaultRecursionLimit);
525 
526             Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(beyondRecursiveLimit));
527 
528             CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
529             Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
530         }
531 
MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)532         private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
533         {
534             // generate recursively nested groups that will be parsed as unknown fields
535             int unknownFieldNumber = 14;  // an unused field number
536             MemoryStream ms = new MemoryStream();
537             CodedOutputStream output = new CodedOutputStream(ms);
538             for (int i = 0; i < recursionDepth; i++)
539             {
540                 output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
541             }
542             for (int i = 0; i < recursionDepth; i++)
543             {
544                 output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
545             }
546             output.Flush();
547             return ms.ToArray();
548         }
549 
550         [Test]
MaliciousRecursion_UnknownFields()551         public void MaliciousRecursion_UnknownFields()
552         {
553             byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
554             byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);
555 
556             Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
557             Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
558         }
559 
560         [Test]
ReadGroup_WrongEndGroupTag()561         public void ReadGroup_WrongEndGroupTag()
562         {
563             int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;
564 
565             // write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
566             MemoryStream ms = new MemoryStream();
567             CodedOutputStream output = new CodedOutputStream(ms);
568             output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
569             output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
570             // end group with different field number
571             output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
572             output.Flush();
573             var payload = ms.ToArray();
574 
575             Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
576         }
577 
578         [Test]
ReadGroup_UnknownFields_WrongEndGroupTag()579         public void ReadGroup_UnknownFields_WrongEndGroupTag()
580         {
581             MemoryStream ms = new MemoryStream();
582             CodedOutputStream output = new CodedOutputStream(ms);
583             output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
584             // end group with different field number
585             output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
586             output.Flush();
587             var payload = ms.ToArray();
588 
589             Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
590         }
591 
592         [Test]
SizeLimit()593         public void SizeLimit()
594         {
595             // Have to use a Stream rather than ByteString.CreateCodedInput as SizeLimit doesn't
596             // apply to the latter case.
597             MemoryStream ms = new MemoryStream(SampleMessages.CreateFullTestAllTypes().ToByteArray());
598             CodedInputStream input = CodedInputStream.CreateWithLimits(ms, 16, 100);
599             Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(input));
600         }
601 
602         /// <summary>
603         /// Tests that if we read an string that contains invalid UTF-8, no exception
604         /// is thrown.  Instead, the invalid bytes are replaced with the Unicode
605         /// "replacement character" U+FFFD.
606         /// </summary>
607         [Test]
ReadInvalidUtf8()608         public void ReadInvalidUtf8()
609         {
610             MemoryStream ms = new MemoryStream();
611             CodedOutputStream output = new CodedOutputStream(ms);
612 
613             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
614             output.WriteRawVarint32(tag);
615             output.WriteRawVarint32(1);
616             output.WriteRawBytes(new byte[] {0x80});
617             output.Flush();
618             ms.Position = 0;
619 
620             CodedInputStream input = new CodedInputStream(ms);
621 
622             Assert.AreEqual(tag, input.ReadTag());
623             string text = input.ReadString();
624             Assert.AreEqual('\ufffd', text[0]);
625         }
626 
627         [Test]
ReadNegativeSizedStringThrowsInvalidProtocolBufferException()628         public void ReadNegativeSizedStringThrowsInvalidProtocolBufferException()
629         {
630             MemoryStream ms = new MemoryStream();
631             CodedOutputStream output = new CodedOutputStream(ms);
632 
633             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
634             output.WriteRawVarint32(tag);
635             output.WriteLength(-1);
636             output.Flush();
637             ms.Position = 0;
638 
639             CodedInputStream input = new CodedInputStream(ms);
640 
641             Assert.AreEqual(tag, input.ReadTag());
642             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadString());
643         }
644 
645         [Test]
ReadNegativeSizedBytesThrowsInvalidProtocolBufferException()646         public void ReadNegativeSizedBytesThrowsInvalidProtocolBufferException()
647         {
648             MemoryStream ms = new MemoryStream();
649             CodedOutputStream output = new CodedOutputStream(ms);
650 
651             uint tag = WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited);
652             output.WriteRawVarint32(tag);
653             output.WriteLength(-1);
654             output.Flush();
655             ms.Position = 0;
656 
657             CodedInputStream input = new CodedInputStream(ms);
658 
659             Assert.AreEqual(tag, input.ReadTag());
660             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadBytes());
661         }
662 
663         /// <summary>
664         /// A stream which limits the number of bytes it reads at a time.
665         /// We use this to make sure that CodedInputStream doesn't screw up when
666         /// reading in small blocks.
667         /// </summary>
668         private sealed class SmallBlockInputStream : MemoryStream
669         {
670             private readonly int blockSize;
671 
SmallBlockInputStream(byte[] data, int blockSize)672             public SmallBlockInputStream(byte[] data, int blockSize)
673                 : base(data)
674             {
675                 this.blockSize = blockSize;
676             }
677 
Read(byte[] buffer, int offset, int count)678             public override int Read(byte[] buffer, int offset, int count)
679             {
680                 return base.Read(buffer, offset, Math.Min(count, blockSize));
681             }
682         }
683 
684         [Test]
TestNegativeEnum()685         public void TestNegativeEnum()
686         {
687             byte[] bytes = { 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01 };
688             CodedInputStream input = new CodedInputStream(bytes);
689             Assert.AreEqual((int)SampleEnum.NegativeValue, input.ReadEnum());
690             Assert.IsTrue(input.IsAtEnd);
691         }
692 
693         //Issue 71:	CodedInputStream.ReadBytes go to slow path unnecessarily
694         [Test]
TestSlowPathAvoidance()695         public void TestSlowPathAvoidance()
696         {
697             using (var ms = new MemoryStream())
698             {
699                 CodedOutputStream output = new CodedOutputStream(ms);
700                 output.WriteTag(1, WireFormat.WireType.LengthDelimited);
701                 output.WriteBytes(ByteString.CopyFrom(new byte[100]));
702                 output.WriteTag(2, WireFormat.WireType.LengthDelimited);
703                 output.WriteBytes(ByteString.CopyFrom(new byte[100]));
704                 output.Flush();
705 
706                 ms.Position = 0;
707                 CodedInputStream input = new CodedInputStream(ms, new byte[ms.Length / 2], 0, 0, false);
708 
709                 uint tag = input.ReadTag();
710                 Assert.AreEqual(1, WireFormat.GetTagFieldNumber(tag));
711                 Assert.AreEqual(100, input.ReadBytes().Length);
712 
713                 tag = input.ReadTag();
714                 Assert.AreEqual(2, WireFormat.GetTagFieldNumber(tag));
715                 Assert.AreEqual(100, input.ReadBytes().Length);
716             }
717         }
718 
719         [Test]
MaximumFieldNumber()720         public void MaximumFieldNumber()
721         {
722             MemoryStream ms = new MemoryStream();
723             CodedOutputStream output = new CodedOutputStream(ms);
724 
725             int fieldNumber = 0x1FFFFFFF;
726             uint tag = WireFormat.MakeTag(fieldNumber, WireFormat.WireType.LengthDelimited);
727             output.WriteRawVarint32(tag);
728             output.WriteString("field 1");
729             output.Flush();
730             ms.Position = 0;
731 
732             CodedInputStream input = new CodedInputStream(ms);
733 
734             Assert.AreEqual(tag, input.ReadTag());
735             Assert.AreEqual(fieldNumber, WireFormat.GetTagFieldNumber(tag));
736         }
737 
738         [Test]
Tag0Throws()739         public void Tag0Throws()
740         {
741             var input = new CodedInputStream(new byte[] { 0 });
742             Assert.Throws<InvalidProtocolBufferException>(() => input.ReadTag());
743         }
744 
745         [Test]
SkipGroup()746         public void SkipGroup()
747         {
748             // Create an output stream with a group in:
749             // Field 1: string "field 1"
750             // Field 2: group containing:
751             //   Field 1: fixed int32 value 100
752             //   Field 2: string "ignore me"
753             //   Field 3: nested group containing
754             //      Field 1: fixed int64 value 1000
755             // Field 3: string "field 3"
756             var stream = new MemoryStream();
757             var output = new CodedOutputStream(stream);
758             output.WriteTag(1, WireFormat.WireType.LengthDelimited);
759             output.WriteString("field 1");
760 
761             // The outer group...
762             output.WriteTag(2, WireFormat.WireType.StartGroup);
763             output.WriteTag(1, WireFormat.WireType.Fixed32);
764             output.WriteFixed32(100);
765             output.WriteTag(2, WireFormat.WireType.LengthDelimited);
766             output.WriteString("ignore me");
767             // The nested group...
768             output.WriteTag(3, WireFormat.WireType.StartGroup);
769             output.WriteTag(1, WireFormat.WireType.Fixed64);
770             output.WriteFixed64(1000);
771             // Note: Not sure the field number is relevant for end group...
772             output.WriteTag(3, WireFormat.WireType.EndGroup);
773 
774             // End the outer group
775             output.WriteTag(2, WireFormat.WireType.EndGroup);
776 
777             output.WriteTag(3, WireFormat.WireType.LengthDelimited);
778             output.WriteString("field 3");
779             output.Flush();
780             stream.Position = 0;
781 
782             // Now act like a generated client
783             var input = new CodedInputStream(stream);
784             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
785             Assert.AreEqual("field 1", input.ReadString());
786             Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
787             input.SkipLastField(); // Should consume the whole group, including the nested one.
788             Assert.AreEqual(WireFormat.MakeTag(3, WireFormat.WireType.LengthDelimited), input.ReadTag());
789             Assert.AreEqual("field 3", input.ReadString());
790         }
791 
792         [Test]
SkipGroup_WrongEndGroupTag()793         public void SkipGroup_WrongEndGroupTag()
794         {
795             // Create an output stream with:
796             // Field 1: string "field 1"
797             // Start group 2
798             //   Field 3: fixed int32
799             // End group 4 (should give an error)
800             var stream = new MemoryStream();
801             var output = new CodedOutputStream(stream);
802             output.WriteTag(1, WireFormat.WireType.LengthDelimited);
803             output.WriteString("field 1");
804 
805             // The outer group...
806             output.WriteTag(2, WireFormat.WireType.StartGroup);
807             output.WriteTag(3, WireFormat.WireType.Fixed32);
808             output.WriteFixed32(100);
809             output.WriteTag(4, WireFormat.WireType.EndGroup);
810             output.Flush();
811             stream.Position = 0;
812 
813             // Now act like a generated client
814             var input = new CodedInputStream(stream);
815             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.LengthDelimited), input.ReadTag());
816             Assert.AreEqual("field 1", input.ReadString());
817             Assert.AreEqual(WireFormat.MakeTag(2, WireFormat.WireType.StartGroup), input.ReadTag());
818             Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
819         }
820 
821         [Test]
RogueEndGroupTag()822         public void RogueEndGroupTag()
823         {
824             // If we have an end-group tag without a leading start-group tag, generated
825             // code will just call SkipLastField... so that should fail.
826 
827             var stream = new MemoryStream();
828             var output = new CodedOutputStream(stream);
829             output.WriteTag(1, WireFormat.WireType.EndGroup);
830             output.Flush();
831             stream.Position = 0;
832 
833             var input = new CodedInputStream(stream);
834             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.EndGroup), input.ReadTag());
835             Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
836         }
837 
838         [Test]
EndOfStreamReachedWhileSkippingGroup()839         public void EndOfStreamReachedWhileSkippingGroup()
840         {
841             var stream = new MemoryStream();
842             var output = new CodedOutputStream(stream);
843             output.WriteTag(1, WireFormat.WireType.StartGroup);
844             output.WriteTag(2, WireFormat.WireType.StartGroup);
845             output.WriteTag(2, WireFormat.WireType.EndGroup);
846 
847             output.Flush();
848             stream.Position = 0;
849 
850             // Now act like a generated client
851             var input = new CodedInputStream(stream);
852             input.ReadTag();
853             Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
854         }
855 
856         [Test]
RecursionLimitAppliedWhileSkippingGroup()857         public void RecursionLimitAppliedWhileSkippingGroup()
858         {
859             var stream = new MemoryStream();
860             var output = new CodedOutputStream(stream);
861             for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
862             {
863                 output.WriteTag(1, WireFormat.WireType.StartGroup);
864             }
865             for (int i = 0; i < CodedInputStream.DefaultRecursionLimit + 1; i++)
866             {
867                 output.WriteTag(1, WireFormat.WireType.EndGroup);
868             }
869             output.Flush();
870             stream.Position = 0;
871 
872             // Now act like a generated client
873             var input = new CodedInputStream(stream);
874             Assert.AreEqual(WireFormat.MakeTag(1, WireFormat.WireType.StartGroup), input.ReadTag());
875             Assert.Throws<InvalidProtocolBufferException>(input.SkipLastField);
876         }
877 
878         [Test]
Construction_Invalid()879         public void Construction_Invalid()
880         {
881             Assert.Throws<ArgumentNullException>(() => new CodedInputStream((byte[]) null));
882             Assert.Throws<ArgumentNullException>(() => new CodedInputStream(null, 0, 0));
883             Assert.Throws<ArgumentNullException>(() => new CodedInputStream((Stream) null));
884             Assert.Throws<ArgumentOutOfRangeException>(() => new CodedInputStream(new byte[10], 100, 0));
885             Assert.Throws<ArgumentOutOfRangeException>(() => new CodedInputStream(new byte[10], 5, 10));
886         }
887 
888         [Test]
CreateWithLimits_InvalidLimits()889         public void CreateWithLimits_InvalidLimits()
890         {
891             var stream = new MemoryStream();
892             Assert.Throws<ArgumentOutOfRangeException>(() => CodedInputStream.CreateWithLimits(stream, 0, 1));
893             Assert.Throws<ArgumentOutOfRangeException>(() => CodedInputStream.CreateWithLimits(stream, 1, 0));
894         }
895 
896         [Test]
Dispose_DisposesUnderlyingStream()897         public void Dispose_DisposesUnderlyingStream()
898         {
899             var memoryStream = new MemoryStream();
900             Assert.IsTrue(memoryStream.CanRead);
901             using (var cis = new CodedInputStream(memoryStream))
902             {
903             }
904             Assert.IsFalse(memoryStream.CanRead); // Disposed
905         }
906 
907         [Test]
Dispose_WithLeaveOpen()908         public void Dispose_WithLeaveOpen()
909         {
910             var memoryStream = new MemoryStream();
911             Assert.IsTrue(memoryStream.CanRead);
912             using (var cis = new CodedInputStream(memoryStream, true))
913             {
914             }
915             Assert.IsTrue(memoryStream.CanRead); // We left the stream open
916         }
917 
918         [Test]
Dispose_FromByteArray()919         public void Dispose_FromByteArray()
920         {
921             var stream = new CodedInputStream(new byte[10]);
922             stream.Dispose();
923         }
924 
925         [Test]
TestParseMessagesCloseTo2G()926         public void TestParseMessagesCloseTo2G()
927         {
928             byte[] serializedMessage = GenerateBigSerializedMessage();
929             // How many of these big messages do we need to take us near our 2GB limit?
930             int count = Int32.MaxValue / serializedMessage.Length;
931             // Now make a MemoryStream that will fake a near-2GB stream of messages by returning
932             // our big serialized message 'count' times.
933             using (RepeatingMemoryStream stream = new RepeatingMemoryStream(serializedMessage, count))
934             {
935                 Assert.DoesNotThrow(()=>TestAllTypes.Parser.ParseFrom(stream));
936             }
937         }
938 
939         [Test]
TestParseMessagesOver2G()940         public void TestParseMessagesOver2G()
941         {
942             byte[] serializedMessage = GenerateBigSerializedMessage();
943             // How many of these big messages do we need to take us near our 2GB limit?
944             int count = Int32.MaxValue / serializedMessage.Length;
945             // Now add one to take us over the 2GB limit
946             count++;
947             // Now make a MemoryStream that will fake a near-2GB stream of messages by returning
948             // our big serialized message 'count' times.
949             using (RepeatingMemoryStream stream = new RepeatingMemoryStream(serializedMessage, count))
950             {
951                 Assert.Throws<InvalidProtocolBufferException>(() => TestAllTypes.Parser.ParseFrom(stream),
952                     "Protocol message was too large.  May be malicious.  " +
953                     "Use CodedInputStream.SetSizeLimit() to increase the size limit.");
954             }
955         }
956 
957         /// <returns>A serialized big message</returns>
GenerateBigSerializedMessage()958         private static byte[] GenerateBigSerializedMessage()
959         {
960             byte[] value = new byte[16 * 1024 * 1024];
961             TestAllTypes message = SampleMessages.CreateFullTestAllTypes();
962             message.SingleBytes = ByteString.CopyFrom(value);
963             return message.ToByteArray();
964         }
965 
966         /// <summary>
967         /// A MemoryStream that repeats a byte arrays' content a number of times.
968         /// Simulates really large input without consuming loads of memory. Used above
969         /// to test the parsing behavior when the input size exceeds 2GB or close to it.
970         /// </summary>
971         private class RepeatingMemoryStream: MemoryStream
972         {
973             private readonly byte[] bytes;
974             private readonly int maxIterations;
975             private int index = 0;
976 
RepeatingMemoryStream(byte[] bytes, int maxIterations)977             public RepeatingMemoryStream(byte[] bytes, int maxIterations)
978             {
979                 this.bytes = bytes;
980                 this.maxIterations = maxIterations;
981             }
982 
Read(byte[] buffer, int offset, int count)983             public override int Read(byte[] buffer, int offset, int count)
984             {
985                 if (bytes.Length == 0)
986                 {
987                     return 0;
988                 }
989                 int numBytesCopiedTotal = 0;
990                 while (numBytesCopiedTotal < count && index < maxIterations)
991                 {
992                     int numBytesToCopy = Math.Min(bytes.Length - (int)Position, count);
993                     Array.Copy(bytes, (int)Position, buffer, offset, numBytesToCopy);
994                     numBytesCopiedTotal += numBytesToCopy;
995                     offset += numBytesToCopy;
996                     count -= numBytesCopiedTotal;
997                     Position += numBytesToCopy;
998                     if (Position >= bytes.Length)
999                     {
1000                         Position = 0;
1001                         index++;
1002                     }
1003                 }
1004                 return numBytesCopiedTotal;
1005             }
1006         }
1007     }
1008 }
1009