• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 package com.ohos.hapsigntool.codesigning.fsverity;
17 
18 import com.ohos.hapsigntool.codesigning.utils.DigestUtils;
19 
20 import java.io.IOException;
21 import java.io.InputStream;
22 import java.nio.ByteBuffer;
23 import java.security.NoSuchAlgorithmException;
24 import java.util.ArrayList;
25 import java.util.concurrent.ArrayBlockingQueue;
26 import java.util.concurrent.ExecutorService;
27 import java.util.concurrent.Phaser;
28 import java.util.concurrent.ThreadPoolExecutor;
29 import java.util.concurrent.TimeUnit;
30 
31 /**
32  * Merkle tree builder
33  *
34  * @since 2023/06/05
35  */
36 public class MerkleTreeBuilder implements AutoCloseable {
37     private static final int FSVERITY_HASH_PAGE_SIZE = 4096;
38 
39     private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L;
40 
41     private static final int CHUNK_SIZE = 4096;
42 
43     private static final long MAX_READ_SIZE = 4194304L;
44 
45     private static final int MAX_PROCESSORS = 32;
46 
47     private static final int BLOCKINGQUEUE = 4;
48 
49     private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors());
50 
51     private String mAlgorithm = "SHA-256";
52 
53     private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS,
54         new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy());
55 
56     /**
57      * Turn off multitasking
58      */
close()59     public void close() {
60         this.mPools.shutdownNow();
61     }
62 
63     /**
64      * set algorithm
65      *
66      * @param algorithm hash algorithm
67      */
setAlgorithm(String algorithm)68     private void setAlgorithm(String algorithm) {
69         this.mAlgorithm = algorithm;
70     }
71 
72     /**
73      * translation inputStream to hash data
74      *
75      * @param inputStream  input stream for generating merkle tree
76      * @param size         total size of input stream
77      * @param outputBuffer hash data
78      * @throws IOException if error
79      */
transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)80     private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)
81         throws IOException {
82         if (size == 0) {
83             throw new IOException("Input size is empty");
84         } else if (size > INPUTSTREAM_MAX_SIZE) {
85             throw new IOException("Input size is too long");
86         }
87         int count = (int) getChunkCount(size, MAX_READ_SIZE);
88         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
89         byte[][] hashes = new byte[chunks][];
90         long readOffset = 0L;
91         Phaser tasks = new Phaser(1);
92         for (int i = 0; i < count; i++) {
93             long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
94             int readSize = (int) (readLimit - readOffset);
95             int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE);
96 
97             ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize);
98             byte[] buffer = new byte[CHUNK_SIZE];
99             int num;
100             int offset = 0;
101             int len = CHUNK_SIZE;
102             while ((num = inputStream.read(buffer, 0, len)) > 0) {
103                 byteBuffer.put(buffer, 0, num);
104                 offset += num;
105                 len = Math.min(CHUNK_SIZE, readSize - offset);
106                 if (len <= 0 || offset == readSize) {
107                     break;
108                 }
109             }
110             if (offset != readSize) {
111                 throw new IOException("IOException read buffer from input errorLHJ.");
112             }
113             byteBuffer.flip();
114             int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i);
115             runHashTask(hashes, tasks, byteBuffer, readChunkIndex);
116             readOffset += readSize;
117         }
118         tasks.arriveAndAwaitAdvance();
119         for (byte[] hash : hashes) {
120             outputBuffer.put(hash, 0, hash.length);
121         }
122     }
123 
124     /**
125      * split buffer by begin and end information
126      *
127      * @param buffer original buffer
128      * @param begin  begin position
129      * @param end    end position
130      * @return slice buffer
131      */
slice(ByteBuffer buffer, int begin, int end)132     private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) {
133         ByteBuffer tempBuffer = buffer.duplicate();
134         tempBuffer.position(0);
135         tempBuffer.limit(end);
136         tempBuffer.position(begin);
137         return tempBuffer.slice();
138     }
139 
140     /**
141      * calculate merkle tree level and size by data size and digest size
142      *
143      * @param dataSize   original data size
144      * @param digestSize algorithm data size
145      * @return level offset list, contains the offset of
146      * each level from the root node to the leaf node
147      */
getOffsetArrays(long dataSize, int digestSize)148     private static int[] getOffsetArrays(long dataSize, int digestSize) {
149         ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize);
150         int[] levelOffset = new int[levelSize.size() + 1];
151         levelOffset[0] = 0;
152         for (int i = 0; i < levelSize.size(); i++) {
153             levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1));
154         }
155         return levelOffset;
156     }
157 
158     /**
159      * calculate data size list by data size and digest size
160      *
161      * @param dataSize   original data size
162      * @param digestSize algorithm data size
163      * @return data size list, contains the offset of
164      * each level from the root node to the leaf node
165      */
getLevelSize(long dataSize, int digestSize)166     private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) {
167         ArrayList<Long> levelSize = new ArrayList<>();
168         long fullChunkSize = 0L;
169         long originalDataSize = dataSize;
170         do {
171             fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
172             long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
173             levelSize.add(size);
174             originalDataSize = fullChunkSize;
175         } while (fullChunkSize > CHUNK_SIZE);
176         return levelSize;
177     }
178 
runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)179     private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) {
180         Runnable task = () -> {
181             int offset = 0;
182             int bufferSize = buffer.capacity();
183             int index = readChunkIndex;
184             while (offset < bufferSize) {
185                 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE);
186                 byte[] tempByte = new byte[CHUNK_SIZE];
187                 chunk.get(tempByte);
188                 try {
189                     hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm);
190                 } catch (NoSuchAlgorithmException e) {
191                     throw new IllegalStateException(e);
192                 }
193                 offset += CHUNK_SIZE;
194             }
195             tasks.arriveAndDeregister();
196         };
197         tasks.register();
198         this.mPools.execute(task);
199     }
200 
201     /**
202      * hash data of buffer
203      *
204      * @param inputBuffer  original data
205      * @param outputBuffer hash data
206      */
transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)207     private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) {
208         long size = inputBuffer.capacity();
209         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
210         byte[][] hashes = new byte[chunks][];
211         Phaser tasks = new Phaser(1);
212         long readOffset = 0L;
213         int startChunkIndex = 0;
214         while (readOffset < size) {
215             long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
216             ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit);
217             buffer.rewind();
218             int readChunkIndex = startChunkIndex;
219             runHashTask(hashes, tasks, buffer, readChunkIndex);
220             int readSize = (int) (readLimit - readOffset);
221             startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE);
222             readOffset += readSize;
223         }
224         tasks.arriveAndAwaitAdvance();
225         for (byte[] hash : hashes) {
226             outputBuffer.put(hash, 0, hash.length);
227         }
228     }
229 
230     /**
231      * generate merkle tree of given input
232      *
233      * @param inputStream           input stream for generate merkle tree
234      * @param size                  total size of input stream
235      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
236      * @return merkle tree
237      * @throws IOException              if error
238      * @throws NoSuchAlgorithmException if error
239      */
generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)240     public MerkleTree generateMerkleTree(InputStream inputStream, long size,
241         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException {
242         setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm());
243         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
244         int[] offsetArrays = getOffsetArrays(size, digestSize);
245         ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]);
246         generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize);
247         generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize);
248         return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm);
249     }
250 
251     /**
252      * translation inputBuffer arrays to hash ByteBuffer
253      *
254      * @param inputStream  input stream for generate merkle tree
255      * @param size         total size of input stream
256      * @param outputBuffer hash data
257      * @param offsetArrays level offset
258      * @param digestSize   algorithm output byte size
259      * @throws IOException if error
260      */
generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)261     private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer,
262         int[] offsetArrays, int digestSize) throws IOException {
263         int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2];
264         int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1];
265         ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd);
266         transInputStreamToHashData(inputStream, size, hashBuffer);
267         dataRoundupChunkSize(hashBuffer, size, digestSize);
268     }
269 
270     /**
271      * get buffer data by level offset, transforms digest data, save in another
272      * memory
273      *
274      * @param buffer       hash data
275      * @param offsetArrays level offset
276      * @param digestSize   algorithm output byte size
277      */
generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)278     private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) {
279         for (int i = offsetArrays.length - 3; i >= 0; i--) {
280             ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]);
281             ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]);
282             transInputDataToHashData(originalHashBuffer, generateHashBuffer);
283             dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize);
284         }
285     }
286 
287     /**
288      * generate merkle tree of given input
289      *
290      * @param dataBuffer            tree data memory block
291      * @param inputDataSize         total size of input stream
292      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
293      * @return merkle tree
294      * @throws NoSuchAlgorithmException if error
295      */
getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)296     private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize,
297         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException {
298         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
299         dataBuffer.flip();
300         byte[] rootHash = null;
301         byte[] tree = null;
302         if (inputDataSize < FSVERITY_HASH_PAGE_SIZE) {
303             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize);
304             rootHash = new byte[digestSize];
305             fsVerityHashPageBuffer.get(rootHash);
306         } else {
307             tree = dataBuffer.array();
308             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE);
309             byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE];
310             fsVerityHashPageBuffer.get(fsVerityHashPage);
311             rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm);
312         }
313         return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
314     }
315 
316     /**
317      * generate merkle tree of given input
318      *
319      * @param data             original data
320      * @param originalDataSize data size
321      * @param digestSize       algorithm output byte size
322      */
dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)323     private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) {
324         long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
325         int diffValue = (int) (fullChunkSize % CHUNK_SIZE);
326         if (diffValue > 0) {
327             byte[] padding = new byte[CHUNK_SIZE - diffValue];
328             data.put(padding, 0, padding.length);
329         }
330     }
331 
332     /**
333      * get mount of chunks to store data
334      *
335      * @param dataSize   data size
336      * @param divisor    split chunk size
337      * @return chunk count
338      */
getChunkCount(long dataSize, long divisor)339     private static long getChunkCount(long dataSize, long divisor) {
340         return (long) Math.ceil((double) dataSize / (double) divisor);
341     }
342 
343     /**
344      * get total size of chunk to store data
345      *
346      * @param dataSize   data size
347      * @param divisor    split chunk size
348      * @param multiplier chunk multiplier
349      * @return chunk size
350      */
getFullChunkSize(long dataSize, long divisor, long multiplier)351     private static long getFullChunkSize(long dataSize, long divisor, long multiplier) {
352         return getChunkCount(dataSize, divisor) * multiplier;
353     }
354 }
355