1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static io.grpc.alts.internal.AltsChannelCrypter.incrementCounter; 21 import static org.junit.Assert.fail; 22 23 import com.google.common.testing.GcFinalization; 24 import io.netty.util.ReferenceCounted; 25 import io.netty.util.ResourceLeakDetector; 26 import io.netty.util.ResourceLeakDetector.Level; 27 import java.security.GeneralSecurityException; 28 import java.util.Arrays; 29 import org.junit.After; 30 import org.junit.Before; 31 import org.junit.Test; 32 import org.junit.runner.RunWith; 33 import org.junit.runners.JUnit4; 34 35 /** Unit tests for {@link AltsChannelCrypter}. */ 36 @RunWith(JUnit4.class) 37 public final class AltsChannelCrypterTest extends ChannelCrypterNettyTestBase { 38 39 @Before setUp()40 public void setUp() throws GeneralSecurityException { 41 ResourceLeakDetector.setLevel(Level.PARANOID); 42 client = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true); 43 server = new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false); 44 } 45 46 @After tearDown()47 public void tearDown() throws GeneralSecurityException { 48 for (ReferenceCounted reference : references) { 49 reference.release(); 50 } 51 references.clear(); 52 client.destroy(); 53 server.destroy(); 54 // Increase our chances to detect ByteBuf leaks. 55 GcFinalization.awaitFullGc(); 56 } 57 58 @Test encryptDecryptKdfCounterIncr()59 public void encryptDecryptKdfCounterIncr() throws GeneralSecurityException { 60 AltsChannelCrypter client = 61 new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], true); 62 AltsChannelCrypter server = 63 new AltsChannelCrypter(new byte[AltsChannelCrypter.getKeyLength()], false); 64 65 String message = "Hello world"; 66 FrameEncrypt frameEncrypt1 = createFrameEncrypt(message); 67 68 client.encrypt(frameEncrypt1.out, frameEncrypt1.plain); 69 FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt1); 70 71 server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext); 72 assertThat(frameEncrypt1.plain.get(0).slice(0, frameDecrypt1.out.readableBytes())) 73 .isEqualTo(frameDecrypt1.out); 74 75 // Increase counters to get a new KDF counter value (first two bytes are skipped). 76 client.incrementOutCounterForTesting(1 << 17); 77 server.incrementInCounterForTesting(1 << 17); 78 79 FrameEncrypt frameEncrypt2 = createFrameEncrypt(message); 80 81 client.encrypt(frameEncrypt2.out, frameEncrypt2.plain); 82 FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt2); 83 84 server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext); 85 assertThat(frameEncrypt2.plain.get(0).slice(0, frameDecrypt2.out.readableBytes())) 86 .isEqualTo(frameDecrypt2.out); 87 } 88 89 @Test overflowsClient()90 public void overflowsClient() throws GeneralSecurityException { 91 byte[] maxFirst = 92 new byte[] { 93 (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 94 (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 95 (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00 96 }; 97 98 byte[] maxFirstPred = Arrays.copyOf(maxFirst, maxFirst.length); 99 maxFirstPred[0]--; 100 101 byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()]; 102 byte[] counter = Arrays.copyOf(maxFirstPred, maxFirstPred.length); 103 104 incrementCounter(counter, oldCounter); 105 106 assertThat(oldCounter).isEqualTo(maxFirstPred); 107 assertThat(counter).isEqualTo(maxFirst); 108 109 try { 110 incrementCounter(counter, oldCounter); 111 fail("Exception expected"); 112 } catch (GeneralSecurityException ex) { 113 assertThat(ex).hasMessageThat().contains("Counter has overflowed"); 114 } 115 116 assertThat(oldCounter).isEqualTo(maxFirst); 117 assertThat(counter).isEqualTo(maxFirst); 118 } 119 120 @Test overflowsServer()121 public void overflowsServer() throws GeneralSecurityException { 122 byte[] maxSecond = 123 new byte[] { 124 (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 125 (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, 126 (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x80 127 }; 128 129 byte[] maxSecondPred = Arrays.copyOf(maxSecond, maxSecond.length); 130 maxSecondPred[0]--; 131 132 byte[] oldCounter = new byte[AltsChannelCrypter.getCounterLength()]; 133 byte[] counter = Arrays.copyOf(maxSecondPred, maxSecondPred.length); 134 135 incrementCounter(counter, oldCounter); 136 137 assertThat(oldCounter).isEqualTo(maxSecondPred); 138 assertThat(counter).isEqualTo(maxSecond); 139 140 try { 141 incrementCounter(counter, oldCounter); 142 fail("Exception expected"); 143 } catch (GeneralSecurityException ex) { 144 assertThat(ex).hasMessageThat().contains("Counter has overflowed"); 145 } 146 147 assertThat(oldCounter).isEqualTo(maxSecond); 148 assertThat(counter).isEqualTo(maxSecond); 149 } 150 } 151