// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specified language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////

package com.google.crypto.tink.testing;

import static com.google.common.truth.Truth.assertThat;
import static java.lang.Math.min;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import com.google.crypto.tink.StreamingAead;
import com.google.crypto.tink.subtle.Hex;
import com.google.crypto.tink.subtle.Random;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.Reader;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NonWritableChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SeekableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.util.Arrays;

/** Helpers for streaming tests. */
public final class StreamingTestUtil {
  /**
   * Implements a SeekableByteChannel for testing.
   *
   * <p>The implementation is backed by a ByteBuffer.
   */
  public static class SeekableByteBufferChannel extends ByteBufferChannel
      implements SeekableByteChannel {
    public SeekableByteBufferChannel(ByteBuffer buffer) {
      super(buffer);
    }

    public SeekableByteBufferChannel(ByteBuffer buffer, int maxChunkSize) {
      super(buffer, maxChunkSize);
    }

    public SeekableByteBufferChannel(byte[] bytes) {
      super(bytes);
    }

    public SeekableByteBufferChannel(byte[] bytes, int maxChunkSize) {
      super(bytes, maxChunkSize);
    }

    @Override
    public long position() throws ClosedChannelException {
      checkIsOpen();
      return buffer.position();
    }

    @CanIgnoreReturnValue
    @Override
    public synchronized SeekableByteBufferChannel position(long newPosition)
        throws ClosedChannelException {
      checkIsOpen();
      if (newPosition < 0) {
        throw new IllegalArgumentException("negative position");
      }
      if (newPosition > buffer.limit()) {
        newPosition = buffer.limit();
      }
      buffer.position((int) newPosition);
      return this;
    }

    @Override
    public int write(ByteBuffer src) throws IOException {
      checkIsOpen();
      // not the most efficient way
      int size = Math.min(buffer.remaining(), src.remaining());
      size = Math.min(size, maxChunkSize);
      byte[] bytes = new byte[size];
      src.get(bytes);
      buffer.put(bytes);
      return size;
    }

    @Override
    public long size() throws ClosedChannelException {
      checkIsOpen();
      return buffer.limit();
    }

    @Override
    public SeekableByteChannel truncate(long size) {
      throw new NonWritableChannelException();
    }
  }

  /**
   * Implements a ReadableByteChannel for testing.
   *
   * <p>The implementation is backed by a ByteBuffer.
   */
  public static class ByteBufferChannel implements ReadableByteChannel {
    final ByteBuffer buffer;
    private final boolean noDataEveryOtherRead;
    private boolean returnDataOnNextRead;

    /**
     * Defines the maximal size of a chunk that is transferred with a single write. This can be used
     * to test the behavior of streaming encryption with channels where not always sufficiently many
     * bytes are available during reads and writes.
     */
    final int maxChunkSize;

    /** keeps track whether the channel is still open. */
    private boolean isopen;

    public ByteBufferChannel(ByteBuffer buffer) {
      this(buffer, Integer.MAX_VALUE);
    }

    public ByteBufferChannel(ByteBuffer buffer, int maxChunkSize) {
      this(buffer, maxChunkSize, /* noDataEveryOtherRead= */ false);
    }

    public ByteBufferChannel(ByteBuffer buffer, int maxChunkSize, boolean noDataEveryOtherRead) {
      this.buffer = buffer.duplicate();
      this.maxChunkSize = maxChunkSize;
      isopen = true;
      this.noDataEveryOtherRead = noDataEveryOtherRead;
      // when noDataEveryOtherRead, then the first read should already not return any data.
      this.returnDataOnNextRead = !noDataEveryOtherRead;
    }

    public ByteBufferChannel(byte[] bytes) {
      this(ByteBuffer.wrap(bytes));
    }

    public ByteBufferChannel(byte[] bytes, int maxChunkSize) {
      this(ByteBuffer.wrap(bytes), maxChunkSize);
    }

    public ByteBufferChannel(byte[] bytes, int maxChunkSize, boolean noDataEveryOtherRead) {
      this(ByteBuffer.wrap(bytes), maxChunkSize, noDataEveryOtherRead);
    }

    void checkIsOpen() throws ClosedChannelException {
      if (!isopen) {
        throw new ClosedChannelException();
      }
    }

    @Override
    public synchronized int read(ByteBuffer dst) throws IOException {
      checkIsOpen();
      if (this.noDataEveryOtherRead) {
        boolean returnData = this.returnDataOnNextRead;
        this.returnDataOnNextRead = !this.returnDataOnNextRead;
        if (!returnData) {
          return 0;
        }
      }
      if (buffer.remaining() == 0) {
        return -1;
      }
      // Not the most efficient way.
      int size = Math.min(buffer.remaining(), dst.remaining());
      size = Math.min(size, maxChunkSize);
      byte[] bytes = new byte[size];
      buffer.get(bytes);
      dst.put(bytes);
      return size;
    }

    @Override
    public void close() throws IOException {
      isopen = false;
    }

    @Override
    public boolean isOpen() {
      return isopen;
    }

    public void rewind() {
      isopen = true;
      buffer.rewind();
    }
  }

  /**
   * Implements a ReadableByteChannel for testing.
   *
   * <p>The implementation is backed by an array of bytes of size {@code BLOCK_SIZE}, which upon
   * read()-operation is repeated until the specified size of the channel.
   */
  public static class PseudorandomReadableByteChannel implements ReadableByteChannel {
    private final long size;
    private long position;
    private boolean open;
    private final byte[] repeatedBlock;
    public static final int BLOCK_SIZE = 1024;

    /** Returns a plaintext of a given size. */
    private byte[] generatePlaintext(int size) {
      byte[] plaintext = new byte[size];
      for (int i = 0; i < size; i++) {
        plaintext[i] = (byte) (i % 253);
      }
      return plaintext;
    }

