• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static com.google.common.base.Preconditions.checkState;
21 import static io.grpc.alts.internal.AltsProtocolNegotiator.AUTH_CONTEXT_KEY;
22 import static io.grpc.alts.internal.AltsProtocolNegotiator.TSI_PEER_KEY;
23 
24 import io.grpc.Attributes;
25 import io.grpc.ChannelLogger;
26 import io.grpc.ChannelLogger.ChannelLogLevel;
27 import io.grpc.InternalChannelz.Security;
28 import io.grpc.SecurityLevel;
29 import io.grpc.alts.internal.TsiHandshakeHandler.HandshakeValidator.SecurityDetails;
30 import io.grpc.internal.GrpcAttributes;
31 import io.grpc.netty.InternalProtocolNegotiationEvent;
32 import io.grpc.netty.ProtocolNegotiationEvent;
33 import io.netty.buffer.ByteBuf;
34 import io.netty.channel.ChannelFuture;
35 import io.netty.channel.ChannelFutureListener;
36 import io.netty.channel.ChannelHandler;
37 import io.netty.channel.ChannelHandlerContext;
38 import io.netty.handler.codec.ByteToMessageDecoder;
39 import java.security.GeneralSecurityException;
40 import java.util.List;
41 import javax.annotation.Nullable;
42 
43 /**
44  * Performs The TSI Handshake.
45  */
46 public final class TsiHandshakeHandler extends ByteToMessageDecoder {
47   /**
48    * Validates a Tsi Peer object.
49    */
50   public abstract static class HandshakeValidator {
51 
52     public static final class SecurityDetails {
53 
54       private final SecurityLevel securityLevel;
55       private final Security security;
56 
57       /**
58        * Constructs SecurityDetails.
59        */
SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security)60       public SecurityDetails(io.grpc.SecurityLevel securityLevel, @Nullable Security security) {
61         this.securityLevel = checkNotNull(securityLevel, "securityLevel");
62         this.security = security;
63       }
64 
getSecurity()65       public Security getSecurity() {
66         return security;
67       }
68 
getSecurityLevel()69       public SecurityLevel getSecurityLevel() {
70         return securityLevel;
71       }
72     }
73 
74     /**
75      * Validates a Tsi Peer object.
76      */
validatePeerObject(Object peerObject)77     public abstract SecurityDetails validatePeerObject(Object peerObject)
78         throws GeneralSecurityException;
79   }
80 
81   private static final int HANDSHAKE_FRAME_SIZE = 1024;
82 
83   private final NettyTsiHandshaker handshaker;
84   private final HandshakeValidator handshakeValidator;
85   private final ChannelHandler next;
86   private final AsyncSemaphore semaphore;
87 
88   private ProtocolNegotiationEvent pne;
89   private boolean semaphoreAcquired;
90   private final ChannelLogger negotiationLogger;
91 
92   /**
93    * Constructs a TsiHandshakeHandler.
94    */
TsiHandshakeHandler( ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, ChannelLogger negotiationLogger)95   public TsiHandshakeHandler(
96       ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator,
97       ChannelLogger negotiationLogger) {
98     this(next, handshaker, handshakeValidator, null, negotiationLogger);
99   }
100 
101   /**
102    * Constructs a TsHandshakeHandler. If a semaphore is provided, a permit from the semaphore is
103    * required to start the handshake and is returned when the handshake ends.
104    */
TsiHandshakeHandler( ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator, AsyncSemaphore semaphore, ChannelLogger negotiationLogger)105   public TsiHandshakeHandler(
106       ChannelHandler next, NettyTsiHandshaker handshaker, HandshakeValidator handshakeValidator,
107       AsyncSemaphore semaphore, ChannelLogger negotiationLogger) {
108     this.handshaker = checkNotNull(handshaker, "handshaker");
109     this.handshakeValidator = checkNotNull(handshakeValidator, "handshakeValidator");
110     this.next = checkNotNull(next, "next");
111     this.semaphore = semaphore;
112     this.negotiationLogger = negotiationLogger;
113   }
114 
115   @Override
decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)116   protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
117       throws Exception {
118     // TODO: Not sure why override is needed. Investigate if it can be removed.
119     decode(ctx, in, out);
120   }
121 
122   @Override
decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)123   protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
124     // Process the data. If we need to send more data, do so now.
125     if (handshaker.processBytesFromPeer(in) && handshaker.isInProgress()) {
126       sendHandshake(ctx);
127     }
128 
129     // If the handshake is complete, transition to the framing state.
130     if (!handshaker.isInProgress()) {
131       TsiPeer peer = handshaker.extractPeer();
132       Object authContext = handshaker.extractPeerObject();
133       SecurityDetails details = handshakeValidator.validatePeerObject(authContext);
134       // createFrameProtector must be called last.
135       TsiFrameProtector protector = handshaker.createFrameProtector(ctx.alloc());
136       TsiFrameHandler framer;
137       boolean success = false;
138       try {
139         framer = new TsiFrameHandler(protector);
140         // adding framer and next handler after this handler before removing Decoder (current
141         // handler). This will prevents any missing read from decoder and/or unframed write from
142         // next handler.
143         ctx.pipeline().addAfter(ctx.name(), null, framer);
144         ctx.pipeline().addAfter(ctx.pipeline().context(framer).name(), null, next);
145         ctx.pipeline().remove(ctx.name());
146         fireProtocolNegotiationEvent(ctx, peer, authContext, details);
147         success = true;
148       } finally {
149         if (!success && protector != null) {
150           protector.destroy();
151         }
152       }
153     }
154   }
155 
156   @Override
userEventTriggered(final ChannelHandlerContext ctx, Object evt)157   public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) throws Exception {
158     if (evt instanceof ProtocolNegotiationEvent) {
159       checkState(pne == null, "negotiation already started");
160       pne = (ProtocolNegotiationEvent) evt;
161       negotiationLogger.log(ChannelLogLevel.INFO, "TsiHandshake started");
162       ChannelFuture acquire = semaphoreAcquire(ctx);
163       if (acquire.isSuccess()) {
164         semaphoreAcquired = true;
165         sendHandshake(ctx);
166       } else {
167         acquire.addListener(new ChannelFutureListener() {
168           @Override public void operationComplete(ChannelFuture future) {
169             if (!future.isSuccess()) {
170               ctx.fireExceptionCaught(future.cause());
171               return;
172             }
173             if (ctx.isRemoved()) {
174               semaphoreRelease();
175               return;
176             }
177             semaphoreAcquired = true;
178             try {
179               sendHandshake(ctx);
180             } catch (Exception ex) {
181               ctx.fireExceptionCaught(ex);
182             }
183             ctx.flush();
184           }
185         });
186       }
187     } else {
188       super.userEventTriggered(ctx, evt);
189     }
190   }
191 
fireProtocolNegotiationEvent( ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details)192   private void fireProtocolNegotiationEvent(
193       ChannelHandlerContext ctx, TsiPeer peer, Object authContext, SecurityDetails details) {
194     checkState(pne != null, "negotiation not yet complete");
195     negotiationLogger.log(ChannelLogLevel.INFO, "TsiHandshake finished");
196     ProtocolNegotiationEvent localPne = pne;
197     Attributes.Builder attrs = InternalProtocolNegotiationEvent.getAttributes(localPne).toBuilder()
198         .set(TSI_PEER_KEY, peer)
199         .set(AUTH_CONTEXT_KEY, authContext)
200         .set(GrpcAttributes.ATTR_SECURITY_LEVEL, details.getSecurityLevel());
201     localPne = InternalProtocolNegotiationEvent.withAttributes(localPne, attrs.build());
202     localPne = InternalProtocolNegotiationEvent.withSecurity(localPne, details.getSecurity());
203     ctx.fireUserEventTriggered(localPne);
204   }
205 
206   /** Sends as many bytes as are available from the handshaker to the remote peer. */
207   @SuppressWarnings("FutureReturnValueIgnored") // for addListener
sendHandshake(ChannelHandlerContext ctx)208   private void sendHandshake(ChannelHandlerContext ctx) throws GeneralSecurityException {
209     while (true) {
210       boolean written = false;
211       ByteBuf buf = ctx.alloc().buffer(HANDSHAKE_FRAME_SIZE).retain(); // refcnt = 2
212       try {
213         handshaker.getBytesToSendToPeer(buf);
214         if (buf.isReadable()) {
215           ctx.writeAndFlush(buf).addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE);
216           written = true;
217         } else {
218           break;
219         }
220       } catch (GeneralSecurityException e) {
221         throw new GeneralSecurityException("TsiHandshakeHandler encountered exception", e);
222       } finally {
223         buf.release(written ? 1 : 2);
224       }
225     }
226   }
227 
228   @Override
handlerRemoved0(ChannelHandlerContext ctx)229   protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
230     if (semaphoreAcquired) {
231       semaphoreRelease();
232       semaphoreAcquired = false;
233     }
234     handshaker.close();
235   }
236 
semaphoreAcquire(ChannelHandlerContext ctx)237   private ChannelFuture semaphoreAcquire(ChannelHandlerContext ctx) {
238     if (semaphore == null) {
239       return ctx.newSucceededFuture();
240     } else {
241       return semaphore.acquire(ctx);
242     }
243   }
244 
semaphoreRelease()245   private void semaphoreRelease() {
246     if (semaphore != null) {
247       semaphore.release();
248     }
249   }
250 }
251