1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Copyright (c) 2023 Huawei Device Co., Ltd. 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17""" 18The tool for making module package. 19 20positional arguments: 21 target_package Target package file path. 22 update_package Update package file path. 23 -pn PACKAGE_NAME, --package_name PACKAGE_NAME 24 Module package name. 25 -pk PRIVATE_KEY, --private_key PRIVATE_KEY 26 Private key file path. 27 -sc SIGN_CERT, --sign_cert SIGN_CERT 28 Sign cert file path. 29""" 30import os 31import sys 32import argparse 33import subprocess 34import hashlib 35import zipfile 36import io 37import struct 38import logging 39 40from asn1crypto import cms 41from asn1crypto import pem 42from asn1crypto import util 43from asn1crypto import x509 44from cryptography.hazmat.backends import default_backend 45from cryptography.hazmat.primitives import serialization 46from cryptography.hazmat.primitives.asymmetric import padding 47from cryptography.hazmat.primitives import hashes 48 49 50# 1000000: max number of function recursion depth 51MAXIMUM_RECURSION_DEPTH = 1000000 52sys.setrecursionlimit(MAXIMUM_RECURSION_DEPTH) 53 54BLCOK_SIZE = 8192 55FOOTER_LENGTH = 6 56ZIP_ECOD_LENGTH = 22 57DIGEST_SHA256 = 672 58SHA256_HASH_LEN = 32 59 60CONTENT_INFO_FORMAT = "<2H32s" 61# the length of zip eocd comment 62ZIP_EOCD_COMMENT_LEN_FORMAT = "<H" 63# signed package footer 64SIGANTURE_FOOTER_FORMAT = "<3H" 65 66SIGN_TOOL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'zipalign.jar') 67 68 69def target_package_check(arg): 70 """ 71 Argument check, which is used to check whether the specified arg is a file path. 72 :param arg: the arg to check. 73 :return: Check result, which is False if the arg is invalid. 74 """ 75 if not os.path.isdir(arg): 76 UPDATE_LOGGER.print_log( 77 "Target package error, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 78 return False 79 return arg 80 81 82def package_name_check(arg): 83 """ 84 Argument check, which is used to check whether the specified arg is none. 85 :param arg: the arg to check. 86 :return: Check result, which is False if the arg is invalid. 87 """ 88 if arg is None: 89 UPDATE_LOGGER.print_log( 90 "Package name error: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 91 return False 92 return arg 93 94 95def private_key_check(arg): 96 """ 97 Argument check, which is used to check whether 98 the specified arg is a private key. 99 :param arg: the arg to check. 100 :return: Check result, which is False if the arg is invalid. 101 """ 102 if arg != "ON_SERVER" and not os.path.isfile(arg): 103 UPDATE_LOGGER.print_log( 104 "FileNotFoundError, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 105 return False 106 return arg 107 108 109def sign_cert_check(arg): 110 """ 111 Argument check, which is used to check whether 112 the specified arg is a sign cert. 113 :param arg: the arg to check. 114 :return: Check result, which is False if the arg is invalid. 115 """ 116 if arg != "ON_SERVER" and not os.path.isfile(arg): 117 UPDATE_LOGGER.print_log( 118 "FileNotFoundError, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 119 return False 120 return arg 121 122 123def check_update_package(arg): 124 """ 125 Argument check, which is used to check whether 126 the update package path exists. 127 :param arg: The arg to check. 128 :return: Check result 129 """ 130 make_dir_path = None 131 if os.path.exists(arg): 132 if os.path.isfile(arg): 133 UPDATE_LOGGER.print_log( 134 "Update package must be a dir path, not a file path. " 135 "path: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 136 return False 137 else: 138 try: 139 UPDATE_LOGGER.print_log( 140 "Update package path does not exist. The dir will be created!" 141 "path: %s" % arg, UPDATE_LOGGER.WARNING_LOG) 142 os.makedirs(arg) 143 make_dir_path = arg 144 except OSError: 145 UPDATE_LOGGER.print_log( 146 "Make update package path dir failed! " 147 "path: %s" % arg, UPDATE_LOGGER.ERROR_LOG) 148 return False 149 return arg 150 151 152class UpdateToolLogger: 153 """ 154 Global log class 155 """ 156 INFO_LOG = 'INFO_LOG' 157 WARNING_LOG = 'WARNING_LOG' 158 ERROR_LOG = 'ERROR_LOG' 159 LOG_TYPE = (INFO_LOG, WARNING_LOG, ERROR_LOG) 160 161 def __init__(self, output_type='console'): 162 self.__logger_obj = self.__get_logger_obj(output_type=output_type) 163 164 @staticmethod 165 def __get_logger_obj(output_type='console'): 166 ota_logger = logging.getLogger(__name__) 167 ota_logger.setLevel(level=logging.INFO) 168 formatter = logging.Formatter( 169 '%(asctime)s %(levelname)s : %(message)s', 170 "%Y-%m-%d %H:%M:%S") 171 if output_type == 'console': 172 console_handler = logging.StreamHandler() 173 console_handler.setLevel(logging.INFO) 174 console_handler.setFormatter(formatter) 175 ota_logger.addHandler(console_handler) 176 elif output_type == 'file': 177 file_handler = logging.FileHandler("UpdateToolLog.txt") 178 file_handler.setLevel(logging.INFO) 179 file_handler.setFormatter(formatter) 180 ota_logger.addHandler(file_handler) 181 return ota_logger 182 183 def print_log(self, msg, log_type=INFO_LOG): 184 """ 185 Print log information. 186 :param msg: log information 187 :param log_type: log type 188 :return: 189 """ 190 if log_type == self.LOG_TYPE[0]: 191 self.__logger_obj.info(msg) 192 elif log_type == self.LOG_TYPE[1]: 193 self.__logger_obj.warning(msg) 194 elif log_type == self.LOG_TYPE[2]: 195 self.__logger_obj.error(msg) 196 else: 197 self.__logger_obj.error("Unknown log type! %s", log_type) 198 return False 199 return True 200 201 def print_uncaught_exception_msg(self, msg, exc_info): 202 """ 203 Print log when an uncaught exception occurs. 204 :param msg: Uncaught exception 205 :param exc_info: information about the uncaught exception 206 """ 207 self.__logger_obj.error(msg, exc_info=exc_info) 208 209 210UPDATE_LOGGER = UpdateToolLogger() 211 212 213def load_public_cert(sign_cert): 214 with open(sign_cert, 'rb') as cert_file: 215 der_bytes = cert_file.read() 216 if pem.detect(der_bytes): 217 type_name, headers, der_bytes = pem.unarmor(der_bytes) 218 219 return x509.Certificate.load(der_bytes) 220 221 222def calculate_package_hash(package_path): 223 """ 224 :return: (hash) for path using hashlib.sha256() 225 """ 226 hash_sha256 = hashlib.sha256() 227 length = 0 228 229 remain_len = os.path.getsize(package_path) - ZIP_ECOD_LENGTH 230 with open(package_path, 'rb') as package_file: 231 while remain_len > BLCOK_SIZE: 232 hash_sha256.update(package_file.read(BLCOK_SIZE)) 233 remain_len -= BLCOK_SIZE 234 if remain_len > 0: 235 hash_sha256.update(package_file.read(remain_len)) 236 237 return hash_sha256.digest() 238 239 240def sign_digest_with_pss(digset, private_key_file): 241 # read private key from pem file 242 try: 243 with open(private_key_file, 'rb') as f_r: 244 key_data = f_r.read() 245 246 private_key = serialization.load_pem_private_key( 247 key_data, 248 password=None, 249 backend=default_backend()) 250 pad = padding.PSS( 251 mgf=padding.MGF1(hashes.SHA256()), 252 salt_length=padding.PSS.MAX_LENGTH) 253 254 signature = private_key.sign( 255 digset, 256 pad, 257 hashes.SHA256() 258 ) 259 except (OSError, ValueError): 260 return False 261 return signature 262 263 264def sign_digest(digset, private_key_file): 265 # read private key from pem file 266 try: 267 with open(private_key_file, 'rb') as f_r: 268 key_data = f_r.read() 269 270 private_key = serialization.load_pem_private_key( 271 key_data, 272 password=None, 273 backend=default_backend()) 274 275 signature = private_key.sign( 276 digset, 277 padding.PKCS1v15(), 278 hashes.SHA256() 279 ) 280 except (OSError, ValueError): 281 return False 282 return signature 283 284 285def create_encap_content_info(diget): 286 if not diget: 287 UPDATE_LOGGER.print_log("calc package hash failed! file: %s", 288 log_type=UPDATE_LOGGER.ERROR_LOG) 289 return False 290 content_header = struct.pack(CONTENT_INFO_FORMAT, DIGEST_SHA256, 291 SHA256_HASH_LEN, diget) 292 return content_header 293 294 295def write_signed_package(unsigned_package, signature, signed_package): 296 """ 297 :Write signature to signed package 298 """ 299 signature_size = len(signature) 300 signature_total_size = signature_size + FOOTER_LENGTH 301 302 package_fd = os.open(signed_package, os.O_RDWR | os.O_CREAT, 0o755) 303 f_signed = os.fdopen(package_fd, 'wb') 304 305 remain_len = os.path.getsize(unsigned_package) - 2 306 with open(unsigned_package, 'rb') as f_unsign: 307 while remain_len > BLCOK_SIZE: 308 f_signed.write(f_unsign.read(BLCOK_SIZE)) 309 remain_len -= BLCOK_SIZE 310 if remain_len > 0: 311 f_signed.write(f_unsign.read(remain_len)) 312 313 zip_comment_len = struct.pack(ZIP_EOCD_COMMENT_LEN_FORMAT, 314 signature_total_size) 315 f_signed.write(zip_comment_len) 316 317 f_signed.write(signature) 318 footter = struct.pack(SIGANTURE_FOOTER_FORMAT, signature_total_size, 319 0xffff, signature_total_size) 320 f_signed.write(footter) 321 f_signed.close() 322 323 324def sign_ota_package(package_path, signed_package, private_key, sign_cert): 325 digest = calculate_package_hash(package_path) 326 data = create_encap_content_info(digest) 327 signature = sign_digest(digest, private_key) 328 329 digest_fd = os.open("digest", os.O_RDWR | os.O_CREAT, 0o755) 330 digest_file = os.fdopen(digest_fd, 'wb') 331 digest_file.write(digest) 332 digest_file.close() 333 334 signatute_fd = os.open("signature", os.O_RDWR | os.O_CREAT, 0o755) 335 signatute_file = os.fdopen(signatute_fd, 'wb') 336 signatute_file.write(signature) 337 signatute_file.close() 338 339 # Creating a SignedData object from cms 340 signed_data = cms.SignedData() 341 signed_data['version'] = 'v1' 342 signed_data['encap_content_info'] = util.OrderedDict([ 343 ('content_type', 'data'), 344 ('content', data)]) 345 346 signed_data['digest_algorithms'] = [util.OrderedDict([ 347 ('algorithm', 'sha256'), 348 ('parameters', None)])] 349 350 cert = load_public_cert(sign_cert) 351 352 # Adding this certificate to SignedData object 353 signed_data['certificates'] = [cert] 354 355 # Setting signer info section 356 signer_info = cms.SignerInfo() 357 signer_info['version'] = 'v1' 358 signer_info['digest_algorithm'] = util.OrderedDict([ 359 ('algorithm', 'sha256'), 360 ('parameters', None)]) 361 signer_info['signature_algorithm'] = util.OrderedDict([ 362 ('algorithm', 'sha256_rsa'), 363 ('parameters', None)]) 364 365 issuer = cert.issuer 366 serial_number = cert.serial_number 367 issuer_and_serial = cms.IssuerAndSerialNumber() 368 issuer_and_serial['issuer'] = cert.issuer 369 issuer_and_serial['serial_number'] = cert.serial_number 370 371 key_id = cert.key_identifier_value.native 372 signer_info['sid'] = cms.SignerIdentifier({ 373 'issuer_and_serial_number': issuer_and_serial}) 374 375 signer_info['signature'] = signature 376 # Adding SignerInfo object to SignedData object 377 signed_data['signer_infos'] = [signer_info] 378 379 # Writing everything into ASN.1 object 380 asn1obj = cms.ContentInfo() 381 asn1obj['content_type'] = 'signed_data' 382 asn1obj['content'] = signed_data 383 384 # This asn1obj can be dumped to a disk using dump() method (DER format) 385 write_signed_package(package_path, asn1obj.dump(), signed_package) 386 return True 387 388 389def build_module_package(package_name, target_package, update_package, private_key, sign_cert): 390 unsigned_package = os.path.join( 391 update_package, '%s_unsigned.zip' % package_name) 392 393 zip_file = zipfile.ZipFile(unsigned_package, 'w') 394 # add module.img to update package 395 img_file_path = os.path.join(target_package, 'module.img') 396 if os.path.exists(img_file_path): 397 zip_file.write(img_file_path, "module.img") 398 # add config.json to update package 399 module_file_path = os.path.join(target_package, 'config.json') 400 zip_file.write(module_file_path, "config.json") 401 # add pub_key.pem to update package 402 pub_key_file_path = os.path.join(target_package, 'pub_key.pem') 403 if os.path.exists(pub_key_file_path): 404 zip_file.write(pub_key_file_path, 'pub_key.pem') 405 zip_file.close() 406 407 # align package 408 align_package = os.path.join( 409 update_package, '%s_align.zip' % package_name) 410 align_cmd = ['java', '-jar', SIGN_TOOL_PATH, unsigned_package, align_package, '4096'] 411 subprocess.call(align_cmd, shell=False) 412 if not os.path.exists(align_package): 413 UPDATE_LOGGER.print_log("align package failed", log_type=UPDATE_LOGGER.ERROR_LOG) 414 return False 415 416 # sign package 417 signed_package = os.path.join( 418 update_package, '%s.zip' % package_name) 419 if os.path.exists(signed_package): 420 os.remove(signed_package) 421 422 sign_result = sign_ota_package( 423 align_package, 424 signed_package, 425 private_key, 426 sign_cert) 427 428 if not sign_result: 429 UPDATE_LOGGER.print_log("Sign module package fail", UPDATE_LOGGER.ERROR_LOG) 430 return False 431 if os.path.exists(align_package): 432 os.remove(align_package) 433 if os.path.exists(unsigned_package): 434 os.remove(unsigned_package) 435 436 return True 437 438 439def main(argv): 440 """ 441 Entry function. 442 """ 443 parser = argparse.ArgumentParser() 444 parser.add_argument("target_package", type=target_package_check, 445 help="Target package file path.") 446 parser.add_argument("update_package", type=check_update_package, 447 help="Update package file path.") 448 parser.add_argument("-pk", "--private_key", type=private_key_check, 449 default=None, help="Private key file path.") 450 parser.add_argument("-sc", "--sign_cert", type=sign_cert_check, 451 default=None, help="Sign cert file path.") 452 parser.add_argument("-pn", "--package_name", type=package_name_check, 453 default=None, help="Package name.") 454 455 args = parser.parse_args(argv) 456 457 # Generate the module package. 458 build_re = build_module_package(args.package_name, args.target_package, 459 args.update_package, args.private_key, args.sign_cert) 460 461if __name__ == '__main__': 462 main(sys.argv[1:]) 463