1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Verify.verify; 21 22 import com.google.common.annotations.VisibleForTesting; 23 import io.netty.buffer.ByteBuf; 24 import java.nio.ByteBuffer; 25 import java.security.GeneralSecurityException; 26 import java.util.List; 27 28 /** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */ 29 final class AltsChannelCrypter implements ChannelCrypterNetty { 30 private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength(); 31 private static final int COUNTER_LENGTH = 12; 32 // The counter will overflow after 2^64 operations and encryption/decryption will stop working. 33 private static final int COUNTER_OVERFLOW_LENGTH = 8; 34 private static final int TAG_LENGTH = 16; 35 36 private final AeadCrypter aeadCrypter; 37 38 private final byte[] outCounter = new byte[COUNTER_LENGTH]; 39 private final byte[] inCounter = new byte[COUNTER_LENGTH]; 40 private final byte[] oldCounter = new byte[COUNTER_LENGTH]; 41 AltsChannelCrypter(byte[] key, boolean isClient)42 AltsChannelCrypter(byte[] key, boolean isClient) { 43 checkArgument(key.length == KEY_LENGTH); 44 byte[] counter = isClient ? inCounter : outCounter; 45 counter[counter.length - 1] = (byte) 0x80; 46 this.aeadCrypter = new AesGcmHkdfAeadCrypter(key); 47 } 48 getKeyLength()49 static int getKeyLength() { 50 return KEY_LENGTH; 51 } 52 getCounterLength()53 static int getCounterLength() { 54 return COUNTER_LENGTH; 55 } 56 57 @SuppressWarnings("BetaApi") // verify is stable in Guava 58 @Override encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs)59 public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException { 60 checkArgument(outBuf.nioBufferCount() == 1); 61 // Copy plaintext buffers into outBuf for in-place encryption on single direct buffer. 62 ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes()); 63 plainBuf.writerIndex(0); 64 for (ByteBuf inBuf : plainBufs) { 65 plainBuf.writeBytes(inBuf); 66 } 67 68 verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH); 69 ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes()); 70 ByteBuffer plain = out.duplicate(); 71 plain.limit(out.limit() - TAG_LENGTH); 72 73 byte[] counter = incrementOutCounter(); 74 int outPosition = out.position(); 75 aeadCrypter.encrypt(out, plain, counter); 76 int bytesWritten = out.position() - outPosition; 77 outBuf.writerIndex(outBuf.writerIndex() + bytesWritten); 78 verify(!outBuf.isWritable()); 79 } 80 81 @Override decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)82 public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs) 83 throws GeneralSecurityException { 84 85 ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes()); 86 cipherTextAndTag.writerIndex(0); 87 88 for (ByteBuf inBuf : ciphertextBufs) { 89 cipherTextAndTag.writeBytes(inBuf); 90 } 91 cipherTextAndTag.writeBytes(tag); 92 93 decrypt(out, cipherTextAndTag); 94 } 95 96 @SuppressWarnings("BetaApi") // verify is stable in Guava 97 @Override decrypt(ByteBuf out, ByteBuf ciphertextAndTag)98 public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException { 99 int bytesRead = ciphertextAndTag.readableBytes(); 100 checkArgument(bytesRead == out.writableBytes()); 101 102 checkArgument(out.nioBufferCount() == 1); 103 ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes()); 104 105 checkArgument(ciphertextAndTag.nioBufferCount() == 1); 106 ByteBuffer ciphertextAndTagBuffer = 107 ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead); 108 109 byte[] counter = incrementInCounter(); 110 int outPosition = outBuffer.position(); 111 aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter); 112 int bytesWritten = outBuffer.position() - outPosition; 113 out.writerIndex(out.writerIndex() + bytesWritten); 114 ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead); 115 verify(out.writableBytes() == TAG_LENGTH); 116 } 117 118 @Override getSuffixLength()119 public int getSuffixLength() { 120 return TAG_LENGTH; 121 } 122 123 @Override destroy()124 public void destroy() { 125 // no destroy required 126 } 127 128 /** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */ incrementCounter(byte[] counter, byte[] oldCounter)129 static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException { 130 System.arraycopy(counter, 0, oldCounter, 0, counter.length); 131 int i = 0; 132 for (; i < COUNTER_OVERFLOW_LENGTH; i++) { 133 counter[i]++; 134 if (counter[i] != (byte) 0x00) { 135 break; 136 } 137 } 138 139 if (i == COUNTER_OVERFLOW_LENGTH) { 140 // Restore old counter value to ensure that encrypt and decrypt keep failing. 141 System.arraycopy(oldCounter, 0, counter, 0, counter.length); 142 throw new GeneralSecurityException("Counter has overflowed."); 143 } 144 } 145 146 /** Increments the input counter, returning the previous (unincremented) value. */ incrementInCounter()147 private byte[] incrementInCounter() throws GeneralSecurityException { 148 incrementCounter(inCounter, oldCounter); 149 return oldCounter; 150 } 151 152 /** Increments the output counter, returning the previous (unincremented) value. */ incrementOutCounter()153 private byte[] incrementOutCounter() throws GeneralSecurityException { 154 incrementCounter(outCounter, oldCounter); 155 return oldCounter; 156 } 157 158 @VisibleForTesting incrementInCounterForTesting(int n)159 void incrementInCounterForTesting(int n) throws GeneralSecurityException { 160 for (int i = 0; i < n; i++) { 161 incrementInCounter(); 162 } 163 } 164 165 @VisibleForTesting incrementOutCounterForTesting(int n)166 void incrementOutCounterForTesting(int n) throws GeneralSecurityException { 167 for (int i = 0; i < n; i++) { 168 incrementOutCounter(); 169 } 170 } 171 } 172