1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkState; 21 import static com.google.common.base.Verify.verify; 22 23 import com.google.common.primitives.Ints; 24 import io.netty.buffer.ByteBuf; 25 import io.netty.buffer.ByteBufAllocator; 26 import java.security.GeneralSecurityException; 27 import java.util.ArrayList; 28 import java.util.List; 29 30 /** Frame protector that uses the ALTS framing. */ 31 public final class AltsTsiFrameProtector implements TsiFrameProtector { 32 private static final int HEADER_LEN_FIELD_BYTES = 4; 33 private static final int HEADER_TYPE_FIELD_BYTES = 4; 34 private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES; 35 private static final int HEADER_TYPE_DEFAULT = 6; 36 private static final int LIMIT_MAX_ALLOWED_FRAME_SIZE = 1024 * 1024; 37 // Frame size negotiation extends frame size range to [MIN_FRAME_SIZE, MAX_FRAME_SIZE]. 38 private static final int MIN_FRAME_SIZE = 16 * 1024; 39 private static final int MAX_FRAME_SIZE = 128 * 1024; 40 41 private final Protector protector; 42 private final Unprotector unprotector; 43 44 /** Create a new AltsTsiFrameProtector. */ AltsTsiFrameProtector( int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc)45 public AltsTsiFrameProtector( 46 int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) { 47 checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength()); 48 maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_SIZE, maxProtectedFrameBytes); 49 protector = new Protector(maxProtectedFrameBytes, crypter); 50 unprotector = new Unprotector(crypter, alloc); 51 } 52 getHeaderLenFieldBytes()53 static int getHeaderLenFieldBytes() { 54 return HEADER_LEN_FIELD_BYTES; 55 } 56 getHeaderTypeFieldBytes()57 static int getHeaderTypeFieldBytes() { 58 return HEADER_TYPE_FIELD_BYTES; 59 } 60 getHeaderBytes()61 public static int getHeaderBytes() { 62 return HEADER_BYTES; 63 } 64 getHeaderTypeDefault()65 static int getHeaderTypeDefault() { 66 return HEADER_TYPE_DEFAULT; 67 } 68 getLimitMaxAllowedFrameSize()69 static int getLimitMaxAllowedFrameSize() { 70 return LIMIT_MAX_ALLOWED_FRAME_SIZE; 71 } 72 getMinFrameSize()73 public static int getMinFrameSize() { 74 return MIN_FRAME_SIZE; 75 } 76 getMaxFrameSize()77 public static int getMaxFrameSize() { 78 return MAX_FRAME_SIZE; 79 } 80 81 @Override protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)82 public void protectFlush( 83 List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) 84 throws GeneralSecurityException { 85 protector.protectFlush(unprotectedBufs, ctxWrite, alloc); 86 } 87 88 @Override unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)89 public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) 90 throws GeneralSecurityException { 91 unprotector.unprotect(in, out, alloc); 92 } 93 94 @Override destroy()95 public void destroy() { 96 try { 97 unprotector.destroy(); 98 } finally { 99 protector.destroy(); 100 } 101 } 102 103 static final class Protector { 104 private final int maxUnprotectedBytesPerFrame; 105 private final int suffixBytes; 106 private ChannelCrypterNetty crypter; 107 Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter)108 Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) { 109 this.suffixBytes = crypter.getSuffixLength(); 110 this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes; 111 this.crypter = crypter; 112 } 113 destroy()114 void destroy() { 115 // Shared with Unprotector and destroyed there. 116 crypter = null; 117 } 118 protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)119 void protectFlush( 120 List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) 121 throws GeneralSecurityException { 122 checkState(crypter != null, "Cannot protectFlush after destroy."); 123 ByteBuf protectedBuf; 124 try { 125 protectedBuf = handleUnprotected(unprotectedBufs, alloc); 126 } finally { 127 for (ByteBuf buf : unprotectedBufs) { 128 buf.release(); 129 } 130 } 131 if (protectedBuf != null) { 132 ctxWrite.accept(protectedBuf); 133 } 134 } 135 handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)136 private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc) 137 throws GeneralSecurityException { 138 long unprotectedBytes = 0; 139 for (ByteBuf buf : unprotectedBufs) { 140 unprotectedBytes += buf.readableBytes(); 141 } 142 // Empty plaintext not allowed since this should be handled as no-op in layer above. 143 checkArgument(unprotectedBytes > 0); 144 145 // Compute number of frames and allocate a single buffer for all frames. 146 long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1; 147 int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame); 148 if (lastFrameUnprotectedBytes == 0) { 149 frameNum--; 150 lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame; 151 } 152 long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes; 153 154 ByteBuf protectedBuf = alloc.directBuffer(Ints.checkedCast(protectedBytes)); 155 try { 156 int bufferIdx = 0; 157 for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) { 158 int unprotectedBytesLeft = 159 (frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame; 160 // Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES). 161 protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes); 162 protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT); 163 164 // Ownership of the backing buffer remains with protectedBuf. 165 ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes); 166 List<ByteBuf> framePlain = new ArrayList<>(); 167 while (unprotectedBytesLeft > 0) { 168 // Ownership of the buffer backing in remains with unprotectedBufs. 169 ByteBuf in = unprotectedBufs.get(bufferIdx); 170 if (in.readableBytes() <= unprotectedBytesLeft) { 171 // The complete buffer belongs to this frame. 172 framePlain.add(in); 173 unprotectedBytesLeft -= in.readableBytes(); 174 bufferIdx++; 175 } else { 176 // The remainder of in will be part of the next frame. 177 framePlain.add(in.readSlice(unprotectedBytesLeft)); 178 unprotectedBytesLeft = 0; 179 } 180 } 181 crypter.encrypt(frameOut, framePlain); 182 verify(!frameOut.isWritable()); 183 } 184 protectedBuf.readerIndex(0); 185 protectedBuf.writerIndex(protectedBuf.capacity()); 186 return protectedBuf.retain(); 187 } finally { 188 protectedBuf.release(); 189 } 190 } 191 } 192 193 static final class Unprotector { 194 private final int suffixBytes; 195 private final ChannelCrypterNetty crypter; 196 197 private DeframerState state = DeframerState.READ_HEADER; 198 private int requiredProtectedBytes; 199 private ByteBuf header; 200 private ByteBuf firstFrameTag; 201 private int unhandledIdx = 0; 202 private long unhandledBytes = 0; 203 private List<ByteBuf> unhandledBufs = new ArrayList<>(16); 204 Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc)205 Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) { 206 this.crypter = crypter; 207 this.suffixBytes = crypter.getSuffixLength(); 208 this.header = alloc.directBuffer(HEADER_BYTES); 209 this.firstFrameTag = alloc.directBuffer(suffixBytes); 210 } 211 addUnhandled(ByteBuf in)212 private void addUnhandled(ByteBuf in) { 213 if (in.isReadable()) { 214 ByteBuf buf = in.readRetainedSlice(in.readableBytes()); 215 unhandledBufs.add(buf); 216 unhandledBytes += buf.readableBytes(); 217 } 218 } 219 unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)220 void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) 221 throws GeneralSecurityException { 222 checkState(header != null, "Cannot unprotect after destroy."); 223 addUnhandled(in); 224 decodeFrame(alloc, out); 225 } 226 227 @SuppressWarnings("fallthrough") decodeFrame(ByteBufAllocator alloc, List<Object> out)228 private void decodeFrame(ByteBufAllocator alloc, List<Object> out) 229 throws GeneralSecurityException { 230 switch (state) { 231 case READ_HEADER: 232 if (unhandledBytes < HEADER_BYTES) { 233 return; 234 } 235 handleHeader(); 236 // fall through 237 case READ_PROTECTED_PAYLOAD: 238 if (unhandledBytes < requiredProtectedBytes) { 239 return; 240 } 241 ByteBuf unprotectedBuf; 242 try { 243 unprotectedBuf = handlePayload(alloc); 244 } finally { 245 clearState(); 246 } 247 if (unprotectedBuf != null) { 248 out.add(unprotectedBuf); 249 } 250 break; 251 default: 252 throw new AssertionError("impossible enum value"); 253 } 254 } 255 handleHeader()256 private void handleHeader() { 257 while (header.isWritable()) { 258 ByteBuf in = unhandledBufs.get(unhandledIdx); 259 int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes()); 260 header.writeBytes(in, headerBytesToRead); 261 unhandledBytes -= headerBytesToRead; 262 if (!in.isReadable()) { 263 unhandledIdx++; 264 } 265 } 266 requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES; 267 checkArgument( 268 requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small"); 269 checkArgument( 270 requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_SIZE - HEADER_BYTES, 271 "Invalid header field: frame size too large"); 272 int frameType = header.readIntLE(); 273 checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type"); 274 state = DeframerState.READ_PROTECTED_PAYLOAD; 275 } 276 handlePayload(ByteBufAllocator alloc)277 private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException { 278 int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes; 279 int firstFrameUnprotectedLen = requiredCiphertextBytes; 280 281 // We get the ciphertexts of the first frame and copy over the tag into a single buffer. 282 List<ByteBuf> firstFrameCiphertext = new ArrayList<>(); 283 while (requiredCiphertextBytes > 0) { 284 ByteBuf buf = unhandledBufs.get(unhandledIdx); 285 if (buf.readableBytes() <= requiredCiphertextBytes) { 286 // We use the whole buffer. 287 firstFrameCiphertext.add(buf); 288 requiredCiphertextBytes -= buf.readableBytes(); 289 unhandledIdx++; 290 } else { 291 firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes)); 292 requiredCiphertextBytes = 0; 293 } 294 } 295 int requiredSuffixBytes = suffixBytes; 296 while (true) { 297 ByteBuf buf = unhandledBufs.get(unhandledIdx); 298 if (buf.readableBytes() <= requiredSuffixBytes) { 299 // We use the whole buffer. 300 requiredSuffixBytes -= buf.readableBytes(); 301 firstFrameTag.writeBytes(buf); 302 if (requiredSuffixBytes == 0) { 303 break; 304 } 305 unhandledIdx++; 306 } else { 307 firstFrameTag.writeBytes(buf, requiredSuffixBytes); 308 break; 309 } 310 } 311 verify(unhandledIdx == unhandledBufs.size() - 1); 312 ByteBuf lastBuf = unhandledBufs.get(unhandledIdx); 313 314 // We get the remaining ciphertexts and tags contained in the last buffer. 315 List<ByteBuf> ciphertextsAndTags = new ArrayList<>(); 316 List<Integer> unprotectedLens = new ArrayList<>(); 317 long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen; 318 while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) { 319 // Read frame size. 320 int frameSize = lastBuf.readIntLE(); 321 int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes; 322 // Break and undo read if we don't have the complete frame yet. 323 if (lastBuf.readableBytes() < frameSize) { 324 lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES); 325 break; 326 } 327 // Check the type header. 328 checkArgument(lastBuf.readIntLE() == 6); 329 // Create a new frame (except for out buffer). 330 ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes)); 331 // Update sizes for frame. 332 requiredUnprotectedBytesCompleteFrames += payloadSize; 333 unprotectedLens.add(payloadSize); 334 } 335 336 // We leave space for suffixBytes to allow for in-place encryption. This allows for calling 337 // doFinal in the JCE implementation which can be optimized better than update and doFinal. 338 ByteBuf unprotectedBuf = 339 alloc.directBuffer( 340 Ints.checkedCast(requiredUnprotectedBytesCompleteFrames + suffixBytes)); 341 try { 342 343 ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes); 344 crypter.decrypt(out, firstFrameTag, firstFrameCiphertext); 345 verify(out.writableBytes() == suffixBytes); 346 unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes); 347 348 for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) { 349 out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes); 350 crypter.decrypt(out, ciphertextsAndTags.get(frameIdx)); 351 verify(out.writableBytes() == suffixBytes); 352 unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes); 353 } 354 return unprotectedBuf.retain(); 355 } finally { 356 unprotectedBuf.release(); 357 } 358 } 359 clearState()360 private void clearState() { 361 int bufsSize = unhandledBufs.size(); 362 ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1); 363 boolean keepLast = lastBuf.isReadable(); 364 for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) { 365 unhandledBufs.get(bufIdx).release(); 366 } 367 unhandledBufs.clear(); 368 unhandledBytes = 0; 369 unhandledIdx = 0; 370 if (keepLast) { 371 unhandledBufs.add(lastBuf); 372 unhandledBytes = lastBuf.readableBytes(); 373 } 374 state = DeframerState.READ_HEADER; 375 requiredProtectedBytes = 0; 376 header.clear(); 377 firstFrameTag.clear(); 378 } 379 destroy()380 void destroy() { 381 for (ByteBuf unhandledBuf : unhandledBufs) { 382 unhandledBuf.release(); 383 } 384 unhandledBufs.clear(); 385 if (header != null) { 386 header.release(); 387 header = null; 388 } 389 if (firstFrameTag != null) { 390 firstFrameTag.release(); 391 firstFrameTag = null; 392 } 393 crypter.destroy(); 394 } 395 } 396 397 private enum DeframerState { 398 READ_HEADER, 399 READ_PROTECTED_PAYLOAD 400 } 401 writeSlice(ByteBuf in, int len)402 private static ByteBuf writeSlice(ByteBuf in, int len) { 403 checkArgument(len <= in.writableBytes()); 404 ByteBuf out = in.slice(in.writerIndex(), len); 405 in.writerIndex(in.writerIndex() + len); 406 return out.writerIndex(0); 407 } 408 } 409