1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.protobuf.lite; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 21 import com.google.common.annotations.VisibleForTesting; 22 import com.google.protobuf.CodedInputStream; 23 import com.google.protobuf.ExtensionRegistryLite; 24 import com.google.protobuf.InvalidProtocolBufferException; 25 import com.google.protobuf.MessageLite; 26 import com.google.protobuf.Parser; 27 import io.grpc.ExperimentalApi; 28 import io.grpc.KnownLength; 29 import io.grpc.Metadata; 30 import io.grpc.MethodDescriptor.Marshaller; 31 import io.grpc.MethodDescriptor.PrototypeMarshaller; 32 import io.grpc.Status; 33 import java.io.IOException; 34 import java.io.InputStream; 35 import java.io.OutputStream; 36 import java.lang.ref.Reference; 37 import java.lang.ref.WeakReference; 38 39 /** 40 * Utility methods for using protobuf with grpc. 41 */ 42 @ExperimentalApi("Experimental until Lite is stable in protobuf") 43 public final class ProtoLiteUtils { 44 45 // default visibility to avoid synthetic accessors 46 static volatile ExtensionRegistryLite globalRegistry = 47 ExtensionRegistryLite.getEmptyRegistry(); 48 49 private static final int BUF_SIZE = 8192; 50 51 /** 52 * The same value as {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}. 53 */ 54 @VisibleForTesting 55 static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024; 56 57 /** 58 * Sets the global registry for proto marshalling shared across all servers and clients. 59 * 60 * <p>Warning: This API will likely change over time. It is not possible to have separate 61 * registries per Process, Server, Channel, Service, or Method. This is intentional until there 62 * is a more appropriate API to set them. 63 * 64 * <p>Warning: Do NOT modify the extension registry after setting it. It is thread safe to call 65 * {@link #setExtensionRegistry}, but not to modify the underlying object. 66 * 67 * <p>If you need custom parsing behavior for protos, you will need to make your own 68 * {@code MethodDescriptor.Marshaller} for the time being. 69 * 70 * @since 1.0.0 71 */ 72 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1787") setExtensionRegistry(ExtensionRegistryLite newRegistry)73 public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) { 74 globalRegistry = checkNotNull(newRegistry, "newRegistry"); 75 } 76 77 /** 78 * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance}. 79 * 80 * @since 1.0.0 81 */ marshaller(T defaultInstance)82 public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) { 83 // TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe) 84 return new MessageMarshaller<>(defaultInstance, -1); 85 } 86 87 /** 88 * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a 89 * custom limit for the recursion depth. Any negative number will leave the limit to its default 90 * value as defined by the protobuf library. 91 * 92 * @since 1.56.0 93 */ 94 @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") marshallerWithRecursionLimit( T defaultInstance, int recursionLimit)95 public static <T extends MessageLite> Marshaller<T> marshallerWithRecursionLimit( 96 T defaultInstance, int recursionLimit) { 97 return new MessageMarshaller<>(defaultInstance, recursionLimit); 98 } 99 100 /** 101 * Produce a metadata marshaller for a protobuf type. 102 * 103 * @since 1.0.0 104 */ metadataMarshaller( T defaultInstance)105 public static <T extends MessageLite> Metadata.BinaryMarshaller<T> metadataMarshaller( 106 T defaultInstance) { 107 return new MetadataMarshaller<>(defaultInstance); 108 } 109 110 /** Copies the data from input stream to output stream. */ copy(InputStream from, OutputStream to)111 static long copy(InputStream from, OutputStream to) throws IOException { 112 // Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta) 113 checkNotNull(from, "inputStream cannot be null!"); 114 checkNotNull(to, "outputStream cannot be null!"); 115 byte[] buf = new byte[BUF_SIZE]; 116 long total = 0; 117 while (true) { 118 int r = from.read(buf); 119 if (r == -1) { 120 break; 121 } 122 to.write(buf, 0, r); 123 total += r; 124 } 125 return total; 126 } 127 ProtoLiteUtils()128 private ProtoLiteUtils() { 129 } 130 131 private static final class MessageMarshaller<T extends MessageLite> 132 implements PrototypeMarshaller<T> { 133 134 private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<>(); 135 136 private final Parser<T> parser; 137 private final T defaultInstance; 138 private final int recursionLimit; 139 140 @SuppressWarnings("unchecked") MessageMarshaller(T defaultInstance, int recursionLimit)141 MessageMarshaller(T defaultInstance, int recursionLimit) { 142 this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null"); 143 this.parser = (Parser<T>) defaultInstance.getParserForType(); 144 this.recursionLimit = recursionLimit; 145 } 146 147 @SuppressWarnings("unchecked") 148 @Override getMessageClass()149 public Class<T> getMessageClass() { 150 // Precisely T since protobuf doesn't let messages extend other messages. 151 return (Class<T>) defaultInstance.getClass(); 152 } 153 154 @Override getMessagePrototype()155 public T getMessagePrototype() { 156 return defaultInstance; 157 } 158 159 @Override stream(T value)160 public InputStream stream(T value) { 161 return new ProtoInputStream(value, parser); 162 } 163 164 @Override parse(InputStream stream)165 public T parse(InputStream stream) { 166 if (stream instanceof ProtoInputStream) { 167 ProtoInputStream protoStream = (ProtoInputStream) stream; 168 // Optimization for in-memory transport. Returning provided object is safe since protobufs 169 // are immutable. 170 // 171 // However, we can't assume the types match, so we have to verify the parser matches. 172 // Today the parser is always the same for a given proto, but that isn't guaranteed. Even 173 // if not, using the same MethodDescriptor would ensure the parser matches and permit us 174 // to enable this optimization. 175 if (protoStream.parser() == parser) { 176 try { 177 @SuppressWarnings("unchecked") 178 T message = (T) ((ProtoInputStream) stream).message(); 179 return message; 180 } catch (IllegalStateException ignored) { 181 // Stream must have been read from, which is a strange state. Since the point of this 182 // optimization is to be transparent, instead of throwing an error we'll continue, 183 // even though it seems likely there's a bug. 184 } 185 } 186 } 187 CodedInputStream cis = null; 188 try { 189 if (stream instanceof KnownLength) { 190 int size = stream.available(); 191 if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) { 192 Reference<byte[]> ref; 193 // buf should not be used after this method has returned. 194 byte[] buf; 195 if ((ref = bufs.get()) == null || (buf = ref.get()) == null || buf.length < size) { 196 buf = new byte[size]; 197 bufs.set(new WeakReference<>(buf)); 198 } 199 200 int remaining = size; 201 while (remaining > 0) { 202 int position = size - remaining; 203 int count = stream.read(buf, position, remaining); 204 if (count == -1) { 205 break; 206 } 207 remaining -= count; 208 } 209 210 if (remaining != 0) { 211 int position = size - remaining; 212 throw new RuntimeException("size inaccurate: " + size + " != " + position); 213 } 214 cis = CodedInputStream.newInstance(buf, 0, size); 215 } else if (size == 0) { 216 return defaultInstance; 217 } 218 } 219 } catch (IOException e) { 220 throw new RuntimeException(e); 221 } 222 if (cis == null) { 223 cis = CodedInputStream.newInstance(stream); 224 } 225 // Pre-create the CodedInputStream so that we can remove the size limit restriction 226 // when parsing. 227 cis.setSizeLimit(Integer.MAX_VALUE); 228 229 if (recursionLimit >= 0) { 230 cis.setRecursionLimit(recursionLimit); 231 } 232 233 try { 234 return parseFrom(cis); 235 } catch (InvalidProtocolBufferException ipbe) { 236 throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence") 237 .withCause(ipbe).asRuntimeException(); 238 } 239 } 240 parseFrom(CodedInputStream stream)241 private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException { 242 T message = parser.parseFrom(stream, globalRegistry); 243 try { 244 stream.checkLastTagWas(0); 245 return message; 246 } catch (InvalidProtocolBufferException e) { 247 e.setUnfinishedMessage(message); 248 throw e; 249 } 250 } 251 } 252 253 private static final class MetadataMarshaller<T extends MessageLite> 254 implements Metadata.BinaryMarshaller<T> { 255 256 private final T defaultInstance; 257 MetadataMarshaller(T defaultInstance)258 MetadataMarshaller(T defaultInstance) { 259 this.defaultInstance = defaultInstance; 260 } 261 262 @Override toBytes(T value)263 public byte[] toBytes(T value) { 264 return value.toByteArray(); 265 } 266 267 @Override 268 @SuppressWarnings("unchecked") parseBytes(byte[] serialized)269 public T parseBytes(byte[] serialized) { 270 try { 271 return (T) defaultInstance.getParserForType().parseFrom(serialized, globalRegistry); 272 } catch (InvalidProtocolBufferException ipbe) { 273 throw new IllegalArgumentException(ipbe); 274 } 275 } 276 } 277 } 278