• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.protobuf.lite;
18 
19 import static org.junit.Assert.assertArrayEquals;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertNotNull;
22 import static org.junit.Assert.assertSame;
23 import static org.junit.Assert.assertThrows;
24 import static org.junit.Assert.fail;
25 
26 import com.google.common.io.ByteStreams;
27 import com.google.protobuf.ByteString;
28 import com.google.protobuf.Empty;
29 import com.google.protobuf.Enum;
30 import com.google.protobuf.InvalidProtocolBufferException;
31 import com.google.protobuf.Type;
32 import io.grpc.Drainable;
33 import io.grpc.KnownLength;
34 import io.grpc.Metadata;
35 import io.grpc.MethodDescriptor.Marshaller;
36 import io.grpc.MethodDescriptor.PrototypeMarshaller;
37 import io.grpc.Status;
38 import io.grpc.StatusRuntimeException;
39 import io.grpc.internal.GrpcUtil;
40 import io.grpc.testing.protobuf.SimpleRecursiveMessage;
41 import java.io.ByteArrayInputStream;
42 import java.io.ByteArrayOutputStream;
43 import java.io.IOException;
44 import java.io.InputStream;
45 import java.util.Arrays;
46 import org.junit.Rule;
47 import org.junit.Test;
48 import org.junit.rules.ExpectedException;
49 import org.junit.runner.RunWith;
50 import org.junit.runners.JUnit4;
51 
52 /** Unit tests for {@link ProtoLiteUtils}. */
53 @RunWith(JUnit4.class)
54 public class ProtoLiteUtilsTest {
55 
56   @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
57   @Rule public final ExpectedException thrown = ExpectedException.none();
58 
59   private final Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
60   private Type proto = Type.newBuilder().setName("name").build();
61 
62   @Test
testPassthrough()63   public void testPassthrough() {
64     assertSame(proto, marshaller.parse(marshaller.stream(proto)));
65   }
66 
67   @Test
testRoundtrip()68   public void testRoundtrip() throws Exception {
69     InputStream is = marshaller.stream(proto);
70     is = new ByteArrayInputStream(ByteStreams.toByteArray(is));
71     assertEquals(proto, marshaller.parse(is));
72   }
73 
74   @Test
testInvalidatedMessage()75   public void testInvalidatedMessage() throws Exception {
76     InputStream is = marshaller.stream(proto);
77     // Invalidates message, and drains all bytes
78     byte[] unused = ByteStreams.toByteArray(is);
79     try {
80       ((ProtoInputStream) is).message();
81       fail("Expected exception");
82     } catch (IllegalStateException ex) {
83       // expected
84     }
85     // Zero bytes is the default message
86     assertEquals(Type.getDefaultInstance(), marshaller.parse(is));
87   }
88 
89   @Test
parseInvalid()90   public void parseInvalid() {
91     InputStream is = new ByteArrayInputStream(new byte[] {-127});
92     try {
93       marshaller.parse(is);
94       fail("Expected exception");
95     } catch (StatusRuntimeException ex) {
96       assertEquals(Status.Code.INTERNAL, ex.getStatus().getCode());
97       assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage());
98     }
99   }
100 
101   @Test
testMismatch()102   public void testMismatch() {
103     Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
104     // Enum's name and Type's name are both strings with tag 1.
105     Enum altProto = Enum.newBuilder().setName(proto.getName()).build();
106     assertEquals(proto, marshaller.parse(enumMarshaller.stream(altProto)));
107   }
108 
109   @Test
introspection()110   public void introspection() {
111     Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
112     PrototypeMarshaller<Enum> prototypeMarshaller = (PrototypeMarshaller<Enum>) enumMarshaller;
113     assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype());
114     assertSame(Enum.class, prototypeMarshaller.getMessageClass());
115   }
116 
117   @Test
marshallerShouldNotLimitProtoSize()118   public void marshallerShouldNotLimitProtoSize() throws Exception {
119     // The default limit is 64MB. Using a larger proto to verify that the limit is not enforced.
120     byte[] bigName = new byte[70 * 1024 * 1024];
121     Arrays.fill(bigName, (byte) 32);
122 
123     proto = Type.newBuilder().setNameBytes(ByteString.copyFrom(bigName)).build();
124 
125     // Just perform a round trip to verify that it works.
126     testRoundtrip();
127   }
128 
129   @Test
testAvailable()130   public void testAvailable() throws Exception {
131     InputStream is = marshaller.stream(proto);
132     assertEquals(proto.getSerializedSize(), is.available());
133     is.read();
134     assertEquals(proto.getSerializedSize() - 1, is.available());
135     while (is.read() != -1) {}
136     assertEquals(-1, is.read());
137     assertEquals(0, is.available());
138   }
139 
140   @Test
testEmpty()141   public void testEmpty() throws IOException {
142     Marshaller<Empty> marshaller = ProtoLiteUtils.marshaller(Empty.getDefaultInstance());
143     InputStream is = marshaller.stream(Empty.getDefaultInstance());
144     assertEquals(0, is.available());
145     byte[] b = new byte[10];
146     assertEquals(-1, is.read(b));
147     assertArrayEquals(new byte[10], b);
148     // Do the same thing again, because the internal state may be different
149     assertEquals(-1, is.read(b));
150     assertArrayEquals(new byte[10], b);
151     assertEquals(-1, is.read());
152     assertEquals(0, is.available());
153   }
154 
155   @Test
testDrainTo_all()156   public void testDrainTo_all() throws Exception {
157     byte[] golden = ByteStreams.toByteArray(marshaller.stream(proto));
158     InputStream is = marshaller.stream(proto);
159     Drainable d = (Drainable) is;
160     ByteArrayOutputStream baos = new ByteArrayOutputStream();
161     int drained = d.drainTo(baos);
162     assertEquals(baos.size(), drained);
163     assertArrayEquals(golden, baos.toByteArray());
164     assertEquals(0, is.available());
165   }
166 
167   @Test
testDrainTo_partial()168   public void testDrainTo_partial() throws Exception {
169     final byte[] golden;
170     {
171       InputStream is = marshaller.stream(proto);
172       is.read();
173       golden = ByteStreams.toByteArray(is);
174     }
175     InputStream is = marshaller.stream(proto);
176     is.read();
177     Drainable d = (Drainable) is;
178     ByteArrayOutputStream baos = new ByteArrayOutputStream();
179     int drained = d.drainTo(baos);
180     assertEquals(baos.size(), drained);
181     assertArrayEquals(golden, baos.toByteArray());
182     assertEquals(0, is.available());
183   }
184 
185   @Test
testDrainTo_none()186   public void testDrainTo_none() throws Exception {
187     InputStream is = marshaller.stream(proto);
188     byte[] unused = ByteStreams.toByteArray(is);
189     Drainable d = (Drainable) is;
190     ByteArrayOutputStream baos = new ByteArrayOutputStream();
191     assertEquals(0, d.drainTo(baos));
192     assertArrayEquals(new byte[0], baos.toByteArray());
193     assertEquals(0, is.available());
194   }
195 
196   @Test
metadataMarshaller_roundtrip()197   public void metadataMarshaller_roundtrip() {
198     Metadata.BinaryMarshaller<Type> metadataMarshaller =
199         ProtoLiteUtils.metadataMarshaller(Type.getDefaultInstance());
200     assertEquals(proto, metadataMarshaller.parseBytes(metadataMarshaller.toBytes(proto)));
201   }
202 
203   @Test
metadataMarshaller_invalid()204   public void metadataMarshaller_invalid() {
205     Metadata.BinaryMarshaller<Type> metadataMarshaller =
206         ProtoLiteUtils.metadataMarshaller(Type.getDefaultInstance());
207     try {
208       metadataMarshaller.parseBytes(new byte[] {-127});
209       fail("Expected exception");
210     } catch (IllegalArgumentException ex) {
211       assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage());
212     }
213   }
214 
215   @Test
extensionRegistry_notNull()216   public void extensionRegistry_notNull() {
217     thrown.expect(NullPointerException.class);
218     thrown.expectMessage("newRegistry");
219 
220     ProtoLiteUtils.setExtensionRegistry(null);
221   }
222 
223   @Test
parseFromKnowLengthInputStream()224   public void parseFromKnowLengthInputStream() {
225     Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
226     Type expect = Type.newBuilder().setName("expected name").build();
227 
228     Type result = marshaller.parse(new CustomKnownLengthInputStream(expect.toByteArray()));
229     assertEquals(expect, result);
230   }
231 
232   @Test
defaultMaxMessageSize()233   public void defaultMaxMessageSize() {
234     assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE);
235   }
236 
237   @Test
testNullDefaultInstance()238   public void testNullDefaultInstance() {
239     String expectedMessage = "defaultInstance cannot be null";
240     assertThrows(expectedMessage, NullPointerException.class,
241         () -> ProtoLiteUtils.marshaller(null));
242 
243     assertThrows(expectedMessage, NullPointerException.class,
244         () -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10)
245     );
246   }
247 
248   @Test
givenPositiveLimit_testRecursionLimitExceeded()249   public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException {
250     Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
251         SimpleRecursiveMessage.getDefaultInstance(), 10);
252     SimpleRecursiveMessage message = buildRecursiveMessage(12);
253 
254     assertRecursionLimitExceeded(marshaller, message);
255   }
256 
257   @Test
givenZeroLimit_testRecursionLimitExceeded()258   public void givenZeroLimit_testRecursionLimitExceeded() throws IOException {
259     Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
260         SimpleRecursiveMessage.getDefaultInstance(), 0);
261     SimpleRecursiveMessage message = buildRecursiveMessage(1);
262 
263     assertRecursionLimitExceeded(marshaller, message);
264   }
265 
266   @Test
givenPositiveLimit_testRecursionLimitNotExceeded()267   public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException {
268     Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
269         SimpleRecursiveMessage.getDefaultInstance(), 15);
270     SimpleRecursiveMessage message = buildRecursiveMessage(12);
271 
272     assertRecursionLimitNotExceeded(marshaller, message);
273   }
274 
275   @Test
givenZeroLimit_testRecursionLimitNotExceeded()276   public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException {
277     Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
278         SimpleRecursiveMessage.getDefaultInstance(), 0);
279     SimpleRecursiveMessage message = buildRecursiveMessage(0);
280 
281     assertRecursionLimitNotExceeded(marshaller, message);
282   }
283 
284   @Test
testDefaultRecursionLimit()285   public void testDefaultRecursionLimit() throws IOException {
286     Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshaller(
287         SimpleRecursiveMessage.getDefaultInstance());
288     SimpleRecursiveMessage message = buildRecursiveMessage(100);
289 
290     assertRecursionLimitNotExceeded(marshaller, message);
291   }
292 
assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller, SimpleRecursiveMessage message)293   private static void assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
294       SimpleRecursiveMessage message) throws IOException {
295     InputStream is = marshaller.stream(message);
296     ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
297 
298     assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais));
299   }
300 
assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller, SimpleRecursiveMessage message)301   private static void assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
302       SimpleRecursiveMessage message) throws IOException {
303     InputStream is = marshaller.stream(message);
304     ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));
305 
306     assertEquals(message, marshaller.parse(bais));
307   }
308 
buildRecursiveMessage(int depth)309   private static SimpleRecursiveMessage buildRecursiveMessage(int depth) {
310     SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder()
311         .setValue("depth-" + depth);
312     for (int i = depth; i > 0; i--) {
313       builder = SimpleRecursiveMessage.newBuilder()
314           .setValue("depth-" + i)
315           .setMessage(builder.build());
316     }
317 
318     return builder.build();
319   }
320 
321   private static class CustomKnownLengthInputStream extends InputStream implements KnownLength {
322 
323     private int position = 0;
324     private final byte[] source;
325 
CustomKnownLengthInputStream(byte[] source)326     private CustomKnownLengthInputStream(byte[] source) {
327       this.source = source;
328     }
329 
330     @Override
available()331     public int available() {
332       return source.length - position;
333     }
334 
335     @Override
read()336     public int read() {
337       if (position == source.length) {
338         return -1;
339       }
340 
341       return source[position++];
342     }
343   }
344 }
345