1 /* 2 * Copyright (c) 2023-2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 package com.ohos.hapsigntool.codesigning.fsverity; 17 18 import com.ohos.hapsigntool.codesigning.exception.CodeSignErrMsg; 19 import com.ohos.hapsigntool.codesigning.utils.DigestUtils; 20 21 import java.io.IOException; 22 import java.io.InputStream; 23 import java.nio.ByteBuffer; 24 import java.security.NoSuchAlgorithmException; 25 import java.util.ArrayList; 26 import java.util.concurrent.ArrayBlockingQueue; 27 import java.util.concurrent.ExecutorService; 28 import java.util.concurrent.Phaser; 29 import java.util.concurrent.ThreadPoolExecutor; 30 import java.util.concurrent.TimeUnit; 31 32 /** 33 * Merkle tree builder 34 * 35 * @since 2023/06/05 36 */ 37 public class MerkleTreeBuilder implements AutoCloseable { 38 private static final int FSVERITY_HASH_PAGE_SIZE = 4096; 39 40 private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L; 41 42 private static final int CHUNK_SIZE = 4096; 43 44 private static final long MAX_READ_SIZE = 4194304L; 45 46 private static final int MAX_PROCESSORS = 32; 47 48 private static final int BLOCKINGQUEUE = 4; 49 50 private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors()); 51 52 private String mAlgorithm = "SHA-256"; 53 54 private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS, 55 new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy()); 56 57 /** 58 * Turn off multitasking 59 */ close()60 public void close() { 61 this.mPools.shutdownNow(); 62 } 63 64 /** 65 * set algorithm 66 * 67 * @param algorithm hash algorithm 68 */ setAlgorithm(String algorithm)69 private void setAlgorithm(String algorithm) { 70 this.mAlgorithm = algorithm; 71 } 72 73 /** 74 * translation inputStream to hash data 75 * 76 * @param inputStream input stream for generating merkle tree 77 * @param size total size of input stream 78 * @param outputBuffer hash data 79 * @throws IOException if error 80 */ transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)81 private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer) 82 throws IOException { 83 if (size == 0) { 84 throw new IOException(CodeSignErrMsg.CODE_SIGN_INTERNAL_ERROR.toString("Input size is empty")); 85 } else if (size > INPUTSTREAM_MAX_SIZE) { 86 throw new IOException(CodeSignErrMsg.CODE_SIGN_INTERNAL_ERROR.toString("Input size is too long")); 87 } 88 int count = (int) getChunkCount(size, MAX_READ_SIZE); 89 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 90 byte[][] hashes = new byte[chunks][]; 91 long readOffset = 0L; 92 Phaser tasks = new Phaser(1); 93 synchronized (MerkleTreeBuilder.class) { 94 for (int i = 0; i < count; i++) { 95 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 96 int readSize = (int) (readLimit - readOffset); 97 int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE); 98 99 ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize); 100 int readDataLen = readIs(inputStream, byteBuffer, readSize); 101 if (readDataLen != readSize) { 102 throw new IOException(CodeSignErrMsg.READ_INPUT_STREAM_ERROR.toString()); 103 } 104 byteBuffer.flip(); 105 int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i); 106 runHashTask(hashes, tasks, byteBuffer, readChunkIndex); 107 readOffset += readSize; 108 } 109 } 110 tasks.arriveAndAwaitAdvance(); 111 for (byte[] hash : hashes) { 112 outputBuffer.put(hash, 0, hash.length); 113 } 114 } 115 readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize)116 private int readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize) throws IOException { 117 byte[] buffer = new byte[CHUNK_SIZE]; 118 int num; 119 int readDataLen = 0; 120 int len = CHUNK_SIZE; 121 while ((num = inputStream.read(buffer, 0, len)) > 0) { 122 byteBuffer.put(buffer, 0, num); 123 readDataLen += num; 124 len = Math.min(CHUNK_SIZE, readSize - readDataLen); 125 if (len <= 0 || readDataLen == readSize) { 126 break; 127 } 128 } 129 return readDataLen; 130 } 131 132 /** 133 * split buffer by begin and end information 134 * 135 * @param buffer original buffer 136 * @param begin begin position 137 * @param end end position 138 * @return slice buffer 139 */ slice(ByteBuffer buffer, int begin, int end)140 private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) { 141 ByteBuffer tempBuffer = buffer.duplicate(); 142 tempBuffer.position(0); 143 tempBuffer.limit(end); 144 tempBuffer.position(begin); 145 return tempBuffer.slice(); 146 } 147 148 /** 149 * calculate merkle tree level and size by data size and digest size 150 * 151 * @param dataSize original data size 152 * @param digestSize algorithm data size 153 * @return level offset list, contains the offset of 154 * each level from the root node to the leaf node 155 */ getOffsetArrays(long dataSize, int digestSize)156 private static int[] getOffsetArrays(long dataSize, int digestSize) { 157 ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize); 158 int[] levelOffset = new int[levelSize.size() + 1]; 159 levelOffset[0] = 0; 160 for (int i = 0; i < levelSize.size(); i++) { 161 levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1)); 162 } 163 return levelOffset; 164 } 165 166 /** 167 * calculate data size list by data size and digest size 168 * 169 * @param dataSize original data size 170 * @param digestSize algorithm data size 171 * @return data size list, contains the offset of 172 * each level from the root node to the leaf node 173 */ getLevelSize(long dataSize, int digestSize)174 private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) { 175 ArrayList<Long> levelSize = new ArrayList<>(); 176 long fullChunkSize = 0L; 177 long originalDataSize = dataSize; 178 do { 179 fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 180 long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE); 181 levelSize.add(size); 182 originalDataSize = fullChunkSize; 183 } while (fullChunkSize > CHUNK_SIZE); 184 return levelSize; 185 } 186 runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)187 private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) { 188 Runnable task = () -> { 189 int offset = 0; 190 int bufferSize = buffer.capacity(); 191 int index = readChunkIndex; 192 while (offset < bufferSize) { 193 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE); 194 byte[] tempByte = new byte[CHUNK_SIZE]; 195 chunk.get(tempByte); 196 try { 197 hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm); 198 } catch (NoSuchAlgorithmException e) { 199 throw new IllegalStateException( 200 CodeSignErrMsg.ALGORITHM_NOT_SUPPORT_ERROR.toString(this.mAlgorithm), e); 201 } 202 offset += CHUNK_SIZE; 203 } 204 tasks.arriveAndDeregister(); 205 }; 206 tasks.register(); 207 this.mPools.execute(task); 208 } 209 210 /** 211 * hash data of buffer 212 * 213 * @param inputBuffer original data 214 * @param outputBuffer hash data 215 */ transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)216 private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) { 217 long size = inputBuffer.capacity(); 218 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 219 byte[][] hashes = new byte[chunks][]; 220 Phaser tasks = new Phaser(1); 221 long readOffset = 0L; 222 int startChunkIndex = 0; 223 while (readOffset < size) { 224 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 225 ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit); 226 buffer.rewind(); 227 int readChunkIndex = startChunkIndex; 228 runHashTask(hashes, tasks, buffer, readChunkIndex); 229 int readSize = (int) (readLimit - readOffset); 230 startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE); 231 readOffset += readSize; 232 } 233 tasks.arriveAndAwaitAdvance(); 234 for (byte[] hash : hashes) { 235 outputBuffer.put(hash, 0, hash.length); 236 } 237 } 238 239 /** 240 * generate merkle tree of given input 241 * 242 * @param inputStream input stream for generate merkle tree 243 * @param size total size of input stream 244 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 245 * @return merkle tree 246 * @throws IOException if error 247 * @throws NoSuchAlgorithmException if error 248 */ generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)249 public MerkleTree generateMerkleTree(InputStream inputStream, long size, 250 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException { 251 setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm()); 252 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 253 int[] offsetArrays = getOffsetArrays(size, digestSize); 254 ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]); 255 generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize); 256 generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize); 257 return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm); 258 } 259 260 /** 261 * translation inputBuffer arrays to hash ByteBuffer 262 * 263 * @param inputStream input stream for generate merkle tree 264 * @param size total size of input stream 265 * @param outputBuffer hash data 266 * @param offsetArrays level offset 267 * @param digestSize algorithm output byte size 268 * @throws IOException if error 269 */ generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)270 private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, 271 int[] offsetArrays, int digestSize) throws IOException { 272 int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2]; 273 int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1]; 274 ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd); 275 transInputStreamToHashData(inputStream, size, hashBuffer); 276 dataRoundupChunkSize(hashBuffer, size, digestSize); 277 } 278 279 /** 280 * get buffer data by level offset, transforms digest data, save in another 281 * memory 282 * 283 * @param buffer hash data 284 * @param offsetArrays level offset 285 * @param digestSize algorithm output byte size 286 */ generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)287 private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) { 288 for (int i = offsetArrays.length - 3; i >= 0; i--) { 289 ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]); 290 ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]); 291 transInputDataToHashData(originalHashBuffer, generateHashBuffer); 292 dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize); 293 } 294 } 295 296 /** 297 * generate merkle tree of given input 298 * 299 * @param dataBuffer tree data memory block 300 * @param inputDataSize total size of input stream 301 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 302 * @return merkle tree 303 * @throws NoSuchAlgorithmException if error 304 */ getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)305 private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, 306 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException { 307 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 308 dataBuffer.flip(); 309 byte[] rootHash = null; 310 byte[] tree = null; 311 if (inputDataSize <= FSVERITY_HASH_PAGE_SIZE) { 312 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize); 313 rootHash = new byte[digestSize]; 314 fsVerityHashPageBuffer.get(rootHash); 315 } else { 316 tree = dataBuffer.array(); 317 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE); 318 byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE]; 319 fsVerityHashPageBuffer.get(fsVerityHashPage); 320 rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm); 321 } 322 return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm); 323 } 324 325 /** 326 * generate merkle tree of given input 327 * 328 * @param data original data 329 * @param originalDataSize data size 330 * @param digestSize algorithm output byte size 331 */ dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)332 private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) { 333 long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 334 int diffValue = (int) (fullChunkSize % CHUNK_SIZE); 335 if (diffValue > 0) { 336 byte[] padding = new byte[CHUNK_SIZE - diffValue]; 337 data.put(padding, 0, padding.length); 338 } 339 } 340 341 /** 342 * get mount of chunks to store data 343 * 344 * @param dataSize data size 345 * @param divisor split chunk size 346 * @return chunk count 347 */ getChunkCount(long dataSize, long divisor)348 private static long getChunkCount(long dataSize, long divisor) { 349 return (long) Math.ceil((double) dataSize / (double) divisor); 350 } 351 352 /** 353 * get total size of chunk to store data 354 * 355 * @param dataSize data size 356 * @param divisor split chunk size 357 * @param multiplier chunk multiplier 358 * @return chunk size 359 */ getFullChunkSize(long dataSize, long divisor, long multiplier)360 private static long getFullChunkSize(long dataSize, long divisor, long multiplier) { 361 return getChunkCount(dataSize, divisor) * multiplier; 362 } 363 } 364