1 /* 2 * Copyright (c) 2023-2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 package com.ohos.hapsigntool.codesigning.fsverity; 17 18 import com.ohos.hapsigntool.codesigning.utils.DigestUtils; 19 20 import java.io.IOException; 21 import java.io.InputStream; 22 import java.nio.ByteBuffer; 23 import java.security.NoSuchAlgorithmException; 24 import java.util.ArrayList; 25 import java.util.concurrent.ArrayBlockingQueue; 26 import java.util.concurrent.ExecutorService; 27 import java.util.concurrent.Phaser; 28 import java.util.concurrent.ThreadPoolExecutor; 29 import java.util.concurrent.TimeUnit; 30 31 /** 32 * Merkle tree builder 33 * 34 * @since 2023/06/05 35 */ 36 public class MerkleTreeBuilder implements AutoCloseable { 37 private static final int FSVERITY_HASH_PAGE_SIZE = 4096; 38 39 private static final long INPUTSTREAM_MAX_SIZE = 4503599627370496L; 40 41 private static final int CHUNK_SIZE = 4096; 42 43 private static final long MAX_READ_SIZE = 4194304L; 44 45 private static final int MAX_PROCESSORS = 32; 46 47 private static final int BLOCKINGQUEUE = 4; 48 49 private static final int POOL_SIZE = Math.min(MAX_PROCESSORS, Runtime.getRuntime().availableProcessors()); 50 51 private String mAlgorithm = "SHA-256"; 52 53 private final ExecutorService mPools = new ThreadPoolExecutor(POOL_SIZE, POOL_SIZE, 0L, TimeUnit.MILLISECONDS, 54 new ArrayBlockingQueue<>(BLOCKINGQUEUE), new ThreadPoolExecutor.CallerRunsPolicy()); 55 56 /** 57 * Turn off multitasking 58 */ close()59 public void close() { 60 this.mPools.shutdownNow(); 61 } 62 63 /** 64 * set algorithm 65 * 66 * @param algorithm hash algorithm 67 */ setAlgorithm(String algorithm)68 private void setAlgorithm(String algorithm) { 69 this.mAlgorithm = algorithm; 70 } 71 72 /** 73 * translation inputStream to hash data 74 * 75 * @param inputStream input stream for generating merkle tree 76 * @param size total size of input stream 77 * @param outputBuffer hash data 78 * @throws IOException if error 79 */ transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer)80 private void transInputStreamToHashData(InputStream inputStream, long size, ByteBuffer outputBuffer) 81 throws IOException { 82 if (size == 0) { 83 throw new IOException("Input size is empty"); 84 } else if (size > INPUTSTREAM_MAX_SIZE) { 85 throw new IOException("Input size is too long"); 86 } 87 int count = (int) getChunkCount(size, MAX_READ_SIZE); 88 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 89 byte[][] hashes = new byte[chunks][]; 90 long readOffset = 0L; 91 Phaser tasks = new Phaser(1); 92 for (int i = 0; i < count; i++) { 93 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 94 int readSize = (int) (readLimit - readOffset); 95 int fullChunkSize = (int) getFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE); 96 97 ByteBuffer byteBuffer = ByteBuffer.allocate(fullChunkSize); 98 byte[] buffer = new byte[CHUNK_SIZE]; 99 int num; 100 int offset = 0; 101 int len = CHUNK_SIZE; 102 while ((num = inputStream.read(buffer, 0, len)) > 0) { 103 byteBuffer.put(buffer, 0, num); 104 offset += num; 105 len = Math.min(CHUNK_SIZE, readSize - offset); 106 if (len <= 0 || offset == readSize) { 107 break; 108 } 109 } 110 if (offset != readSize) { 111 throw new IOException("IOException read buffer from input errorLHJ."); 112 } 113 byteBuffer.flip(); 114 int readChunkIndex = (int) getFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i); 115 runHashTask(hashes, tasks, byteBuffer, readChunkIndex); 116 readOffset += readSize; 117 } 118 tasks.arriveAndAwaitAdvance(); 119 for (byte[] hash : hashes) { 120 outputBuffer.put(hash, 0, hash.length); 121 } 122 } 123 124 /** 125 * split buffer by begin and end information 126 * 127 * @param buffer original buffer 128 * @param begin begin position 129 * @param end end position 130 * @return slice buffer 131 */ slice(ByteBuffer buffer, int begin, int end)132 private static ByteBuffer slice(ByteBuffer buffer, int begin, int end) { 133 ByteBuffer tempBuffer = buffer.duplicate(); 134 tempBuffer.position(0); 135 tempBuffer.limit(end); 136 tempBuffer.position(begin); 137 return tempBuffer.slice(); 138 } 139 140 /** 141 * calculate merkle tree level and size by data size and digest size 142 * 143 * @param dataSize original data size 144 * @param digestSize algorithm data size 145 * @return level offset list, contains the offset of 146 * each level from the root node to the leaf node 147 */ getOffsetArrays(long dataSize, int digestSize)148 private static int[] getOffsetArrays(long dataSize, int digestSize) { 149 ArrayList<Long> levelSize = getLevelSize(dataSize, digestSize); 150 int[] levelOffset = new int[levelSize.size() + 1]; 151 levelOffset[0] = 0; 152 for (int i = 0; i < levelSize.size(); i++) { 153 levelOffset[i + 1] = levelOffset[i] + Math.toIntExact(levelSize.get(levelSize.size() - i - 1)); 154 } 155 return levelOffset; 156 } 157 158 /** 159 * calculate data size list by data size and digest size 160 * 161 * @param dataSize original data size 162 * @param digestSize algorithm data size 163 * @return data size list, contains the offset of 164 * each level from the root node to the leaf node 165 */ getLevelSize(long dataSize, int digestSize)166 private static ArrayList<Long> getLevelSize(long dataSize, int digestSize) { 167 ArrayList<Long> levelSize = new ArrayList<>(); 168 long fullChunkSize = 0L; 169 long originalDataSize = dataSize; 170 do { 171 fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 172 long size = getFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE); 173 levelSize.add(size); 174 originalDataSize = fullChunkSize; 175 } while (fullChunkSize > CHUNK_SIZE); 176 return levelSize; 177 } 178 runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex)179 private void runHashTask(byte[][] hashes, Phaser tasks, ByteBuffer buffer, int readChunkIndex) { 180 Runnable task = () -> { 181 int offset = 0; 182 int bufferSize = buffer.capacity(); 183 int index = readChunkIndex; 184 while (offset < bufferSize) { 185 ByteBuffer chunk = slice(buffer, offset, offset + CHUNK_SIZE); 186 byte[] tempByte = new byte[CHUNK_SIZE]; 187 chunk.get(tempByte); 188 try { 189 hashes[index++] = DigestUtils.computeDigest(tempByte, this.mAlgorithm); 190 } catch (NoSuchAlgorithmException e) { 191 throw new IllegalStateException(e); 192 } 193 offset += CHUNK_SIZE; 194 } 195 tasks.arriveAndDeregister(); 196 }; 197 tasks.register(); 198 this.mPools.execute(task); 199 } 200 201 /** 202 * hash data of buffer 203 * 204 * @param inputBuffer original data 205 * @param outputBuffer hash data 206 */ transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer)207 private void transInputDataToHashData(ByteBuffer inputBuffer, ByteBuffer outputBuffer) { 208 long size = inputBuffer.capacity(); 209 int chunks = (int) getChunkCount(size, CHUNK_SIZE); 210 byte[][] hashes = new byte[chunks][]; 211 Phaser tasks = new Phaser(1); 212 long readOffset = 0L; 213 int startChunkIndex = 0; 214 while (readOffset < size) { 215 long readLimit = Math.min(readOffset + MAX_READ_SIZE, size); 216 ByteBuffer buffer = slice(inputBuffer, (int) readOffset, (int) readLimit); 217 buffer.rewind(); 218 int readChunkIndex = startChunkIndex; 219 runHashTask(hashes, tasks, buffer, readChunkIndex); 220 int readSize = (int) (readLimit - readOffset); 221 startChunkIndex += (int) getChunkCount(readSize, CHUNK_SIZE); 222 readOffset += readSize; 223 } 224 tasks.arriveAndAwaitAdvance(); 225 for (byte[] hash : hashes) { 226 outputBuffer.put(hash, 0, hash.length); 227 } 228 } 229 230 /** 231 * generate merkle tree of given input 232 * 233 * @param inputStream input stream for generate merkle tree 234 * @param size total size of input stream 235 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 236 * @return merkle tree 237 * @throws IOException if error 238 * @throws NoSuchAlgorithmException if error 239 */ generateMerkleTree(InputStream inputStream, long size, FsVerityHashAlgorithm fsVerityHashAlgorithm)240 public MerkleTree generateMerkleTree(InputStream inputStream, long size, 241 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws IOException, NoSuchAlgorithmException { 242 setAlgorithm(fsVerityHashAlgorithm.getHashAlgorithm()); 243 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 244 int[] offsetArrays = getOffsetArrays(size, digestSize); 245 ByteBuffer allHashBuffer = ByteBuffer.allocate(offsetArrays[offsetArrays.length - 1]); 246 generateHashDataByInputData(inputStream, size, allHashBuffer, offsetArrays, digestSize); 247 generateHashDataByHashData(allHashBuffer, offsetArrays, digestSize); 248 return getMerkleTree(allHashBuffer, size, fsVerityHashAlgorithm); 249 } 250 251 /** 252 * translation inputBuffer arrays to hash ByteBuffer 253 * 254 * @param inputStream input stream for generate merkle tree 255 * @param size total size of input stream 256 * @param outputBuffer hash data 257 * @param offsetArrays level offset 258 * @param digestSize algorithm output byte size 259 * @throws IOException if error 260 */ generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, int[] offsetArrays, int digestSize)261 private void generateHashDataByInputData(InputStream inputStream, long size, ByteBuffer outputBuffer, 262 int[] offsetArrays, int digestSize) throws IOException { 263 int inputDataOffsetBegin = offsetArrays[offsetArrays.length - 2]; 264 int inputDataOffsetEnd = offsetArrays[offsetArrays.length - 1]; 265 ByteBuffer hashBuffer = slice(outputBuffer, inputDataOffsetBegin, inputDataOffsetEnd); 266 transInputStreamToHashData(inputStream, size, hashBuffer); 267 dataRoundupChunkSize(hashBuffer, size, digestSize); 268 } 269 270 /** 271 * get buffer data by level offset, transforms digest data, save in another 272 * memory 273 * 274 * @param buffer hash data 275 * @param offsetArrays level offset 276 * @param digestSize algorithm output byte size 277 */ generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize)278 private void generateHashDataByHashData(ByteBuffer buffer, int[] offsetArrays, int digestSize) { 279 for (int i = offsetArrays.length - 3; i >= 0; i--) { 280 ByteBuffer generateHashBuffer = slice(buffer, offsetArrays[i], offsetArrays[i + 1]); 281 ByteBuffer originalHashBuffer = slice(buffer.asReadOnlyBuffer(), offsetArrays[i + 1], offsetArrays[i + 2]); 282 transInputDataToHashData(originalHashBuffer, generateHashBuffer); 283 dataRoundupChunkSize(generateHashBuffer, originalHashBuffer.capacity(), digestSize); 284 } 285 } 286 287 /** 288 * generate merkle tree of given input 289 * 290 * @param dataBuffer tree data memory block 291 * @param inputDataSize total size of input stream 292 * @param fsVerityHashAlgorithm hash algorithm for FsVerity 293 * @return merkle tree 294 * @throws NoSuchAlgorithmException if error 295 */ getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)296 private MerkleTree getMerkleTree(ByteBuffer dataBuffer, long inputDataSize, 297 FsVerityHashAlgorithm fsVerityHashAlgorithm) throws NoSuchAlgorithmException { 298 int digestSize = fsVerityHashAlgorithm.getOutputByteSize(); 299 dataBuffer.flip(); 300 byte[] rootHash = null; 301 byte[] tree = null; 302 if (inputDataSize < FSVERITY_HASH_PAGE_SIZE) { 303 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer, 0, digestSize); 304 rootHash = new byte[digestSize]; 305 fsVerityHashPageBuffer.get(rootHash); 306 } else { 307 tree = dataBuffer.array(); 308 ByteBuffer fsVerityHashPageBuffer = slice(dataBuffer.asReadOnlyBuffer(), 0, FSVERITY_HASH_PAGE_SIZE); 309 byte[] fsVerityHashPage = new byte[FSVERITY_HASH_PAGE_SIZE]; 310 fsVerityHashPageBuffer.get(fsVerityHashPage); 311 rootHash = DigestUtils.computeDigest(fsVerityHashPage, this.mAlgorithm); 312 } 313 return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm); 314 } 315 316 /** 317 * generate merkle tree of given input 318 * 319 * @param data original data 320 * @param originalDataSize data size 321 * @param digestSize algorithm output byte size 322 */ dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize)323 private void dataRoundupChunkSize(ByteBuffer data, long originalDataSize, int digestSize) { 324 long fullChunkSize = getFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize); 325 int diffValue = (int) (fullChunkSize % CHUNK_SIZE); 326 if (diffValue > 0) { 327 byte[] padding = new byte[CHUNK_SIZE - diffValue]; 328 data.put(padding, 0, padding.length); 329 } 330 } 331 332 /** 333 * get mount of chunks to store data 334 * 335 * @param dataSize data size 336 * @param divisor split chunk size 337 * @return chunk count 338 */ getChunkCount(long dataSize, long divisor)339 private static long getChunkCount(long dataSize, long divisor) { 340 return (long) Math.ceil((double) dataSize / (double) divisor); 341 } 342 343 /** 344 * get total size of chunk to store data 345 * 346 * @param dataSize data size 347 * @param divisor split chunk size 348 * @param multiplier chunk multiplier 349 * @return chunk size 350 */ getFullChunkSize(long dataSize, long divisor, long multiplier)351 private static long getFullChunkSize(long dataSize, long divisor, long multiplier) { 352 return getChunkCount(dataSize, divisor) * multiplier; 353 } 354 } 355