• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ////////////////////////////////////////////////////////////////////////////////
16 
17 package com.google.crypto.tink.subtle;
18 
19 import static java.lang.Math.min;
20 
21 import java.io.FilterInputStream;
22 import java.io.IOException;
23 import java.io.InputStream;
24 import java.nio.ByteBuffer;
25 import java.security.GeneralSecurityException;
26 import java.util.Arrays;
27 
28 /**
29  * An instance of a InputStream that returns the plaintext for some ciphertext.
30  *
31  * <p>TODO(bleichen): define what the state is after an IOException.
32  */
33 class StreamingAeadDecryptingStream extends FilterInputStream {
34   // Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains.
35   // This is a workaround for an incompatibility between Conscrypt and OpenJDK in their
36   // AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information.
37   // Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones
38   // running old versions of Conscrypt, so we decided to take matters into our own hands.
39   // Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the
40   // size of the tags of each AES-GCM ciphertext segment.
41   private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16;
42 
43   /**
44    * A buffer containing ciphertext that has not yet been decrypted. The limit of ciphertextSegment
45    * is set such that it can contain segment plus the first character of the next segment. It is
46    * necessary to read a segment plus one more byte to decrypt a segment, since the last segment of
47    * a ciphertext is encrypted differently.
48    */
49   private final ByteBuffer ciphertextSegment;
50 
51   /**
52    * A buffer containing a plaintext segment. The bytes in the range plaintexSegment.position() ..
53    * plaintextSegment.limit() - 1 are plaintext that have been decrypted but not yet read out of
54    * AesGcmInputStream.
55    */
56   private final ByteBuffer plaintextSegment;
57 
58   /* Header information */
59   private final int headerLength;
60   private boolean headerRead;
61 
62   /* Indicates whether the end of this InputStream has been reached. */
63   private boolean endOfCiphertext;
64 
65   /* Indicates whether the end of the plaintext has been reached. */
66   private boolean endOfPlaintext;
67 
68   /* Indicates whether a decyrption error has occured. */
69   private boolean decryptionErrorOccured;
70 
71   /** The additional data that is authenticated with the ciphertext. */
72   private final byte[] aad;
73 
74   /** The number of the current segment of ciphertext buffered in ciphertexSegment. */
75   private int segmentNr;
76 
77   private final StreamSegmentDecrypter decrypter;
78   private final int ciphertextSegmentSize;
79   private final int firstCiphertextSegmentSize;
80 
StreamingAeadDecryptingStream( NonceBasedStreamingAead streamAead, InputStream ciphertextStream, byte[] associatedData)81   public StreamingAeadDecryptingStream(
82       NonceBasedStreamingAead streamAead, InputStream ciphertextStream, byte[] associatedData)
83       throws GeneralSecurityException, IOException {
84     super(ciphertextStream);
85     decrypter = streamAead.newStreamSegmentDecrypter();
86     headerLength = streamAead.getHeaderLength();
87     aad = Arrays.copyOf(associatedData, associatedData.length);
88     // ciphertextSegment is one byte longer than a ciphertext segment,
89     // so that the code can decide if the current segment is the last segment in the
90     // stream.
91     ciphertextSegmentSize = streamAead.getCiphertextSegmentSize();
92     ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize + 1);
93     ciphertextSegment.limit(0);
94     firstCiphertextSegmentSize = ciphertextSegmentSize - streamAead.getCiphertextOffset();
95     plaintextSegment = ByteBuffer.allocate(streamAead.getPlaintextSegmentSize()
96         + PLAINTEXT_SEGMENT_EXTRA_SIZE);
97     plaintextSegment.limit(0);
98     headerRead = false;
99     endOfCiphertext = false;
100     endOfPlaintext = false;
101     segmentNr = 0;
102     decryptionErrorOccured = false;
103   }
104 
105   /**
106    * Reads the header of the ciphertext and sets headerRead = true.
107    *
108    * @throws IOException when an exception occurs while reading from {@code in} or when the header
109    *     is too short.
110    */
readHeader()111   private void readHeader() throws IOException {
112     if (headerRead) {
113       setDecryptionErrorOccured();
114       throw new IOException("Decryption failed.");
115     }
116     ByteBuffer header = ByteBuffer.allocate(headerLength);
117     while (header.remaining() > 0) {
118       int read = in.read(header.array(), header.position(), header.remaining());
119       if (read == -1) {
120         setDecryptionErrorOccured();
121         throw new IOException("Ciphertext is too short");
122       }
123       if (read == 0) {
124         throw new IOException("Could not read bytes from the ciphertext stream");
125       }
126       header.position(header.position() + read);
127     }
128     header.flip();
129     try {
130       decrypter.init(header, aad);
131     } catch (GeneralSecurityException ex) {
132       throw new IOException(ex);
133     }
134     headerRead = true;
135   }
136 
setDecryptionErrorOccured()137   private void setDecryptionErrorOccured() {
138     decryptionErrorOccured = true;
139     plaintextSegment.limit(0);
140   }
141 
142   /** Loads the next plaintext segment. */
loadSegment()143   private void loadSegment() throws IOException {
144     // Try filling the ciphertextSegment
145     while (!endOfCiphertext && ciphertextSegment.remaining() > 0) {
146       int read =
147           in.read(
148               ciphertextSegment.array(),
149               ciphertextSegment.position(),
150               ciphertextSegment.remaining());
151       if (read > 0) {
152         ciphertextSegment.position(ciphertextSegment.position() + read);
153       } else if (read == -1) {
154         endOfCiphertext = true;
155       } else if (read == 0) {
156         // We expect that read returns at least one byte.
157         throw new IOException("Could not read bytes from the ciphertext stream");
158       }
159     }
160     byte lastByte = 0;
161     if (!endOfCiphertext) {
162       lastByte = ciphertextSegment.get(ciphertextSegment.position() - 1);
163       ciphertextSegment.position(ciphertextSegment.position() - 1);
164     }
165     ciphertextSegment.flip();
166     plaintextSegment.clear();
167     try {
168       decrypter.decryptSegment(ciphertextSegment, segmentNr, endOfCiphertext, plaintextSegment);
169     } catch (GeneralSecurityException ex) {
170       // The current segment did not validate.
171       // Currently this means that decryption cannot resume.
172       setDecryptionErrorOccured();
173       throw new IOException(
174           ex.getMessage()
175               + "\n"
176               + toString()
177               + "\nsegmentNr:"
178               + segmentNr
179               + " endOfCiphertext:"
180               + endOfCiphertext,
181           ex);
182     }
183     segmentNr += 1;
184     plaintextSegment.flip();
185     ciphertextSegment.clear();
186     if (!endOfCiphertext) {
187       ciphertextSegment.clear();
188       ciphertextSegment.limit(ciphertextSegmentSize + 1);
189       ciphertextSegment.put(lastByte);
190     }
191   }
192 
193   @Override
read()194   public int read() throws IOException {
195     byte[] oneByte = new byte[1];
196     int ret = read(oneByte, 0, 1);
197     if (ret == 1) {
198       return oneByte[0] & 0xff;
199     } else if (ret == -1) {
200       return ret;
201     } else {
202       throw new IOException("Reading failed");
203     }
204   }
205 
206   @Override
read(byte[] dst)207   public int read(byte[] dst) throws IOException {
208     return read(dst, 0, dst.length);
209   }
210 
211   @Override
read(byte[] dst, int offset, int length)212   public synchronized int read(byte[] dst, int offset, int length) throws IOException {
213     if (decryptionErrorOccured) {
214       throw new IOException("Decryption failed.");
215     }
216     if (!headerRead) {
217       readHeader();
218       ciphertextSegment.clear();
219       ciphertextSegment.limit(firstCiphertextSegmentSize + 1);
220     }
221     if (endOfPlaintext) {
222       return -1;
223     }
224     int bytesRead = 0;
225     while (bytesRead < length) {
226       if (plaintextSegment.remaining() == 0) {
227         if (endOfCiphertext) {
228           endOfPlaintext = true;
229           break;
230         }
231         loadSegment();
232       }
233       int sliceSize = min(plaintextSegment.remaining(), length - bytesRead);
234       plaintextSegment.get(dst, bytesRead + offset, sliceSize);
235       bytesRead += sliceSize;
236     }
237     if (bytesRead == 0 && endOfPlaintext) {
238       return -1;
239     } else {
240       return bytesRead;
241     }
242   }
243 
244   @Override
close()245   public synchronized void close() throws IOException {
246     super.close();
247   }
248 
249   @Override
available()250   public synchronized int available() {
251     return plaintextSegment.remaining();
252   }
253 
254   @Override
mark(int readlimit)255   public synchronized void mark(int readlimit) {
256     // Mark is not supported.
257   }
258 
259   @Override
markSupported()260   public boolean markSupported() {
261     return false;
262   }
263 
264   /**
265    * Skips over and discards <code>n</code> bytes of plaintext from the input stream. The
266    * implementation reads and decrypts the plaintext that is skipped. Hence skipping a large number
267    * of bytes is slow.
268    *
269    * <p>Returns the number of bytes skipped. This number can be smaller than the number of bytes
270    * requested. This can happend for a number of reasons: e.g., this happens when the underlying
271    * stream is non-blocking and not enough bytes are available or when the stream reaches the end of
272    * the stream.
273    *
274    * @throws IOException when an exception occurs while reading from {@code in} or when the
275    *     ciphertext is corrupt. Currently all corrupt ciphertext will be detected. However this
276    *     behaviour may change.
277    */
278   @Override
skip(long n)279   public long skip(long n) throws IOException {
280     long maxSkipBufferSize = ciphertextSegmentSize;
281     long remaining = n;
282     if (n <= 0) {
283       return 0;
284     }
285     int size = (int) min(maxSkipBufferSize, remaining);
286     byte[] skipBuffer = new byte[size];
287     while (remaining > 0) {
288       int bytesRead = read(skipBuffer, 0, (int) min(size, remaining));
289       if (bytesRead <= 0) {
290         break;
291       }
292       remaining -= bytesRead;
293     }
294     return n - remaining;
295   }
296 
297   /* Returns the state of the channel. */
298   @Override
toString()299   public synchronized String toString() {
300     StringBuilder res = new StringBuilder();
301     res.append("StreamingAeadDecryptingStream")
302         .append("\nsegmentNr:")
303         .append(segmentNr)
304         .append("\nciphertextSegmentSize:")
305         .append(ciphertextSegmentSize)
306         .append("\nheaderRead:")
307         .append(headerRead)
308         .append("\nendOfCiphertext:")
309         .append(endOfCiphertext)
310         .append("\nendOfPlaintext:")
311         .append(endOfPlaintext)
312         .append("\ndecryptionErrorOccured:")
313         .append(decryptionErrorOccured)
314         .append("\nciphertextSgement")
315         .append(" position:")
316         .append(ciphertextSegment.position())
317         .append(" limit:")
318         .append(ciphertextSegment.limit())
319         .append("\nplaintextSegment")
320         .append(" position:")
321         .append(plaintextSegment.position())
322         .append(" limit:")
323         .append(plaintextSegment.limit());
324     return res.toString();
325   }
326 }
327