    public PseudorandomReadableByteChannel(long size) {
      this.size = size;
      this.position = 0;
      this.open = true;
      this.repeatedBlock = generatePlaintext(BLOCK_SIZE);
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
      if (!open) {
        throw new ClosedChannelException();
      }
      if (position == size) {
        return -1;
      }
      long start = position;
      long end = Math.min(size, start + dst.remaining());
      long firstBlock = start / BLOCK_SIZE;
      long lastBlock = end / BLOCK_SIZE;
      int startOffset = (int) (start % BLOCK_SIZE);
      int endOffset = (int) (end % BLOCK_SIZE);
      if (firstBlock == lastBlock) {
        dst.put(repeatedBlock, startOffset, endOffset - startOffset);
      } else {
        dst.put(repeatedBlock, startOffset, BLOCK_SIZE - startOffset);
        for (long block = firstBlock + 1; block < lastBlock; block++) {
          dst.put(repeatedBlock);
        }
        dst.put(repeatedBlock, 0, endOffset);
      }
      position = end;
      return (int) (position - start);
    }

    @Override
    public void close() {
      this.open = false;
    }

    @Override
    public boolean isOpen() {
      return this.open;
    }
  }

  /**
   * Implements a ByteArrayInputStream that returns only small chunks for testing.
   */
  public static class SmallChunksByteArrayInputStream extends ByteArrayInputStream {
    final int maxChunkSize;

    SmallChunksByteArrayInputStream(byte[] data, int maxChunkSize) {
      super(data);
      this.maxChunkSize = maxChunkSize;
    }

    @Override
    public synchronized int available() {
      return min(maxChunkSize, super.available());
    }

    @Override
    public synchronized int read(byte[] b) {
      return super.read(b, 0, min(b.length, maxChunkSize));
    }

    @Override
    public synchronized int read(byte[] b, int off, int len) {
      return super.read(b, off, min(len, maxChunkSize));
    }
  }

  /** Returns a plaintext of a given size. */
  public static byte[] generatePlaintext(int size) {
    byte[] plaintext = new byte[size];
    for (int i = 0; i < size; i++) {
      plaintext[i] = (byte) (i % 253);
    }
    return plaintext;
  }

  public static byte[] concatBytes(byte[] first, byte[] last) {
    byte[] res = new byte[first.length + last.length];
    System.arraycopy(first, 0, res, 0, first.length);
    System.arraycopy(last, 0, res, first.length, last.length);
    return res;
  }

  /**
   * Tests encryption and decryption functionalities using {@code encryptionStreamingAead} for
   * encryption and {@code decryptionStreamingAead} for decryption.
   */
  public static void testEncryptionAndDecryption(
      StreamingAead encryptionStreamingAead, StreamingAead decryptionStreamingAead)
      throws Exception {
    byte[] associatedData = Random.randBytes(15);
    // Short plaintext.
    byte[] shortPlaintext = Random.randBytes(10);
    testEncryptionAndDecryption(
        encryptionStreamingAead, decryptionStreamingAead, shortPlaintext, associatedData);
    // Long plaintext.
    byte[] longPlaintext = Random.randBytes(1100);
    testEncryptionAndDecryption(
        encryptionStreamingAead, decryptionStreamingAead, longPlaintext, associatedData);

    // Even longer plaintext. A typical cache size for data types such as BufferedInputStream
    // is 8 kB. Hence, testing with inputs longer than this makes sense.
    byte[] evenLongerPlaintext = Random.randBytes(16000);
    testEncryptionAndDecryption(
        encryptionStreamingAead, decryptionStreamingAead, evenLongerPlaintext, associatedData);

    // Empty plaintext.
    byte[] empty = new byte[0];
    testEncryptionAndDecryption(
        encryptionStreamingAead, decryptionStreamingAead, empty, associatedData);

  }

  /** Tests encryption and decryption functionalities of {@code streamingAead}. */
  public static void testEncryptionAndDecryption(StreamingAead streamingAead) throws Exception {
    testEncryptionAndDecryption(streamingAead, streamingAead);
  }

  /**
   * Tests encryption and decryption functionalities using {@code encryptionStreamingAead} for
   * encryption and {@code decryptionStreamingAead} for decryption on inputs {@code plaintext} and
   * {@code associatedData}.
   */
  public static void testEncryptionAndDecryption(
      StreamingAead encryptionStreamingAead,
      StreamingAead decryptionStreamingAead,
      byte[] plaintext,
      byte[] associatedData)
      throws Exception {

    // Encrypt plaintext.
    ByteArrayOutputStream ciphertext = new ByteArrayOutputStream();
    try (WritableByteChannel encChannel =
        encryptionStreamingAead.newEncryptingChannel(
            Channels.newChannel(ciphertext), associatedData)) {
      encChannel.write(ByteBuffer.wrap(plaintext));
    }

    // Decrypt ciphertext via ReadableByteChannel.
    {
      ByteBufferChannel ciphertextChannel = new ByteBufferChannel(ciphertext.toByteArray());
      ReadableByteChannel decChannel =
          decryptionStreamingAead.newDecryptingChannel(ciphertextChannel, associatedData);
      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length);
      int unused = decChannel.read(decrypted);

      // Compare results;
      TestUtil.assertByteArrayEquals(plaintext, decrypted.array());
    }

