1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.core.appdigest; 18 19 import org.jspecify.annotations.NonNull; 20 21 import java.io.IOException; 22 import java.io.RandomAccessFile; 23 import java.nio.ByteBuffer; 24 import java.nio.ByteOrder; 25 import java.nio.channels.FileChannel; 26 import java.security.MessageDigest; 27 import java.security.NoSuchAlgorithmException; 28 import java.util.ArrayList; 29 30 /** 31 * VerityTreeBuilder is used to generate the root hash of verity tree built from the input file. 32 * This version was adopted from VerityTreeBuilder.java in ApkSigner tool and changed to work on 33 * all target APIs. 34 */ 35 class VerityTreeBuilder { 36 /** 37 * Maximum size (in bytes) of each node of the tree. 38 */ 39 private static final int CHUNK_SIZE = 4096; 40 41 /** 42 * Digest algorithm (JCA Digest algorithm name) used in the tree. 43 */ 44 private static final String JCA_ALGORITHM = "SHA-256"; 45 /** 46 * Typical prefetch size. 47 */ 48 private static final int MAX_PREFETCH_CHUNKS = 1024; 49 computeChunkVerityTreeAndDigest(@onNull String apkPath)50 static byte[] computeChunkVerityTreeAndDigest(@NonNull String apkPath) 51 throws IOException, NoSuchAlgorithmException { 52 RandomAccessFile apk = new RandomAccessFile(apkPath, "r"); 53 try { 54 final MessageDigest md = getNewMessageDigest(); 55 ByteBuffer tree = generateVerityTree(md, apk); 56 return getRootHashFromTree(md, tree); 57 } finally { 58 apk.close(); 59 } 60 } 61 VerityTreeBuilder()62 private VerityTreeBuilder() { 63 } 64 65 private interface DataSource { copyTo(long offset, int size, ByteBuffer dest)66 void copyTo(long offset, int size, ByteBuffer dest) throws IOException; 67 } 68 69 /** 70 * Returns the digested root hash from the top level (only page) of a verity tree. 71 */ getRootHashFromTree(MessageDigest md, ByteBuffer verityBuffer)72 private static byte[] getRootHashFromTree(MessageDigest md, ByteBuffer verityBuffer) { 73 ByteBuffer firstPage = slice(verityBuffer.asReadOnlyBuffer(), 0, CHUNK_SIZE); 74 return digest(md, firstPage); 75 } 76 77 /** 78 * Returns the byte buffer that contains the whole verity tree. 79 * 80 * The tree is built bottom up. The bottom level has 256-bit digest for each 4 KB block in the 81 * input file. If the total size is larger than 4 KB, take this level as input and repeat the 82 * same procedure, until the level is within 4 KB. If salt is given, it will apply to each 83 * digestion before the actual data. 84 * 85 * The returned root hash is calculated from the last level of 4 KB chunk, similarly with salt. 86 * 87 * The tree is currently stored only in memory and is never written out. Nevertheless, it is 88 * the actual verity tree format on disk, and is supposed to be re-generated on device. 89 */ generateVerityTree(MessageDigest md, final RandomAccessFile file)90 private static ByteBuffer generateVerityTree(MessageDigest md, final RandomAccessFile file) 91 throws IOException { 92 int digestSize = md.getDigestLength(); 93 94 // Calculate the summed area table of level size. In other word, this is the offset 95 // table of each level, plus the next non-existing level. 96 int[] levelOffset = calculateLevelOffset(file.length(), digestSize); 97 98 ByteBuffer verityBuffer = ByteBuffer.allocate(levelOffset[levelOffset.length - 1]).order( 99 ByteOrder.LITTLE_ENDIAN); 100 101 // Generate the hash tree bottom-up. 102 for (int i = levelOffset.length - 2; i >= 0; i--) { 103 ByteBuffer middleBuffer = slice(verityBuffer, levelOffset[i], levelOffset[i + 1]); 104 final long srcSize; 105 if (i == levelOffset.length - 2) { 106 srcSize = file.length(); 107 final FileChannel channel = file.getChannel(); 108 digestDataByChunks(md, srcSize, new DataSource() { 109 @Override 110 public void copyTo(long offset, int size, ByteBuffer dest) throws IOException { 111 if (size == 0) { 112 return; 113 } 114 if (size > dest.remaining()) { 115 throw new IOException(); 116 } 117 118 long offsetInFile = offset; 119 int remaining = size; 120 int prevLimit = dest.limit(); 121 try { 122 // FileChannel.read(ByteBuffer) reads up to dest.remaining(). Thus, 123 // we need to adjust the buffer's limit to avoid reading more than 124 // size bytes. 125 dest.limit(dest.position() + size); 126 while (remaining > 0) { 127 int chunkSize; 128 synchronized (file) { 129 channel.position(offsetInFile); 130 chunkSize = channel.read(dest); 131 } 132 offsetInFile += chunkSize; 133 remaining -= chunkSize; 134 } 135 } finally { 136 dest.limit(prevLimit); 137 } 138 } 139 }, middleBuffer); 140 } else { 141 srcSize = (long) levelOffset[i + 2] - levelOffset[i + 1]; 142 final ByteBuffer srcBuffer = slice(verityBuffer.asReadOnlyBuffer(), 143 levelOffset[i + 1], levelOffset[i + 2]).asReadOnlyBuffer(); 144 digestDataByChunks(md, srcSize, new DataSource() { 145 @Override 146 public void copyTo(long offset, int size, ByteBuffer dest) throws IOException { 147 int chunkPosition = (int) offset; 148 int chunkLimit = chunkPosition + size; 149 150 final ByteBuffer slice; 151 synchronized (srcBuffer) { 152 srcBuffer.position(0); // to ensure position <= limit invariant 153 srcBuffer.limit(chunkLimit); 154 srcBuffer.position(chunkPosition); 155 slice = srcBuffer.slice(); 156 } 157 158 dest.put(slice); 159 } 160 }, middleBuffer); 161 } 162 163 // If the output is not full chunk, pad with 0s. 164 long totalOutput = divideRoundup(srcSize, CHUNK_SIZE) * digestSize; 165 int incomplete = (int) (totalOutput % CHUNK_SIZE); 166 if (incomplete > 0) { 167 byte[] padding = new byte[CHUNK_SIZE - incomplete]; 168 middleBuffer.put(padding, 0, padding.length); 169 } 170 } 171 return verityBuffer; 172 } 173 174 /** 175 * Returns an array of summed area table of level size in the verity tree. In other words, the 176 * returned array is offset of each level in the verity tree file format, plus an additional 177 * offset of the next non-existing level (i.e. end of the last level + 1). Thus the array size 178 * is level + 1. 179 */ calculateLevelOffset(long dataSize, int digestSize)180 private static int[] calculateLevelOffset(long dataSize, int digestSize) { 181 // Compute total size of each level, bottom to top. 182 ArrayList<Long> levelSize = new ArrayList<>(); 183 while (true) { 184 long chunkCount = divideRoundup(dataSize, CHUNK_SIZE); 185 long size = CHUNK_SIZE * divideRoundup(chunkCount * digestSize, CHUNK_SIZE); 186 levelSize.add(size); 187 if (chunkCount * digestSize <= CHUNK_SIZE) { 188 break; 189 } 190 dataSize = chunkCount * digestSize; 191 } 192 193 // Reverse and convert to summed area table. 194 int[] levelOffset = new int[levelSize.size() + 1]; 195 levelOffset[0] = 0; 196 for (int i = 0; i < levelSize.size(); i++) { 197 final long size = levelSize.get(levelSize.size() - i - 1); 198 // We don't support verity tree if it is larger then Integer.MAX_VALUE. 199 levelOffset[i + 1] = levelOffset[i] + (int) size; 200 } 201 return levelOffset; 202 } 203 204 /** 205 * Digest data source by chunks then feeds them to the sink one by one. If the last unit is 206 * less than the chunk size and padding is desired, feed with extra padding 0 to fill up the 207 * chunk before digesting. 208 */ digestDataByChunks(MessageDigest md, long size, DataSource dataSource, ByteBuffer dataSink)209 private static void digestDataByChunks(MessageDigest md, long size, 210 DataSource dataSource, ByteBuffer dataSink) throws IOException { 211 final int chunks = (int) divideRoundup(size, CHUNK_SIZE); 212 213 /* Single IO operation size, in chunks. */ 214 final int ioSizeChunks = MAX_PREFETCH_CHUNKS; 215 216 final byte[][] hashes = new byte[chunks][]; 217 218 // Reading the input file as fast as we can. 219 final long maxReadSize = ioSizeChunks * CHUNK_SIZE; 220 221 long readOffset = 0; 222 int startChunkIndex = 0; 223 while (readOffset < size) { 224 final long readLimit = Math.min(readOffset + maxReadSize, size); 225 final int readSize = (int) (readLimit - readOffset); 226 final int bufferSizeChunks = (int) divideRoundup(readSize, CHUNK_SIZE); 227 228 // Overllocating to zero-pad last chunk. 229 // With 4MiB block size, 32 threads and 4 queue size we might allocate up to 144MiB. 230 final ByteBuffer buffer = ByteBuffer.allocate(bufferSizeChunks * CHUNK_SIZE); 231 dataSource.copyTo(readOffset, readSize, buffer); 232 buffer.rewind(); 233 234 final int readChunkIndex = startChunkIndex; 235 for (int offset = 0, finish = buffer.capacity(), chunkIndex = readChunkIndex; 236 offset < finish; offset += CHUNK_SIZE, ++chunkIndex) { 237 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE); 238 hashes[chunkIndex] = digest(md, chunk); 239 } 240 241 startChunkIndex += bufferSizeChunks; 242 readOffset += readSize; 243 } 244 245 // Streaming hashes back. 246 for (byte[] hash : hashes) { 247 dataSink.put(hash, 0, hash.length); 248 } 249 } 250 251 /** 252 * Obtains a new instance of the message digest algorithm. 253 */ getNewMessageDigest()254 private static MessageDigest getNewMessageDigest() throws NoSuchAlgorithmException { 255 return MessageDigest.getInstance(JCA_ALGORITHM); 256 } 257 258 /** Returns the digest of data with salt prepended. */ digest(MessageDigest md, ByteBuffer data)259 private static byte[] digest(MessageDigest md, ByteBuffer data) { 260 md.reset(); 261 md.update(data); 262 return md.digest(); 263 } 264 265 /** Divides a number and round up to the closest integer. */ divideRoundup(long dividend, long divisor)266 private static long divideRoundup(long dividend, long divisor) { 267 return (dividend + divisor - 1) / divisor; 268 } 269 270 /** Returns a slice of the buffer with shared the content. */ slice(ByteBuffer buffer, int begin, int end)271 private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) { 272 ByteBuffer b = buffer.duplicate(); 273 b.position(0); // to ensure position <= limit invariant. 274 b.limit(end); 275 b.position(begin); 276 return b.slice(); 277 } 278 } 279