• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ////////////////////////////////////////////////////////////////////////////////
16 
17 package com.google.crypto.tink.subtle;
18 
19 import com.google.errorprone.annotations.CanIgnoreReturnValue;
20 import java.io.IOException;
21 import java.nio.ByteBuffer;
22 import java.nio.channels.ClosedChannelException;
23 import java.nio.channels.NonWritableChannelException;
24 import java.nio.channels.SeekableByteChannel;
25 import java.security.GeneralSecurityException;
26 import java.util.Arrays;
27 
28 /**
29  * An instance of {@link SeekableByteChannel} that allows random access to the plaintext of some
30  * ciphertext.
31  */
32 class StreamingAeadSeekableDecryptingChannel implements SeekableByteChannel {
33   // Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains.
34   // This is a workaround for an incompatibility between Conscrypt and OpenJDK in their
35   // AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information.
36   // Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones
37   // running old versions of Conscrypt, so we decided to take matters into our own hands.
38   // Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the
39   // size of the tags of each AES-GCM ciphertext segment.
40   private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16;
41 
42   private final SeekableByteChannel ciphertextChannel;
43   private final ByteBuffer ciphertextSegment;
44   private final ByteBuffer plaintextSegment;
45   private final ByteBuffer header;
46   private final long ciphertextChannelSize;  // unverified size of the ciphertext
47   private final int numberOfSegments;  // unverified number of segments
48   private final int lastCiphertextSegmentSize;  // unverified size of the last segment.
49   private final byte[] aad;
50   private final StreamSegmentDecrypter decrypter;
51   private long plaintextPosition;
52   private long plaintextSize;
53   private boolean headerRead;
54   private boolean isCurrentSegmentDecrypted;
55   private int currentSegmentNr;
56   private boolean isopen;
57   private final int plaintextSegmentSize;
58   private final int ciphertextSegmentSize;
59   private final int ciphertextOffset;
60   private final int firstSegmentOffset;
61 
StreamingAeadSeekableDecryptingChannel( NonceBasedStreamingAead streamAead, SeekableByteChannel ciphertext, byte[] associatedData)62   public StreamingAeadSeekableDecryptingChannel(
63       NonceBasedStreamingAead streamAead,
64       SeekableByteChannel ciphertext,
65       byte[] associatedData) throws IOException, GeneralSecurityException {
66     decrypter = streamAead.newStreamSegmentDecrypter();
67     ciphertextChannel = ciphertext;
68     header = ByteBuffer.allocate(streamAead.getHeaderLength());
69     ciphertextSegmentSize = streamAead.getCiphertextSegmentSize();
70     ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize);
71     plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
72     plaintextSegment = ByteBuffer.allocate(plaintextSegmentSize + PLAINTEXT_SEGMENT_EXTRA_SIZE);
73     plaintextPosition = 0;
74     headerRead = false;
75     currentSegmentNr = -1;
76     isCurrentSegmentDecrypted = false;
77     ciphertextChannelSize = ciphertextChannel.size();
78     aad = Arrays.copyOf(associatedData, associatedData.length);
79     isopen = ciphertextChannel.isOpen();
80     int  fullSegments = (int) (ciphertextChannelSize / ciphertextSegmentSize);
81     int remainder = (int) (ciphertextChannelSize % ciphertextSegmentSize);
82     int ciphertextOverhead = streamAead.getCiphertextOverhead();
83     if (remainder > 0) {
84       numberOfSegments = fullSegments + 1;
85       if (remainder < ciphertextOverhead) {
86         throw new IOException("Invalid ciphertext size");
87       }
88       lastCiphertextSegmentSize = remainder;
89     } else {
90       numberOfSegments = fullSegments;
91       lastCiphertextSegmentSize = ciphertextSegmentSize;
92     }
93     ciphertextOffset = streamAead.getCiphertextOffset();
94     firstSegmentOffset = ciphertextOffset - streamAead.getHeaderLength();
95     if (firstSegmentOffset < 0) {
96       throw new IOException("Invalid ciphertext offset or header length");
97     }
98     long overhead = (long) numberOfSegments * ciphertextOverhead + ciphertextOffset;
99     if (overhead > ciphertextChannelSize) {
100       throw new IOException("Ciphertext is too short");
101     }
102     plaintextSize = ciphertextChannelSize - overhead;
103   }
104 
105   /**
106    * A description of the state of this StreamingAeadSeekableDecryptingChannel.
107    * While this description does not contain plaintext or key material
108    * it contains length information that might leak some information.
109    */
110   @Override
toString()111   public synchronized String toString() {
112     StringBuilder res =
113       new StringBuilder();
114     String ctChannel;
115     try {
116       ctChannel = "position:" + ciphertextChannel.position();
117     } catch (IOException ex) {
118       ctChannel = "position: n/a";
119     }
120     res.append("StreamingAeadSeekableDecryptingChannel")
121        .append("\nciphertextChannel").append(ctChannel)
122        .append("\nciphertextChannelSize:").append(ciphertextChannelSize)
123        .append("\nplaintextSize:").append(plaintextSize)
124        .append("\nciphertextSegmentSize:").append(ciphertextSegmentSize)
125        .append("\nnumberOfSegments:").append(numberOfSegments)
126        .append("\nheaderRead:").append(headerRead)
127        .append("\nplaintextPosition:").append(plaintextPosition)
128        .append("\nHeader")
129        .append(" position:").append(header.position())
130        .append(" limit:").append(header.position())
131        .append("\ncurrentSegmentNr:").append(currentSegmentNr)
132        .append("\nciphertextSgement")
133        .append(" position:").append(ciphertextSegment.position())
134        .append(" limit:").append(ciphertextSegment.limit())
135        .append("\nisCurrentSegmentDecrypted:").append(isCurrentSegmentDecrypted)
136        .append("\nplaintextSegment")
137        .append(" position:").append(plaintextSegment.position())
138        .append(" limit:").append(plaintextSegment.limit());
139     return res.toString();
140   }
141 
142   /**
143    * Returns the position of this channel.
144    * The position is relative to the plaintext.
145    */
146   @Override
position()147   public synchronized long position() {
148     return plaintextPosition;
149   }
150 
151   /**
152    * Sets the position in the plaintext. Setting the position to a value greater than the plaintext
153    * size is legal. A later attempt to read byte will throw an IOException.
154    */
155   @CanIgnoreReturnValue
156   @Override
position(long newPosition)157   public synchronized SeekableByteChannel position(long newPosition) {
158     plaintextPosition = newPosition;
159     return this;
160   }
161 
162   /**
163    * Tries to read the header of the ciphertext and derive the key used for the ciphertext from the
164    * information in the header.
165    *
166    * @return true if the header was fully read and has a correct format. Returns false if the header
167    *     could not be read.
168    * @throws IOException if the header was incorrectly formatted or if there was an exception during
169    *     the key derivation.
170    */
tryReadHeader()171   private boolean tryReadHeader() throws IOException {
172     ciphertextChannel.position(header.position() + firstSegmentOffset);
173     ciphertextChannel.read(header);
174     if (header.remaining() > 0) {
175       return false;
176     } else {
177       header.flip();
178       try {
179         decrypter.init(header, aad);
180         headerRead = true;
181       } catch (GeneralSecurityException ex) {
182         // TODO(bleichen): Define the state of this.
183         throw new IOException(ex);
184       }
185       return true;
186     }
187   }
188 
getSegmentNr(long plaintextPosition)189   private int getSegmentNr(long plaintextPosition) {
190     return (int) ((plaintextPosition + ciphertextOffset) / plaintextSegmentSize);
191   }
192 
193   /**
194    * Tries to read and decrypt a ciphertext segment.
195    * @param segmentNr the number of the segment
196    * @return true if the segment was read and correctly decrypted.
197    *          Returns false if the segment could not be fully read.
198    * @throws IOException if there was an exception reading the ciphertext,
199    *         if the segment number was incorrect, or
200    *         if there was an exception trying to decrypt the ciphertext segment.
201    */
tryLoadSegment(int segmentNr)202   private boolean tryLoadSegment(int segmentNr) throws IOException {
203     if (segmentNr < 0 || segmentNr >= numberOfSegments) {
204       throw new IOException("Invalid position");
205     }
206     boolean isLast = segmentNr == numberOfSegments - 1;
207     if (segmentNr == currentSegmentNr) {
208       if (isCurrentSegmentDecrypted) {
209         return true;
210       }
211     } else {
212       // segmentNr != currentSegmentNr
213       long ciphertextPosition = (long) segmentNr * ciphertextSegmentSize;
214       int segmentSize = ciphertextSegmentSize;
215       if (isLast) {
216         segmentSize = lastCiphertextSegmentSize;
217       }
218       if (segmentNr == 0) {
219         segmentSize -= ciphertextOffset;
220         ciphertextPosition = ciphertextOffset;
221       }
222       ciphertextChannel.position(ciphertextPosition);
223       ciphertextSegment.clear();
224       ciphertextSegment.limit(segmentSize);
225       currentSegmentNr = segmentNr;
226       isCurrentSegmentDecrypted = false;
227     }
228     if (ciphertextSegment.remaining() > 0) {
229       ciphertextChannel.read(ciphertextSegment);
230     }
231     if (ciphertextSegment.remaining() > 0) {
232       return false;
233     }
234     ciphertextSegment.flip();
235     plaintextSegment.clear();
236     try {
237       decrypter.decryptSegment(ciphertextSegment, segmentNr, isLast, plaintextSegment);
238     } catch (GeneralSecurityException ex) {
239       // The current segment did not validate. Ensure that this instance remains
240       // in a valid state.
241       currentSegmentNr = -1;
242       throw new IOException("Failed to decrypt", ex);
243     }
244     plaintextSegment.flip();
245     isCurrentSegmentDecrypted = true;
246     return true;
247   }
248 
249   /**
250    * Returns true if plaintextPositon is at the end of the file
251    * and this has been verified, by decrypting the last segment.
252    */
reachedEnd()253   private boolean reachedEnd() {
254     return (plaintextPosition == plaintextSize
255         && isCurrentSegmentDecrypted
256         && currentSegmentNr == numberOfSegments - 1
257         && plaintextSegment.remaining() == 0);
258   }
259 
260   /**
261    * Atomic read from a given position.
262    *
263    * This method works in the same way as read(ByteBuffer), except that it starts at the given
264    * position and does not modify the channel's position.
265    */
read(ByteBuffer dst, long start)266   public synchronized int read(ByteBuffer dst, long start) throws IOException {
267     long oldPosition = position();
268     try {
269       position(start);
270       return read(dst);
271     } finally {
272       position(oldPosition);
273     }
274   }
275 
276   @Override
read(ByteBuffer dst)277   public synchronized int read(ByteBuffer dst) throws IOException {
278     if (!isopen) {
279       throw new ClosedChannelException();
280     }
281     if (!headerRead) {
282       if (!tryReadHeader()) {
283         return 0;
284       }
285     }
286     int startPos = dst.position();
287     while (dst.remaining() > 0 && plaintextPosition < plaintextSize) {
288       // Determine segmentNr for the plaintext to read and the offset in
289       // the plaintext, where reading should start.
290       int segmentNr = getSegmentNr(plaintextPosition);
291       int segmentOffset;
292       if (segmentNr == 0) {
293          segmentOffset = (int) plaintextPosition;
294       } else {
295          segmentOffset = (int) ((plaintextPosition +  ciphertextOffset) % plaintextSegmentSize);
296       }
297 
298       if (tryLoadSegment(segmentNr)) {
299         plaintextSegment.position(segmentOffset);
300         if (plaintextSegment.remaining() <= dst.remaining()) {
301           plaintextPosition += plaintextSegment.remaining();
302           dst.put(plaintextSegment);
303         } else {
304           int sliceSize = dst.remaining();
305           ByteBuffer slice = plaintextSegment.duplicate();
306           slice.limit(slice.position() + sliceSize);
307           dst.put(slice);
308           plaintextPosition += sliceSize;
309           plaintextSegment.position(plaintextSegment.position() + sliceSize);
310         }
311       } else {
312         break;
313       }
314     }
315     int read = dst.position() - startPos;
316     if (read == 0 && reachedEnd()) {
317       return -1;
318     }
319     return read;
320   }
321 
322   /**
323    * Returns the expected size of the plaintext.
324    * Note that this implementation does not perform an integrity check on the size.
325    * I.e. if the file has been truncated then size() will return the wrong
326    * result. Reading the last block of the ciphertext will verify whether size()
327    * is correct.
328    */
329   @Override
size()330   public long size() {
331     return plaintextSize;
332   }
333 
verifiedSize()334   public synchronized long verifiedSize() throws IOException {
335     if (tryLoadSegment(numberOfSegments - 1)) {
336       return plaintextSize;
337     } else {
338       throw new IOException("could not verify the size");
339     }
340   }
341 
342   @Override
truncate(long size)343   public SeekableByteChannel truncate(long size) throws NonWritableChannelException {
344     throw new NonWritableChannelException();
345   }
346 
347   @Override
write(ByteBuffer src)348   public int write(ByteBuffer src) throws NonWritableChannelException {
349     throw new NonWritableChannelException();
350   }
351 
352   @Override
close()353   public synchronized void close() throws IOException {
354     ciphertextChannel.close();
355     isopen = false;
356   }
357 
358   @Override
isOpen()359   public synchronized boolean isOpen() {
360     return isopen;
361   }
362 }
363