• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.base.Preconditions.checkArgument;
20 import static com.google.common.base.Verify.verify;
21 
22 import com.google.common.annotations.VisibleForTesting;
23 import io.netty.buffer.ByteBuf;
24 import java.nio.ByteBuffer;
25 import java.security.GeneralSecurityException;
26 import java.util.List;
27 
28 /** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */
29 final class AltsChannelCrypter implements ChannelCrypterNetty {
30   private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength();
31   private static final int COUNTER_LENGTH = 12;
32   // The counter will overflow after 2^64 operations and encryption/decryption will stop working.
33   private static final int COUNTER_OVERFLOW_LENGTH = 8;
34   private static final int TAG_LENGTH = 16;
35 
36   private final AeadCrypter aeadCrypter;
37 
38   private final byte[] outCounter = new byte[COUNTER_LENGTH];
39   private final byte[] inCounter = new byte[COUNTER_LENGTH];
40   private final byte[] oldCounter = new byte[COUNTER_LENGTH];
41 
AltsChannelCrypter(byte[] key, boolean isClient)42   AltsChannelCrypter(byte[] key, boolean isClient) {
43     checkArgument(key.length == KEY_LENGTH);
44     byte[] counter = isClient ? inCounter : outCounter;
45     counter[counter.length - 1] = (byte) 0x80;
46     this.aeadCrypter = new AesGcmHkdfAeadCrypter(key);
47   }
48 
getKeyLength()49   static int getKeyLength() {
50     return KEY_LENGTH;
51   }
52 
getCounterLength()53   static int getCounterLength() {
54     return COUNTER_LENGTH;
55   }
56 
57   @SuppressWarnings("BetaApi") // verify is stable in Guava
58   @Override
encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs)59   public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
60     checkArgument(outBuf.nioBufferCount() == 1);
61     // Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
62     ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
63     plainBuf.writerIndex(0);
64     for (ByteBuf inBuf : plainBufs) {
65       plainBuf.writeBytes(inBuf);
66     }
67 
68     verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
69     ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
70     ByteBuffer plain = out.duplicate();
71     plain.limit(out.limit() - TAG_LENGTH);
72 
73     byte[] counter = incrementOutCounter();
74     int outPosition = out.position();
75     aeadCrypter.encrypt(out, plain, counter);
76     int bytesWritten = out.position() - outPosition;
77     outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
78     verify(!outBuf.isWritable());
79   }
80 
81   @Override
decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)82   public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
83       throws GeneralSecurityException {
84 
85     ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
86     cipherTextAndTag.writerIndex(0);
87 
88     for (ByteBuf inBuf : ciphertextBufs) {
89       cipherTextAndTag.writeBytes(inBuf);
90     }
91     cipherTextAndTag.writeBytes(tag);
92 
93     decrypt(out, cipherTextAndTag);
94   }
95 
96   @SuppressWarnings("BetaApi") // verify is stable in Guava
97   @Override
decrypt(ByteBuf out, ByteBuf ciphertextAndTag)98   public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
99     int bytesRead = ciphertextAndTag.readableBytes();
100     checkArgument(bytesRead == out.writableBytes());
101 
102     checkArgument(out.nioBufferCount() == 1);
103     ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());
104 
105     checkArgument(ciphertextAndTag.nioBufferCount() == 1);
106     ByteBuffer ciphertextAndTagBuffer =
107         ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);
108 
109     byte[] counter = incrementInCounter();
110     int outPosition = outBuffer.position();
111     aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
112     int bytesWritten = outBuffer.position() - outPosition;
113     out.writerIndex(out.writerIndex() + bytesWritten);
114     ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
115     verify(out.writableBytes() == TAG_LENGTH);
116   }
117 
118   @Override
getSuffixLength()119   public int getSuffixLength() {
120     return TAG_LENGTH;
121   }
122 
123   @Override
destroy()124   public void destroy() {
125     // no destroy required
126   }
127 
128   /** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */
incrementCounter(byte[] counter, byte[] oldCounter)129   static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException {
130     System.arraycopy(counter, 0, oldCounter, 0, counter.length);
131     int i = 0;
132     for (; i < COUNTER_OVERFLOW_LENGTH; i++) {
133       counter[i]++;
134       if (counter[i] != (byte) 0x00) {
135         break;
136       }
137     }
138 
139     if (i == COUNTER_OVERFLOW_LENGTH) {
140       // Restore old counter value to ensure that encrypt and decrypt keep failing.
141       System.arraycopy(oldCounter, 0, counter, 0, counter.length);
142       throw new GeneralSecurityException("Counter has overflowed.");
143     }
144   }
145 
146   /** Increments the input counter, returning the previous (unincremented) value. */
incrementInCounter()147   private byte[] incrementInCounter() throws GeneralSecurityException {
148     incrementCounter(inCounter, oldCounter);
149     return oldCounter;
150   }
151 
152   /** Increments the output counter, returning the previous (unincremented) value. */
incrementOutCounter()153   private byte[] incrementOutCounter() throws GeneralSecurityException {
154     incrementCounter(outCounter, oldCounter);
155     return oldCounter;
156   }
157 
158   @VisibleForTesting
incrementInCounterForTesting(int n)159   void incrementInCounterForTesting(int n) throws GeneralSecurityException {
160     for (int i = 0; i < n; i++) {
161       incrementInCounter();
162     }
163   }
164 
165   @VisibleForTesting
incrementOutCounterForTesting(int n)166   void incrementOutCounterForTesting(int n) throws GeneralSecurityException {
167     for (int i = 0; i < n; i++) {
168       incrementOutCounter();
169     }
170   }
171 }
172