    // Decrypt ciphertext via ReadableByteChannel, using a very small chunck size.
    {
      ByteBufferChannel ciphertextChannel = new ByteBufferChannel(
          ciphertext.toByteArray(), /* */10, true);
      ReadableByteChannel decChannel =
          decryptionStreamingAead.newDecryptingChannel(ciphertextChannel, associatedData);
      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length);
      do {
        int unused = decChannel.read(decrypted);
      } while (decrypted.hasRemaining());
      // Compare results;
      TestUtil.assertByteArrayEquals(plaintext, decrypted.array());
    }

    // Decrypt ciphertext via SeekableByteChannel.
    {
      SeekableByteChannel ciphertextChannel =
          new SeekableByteBufferChannel(ciphertext.toByteArray());
      SeekableByteChannel decChannel =
          decryptionStreamingAead.newSeekableDecryptingChannel(ciphertextChannel, associatedData);
      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length);
      int unused = decChannel.read(decrypted);

      // Compare results;
      TestUtil.assertByteArrayEquals(plaintext, decrypted.array());
    }

    // Decrypt ciphertext via SeekableByteChannel, using a very small chunck size.
    {
      SeekableByteChannel ciphertextChannel =
          new SeekableByteBufferChannel(ciphertext.toByteArray(), 10);
      SeekableByteChannel decChannel =
          decryptionStreamingAead.newSeekableDecryptingChannel(ciphertextChannel, associatedData);
      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length);
      do {
        int unused = decChannel.read(decrypted);
      } while (decrypted.hasRemaining());
      // Compare results;
      TestUtil.assertByteArrayEquals(plaintext, decrypted.array());
    }

    // Decrypt ciphertext via SeekableByteChannel, setting position
    if (plaintext.length > 5) {
      SeekableByteChannel ciphertextChannel =
          new SeekableByteBufferChannel(ciphertext.toByteArray(), 10);
      SeekableByteChannel decChannel =
          decryptionStreamingAead.newSeekableDecryptingChannel(ciphertextChannel, associatedData);
      decChannel.position(5);
      assertEquals(5, decChannel.position());

      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length - 5);
      do {
        int unused = decChannel.read(decrypted);
      } while (decrypted.hasRemaining());
      // Compare results;
      TestUtil.assertByteArrayEquals(
          Arrays.copyOfRange(plaintext, 5, plaintext.length),
          decrypted.array());
    }

    // Decrypt ciphertext via InputStream using read(byte[])
    {
      InputStream ctStream = new ByteArrayInputStream(ciphertext.toByteArray());
      InputStream decStream = decryptionStreamingAead.newDecryptingStream(ctStream, associatedData);
      byte[] decrypted = new byte[plaintext.length];
      int decryptedLength = decStream.read(decrypted);

      assertEquals("Decrypted length should be equal to plaintext length", decryptedLength,
          plaintext.length);
      TestUtil.assertByteArrayEquals(plaintext, decrypted);

      byte[] buf = new byte[1];
      int n = decStream.read(buf);
      assertThat(n).isEqualTo(-1);
    }

    // Decrypt ciphertext via InputStream using read()
    {
      InputStream ctStream = new ByteArrayInputStream(ciphertext.toByteArray());
      InputStream decStream = decryptionStreamingAead.newDecryptingStream(ctStream, associatedData);
      byte[] decrypted = new byte[plaintext.length];
      for (int i = 0; i < plaintext.length; i++) {
        int b = decStream.read();
        assertThat(b).isAtLeast(0);
        assertThat(b).isAtMost(255);
        decrypted[i] = (byte) b;
      }
      assertThat(decrypted).isEqualTo(plaintext);

      int b = decStream.read();
      assertThat(b).isEqualTo(-1);
    }

    // Decrypt ciphertext via InputStream using read(byte[], int, int)
    {
      InputStream ctStream = new ByteArrayInputStream(ciphertext.toByteArray());
      InputStream decStream = decryptionStreamingAead.newDecryptingStream(ctStream, associatedData);
      byte[] decrypted = new byte[plaintext.length];
      for (int i = 0; i < plaintext.length; i++) {
        int n = decStream.read(decrypted, i, 1);
        assertThat(n).isEqualTo(1);
      }
      assertThat(decrypted).isEqualTo(plaintext);

      byte[] buf = new byte[1];
      int n = decStream.read(buf, 0, 1);
      assertThat(n).isEqualTo(-1);
    }

    // Decrypt ciphertext via SmallChunksByteArrayInputStream.
    {
      InputStream ctStream = new SmallChunksByteArrayInputStream(ciphertext.toByteArray(), 10);
      InputStream decStream = decryptionStreamingAead.newDecryptingStream(ctStream, associatedData);
      byte[] decrypted = new byte[plaintext.length];
      int decryptedLength = decStream.read(decrypted);

      // Compare results;
      assertEquals("Decrypted length should be equal to plaintext length", decryptedLength,
          plaintext.length);
      TestUtil.assertByteArrayEquals(plaintext, decrypted);
    }

    // Encrypt with an OutputStream.
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    try (OutputStream encStream =
        encryptionStreamingAead.newEncryptingStream(bos, associatedData)) {
      encStream.write(plaintext);
    }
    byte[] ciphertext2 = bos.toByteArray();

    // Check that the stream encrypted ciphertext is correct.
    {
      ByteBufferChannel ciphertextChannel = new ByteBufferChannel(ciphertext2);
      ReadableByteChannel decChannel =
          decryptionStreamingAead.newDecryptingChannel(ciphertextChannel, associatedData);
      ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length);
      int unused = decChannel.read(decrypted);

      // Compare results;
      TestUtil.assertByteArrayEquals(plaintext, decrypted.array());
    }

  }

  // Methods for testEncryptDecrypt.

  /**
   * Convenience method for encrypting some plaintext.
   *
   * @param ags the streaming primitive
   * @param plaintext the plaintext to encrypt
   * @param associatedData the additional data to authenticate
   * @param firstSegmentOffset the offset of the first ciphertext segment
   * @return the ciphertext including a prefix of size ags.firstSegmentOffset
   */
  public static byte[] encryptWithChannel(
      StreamingAead ags, byte[] plaintext, byte[] associatedData, int firstSegmentOffset)
      throws Exception {
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    WritableByteChannel ctChannel = Channels.newChannel(bos);
    ctChannel.write(ByteBuffer.allocate(firstSegmentOffset));
    try (WritableByteChannel encChannel = ags.newEncryptingChannel(ctChannel, associatedData)) {
      encChannel.write(ByteBuffer.wrap(plaintext));
    }
    return bos.toByteArray();
  }

  // Methods for testEncryptDecryptLong.

  /**
   * Reads everything from plaintext, encrypt it and writes the result to ciphertext. This method is
   * used to test aynchronous encryption.
   *
   * @param ags the streaming encryption
   * @param plaintext the channel containing the plaintext
   * @param ciphertext the channel to which the ciphertext is written
   * @param associatedData the additional data to authenticate
   * @param chunkSize the size of blocks that are read and written. This size determines the
   *     temporary memory used in this method but is independent of the streaming encryption.
   * @throws RuntimeException if something goes wrong.
   */
  private static void encryptWithChannel(
      StreamingAead ags,
      ReadableByteChannel plaintext,
      WritableByteChannel ciphertext,
      byte[] associatedData,
      int chunkSize) {
    try (WritableByteChannel encChannel = ags.newEncryptingChannel(ciphertext, associatedData)) {
      ByteBuffer chunk = ByteBuffer.allocate(chunkSize);
      int read;
      do {
        chunk.clear();
        read = plaintext.read(chunk);
        if (read > 0) {
          chunk.flip();
          encChannel.write(chunk);
        }
      } while (read != -1);
    } catch (Exception ex) {
      // TODO(bleichen): What is the best way to chatch exceptions in threads?
      throw new RuntimeException(ex);
    }
  }

  private static byte[] encryptWithStream(
      StreamingAead ags, byte[] plaintext, byte[] associatedData, int firstSegmentOffset)
      throws Exception {
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    bos.write(new byte[firstSegmentOffset]);
    try (OutputStream encChannel = ags.newEncryptingStream(bos, associatedData)) {
      encChannel.write(plaintext);
    }
    byte[] ciphertext = bos.toByteArray();
    return ciphertext;
  }

  /**
   * Encrypts and decrypts some plaintext in a stream and checks that the expected plaintext is
   * returned.
   */
  private static void testEncryptDecryptWithChannel(
      StreamingAead encryptionStreamingAead,
      StreamingAead decryptionStreamingAead,
      int firstSegmentOffset,
      int plaintextSize,
      int chunkSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext =
        encryptWithChannel(encryptionStreamingAead, plaintext, associatedData, firstSegmentOffset);

    // Construct an InputStream from the ciphertext where the first
    // firstSegmentOffset bytes have already been read.
    ReadableByteChannel ctChannel =
        new SeekableByteBufferChannel(ciphertext).position(firstSegmentOffset);

    // Construct an InputStream that returns the plaintext.
    ReadableByteChannel ptChannel =
        decryptionStreamingAead.newDecryptingChannel(ctChannel, associatedData);
    int decryptedSize = 0;
    while (true) {
      ByteBuffer chunk = ByteBuffer.allocate(chunkSize);
      int read = ptChannel.read(chunk);
      if (read == -1) {
        break;
      }
      assertEquals(read, chunk.position());
      byte[] expectedPlaintext = Arrays.copyOfRange(plaintext, decryptedSize, decryptedSize + read);
      TestUtil.assertByteArrayEquals(expectedPlaintext, Arrays.copyOf(chunk.array(), read));
      decryptedSize += read;
      // ptChannel should fill chunk, unless the end of the plaintext has been reached.
      if (decryptedSize < plaintextSize) {
        assertEquals(
            "Decrypted chunk is shorter than expected\n" + ptChannel.toString(),
            chunk.limit(),
            chunk.position());
      }
    }
    assertEquals(plaintext.length, decryptedSize);
  }

  /**
   * Encrypts and decrypts some plaintext in a stream and checks that the expected plaintext is
   * returned.
   *
   * @param encryptionStreamingAead the StreamingAead test object that encrypts.
   * @param decryptionStreamingAead the StreamingAead test object that decrypts (can be the same
   *     object as {@code encryptionStreamingAead}).
   * @param firstSegmentOffset number of bytes prepended to the ciphertext stream.
   * @param plaintextSize the size of the plaintext
   * @param chunkSize decryption read chunks of this size.
   */
  private static void testEncryptDecryptWithStream(
      StreamingAead encryptionStreamingAead,
      StreamingAead decryptionStreamingAead,
      int firstSegmentOffset,
      int plaintextSize,
      int chunkSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext =
        encryptWithStream(encryptionStreamingAead, plaintext, associatedData, firstSegmentOffset);

    // Construct an InputStream from the ciphertext where the first
    // firstSegmentOffset bytes have already been read.
    InputStream ctStream = new ByteArrayInputStream(ciphertext);
    ctStream.read(new byte[firstSegmentOffset]);

    // Construct an InputStream that returns the plaintext.
    InputStream ptStream = decryptionStreamingAead.newDecryptingStream(ctStream, associatedData);
    int decryptedSize = 0;
    while (true) {
      byte[] chunk = new byte[chunkSize];
      int read = ptStream.read(chunk);
      if (read == -1) {
        break;
      }
      byte[] expected = Arrays.copyOfRange(plaintext, decryptedSize, decryptedSize + read);
      TestUtil.assertByteArrayEquals(expected, Arrays.copyOf(chunk, read));
      decryptedSize += read;
      if (read < chunkSize && decryptedSize < plaintextSize) {
        // read should block until either all requested bytes are read, the end of the stream has
        // been reached or an error occurred.
        fail("read did not return enough bytes");
      }
    }
    assertEquals("Size of decryption does not match plaintext", plaintextSize, decryptedSize);
  }

  public static void testEncryptDecrypt(
      StreamingAead ags, int firstSegmentOffset, int plaintextSize, int chunkSize)
      throws Exception {
    testEncryptDecryptWithChannel(ags, ags, firstSegmentOffset, plaintextSize, chunkSize);
    testEncryptDecryptWithStream(ags, ags, firstSegmentOffset, plaintextSize, chunkSize);
  }

  // Methods for testEncryptDecryptDifferentInstances

  public static void testEncryptDecryptDifferentInstances(
      StreamingAead ags,
      StreamingAead other,
      int firstSegmentOffset,
      int plaintextSize,
      int chunkSize)
      throws Exception {
    testEncryptDecryptWithChannel(ags, other, firstSegmentOffset, plaintextSize, chunkSize);
    testEncryptDecryptWithStream(ags, other, firstSegmentOffset, plaintextSize, chunkSize);
  }

  // Methods for testEncryptDecryptRandomAccess.

  /** Encrypt and then decrypt partially, and check that the result is the same. */
  public static void testEncryptDecryptRandomAccess(
      StreamingAead ags, int firstSegmentOffset, int plaintextSize) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext = encryptWithChannel(ags, plaintext, associatedData, firstSegmentOffset);

    // Construct a channel with random access for the ciphertext.
    SeekableByteChannel bbc = new SeekableByteBufferChannel(ciphertext);
    SeekableByteChannel ptChannel = ags.newSeekableDecryptingChannel(bbc, associatedData);

    for (int start = 0; start < plaintextSize; start += 1 + start / 2) {
      for (int length = 1; length < plaintextSize; length += 1 + length / 2) {
        ByteBuffer pt = ByteBuffer.allocate(length);
        ptChannel.position(start);
        int read = ptChannel.read(pt);
        // Expect that pt is filled unless the end of the plaintext has been reached.
        assertTrue(
            "start:" + start + " read:" + read + " length:" + length,
            pt.remaining() == 0 || start + pt.position() == plaintext.length);
        String expected = Hex.encode(Arrays.copyOfRange(plaintext, start, start + pt.position()));
        String actual = Hex.encode(Arrays.copyOf(pt.array(), pt.position()));
        assertEquals("start: " + start, expected, actual);
      }
    }
  }

  /**
   * Encrypts and decrypts some plaintext in a stream using skips and checks that the expected
   * plaintext is returned for the parts not skipped.
   *
   * @param ags the StreamingAead test object.
   * @param firstSegmentOffset number of bytes prepended to the ciphertext stream.
   * @param plaintextSize the size of the plaintext
   * @param chunkSize decryption skips and reads chunks of this size.
   */
  public static void testSkipWithStream(
      StreamingAead ags, int firstSegmentOffset, int plaintextSize, int chunkSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext = encryptWithStream(ags, plaintext, associatedData, firstSegmentOffset);

    // Runs this part twice skips the chunk number i if skipChunk == i % 2.
    for (int skipChunk = 0; skipChunk < 2; skipChunk++) {
      // Construct an InputStream from the ciphertext where the first
      // firstSegmentOffset bytes have already been read.
      InputStream ctStream = new ByteArrayInputStream(ciphertext);
      ctStream.read(new byte[firstSegmentOffset]);

      // Construct an InputStream that returns the plaintext.
      InputStream ptStream = ags.newDecryptingStream(ctStream, associatedData);
      int decryptedSize = 0;
      int chunkNumber = 0;
      while (true) {
        if (chunkNumber % 2 == skipChunk) {
          int bytesSkipped = (int) ptStream.skip(chunkSize);
          if (bytesSkipped < 0) {
            fail("skip must not return a negative integer (not even at eof).");
          }
          if (bytesSkipped == 0) {
            // The implementation here is blocking. Hence getting 0 here implies that
            // the end of the stream has been reached. However, this has not been
            // verified yet.
            assertEquals("Expecting end of stream after a 0-byte skip.", -1, ptStream.read());
            break;
          }
          decryptedSize += bytesSkipped;
          if (decryptedSize < plaintextSize) {
            // The stream is blocking. Hence we expect the number of requested
            // bytes unless the end of the stream has been reached.
            assertEquals("Size of skipped chunk is invalid", chunkSize, bytesSkipped);
          }
        } else {
          byte[] chunk = new byte[chunkSize];
          int read = ptStream.read(chunk);
          if (read == -1) {
            break;
          }
          byte[] expected = Arrays.copyOfRange(plaintext, decryptedSize, decryptedSize + read);
          TestUtil.assertByteArrayEquals(expected, Arrays.copyOf(chunk, read));
          decryptedSize += read;
          if (read < chunkSize && decryptedSize < plaintextSize) {
            // read should block until either all requested bytes are read, the end of the stream
            // has been reached or an error occurred.
            fail("read did not return enough bytes");
          }
        }
        chunkNumber += 1;
      }
      assertEquals("Size of decryption does not match plaintext", plaintextSize, decryptedSize);
    }

    // Checks whether skipping at the end of a broken ciphertext is detected.
    InputStream brokenCtStream = new ByteArrayInputStream(ciphertext, 0, ciphertext.length - 1);
    brokenCtStream.read(new byte[firstSegmentOffset]);
    InputStream brokenPtStream = ags.newDecryptingStream(brokenCtStream, associatedData);
    try {
      brokenPtStream.skip(2 * plaintextSize);
      brokenPtStream.read();
      fail("Failed to detect invalid ciphertext");
    } catch (IOException ex) {
      // expected
    }
  }

  // Methods for testEncryptSingleBytes.

  private static void testEncryptSingleBytesWithChannel(StreamingAead ags, int plaintextSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    WritableByteChannel ctChannel = Channels.newChannel(bos);
    WritableByteChannel encChannel = ags.newEncryptingChannel(ctChannel, associatedData);
    try (OutputStream encStream = Channels.newOutputStream(encChannel)) {
      for (int i = 0; i < plaintext.length; i++) {
        encStream.write(plaintext[i]);
      }
    }
    isValidCiphertext(ags, plaintext, associatedData, bos.toByteArray());
  }

  private static void testEncryptSingleBytesWithStream(StreamingAead ags, int plaintextSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] plaintext = generatePlaintext(plaintextSize);
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    WritableByteChannel ctChannel = Channels.newChannel(bos);
    WritableByteChannel encChannel = ags.newEncryptingChannel(ctChannel, associatedData);
    try (OutputStream encStream = Channels.newOutputStream(encChannel)) {
      for (int i = 0; i < plaintext.length; i++) {
        encStream.write(plaintext[i]);
      }
    }
    isValidCiphertext(ags, plaintext, associatedData, bos.toByteArray());
  }

  public static void testEncryptSingleBytes(StreamingAead ags, int plaintextSize) throws Exception {
    testEncryptSingleBytesWithChannel(ags, plaintextSize);
    testEncryptSingleBytesWithStream(ags, plaintextSize);
  }

  // Methods for testEncryptDecryptString.

  /**
   * Encrypts and decrypts a with non-ASCII characters using CharsetEncoders and CharsetDecoders.
   */
  public static void testEncryptDecryptString(StreamingAead ags) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    String stringWithNonAsciiChars = "αβγδ áéíóúý ∀∑∊∫≅⊕⊄";
    int repetitions = 1000;

    // Encrypts a sequence of strings.
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    WritableByteChannel ctChannel = Channels.newChannel(bos);
    try (Writer writer =
        Channels.newWriter(ags.newEncryptingChannel(ctChannel, associatedData), "UTF-8")) {
      for (int i = 0; i < repetitions; i++) {
        writer.write(stringWithNonAsciiChars);
      }
    }
    byte[] ciphertext = bos.toByteArray();

    // Decrypts a sequence of strings.
    // channels.newReader does not always return the requested number of characters.
    SeekableByteChannel ctBuffer = new SeekableByteBufferChannel(ByteBuffer.wrap(ciphertext));
    Reader reader =
        Channels.newReader(ags.newSeekableDecryptingChannel(ctBuffer, associatedData), "UTF-8");
    for (int i = 0; i < repetitions; i++) {
      char[] chunk = new char[stringWithNonAsciiChars.length()];
      int position = 0;
      while (position < stringWithNonAsciiChars.length()) {
        int read = reader.read(chunk, position, stringWithNonAsciiChars.length() - position);
        assertTrue("read:" + read, read > 0);
        position += read;
      }
      assertEquals("i:" + i, stringWithNonAsciiChars, new String(chunk));
    }
    int res = reader.read();
    assertEquals(-1, res);
  }

  public static void isValidCiphertext(
      StreamingAead ags, byte[] plaintext, byte[] associatedData, byte[] ciphertext)
      throws Exception {
    ByteBufferChannel ctChannel = new ByteBufferChannel(ciphertext);
    ReadableByteChannel ptChannel = ags.newDecryptingChannel(ctChannel, associatedData);
    ByteBuffer decrypted = ByteBuffer.allocate(plaintext.length + 1);
    ptChannel.read(decrypted);
    decrypted.flip();
    TestUtil.assertByteBufferContains(plaintext, decrypted);
  }

  // Methods for testModifiedCiphertext.

  /**
   * Tries to decrypt a modified ciphertext. Each call to read must either return the original
   * plaintext (e.g. when the modification in the ciphertext has not yet been read) or it must throw
   * an IOException.
   */
  private static void tryDecryptModifiedCiphertext(
      StreamingAead ags,
      int firstSegmentOffset,
      byte[] modifiedCiphertext,
      byte[] associatedData,
      int chunkSize,
      byte[] plaintext)
      throws Exception {
    SeekableByteChannel ct = new SeekableByteBufferChannel(modifiedCiphertext);
    ct.position(firstSegmentOffset);
    ReadableByteChannel ptChannel = ags.newDecryptingChannel(ct, associatedData);
    int position = 0;
    int read;
    do {
      ByteBuffer chunk = ByteBuffer.allocate(chunkSize);
      try {
        read = ptChannel.read(chunk);
      } catch (IOException ex) {
        // Detected that the ciphertext was modified.
        // TODO(bleichen): Maybe check that the stream cannot longer be accessed.
        return;
      }
      if (read > 0) {
        assertTrue("Read more plaintext than expected", position + read <= plaintext.length);
        // Everything decrypted must be equal to the original plaintext.
        TestUtil.assertByteArrayEquals(
            "Returned modified plaintext position:" + position + " size:" + read,
            Arrays.copyOf(chunk.array(), read),
            Arrays.copyOfRange(plaintext, position, position + read));
        position += read;
      }
    } while (read >= 0);
    fail("Reached end of plaintext.");
  }

  public static void testModifiedCiphertext(
      StreamingAead ags, int segmentSize, int firstSegmentOffset) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    int plaintextSize = 512;
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext = encryptWithChannel(ags, plaintext, associatedData, firstSegmentOffset);

    // truncate the ciphertext
    for (int i = 0; i < ciphertext.length; i += 8) {
      byte[] truncatedCiphertext = Arrays.copyOf(ciphertext, i);
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, truncatedCiphertext, associatedData, 128, plaintext);
    }

    // Append stuff to ciphertext
    int[] sizes = new int[] {1, (segmentSize - ciphertext.length % segmentSize), segmentSize};
    for (int appendedBytes : sizes) {
      byte[] modifiedCiphertext = concatBytes(ciphertext, new byte[appendedBytes]);
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, modifiedCiphertext, associatedData, 128, plaintext);
    }

    // flip bits
    for (int pos = firstSegmentOffset; pos < ciphertext.length; pos++) {
      byte[] modifiedCiphertext = Arrays.copyOf(ciphertext, ciphertext.length);
      modifiedCiphertext[pos] ^= (byte) 1;
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, modifiedCiphertext, associatedData, 128, plaintext);
    }

    // delete segments
    for (int segment = 0; segment < (ciphertext.length / segmentSize); segment++) {
      byte[] modifiedCiphertext =
          concatBytes(
              Arrays.copyOf(ciphertext, segment * segmentSize),
              Arrays.copyOfRange(ciphertext, (segment + 1) * segmentSize, ciphertext.length));
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, modifiedCiphertext, associatedData, 128, plaintext);
    }

    // duplicate segments
    for (int segment = 0; segment < (ciphertext.length / segmentSize); segment++) {
      byte[] modifiedCiphertext =
          concatBytes(
              Arrays.copyOf(ciphertext, (segment + 1) * segmentSize),
              Arrays.copyOfRange(ciphertext, segment * segmentSize, ciphertext.length));
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, modifiedCiphertext, associatedData, 128, plaintext);
    }

    // Modify associatedData
    // When the associated data is modified then any attempt to read plaintext must fail.
    for (int pos = 0; pos < associatedData.length; pos++) {
      byte[] modifiedAd = Arrays.copyOf(associatedData, associatedData.length);
      modifiedAd[pos] ^= (byte) 1;
      tryDecryptModifiedCiphertext(
          ags, firstSegmentOffset, ciphertext, modifiedAd, 128, new byte[0]);
    }
  }

  // Methods for testModifiedCiphertextWithSeekableByteChannel.

  /**
   * Tries to decrypt a modified ciphertext using an SeekableByteChannel. Each call to read must
   * either return the original plaintext (e.g. when the modification in the ciphertext does not
   * affect the plaintext) or it must throw an IOException.
   */
  private static void tryDecryptModifiedCiphertextWithSeekableByteChannel(
      StreamingAead ags, byte[] modifiedCiphertext, byte[] associatedData, byte[] plaintext)
      throws Exception {

    SeekableByteChannel bbc = new SeekableByteBufferChannel(modifiedCiphertext);
    SeekableByteChannel ptChannel;
    // Failing in the constructor is valid in principle, but does not happen
    // with the current implementation. Hence we don't catch these exceptions at the moment.
    try {
      ptChannel = ags.newSeekableDecryptingChannel(bbc, associatedData);
    } catch (IOException | GeneralSecurityException ex) {
      return;
    }
    for (int start = 0; start <= plaintext.length; start += 1 + start / 2) {
      for (int length = 1; length <= plaintext.length; length += 1 + length / 2) {
        ByteBuffer pt = ByteBuffer.allocate(length);
        ptChannel.position(start);
        int read;
        try {
          read = ptChannel.read(pt);
        } catch (IOException ex) {
          // Modified ciphertext was found.
          // TODO(bleichen): Currently it is undefined whether we should be able to read
          //   more plaintext from the stream (i.e. unmodified segments).
          //   However, if later calls return plaintext this has to be valid plaintext.
          continue;
        }
        if (read == -1) {
          // ptChannel claims that we reached the end of the plaintext.
          assertTrue("Incorrect truncation: ", start == plaintext.length);
        } else {
          // Expect the decrypted plaintext not to be longer than the expected plaintext.
          assertTrue(
              "start:" + start + " read:" + read + " length:" + length,
              start + read <= plaintext.length);
          // Check that the decrypted plaintext matches the original plaintext.
          String expected = Hex.encode(Arrays.copyOfRange(plaintext, start, start + pt.position()));
          String actual = Hex.encode(Arrays.copyOf(pt.array(), pt.position()));
          assertEquals("start: " + start, expected, actual);
        }
      }
    }
  }

  public static void testModifiedCiphertextWithSeekableByteChannel(
      StreamingAead ags, int segmentSize, int firstSegmentOffset) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    int plaintextSize = 2000;
    byte[] plaintext = generatePlaintext(plaintextSize);
    byte[] ciphertext = encryptWithChannel(ags, plaintext, associatedData, firstSegmentOffset);

    // truncate the ciphertext
    for (int i = 0; i < ciphertext.length; i += 64) {
      byte[] truncatedCiphertext = Arrays.copyOf(ciphertext, i);
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, truncatedCiphertext, associatedData, plaintext);
    }

    // Append stuff to ciphertext
    int[] sizes = new int[] {1, (segmentSize - ciphertext.length % segmentSize), segmentSize};
    for (int appendedBytes : sizes) {
      byte[] modifiedCiphertext = concatBytes(ciphertext, new byte[appendedBytes]);
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, modifiedCiphertext, associatedData, plaintext);
    }

    // flip bits
    for (int pos = firstSegmentOffset; pos < ciphertext.length; pos++) {
      byte[] modifiedCiphertext = Arrays.copyOf(ciphertext, ciphertext.length);
      modifiedCiphertext[pos] ^= (byte) 1;
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, modifiedCiphertext, associatedData, plaintext);
    }

    // delete segments
    for (int segment = 0; segment < (ciphertext.length / segmentSize); segment++) {
      byte[] modifiedCiphertext =
          concatBytes(
              Arrays.copyOf(ciphertext, segment * segmentSize),
              Arrays.copyOfRange(ciphertext, (segment + 1) * segmentSize, ciphertext.length));
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, modifiedCiphertext, associatedData, plaintext);
    }

    // duplicate segments
    for (int segment = 0; segment < (ciphertext.length / segmentSize); segment++) {
      byte[] modifiedCiphertext =
          concatBytes(
              Arrays.copyOf(ciphertext, (segment + 1) * segmentSize),
              Arrays.copyOfRange(ciphertext, segment * segmentSize, ciphertext.length));
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, modifiedCiphertext, associatedData, plaintext);
    }

    // Modify associatedData
    // When the associated data is modified then any attempt to read plaintext must fail.
    for (int pos = 0; pos < associatedData.length; pos++) {
      byte[] modifiedAad = Arrays.copyOf(associatedData, associatedData.length);
      modifiedAad[pos] ^= (byte) 1;
      tryDecryptModifiedCiphertextWithSeekableByteChannel(
          ags, ciphertext, modifiedAad, new byte[0]);
    }
  }

  /**
   * Constructs a ReadableByteChannel with ciphertext from a ReadableByteChannel. The method
   * constructs a new thread that is used to encrypt the plaintext. TODO(bleichen): Using
   * PipedInputStream may have performance problems.
   */
  private static ReadableByteChannel createCiphertextChannel(
      final StreamingAead ags,
      final ReadableByteChannel plaintext,
      final byte[] associatedData,
      final int chunkSize)
      throws Exception {
    PipedOutputStream output = new PipedOutputStream();
    PipedInputStream result = new PipedInputStream(output);
    final WritableByteChannel ciphertext = Channels.newChannel(output);
    new Thread(
            new Runnable() {
              @Override
              public void run() {
                encryptWithChannel(ags, plaintext, ciphertext, associatedData, chunkSize);
              }
            })
        .start();
    return Channels.newChannel(result);
  }

  /** Encrypt and decrypt a long ciphertext. */
  public static void testEncryptDecryptLong(StreamingAead ags, long plaintextSize)
      throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    ReadableByteChannel plaintext = new PseudorandomReadableByteChannel(plaintextSize);
    ReadableByteChannel copy = new PseudorandomReadableByteChannel(plaintextSize);
    ReadableByteChannel ciphertext =
        createCiphertextChannel(ags, plaintext, associatedData, 1 << 20);
    ReadableByteChannel decrypted = ags.newDecryptingChannel(ciphertext, associatedData);
    byte[] chunk = new byte[1 << 15];
    int read;
    long decryptedBytes = 0;
    do {
      read = decrypted.read(ByteBuffer.wrap(chunk));
      if (read > 0) {
        ByteBuffer expected = ByteBuffer.allocate(read);
        int unused = copy.read(expected);
        decryptedBytes += read;
        TestUtil.assertByteArrayEquals(expected.array(), Arrays.copyOf(chunk, read));
      }
    } while (read != -1);
    assertEquals(plaintextSize, decryptedBytes);
  }

  // Methods for testFileEncryption.

  /** Encrypt some plaintext to a file, then decrypt from the file */
  private static void testFileEncryptionWithChannel(
      StreamingAead ags, File tmpFile, int plaintextSize) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    SeekableByteBufferChannel plaintext =
        new SeekableByteBufferChannel(generatePlaintext(plaintextSize));

    // Encrypt to file
    WritableByteChannel bc =
        ags.newEncryptingChannel(new FileOutputStream(tmpFile).getChannel(), associatedData);
    int chunkSize = 1000;
    ByteBuffer chunk = ByteBuffer.allocate(chunkSize);
    int read;
    do {
      chunk.clear();
      read = plaintext.read(chunk);
      if (read > 0) {
        chunk.flip();
        bc.write(chunk);
      }
    } while (read != -1);
    bc.close();

    // Decrypt the whole file and compare to plaintext
    plaintext.rewind();
    ReadableByteChannel ptStream =
        ags.newDecryptingChannel(new FileInputStream(tmpFile).getChannel(), associatedData);
    int decryptedSize = 0;
    do {
      ByteBuffer decrypted = ByteBuffer.allocate(512);
      read = ptStream.read(decrypted);
      if (read > 0) {
        ByteBuffer expected = ByteBuffer.allocate(read);
        assertEquals(plaintext.read(expected), read);
        decrypted.flip();
        TestUtil.assertByteBufferContains(expected.array(), decrypted);
        decryptedSize += read;
      }
    } while (read != -1);
    assertEquals(plaintextSize, decryptedSize);

    // Decrypt file partially using FileChannel and compare to plaintext
    plaintext.rewind();
    SeekableByteChannel ptChannel =
        ags.newSeekableDecryptingChannel(new FileInputStream(tmpFile).getChannel(), associatedData);
    SecureRandom random = new SecureRandom();
    for (int samples = 0; samples < 100; samples++) {
      int start = random.nextInt(plaintextSize);
      int length = random.nextInt(plaintextSize / 100 + 1);
      ByteBuffer decrypted = ByteBuffer.allocate(length);
      ptChannel.position(start);
      read = ptChannel.read(decrypted);
      // We expect that all read of ctChannel return the requested number of bytes.
      // Hence we also expect that ptChannel returns the maximal number of bytes.
      if (read < length && read + start < plaintextSize) {
        fail(
            "Plaintext size is smaller than expected; read:"
                + read
                + " position:"
                + start
                + " length:"
                + length);
      }
      byte[] expected = new byte[read];
      plaintext.position(start);
      assertEquals(plaintext.read(ByteBuffer.wrap(expected)), read);
      decrypted.flip();
      TestUtil.assertByteBufferContains(expected, decrypted);
    }
  }

  /**
   * Encrypts some plaintext to a file using FileOutputStream, then decrypt with a FileInputStream.
   * Reading and writing is done byte by byte.
   */
  private static void testFileEncryptionWithStream(
      StreamingAead ags, File tmpFile, int plaintextSize) throws Exception {
    byte[] associatedData = Hex.decode("aabbccddeeff");
    byte[] pt = generatePlaintext(plaintextSize);
    FileOutputStream ctStream = new FileOutputStream(tmpFile);
    WritableByteChannel channel = Channels.newChannel(ctStream);
    WritableByteChannel encChannel = ags.newEncryptingChannel(channel, associatedData);
    OutputStream encStream = Channels.newOutputStream(encChannel);

    // Writing single bytes appears to be the most troubling case.
    for (int i = 0; i < pt.length; i++) {
      encStream.write(pt[i]);
    }
    encStream.close();

    FileInputStream inpStream = new FileInputStream(tmpFile);
    ReadableByteChannel inpChannel = Channels.newChannel(inpStream);
    ReadableByteChannel decryptedChannel = ags.newDecryptingChannel(inpChannel, associatedData);
    InputStream decrypted = Channels.newInputStream(decryptedChannel);
    int decryptedSize = 0;
    int read;
    while (true) {
      read = decrypted.read();
      if (read == -1) {
        break;
      }
      if (read != (pt[decryptedSize] & 0xff)) {
        fail(
            "Incorrect decryption at position "
                + decryptedSize
                + " expected: "
                + pt[decryptedSize]
                + " read:"
                + read);
      }
      decryptedSize += 1;
    }
    assertEquals(plaintextSize, decryptedSize);
  }

  public static void testFileEncryption(StreamingAead ags, File tmpFile, int plaintextSize)
      throws Exception {
    testFileEncryptionWithChannel(ags, tmpFile, plaintextSize);
    testFileEncryptionWithStream(ags, tmpFile, plaintextSize);
  }

  private StreamingTestUtil() {}
}
