1 /*
2 * Copyright (c) 2025-2025 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include <cmath>
16
17 #include "merkle_tree_builder.h"
18
19 using namespace OHOS::SignatureTools::Uscript;
20 namespace OHOS {
21 namespace SignatureTools {
22
23 const int MerkleTreeBuilder::FSVERITY_HASH_PAGE_SIZE = 4096;
24 const int64_t MerkleTreeBuilder::INPUTSTREAM_MAX_SIZE = 4503599627370496L;
25 const int MerkleTreeBuilder::CHUNK_SIZE = 4096;
26 const long MerkleTreeBuilder::MAX_READ_SIZE = 4194304L;
27 const int MerkleTreeBuilder::MAX_PROCESSORS = 32;
28 const int MerkleTreeBuilder::BLOCKINGQUEUE = 4;
29
SetAlgorithm(const std::string & algorithm)30 void MerkleTreeBuilder::SetAlgorithm(const std::string& algorithm)
31 {
32 mAlgorithm = algorithm;
33 }
34
TransInputStreamToHashData(std::istream & inputStream,long size,ByteBuffer * outputBuffer,int bufStartIdx)35 void MerkleTreeBuilder::TransInputStreamToHashData(std::istream& inputStream,
36 long size, ByteBuffer* outputBuffer, int bufStartIdx)
37 {
38 if (size == 0 || static_cast<int64_t>(size) > INPUTSTREAM_MAX_SIZE) {
39 SIGNATURE_TOOLS_LOGE("invalid input stream size");
40 CheckCalculateHashResult = false;
41 return;
42 }
43 std::vector<std::vector<int8_t>> hashes = GetDataHashes(inputStream, size);
44 int32_t writeSize = 0;
45 for (const auto& hash : hashes) {
46 outputBuffer->PutData(writeSize + bufStartIdx, hash.data(), hash.size());
47 writeSize += hash.size();
48 }
49 outputBuffer->SetLimit(outputBuffer->GetCapacity() - bufStartIdx);
50 outputBuffer->SetCapacity(outputBuffer->GetCapacity() - bufStartIdx);
51 outputBuffer->SetPosition(writeSize);
52 }
53
GetDataHashes(std::istream & inputStream,long size)54 std::vector<std::vector<int8_t>> MerkleTreeBuilder::GetDataHashes(std::istream& inputStream, long size)
55 {
56 int count = (int)(GetChunkCount(size, MAX_READ_SIZE));
57 int chunks = (int)(GetChunkCount(size, CHUNK_SIZE));
58 std::vector<std::vector<int8_t>> hashes(chunks);
59 std::vector<std::future<void>> thread_results;
60 long readOffset = 0L;
61 for (int i = 0; i < count; i++) {
62 long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
63 int readSize = (int)((readLimit - readOffset));
64 int fullChunkSize = (int)(GetFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE));
65 ByteBuffer* byteBuffer(new ByteBuffer(fullChunkSize));
66 std::vector<char> buffer(CHUNK_SIZE);
67 int num = 0;
68 int offset = 0;
69 int flag = 0;
70 int len = CHUNK_SIZE;
71 while (num > 0 || flag == 0) {
72 inputStream.read((buffer.data()), len);
73 if (inputStream.fail() && !inputStream.eof()) {
74 PrintErrorNumberMsg("IO_ERROR", IO_ERROR, "Error occurred while reading data");
75 CheckCalculateHashResult = false;
76 return std::vector<std::vector<int8_t>>();
77 }
78 num = inputStream.gcount();
79 byteBuffer->PutData(buffer.data(), num);
80 offset += num;
81 len = std::min(CHUNK_SIZE, readSize - offset);
82 if (len <= 0 || offset == readSize) {
83 break;
84 }
85 flag = 1;
86 }
87 if (offset != readSize) {
88 PrintErrorNumberMsg("READ_FILE_ERROR", IO_ERROR, "Error reading buffer from input");
89 CheckCalculateHashResult = false;
90 return std::vector<std::vector<int8_t>>();
91 }
92 byteBuffer->Flip();
93 int readChunkIndex = (int)(GetFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i));
94 thread_results.push_back(mPools->Enqueue(&MerkleTreeBuilder::RunHashTask, this, std::ref(hashes),
95 byteBuffer, readChunkIndex, 0));
96 readOffset += readSize;
97 }
98 for (auto& thread_result : thread_results) {
99 thread_result.wait();
100 }
101 return hashes;
102 }
103
Slice(ByteBuffer * buffer,int begin,int end)104 ByteBuffer* MerkleTreeBuilder::Slice(ByteBuffer* buffer, int begin, int end)
105 {
106 ByteBuffer* tmpBuffer = buffer->Duplicate();
107 tmpBuffer->SetPosition(0);
108 tmpBuffer->SetLimit(end);
109 tmpBuffer->SetPosition(begin);
110 return &tmpBuffer->slice_for_codesigning();
111 }
112
GetOffsetArrays(long dataSize,int digestSize)113 std::vector<int64_t> MerkleTreeBuilder::GetOffsetArrays(long dataSize, int digestSize)
114 {
115 std::vector<long> levelSize = GetLevelSize(dataSize, digestSize);
116 std::vector<int64_t> levelOffset(levelSize.size() + 1);
117 levelOffset[0] = 0;
118 for (long i = 0; i < levelSize.size(); i++) {
119 levelOffset[i + 1] = levelOffset[i] + levelSize[levelSize.size() - i - 1];
120 }
121 return levelOffset;
122 }
123
GetLevelSize(long dataSize,int digestSize)124 std::vector<long> MerkleTreeBuilder::GetLevelSize(long dataSize, int digestSize)
125 {
126 std::vector<long> levelSize;
127
128 long fullChunkSize = 0L;
129 long originalDataSize = dataSize;
130 do {
131 fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
132 int size = GetFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
133 levelSize.push_back(size);
134 originalDataSize = fullChunkSize;
135 } while (fullChunkSize > CHUNK_SIZE);
136 return levelSize;
137 }
138
RunHashTask(std::vector<std::vector<int8_t>> & hashes,ByteBuffer * buffer,int readChunkIndex,int bufStartIdx)139 void MerkleTreeBuilder::RunHashTask(std::vector<std::vector<int8_t>>& hashes,
140 ByteBuffer* buffer, int readChunkIndex, int bufStartIdx)
141 {
142 int offset = 0;
143
144 std::shared_ptr<ByteBuffer> bufferPtr(buffer);
145 int bufferSize = bufferPtr->GetCapacity();
146 int index = readChunkIndex;
147 while (offset < bufferSize) {
148 ByteBuffer* chunk = Slice(bufferPtr.get(), offset, offset + CHUNK_SIZE);
149 std::vector<int8_t> tmpByte(CHUNK_SIZE);
150 chunk->GetData(offset + bufStartIdx, tmpByte.data(), CHUNK_SIZE);
151 bool isCsSection = (csOffset != 0 && csOffset / 4096 == index);
152 if (isCsSection) {
153 SIGNATURE_TOOLS_LOGI("CsSection index = %d", index);
154 }
155 DigestUtils digestUtils(HASH_SHA256);
156 std::string tmpByteStr(tmpByte.begin(), tmpByte.end());
157 digestUtils.AddData(tmpByteStr);
158 std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
159 std::vector<int8_t> hashEle;
160 for (long i = 0; i < result.size(); i++) {
161 if (isCsSection) {
162 hashEle.push_back(0);
163 } else {
164 hashEle.push_back(result[i]);
165 }
166 }
167 hashes[index++] = hashEle;
168 offset += CHUNK_SIZE;
169 delete chunk;
170 }
171 }
172
TransInputDataToHashData(ByteBuffer * inputBuffer,ByteBuffer * outputBuffer,int64_t inputStartIdx,int64_t outputStartIdx)173 void MerkleTreeBuilder::TransInputDataToHashData(ByteBuffer* inputBuffer, ByteBuffer* outputBuffer,
174 int64_t inputStartIdx, int64_t outputStartIdx)
175 {
176 long size = inputBuffer->GetCapacity();
177 int chunks = (int)GetChunkCount(size, CHUNK_SIZE);
178 std::vector<std::vector<int8_t>> hashes(chunks);
179 long readOffset = 0L;
180 int startChunkIndex = 0;
181 while (readOffset < size) {
182 long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
183 ByteBuffer* buffer = Slice(inputBuffer, (int)readOffset, (int)readLimit);
184 buffer->SetPosition(0);
185 int readChunkIndex = startChunkIndex;
186 RunHashTask(hashes, buffer, readChunkIndex, inputStartIdx);
187 int readSize = (int)(readLimit - readOffset);
188 startChunkIndex += (int)GetChunkCount(readSize, CHUNK_SIZE);
189 readOffset += readSize;
190 inputStartIdx += readSize;
191 }
192 int32_t writeSize = 0;
193 for (const auto& hash : hashes) {
194 outputBuffer->PutData(writeSize + outputStartIdx, hash.data(), hash.size());
195 writeSize += hash.size();
196 }
197 }
198
MerkleTreeBuilder()199 OHOS::SignatureTools::MerkleTreeBuilder::MerkleTreeBuilder() :mPools(new Uscript::ThreadPool(POOL_SIZE))
200 {
201 CheckCalculateHashResult = true;
202 }
203
GenerateMerkleTree(std::istream & inputStream,long size,const FsVerityHashAlgorithm & fsVerityHashAlgorithm)204 MerkleTree* MerkleTreeBuilder::GenerateMerkleTree(std::istream& inputStream, long size,
205 const FsVerityHashAlgorithm& fsVerityHashAlgorithm)
206 {
207 SetAlgorithm(fsVerityHashAlgorithm.GetHashAlgorithm());
208 int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
209 std::vector<int64_t> offsetArrays = GetOffsetArrays(size, digestSize);
210 std::unique_ptr<ByteBuffer> allHashBuffer = std::make_unique<ByteBuffer>
211 (ByteBuffer(offsetArrays[offsetArrays.size() - 1]));
212 GenerateHashDataByInputData(inputStream, size, allHashBuffer.get(), offsetArrays, digestSize);
213 GenerateHashDataByHashData(allHashBuffer.get(), offsetArrays, digestSize);
214
215 if (CheckCalculateHashResult) {
216 return GetMerkleTree(allHashBuffer.get(), size, fsVerityHashAlgorithm);
217 }
218 return nullptr;
219 }
220
GenerateHashDataByInputData(std::istream & inputStream,long size,ByteBuffer * outputBuffer,std::vector<int64_t> & offsetArrays,int digestSize)221 void MerkleTreeBuilder::GenerateHashDataByInputData(std::istream& inputStream, long size, ByteBuffer* outputBuffer,
222 std::vector<int64_t>& offsetArrays, int digestSize)
223 {
224 int64_t inputDataOffsetBegin = offsetArrays[offsetArrays.size() - 2];
225 int64_t inputDataOffsetEnd = offsetArrays[offsetArrays.size() - 1];
226 ByteBuffer* hashBuffer = Slice(outputBuffer, 0, inputDataOffsetEnd);
227 TransInputStreamToHashData(inputStream, size, hashBuffer, inputDataOffsetBegin);
228 DataRoundupChunkSize(hashBuffer, size, digestSize);
229 delete hashBuffer;
230 }
231
GenerateHashDataByHashData(ByteBuffer * buffer,std::vector<int64_t> & offsetArrays,int digestSize)232 void MerkleTreeBuilder::GenerateHashDataByHashData(ByteBuffer* buffer,
233 std::vector<int64_t>& offsetArrays, int digestSize)
234 {
235 for (int i = offsetArrays.size() - 3; i >= 0; i--) {
236 int64_t generateOffset = offsetArrays[i];
237 int64_t originalOffset = offsetArrays[i + 1];
238 ByteBuffer* generateHashBuffer = Slice(buffer, offsetArrays[i], offsetArrays[i + 1] + offsetArrays[i]);
239 ByteBuffer* readOnlyBuffer = buffer->Duplicate();
240 ByteBuffer* originalHashBuffer = Slice(readOnlyBuffer, offsetArrays[i + 1], offsetArrays[i + 2]);
241 TransInputDataToHashData(originalHashBuffer, generateHashBuffer, originalOffset, generateOffset);
242 DataRoundupChunkSize(generateHashBuffer, originalHashBuffer->GetCapacity(), digestSize);
243 delete originalHashBuffer;
244 delete readOnlyBuffer;
245 delete generateHashBuffer;
246 }
247 }
248
GetMerkleTree(ByteBuffer * dataBuffer,long inputDataSize,FsVerityHashAlgorithm fsVerityHashAlgorithm)249 MerkleTree* MerkleTreeBuilder::GetMerkleTree(ByteBuffer* dataBuffer, long inputDataSize,
250 FsVerityHashAlgorithm fsVerityHashAlgorithm)
251 {
252 int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
253 dataBuffer->Flip();
254 std::vector<int8_t> rootHash;
255 std::vector<int8_t> tree;
256 if (inputDataSize <= FSVERITY_HASH_PAGE_SIZE) {
257 ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, digestSize);
258 if (fsVerityHashPageBuffer != nullptr) {
259 rootHash = std::vector<int8_t>(digestSize);
260 fsVerityHashPageBuffer->GetByte(rootHash.data(), digestSize);
261 delete fsVerityHashPageBuffer;
262 fsVerityHashPageBuffer = nullptr;
263 }
264 } else {
265 tree = std::vector<int8_t>(dataBuffer->GetBufferPtr(), dataBuffer->GetBufferPtr() + dataBuffer->GetCapacity());
266 ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, FSVERITY_HASH_PAGE_SIZE);
267 if (fsVerityHashPageBuffer != nullptr) {
268 std::vector<int8_t> fsVerityHashPage(FSVERITY_HASH_PAGE_SIZE);
269 fsVerityHashPageBuffer->GetData(0, fsVerityHashPage.data(), FSVERITY_HASH_PAGE_SIZE);
270 DigestUtils digestUtils(HASH_SHA256);
271 std::string fsVerityHashPageStr(fsVerityHashPage.begin(), fsVerityHashPage.end());
272 digestUtils.AddData(fsVerityHashPageStr);
273 std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
274 for (int i = 0; i < static_cast<int>(result.size()); i++) {
275 rootHash.push_back(result[i]);
276 }
277 delete fsVerityHashPageBuffer;
278 fsVerityHashPageBuffer = nullptr;
279 }
280 }
281
282 return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
283 }
284
DataRoundupChunkSize(ByteBuffer * data,long originalDataSize,int digestSize)285 void MerkleTreeBuilder::DataRoundupChunkSize(ByteBuffer* data, long originalDataSize, int digestSize)
286 {
287 long fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
288 int diffValue = (int)(fullChunkSize % CHUNK_SIZE);
289 if (diffValue > 0) {
290 data->SetPosition(data->GetPosition() + (CHUNK_SIZE - diffValue));
291 }
292 }
293
GetChunkCount(long dataSize,long divisor)294 long MerkleTreeBuilder::GetChunkCount(long dataSize, long divisor)
295 {
296 return (long)std::ceil((double)dataSize / (double)divisor);
297 }
298
GetFullChunkSize(long dataSize,long divisor,long multiplier)299 long MerkleTreeBuilder::GetFullChunkSize(long dataSize, long divisor, long multiplier)
300 {
301 return GetChunkCount(dataSize, divisor) * multiplier;
302 }
303 } // namespace SignatureTools
304 } // namespace OHOS