1 /* 2 * Copyright 2015 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.protobuf.lite; 18 19 import static org.junit.Assert.assertArrayEquals; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertNotNull; 22 import static org.junit.Assert.assertSame; 23 import static org.junit.Assert.assertThrows; 24 import static org.junit.Assert.fail; 25 26 import com.google.common.io.ByteStreams; 27 import com.google.protobuf.ByteString; 28 import com.google.protobuf.Empty; 29 import com.google.protobuf.Enum; 30 import com.google.protobuf.InvalidProtocolBufferException; 31 import com.google.protobuf.Type; 32 import io.grpc.Drainable; 33 import io.grpc.KnownLength; 34 import io.grpc.Metadata; 35 import io.grpc.MethodDescriptor.Marshaller; 36 import io.grpc.MethodDescriptor.PrototypeMarshaller; 37 import io.grpc.Status; 38 import io.grpc.StatusRuntimeException; 39 import io.grpc.internal.GrpcUtil; 40 import io.grpc.testing.protobuf.SimpleRecursiveMessage; 41 import java.io.ByteArrayInputStream; 42 import java.io.ByteArrayOutputStream; 43 import java.io.IOException; 44 import java.io.InputStream; 45 import java.util.Arrays; 46 import org.junit.Rule; 47 import org.junit.Test; 48 import org.junit.rules.ExpectedException; 49 import org.junit.runner.RunWith; 50 import org.junit.runners.JUnit4; 51 52 /** Unit tests for {@link ProtoLiteUtils}. */ 53 @RunWith(JUnit4.class) 54 public class ProtoLiteUtilsTest { 55 56 @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 57 @Rule public final ExpectedException thrown = ExpectedException.none(); 58 59 private final Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); 60 private Type proto = Type.newBuilder().setName("name").build(); 61 62 @Test testPassthrough()63 public void testPassthrough() { 64 assertSame(proto, marshaller.parse(marshaller.stream(proto))); 65 } 66 67 @Test testRoundtrip()68 public void testRoundtrip() throws Exception { 69 InputStream is = marshaller.stream(proto); 70 is = new ByteArrayInputStream(ByteStreams.toByteArray(is)); 71 assertEquals(proto, marshaller.parse(is)); 72 } 73 74 @Test testInvalidatedMessage()75 public void testInvalidatedMessage() throws Exception { 76 InputStream is = marshaller.stream(proto); 77 // Invalidates message, and drains all bytes 78 byte[] unused = ByteStreams.toByteArray(is); 79 try { 80 ((ProtoInputStream) is).message(); 81 fail("Expected exception"); 82 } catch (IllegalStateException ex) { 83 // expected 84 } 85 // Zero bytes is the default message 86 assertEquals(Type.getDefaultInstance(), marshaller.parse(is)); 87 } 88 89 @Test parseInvalid()90 public void parseInvalid() { 91 InputStream is = new ByteArrayInputStream(new byte[] {-127}); 92 try { 93 marshaller.parse(is); 94 fail("Expected exception"); 95 } catch (StatusRuntimeException ex) { 96 assertEquals(Status.Code.INTERNAL, ex.getStatus().getCode()); 97 assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage()); 98 } 99 } 100 101 @Test testMismatch()102 public void testMismatch() { 103 Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance()); 104 // Enum's name and Type's name are both strings with tag 1. 105 Enum altProto = Enum.newBuilder().setName(proto.getName()).build(); 106 assertEquals(proto, marshaller.parse(enumMarshaller.stream(altProto))); 107 } 108 109 @Test introspection()110 public void introspection() { 111 Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance()); 112 PrototypeMarshaller<Enum> prototypeMarshaller = (PrototypeMarshaller<Enum>) enumMarshaller; 113 assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype()); 114 assertSame(Enum.class, prototypeMarshaller.getMessageClass()); 115 } 116 117 @Test marshallerShouldNotLimitProtoSize()118 public void marshallerShouldNotLimitProtoSize() throws Exception { 119 // The default limit is 64MB. Using a larger proto to verify that the limit is not enforced. 120 byte[] bigName = new byte[70 * 1024 * 1024]; 121 Arrays.fill(bigName, (byte) 32); 122 123 proto = Type.newBuilder().setNameBytes(ByteString.copyFrom(bigName)).build(); 124 125 // Just perform a round trip to verify that it works. 126 testRoundtrip(); 127 } 128 129 @Test testAvailable()130 public void testAvailable() throws Exception { 131 InputStream is = marshaller.stream(proto); 132 assertEquals(proto.getSerializedSize(), is.available()); 133 is.read(); 134 assertEquals(proto.getSerializedSize() - 1, is.available()); 135 while (is.read() != -1) {} 136 assertEquals(-1, is.read()); 137 assertEquals(0, is.available()); 138 } 139 140 @Test testEmpty()141 public void testEmpty() throws IOException { 142 Marshaller<Empty> marshaller = ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); 143 InputStream is = marshaller.stream(Empty.getDefaultInstance()); 144 assertEquals(0, is.available()); 145 byte[] b = new byte[10]; 146 assertEquals(-1, is.read(b)); 147 assertArrayEquals(new byte[10], b); 148 // Do the same thing again, because the internal state may be different 149 assertEquals(-1, is.read(b)); 150 assertArrayEquals(new byte[10], b); 151 assertEquals(-1, is.read()); 152 assertEquals(0, is.available()); 153 } 154 155 @Test testDrainTo_all()156 public void testDrainTo_all() throws Exception { 157 byte[] golden = ByteStreams.toByteArray(marshaller.stream(proto)); 158 InputStream is = marshaller.stream(proto); 159 Drainable d = (Drainable) is; 160 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 161 int drained = d.drainTo(baos); 162 assertEquals(baos.size(), drained); 163 assertArrayEquals(golden, baos.toByteArray()); 164 assertEquals(0, is.available()); 165 } 166 167 @Test testDrainTo_partial()168 public void testDrainTo_partial() throws Exception { 169 final byte[] golden; 170 { 171 InputStream is = marshaller.stream(proto); 172 is.read(); 173 golden = ByteStreams.toByteArray(is); 174 } 175 InputStream is = marshaller.stream(proto); 176 is.read(); 177 Drainable d = (Drainable) is; 178 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 179 int drained = d.drainTo(baos); 180 assertEquals(baos.size(), drained); 181 assertArrayEquals(golden, baos.toByteArray()); 182 assertEquals(0, is.available()); 183 } 184 185 @Test testDrainTo_none()186 public void testDrainTo_none() throws Exception { 187 InputStream is = marshaller.stream(proto); 188 byte[] unused = ByteStreams.toByteArray(is); 189 Drainable d = (Drainable) is; 190 ByteArrayOutputStream baos = new ByteArrayOutputStream(); 191 assertEquals(0, d.drainTo(baos)); 192 assertArrayEquals(new byte[0], baos.toByteArray()); 193 assertEquals(0, is.available()); 194 } 195 196 @Test metadataMarshaller_roundtrip()197 public void metadataMarshaller_roundtrip() { 198 Metadata.BinaryMarshaller<Type> metadataMarshaller = 199 ProtoLiteUtils.metadataMarshaller(Type.getDefaultInstance()); 200 assertEquals(proto, metadataMarshaller.parseBytes(metadataMarshaller.toBytes(proto))); 201 } 202 203 @Test metadataMarshaller_invalid()204 public void metadataMarshaller_invalid() { 205 Metadata.BinaryMarshaller<Type> metadataMarshaller = 206 ProtoLiteUtils.metadataMarshaller(Type.getDefaultInstance()); 207 try { 208 metadataMarshaller.parseBytes(new byte[] {-127}); 209 fail("Expected exception"); 210 } catch (IllegalArgumentException ex) { 211 assertNotNull(((InvalidProtocolBufferException) ex.getCause()).getUnfinishedMessage()); 212 } 213 } 214 215 @Test extensionRegistry_notNull()216 public void extensionRegistry_notNull() { 217 thrown.expect(NullPointerException.class); 218 thrown.expectMessage("newRegistry"); 219 220 ProtoLiteUtils.setExtensionRegistry(null); 221 } 222 223 @Test parseFromKnowLengthInputStream()224 public void parseFromKnowLengthInputStream() { 225 Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); 226 Type expect = Type.newBuilder().setName("expected name").build(); 227 228 Type result = marshaller.parse(new CustomKnownLengthInputStream(expect.toByteArray())); 229 assertEquals(expect, result); 230 } 231 232 @Test defaultMaxMessageSize()233 public void defaultMaxMessageSize() { 234 assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE); 235 } 236 237 @Test testNullDefaultInstance()238 public void testNullDefaultInstance() { 239 String expectedMessage = "defaultInstance cannot be null"; 240 assertThrows(expectedMessage, NullPointerException.class, 241 () -> ProtoLiteUtils.marshaller(null)); 242 243 assertThrows(expectedMessage, NullPointerException.class, 244 () -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10) 245 ); 246 } 247 248 @Test givenPositiveLimit_testRecursionLimitExceeded()249 public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException { 250 Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( 251 SimpleRecursiveMessage.getDefaultInstance(), 10); 252 SimpleRecursiveMessage message = buildRecursiveMessage(12); 253 254 assertRecursionLimitExceeded(marshaller, message); 255 } 256 257 @Test givenZeroLimit_testRecursionLimitExceeded()258 public void givenZeroLimit_testRecursionLimitExceeded() throws IOException { 259 Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( 260 SimpleRecursiveMessage.getDefaultInstance(), 0); 261 SimpleRecursiveMessage message = buildRecursiveMessage(1); 262 263 assertRecursionLimitExceeded(marshaller, message); 264 } 265 266 @Test givenPositiveLimit_testRecursionLimitNotExceeded()267 public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException { 268 Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( 269 SimpleRecursiveMessage.getDefaultInstance(), 15); 270 SimpleRecursiveMessage message = buildRecursiveMessage(12); 271 272 assertRecursionLimitNotExceeded(marshaller, message); 273 } 274 275 @Test givenZeroLimit_testRecursionLimitNotExceeded()276 public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException { 277 Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( 278 SimpleRecursiveMessage.getDefaultInstance(), 0); 279 SimpleRecursiveMessage message = buildRecursiveMessage(0); 280 281 assertRecursionLimitNotExceeded(marshaller, message); 282 } 283 284 @Test testDefaultRecursionLimit()285 public void testDefaultRecursionLimit() throws IOException { 286 Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshaller( 287 SimpleRecursiveMessage.getDefaultInstance()); 288 SimpleRecursiveMessage message = buildRecursiveMessage(100); 289 290 assertRecursionLimitNotExceeded(marshaller, message); 291 } 292 assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller, SimpleRecursiveMessage message)293 private static void assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller, 294 SimpleRecursiveMessage message) throws IOException { 295 InputStream is = marshaller.stream(message); 296 ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is)); 297 298 assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais)); 299 } 300 assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller, SimpleRecursiveMessage message)301 private static void assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller, 302 SimpleRecursiveMessage message) throws IOException { 303 InputStream is = marshaller.stream(message); 304 ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is)); 305 306 assertEquals(message, marshaller.parse(bais)); 307 } 308 buildRecursiveMessage(int depth)309 private static SimpleRecursiveMessage buildRecursiveMessage(int depth) { 310 SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder() 311 .setValue("depth-" + depth); 312 for (int i = depth; i > 0; i--) { 313 builder = SimpleRecursiveMessage.newBuilder() 314 .setValue("depth-" + i) 315 .setMessage(builder.build()); 316 } 317 318 return builder.build(); 319 } 320 321 private static class CustomKnownLengthInputStream extends InputStream implements KnownLength { 322 323 private int position = 0; 324 private final byte[] source; 325 CustomKnownLengthInputStream(byte[] source)326 private CustomKnownLengthInputStream(byte[] source) { 327 this.source = source; 328 } 329 330 @Override available()331 public int available() { 332 return source.length - position; 333 } 334 335 @Override read()336 public int read() { 337 if (position == source.length) { 338 return -1; 339 } 340 341 return source[position++]; 342 } 343 } 344 } 345