• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Preconditions.checkState;
21 import static com.google.common.base.Verify.verify;
22 
23 import com.google.common.primitives.Ints;
24 import io.netty.buffer.ByteBuf;
25 import io.netty.buffer.ByteBufAllocator;
26 import java.security.GeneralSecurityException;
27 import java.util.ArrayList;
28 import java.util.List;
29 
30 /** Frame protector that uses the ALTS framing. */
31 public final class AltsTsiFrameProtector implements TsiFrameProtector {
32   private static final int HEADER_LEN_FIELD_BYTES = 4;
33   private static final int HEADER_TYPE_FIELD_BYTES = 4;
34   private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES;
35   private static final int HEADER_TYPE_DEFAULT = 6;
36   private static final int LIMIT_MAX_ALLOWED_FRAME_SIZE = 1024 * 1024;
37   // Frame size negotiation extends frame size range to [MIN_FRAME_SIZE, MAX_FRAME_SIZE].
38   private static final int MIN_FRAME_SIZE = 16 * 1024;
39   private static final int MAX_FRAME_SIZE = 128 * 1024;
40 
41   private final Protector protector;
42   private final Unprotector unprotector;
43 
44   /** Create a new AltsTsiFrameProtector. */
AltsTsiFrameProtector( int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc)45   public AltsTsiFrameProtector(
46       int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
47     checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength());
48     maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_SIZE, maxProtectedFrameBytes);
49     protector = new Protector(maxProtectedFrameBytes, crypter);
50     unprotector = new Unprotector(crypter, alloc);
51   }
52 
getHeaderLenFieldBytes()53   static int getHeaderLenFieldBytes() {
54     return HEADER_LEN_FIELD_BYTES;
55   }
56 
getHeaderTypeFieldBytes()57   static int getHeaderTypeFieldBytes() {
58     return HEADER_TYPE_FIELD_BYTES;
59   }
60 
getHeaderBytes()61   public static int getHeaderBytes() {
62     return HEADER_BYTES;
63   }
64 
getHeaderTypeDefault()65   static int getHeaderTypeDefault() {
66     return HEADER_TYPE_DEFAULT;
67   }
68 
getLimitMaxAllowedFrameSize()69   static int getLimitMaxAllowedFrameSize() {
70     return LIMIT_MAX_ALLOWED_FRAME_SIZE;
71   }
72 
getMinFrameSize()73   public static int getMinFrameSize() {
74     return MIN_FRAME_SIZE;
75   }
76 
getMaxFrameSize()77   public static int getMaxFrameSize() {
78     return MAX_FRAME_SIZE;
79   }
80 
81   @Override
protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)82   public void protectFlush(
83       List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
84       throws GeneralSecurityException {
85     protector.protectFlush(unprotectedBufs, ctxWrite, alloc);
86   }
87 
88   @Override
unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)89   public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
90       throws GeneralSecurityException {
91     unprotector.unprotect(in, out, alloc);
92   }
93 
94   @Override
destroy()95   public void destroy() {
96     try {
97       unprotector.destroy();
98     } finally {
99       protector.destroy();
100     }
101   }
102 
103   static final class Protector {
104     private final int maxUnprotectedBytesPerFrame;
105     private final int suffixBytes;
106     private ChannelCrypterNetty crypter;
107 
Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter)108     Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) {
109       this.suffixBytes = crypter.getSuffixLength();
110       this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes;
111       this.crypter = crypter;
112     }
113 
destroy()114     void destroy() {
115       // Shared with Unprotector and destroyed there.
116       crypter = null;
117     }
118 
protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)119     void protectFlush(
120         List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
121         throws GeneralSecurityException {
122       checkState(crypter != null, "Cannot protectFlush after destroy.");
123       ByteBuf protectedBuf;
124       try {
125         protectedBuf = handleUnprotected(unprotectedBufs, alloc);
126       } finally {
127         for (ByteBuf buf : unprotectedBufs) {
128           buf.release();
129         }
130       }
131       if (protectedBuf != null) {
132         ctxWrite.accept(protectedBuf);
133       }
134     }
135 
handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)136     private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)
137         throws GeneralSecurityException {
138       long unprotectedBytes = 0;
139       for (ByteBuf buf : unprotectedBufs) {
140         unprotectedBytes += buf.readableBytes();
141       }
142       // Empty plaintext not allowed since this should be handled as no-op in layer above.
143       checkArgument(unprotectedBytes > 0);
144 
145       // Compute number of frames and allocate a single buffer for all frames.
146       long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1;
147       int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame);
148       if (lastFrameUnprotectedBytes == 0) {
149         frameNum--;
150         lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame;
151       }
152       long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes;
153 
154       ByteBuf protectedBuf = alloc.directBuffer(Ints.checkedCast(protectedBytes));
155       try {
156         int bufferIdx = 0;
157         for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) {
158           int unprotectedBytesLeft =
159               (frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame;
160           // Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES).
161           protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes);
162           protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT);
163 
164           // Ownership of the backing buffer remains with protectedBuf.
165           ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes);
166           List<ByteBuf> framePlain = new ArrayList<>();
167           while (unprotectedBytesLeft > 0) {
168             // Ownership of the buffer backing in remains with unprotectedBufs.
169             ByteBuf in = unprotectedBufs.get(bufferIdx);
170             if (in.readableBytes() <= unprotectedBytesLeft) {
171               // The complete buffer belongs to this frame.
172               framePlain.add(in);
173               unprotectedBytesLeft -= in.readableBytes();
174               bufferIdx++;
175             } else {
176               // The remainder of in will be part of the next frame.
177               framePlain.add(in.readSlice(unprotectedBytesLeft));
178               unprotectedBytesLeft = 0;
179             }
180           }
181           crypter.encrypt(frameOut, framePlain);
182           verify(!frameOut.isWritable());
183         }
184         protectedBuf.readerIndex(0);
185         protectedBuf.writerIndex(protectedBuf.capacity());
186         return protectedBuf.retain();
187       } finally {
188         protectedBuf.release();
189       }
190     }
191   }
192 
193   static final class Unprotector {
194     private final int suffixBytes;
195     private final ChannelCrypterNetty crypter;
196 
197     private DeframerState state = DeframerState.READ_HEADER;
198     private int requiredProtectedBytes;
199     private ByteBuf header;
200     private ByteBuf firstFrameTag;
201     private int unhandledIdx = 0;
202     private long unhandledBytes = 0;
203     private List<ByteBuf> unhandledBufs = new ArrayList<>(16);
204 
Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc)205     Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) {
206       this.crypter = crypter;
207       this.suffixBytes = crypter.getSuffixLength();
208       this.header = alloc.directBuffer(HEADER_BYTES);
209       this.firstFrameTag = alloc.directBuffer(suffixBytes);
210     }
211 
addUnhandled(ByteBuf in)212     private void addUnhandled(ByteBuf in) {
213       if (in.isReadable()) {
214         ByteBuf buf = in.readRetainedSlice(in.readableBytes());
215         unhandledBufs.add(buf);
216         unhandledBytes += buf.readableBytes();
217       }
218     }
219 
unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)220     void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
221         throws GeneralSecurityException {
222       checkState(header != null, "Cannot unprotect after destroy.");
223       addUnhandled(in);
224       decodeFrame(alloc, out);
225     }
226 
227     @SuppressWarnings("fallthrough")
decodeFrame(ByteBufAllocator alloc, List<Object> out)228     private void decodeFrame(ByteBufAllocator alloc, List<Object> out)
229         throws GeneralSecurityException {
230       switch (state) {
231         case READ_HEADER:
232           if (unhandledBytes < HEADER_BYTES) {
233             return;
234           }
235           handleHeader();
236           // fall through
237         case READ_PROTECTED_PAYLOAD:
238           if (unhandledBytes < requiredProtectedBytes) {
239             return;
240           }
241           ByteBuf unprotectedBuf;
242           try {
243             unprotectedBuf = handlePayload(alloc);
244           } finally {
245             clearState();
246           }
247           if (unprotectedBuf != null) {
248             out.add(unprotectedBuf);
249           }
250           break;
251         default:
252           throw new AssertionError("impossible enum value");
253       }
254     }
255 
handleHeader()256     private void handleHeader() {
257       while (header.isWritable()) {
258         ByteBuf in = unhandledBufs.get(unhandledIdx);
259         int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes());
260         header.writeBytes(in, headerBytesToRead);
261         unhandledBytes -= headerBytesToRead;
262         if (!in.isReadable()) {
263           unhandledIdx++;
264         }
265       }
266       requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES;
267       checkArgument(
268           requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small");
269       checkArgument(
270           requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_SIZE - HEADER_BYTES,
271           "Invalid header field: frame size too large");
272       int frameType = header.readIntLE();
273       checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type");
274       state = DeframerState.READ_PROTECTED_PAYLOAD;
275     }
276 
handlePayload(ByteBufAllocator alloc)277     private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException {
278       int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes;
279       int firstFrameUnprotectedLen = requiredCiphertextBytes;
280 
281       // We get the ciphertexts of the first frame and copy over the tag into a single buffer.
282       List<ByteBuf> firstFrameCiphertext = new ArrayList<>();
283       while (requiredCiphertextBytes > 0) {
284         ByteBuf buf = unhandledBufs.get(unhandledIdx);
285         if (buf.readableBytes() <= requiredCiphertextBytes) {
286           // We use the whole buffer.
287           firstFrameCiphertext.add(buf);
288           requiredCiphertextBytes -= buf.readableBytes();
289           unhandledIdx++;
290         } else {
291           firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes));
292           requiredCiphertextBytes = 0;
293         }
294       }
295       int requiredSuffixBytes = suffixBytes;
296       while (true) {
297         ByteBuf buf = unhandledBufs.get(unhandledIdx);
298         if (buf.readableBytes() <= requiredSuffixBytes) {
299           // We use the whole buffer.
300           requiredSuffixBytes -= buf.readableBytes();
301           firstFrameTag.writeBytes(buf);
302           if (requiredSuffixBytes == 0) {
303             break;
304           }
305           unhandledIdx++;
306         } else {
307           firstFrameTag.writeBytes(buf, requiredSuffixBytes);
308           break;
309         }
310       }
311       verify(unhandledIdx == unhandledBufs.size() - 1);
312       ByteBuf lastBuf = unhandledBufs.get(unhandledIdx);
313 
314       // We get the remaining ciphertexts and tags contained in the last buffer.
315       List<ByteBuf> ciphertextsAndTags = new ArrayList<>();
316       List<Integer> unprotectedLens = new ArrayList<>();
317       long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen;
318       while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) {
319         // Read frame size.
320         int frameSize = lastBuf.readIntLE();
321         int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes;
322         // Break and undo read if we don't have the complete frame yet.
323         if (lastBuf.readableBytes() < frameSize) {
324           lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES);
325           break;
326         }
327         // Check the type header.
328         checkArgument(lastBuf.readIntLE() == 6);
329         // Create a new frame (except for out buffer).
330         ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes));
331         // Update sizes for frame.
332         requiredUnprotectedBytesCompleteFrames += payloadSize;
333         unprotectedLens.add(payloadSize);
334       }
335 
336       // We leave space for suffixBytes to allow for in-place encryption. This allows for calling
337       // doFinal in the JCE implementation which can be optimized better than update and doFinal.
338       ByteBuf unprotectedBuf =
339           alloc.directBuffer(
340               Ints.checkedCast(requiredUnprotectedBytesCompleteFrames + suffixBytes));
341       try {
342 
343         ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes);
344         crypter.decrypt(out, firstFrameTag, firstFrameCiphertext);
345         verify(out.writableBytes() == suffixBytes);
346         unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
347 
348         for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) {
349           out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes);
350           crypter.decrypt(out, ciphertextsAndTags.get(frameIdx));
351           verify(out.writableBytes() == suffixBytes);
352           unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes);
353         }
354         return unprotectedBuf.retain();
355       } finally {
356         unprotectedBuf.release();
357       }
358     }
359 
clearState()360     private void clearState() {
361       int bufsSize = unhandledBufs.size();
362       ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1);
363       boolean keepLast = lastBuf.isReadable();
364       for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) {
365         unhandledBufs.get(bufIdx).release();
366       }
367       unhandledBufs.clear();
368       unhandledBytes = 0;
369       unhandledIdx = 0;
370       if (keepLast) {
371         unhandledBufs.add(lastBuf);
372         unhandledBytes = lastBuf.readableBytes();
373       }
374       state = DeframerState.READ_HEADER;
375       requiredProtectedBytes = 0;
376       header.clear();
377       firstFrameTag.clear();
378     }
379 
destroy()380     void destroy() {
381       for (ByteBuf unhandledBuf : unhandledBufs) {
382         unhandledBuf.release();
383       }
384       unhandledBufs.clear();
385       if (header != null) {
386         header.release();
387         header = null;
388       }
389       if (firstFrameTag != null) {
390         firstFrameTag.release();
391         firstFrameTag = null;
392       }
393       crypter.destroy();
394     }
395   }
396 
397   private enum DeframerState {
398     READ_HEADER,
399     READ_PROTECTED_PAYLOAD
400   }
401 
writeSlice(ByteBuf in, int len)402   private static ByteBuf writeSlice(ByteBuf in, int len) {
403     checkArgument(len <= in.writableBytes());
404     ByteBuf out = in.slice(in.writerIndex(), len);
405     in.writerIndex(in.writerIndex() + len);
406     return out.writerIndex(0);
407   }
408 }
409