• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 package com.ohos.hapsigntool.codesigning.fsverity;
17 
18 import com.ohos.hapsigntool.codesigning.utils.DigestUtils;
19 
20 import java.io.IOException;
21 import java.io.InputStream;
22 import java.nio.ByteBuffer;
23 import java.security.NoSuchAlgorithmException;
24 import java.util.ArrayList;
25 import java.util.concurrent.ArrayBlockingQueue;
26 import java.util.concurrent.ExecutorService;
27 import java.util.concurrent.Phaser;
28 import java.util.concurrent.ThreadPoolExecutor;
29 import java.util.concurrent.TimeUnit;
30 
31 /**
32  * Merkle tree builder
33  *
34  * @since 2023/06/05
35  */
36 public class MerkleTreeBuilder implements AutoCloseable {
37     private static final int FSVERITY_HASH_PAGE_SIZE = 4096;
38 
39     private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L;
40 
41     private static final int CHUNK_SIZE = 4096;
42 
43     private static final long MAX_READ_SIZE = 4194304L;
44 
45     private static final int MAX_PROCESSORS = 32;
46 
47     private static final int BLOCKINGQUEUE = 4;
48 
49     private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors());
50 
51     private String mAlgorithm = "SHA-256";
52 
53     private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS,
54         new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy());
55 
56     /**
57      * Turn off multitasking
58      */
close()59     public void close() {
60         this.mPools.shutdownNow();
61     }
62 
63     /**
64      * set algorithm
65      *
66      * @param algorithm hash algorithm
67      */
setAlgorithm(String algorithm)68     private void setAlgorithm(String algorithm) {
69         this.mAlgorithm = algorithm;
70     }
71 
72     /**
73      * translation inputStream to hash data
74      *
75      * @param inputStream  input stream for generating merkle tree
76      * @param size         total size of input stream
77      * @param outputBuffer hash data
78      * @throws IOException if error
79      */
transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)80     private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)
81         throws IOException {
82         if (size == 0) {
83             throw new IOException("Input size is empty");
84         } else if (size > INPUTSTREAM_MAX_SIZE) {
85             throw new IOException("Input size is too long");
86         }
87         int count = (int) getChunkCount(size, MAX_READ_SIZE);
88         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
89         byte[][] hashes = new byte[chunks][];
90         long readOffset = 0L;
91         Phaser tasks = new Phaser(1);
92         synchronized (MerkleTreeBuilder.class) {
93             for (int i = 0; i < count; i++) {
94                 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
95                 int readSize = (int) (readLimit - readOffset);
96                 int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE);
97 
98                 ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize);
99                 int readDataLen = readIs(inputStream, byteBuffer, readSize);
100                 if (readDataLen != readSize) {
101                     throw new IOException("IOException read buffer from input errorLHJ.");
102                 }
103                 byteBuffer.flip();
104                 int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i);
105                 runHashTask(hashes, tasks, byteBuffer, readChunkIndex);
106                 readOffset += readSize;
107             }
108         }
109         tasks.arriveAndAwaitAdvance();
110         for (byte[] hash : hashes) {
111             outputBuffer.put(hash, 0, hash.length);
112         }
113     }
114 
readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize)115     private int readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize) throws IOException {
116         byte[] buffer = new byte[CHUNK_SIZE];
117         int num;
118         int readDataLen = 0;
119         int len = CHUNK_SIZE;
120         while ((num = inputStream.read(buffer, 0, len)) > 0) {
121             byteBuffer.put(buffer, 0, num);
122             readDataLen += num;
123             len = Math.min(CHUNK_SIZE, readSize - readDataLen);
124             if (len <= 0 || readDataLen == readSize) {
125                 break;
126             }
127         }
128         return readDataLen;
129     }
130 
131     /**
132      * split buffer by begin and end information
133      *
134      * @param buffer original buffer
135      * @param begin  begin position
136      * @param end    end position
137      * @return slice buffer
138      */
slice(ByteBuffer buffer, int begin, int end)139     private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) {
140         ByteBuffer tempBuffer = buffer.duplicate();
141         tempBuffer.position(0);
142         tempBuffer.limit(end);
143         tempBuffer.position(begin);
144         return tempBuffer.slice();
145     }
146 
147     /**
148      * calculate merkle tree level and size by data size and digest size
149      *
150      * @param dataSize   original data size
151      * @param digestSize algorithm data size
152      * @return level offset list, contains the offset of
153      * each level from the root node to the leaf node
154      */
getOffsetArrays(long dataSize, int digestSize)155     private static int[] getOffsetArrays(long dataSize, int digestSize) {
156         ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize);
157         int[] levelOffset = new int[levelSize.size() + 1];
158         levelOffset[0] = 0;
159         for (int i = 0; i < levelSize.size(); i++) {
160             levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1));
161         }
162         return levelOffset;
163     }
164 
165     /**
166      * calculate data size list by data size and digest size
167      *
168      * @param dataSize   original data size
169      * @param digestSize algorithm data size
170      * @return data size list, contains the offset of
171      * each level from the root node to the leaf node
172      */
getLevelSize(long dataSize, int digestSize)173     private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) {
174         ArrayList<Long> levelSize = new ArrayList<>();
175         long fullChunkSize = 0L;
176         long originalDataSize = dataSize;
177         do {
178             fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
179             long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
180             levelSize.add(size);
181             originalDataSize = fullChunkSize;
182         } while (fullChunkSize > CHUNK_SIZE);
183         return levelSize;
184     }
185 
runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)186     private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) {
187         Runnable task = () -> {
188             int offset = 0;
189             int bufferSize = buffer.capacity();
190             int index = readChunkIndex;
191             while (offset < bufferSize) {
192                 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE);
193                 byte[] tempByte = new byte[CHUNK_SIZE];
194                 chunk.get(tempByte);
195                 try {
196                     hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm);
197                 } catch (NoSuchAlgorithmException e) {
198                     throw new IllegalStateException(e);
199                 }
200                 offset += CHUNK_SIZE;
201             }
202             tasks.arriveAndDeregister();
203         };
204         tasks.register();
205         this.mPools.execute(task);
206     }
207 
208     /**
209      * hash data of buffer
210      *
211      * @param inputBuffer  original data
212      * @param outputBuffer hash data
213      */
transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)214     private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) {
215         long size = inputBuffer.capacity();
216         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
217         byte[][] hashes = new byte[chunks][];
218         Phaser tasks = new Phaser(1);
219         long readOffset = 0L;
220         int startChunkIndex = 0;
221         while (readOffset < size) {
222             long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
223             ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit);
224             buffer.rewind();
225             int readChunkIndex = startChunkIndex;
226             runHashTask(hashes, tasks, buffer, readChunkIndex);
227             int readSize = (int) (readLimit - readOffset);
228             startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE);
229             readOffset += readSize;
230         }
231         tasks.arriveAndAwaitAdvance();
232         for (byte[] hash : hashes) {
233             outputBuffer.put(hash, 0, hash.length);
234         }
235     }
236 
237     /**
238      * generate merkle tree of given input
239      *
240      * @param inputStream           input stream for generate merkle tree
241      * @param size                  total size of input stream
242      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
243      * @return merkle tree
244      * @throws IOException              if error
245      * @throws NoSuchAlgorithmException if error
246      */
generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)247     public MerkleTree generateMerkleTree(InputStream inputStream, long size,
248         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException {
249         setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm());
250         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
251         int[] offsetArrays = getOffsetArrays(size, digestSize);
252         ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]);
253         generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize);
254         generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize);
255         return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm);
256     }
257 
258     /**
259      * translation inputBuffer arrays to hash ByteBuffer
260      *
261      * @param inputStream  input stream for generate merkle tree
262      * @param size         total size of input stream
263      * @param outputBuffer hash data
264      * @param offsetArrays level offset
265      * @param digestSize   algorithm output byte size
266      * @throws IOException if error
267      */
generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)268     private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer,
269         int[] offsetArrays, int digestSize) throws IOException {
270         int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2];
271         int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1];
272         ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd);
273         transInputStreamToHashData(inputStream, size, hashBuffer);
274         dataRoundupChunkSize(hashBuffer, size, digestSize);
275     }
276 
277     /**
278      * get buffer data by level offset, transforms digest data, save in another
279      * memory
280      *
281      * @param buffer       hash data
282      * @param offsetArrays level offset
283      * @param digestSize   algorithm output byte size
284      */
generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)285     private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) {
286         for (int i = offsetArrays.length - 3; i >= 0; i--) {
287             ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]);
288             ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]);
289             transInputDataToHashData(originalHashBuffer, generateHashBuffer);
290             dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize);
291         }
292     }
293 
294     /**
295      * generate merkle tree of given input
296      *
297      * @param dataBuffer            tree data memory block
298      * @param inputDataSize         total size of input stream
299      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
300      * @return merkle tree
301      * @throws NoSuchAlgorithmException if error
302      */
getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)303     private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize,
304         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException {
305         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
306         dataBuffer.flip();
307         byte[] rootHash = null;
308         byte[] tree = null;
309         if (inputDataSize <= FSVERITY_HASH_PAGE_SIZE) {
310             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize);
311             rootHash = new byte[digestSize];
312             fsVerityHashPageBuffer.get(rootHash);
313         } else {
314             tree = dataBuffer.array();
315             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE);
316             byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE];
317             fsVerityHashPageBuffer.get(fsVerityHashPage);
318             rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm);
319         }
320         return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
321     }
322 
323     /**
324      * generate merkle tree of given input
325      *
326      * @param data             original data
327      * @param originalDataSize data size
328      * @param digestSize       algorithm output byte size
329      */
dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)330     private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) {
331         long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
332         int diffValue = (int) (fullChunkSize % CHUNK_SIZE);
333         if (diffValue > 0) {
334             byte[] padding = new byte[CHUNK_SIZE - diffValue];
335             data.put(padding, 0, padding.length);
336         }
337     }
338 
339     /**
340      * get mount of chunks to store data
341      *
342      * @param dataSize   data size
343      * @param divisor    split chunk size
344      * @return chunk count
345      */
getChunkCount(long dataSize, long divisor)346     private static long getChunkCount(long dataSize, long divisor) {
347         return (long) Math.ceil((double) dataSize / (double) divisor);
348     }
349 
350     /**
351      * get total size of chunk to store data
352      *
353      * @param dataSize   data size
354      * @param divisor    split chunk size
355      * @param multiplier chunk multiplier
356      * @return chunk size
357      */
getFullChunkSize(long dataSize, long divisor, long multiplier)358     private static long getFullChunkSize(long dataSize, long divisor, long multiplier) {
359         return getChunkCount(dataSize, divisor) * multiplier;
360     }
361 }
362