1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 //////////////////////////////////////////////////////////////////////////////// 16 17 package com.google.crypto.tink.subtle; 18 19 import com.google.errorprone.annotations.CanIgnoreReturnValue; 20 import java.io.IOException; 21 import java.nio.ByteBuffer; 22 import java.nio.channels.ClosedChannelException; 23 import java.nio.channels.NonWritableChannelException; 24 import java.nio.channels.SeekableByteChannel; 25 import java.security.GeneralSecurityException; 26 import java.util.Arrays; 27 28 /** 29 * An instance of {@link SeekableByteChannel} that allows random access to the plaintext of some 30 * ciphertext. 31 */ 32 class StreamingAeadSeekableDecryptingChannel implements SeekableByteChannel { 33 // Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains. 34 // This is a workaround for an incompatibility between Conscrypt and OpenJDK in their 35 // AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information. 36 // Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones 37 // running old versions of Conscrypt, so we decided to take matters into our own hands. 38 // Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the 39 // size of the tags of each AES-GCM ciphertext segment. 40 private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16; 41 42 private final SeekableByteChannel ciphertextChannel; 43 private final ByteBuffer ciphertextSegment; 44 private final ByteBuffer plaintextSegment; 45 private final ByteBuffer header; 46 private final long ciphertextChannelSize; // unverified size of the ciphertext 47 private final int numberOfSegments; // unverified number of segments 48 private final int lastCiphertextSegmentSize; // unverified size of the last segment. 49 private final byte[] aad; 50 private final StreamSegmentDecrypter decrypter; 51 private long plaintextPosition; 52 private long plaintextSize; 53 private boolean headerRead; 54 private boolean isCurrentSegmentDecrypted; 55 private int currentSegmentNr; 56 private boolean isopen; 57 private final int plaintextSegmentSize; 58 private final int ciphertextSegmentSize; 59 private final int ciphertextOffset; 60 private final int firstSegmentOffset; 61 StreamingAeadSeekableDecryptingChannel( NonceBasedStreamingAead streamAead, SeekableByteChannel ciphertext, byte[] associatedData)62 public StreamingAeadSeekableDecryptingChannel( 63 NonceBasedStreamingAead streamAead, 64 SeekableByteChannel ciphertext, 65 byte[] associatedData) throws IOException, GeneralSecurityException { 66 decrypter = streamAead.newStreamSegmentDecrypter(); 67 ciphertextChannel = ciphertext; 68 header = ByteBuffer.allocate(streamAead.getHeaderLength()); 69 ciphertextSegmentSize = streamAead.getCiphertextSegmentSize(); 70 ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize); 71 plaintextSegmentSize = streamAead.getPlaintextSegmentSize(); 72 plaintextSegment = ByteBuffer.allocate(plaintextSegmentSize + PLAINTEXT_SEGMENT_EXTRA_SIZE); 73 plaintextPosition = 0; 74 headerRead = false; 75 currentSegmentNr = -1; 76 isCurrentSegmentDecrypted = false; 77 ciphertextChannelSize = ciphertextChannel.size(); 78 aad = Arrays.copyOf(associatedData, associatedData.length); 79 isopen = ciphertextChannel.isOpen(); 80 int fullSegments = (int) (ciphertextChannelSize / ciphertextSegmentSize); 81 int remainder = (int) (ciphertextChannelSize % ciphertextSegmentSize); 82 int ciphertextOverhead = streamAead.getCiphertextOverhead(); 83 if (remainder > 0) { 84 numberOfSegments = fullSegments + 1; 85 if (remainder < ciphertextOverhead) { 86 throw new IOException("Invalid ciphertext size"); 87 } 88 lastCiphertextSegmentSize = remainder; 89 } else { 90 numberOfSegments = fullSegments; 91 lastCiphertextSegmentSize = ciphertextSegmentSize; 92 } 93 ciphertextOffset = streamAead.getCiphertextOffset(); 94 firstSegmentOffset = ciphertextOffset - streamAead.getHeaderLength(); 95 if (firstSegmentOffset < 0) { 96 throw new IOException("Invalid ciphertext offset or header length"); 97 } 98 long overhead = (long) numberOfSegments * ciphertextOverhead + ciphertextOffset; 99 if (overhead > ciphertextChannelSize) { 100 throw new IOException("Ciphertext is too short"); 101 } 102 plaintextSize = ciphertextChannelSize - overhead; 103 } 104 105 /** 106 * A description of the state of this StreamingAeadSeekableDecryptingChannel. 107 * While this description does not contain plaintext or key material 108 * it contains length information that might leak some information. 109 */ 110 @Override toString()111 public synchronized String toString() { 112 StringBuilder res = 113 new StringBuilder(); 114 String ctChannel; 115 try { 116 ctChannel = "position:" + ciphertextChannel.position(); 117 } catch (IOException ex) { 118 ctChannel = "position: n/a"; 119 } 120 res.append("StreamingAeadSeekableDecryptingChannel") 121 .append("\nciphertextChannel").append(ctChannel) 122 .append("\nciphertextChannelSize:").append(ciphertextChannelSize) 123 .append("\nplaintextSize:").append(plaintextSize) 124 .append("\nciphertextSegmentSize:").append(ciphertextSegmentSize) 125 .append("\nnumberOfSegments:").append(numberOfSegments) 126 .append("\nheaderRead:").append(headerRead) 127 .append("\nplaintextPosition:").append(plaintextPosition) 128 .append("\nHeader") 129 .append(" position:").append(header.position()) 130 .append(" limit:").append(header.position()) 131 .append("\ncurrentSegmentNr:").append(currentSegmentNr) 132 .append("\nciphertextSgement") 133 .append(" position:").append(ciphertextSegment.position()) 134 .append(" limit:").append(ciphertextSegment.limit()) 135 .append("\nisCurrentSegmentDecrypted:").append(isCurrentSegmentDecrypted) 136 .append("\nplaintextSegment") 137 .append(" position:").append(plaintextSegment.position()) 138 .append(" limit:").append(plaintextSegment.limit()); 139 return res.toString(); 140 } 141 142 /** 143 * Returns the position of this channel. 144 * The position is relative to the plaintext. 145 */ 146 @Override position()147 public synchronized long position() { 148 return plaintextPosition; 149 } 150 151 /** 152 * Sets the position in the plaintext. Setting the position to a value greater than the plaintext 153 * size is legal. A later attempt to read byte will throw an IOException. 154 */ 155 @CanIgnoreReturnValue 156 @Override position(long newPosition)157 public synchronized SeekableByteChannel position(long newPosition) { 158 plaintextPosition = newPosition; 159 return this; 160 } 161 162 /** 163 * Tries to read the header of the ciphertext and derive the key used for the ciphertext from the 164 * information in the header. 165 * 166 * @return true if the header was fully read and has a correct format. Returns false if the header 167 * could not be read. 168 * @throws IOException if the header was incorrectly formatted or if there was an exception during 169 * the key derivation. 170 */ tryReadHeader()171 private boolean tryReadHeader() throws IOException { 172 ciphertextChannel.position(header.position() + firstSegmentOffset); 173 ciphertextChannel.read(header); 174 if (header.remaining() > 0) { 175 return false; 176 } else { 177 header.flip(); 178 try { 179 decrypter.init(header, aad); 180 headerRead = true; 181 } catch (GeneralSecurityException ex) { 182 // TODO(bleichen): Define the state of this. 183 throw new IOException(ex); 184 } 185 return true; 186 } 187 } 188 getSegmentNr(long plaintextPosition)189 private int getSegmentNr(long plaintextPosition) { 190 return (int) ((plaintextPosition + ciphertextOffset) / plaintextSegmentSize); 191 } 192 193 /** 194 * Tries to read and decrypt a ciphertext segment. 195 * @param segmentNr the number of the segment 196 * @return true if the segment was read and correctly decrypted. 197 * Returns false if the segment could not be fully read. 198 * @throws IOException if there was an exception reading the ciphertext, 199 * if the segment number was incorrect, or 200 * if there was an exception trying to decrypt the ciphertext segment. 201 */ tryLoadSegment(int segmentNr)202 private boolean tryLoadSegment(int segmentNr) throws IOException { 203 if (segmentNr < 0 || segmentNr >= numberOfSegments) { 204 throw new IOException("Invalid position"); 205 } 206 boolean isLast = segmentNr == numberOfSegments - 1; 207 if (segmentNr == currentSegmentNr) { 208 if (isCurrentSegmentDecrypted) { 209 return true; 210 } 211 } else { 212 // segmentNr != currentSegmentNr 213 long ciphertextPosition = (long) segmentNr * ciphertextSegmentSize; 214 int segmentSize = ciphertextSegmentSize; 215 if (isLast) { 216 segmentSize = lastCiphertextSegmentSize; 217 } 218 if (segmentNr == 0) { 219 segmentSize -= ciphertextOffset; 220 ciphertextPosition = ciphertextOffset; 221 } 222 ciphertextChannel.position(ciphertextPosition); 223 ciphertextSegment.clear(); 224 ciphertextSegment.limit(segmentSize); 225 currentSegmentNr = segmentNr; 226 isCurrentSegmentDecrypted = false; 227 } 228 if (ciphertextSegment.remaining() > 0) { 229 ciphertextChannel.read(ciphertextSegment); 230 } 231 if (ciphertextSegment.remaining() > 0) { 232 return false; 233 } 234 ciphertextSegment.flip(); 235 plaintextSegment.clear(); 236 try { 237 decrypter.decryptSegment(ciphertextSegment, segmentNr, isLast, plaintextSegment); 238 } catch (GeneralSecurityException ex) { 239 // The current segment did not validate. Ensure that this instance remains 240 // in a valid state. 241 currentSegmentNr = -1; 242 throw new IOException("Failed to decrypt", ex); 243 } 244 plaintextSegment.flip(); 245 isCurrentSegmentDecrypted = true; 246 return true; 247 } 248 249 /** 250 * Returns true if plaintextPositon is at the end of the file 251 * and this has been verified, by decrypting the last segment. 252 */ reachedEnd()253 private boolean reachedEnd() { 254 return (plaintextPosition == plaintextSize 255 && isCurrentSegmentDecrypted 256 && currentSegmentNr == numberOfSegments - 1 257 && plaintextSegment.remaining() == 0); 258 } 259 260 /** 261 * Atomic read from a given position. 262 * 263 * This method works in the same way as read(ByteBuffer), except that it starts at the given 264 * position and does not modify the channel's position. 265 */ read(ByteBuffer dst, long start)266 public synchronized int read(ByteBuffer dst, long start) throws IOException { 267 long oldPosition = position(); 268 try { 269 position(start); 270 return read(dst); 271 } finally { 272 position(oldPosition); 273 } 274 } 275 276 @Override read(ByteBuffer dst)277 public synchronized int read(ByteBuffer dst) throws IOException { 278 if (!isopen) { 279 throw new ClosedChannelException(); 280 } 281 if (!headerRead) { 282 if (!tryReadHeader()) { 283 return 0; 284 } 285 } 286 int startPos = dst.position(); 287 while (dst.remaining() > 0 && plaintextPosition < plaintextSize) { 288 // Determine segmentNr for the plaintext to read and the offset in 289 // the plaintext, where reading should start. 290 int segmentNr = getSegmentNr(plaintextPosition); 291 int segmentOffset; 292 if (segmentNr == 0) { 293 segmentOffset = (int) plaintextPosition; 294 } else { 295 segmentOffset = (int) ((plaintextPosition + ciphertextOffset) % plaintextSegmentSize); 296 } 297 298 if (tryLoadSegment(segmentNr)) { 299 plaintextSegment.position(segmentOffset); 300 if (plaintextSegment.remaining() <= dst.remaining()) { 301 plaintextPosition += plaintextSegment.remaining(); 302 dst.put(plaintextSegment); 303 } else { 304 int sliceSize = dst.remaining(); 305 ByteBuffer slice = plaintextSegment.duplicate(); 306 slice.limit(slice.position() + sliceSize); 307 dst.put(slice); 308 plaintextPosition += sliceSize; 309 plaintextSegment.position(plaintextSegment.position() + sliceSize); 310 } 311 } else { 312 break; 313 } 314 } 315 int read = dst.position() - startPos; 316 if (read == 0 && reachedEnd()) { 317 return -1; 318 } 319 return read; 320 } 321 322 /** 323 * Returns the expected size of the plaintext. 324 * Note that this implementation does not perform an integrity check on the size. 325 * I.e. if the file has been truncated then size() will return the wrong 326 * result. Reading the last block of the ciphertext will verify whether size() 327 * is correct. 328 */ 329 @Override size()330 public long size() { 331 return plaintextSize; 332 } 333 verifiedSize()334 public synchronized long verifiedSize() throws IOException { 335 if (tryLoadSegment(numberOfSegments - 1)) { 336 return plaintextSize; 337 } else { 338 throw new IOException("could not verify the size"); 339 } 340 } 341 342 @Override truncate(long size)343 public SeekableByteChannel truncate(long size) throws NonWritableChannelException { 344 throw new NonWritableChannelException(); 345 } 346 347 @Override write(ByteBuffer src)348 public int write(ByteBuffer src) throws NonWritableChannelException { 349 throw new NonWritableChannelException(); 350 } 351 352 @Override close()353 public synchronized void close() throws IOException { 354 ciphertextChannel.close(); 355 isopen = false; 356 } 357 358 @Override isOpen()359 public synchronized boolean isOpen() { 360 return isopen; 361 } 362 } 363