• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.alts.internal.AltsChannelCrypter.incrementCounter;
21 import static org.junit.Assert.fail;
22 
23 import com.google.common.testing.GcFinalization;
24 import io.netty.util.ReferenceCounted;
25 import io.netty.util.ResourceLeakDetector;
26 import io.netty.util.ResourceLeakDetector.Level;
27 import java.security.GeneralSecurityException;
28 import java.util.Arrays;
29 import org.junit.After;
30 import org.junit.Before;
31 import org.junit.Test;
32 import org.junit.runner.RunWith;
33 import org.junit.runners.JUnit4;
34 
35 /** Unit tests for {@link AltsChannelCrypter}. */
36 @RunWith(JUnit4.class)
37 public final class AltsChannelCrypterTest extends ChannelCrypterNettyTestBase {
38 
39   @Before
setUp()40   public void setUp() throws GeneralSecurityException {
41     ResourceLeakDetector.setLevel(Level.PARANOID);
42     client = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
43     server = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
44   }
45 
46   @After
tearDown()47   public void tearDown() throws GeneralSecurityException {
48     for (ReferenceCounted reference : references) {
49       reference.release();
50     }
51     references.clear();
52     client.destroy();
53     server.destroy();
54     // Increase our chances to detect ByteBuf leaks.
55     GcFinalization.awaitFullGc();
56   }
57 
58   @Test
encryptDecryptKdfCounterIncr()59   public void encryptDecryptKdfCounterIncr() throws GeneralSecurityException {
60     AltsChannelCrypter client =
61         new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true);
62     AltsChannelCrypter server =
63         new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false);
64 
65     String message = "Hello world";
66     FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
67 
68     client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
69     FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt1);
70 
71     server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
72     assertThat(frameEncrypt1.plain.get(0).slice(0, frameDecrypt1.out.readableBytes()))
73         .isEqualTo(frameDecrypt1.out);
74 
75     // Increase counters to get a new KDF counter value (first two bytes are skipped).
76     client.incrementOutCounterForTesting(1 << 17);
77     server.incrementInCounterForTesting(1 << 17);
78 
79     FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
80 
81     client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
82     FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt2);
83 
84     server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
85     assertThat(frameEncrypt2.plain.get(0).slice(0, frameDecrypt2.out.readableBytes()))
86         .isEqualTo(frameDecrypt2.out);
87   }
88 
89   @Test
overflowsClient()90   public void overflowsClient() throws GeneralSecurityException {
91     byte[] maxFirst =
92         new byte[] {
93           (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
94           (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
95           (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
96         };
97 
98     byte[] maxFirstPred = Arrays.copyOf(maxFirst, maxFirst.length);
99     maxFirstPred[0]--;
100 
101     byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
102     byte[] counter = Arrays.copyOf(maxFirstPred, maxFirstPred.length);
103 
104     incrementCounter(counter, oldCounter);
105 
106     assertThat(oldCounter).isEqualTo(maxFirstPred);
107     assertThat(counter).isEqualTo(maxFirst);
108 
109     try {
110       incrementCounter(counter, oldCounter);
111       fail("Exception expected");
112     } catch (GeneralSecurityException ex) {
113       assertThat(ex).hasMessageThat().contains("Counter has overflowed");
114     }
115 
116     assertThat(oldCounter).isEqualTo(maxFirst);
117     assertThat(counter).isEqualTo(maxFirst);
118   }
119 
120   @Test
overflowsServer()121   public void overflowsServer() throws GeneralSecurityException {
122     byte[] maxSecond =
123         new byte[] {
124           (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
125           (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF,
126           (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x80
127         };
128 
129     byte[] maxSecondPred = Arrays.copyOf(maxSecond, maxSecond.length);
130     maxSecondPred[0]--;
131 
132     byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()];
133     byte[] counter = Arrays.copyOf(maxSecondPred, maxSecondPred.length);
134 
135     incrementCounter(counter, oldCounter);
136 
137     assertThat(oldCounter).isEqualTo(maxSecondPred);
138     assertThat(counter).isEqualTo(maxSecond);
139 
140     try {
141       incrementCounter(counter, oldCounter);
142       fail("Exception expected");
143     } catch (GeneralSecurityException ex) {
144       assertThat(ex).hasMessageThat().contains("Counter has overflowed");
145     }
146 
147     assertThat(oldCounter).isEqualTo(maxSecond);
148     assertThat(counter).isEqualTo(maxSecond);
149   }
150 }
151