1 /* 2 * Copyright (c) 2023-2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 package com.ohos.hapsigntool.codesigning.fsverity; 17 18 import com.ohos.hapsigntool.codesigning.utils.DigestUtils; 19 20 import java.io.IOException; 21 import java.io.InputStream; 22 import java.nio.ByteBuffer; 23 import java.security.NoSuchAlgorithmException; 24 import java.util.ArrayList; 25 import java.util.concurrent.ArrayBlockingQueue; 26 import java.util.concurrent.ExecutorService; 27 import java.util.concurrent.Phaser; 28 import java.util.concurrent.ThreadPoolExecutor; 29 import java.util.concurrent.TimeUnit; 30 31 /** 32 * Merkle tree builder 33 * 34 * @since 2023/06/05 35 */ 36 public class MerkleTreeBuilder implements AutoCloseable { 37 private static final int FSVERITY_HASH_PAGE_SIZE = 4096; 38 39 private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L; 40 41 private static final int CHUNK_SIZE = 4096; 42 43 private static final long MAX_READ_SIZE = 4194304L; 44 45 private static final int MAX_PROCESSORS = 32; 46 47 private static final int BLOCKINGQUEUE = 4; 48 49 private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors()); 50 51 private String mAlgorithm = "SHA-256"; 52 53 private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS, 54 new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy()); 55 56 /** 57 * Turn off multitasking 58 */ close()59 public void close() { 60 this.mPools.shutdownNow(); 61 } 62 63 /** 64 * set algorithm 65 * 66 * @param algorithm hash algorithm 67 */ setAlgorithm(String algorithm)68 private void setAlgorithm(String algorithm) { 69 this.mAlgorithm = algorithm; 70 } 71 72 /** 73 * translation inputStream to hash data 74 * 75 * @param inputStream input stream for generating merkle tree 76 * @param size total size of input stream 77 * @param outputBuffer hash data 78 * @throws IOException if error 79 */ transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)80 private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer) 81 throws IOException { 82 if (size == 0) { 83 throw new IOException("Input size is empty"); 84 } else if (size > INPUTSTREAM_MAX_SIZE) { 85 throw new IOException("Input size is too long"); 86 } 87 int count = (int) getChunkCount(size, MAX_READ_SIZE); 88 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 89 byte[][] hashes = new byte[chunks][]; 90 long readOffset = 0L; 91 Phaser tasks = new Phaser(1); 92 synchronized (MerkleTreeBuilder.class) { 93 for (int i = 0; i < count; i++) { 94 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 95 int readSize = (int) (readLimit - readOffset); 96 int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE); 97 98 ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize); 99 int readDataLen = readIs(inputStream, byteBuffer, readSize); 100 if (readDataLen != readSize) { 101 throw new IOException("IOException read buffer from input errorLHJ."); 102 } 103 byteBuffer.flip(); 104 int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i); 105 runHashTask(hashes, tasks, byteBuffer, readChunkIndex); 106 readOffset += readSize; 107 } 108 } 109 tasks.arriveAndAwaitAdvance(); 110 for (byte[] hash : hashes) { 111 outputBuffer.put(hash, 0, hash.length); 112 } 113 } 114 readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize)115 private int readIs(InputStream inputStream, ByteBuffer byteBuffer, int readSize) throws IOException { 116 byte[] buffer = new byte[CHUNK_SIZE]; 117 int num; 118 int readDataLen = 0; 119 int len = CHUNK_SIZE; 120 while ((num = inputStream.read(buffer, 0, len)) > 0) { 121 byteBuffer.put(buffer, 0, num); 122 readDataLen += num; 123 len = Math.min(CHUNK_SIZE, readSize - readDataLen); 124 if (len <= 0 || readDataLen == readSize) { 125 break; 126 } 127 } 128 return readDataLen; 129 } 130 131 /** 132 * split buffer by begin and end information 133 * 134 * @param buffer original buffer 135 * @param begin begin position 136 * @param end end position 137 * @return slice buffer 138 */ slice(ByteBuffer buffer, int begin, int end)139 private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) { 140 ByteBuffer tempBuffer = buffer.duplicate(); 141 tempBuffer.position(0); 142 tempBuffer.limit(end); 143 tempBuffer.position(begin); 144 return tempBuffer.slice(); 145 } 146 147 /** 148 * calculate merkle tree level and size by data size and digest size 149 * 150 * @param dataSize original data size 151 * @param digestSize algorithm data size 152 * @return level offset list, contains the offset of 153 * each level from the root node to the leaf node 154 */ getOffsetArrays(long dataSize, int digestSize)155 private static int[] getOffsetArrays(long dataSize, int digestSize) { 156 ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize); 157 int[] levelOffset = new int[levelSize.size() + 1]; 158 levelOffset[0] = 0; 159 for (int i = 0; i < levelSize.size(); i++) { 160 levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1)); 161 } 162 return levelOffset; 163 } 164 165 /** 166 * calculate data size list by data size and digest size 167 * 168 * @param dataSize original data size 169 * @param digestSize algorithm data size 170 * @return data size list, contains the offset of 171 * each level from the root node to the leaf node 172 */ getLevelSize(long dataSize, int digestSize)173 private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) { 174 ArrayList<Long> levelSize = new ArrayList<>(); 175 long fullChunkSize = 0L; 176 long originalDataSize = dataSize; 177 do { 178 fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 179 long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE); 180 levelSize.add(size); 181 originalDataSize = fullChunkSize; 182 } while (fullChunkSize > CHUNK_SIZE); 183 return levelSize; 184 } 185 runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)186 private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) { 187 Runnable task = () -> { 188 int offset = 0; 189 int bufferSize = buffer.capacity(); 190 int index = readChunkIndex; 191 while (offset < bufferSize) { 192 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE); 193 byte[] tempByte = new byte[CHUNK_SIZE]; 194 chunk.get(tempByte); 195 try { 196 hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm); 197 } catch (NoSuchAlgorithmException e) { 198 throw new IllegalStateException(e); 199 } 200 offset += CHUNK_SIZE; 201 } 202 tasks.arriveAndDeregister(); 203 }; 204 tasks.register(); 205 this.mPools.execute(task); 206 } 207 208 /** 209 * hash data of buffer 210 * 211 * @param inputBuffer original data 212 * @param outputBuffer hash data 213 */ transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)214 private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) { 215 long size = inputBuffer.capacity(); 216 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 217 byte[][] hashes = new byte[chunks][]; 218 Phaser tasks = new Phaser(1); 219 long readOffset = 0L; 220 int startChunkIndex = 0; 221 while (readOffset < size) { 222 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 223 ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit); 224 buffer.rewind(); 225 int readChunkIndex = startChunkIndex; 226 runHashTask(hashes, tasks, buffer, readChunkIndex); 227 int readSize = (int) (readLimit - readOffset); 228 startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE); 229 readOffset += readSize; 230 } 231 tasks.arriveAndAwaitAdvance(); 232 for (byte[] hash : hashes) { 233 outputBuffer.put(hash, 0, hash.length); 234 } 235 } 236 237 /** 238 * generate merkle tree of given input 239 * 240 * @param inputStream input stream for generate merkle tree 241 * @param size total size of input stream 242 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 243 * @return merkle tree 244 * @throws IOException if error 245 * @throws NoSuchAlgorithmException if error 246 */ generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)247 public MerkleTree generateMerkleTree(InputStream inputStream, long size, 248 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException { 249 setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm()); 250 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 251 int[] offsetArrays = getOffsetArrays(size, digestSize); 252 ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]); 253 generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize); 254 generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize); 255 return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm); 256 } 257 258 /** 259 * translation inputBuffer arrays to hash ByteBuffer 260 * 261 * @param inputStream input stream for generate merkle tree 262 * @param size total size of input stream 263 * @param outputBuffer hash data 264 * @param offsetArrays level offset 265 * @param digestSize algorithm output byte size 266 * @throws IOException if error 267 */ generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)268 private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, 269 int[] offsetArrays, int digestSize) throws IOException { 270 int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2]; 271 int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1]; 272 ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd); 273 transInputStreamToHashData(inputStream, size, hashBuffer); 274 dataRoundupChunkSize(hashBuffer, size, digestSize); 275 } 276 277 /** 278 * get buffer data by level offset, transforms digest data, save in another 279 * memory 280 * 281 * @param buffer hash data 282 * @param offsetArrays level offset 283 * @param digestSize algorithm output byte size 284 */ generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)285 private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) { 286 for (int i = offsetArrays.length - 3; i >= 0; i--) { 287 ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]); 288 ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]); 289 transInputDataToHashData(originalHashBuffer, generateHashBuffer); 290 dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize); 291 } 292 } 293 294 /** 295 * generate merkle tree of given input 296 * 297 * @param dataBuffer tree data memory block 298 * @param inputDataSize total size of input stream 299 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 300 * @return merkle tree 301 * @throws NoSuchAlgorithmException if error 302 */ getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)303 private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, 304 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException { 305 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 306 dataBuffer.flip(); 307 byte[] rootHash = null; 308 byte[] tree = null; 309 if (inputDataSize <= FSVERITY_HASH_PAGE_SIZE) { 310 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize); 311 rootHash = new byte[digestSize]; 312 fsVerityHashPageBuffer.get(rootHash); 313 } else { 314 tree = dataBuffer.array(); 315 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE); 316 byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE]; 317 fsVerityHashPageBuffer.get(fsVerityHashPage); 318 rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm); 319 } 320 return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm); 321 } 322 323 /** 324 * generate merkle tree of given input 325 * 326 * @param data original data 327 * @param originalDataSize data size 328 * @param digestSize algorithm output byte size 329 */ dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)330 private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) { 331 long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 332 int diffValue = (int) (fullChunkSize % CHUNK_SIZE); 333 if (diffValue > 0) { 334 byte[] padding = new byte[CHUNK_SIZE - diffValue]; 335 data.put(padding, 0, padding.length); 336 } 337 } 338 339 /** 340 * get mount of chunks to store data 341 * 342 * @param dataSize data size 343 * @param divisor split chunk size 344 * @return chunk count 345 */ getChunkCount(long dataSize, long divisor)346 private static long getChunkCount(long dataSize, long divisor) { 347 return (long) Math.ceil((double) dataSize / (double) divisor); 348 } 349 350 /** 351 * get total size of chunk to store data 352 * 353 * @param dataSize data size 354 * @param divisor split chunk size 355 * @param multiplier chunk multiplier 356 * @return chunk size 357 */ getFullChunkSize(long dataSize, long divisor, long multiplier)358 private static long getFullChunkSize(long dataSize, long divisor, long multiplier) { 359 return getChunkCount(dataSize, divisor) * multiplier; 360 } 361 } 362