1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.core.appdigest;
18 
19 import org.jspecify.annotations.NonNull;
20 
21 import java.io.IOException;
22 import java.io.RandomAccessFile;
23 import java.nio.ByteBuffer;
24 import java.nio.ByteOrder;
25 import java.nio.channels.FileChannel;
26 import java.security.MessageDigest;
27 import java.security.NoSuchAlgorithmException;
28 import java.util.ArrayList;
29 
30 /**
31  * VerityTreeBuilder is used to generate the root hash of verity tree built from the input file.
32  * This version was adopted from VerityTreeBuilder.java in ApkSigner tool and changed to work on
33  * all target APIs.
34  */
35 class VerityTreeBuilder {
36     /**
37      * Maximum size (in bytes) of each node of the tree.
38      */
39     private static final int CHUNK_SIZE = 4096;
40 
41     /**
42      * Digest algorithm (JCA Digest algorithm name) used in the tree.
43      */
44     private static final String JCA_ALGORITHM = "SHA-256";
45     /**
46      * Typical prefetch size.
47      */
48     private static final int MAX_PREFETCH_CHUNKS = 1024;
49 
computeChunkVerityTreeAndDigest(@onNull String apkPath)50     static byte[] computeChunkVerityTreeAndDigest(@NonNull String apkPath)
51             throws IOException, NoSuchAlgorithmException {
52         RandomAccessFile apk = new RandomAccessFile(apkPath, "r");
53         try {
54             final MessageDigest md = getNewMessageDigest();
55             ByteBuffer tree = generateVerityTree(md, apk);
56             return getRootHashFromTree(md, tree);
57         } finally {
58             apk.close();
59         }
60     }
61 
VerityTreeBuilder()62     private VerityTreeBuilder() {
63     }
64 
65     private interface DataSource {
copyTo(long offset, int size, ByteBuffer dest)66         void copyTo(long offset, int size, ByteBuffer dest) throws IOException;
67     }
68 
69     /**
70      * Returns the digested root hash from the top level (only page) of a verity tree.
71      */
getRootHashFromTree(MessageDigest md, ByteBuffer verityBuffer)72     private static byte[] getRootHashFromTree(MessageDigest md, ByteBuffer verityBuffer) {
73         ByteBuffer firstPage = slice(verityBuffer.asReadOnlyBuffer(), 0, CHUNK_SIZE);
74         return digest(md, firstPage);
75     }
76 
77     /**
78      * Returns the byte buffer that contains the whole verity tree.
79      *
80      * The tree is built bottom up. The bottom level has 256-bit digest for each 4 KB block in the
81      * input file.  If the total size is larger than 4 KB, take this level as input and repeat the
82      * same procedure, until the level is within 4 KB.  If salt is given, it will apply to each
83      * digestion before the actual data.
84      *
85      * The returned root hash is calculated from the last level of 4 KB chunk, similarly with salt.
86      *
87      * The tree is currently stored only in memory and is never written out.  Nevertheless, it is
88      * the actual verity tree format on disk, and is supposed to be re-generated on device.
89      */
generateVerityTree(MessageDigest md, final RandomAccessFile file)90     private static ByteBuffer generateVerityTree(MessageDigest md, final RandomAccessFile file)
91             throws IOException {
92         int digestSize = md.getDigestLength();
93 
94         // Calculate the summed area table of level size. In other word, this is the offset
95         // table of each level, plus the next non-existing level.
96         int[] levelOffset = calculateLevelOffset(file.length(), digestSize);
97 
98         ByteBuffer verityBuffer = ByteBuffer.allocate(levelOffset[levelOffset.length - 1]).order(
99                 ByteOrder.LITTLE_ENDIAN);
100 
101         // Generate the hash tree bottom-up.
102         for (int i = levelOffset.length - 2; i >= 0; i--) {
103             ByteBuffer middleBuffer = slice(verityBuffer, levelOffset[i], levelOffset[i + 1]);
104             final long srcSize;
105             if (i == levelOffset.length - 2) {
106                 srcSize = file.length();
107                 final FileChannel channel = file.getChannel();
108                 digestDataByChunks(md, srcSize, new DataSource() {
109                     @Override
110                     public void copyTo(long offset, int size, ByteBuffer dest) throws IOException {
111                         if (size == 0) {
112                             return;
113                         }
114                         if (size > dest.remaining()) {
115                             throw new IOException();
116                         }
117 
118                         long offsetInFile = offset;
119                         int remaining = size;
120                         int prevLimit = dest.limit();
121                         try {
122                             // FileChannel.read(ByteBuffer) reads up to dest.remaining(). Thus,
123                             // we need to adjust the buffer's limit to avoid reading more than
124                             // size bytes.
125                             dest.limit(dest.position() + size);
126                             while (remaining > 0) {
127                                 int chunkSize;
128                                 synchronized (file) {
129                                     channel.position(offsetInFile);
130                                     chunkSize = channel.read(dest);
131                                 }
132                                 offsetInFile += chunkSize;
133                                 remaining -= chunkSize;
134                             }
135                         } finally {
136                             dest.limit(prevLimit);
137                         }
138                     }
139                 }, middleBuffer);
140             } else {
141                 srcSize = (long) levelOffset[i + 2] - levelOffset[i + 1];
142                 final ByteBuffer srcBuffer = slice(verityBuffer.asReadOnlyBuffer(),
143                         levelOffset[i + 1], levelOffset[i + 2]).asReadOnlyBuffer();
144                 digestDataByChunks(md, srcSize, new DataSource() {
145                     @Override
146                     public void copyTo(long offset, int size, ByteBuffer dest) throws IOException {
147                         int chunkPosition = (int) offset;
148                         int chunkLimit = chunkPosition + size;
149 
150                         final ByteBuffer slice;
151                         synchronized (srcBuffer) {
152                             srcBuffer.position(0);  // to ensure position <= limit invariant
153                             srcBuffer.limit(chunkLimit);
154                             srcBuffer.position(chunkPosition);
155                             slice = srcBuffer.slice();
156                         }
157 
158                         dest.put(slice);
159                     }
160                 }, middleBuffer);
161             }
162 
163             // If the output is not full chunk, pad with 0s.
164             long totalOutput = divideRoundup(srcSize, CHUNK_SIZE) * digestSize;
165             int incomplete = (int) (totalOutput % CHUNK_SIZE);
166             if (incomplete > 0) {
167                 byte[] padding = new byte[CHUNK_SIZE - incomplete];
168                 middleBuffer.put(padding, 0, padding.length);
169             }
170         }
171         return verityBuffer;
172     }
173 
174     /**
175      * Returns an array of summed area table of level size in the verity tree.  In other words, the
176      * returned array is offset of each level in the verity tree file format, plus an additional
177      * offset of the next non-existing level (i.e. end of the last level + 1).  Thus the array size
178      * is level + 1.
179      */
calculateLevelOffset(long dataSize, int digestSize)180     private static int[] calculateLevelOffset(long dataSize, int digestSize) {
181         // Compute total size of each level, bottom to top.
182         ArrayList<Long> levelSize = new ArrayList<>();
183         while (true) {
184             long chunkCount = divideRoundup(dataSize, CHUNK_SIZE);
185             long size = CHUNK_SIZE * divideRoundup(chunkCount * digestSize, CHUNK_SIZE);
186             levelSize.add(size);
187             if (chunkCount * digestSize <= CHUNK_SIZE) {
188                 break;
189             }
190             dataSize = chunkCount * digestSize;
191         }
192 
193         // Reverse and convert to summed area table.
194         int[] levelOffset = new int[levelSize.size() + 1];
195         levelOffset[0] = 0;
196         for (int i = 0; i < levelSize.size(); i++) {
197             final long size = levelSize.get(levelSize.size() - i - 1);
198             // We don't support verity tree if it is larger then Integer.MAX_VALUE.
199             levelOffset[i + 1] = levelOffset[i] + (int) size;
200         }
201         return levelOffset;
202     }
203 
204     /**
205      * Digest data source by chunks then feeds them to the sink one by one.  If the last unit is
206      * less than the chunk size and padding is desired, feed with extra padding 0 to fill up the
207      * chunk before digesting.
208      */
digestDataByChunks(MessageDigest md, long size, DataSource dataSource, ByteBuffer dataSink)209     private static void digestDataByChunks(MessageDigest md, long size,
210             DataSource dataSource, ByteBuffer dataSink) throws IOException {
211         final int chunks = (int) divideRoundup(size, CHUNK_SIZE);
212 
213         /* Single IO operation size, in chunks. */
214         final int ioSizeChunks = MAX_PREFETCH_CHUNKS;
215 
216         final byte[][] hashes = new byte[chunks][];
217 
218         // Reading the input file as fast as we can.
219         final long maxReadSize = ioSizeChunks * CHUNK_SIZE;
220 
221         long readOffset = 0;
222         int startChunkIndex = 0;
223         while (readOffset < size) {
224             final long readLimit = Math.min(readOffset + maxReadSize, size);
225             final int readSize = (int) (readLimit - readOffset);
226             final int bufferSizeChunks = (int) divideRoundup(readSize, CHUNK_SIZE);
227 
228             // Overllocating to zero-pad last chunk.
229             // With 4MiB block size, 32 threads and 4 queue size we might allocate up to 144MiB.
230             final ByteBuffer buffer = ByteBuffer.allocate(bufferSizeChunks * CHUNK_SIZE);
231             dataSource.copyTo(readOffset, readSize, buffer);
232             buffer.rewind();
233 
234             final int readChunkIndex = startChunkIndex;
235             for (int offset = 0, finish = buffer.capacity(), chunkIndex = readChunkIndex;
236                     offset < finish; offset += CHUNK_SIZE, ++chunkIndex) {
237                 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE);
238                 hashes[chunkIndex] = digest(md, chunk);
239             }
240 
241             startChunkIndex += bufferSizeChunks;
242             readOffset += readSize;
243         }
244 
245         // Streaming hashes back.
246         for (byte[] hash : hashes) {
247             dataSink.put(hash, 0, hash.length);
248         }
249     }
250 
251     /**
252      * Obtains a new instance of the message digest algorithm.
253      */
getNewMessageDigest()254     private static MessageDigest getNewMessageDigest() throws NoSuchAlgorithmException {
255         return MessageDigest.getInstance(JCA_ALGORITHM);
256     }
257 
258     /** Returns the digest of data with salt prepended. */
digest(MessageDigest md, ByteBuffer data)259     private static byte[] digest(MessageDigest md, ByteBuffer data) {
260         md.reset();
261         md.update(data);
262         return md.digest();
263     }
264 
265     /** Divides a number and round up to the closest integer. */
divideRoundup(long dividend, long divisor)266     private static long divideRoundup(long dividend, long divisor) {
267         return (dividend + divisor - 1) / divisor;
268     }
269 
270     /** Returns a slice of the buffer with shared the content. */
slice(ByteBuffer buffer, int begin, int end)271     private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) {
272         ByteBuffer b = buffer.duplicate();
273         b.position(0);  // to ensure position <= limit invariant.
274         b.limit(end);
275         b.position(begin);
276         return b.slice();
277     }
278 }
279