• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.alts.internal.ByteBufTestUtils.getDirectBuffer;
21 import static io.grpc.alts.internal.ByteBufTestUtils.getRandom;
22 import static java.nio.charset.StandardCharsets.UTF_8;
23 import static org.junit.Assert.fail;
24 
25 import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
26 import io.netty.buffer.ByteBuf;
27 import io.netty.buffer.Unpooled;
28 import io.netty.util.ReferenceCounted;
29 import java.security.GeneralSecurityException;
30 import java.util.ArrayList;
31 import java.util.Collections;
32 import java.util.List;
33 import javax.crypto.AEADBadTagException;
34 import org.junit.Test;
35 
36 /** Abstract class for unit tests of {@link ChannelCrypterNetty}. */
37 public abstract class ChannelCrypterNettyTestBase {
38   private static final String DECRYPTION_FAILURE_MESSAGE = "Tag mismatch";
39 
40   protected final List<ReferenceCounted> references = new ArrayList<>();
41   public ChannelCrypterNetty client;
42   public ChannelCrypterNetty server;
43   private final RegisterRef ref =
44       new RegisterRef() {
45         @Override
46         public ByteBuf register(ByteBuf buf) {
47           if (buf != null) {
48             references.add(buf);
49           }
50           return buf;
51         }
52       };
53 
54   static final class FrameEncrypt {
55     List<ByteBuf> plain;
56     ByteBuf out;
57   }
58 
59   static final class FrameDecrypt {
60     List<ByteBuf> ciphertext;
61     ByteBuf out;
62     ByteBuf tag;
63   }
64 
createFrameEncrypt(String message)65   FrameEncrypt createFrameEncrypt(String message) {
66     byte[] messageBytes = message.getBytes(UTF_8);
67     FrameEncrypt frame = new FrameEncrypt();
68     ByteBuf plain = getDirectBuffer(messageBytes.length, ref);
69     plain.writeBytes(messageBytes);
70     frame.plain = Collections.singletonList(plain);
71     frame.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
72     return frame;
73   }
74 
frameDecryptOfEncrypt(FrameEncrypt frameEncrypt)75   FrameDecrypt frameDecryptOfEncrypt(FrameEncrypt frameEncrypt) {
76     int tagLen = client.getSuffixLength();
77     FrameDecrypt frameDecrypt = new FrameDecrypt();
78     ByteBuf out = frameEncrypt.out;
79     frameDecrypt.ciphertext =
80         Collections.singletonList(out.slice(out.readerIndex(), out.readableBytes() - tagLen));
81     frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
82     frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
83     return frameDecrypt;
84   }
85 
86   @Test
encryptDecrypt()87   public void encryptDecrypt() throws GeneralSecurityException {
88     String message = "Hello world";
89     FrameEncrypt frameEncrypt = createFrameEncrypt(message);
90     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
91     FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
92 
93     server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
94     assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
95         .isEqualTo(frameDecrypt.out);
96   }
97 
98   @Test
encryptDecryptLarge()99   public void encryptDecryptLarge() throws GeneralSecurityException {
100     FrameEncrypt frameEncrypt = new FrameEncrypt();
101     ByteBuf plain = getRandom(17 * 1024, ref);
102     frameEncrypt.plain = Collections.singletonList(plain);
103     frameEncrypt.out = getDirectBuffer(plain.readableBytes() + client.getSuffixLength(), ref);
104 
105     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
106     FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
107 
108     // Call decrypt overload that takes ciphertext and tag.
109     server.decrypt(frameDecrypt.out, frameEncrypt.out);
110     assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
111         .isEqualTo(frameDecrypt.out);
112   }
113 
114   @Test
encryptDecryptMultiple()115   public void encryptDecryptMultiple() throws GeneralSecurityException {
116     String message = "Hello world";
117     for (int i = 0; i < 512; ++i) {
118       FrameEncrypt frameEncrypt = createFrameEncrypt(message);
119       client.encrypt(frameEncrypt.out, frameEncrypt.plain);
120       FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
121 
122       server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
123       assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
124           .isEqualTo(frameDecrypt.out);
125     }
126   }
127 
128   @Test
encryptDecryptComposite()129   public void encryptDecryptComposite() throws GeneralSecurityException {
130     String message = "Hello world";
131     int lastLen = 2;
132     byte[] messageBytes = message.getBytes(UTF_8);
133     FrameEncrypt frameEncrypt = new FrameEncrypt();
134     ByteBuf plain1 = getDirectBuffer(messageBytes.length - lastLen, ref);
135     ByteBuf plain2 = getDirectBuffer(lastLen, ref);
136     plain1.writeBytes(messageBytes, 0, messageBytes.length - lastLen);
137     plain2.writeBytes(messageBytes, messageBytes.length - lastLen, lastLen);
138     ByteBuf plain = Unpooled.wrappedBuffer(plain1, plain2);
139     frameEncrypt.plain = Collections.singletonList(plain);
140     frameEncrypt.out = getDirectBuffer(messageBytes.length + client.getSuffixLength(), ref);
141 
142     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
143 
144     int tagLen = client.getSuffixLength();
145     FrameDecrypt frameDecrypt = new FrameDecrypt();
146     ByteBuf out = frameEncrypt.out;
147     int outLen = out.readableBytes();
148     ByteBuf cipher1 = getDirectBuffer(outLen - lastLen - tagLen, ref);
149     ByteBuf cipher2 = getDirectBuffer(lastLen, ref);
150     cipher1.writeBytes(out, 0, outLen - lastLen - tagLen);
151     cipher2.writeBytes(out, outLen - tagLen - lastLen, lastLen);
152     ByteBuf cipher = Unpooled.wrappedBuffer(cipher1, cipher2);
153     frameDecrypt.ciphertext = Collections.singletonList(cipher);
154     frameDecrypt.tag = out.slice(out.readerIndex() + out.readableBytes() - tagLen, tagLen);
155     frameDecrypt.out = getDirectBuffer(out.readableBytes(), ref);
156 
157     server.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
158     assertThat(frameEncrypt.plain.get(0).slice(0, frameDecrypt.out.readableBytes()))
159         .isEqualTo(frameDecrypt.out);
160   }
161 
162   @Test
reflection()163   public void reflection() throws GeneralSecurityException {
164     String message = "Hello world";
165     FrameEncrypt frameEncrypt = createFrameEncrypt(message);
166     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
167     FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
168     try {
169       client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
170       fail("Exception expected");
171     } catch (AEADBadTagException ex) {
172       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
173     }
174   }
175 
176   @Test
skipMessage()177   public void skipMessage() throws GeneralSecurityException {
178     String message = "Hello world";
179     FrameEncrypt frameEncrypt1 = createFrameEncrypt(message);
180     client.encrypt(frameEncrypt1.out, frameEncrypt1.plain);
181     FrameEncrypt frameEncrypt2 = createFrameEncrypt(message);
182     client.encrypt(frameEncrypt2.out, frameEncrypt2.plain);
183     FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt2);
184 
185     try {
186       client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
187       fail("Exception expected");
188     } catch (AEADBadTagException ex) {
189       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
190     }
191   }
192 
193   @Test
corruptMessage()194   public void corruptMessage() throws GeneralSecurityException {
195     String message = "Hello world";
196     FrameEncrypt frameEncrypt = createFrameEncrypt(message);
197     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
198     FrameDecrypt frameDecrypt = frameDecryptOfEncrypt(frameEncrypt);
199     frameEncrypt.out.setByte(3, frameEncrypt.out.getByte(3) + 1);
200 
201     try {
202       client.decrypt(frameDecrypt.out, frameDecrypt.tag, frameDecrypt.ciphertext);
203       fail("Exception expected");
204     } catch (AEADBadTagException ex) {
205       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
206     }
207   }
208 
209   @Test
replayMessage()210   public void replayMessage() throws GeneralSecurityException {
211     String message = "Hello world";
212     FrameEncrypt frameEncrypt = createFrameEncrypt(message);
213     client.encrypt(frameEncrypt.out, frameEncrypt.plain);
214     FrameDecrypt frameDecrypt1 = frameDecryptOfEncrypt(frameEncrypt);
215     FrameDecrypt frameDecrypt2 = frameDecryptOfEncrypt(frameEncrypt);
216 
217     server.decrypt(frameDecrypt1.out, frameDecrypt1.tag, frameDecrypt1.ciphertext);
218 
219     try {
220       server.decrypt(frameDecrypt2.out, frameDecrypt2.tag, frameDecrypt2.ciphertext);
221       fail("Exception expected");
222     } catch (AEADBadTagException ex) {
223       assertThat(ex).hasMessageThat().contains(DECRYPTION_FAILURE_MESSAGE);
224     }
225   }
226 }
227