1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 //////////////////////////////////////////////////////////////////////////////// 16 17 package com.google.crypto.tink.subtle; 18 19 import static java.lang.Math.min; 20 21 import java.io.FilterInputStream; 22 import java.io.IOException; 23 import java.io.InputStream; 24 import java.nio.ByteBuffer; 25 import java.security.GeneralSecurityException; 26 import java.util.Arrays; 27 28 /** 29 * An instance of a InputStream that returns the plaintext for some ciphertext. 30 * 31 * <p>TODO(bleichen): define what the state is after an IOException. 32 */ 33 class StreamingAeadDecryptingStream extends FilterInputStream { 34 // Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains. 35 // This is a workaround for an incompatibility between Conscrypt and OpenJDK in their 36 // AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information. 37 // Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones 38 // running old versions of Conscrypt, so we decided to take matters into our own hands. 39 // Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the 40 // size of the tags of each AES-GCM ciphertext segment. 41 private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16; 42 43 /** 44 * A buffer containing ciphertext that has not yet been decrypted. The limit of ciphertextSegment 45 * is set such that it can contain segment plus the first character of the next segment. It is 46 * necessary to read a segment plus one more byte to decrypt a segment, since the last segment of 47 * a ciphertext is encrypted differently. 48 */ 49 private final ByteBuffer ciphertextSegment; 50 51 /** 52 * A buffer containing a plaintext segment. The bytes in the range plaintexSegment.position() .. 53 * plaintextSegment.limit() - 1 are plaintext that have been decrypted but not yet read out of 54 * AesGcmInputStream. 55 */ 56 private final ByteBuffer plaintextSegment; 57 58 /* Header information */ 59 private final int headerLength; 60 private boolean headerRead; 61 62 /* Indicates whether the end of this InputStream has been reached. */ 63 private boolean endOfCiphertext; 64 65 /* Indicates whether the end of the plaintext has been reached. */ 66 private boolean endOfPlaintext; 67 68 /* Indicates whether a decyrption error has occured. */ 69 private boolean decryptionErrorOccured; 70 71 /** The additional data that is authenticated with the ciphertext. */ 72 private final byte[] aad; 73 74 /** The number of the current segment of ciphertext buffered in ciphertexSegment. */ 75 private int segmentNr; 76 77 private final StreamSegmentDecrypter decrypter; 78 private final int ciphertextSegmentSize; 79 private final int firstCiphertextSegmentSize; 80 StreamingAeadDecryptingStream( NonceBasedStreamingAead streamAead, InputStream ciphertextStream, byte[] associatedData)81 public StreamingAeadDecryptingStream( 82 NonceBasedStreamingAead streamAead, InputStream ciphertextStream, byte[] associatedData) 83 throws GeneralSecurityException, IOException { 84 super(ciphertextStream); 85 decrypter = streamAead.newStreamSegmentDecrypter(); 86 headerLength = streamAead.getHeaderLength(); 87 aad = Arrays.copyOf(associatedData, associatedData.length); 88 // ciphertextSegment is one byte longer than a ciphertext segment, 89 // so that the code can decide if the current segment is the last segment in the 90 // stream. 91 ciphertextSegmentSize = streamAead.getCiphertextSegmentSize(); 92 ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize + 1); 93 ciphertextSegment.limit(0); 94 firstCiphertextSegmentSize = ciphertextSegmentSize - streamAead.getCiphertextOffset(); 95 plaintextSegment = ByteBuffer.allocate(streamAead.getPlaintextSegmentSize() 96 + PLAINTEXT_SEGMENT_EXTRA_SIZE); 97 plaintextSegment.limit(0); 98 headerRead = false; 99 endOfCiphertext = false; 100 endOfPlaintext = false; 101 segmentNr = 0; 102 decryptionErrorOccured = false; 103 } 104 105 /** 106 * Reads the header of the ciphertext and sets headerRead = true. 107 * 108 * @throws IOException when an exception occurs while reading from {@code in} or when the header 109 * is too short. 110 */ readHeader()111 private void readHeader() throws IOException { 112 if (headerRead) { 113 setDecryptionErrorOccured(); 114 throw new IOException("Decryption failed."); 115 } 116 ByteBuffer header = ByteBuffer.allocate(headerLength); 117 while (header.remaining() > 0) { 118 int read = in.read(header.array(), header.position(), header.remaining()); 119 if (read == -1) { 120 setDecryptionErrorOccured(); 121 throw new IOException("Ciphertext is too short"); 122 } 123 if (read == 0) { 124 throw new IOException("Could not read bytes from the ciphertext stream"); 125 } 126 header.position(header.position() + read); 127 } 128 header.flip(); 129 try { 130 decrypter.init(header, aad); 131 } catch (GeneralSecurityException ex) { 132 throw new IOException(ex); 133 } 134 headerRead = true; 135 } 136 setDecryptionErrorOccured()137 private void setDecryptionErrorOccured() { 138 decryptionErrorOccured = true; 139 plaintextSegment.limit(0); 140 } 141 142 /** Loads the next plaintext segment. */ loadSegment()143 private void loadSegment() throws IOException { 144 // Try filling the ciphertextSegment 145 while (!endOfCiphertext && ciphertextSegment.remaining() > 0) { 146 int read = 147 in.read( 148 ciphertextSegment.array(), 149 ciphertextSegment.position(), 150 ciphertextSegment.remaining()); 151 if (read > 0) { 152 ciphertextSegment.position(ciphertextSegment.position() + read); 153 } else if (read == -1) { 154 endOfCiphertext = true; 155 } else if (read == 0) { 156 // We expect that read returns at least one byte. 157 throw new IOException("Could not read bytes from the ciphertext stream"); 158 } 159 } 160 byte lastByte = 0; 161 if (!endOfCiphertext) { 162 lastByte = ciphertextSegment.get(ciphertextSegment.position() - 1); 163 ciphertextSegment.position(ciphertextSegment.position() - 1); 164 } 165 ciphertextSegment.flip(); 166 plaintextSegment.clear(); 167 try { 168 decrypter.decryptSegment(ciphertextSegment, segmentNr, endOfCiphertext, plaintextSegment); 169 } catch (GeneralSecurityException ex) { 170 // The current segment did not validate. 171 // Currently this means that decryption cannot resume. 172 setDecryptionErrorOccured(); 173 throw new IOException( 174 ex.getMessage() 175 + "\n" 176 + toString() 177 + "\nsegmentNr:" 178 + segmentNr 179 + " endOfCiphertext:" 180 + endOfCiphertext, 181 ex); 182 } 183 segmentNr += 1; 184 plaintextSegment.flip(); 185 ciphertextSegment.clear(); 186 if (!endOfCiphertext) { 187 ciphertextSegment.clear(); 188 ciphertextSegment.limit(ciphertextSegmentSize + 1); 189 ciphertextSegment.put(lastByte); 190 } 191 } 192 193 @Override read()194 public int read() throws IOException { 195 byte[] oneByte = new byte[1]; 196 int ret = read(oneByte, 0, 1); 197 if (ret == 1) { 198 return oneByte[0] & 0xff; 199 } else if (ret == -1) { 200 return ret; 201 } else { 202 throw new IOException("Reading failed"); 203 } 204 } 205 206 @Override read(byte[] dst)207 public int read(byte[] dst) throws IOException { 208 return read(dst, 0, dst.length); 209 } 210 211 @Override read(byte[] dst, int offset, int length)212 public synchronized int read(byte[] dst, int offset, int length) throws IOException { 213 if (decryptionErrorOccured) { 214 throw new IOException("Decryption failed."); 215 } 216 if (!headerRead) { 217 readHeader(); 218 ciphertextSegment.clear(); 219 ciphertextSegment.limit(firstCiphertextSegmentSize + 1); 220 } 221 if (endOfPlaintext) { 222 return -1; 223 } 224 int bytesRead = 0; 225 while (bytesRead < length) { 226 if (plaintextSegment.remaining() == 0) { 227 if (endOfCiphertext) { 228 endOfPlaintext = true; 229 break; 230 } 231 loadSegment(); 232 } 233 int sliceSize = min(plaintextSegment.remaining(), length - bytesRead); 234 plaintextSegment.get(dst, bytesRead + offset, sliceSize); 235 bytesRead += sliceSize; 236 } 237 if (bytesRead == 0 && endOfPlaintext) { 238 return -1; 239 } else { 240 return bytesRead; 241 } 242 } 243 244 @Override close()245 public synchronized void close() throws IOException { 246 super.close(); 247 } 248 249 @Override available()250 public synchronized int available() { 251 return plaintextSegment.remaining(); 252 } 253 254 @Override mark(int readlimit)255 public synchronized void mark(int readlimit) { 256 // Mark is not supported. 257 } 258 259 @Override markSupported()260 public boolean markSupported() { 261 return false; 262 } 263 264 /** 265 * Skips over and discards <code>n</code> bytes of plaintext from the input stream. The 266 * implementation reads and decrypts the plaintext that is skipped. Hence skipping a large number 267 * of bytes is slow. 268 * 269 * <p>Returns the number of bytes skipped. This number can be smaller than the number of bytes 270 * requested. This can happend for a number of reasons: e.g., this happens when the underlying 271 * stream is non-blocking and not enough bytes are available or when the stream reaches the end of 272 * the stream. 273 * 274 * @throws IOException when an exception occurs while reading from {@code in} or when the 275 * ciphertext is corrupt. Currently all corrupt ciphertext will be detected. However this 276 * behaviour may change. 277 */ 278 @Override skip(long n)279 public long skip(long n) throws IOException { 280 long maxSkipBufferSize = ciphertextSegmentSize; 281 long remaining = n; 282 if (n <= 0) { 283 return 0; 284 } 285 int size = (int) min(maxSkipBufferSize, remaining); 286 byte[] skipBuffer = new byte[size]; 287 while (remaining > 0) { 288 int bytesRead = read(skipBuffer, 0, (int) min(size, remaining)); 289 if (bytesRead <= 0) { 290 break; 291 } 292 remaining -= bytesRead; 293 } 294 return n - remaining; 295 } 296 297 /* Returns the state of the channel. */ 298 @Override toString()299 public synchronized String toString() { 300 StringBuilder res = new StringBuilder(); 301 res.append("StreamingAeadDecryptingStream") 302 .append("\nsegmentNr:") 303 .append(segmentNr) 304 .append("\nciphertextSegmentSize:") 305 .append(ciphertextSegmentSize) 306 .append("\nheaderRead:") 307 .append(headerRead) 308 .append("\nendOfCiphertext:") 309 .append(endOfCiphertext) 310 .append("\nendOfPlaintext:") 311 .append(endOfPlaintext) 312 .append("\ndecryptionErrorOccured:") 313 .append(decryptionErrorOccured) 314 .append("\nciphertextSgement") 315 .append(" position:") 316 .append(ciphertextSegment.position()) 317 .append(" limit:") 318 .append(ciphertextSegment.limit()) 319 .append("\nplaintextSegment") 320 .append(" position:") 321 .append(plaintextSegment.position()) 322 .append(" limit:") 323 .append(plaintextSegment.limit()); 324 return res.toString(); 325 } 326 } 327