1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 import static com.google.common.base.Preconditions.checkState; 21 import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY; 22 import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY; 23 24 import io.grpc.Attributes; 25 import io.grpc.ChannelLogger; 26 import io.grpc.ChannelLogger.ChannelLogLevel; 27 import io.grpc.InternalChannelz.Security; 28 import io.grpc.SecurityLevel; 29 import io.grpc.alts.internal.TsiHandshakeHandler.HandshakeValidator.SecurityDetails; 30 import io.grpc.internal.GrpcAttributes; 31 import io.grpc.netty.InternalProtocolNegotiationEvent; 32 import io.grpc.netty.ProtocolNegotiationEvent; 33 import io.netty.buffer.ByteBuf; 34 import io.netty.channel.ChannelFuture; 35 import io.netty.channel.ChannelFutureListener; 36 import io.netty.channel.ChannelHandler; 37 import io.netty.channel.ChannelHandlerContext; 38 import io.netty.handler.codec.ByteToMessageDecoder; 39 import java.security.GeneralSecurityException; 40 import java.util.List; 41 import javax.annotation.Nullable; 42 43 /** 44 * Performs The TSI Handshake. 45 */ 46 public final class TsiHandshakeHandler extends ByteToMessageDecoder { 47 /** 48 * Validates a Tsi Peer object. 49 */ 50 public abstract static class HandshakeValidator { 51 52 public static final class SecurityDetails { 53 54 private final SecurityLevel securityLevel; 55 private final Security security; 56 57 /** 58 * Constructs SecurityDetails. 59 */ SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security)60 public SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security) { 61 this.securityLevel = checkNotNull(securityLevel, "securityLevel"); 62 this.security = security; 63 } 64 getSecurity()65 public Security getSecurity() { 66 return security; 67 } 68 getSecurityLevel()69 public SecurityLevel getSecurityLevel() { 70 return securityLevel; 71 } 72 } 73 74 /** 75 * Validates a Tsi Peer object. 76 */ validatePeerObject(Object peerObject)77 public abstract SecurityDetails validatePeerObject(Object peerObject) 78 throws GeneralSecurityException; 79 } 80 81 private static final int HANDSHAKE_FRAME_SIZE = 1024; 82 83 private final NettyTsiHandshaker handshaker; 84 private final HandshakeValidator handshakeValidator; 85 private final ChannelHandler next; 86 private final AsyncSemaphore semaphore; 87 88 private ProtocolNegotiationEvent pne; 89 private boolean semaphoreAcquired; 90 private final ChannelLogger negotiationLogger; 91 92 /** 93 * Constructs a TsiHandshakeHandler. 94 */ TsiHandshakeHandler( ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, ChannelLogger negotiationLogger)95 public TsiHandshakeHandler( 96 ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, 97 ChannelLogger negotiationLogger) { 98 this(next, handshaker, handshakeValidator, null, negotiationLogger); 99 } 100 101 /** 102 * Constructs a TsHandshakeHandler. If a semaphore is provided, a permit from the semaphore is 103 * required to start the handshake and is returned when the handshake ends. 104 */ TsiHandshakeHandler( ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, AsyncSemaphore semaphore, ChannelLogger negotiationLogger)105 public TsiHandshakeHandler( 106 ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, 107 AsyncSemaphore semaphore, ChannelLogger negotiationLogger) { 108 this.handshaker = checkNotNull(handshaker, "handshaker"); 109 this.handshakeValidator = checkNotNull(handshakeValidator, "handshakeValidator"); 110 this.next = checkNotNull(next, "next"); 111 this.semaphore = semaphore; 112 this.negotiationLogger = negotiationLogger; 113 } 114 115 @Override decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)116 protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) 117 throws Exception { 118 // TODO: Not sure why override is needed. Investigate if it can be removed. 119 decode(ctx, in, out); 120 } 121 122 @Override decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)123 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { 124 // Process the data. If we need to send more data, do so now. 125 if (handshaker.processBytesFromPeer(in) && handshaker.isInProgress()) { 126 sendHandshake(ctx); 127 } 128 129 // If the handshake is complete, transition to the framing state. 130 if (!handshaker.isInProgress()) { 131 TsiPeer peer = handshaker.extractPeer(); 132 Object authContext = handshaker.extractPeerObject(); 133 SecurityDetails details = handshakeValidator.validatePeerObject(authContext); 134 // createFrameProtector must be called last. 135 TsiFrameProtector protector = handshaker.createFrameProtector(ctx.alloc()); 136 TsiFrameHandler framer; 137 boolean success = false; 138 try { 139 framer = new TsiFrameHandler(protector); 140 // adding framer and next handler after this handler before removing Decoder (current 141 // handler). This will prevents any missing read from decoder and/or unframed write from 142 // next handler. 143 ctx.pipeline().addAfter(ctx.name(), null, framer); 144 ctx.pipeline().addAfter(ctx.pipeline().context(framer).name(), null, next); 145 ctx.pipeline().remove(ctx.name()); 146 fireProtocolNegotiationEvent(ctx, peer, authContext, details); 147 success = true; 148 } finally { 149 if (!success && protector != null) { 150 protector.destroy(); 151 } 152 } 153 } 154 } 155 156 @Override userEventTriggered(final ChannelHandlerContext ctx, Object evt)157 public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) throws Exception { 158 if (evt instanceof ProtocolNegotiationEvent) { 159 checkState(pne == null, "negotiation already started"); 160 pne = (ProtocolNegotiationEvent) evt; 161 negotiationLogger.log(ChannelLogLevel.INFO, "TsiHandshake started"); 162 ChannelFuture acquire = semaphoreAcquire(ctx); 163 if (acquire.isSuccess()) { 164 semaphoreAcquired = true; 165 sendHandshake(ctx); 166 } else { 167 acquire.addListener(new ChannelFutureListener() { 168 @Override public void operationComplete(ChannelFuture future) { 169 if (!future.isSuccess()) { 170 ctx.fireExceptionCaught(future.cause()); 171 return; 172 } 173 if (ctx.isRemoved()) { 174 semaphoreRelease(); 175 return; 176 } 177 semaphoreAcquired = true; 178 try { 179 sendHandshake(ctx); 180 } catch (Exception ex) { 181 ctx.fireExceptionCaught(ex); 182 } 183 ctx.flush(); 184 } 185 }); 186 } 187 } else { 188 super.userEventTriggered(ctx, evt); 189 } 190 } 191 fireProtocolNegotiationEvent( ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details)192 private void fireProtocolNegotiationEvent( 193 ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) { 194 checkState(pne != null, "negotiation not yet complete"); 195 negotiationLogger.log(ChannelLogLevel.INFO, "TsiHandshake finished"); 196 ProtocolNegotiationEvent localPne = pne; 197 Attributes.Builder attrs = InternalProtocolNegotiationEvent.getAttributes(localPne).toBuilder() 198 .set(TSI_PEER_KEY, peer) 199 .set(AUTH_CONTEXT_KEY, authContext) 200 .set(GrpcAttributes.ATTR_SECURITY_LEVEL, details.getSecurityLevel()); 201 localPne = InternalProtocolNegotiationEvent.withAttributes(localPne, attrs.build()); 202 localPne = InternalProtocolNegotiationEvent.withSecurity(localPne, details.getSecurity()); 203 ctx.fireUserEventTriggered(localPne); 204 } 205 206 /** Sends as many bytes as are available from the handshaker to the remote peer. */ 207 @SuppressWarnings("FutureReturnValueIgnored") // for addListener sendHandshake(ChannelHandlerContext ctx)208 private void sendHandshake(ChannelHandlerContext ctx) throws GeneralSecurityException { 209 while (true) { 210 boolean written = false; 211 ByteBuf buf = ctx.alloc().buffer(HANDSHAKE_FRAME_SIZE).retain(); // refcnt = 2 212 try { 213 handshaker.getBytesToSendToPeer(buf); 214 if (buf.isReadable()) { 215 ctx.writeAndFlush(buf).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); 216 written = true; 217 } else { 218 break; 219 } 220 } catch (GeneralSecurityException e) { 221 throw new GeneralSecurityException("TsiHandshakeHandler encountered exception", e); 222 } finally { 223 buf.release(written ? 1 : 2); 224 } 225 } 226 } 227 228 @Override handlerRemoved0(ChannelHandlerContext ctx)229 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { 230 if (semaphoreAcquired) { 231 semaphoreRelease(); 232 semaphoreAcquired = false; 233 } 234 handshaker.close(); 235 } 236 semaphoreAcquire(ChannelHandlerContext ctx)237 private ChannelFuture semaphoreAcquire(ChannelHandlerContext ctx) { 238 if (semaphore == null) { 239 return ctx.newSucceededFuture(); 240 } else { 241 return semaphore.acquire(ctx); 242 } 243 } 244 semaphoreRelease()245 private void semaphoreRelease() { 246 if (semaphore != null) { 247 semaphore.release(); 248 } 249 } 250 } 251