1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.internal; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 import static com.google.common.base.Preconditions.checkState; 22 23 import com.google.common.annotations.VisibleForTesting; 24 import io.grpc.Codec; 25 import io.grpc.Decompressor; 26 import io.grpc.Status; 27 import java.io.Closeable; 28 import java.io.FilterInputStream; 29 import java.io.IOException; 30 import java.io.InputStream; 31 import java.util.Locale; 32 import java.util.zip.DataFormatException; 33 import javax.annotation.Nullable; 34 import javax.annotation.concurrent.NotThreadSafe; 35 36 /** 37 * Deframer for GRPC frames. 38 * 39 * <p>This class is not thread-safe. Unless otherwise stated, all calls to public methods should be 40 * made in the deframing thread. 41 */ 42 @NotThreadSafe 43 public class MessageDeframer implements Closeable, Deframer { 44 private static final int HEADER_LENGTH = 5; 45 private static final int COMPRESSED_FLAG_MASK = 1; 46 private static final int RESERVED_MASK = 0xFE; 47 private static final int MAX_BUFFER_SIZE = 1024 * 1024 * 2; 48 49 /** 50 * A listener of deframing events. These methods will be invoked from the deframing thread. 51 */ 52 public interface Listener { 53 54 /** 55 * Called when the given number of bytes has been read from the input source of the deframer. 56 * This is typically used to indicate to the underlying transport that more data can be 57 * accepted. 58 * 59 * @param numBytes the number of bytes read from the deframer's input source. 60 */ bytesRead(int numBytes)61 void bytesRead(int numBytes); 62 63 /** 64 * Called to deliver the next complete message. 65 * 66 * @param producer single message producer wrapping the message. 67 */ messagesAvailable(StreamListener.MessageProducer producer)68 void messagesAvailable(StreamListener.MessageProducer producer); 69 70 /** 71 * Called when the deframer closes. 72 * 73 * @param hasPartialMessage whether the deframer contained an incomplete message at closing. 74 */ deframerClosed(boolean hasPartialMessage)75 void deframerClosed(boolean hasPartialMessage); 76 77 /** 78 * Called when a {@link #deframe(ReadableBuffer)} operation failed. 79 * 80 * @param cause the actual failure 81 */ deframeFailed(Throwable cause)82 void deframeFailed(Throwable cause); 83 } 84 85 private enum State { 86 HEADER, BODY 87 } 88 89 private Listener listener; 90 private int maxInboundMessageSize; 91 private final StatsTraceContext statsTraceCtx; 92 private final TransportTracer transportTracer; 93 private Decompressor decompressor; 94 private GzipInflatingBuffer fullStreamDecompressor; 95 private byte[] inflatedBuffer; 96 private int inflatedIndex; 97 private State state = State.HEADER; 98 private int requiredLength = HEADER_LENGTH; 99 private boolean compressedFlag; 100 private CompositeReadableBuffer nextFrame; 101 private CompositeReadableBuffer unprocessed = new CompositeReadableBuffer(); 102 private long pendingDeliveries; 103 private boolean inDelivery = false; 104 private int currentMessageSeqNo = -1; 105 private int inboundBodyWireSize; 106 107 private boolean closeWhenComplete = false; 108 private volatile boolean stopDelivery = false; 109 110 /** 111 * Create a deframer. 112 * 113 * @param listener listener for deframer events. 114 * @param decompressor the compression used if a compressed frame is encountered, with 115 * {@code NONE} meaning unsupported 116 * @param maxMessageSize the maximum allowed size for received messages. 117 */ MessageDeframer( Listener listener, Decompressor decompressor, int maxMessageSize, StatsTraceContext statsTraceCtx, TransportTracer transportTracer)118 public MessageDeframer( 119 Listener listener, 120 Decompressor decompressor, 121 int maxMessageSize, 122 StatsTraceContext statsTraceCtx, 123 TransportTracer transportTracer) { 124 this.listener = checkNotNull(listener, "sink"); 125 this.decompressor = checkNotNull(decompressor, "decompressor"); 126 this.maxInboundMessageSize = maxMessageSize; 127 this.statsTraceCtx = checkNotNull(statsTraceCtx, "statsTraceCtx"); 128 this.transportTracer = checkNotNull(transportTracer, "transportTracer"); 129 } 130 setListener(Listener listener)131 void setListener(Listener listener) { 132 this.listener = listener; 133 } 134 135 @Override setMaxInboundMessageSize(int messageSize)136 public void setMaxInboundMessageSize(int messageSize) { 137 maxInboundMessageSize = messageSize; 138 } 139 140 @Override setDecompressor(Decompressor decompressor)141 public void setDecompressor(Decompressor decompressor) { 142 checkState(fullStreamDecompressor == null, "Already set full stream decompressor"); 143 this.decompressor = checkNotNull(decompressor, "Can't pass an empty decompressor"); 144 } 145 146 @Override setFullStreamDecompressor(GzipInflatingBuffer fullStreamDecompressor)147 public void setFullStreamDecompressor(GzipInflatingBuffer fullStreamDecompressor) { 148 checkState(decompressor == Codec.Identity.NONE, "per-message decompressor already set"); 149 checkState(this.fullStreamDecompressor == null, "full stream decompressor already set"); 150 this.fullStreamDecompressor = 151 checkNotNull(fullStreamDecompressor, "Can't pass a null full stream decompressor"); 152 unprocessed = null; 153 } 154 155 @Override request(int numMessages)156 public void request(int numMessages) { 157 checkArgument(numMessages > 0, "numMessages must be > 0"); 158 if (isClosed()) { 159 return; 160 } 161 pendingDeliveries += numMessages; 162 deliver(); 163 } 164 165 @Override deframe(ReadableBuffer data)166 public void deframe(ReadableBuffer data) { 167 checkNotNull(data, "data"); 168 boolean needToCloseData = true; 169 try { 170 if (!isClosedOrScheduledToClose()) { 171 if (fullStreamDecompressor != null) { 172 fullStreamDecompressor.addGzippedBytes(data); 173 } else { 174 unprocessed.addBuffer(data); 175 } 176 needToCloseData = false; 177 178 deliver(); 179 } 180 } finally { 181 if (needToCloseData) { 182 data.close(); 183 } 184 } 185 } 186 187 @Override closeWhenComplete()188 public void closeWhenComplete() { 189 if (isClosed()) { 190 return; 191 } else if (isStalled()) { 192 close(); 193 } else { 194 closeWhenComplete = true; 195 } 196 } 197 198 /** 199 * Sets a flag to interrupt delivery of any currently queued messages. This may be invoked outside 200 * of the deframing thread, and must be followed by a call to {@link #close()} in the deframing 201 * thread. Without a subsequent call to {@link #close()}, the deframer may hang waiting for 202 * additional messages before noticing that the {@code stopDelivery} flag has been set. 203 */ stopDelivery()204 void stopDelivery() { 205 stopDelivery = true; 206 } 207 hasPendingDeliveries()208 boolean hasPendingDeliveries() { 209 return pendingDeliveries != 0; 210 } 211 212 @Override close()213 public void close() { 214 if (isClosed()) { 215 return; 216 } 217 boolean hasPartialMessage = nextFrame != null && nextFrame.readableBytes() > 0; 218 try { 219 if (fullStreamDecompressor != null) { 220 hasPartialMessage = hasPartialMessage || fullStreamDecompressor.hasPartialData(); 221 fullStreamDecompressor.close(); 222 } 223 if (unprocessed != null) { 224 unprocessed.close(); 225 } 226 if (nextFrame != null) { 227 nextFrame.close(); 228 } 229 } finally { 230 fullStreamDecompressor = null; 231 unprocessed = null; 232 nextFrame = null; 233 } 234 listener.deframerClosed(hasPartialMessage); 235 } 236 237 /** 238 * Indicates whether or not this deframer has been closed. 239 */ isClosed()240 public boolean isClosed() { 241 return unprocessed == null && fullStreamDecompressor == null; 242 } 243 244 /** Returns true if this deframer has already been closed or scheduled to close. */ isClosedOrScheduledToClose()245 private boolean isClosedOrScheduledToClose() { 246 return isClosed() || closeWhenComplete; 247 } 248 isStalled()249 private boolean isStalled() { 250 if (fullStreamDecompressor != null) { 251 return fullStreamDecompressor.isStalled(); 252 } else { 253 return unprocessed.readableBytes() == 0; 254 } 255 } 256 257 /** 258 * Reads and delivers as many messages to the listener as possible. 259 */ deliver()260 private void deliver() { 261 // We can have reentrancy here when using a direct executor, triggered by calls to 262 // request more messages. This is safe as we simply loop until pendingDelivers = 0 263 if (inDelivery) { 264 return; 265 } 266 inDelivery = true; 267 try { 268 // Process the uncompressed bytes. 269 while (!stopDelivery && pendingDeliveries > 0 && readRequiredBytes()) { 270 switch (state) { 271 case HEADER: 272 processHeader(); 273 break; 274 case BODY: 275 // Read the body and deliver the message. 276 processBody(); 277 278 // Since we've delivered a message, decrement the number of pending 279 // deliveries remaining. 280 pendingDeliveries--; 281 break; 282 default: 283 throw new AssertionError("Invalid state: " + state); 284 } 285 } 286 287 if (stopDelivery) { 288 close(); 289 return; 290 } 291 292 /* 293 * We are stalled when there are no more bytes to process. This allows delivering errors as 294 * soon as the buffered input has been consumed, independent of whether the application 295 * has requested another message. At this point in the function, either all frames have been 296 * delivered, or unprocessed is empty. If there is a partial message, it will be inside next 297 * frame and not in unprocessed. If there is extra data but no pending deliveries, it will 298 * be in unprocessed. 299 */ 300 if (closeWhenComplete && isStalled()) { 301 close(); 302 } 303 } finally { 304 inDelivery = false; 305 } 306 } 307 308 /** 309 * Attempts to read the required bytes into nextFrame. 310 * 311 * @return {@code true} if all of the required bytes have been read. 312 */ readRequiredBytes()313 private boolean readRequiredBytes() { 314 int totalBytesRead = 0; 315 int deflatedBytesRead = 0; 316 try { 317 if (nextFrame == null) { 318 nextFrame = new CompositeReadableBuffer(); 319 } 320 321 // Read until the buffer contains all the required bytes. 322 int missingBytes; 323 while ((missingBytes = requiredLength - nextFrame.readableBytes()) > 0) { 324 if (fullStreamDecompressor != null) { 325 try { 326 if (inflatedBuffer == null || inflatedIndex == inflatedBuffer.length) { 327 inflatedBuffer = new byte[Math.min(missingBytes, MAX_BUFFER_SIZE)]; 328 inflatedIndex = 0; 329 } 330 int bytesToRead = Math.min(missingBytes, inflatedBuffer.length - inflatedIndex); 331 int n = fullStreamDecompressor.inflateBytes(inflatedBuffer, inflatedIndex, bytesToRead); 332 totalBytesRead += fullStreamDecompressor.getAndResetBytesConsumed(); 333 deflatedBytesRead += fullStreamDecompressor.getAndResetDeflatedBytesConsumed(); 334 if (n == 0) { 335 // No more inflated data is available. 336 return false; 337 } 338 nextFrame.addBuffer(ReadableBuffers.wrap(inflatedBuffer, inflatedIndex, n)); 339 inflatedIndex += n; 340 } catch (IOException e) { 341 throw new RuntimeException(e); 342 } catch (DataFormatException e) { 343 throw new RuntimeException(e); 344 } 345 } else { 346 if (unprocessed.readableBytes() == 0) { 347 // No more data is available. 348 return false; 349 } 350 int toRead = Math.min(missingBytes, unprocessed.readableBytes()); 351 totalBytesRead += toRead; 352 nextFrame.addBuffer(unprocessed.readBytes(toRead)); 353 } 354 } 355 return true; 356 } finally { 357 if (totalBytesRead > 0) { 358 listener.bytesRead(totalBytesRead); 359 if (state == State.BODY) { 360 if (fullStreamDecompressor != null) { 361 // With compressed streams, totalBytesRead can include gzip header and trailer metadata 362 statsTraceCtx.inboundWireSize(deflatedBytesRead); 363 inboundBodyWireSize += deflatedBytesRead; 364 } else { 365 statsTraceCtx.inboundWireSize(totalBytesRead); 366 inboundBodyWireSize += totalBytesRead; 367 } 368 } 369 } 370 } 371 } 372 373 /** 374 * Processes the GRPC compression header which is composed of the compression flag and the outer 375 * frame length. 376 */ processHeader()377 private void processHeader() { 378 int type = nextFrame.readUnsignedByte(); 379 if ((type & RESERVED_MASK) != 0) { 380 throw Status.INTERNAL.withDescription( 381 "gRPC frame header malformed: reserved bits not zero") 382 .asRuntimeException(); 383 } 384 compressedFlag = (type & COMPRESSED_FLAG_MASK) != 0; 385 386 // Update the required length to include the length of the frame. 387 requiredLength = nextFrame.readInt(); 388 if (requiredLength < 0 || requiredLength > maxInboundMessageSize) { 389 throw Status.RESOURCE_EXHAUSTED.withDescription( 390 String.format(Locale.US, "gRPC message exceeds maximum size %d: %d", 391 maxInboundMessageSize, requiredLength)) 392 .asRuntimeException(); 393 } 394 395 currentMessageSeqNo++; 396 statsTraceCtx.inboundMessage(currentMessageSeqNo); 397 transportTracer.reportMessageReceived(); 398 // Continue reading the frame body. 399 state = State.BODY; 400 } 401 402 /** 403 * Processes the GRPC message body, which depending on frame header flags may be compressed. 404 */ processBody()405 private void processBody() { 406 // There is no reliable way to get the uncompressed size per message when it's compressed, 407 // because the uncompressed bytes are provided through an InputStream whose total size is 408 // unknown until all bytes are read, and we don't know when it happens. 409 statsTraceCtx.inboundMessageRead(currentMessageSeqNo, inboundBodyWireSize, -1); 410 inboundBodyWireSize = 0; 411 InputStream stream = compressedFlag ? getCompressedBody() : getUncompressedBody(); 412 nextFrame = null; 413 listener.messagesAvailable(new SingleMessageProducer(stream)); 414 415 // Done with this frame, begin processing the next header. 416 state = State.HEADER; 417 requiredLength = HEADER_LENGTH; 418 } 419 getUncompressedBody()420 private InputStream getUncompressedBody() { 421 statsTraceCtx.inboundUncompressedSize(nextFrame.readableBytes()); 422 return ReadableBuffers.openStream(nextFrame, true); 423 } 424 getCompressedBody()425 private InputStream getCompressedBody() { 426 if (decompressor == Codec.Identity.NONE) { 427 throw Status.INTERNAL.withDescription( 428 "Can't decode compressed gRPC message as compression not configured") 429 .asRuntimeException(); 430 } 431 432 try { 433 // Enforce the maxMessageSize limit on the returned stream. 434 InputStream unlimitedStream = 435 decompressor.decompress(ReadableBuffers.openStream(nextFrame, true)); 436 return new SizeEnforcingInputStream( 437 unlimitedStream, maxInboundMessageSize, statsTraceCtx); 438 } catch (IOException e) { 439 throw new RuntimeException(e); 440 } 441 } 442 443 /** 444 * An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames. 445 */ 446 @VisibleForTesting 447 static final class SizeEnforcingInputStream extends FilterInputStream { 448 private final int maxMessageSize; 449 private final StatsTraceContext statsTraceCtx; 450 private long maxCount; 451 private long count; 452 private long mark = -1; 453 SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx)454 SizeEnforcingInputStream(InputStream in, int maxMessageSize, StatsTraceContext statsTraceCtx) { 455 super(in); 456 this.maxMessageSize = maxMessageSize; 457 this.statsTraceCtx = statsTraceCtx; 458 } 459 460 @Override read()461 public int read() throws IOException { 462 int result = in.read(); 463 if (result != -1) { 464 count++; 465 } 466 verifySize(); 467 reportCount(); 468 return result; 469 } 470 471 @Override read(byte[] b, int off, int len)472 public int read(byte[] b, int off, int len) throws IOException { 473 int result = in.read(b, off, len); 474 if (result != -1) { 475 count += result; 476 } 477 verifySize(); 478 reportCount(); 479 return result; 480 } 481 482 @Override skip(long n)483 public long skip(long n) throws IOException { 484 long result = in.skip(n); 485 count += result; 486 verifySize(); 487 reportCount(); 488 return result; 489 } 490 491 @Override mark(int readlimit)492 public synchronized void mark(int readlimit) { 493 in.mark(readlimit); 494 mark = count; 495 // it's okay to mark even if mark isn't supported, as reset won't work 496 } 497 498 @Override reset()499 public synchronized void reset() throws IOException { 500 if (!in.markSupported()) { 501 throw new IOException("Mark not supported"); 502 } 503 if (mark == -1) { 504 throw new IOException("Mark not set"); 505 } 506 507 in.reset(); 508 count = mark; 509 } 510 reportCount()511 private void reportCount() { 512 if (count > maxCount) { 513 statsTraceCtx.inboundUncompressedSize(count - maxCount); 514 maxCount = count; 515 } 516 } 517 verifySize()518 private void verifySize() { 519 if (count > maxMessageSize) { 520 throw Status.RESOURCE_EXHAUSTED 521 .withDescription("Decompressed gRPC message exceeds maximum size " + maxMessageSize) 522 .asRuntimeException(); 523 } 524 } 525 } 526 527 private static class SingleMessageProducer implements StreamListener.MessageProducer { 528 private InputStream message; 529 SingleMessageProducer(InputStream message)530 private SingleMessageProducer(InputStream message) { 531 this.message = message; 532 } 533 534 @Nullable 535 @Override next()536 public InputStream next() { 537 InputStream messageToReturn = message; 538 message = null; 539 return messageToReturn; 540 } 541 } 542 } 543