1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkState; 21 import static com.google.common.base.Verify.verify; 22 23 import com.google.common.primitives.Ints; 24 import io.netty.buffer.ByteBuf; 25 import io.netty.buffer.ByteBufAllocator; 26 import java.security.GeneralSecurityException; 27 import java.util.ArrayList; 28 import java.util.List; 29 30 /** Frame protector that uses the ALTS framing. */ 31 public final class AltsTsiFrameProtector implements TsiFrameProtector { 32 private static final int HEADER_LEN_FIELD_BYTES = 4; 33 private static final int HEADER_TYPE_FIELD_BYTES = 4; 34 private static final int HEADER_BYTES = HEADER_LEN_FIELD_BYTES + HEADER_TYPE_FIELD_BYTES; 35 private static final int HEADER_TYPE_DEFAULT = 6; 36 // Total frame size including full header and tag. 37 private static final int MAX_ALLOWED_FRAME_BYTES = 16 * 1024; 38 private static final int LIMIT_MAX_ALLOWED_FRAME_BYTES = 1024 * 1024; 39 40 private final Protector protector; 41 private final Unprotector unprotector; 42 43 /** Create a new AltsTsiFrameProtector. */ AltsTsiFrameProtector( int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc)44 public AltsTsiFrameProtector( 45 int maxProtectedFrameBytes, ChannelCrypterNetty crypter, ByteBufAllocator alloc) { 46 checkArgument(maxProtectedFrameBytes > HEADER_BYTES + crypter.getSuffixLength()); 47 maxProtectedFrameBytes = Math.min(LIMIT_MAX_ALLOWED_FRAME_BYTES, maxProtectedFrameBytes); 48 protector = new Protector(maxProtectedFrameBytes, crypter); 49 unprotector = new Unprotector(crypter, alloc); 50 } 51 getHeaderLenFieldBytes()52 static int getHeaderLenFieldBytes() { 53 return HEADER_LEN_FIELD_BYTES; 54 } 55 getHeaderTypeFieldBytes()56 static int getHeaderTypeFieldBytes() { 57 return HEADER_TYPE_FIELD_BYTES; 58 } 59 getHeaderBytes()60 public static int getHeaderBytes() { 61 return HEADER_BYTES; 62 } 63 getHeaderTypeDefault()64 static int getHeaderTypeDefault() { 65 return HEADER_TYPE_DEFAULT; 66 } 67 getMaxAllowedFrameBytes()68 public static int getMaxAllowedFrameBytes() { 69 return MAX_ALLOWED_FRAME_BYTES; 70 } 71 getLimitMaxAllowedFrameBytes()72 static int getLimitMaxAllowedFrameBytes() { 73 return LIMIT_MAX_ALLOWED_FRAME_BYTES; 74 } 75 76 @Override protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)77 public void protectFlush( 78 List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) 79 throws GeneralSecurityException { 80 protector.protectFlush(unprotectedBufs, ctxWrite, alloc); 81 } 82 83 @Override unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)84 public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) 85 throws GeneralSecurityException { 86 unprotector.unprotect(in, out, alloc); 87 } 88 89 @Override destroy()90 public void destroy() { 91 try { 92 unprotector.destroy(); 93 } finally { 94 protector.destroy(); 95 } 96 } 97 98 static final class Protector { 99 private final int maxUnprotectedBytesPerFrame; 100 private final int suffixBytes; 101 private ChannelCrypterNetty crypter; 102 Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter)103 Protector(int maxProtectedFrameBytes, ChannelCrypterNetty crypter) { 104 this.suffixBytes = crypter.getSuffixLength(); 105 this.maxUnprotectedBytesPerFrame = maxProtectedFrameBytes - HEADER_BYTES - suffixBytes; 106 this.crypter = crypter; 107 } 108 destroy()109 void destroy() { 110 // Shared with Unprotector and destroyed there. 111 crypter = null; 112 } 113 protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)114 void protectFlush( 115 List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc) 116 throws GeneralSecurityException { 117 checkState(crypter != null, "Cannot protectFlush after destroy."); 118 ByteBuf protectedBuf; 119 try { 120 protectedBuf = handleUnprotected(unprotectedBufs, alloc); 121 } finally { 122 for (ByteBuf buf : unprotectedBufs) { 123 buf.release(); 124 } 125 } 126 if (protectedBuf != null) { 127 ctxWrite.accept(protectedBuf); 128 } 129 } 130 131 @SuppressWarnings("BetaApi") // verify is stable in Guava handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc)132 private ByteBuf handleUnprotected(List<ByteBuf> unprotectedBufs, ByteBufAllocator alloc) 133 throws GeneralSecurityException { 134 long unprotectedBytes = 0; 135 for (ByteBuf buf : unprotectedBufs) { 136 unprotectedBytes += buf.readableBytes(); 137 } 138 // Empty plaintext not allowed since this should be handled as no-op in layer above. 139 checkArgument(unprotectedBytes > 0); 140 141 // Compute number of frames and allocate a single buffer for all frames. 142 long frameNum = unprotectedBytes / maxUnprotectedBytesPerFrame + 1; 143 int lastFrameUnprotectedBytes = (int) (unprotectedBytes % maxUnprotectedBytesPerFrame); 144 if (lastFrameUnprotectedBytes == 0) { 145 frameNum--; 146 lastFrameUnprotectedBytes = maxUnprotectedBytesPerFrame; 147 } 148 long protectedBytes = frameNum * (HEADER_BYTES + suffixBytes) + unprotectedBytes; 149 150 ByteBuf protectedBuf = alloc.directBuffer(Ints.checkedCast(protectedBytes)); 151 try { 152 int bufferIdx = 0; 153 for (int frameIdx = 0; frameIdx < frameNum; ++frameIdx) { 154 int unprotectedBytesLeft = 155 (frameIdx == frameNum - 1) ? lastFrameUnprotectedBytes : maxUnprotectedBytesPerFrame; 156 // Write header (at most LIMIT_MAX_ALLOWED_FRAME_BYTES). 157 protectedBuf.writeIntLE(unprotectedBytesLeft + HEADER_TYPE_FIELD_BYTES + suffixBytes); 158 protectedBuf.writeIntLE(HEADER_TYPE_DEFAULT); 159 160 // Ownership of the backing buffer remains with protectedBuf. 161 ByteBuf frameOut = writeSlice(protectedBuf, unprotectedBytesLeft + suffixBytes); 162 List<ByteBuf> framePlain = new ArrayList<>(); 163 while (unprotectedBytesLeft > 0) { 164 // Ownership of the buffer backing in remains with unprotectedBufs. 165 ByteBuf in = unprotectedBufs.get(bufferIdx); 166 if (in.readableBytes() <= unprotectedBytesLeft) { 167 // The complete buffer belongs to this frame. 168 framePlain.add(in); 169 unprotectedBytesLeft -= in.readableBytes(); 170 bufferIdx++; 171 } else { 172 // The remainder of in will be part of the next frame. 173 framePlain.add(in.readSlice(unprotectedBytesLeft)); 174 unprotectedBytesLeft = 0; 175 } 176 } 177 crypter.encrypt(frameOut, framePlain); 178 verify(!frameOut.isWritable()); 179 } 180 protectedBuf.readerIndex(0); 181 protectedBuf.writerIndex(protectedBuf.capacity()); 182 return protectedBuf.retain(); 183 } finally { 184 protectedBuf.release(); 185 } 186 } 187 } 188 189 static final class Unprotector { 190 private final int suffixBytes; 191 private final ChannelCrypterNetty crypter; 192 193 private DeframerState state = DeframerState.READ_HEADER; 194 private int requiredProtectedBytes; 195 private ByteBuf header; 196 private ByteBuf firstFrameTag; 197 private int unhandledIdx = 0; 198 private long unhandledBytes = 0; 199 private List<ByteBuf> unhandledBufs = new ArrayList<>(16); 200 Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc)201 Unprotector(ChannelCrypterNetty crypter, ByteBufAllocator alloc) { 202 this.crypter = crypter; 203 this.suffixBytes = crypter.getSuffixLength(); 204 this.header = alloc.directBuffer(HEADER_BYTES); 205 this.firstFrameTag = alloc.directBuffer(suffixBytes); 206 } 207 addUnhandled(ByteBuf in)208 private void addUnhandled(ByteBuf in) { 209 if (in.isReadable()) { 210 ByteBuf buf = in.readRetainedSlice(in.readableBytes()); 211 unhandledBufs.add(buf); 212 unhandledBytes += buf.readableBytes(); 213 } 214 } 215 unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)216 void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc) 217 throws GeneralSecurityException { 218 checkState(header != null, "Cannot unprotect after destroy."); 219 addUnhandled(in); 220 decodeFrame(alloc, out); 221 } 222 223 @SuppressWarnings("fallthrough") decodeFrame(ByteBufAllocator alloc, List<Object> out)224 private void decodeFrame(ByteBufAllocator alloc, List<Object> out) 225 throws GeneralSecurityException { 226 switch (state) { 227 case READ_HEADER: 228 if (unhandledBytes < HEADER_BYTES) { 229 return; 230 } 231 handleHeader(); 232 // fall through 233 case READ_PROTECTED_PAYLOAD: 234 if (unhandledBytes < requiredProtectedBytes) { 235 return; 236 } 237 ByteBuf unprotectedBuf; 238 try { 239 unprotectedBuf = handlePayload(alloc); 240 } finally { 241 clearState(); 242 } 243 if (unprotectedBuf != null) { 244 out.add(unprotectedBuf); 245 } 246 break; 247 default: 248 throw new AssertionError("impossible enum value"); 249 } 250 } 251 handleHeader()252 private void handleHeader() { 253 while (header.isWritable()) { 254 ByteBuf in = unhandledBufs.get(unhandledIdx); 255 int headerBytesToRead = Math.min(in.readableBytes(), header.writableBytes()); 256 header.writeBytes(in, headerBytesToRead); 257 unhandledBytes -= headerBytesToRead; 258 if (!in.isReadable()) { 259 unhandledIdx++; 260 } 261 } 262 requiredProtectedBytes = header.readIntLE() - HEADER_TYPE_FIELD_BYTES; 263 checkArgument( 264 requiredProtectedBytes >= suffixBytes, "Invalid header field: frame size too small"); 265 checkArgument( 266 requiredProtectedBytes <= LIMIT_MAX_ALLOWED_FRAME_BYTES - HEADER_BYTES, 267 "Invalid header field: frame size too large"); 268 int frameType = header.readIntLE(); 269 checkArgument(frameType == HEADER_TYPE_DEFAULT, "Invalid header field: frame type"); 270 state = DeframerState.READ_PROTECTED_PAYLOAD; 271 } 272 273 @SuppressWarnings("BetaApi") // verify is stable in Guava handlePayload(ByteBufAllocator alloc)274 private ByteBuf handlePayload(ByteBufAllocator alloc) throws GeneralSecurityException { 275 int requiredCiphertextBytes = requiredProtectedBytes - suffixBytes; 276 int firstFrameUnprotectedLen = requiredCiphertextBytes; 277 278 // We get the ciphertexts of the first frame and copy over the tag into a single buffer. 279 List<ByteBuf> firstFrameCiphertext = new ArrayList<>(); 280 while (requiredCiphertextBytes > 0) { 281 ByteBuf buf = unhandledBufs.get(unhandledIdx); 282 if (buf.readableBytes() <= requiredCiphertextBytes) { 283 // We use the whole buffer. 284 firstFrameCiphertext.add(buf); 285 requiredCiphertextBytes -= buf.readableBytes(); 286 unhandledIdx++; 287 } else { 288 firstFrameCiphertext.add(buf.readSlice(requiredCiphertextBytes)); 289 requiredCiphertextBytes = 0; 290 } 291 } 292 int requiredSuffixBytes = suffixBytes; 293 while (true) { 294 ByteBuf buf = unhandledBufs.get(unhandledIdx); 295 if (buf.readableBytes() <= requiredSuffixBytes) { 296 // We use the whole buffer. 297 requiredSuffixBytes -= buf.readableBytes(); 298 firstFrameTag.writeBytes(buf); 299 if (requiredSuffixBytes == 0) { 300 break; 301 } 302 unhandledIdx++; 303 } else { 304 firstFrameTag.writeBytes(buf, requiredSuffixBytes); 305 break; 306 } 307 } 308 verify(unhandledIdx == unhandledBufs.size() - 1); 309 ByteBuf lastBuf = unhandledBufs.get(unhandledIdx); 310 311 // We get the remaining ciphertexts and tags contained in the last buffer. 312 List<ByteBuf> ciphertextsAndTags = new ArrayList<>(); 313 List<Integer> unprotectedLens = new ArrayList<>(); 314 long requiredUnprotectedBytesCompleteFrames = firstFrameUnprotectedLen; 315 while (lastBuf.readableBytes() >= HEADER_BYTES + suffixBytes) { 316 // Read frame size. 317 int frameSize = lastBuf.readIntLE(); 318 int payloadSize = frameSize - HEADER_TYPE_FIELD_BYTES - suffixBytes; 319 // Break and undo read if we don't have the complete frame yet. 320 if (lastBuf.readableBytes() < frameSize) { 321 lastBuf.readerIndex(lastBuf.readerIndex() - HEADER_LEN_FIELD_BYTES); 322 break; 323 } 324 // Check the type header. 325 checkArgument(lastBuf.readIntLE() == 6); 326 // Create a new frame (except for out buffer). 327 ciphertextsAndTags.add(lastBuf.readSlice(payloadSize + suffixBytes)); 328 // Update sizes for frame. 329 requiredUnprotectedBytesCompleteFrames += payloadSize; 330 unprotectedLens.add(payloadSize); 331 } 332 333 // We leave space for suffixBytes to allow for in-place encryption. This allows for calling 334 // doFinal in the JCE implementation which can be optimized better than update and doFinal. 335 ByteBuf unprotectedBuf = 336 alloc.directBuffer( 337 Ints.checkedCast(requiredUnprotectedBytesCompleteFrames + suffixBytes)); 338 try { 339 340 ByteBuf out = writeSlice(unprotectedBuf, firstFrameUnprotectedLen + suffixBytes); 341 crypter.decrypt(out, firstFrameTag, firstFrameCiphertext); 342 verify(out.writableBytes() == suffixBytes); 343 unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes); 344 345 for (int frameIdx = 0; frameIdx < ciphertextsAndTags.size(); ++frameIdx) { 346 out = writeSlice(unprotectedBuf, unprotectedLens.get(frameIdx) + suffixBytes); 347 crypter.decrypt(out, ciphertextsAndTags.get(frameIdx)); 348 verify(out.writableBytes() == suffixBytes); 349 unprotectedBuf.writerIndex(unprotectedBuf.writerIndex() - suffixBytes); 350 } 351 return unprotectedBuf.retain(); 352 } finally { 353 unprotectedBuf.release(); 354 } 355 } 356 clearState()357 private void clearState() { 358 int bufsSize = unhandledBufs.size(); 359 ByteBuf lastBuf = unhandledBufs.get(bufsSize - 1); 360 boolean keepLast = lastBuf.isReadable(); 361 for (int bufIdx = 0; bufIdx < (keepLast ? bufsSize - 1 : bufsSize); ++bufIdx) { 362 unhandledBufs.get(bufIdx).release(); 363 } 364 unhandledBufs.clear(); 365 unhandledBytes = 0; 366 unhandledIdx = 0; 367 if (keepLast) { 368 unhandledBufs.add(lastBuf); 369 unhandledBytes = lastBuf.readableBytes(); 370 } 371 state = DeframerState.READ_HEADER; 372 requiredProtectedBytes = 0; 373 header.clear(); 374 firstFrameTag.clear(); 375 } 376 destroy()377 void destroy() { 378 for (ByteBuf unhandledBuf : unhandledBufs) { 379 unhandledBuf.release(); 380 } 381 unhandledBufs.clear(); 382 if (header != null) { 383 header.release(); 384 header = null; 385 } 386 if (firstFrameTag != null) { 387 firstFrameTag.release(); 388 firstFrameTag = null; 389 } 390 crypter.destroy(); 391 } 392 } 393 394 private enum DeframerState { 395 READ_HEADER, 396 READ_PROTECTED_PAYLOAD 397 } 398 writeSlice(ByteBuf in, int len)399 private static ByteBuf writeSlice(ByteBuf in, int len) { 400 checkArgument(len <= in.writableBytes()); 401 ByteBuf out = in.slice(in.writerIndex(), len); 402 in.writerIndex(in.writerIndex() + len); 403 return out.writerIndex(0); 404 } 405 } 406