• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 package com.ohos.hapsigntool.codesigning.fsverity;
17 
18 import com.ohos.hapsigntool.codesigning.exception.CodeSignErrMsg;
19 import com.ohos.hapsigntool.codesigning.utils.DigestUtils;
20 
21 import java.io.IOException;
22 import java.io.InputStream;
23 import java.nio.ByteBuffer;
24 import java.security.NoSuchAlgorithmException;
25 import java.util.ArrayList;
26 import java.util.concurrent.ArrayBlockingQueue;
27 import java.util.concurrent.ExecutorService;
28 import java.util.concurrent.Phaser;
29 import java.util.concurrent.ThreadPoolExecutor;
30 import java.util.concurrent.TimeUnit;
31 
32 /**
33  * Merkle tree builder
34  *
35  * @since 2023/06/05
36  */
37 public class MerkleTreeBuilder implements AutoCloseable {
38     private static final int FSVERITY_HASH_PAGE_SIZE = 4096;
39 
40     private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L;
41 
42     private static final int CHUNK_SIZE = 4096;
43 
44     private static final long MAX_READ_SIZE = 4194304L;
45 
46     private static final int MAX_PROCESSORS = 32;
47 
48     private static final int BLOCKINGQUEUE = 4;
49 
50     private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors());
51 
52     private String mAlgorithm = "SHA-256";
53 
54     private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS,
55         new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy());
56 
57     /**
58      * Turn off multitasking
59      */
close()60     public void close() {
61         this.mPools.shutdownNow();
62     }
63 
64     /**
65      * set algorithm
66      *
67      * @param algorithm hash algorithm
68      */
setAlgorithm(String algorithm)69     private void setAlgorithm(String algorithm) {
70         this.mAlgorithm = algorithm;
71     }
72 
73     /**
74      * translation inputStream to hash data
75      *
76      * @param inputStream  input stream for generating merkle tree
77      * @param size         total size of input stream
78      * @param outputBuffer hash data
79      * @throws IOException if error
80      */
transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)81     private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)
82         throws IOException {
83         if (size == 0) {
84             throw new IOException(CodeSignErrMsg.CODE_SIGN_INTERNAL_ERROR.toString("Input size is empty"));
85         } else if (size > INPUTSTREAM_MAX_SIZE) {
86             throw new IOException(CodeSignErrMsg.CODE_SIGN_INTERNAL_ERROR.toString("Input size is too long"));
87         }
88         int count = (int) getChunkCount(size, MAX_READ_SIZE);
89         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
90         byte[][] hashes = new byte[chunks][];
91         long readOffset = 0L;
92         Phaser tasks = new Phaser(1);
93         synchronized (MerkleTreeBuilder.class) {
94             for (int i = 0; i < count; i++) {
95                 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
96                 int readSize = (int) (readLimit - readOffset);
97                 int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE);
98 
99                 ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize);
100                 int readDataLen = readIs(inputStream, byteBuffer, readSize);
101                 if (readDataLen != readSize) {
102                     throw new IOException(CodeSignErrMsg.READ_INPUT_STREAM_ERROR.toString());
103                 }
104                 byteBuffer.flip();
105                 int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i);
106                 runHashTask(hashes, tasks, byteBuffer, readChunkIndex);
107                 readOffset += readSize;
108             }
109         }
110         tasks.arriveAndAwaitAdvance();
111         for (byte[] hash : hashes) {
112             outputBuffer.put(hash, 0, hash.length);
113         }
114     }
115 
readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize)116     private int readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize) throws IOException {
117         byte[] buffer = new byte[CHUNK_SIZE];
118         int num;
119         int readDataLen = 0;
120         int len = CHUNK_SIZE;
121         while ((num = inputStream.read(buffer, 0, len)) > 0) {
122             byteBuffer.put(buffer, 0, num);
123             readDataLen += num;
124             len = Math.min(CHUNK_SIZE, readSize - readDataLen);
125             if (len <= 0 || readDataLen == readSize) {
126                 break;
127             }
128         }
129         return readDataLen;
130     }
131 
132     /**
133      * split buffer by begin and end information
134      *
135      * @param buffer original buffer
136      * @param begin  begin position
137      * @param end    end position
138      * @return slice buffer
139      */
slice(ByteBuffer buffer, int begin, int end)140     private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) {
141         ByteBuffer tempBuffer = buffer.duplicate();
142         tempBuffer.position(0);
143         tempBuffer.limit(end);
144         tempBuffer.position(begin);
145         return tempBuffer.slice();
146     }
147 
148     /**
149      * calculate merkle tree level and size by data size and digest size
150      *
151      * @param dataSize   original data size
152      * @param digestSize algorithm data size
153      * @return level offset list, contains the offset of
154      * each level from the root node to the leaf node
155      */
getOffsetArrays(long dataSize, int digestSize)156     private static int[] getOffsetArrays(long dataSize, int digestSize) {
157         ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize);
158         int[] levelOffset = new int[levelSize.size() + 1];
159         levelOffset[0] = 0;
160         for (int i = 0; i < levelSize.size(); i++) {
161             levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1));
162         }
163         return levelOffset;
164     }
165 
166     /**
167      * calculate data size list by data size and digest size
168      *
169      * @param dataSize   original data size
170      * @param digestSize algorithm data size
171      * @return data size list, contains the offset of
172      * each level from the root node to the leaf node
173      */
getLevelSize(long dataSize, int digestSize)174     private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) {
175         ArrayList<Long> levelSize = new ArrayList<>();
176         long fullChunkSize = 0L;
177         long originalDataSize = dataSize;
178         do {
179             fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
180             long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
181             levelSize.add(size);
182             originalDataSize = fullChunkSize;
183         } while (fullChunkSize > CHUNK_SIZE);
184         return levelSize;
185     }
186 
runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)187     private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) {
188         Runnable task = () -> {
189             int offset = 0;
190             int bufferSize = buffer.capacity();
191             int index = readChunkIndex;
192             while (offset < bufferSize) {
193                 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE);
194                 byte[] tempByte = new byte[CHUNK_SIZE];
195                 chunk.get(tempByte);
196                 try {
197                     hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm);
198                 } catch (NoSuchAlgorithmException e) {
199                     throw new IllegalStateException(
200                         CodeSignErrMsg.ALGORITHM_NOT_SUPPORT_ERROR.toString(this.mAlgorithm), e);
201                 }
202                 offset += CHUNK_SIZE;
203             }
204             tasks.arriveAndDeregister();
205         };
206         tasks.register();
207         this.mPools.execute(task);
208     }
209 
210     /**
211      * hash data of buffer
212      *
213      * @param inputBuffer  original data
214      * @param outputBuffer hash data
215      */
transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)216     private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) {
217         long size = inputBuffer.capacity();
218         int chunks = (int) getChunkCount(size, CHUNK_SIZE);
219         byte[][] hashes = new byte[chunks][];
220         Phaser tasks = new Phaser(1);
221         long readOffset = 0L;
222         int startChunkIndex = 0;
223         while (readOffset < size) {
224             long readLimit = Math.min(readOffset + MAX_READ_SIZE, size);
225             ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit);
226             buffer.rewind();
227             int readChunkIndex = startChunkIndex;
228             runHashTask(hashes, tasks, buffer, readChunkIndex);
229             int readSize = (int) (readLimit - readOffset);
230             startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE);
231             readOffset += readSize;
232         }
233         tasks.arriveAndAwaitAdvance();
234         for (byte[] hash : hashes) {
235             outputBuffer.put(hash, 0, hash.length);
236         }
237     }
238 
239     /**
240      * generate merkle tree of given input
241      *
242      * @param inputStream           input stream for generate merkle tree
243      * @param size                  total size of input stream
244      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
245      * @return merkle tree
246      * @throws IOException              if error
247      * @throws NoSuchAlgorithmException if error
248      */
generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)249     public MerkleTree generateMerkleTree(InputStream inputStream, long size,
250         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException {
251         setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm());
252         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
253         int[] offsetArrays = getOffsetArrays(size, digestSize);
254         ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]);
255         generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize);
256         generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize);
257         return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm);
258     }
259 
260     /**
261      * translation inputBuffer arrays to hash ByteBuffer
262      *
263      * @param inputStream  input stream for generate merkle tree
264      * @param size         total size of input stream
265      * @param outputBuffer hash data
266      * @param offsetArrays level offset
267      * @param digestSize   algorithm output byte size
268      * @throws IOException if error
269      */
generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)270     private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer,
271         int[] offsetArrays, int digestSize) throws IOException {
272         int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2];
273         int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1];
274         ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd);
275         transInputStreamToHashData(inputStream, size, hashBuffer);
276         dataRoundupChunkSize(hashBuffer, size, digestSize);
277     }
278 
279     /**
280      * get buffer data by level offset, transforms digest data, save in another
281      * memory
282      *
283      * @param buffer       hash data
284      * @param offsetArrays level offset
285      * @param digestSize   algorithm output byte size
286      */
generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)287     private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) {
288         for (int i = offsetArrays.length - 3; i >= 0; i--) {
289             ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]);
290             ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]);
291             transInputDataToHashData(originalHashBuffer, generateHashBuffer);
292             dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize);
293         }
294     }
295 
296     /**
297      * generate merkle tree of given input
298      *
299      * @param dataBuffer            tree data memory block
300      * @param inputDataSize         total size of input stream
301      * @param fsVerityHashAlgorithm hash algorithm for FsVerity
302      * @return merkle tree
303      * @throws NoSuchAlgorithmException if error
304      */
getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)305     private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize,
306         FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException {
307         int digestSize = fsVerityHashAlgorithm.getOutputByteSize();
308         dataBuffer.flip();
309         byte[] rootHash = null;
310         byte[] tree = null;
311         if (inputDataSize <= FSVERITY_HASH_PAGE_SIZE) {
312             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize);
313             rootHash = new byte[digestSize];
314             fsVerityHashPageBuffer.get(rootHash);
315         } else {
316             tree = dataBuffer.array();
317             ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE);
318             byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE];
319             fsVerityHashPageBuffer.get(fsVerityHashPage);
320             rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm);
321         }
322         return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
323     }
324 
325     /**
326      * generate merkle tree of given input
327      *
328      * @param data             original data
329      * @param originalDataSize data size
330      * @param digestSize       algorithm output byte size
331      */
dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)332     private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) {
333         long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
334         int diffValue = (int) (fullChunkSize % CHUNK_SIZE);
335         if (diffValue > 0) {
336             byte[] padding = new byte[CHUNK_SIZE - diffValue];
337             data.put(padding, 0, padding.length);
338         }
339     }
340 
341     /**
342      * get mount of chunks to store data
343      *
344      * @param dataSize   data size
345      * @param divisor    split chunk size
346      * @return chunk count
347      */
getChunkCount(long dataSize, long divisor)348     private static long getChunkCount(long dataSize, long divisor) {
349         return (long) Math.ceil((double) dataSize / (double) divisor);
350     }
351 
352     /**
353      * get total size of chunk to store data
354      *
355      * @param dataSize   data size
356      * @param divisor    split chunk size
357      * @param multiplier chunk multiplier
358      * @return chunk size
359      */
getFullChunkSize(long dataSize, long divisor, long multiplier)360     private static long getFullChunkSize(long dataSize, long divisor, long multiplier) {
361         return getChunkCount(dataSize, divisor) * multiplier;
362     }
363 }
364