• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.protobuf.lite;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 
21 import com.google.common.annotations.VisibleForTesting;
22 import com.google.protobuf.CodedInputStream;
23 import com.google.protobuf.ExtensionRegistryLite;
24 import com.google.protobuf.InvalidProtocolBufferException;
25 import com.google.protobuf.MessageLite;
26 import com.google.protobuf.Parser;
27 import io.grpc.ExperimentalApi;
28 import io.grpc.KnownLength;
29 import io.grpc.Metadata;
30 import io.grpc.MethodDescriptor.Marshaller;
31 import io.grpc.MethodDescriptor.PrototypeMarshaller;
32 import io.grpc.Status;
33 import java.io.IOException;
34 import java.io.InputStream;
35 import java.io.OutputStream;
36 import java.lang.ref.Reference;
37 import java.lang.ref.WeakReference;
38 
39 /**
40  * Utility methods for using protobuf with grpc.
41  */
42 @ExperimentalApi("Experimental until Lite is stable in protobuf")
43 public final class ProtoLiteUtils {
44 
45   // default visibility to avoid synthetic accessors
46   static volatile ExtensionRegistryLite globalRegistry =
47       ExtensionRegistryLite.getEmptyRegistry();
48 
49   private static final int BUF_SIZE = 8192;
50 
51   /**
52    * The same value as {@link io.grpc.internal.GrpcUtil#DEFAULT_MAX_MESSAGE_SIZE}.
53    */
54   @VisibleForTesting
55   static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
56 
57   /**
58    * Sets the global registry for proto marshalling shared across all servers and clients.
59    *
60    * <p>Warning:  This API will likely change over time.  It is not possible to have separate
61    * registries per Process, Server, Channel, Service, or Method.  This is intentional until there
62    * is a more appropriate API to set them.
63    *
64    * <p>Warning:  Do NOT modify the extension registry after setting it.  It is thread safe to call
65    * {@link #setExtensionRegistry}, but not to modify the underlying object.
66    *
67    * <p>If you need custom parsing behavior for protos, you will need to make your own
68    * {@code MethodDescriptor.Marshaller} for the time being.
69    *
70    * @since 1.0.0
71    */
72   @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1787")
setExtensionRegistry(ExtensionRegistryLite newRegistry)73   public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
74     globalRegistry = checkNotNull(newRegistry, "newRegistry");
75   }
76 
77   /**
78    * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance}.
79    *
80    * @since 1.0.0
81    */
marshaller(T defaultInstance)82   public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) {
83     // TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
84     return new MessageMarshaller<>(defaultInstance, -1);
85   }
86 
87   /**
88    * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
89    * custom limit for the recursion depth. Any negative number will leave the limit to its default
90    * value as defined by the protobuf library.
91    *
92    * @since 1.56.0
93    */
94   @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108")
marshallerWithRecursionLimit( T defaultInstance, int recursionLimit)95   public static <T extends MessageLite> Marshaller<T> marshallerWithRecursionLimit(
96       T defaultInstance, int recursionLimit) {
97     return new MessageMarshaller<>(defaultInstance, recursionLimit);
98   }
99 
100   /**
101    * Produce a metadata marshaller for a protobuf type.
102    *
103    * @since 1.0.0
104    */
metadataMarshaller( T defaultInstance)105   public static <T extends MessageLite> Metadata.BinaryMarshaller<T> metadataMarshaller(
106       T defaultInstance) {
107     return new MetadataMarshaller<>(defaultInstance);
108   }
109 
110   /** Copies the data from input stream to output stream. */
copy(InputStream from, OutputStream to)111   static long copy(InputStream from, OutputStream to) throws IOException {
112     // Copied from guava com.google.common.io.ByteStreams because its API is unstable (beta)
113     checkNotNull(from, "inputStream cannot be null!");
114     checkNotNull(to, "outputStream cannot be null!");
115     byte[] buf = new byte[BUF_SIZE];
116     long total = 0;
117     while (true) {
118       int r = from.read(buf);
119       if (r == -1) {
120         break;
121       }
122       to.write(buf, 0, r);
123       total += r;
124     }
125     return total;
126   }
127 
ProtoLiteUtils()128   private ProtoLiteUtils() {
129   }
130 
131   private static final class MessageMarshaller<T extends MessageLite>
132       implements PrototypeMarshaller<T> {
133 
134     private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<>();
135 
136     private final Parser<T> parser;
137     private final T defaultInstance;
138     private final int recursionLimit;
139 
140     @SuppressWarnings("unchecked")
MessageMarshaller(T defaultInstance, int recursionLimit)141     MessageMarshaller(T defaultInstance, int recursionLimit) {
142       this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null");
143       this.parser = (Parser<T>) defaultInstance.getParserForType();
144       this.recursionLimit = recursionLimit;
145     }
146 
147     @SuppressWarnings("unchecked")
148     @Override
getMessageClass()149     public Class<T> getMessageClass() {
150       // Precisely T since protobuf doesn't let messages extend other messages.
151       return (Class<T>) defaultInstance.getClass();
152     }
153 
154     @Override
getMessagePrototype()155     public T getMessagePrototype() {
156       return defaultInstance;
157     }
158 
159     @Override
stream(T value)160     public InputStream stream(T value) {
161       return new ProtoInputStream(value, parser);
162     }
163 
164     @Override
parse(InputStream stream)165     public T parse(InputStream stream) {
166       if (stream instanceof ProtoInputStream) {
167         ProtoInputStream protoStream = (ProtoInputStream) stream;
168         // Optimization for in-memory transport. Returning provided object is safe since protobufs
169         // are immutable.
170         //
171         // However, we can't assume the types match, so we have to verify the parser matches.
172         // Today the parser is always the same for a given proto, but that isn't guaranteed. Even
173         // if not, using the same MethodDescriptor would ensure the parser matches and permit us
174         // to enable this optimization.
175         if (protoStream.parser() == parser) {
176           try {
177             @SuppressWarnings("unchecked")
178             T message = (T) ((ProtoInputStream) stream).message();
179             return message;
180           } catch (IllegalStateException ignored) {
181             // Stream must have been read from, which is a strange state. Since the point of this
182             // optimization is to be transparent, instead of throwing an error we'll continue,
183             // even though it seems likely there's a bug.
184           }
185         }
186       }
187       CodedInputStream cis = null;
188       try {
189         if (stream instanceof KnownLength) {
190           int size = stream.available();
191           if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) {
192             Reference<byte[]> ref;
193             // buf should not be used after this method has returned.
194             byte[] buf;
195             if ((ref = bufs.get()) == null || (buf = ref.get()) == null || buf.length < size) {
196               buf = new byte[size];
197               bufs.set(new WeakReference<>(buf));
198             }
199 
200             int remaining = size;
201             while (remaining > 0) {
202               int position = size - remaining;
203               int count = stream.read(buf, position, remaining);
204               if (count == -1) {
205                 break;
206               }
207               remaining -= count;
208             }
209 
210             if (remaining != 0) {
211               int position = size - remaining;
212               throw new RuntimeException("size inaccurate: " + size + " != " + position);
213             }
214             cis = CodedInputStream.newInstance(buf, 0, size);
215           } else if (size == 0) {
216             return defaultInstance;
217           }
218         }
219       } catch (IOException e) {
220         throw new RuntimeException(e);
221       }
222       if (cis == null) {
223         cis = CodedInputStream.newInstance(stream);
224       }
225       // Pre-create the CodedInputStream so that we can remove the size limit restriction
226       // when parsing.
227       cis.setSizeLimit(Integer.MAX_VALUE);
228 
229       if (recursionLimit >= 0) {
230         cis.setRecursionLimit(recursionLimit);
231       }
232 
233       try {
234         return parseFrom(cis);
235       } catch (InvalidProtocolBufferException ipbe) {
236         throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence")
237             .withCause(ipbe).asRuntimeException();
238       }
239     }
240 
parseFrom(CodedInputStream stream)241     private T parseFrom(CodedInputStream stream) throws InvalidProtocolBufferException {
242       T message = parser.parseFrom(stream, globalRegistry);
243       try {
244         stream.checkLastTagWas(0);
245         return message;
246       } catch (InvalidProtocolBufferException e) {
247         e.setUnfinishedMessage(message);
248         throw e;
249       }
250     }
251   }
252 
253   private static final class MetadataMarshaller<T extends MessageLite>
254       implements Metadata.BinaryMarshaller<T> {
255 
256     private final T defaultInstance;
257 
MetadataMarshaller(T defaultInstance)258     MetadataMarshaller(T defaultInstance) {
259       this.defaultInstance = defaultInstance;
260     }
261 
262     @Override
toBytes(T value)263     public byte[] toBytes(T value) {
264       return value.toByteArray();
265     }
266 
267     @Override
268     @SuppressWarnings("unchecked")
parseBytes(byte[] serialized)269     public T parseBytes(byte[] serialized) {
270       try {
271         return (T) defaultInstance.getParserForType().parseFrom(serialized, globalRegistry);
272       } catch (InvalidProtocolBufferException ipbe) {
273         throw new IllegalArgumentException(ipbe);
274       }
275     }
276   }
277 }
278