• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.internal;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 import static com.google.common.base.Preconditions.checkState;
22 
23 import com.google.common.annotations.VisibleForTesting;
24 import io.grpc.Codec;
25 import io.grpc.Decompressor;
26 import io.grpc.Status;
27 import java.io.Closeable;
28 import java.io.FilterInputStream;
29 import java.io.IOException;
30 import java.io.InputStream;
31 import java.util.Locale;
32 import java.util.zip.DataFormatException;
33 import javax.annotation.Nullable;
34 import javax.annotation.concurrent.NotThreadSafe;
35 
36 /**
37  * Deframer for GRPC frames.
38  *
39  * <p>This class is not thread-safe. Unless otherwise stated, all calls to public methods should be
40  * made in the deframing thread.
41  */
42 @NotThreadSafe
43 public class MessageDeframer implements Closeable, Deframer {
44   private static final int HEADER_LENGTH = 5;
45   private static final int COMPRESSED_FLAG_MASK = 1;
46   private static final int RESERVED_MASK = 0xFE;
47   private static final int MAX_BUFFER_SIZE = 1024 * 1024 * 2;
48 
49   /**
50    * A listener of deframing events. These methods will be invoked from the deframing thread.
51    */
52   public interface Listener {
53 
54     /**
55      * Called when the given number of bytes has been read from the input source of the deframer.
56      * This is typically used to indicate to the underlying transport that more data can be
57      * accepted.
58      *
59      * @param numBytes the number of bytes read from the deframer's input source.
60      */
bytesRead(int numBytes)61     void bytesRead(int numBytes);
62 
63     /**
64      * Called to deliver the next complete message.
65      *
66      * @param producer single message producer wrapping the message.
67      */
messagesAvailable(StreamListener.MessageProducer producer)68     void messagesAvailable(StreamListener.MessageProducer producer);
69 
70     /**
71      * Called when the deframer closes.
72      *
73      * @param hasPartialMessage whether the deframer contained an incomplete message at closing.
74      */
deframerClosed(boolean hasPartialMessage)75     void deframerClosed(boolean hasPartialMessage);
76 
77     /**
78      * Called when a {@link #deframe(ReadableBuffer)} operation failed.
79      *
80      * @param cause the actual failure
81      */
deframeFailed(Throwable cause)82     void deframeFailed(Throwable cause);
83   }
84 
85   private enum State {
86     HEADER, BODY
87   }
88 
89   private Listener listener;
90   private int maxInboundMessageSize;
91   private final StatsTraceContext statsTraceCtx;
92   private final TransportTracer transportTracer;
93   private Decompressor decompressor;
94   private GzipInflatingBuffer fullStreamDecompressor;
95   private byte[] inflatedBuffer;
96   private int inflatedIndex;
97   private State state = State.HEADER;
98   private int requiredLength = HEADER_LENGTH;
99   private boolean compressedFlag;
100   private CompositeReadableBuffer nextFrame;
101   private CompositeReadableBuffer unprocessed = new CompositeReadableBuffer();
102   private long pendingDeliveries;
103   private boolean inDelivery = false;
104   private int currentMessageSeqNo = -1;
105   private int inboundBodyWireSize;
106 
107   private boolean closeWhenComplete = false;
108   private volatile boolean stopDelivery = false;
109 
110   /**
111    * Create a deframer.
112    *
113    * @param listener listener for deframer events.
114    * @param decompressor the compression used if a compressed frame is encountered, with
115    *  {@code NONE} meaning unsupported
116    * @param maxMessageSize the maximum allowed size for received messages.
117    */
MessageDeframer( Listener listener, Decompressor decompressor, int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer)118   public MessageDeframer(
119       Listener listener,
120       Decompressor decompressor,
121       int maxMessageSize,
122       StatsTraceContext statsTraceCtx,
123       TransportTracer transportTracer) {
124     this.listener = checkNotNull(listener, "sink");
125     this.decompressor = checkNotNull(decompressor, "decompressor");
126     this.maxInboundMessageSize = maxMessageSize;
127     this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx");
128     this.transportTracer = checkNotNull(transportTracer, "transportTracer");
129   }
130 
setListener(Listener listener)131   void setListener(Listener listener) {
132     this.listener = listener;
133   }
134 
135   @Override
setMaxInboundMessageSize(int messageSize)136   public void setMaxInboundMessageSize(int messageSize) {
137     maxInboundMessageSize = messageSize;
138   }
139 
140   @Override
setDecompressor(Decompressor decompressor)141   public void setDecompressor(Decompressor decompressor) {
142     checkState(fullStreamDecompressor == null, "Already set full stream decompressor");
143     this.decompressor = checkNotNull(decompressor, "Can't pass an empty decompressor");
144   }
145 
146   @Override
setFullStreamDecompressor(GzipInflatingBuffer fullStreamDecompressor)147   public void setFullStreamDecompressor(GzipInflatingBuffer fullStreamDecompressor) {
148     checkState(decompressor == Codec.Identity.NONE, "per-message decompressor already set");
149     checkState(this.fullStreamDecompressor == null, "full stream decompressor already set");
150     this.fullStreamDecompressor =
151         checkNotNull(fullStreamDecompressor, "Can't pass a null full stream decompressor");
152     unprocessed = null;
153   }
154 
155   @Override
request(int numMessages)156   public void request(int numMessages) {
157     checkArgument(numMessages > 0, "numMessages must be > 0");
158     if (isClosed()) {
159       return;
160     }
161     pendingDeliveries += numMessages;
162     deliver();
163   }
164 
165   @Override
deframe(ReadableBuffer data)166   public void deframe(ReadableBuffer data) {
167     checkNotNull(data, "data");
168     boolean needToCloseData = true;
169     try {
170       if (!isClosedOrScheduledToClose()) {
171         if (fullStreamDecompressor != null) {
172           fullStreamDecompressor.addGzippedBytes(data);
173         } else {
174           unprocessed.addBuffer(data);
175         }
176         needToCloseData = false;
177 
178         deliver();
179       }
180     } finally {
181       if (needToCloseData) {
182         data.close();
183       }
184     }
185   }
186 
187   @Override
closeWhenComplete()188   public void closeWhenComplete() {
189     if (isClosed()) {
190       return;
191     } else if (isStalled()) {
192       close();
193     } else {
194       closeWhenComplete = true;
195     }
196   }
197 
198   /**
199    * Sets a flag to interrupt delivery of any currently queued messages. This may be invoked outside
200    * of the deframing thread, and must be followed by a call to {@link #close()} in the deframing
201    * thread. Without a subsequent call to {@link #close()}, the deframer may hang waiting for
202    * additional messages before noticing that the {@code stopDelivery} flag has been set.
203    */
stopDelivery()204   void stopDelivery() {
205     stopDelivery = true;
206   }
207 
hasPendingDeliveries()208   boolean hasPendingDeliveries() {
209     return pendingDeliveries != 0;
210   }
211 
212   @Override
close()213   public void close() {
214     if (isClosed()) {
215       return;
216     }
217     boolean hasPartialMessage = nextFrame != null && nextFrame.readableBytes() > 0;
218     try {
219       if (fullStreamDecompressor != null) {
220         hasPartialMessage = hasPartialMessage || fullStreamDecompressor.hasPartialData();
221         fullStreamDecompressor.close();
222       }
223       if (unprocessed != null) {
224         unprocessed.close();
225       }
226       if (nextFrame != null) {
227         nextFrame.close();
228       }
229     } finally {
230       fullStreamDecompressor = null;
231       unprocessed = null;
232       nextFrame = null;
233     }
234     listener.deframerClosed(hasPartialMessage);
235   }
236 
237   /**
238    * Indicates whether or not this deframer has been closed.
239    */
isClosed()240   public boolean isClosed() {
241     return unprocessed == null && fullStreamDecompressor == null;
242   }
243 
244   /** Returns true if this deframer has already been closed or scheduled to close. */
isClosedOrScheduledToClose()245   private boolean isClosedOrScheduledToClose() {
246     return isClosed() || closeWhenComplete;
247   }
248 
isStalled()249   private boolean isStalled() {
250     if (fullStreamDecompressor != null) {
251       return fullStreamDecompressor.isStalled();
252     } else {
253       return unprocessed.readableBytes() == 0;
254     }
255   }
256 
257   /**
258    * Reads and delivers as many messages to the listener as possible.
259    */
deliver()260   private void deliver() {
261     // We can have reentrancy here when using a direct executor, triggered by calls to
262     // request more messages. This is safe as we simply loop until pendingDelivers = 0
263     if (inDelivery) {
264       return;
265     }
266     inDelivery = true;
267     try {
268       // Process the uncompressed bytes.
269       while (!stopDelivery && pendingDeliveries > 0 && readRequiredBytes()) {
270         switch (state) {
271           case HEADER:
272             processHeader();
273             break;
274           case BODY:
275             // Read the body and deliver the message.
276             processBody();
277 
278             // Since we've delivered a message, decrement the number of pending
279             // deliveries remaining.
280             pendingDeliveries--;
281             break;
282           default:
283             throw new AssertionError("Invalid state: " + state);
284         }
285       }
286 
287       if (stopDelivery) {
288         close();
289         return;
290       }
291 
292       /*
293        * We are stalled when there are no more bytes to process. This allows delivering errors as
294        * soon as the buffered input has been consumed, independent of whether the application
295        * has requested another message.  At this point in the function, either all frames have been
296        * delivered, or unprocessed is empty.  If there is a partial message, it will be inside next
297        * frame and not in unprocessed.  If there is extra data but no pending deliveries, it will
298        * be in unprocessed.
299        */
300       if (closeWhenComplete && isStalled()) {
301         close();
302       }
303     } finally {
304       inDelivery = false;
305     }
306   }
307 
308   /**
309    * Attempts to read the required bytes into nextFrame.
310    *
311    * @return {@code true} if all of the required bytes have been read.
312    */
readRequiredBytes()313   private boolean readRequiredBytes() {
314     int totalBytesRead = 0;
315     int deflatedBytesRead = 0;
316     try {
317       if (nextFrame == null) {
318         nextFrame = new CompositeReadableBuffer();
319       }
320 
321       // Read until the buffer contains all the required bytes.
322       int missingBytes;
323       while ((missingBytes = requiredLength - nextFrame.readableBytes()) > 0) {
324         if (fullStreamDecompressor != null) {
325           try {
326             if (inflatedBuffer == null || inflatedIndex == inflatedBuffer.length) {
327               inflatedBuffer = new byte[Math.min(missingBytes, MAX_BUFFER_SIZE)];
328               inflatedIndex = 0;
329             }
330             int bytesToRead = Math.min(missingBytes, inflatedBuffer.length - inflatedIndex);
331             int n = fullStreamDecompressor.inflateBytes(inflatedBuffer, inflatedIndex, bytesToRead);
332             totalBytesRead += fullStreamDecompressor.getAndResetBytesConsumed();
333             deflatedBytesRead += fullStreamDecompressor.getAndResetDeflatedBytesConsumed();
334             if (n == 0) {
335               // No more inflated data is available.
336               return false;
337             }
338             nextFrame.addBuffer(ReadableBuffers.wrap(inflatedBuffer, inflatedIndex, n));
339             inflatedIndex += n;
340           } catch (IOException e) {
341             throw new RuntimeException(e);
342           } catch (DataFormatException e) {
343             throw new RuntimeException(e);
344           }
345         } else {
346           if (unprocessed.readableBytes() == 0) {
347             // No more data is available.
348             return false;
349           }
350           int toRead = Math.min(missingBytes, unprocessed.readableBytes());
351           totalBytesRead += toRead;
352           nextFrame.addBuffer(unprocessed.readBytes(toRead));
353         }
354       }
355       return true;
356     } finally {
357       if (totalBytesRead > 0) {
358         listener.bytesRead(totalBytesRead);
359         if (state == State.BODY) {
360           if (fullStreamDecompressor != null) {
361             // With compressed streams, totalBytesRead can include gzip header and trailer metadata
362             statsTraceCtx.inboundWireSize(deflatedBytesRead);
363             inboundBodyWireSize += deflatedBytesRead;
364           } else {
365             statsTraceCtx.inboundWireSize(totalBytesRead);
366             inboundBodyWireSize += totalBytesRead;
367           }
368         }
369       }
370     }
371   }
372 
373   /**
374    * Processes the GRPC compression header which is composed of the compression flag and the outer
375    * frame length.
376    */
processHeader()377   private void processHeader() {
378     int type = nextFrame.readUnsignedByte();
379     if ((type & RESERVED_MASK) != 0) {
380       throw Status.INTERNAL.withDescription(
381           "gRPC frame header malformed: reserved bits not zero")
382           .asRuntimeException();
383     }
384     compressedFlag = (type & COMPRESSED_FLAG_MASK) != 0;
385 
386     // Update the required length to include the length of the frame.
387     requiredLength = nextFrame.readInt();
388     if (requiredLength < 0 || requiredLength > maxInboundMessageSize) {
389       throw Status.RESOURCE_EXHAUSTED.withDescription(
390           String.format(Locale.US, "gRPC message exceeds maximum size %d: %d",
391               maxInboundMessageSize, requiredLength))
392           .asRuntimeException();
393     }
394 
395     currentMessageSeqNo++;
396     statsTraceCtx.inboundMessage(currentMessageSeqNo);
397     transportTracer.reportMessageReceived();
398     // Continue reading the frame body.
399     state = State.BODY;
400   }
401 
402   /**
403    * Processes the GRPC message body, which depending on frame header flags may be compressed.
404    */
processBody()405   private void processBody() {
406     // There is no reliable way to get the uncompressed size per message when it's compressed,
407     // because the uncompressed bytes are provided through an InputStream whose total size is
408     // unknown until all bytes are read, and we don't know when it happens.
409     statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, -1);
410     inboundBodyWireSize = 0;
411     InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody();
412     nextFrame = null;
413     listener.messagesAvailable(new SingleMessageProducer(stream));
414 
415     // Done with this frame, begin processing the next header.
416     state = State.HEADER;
417     requiredLength = HEADER_LENGTH;
418   }
419 
getUncompressedBody()420   private InputStream getUncompressedBody() {
421     statsTraceCtx.inboundUncompressedSize(nextFrame.readableBytes());
422     return ReadableBuffers.openStream(nextFrame, true);
423   }
424 
getCompressedBody()425   private InputStream getCompressedBody() {
426     if (decompressor == Codec.Identity.NONE) {
427       throw Status.INTERNAL.withDescription(
428           "Can't decode compressed gRPC message as compression not configured")
429           .asRuntimeException();
430     }
431 
432     try {
433       // Enforce the maxMessageSize limit on the returned stream.
434       InputStream unlimitedStream =
435           decompressor.decompress(ReadableBuffers.openStream(nextFrame, true));
436       return new SizeEnforcingInputStream(
437           unlimitedStream, maxInboundMessageSize, statsTraceCtx);
438     } catch (IOException e) {
439       throw new RuntimeException(e);
440     }
441   }
442 
443   /**
444    * An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames.
445    */
446   @VisibleForTesting
447   static final class SizeEnforcingInputStream extends FilterInputStream {
448     private final int maxMessageSize;
449     private final StatsTraceContext statsTraceCtx;
450     private long maxCount;
451     private long count;
452     private long mark = -1;
453 
SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx)454     SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx) {
455       super(in);
456       this.maxMessageSize = maxMessageSize;
457       this.statsTraceCtx = statsTraceCtx;
458     }
459 
460     @Override
read()461     public int read() throws IOException {
462       int result = in.read();
463       if (result != -1) {
464         count++;
465       }
466       verifySize();
467       reportCount();
468       return result;
469     }
470 
471     @Override
read(byte[] b, int off, int len)472     public int read(byte[] b, int off, int len) throws IOException {
473       int result = in.read(b, off, len);
474       if (result != -1) {
475         count += result;
476       }
477       verifySize();
478       reportCount();
479       return result;
480     }
481 
482     @Override
skip(long n)483     public long skip(long n) throws IOException {
484       long result = in.skip(n);
485       count += result;
486       verifySize();
487       reportCount();
488       return result;
489     }
490 
491     @Override
mark(int readlimit)492     public synchronized void mark(int readlimit) {
493       in.mark(readlimit);
494       mark = count;
495       // it's okay to mark even if mark isn't supported, as reset won't work
496     }
497 
498     @Override
reset()499     public synchronized void reset() throws IOException {
500       if (!in.markSupported()) {
501         throw new IOException("Mark not supported");
502       }
503       if (mark == -1) {
504         throw new IOException("Mark not set");
505       }
506 
507       in.reset();
508       count = mark;
509     }
510 
reportCount()511     private void reportCount() {
512       if (count > maxCount) {
513         statsTraceCtx.inboundUncompressedSize(count - maxCount);
514         maxCount = count;
515       }
516     }
517 
verifySize()518     private void verifySize() {
519       if (count > maxMessageSize) {
520         throw Status.RESOURCE_EXHAUSTED
521             .withDescription("Decompressed gRPC message exceeds maximum size " + maxMessageSize)
522             .asRuntimeException();
523       }
524     }
525   }
526 
527   private static class SingleMessageProducer implements StreamListener.MessageProducer {
528     private InputStream message;
529 
SingleMessageProducer(InputStream message)530     private SingleMessageProducer(InputStream message) {
531       this.message = message;
532     }
533 
534     @Nullable
535     @Override
next()536     public InputStream next() {
537       InputStream messageToReturn = message;
538       message = null;
539       return messageToReturn;
540     }
541   }
542 }
543