1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static io.grpc.alts.internal.ByteBufTestUtils.getDirectBuffer; 21 import static io.grpc.alts.internal.ByteBufTestUtils.getRandom; 22 import static java.nio.charset.StandardCharsets.UTF_8; 23 import static org.junit.Assert.fail; 24 25 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef; 26 import io.netty.buffer.ByteBuf; 27 import io.netty.buffer.Unpooled; 28 import io.netty.util.ReferenceCounted; 29 import java.security.GeneralSecurityException; 30 import java.util.ArrayList; 31 import java.util.Collections; 32 import java.util.List; 33 import javax.crypto.AEADBadTagException; 34 import org.junit.Test; 35 36 /** Abstract class for unit tests of {@link ChannelCrypterNetty}. */ 37 public abstract class ChannelCrypterNettyTestBase { 38 private static final String DECRYPTION_FAILURE_MESSAGE = "Tag mismatch"; 39 40 protected final List<ReferenceCounted> references = new ArrayList<>(); 41 public ChannelCrypterNetty client; 42 public ChannelCrypterNetty server; 43 private final RegisterRef ref = 44 new RegisterRef() { 45 @Override 46 public ByteBuf register(ByteBuf buf) { 47 if (buf != null) { 48 references.add(buf); 49 } 50 return buf; 51 } 52 }; 53 54 static final class FrameEncrypt { 55 List<ByteBuf> plain; 56 ByteBuf out; 57 } 58 59 static final class FrameDecrypt { 60 List<ByteBuf> ciphertext; 61 ByteBuf out; 62 ByteBuf tag; 63 } 64 createFrameEncrypt(String message)65 FrameEncrypt createFrameEncrypt(String message) { 66 byte[] messageBytes = message.getBytes(UTF_8); 67 FrameEncrypt frame = new FrameEncrypt(); 68 ByteBuf plain = getDirectBuffer(messageBytes.length, ref); 69 plain.writeBytes(messageBytes); 70 frame.plain = Collections.singletonList(plain); 71 frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref); 72 return frame; 73 } 74 frameDecryptOfEncrypt(FrameEncrypt frameEncrypt)75 FrameDecrypt frameDecryptOfEncrypt(FrameEncrypt frameEncrypt) { 76 int tagLen = client.getSuffixLength(); 77 FrameDecrypt frameDecrypt = new FrameDecrypt(); 78 ByteBuf out = frameEncrypt.out; 79 frameDecrypt.ciphertext = 80 Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen)); 81 frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); 82 frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref); 83 return frameDecrypt; 84 } 85 86 @Test encryptDecrypt()87 public void encryptDecrypt() throws GeneralSecurityException { 88 String message = "Hello world"; 89 FrameEncrypt frameEncrypt = createFrameEncrypt(message); 90 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 91 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); 92 93 server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 94 assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) 95 .isEqualTo(frameDecrypt.out); 96 } 97 98 @Test encryptDecryptLarge()99 public void encryptDecryptLarge() throws GeneralSecurityException { 100 FrameEncrypt frameEncrypt = new FrameEncrypt(); 101 ByteBuf plain = getRandom(17 * 1024, ref); 102 frameEncrypt.plain = Collections.singletonList(plain); 103 frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref); 104 105 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 106 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); 107 108 // Call decrypt overload that takes ciphertext and tag. 109 server.decrypt(frameDecrypt.out, frameEncrypt.out); 110 assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) 111 .isEqualTo(frameDecrypt.out); 112 } 113 114 @Test encryptDecryptMultiple()115 public void encryptDecryptMultiple() throws GeneralSecurityException { 116 String message = "Hello world"; 117 for (int i = 0; i < 512; ++i) { 118 FrameEncrypt frameEncrypt = createFrameEncrypt(message); 119 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 120 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); 121 122 server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 123 assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) 124 .isEqualTo(frameDecrypt.out); 125 } 126 } 127 128 @Test encryptDecryptComposite()129 public void encryptDecryptComposite() throws GeneralSecurityException { 130 String message = "Hello world"; 131 int lastLen = 2; 132 byte[] messageBytes = message.getBytes(UTF_8); 133 FrameEncrypt frameEncrypt = new FrameEncrypt(); 134 ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref); 135 ByteBuf plain2 = getDirectBuffer(lastLen, ref); 136 plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen); 137 plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen); 138 ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2); 139 frameEncrypt.plain = Collections.singletonList(plain); 140 frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref); 141 142 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 143 144 int tagLen = client.getSuffixLength(); 145 FrameDecrypt frameDecrypt = new FrameDecrypt(); 146 ByteBuf out = frameEncrypt.out; 147 int outLen = out.readableBytes(); 148 ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref); 149 ByteBuf cipher2 = getDirectBuffer(lastLen, ref); 150 cipher1.writeBytes(out, 0, outLen - lastLen - tagLen); 151 cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen); 152 ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2); 153 frameDecrypt.ciphertext = Collections.singletonList(cipher); 154 frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen); 155 frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref); 156 157 server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 158 assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes())) 159 .isEqualTo(frameDecrypt.out); 160 } 161 162 @Test reflection()163 public void reflection() throws GeneralSecurityException { 164 String message = "Hello world"; 165 FrameEncrypt frameEncrypt = createFrameEncrypt(message); 166 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 167 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); 168 try { 169 client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 170 fail("Exception expected"); 171 } catch (AEADBadTagException ex) { 172 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); 173 } 174 } 175 176 @Test skipMessage()177 public void skipMessage() throws GeneralSecurityException { 178 String message = "Hello world"; 179 FrameEncrypt frameEncrypt1 = createFrameEncrypt(message); 180 client.encrypt(frameEncrypt1.out, frameEncrypt1.plain); 181 FrameEncrypt frameEncrypt2 = createFrameEncrypt(message); 182 client.encrypt(frameEncrypt2.out, frameEncrypt2.plain); 183 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt2); 184 185 try { 186 client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 187 fail("Exception expected"); 188 } catch (AEADBadTagException ex) { 189 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); 190 } 191 } 192 193 @Test corruptMessage()194 public void corruptMessage() throws GeneralSecurityException { 195 String message = "Hello world"; 196 FrameEncrypt frameEncrypt = createFrameEncrypt(message); 197 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 198 FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt); 199 frameEncrypt.out.setByte(3, frameEncrypt.out.getByte(3) + 1); 200 201 try { 202 client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext); 203 fail("Exception expected"); 204 } catch (AEADBadTagException ex) { 205 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); 206 } 207 } 208 209 @Test replayMessage()210 public void replayMessage() throws GeneralSecurityException { 211 String message = "Hello world"; 212 FrameEncrypt frameEncrypt = createFrameEncrypt(message); 213 client.encrypt(frameEncrypt.out, frameEncrypt.plain); 214 FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt); 215 FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt); 216 217 server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext); 218 219 try { 220 server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext); 221 fail("Exception expected"); 222 } catch (AEADBadTagException ex) { 223 assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE); 224 } 225 } 226 } 227