1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 import static com.google.common.base.Preconditions.checkState; 21 22 import com.google.common.annotations.VisibleForTesting; 23 import io.grpc.alts.internal.TsiFrameProtector.Consumer; 24 import io.grpc.alts.internal.TsiHandshakeHandler.TsiHandshakeCompletionEvent; 25 import io.netty.buffer.ByteBuf; 26 import io.netty.channel.ChannelException; 27 import io.netty.channel.ChannelHandlerContext; 28 import io.netty.channel.ChannelOutboundHandler; 29 import io.netty.channel.ChannelPromise; 30 import io.netty.channel.PendingWriteQueue; 31 import io.netty.handler.codec.ByteToMessageDecoder; 32 import java.net.SocketAddress; 33 import java.security.GeneralSecurityException; 34 import java.util.ArrayList; 35 import java.util.List; 36 import java.util.concurrent.Future; 37 38 /** 39 * Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes 40 * must not be made before the TSI handshake is complete. 41 */ 42 public final class TsiFrameHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { 43 44 private TsiFrameProtector protector; 45 private PendingWriteQueue pendingUnprotectedWrites; 46 TsiFrameHandler()47 public TsiFrameHandler() {} 48 49 @Override handlerAdded(ChannelHandlerContext ctx)50 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 51 super.handlerAdded(ctx); 52 assert pendingUnprotectedWrites == null; 53 pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx)); 54 } 55 56 @Override userEventTriggered(ChannelHandlerContext ctx, Object event)57 public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception { 58 if (event instanceof TsiHandshakeCompletionEvent) { 59 TsiHandshakeCompletionEvent tsiEvent = (TsiHandshakeCompletionEvent) event; 60 if (tsiEvent.isSuccess()) { 61 setProtector(tsiEvent.protector()); 62 } 63 // Ignore errors. Another handler in the pipeline must handle TSI Errors. 64 } 65 // Keep propagating the message, as others may want to read it. 66 super.userEventTriggered(ctx, event); 67 } 68 69 @VisibleForTesting setProtector(TsiFrameProtector protector)70 void setProtector(TsiFrameProtector protector) { 71 checkState(this.protector == null); 72 this.protector = checkNotNull(protector); 73 } 74 75 @Override decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)76 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { 77 checkState(protector != null, "Cannot read frames while the TSI handshake is in progress"); 78 protector.unprotect(in, out, ctx.alloc()); 79 } 80 81 @Override write(ChannelHandlerContext ctx, Object message, ChannelPromise promise)82 public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) 83 throws Exception { 84 checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); 85 ByteBuf msg = (ByteBuf) message; 86 if (!msg.isReadable()) { 87 // Nothing to encode. 88 @SuppressWarnings("unused") // go/futurereturn-lsc 89 Future<?> possiblyIgnoredError = promise.setSuccess(); 90 return; 91 } 92 93 // Just add the message to the pending queue. We'll write it on the next flush. 94 pendingUnprotectedWrites.add(msg, promise); 95 } 96 97 @Override handlerRemoved0(ChannelHandlerContext ctx)98 public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { 99 if (!pendingUnprotectedWrites.isEmpty()) { 100 pendingUnprotectedWrites.removeAndFailAll( 101 new ChannelException("Pending write on removal of TSI handler")); 102 } 103 } 104 105 @Override exceptionCaught(ChannelHandlerContext ctx, Throwable cause)106 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { 107 pendingUnprotectedWrites.removeAndFailAll(cause); 108 super.exceptionCaught(ctx, cause); 109 } 110 111 @Override bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise)112 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { 113 ctx.bind(localAddress, promise); 114 } 115 116 @Override connect( ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise)117 public void connect( 118 ChannelHandlerContext ctx, 119 SocketAddress remoteAddress, 120 SocketAddress localAddress, 121 ChannelPromise promise) { 122 ctx.connect(remoteAddress, localAddress, promise); 123 } 124 125 @Override disconnect(ChannelHandlerContext ctx, ChannelPromise promise)126 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { 127 ctx.disconnect(promise); 128 } 129 130 @Override close(ChannelHandlerContext ctx, ChannelPromise promise)131 public void close(ChannelHandlerContext ctx, ChannelPromise promise) { 132 ctx.close(promise); 133 } 134 135 @Override deregister(ChannelHandlerContext ctx, ChannelPromise promise)136 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { 137 ctx.deregister(promise); 138 } 139 140 @Override read(ChannelHandlerContext ctx)141 public void read(ChannelHandlerContext ctx) { 142 ctx.read(); 143 } 144 145 @Override flush(final ChannelHandlerContext ctx)146 public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { 147 checkState(protector != null, "Cannot write frames while the TSI handshake is in progress"); 148 final ProtectedPromise aggregatePromise = 149 new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); 150 151 List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size()); 152 153 if (pendingUnprotectedWrites.isEmpty()) { 154 // Return early if there's nothing to write. Otherwise protector.protectFlush() below may 155 // not check for "no-data" and go on writing the 0-byte "data" to the socket with the 156 // protection framing. 157 return; 158 } 159 // Drain the unprotected writes. 160 while (!pendingUnprotectedWrites.isEmpty()) { 161 ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current(); 162 bufs.add(in.retain()); 163 // Remove and release the buffer and add its promise to the aggregate. 164 aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove()); 165 } 166 167 protector.protectFlush( 168 bufs, 169 new Consumer<ByteBuf>() { 170 @Override 171 public void accept(ByteBuf b) { 172 ctx.writeAndFlush(b, aggregatePromise.newPromise()); 173 } 174 }, 175 ctx.alloc()); 176 177 // We're done writing, start the flow of promise events. 178 @SuppressWarnings("unused") // go/futurereturn-lsc 179 Future<?> possiblyIgnoredError = aggregatePromise.doneAllocatingPromises(); 180 } 181 } 182