• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Copyright (c) 2023 Huawei Device Co., Ltd.
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""
18The tool for making module package.
19
20positional arguments:
21  target_package        Target package file path.
22  update_package        Update package file path.
23  -pn PACKAGE_NAME, --package_name PACKAGE_NAME
24                        Module package name.
25  -pk PRIVATE_KEY, --private_key PRIVATE_KEY
26                        Private key file path.
27  -sc SIGN_CERT, --sign_cert SIGN_CERT
28                        Sign cert file path.
29"""
30import os
31import sys
32import argparse
33import subprocess
34import hashlib
35import zipfile
36import io
37import struct
38import logging
39
40from asn1crypto import cms
41from asn1crypto import pem
42from asn1crypto import util
43from asn1crypto import x509
44from cryptography.hazmat.backends import default_backend
45from cryptography.hazmat.primitives import serialization
46from cryptography.hazmat.primitives.asymmetric import padding
47from cryptography.hazmat.primitives import hashes
48
49
50# 1000000: max number of function recursion depth
51MAXIMUM_RECURSION_DEPTH = 1000000
52sys.setrecursionlimit(MAXIMUM_RECURSION_DEPTH)
53
54BLCOK_SIZE = 8192
55FOOTER_LENGTH = 6
56ZIP_ECOD_LENGTH = 22
57DIGEST_SHA256 = 672
58SHA256_HASH_LEN = 32
59
60CONTENT_INFO_FORMAT = "<2H32s"
61# the length of zip eocd comment
62ZIP_EOCD_COMMENT_LEN_FORMAT = "<H"
63# signed package footer
64SIGANTURE_FOOTER_FORMAT = "<3H"
65
66SIGN_TOOL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'zipalign.jar')
67
68
69def target_package_check(arg):
70    """
71    Argument check, which is used to check whether the specified arg is a file path.
72    :param arg: the arg to check.
73    :return:  Check result, which is False if the arg is invalid.
74    """
75    if not os.path.isdir(arg):
76        UPDATE_LOGGER.print_log(
77            "Target package error, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
78        return False
79    return arg
80
81
82def package_name_check(arg):
83    """
84    Argument check, which is used to check whether the specified arg is none.
85    :param arg: the arg to check.
86    :return: Check result, which is False if the arg is invalid.
87    """
88    if arg is None:
89        UPDATE_LOGGER.print_log(
90            "Package name error: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
91        return False
92    return arg
93
94
95def private_key_check(arg):
96    """
97    Argument check, which is used to check whether
98    the specified arg is a private key.
99    :param arg: the arg to check.
100    :return: Check result, which is False if the arg is invalid.
101    """
102    if arg != "ON_SERVER" and not os.path.isfile(arg):
103        UPDATE_LOGGER.print_log(
104            "FileNotFoundError, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
105        return False
106    return arg
107
108
109def sign_cert_check(arg):
110    """
111    Argument check, which is used to check whether
112    the specified arg is a sign cert.
113    :param arg: the arg to check.
114    :return: Check result, which is False if the arg is invalid.
115    """
116    if arg != "ON_SERVER" and not os.path.isfile(arg):
117        UPDATE_LOGGER.print_log(
118            "FileNotFoundError, path: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
119        return False
120    return arg
121
122
123def check_update_package(arg):
124    """
125    Argument check, which is used to check whether
126    the update package path exists.
127    :param arg: The arg to check.
128    :return: Check result
129    """
130    make_dir_path = None
131    if os.path.exists(arg):
132        if os.path.isfile(arg):
133            UPDATE_LOGGER.print_log(
134                "Update package must be a dir path, not a file path. "
135                "path: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
136            return False
137    else:
138        try:
139            UPDATE_LOGGER.print_log(
140                "Update package path does not exist. The dir will be created!"
141                "path: %s" % arg, UPDATE_LOGGER.WARNING_LOG)
142            os.makedirs(arg)
143            make_dir_path = arg
144        except OSError:
145            UPDATE_LOGGER.print_log(
146                "Make update package path dir failed! "
147                "path: %s" % arg, UPDATE_LOGGER.ERROR_LOG)
148            return False
149    return arg
150
151
152class UpdateToolLogger:
153    """
154    Global log class
155    """
156    INFO_LOG = 'INFO_LOG'
157    WARNING_LOG = 'WARNING_LOG'
158    ERROR_LOG = 'ERROR_LOG'
159    LOG_TYPE = (INFO_LOG, WARNING_LOG, ERROR_LOG)
160
161    def __init__(self, output_type='console'):
162        self.__logger_obj = self.__get_logger_obj(output_type=output_type)
163
164    @staticmethod
165    def __get_logger_obj(output_type='console'):
166        ota_logger = logging.getLogger(__name__)
167        ota_logger.setLevel(level=logging.INFO)
168        formatter = logging.Formatter(
169            '%(asctime)s %(levelname)s : %(message)s',
170            "%Y-%m-%d %H:%M:%S")
171        if output_type == 'console':
172            console_handler = logging.StreamHandler()
173            console_handler.setLevel(logging.INFO)
174            console_handler.setFormatter(formatter)
175            ota_logger.addHandler(console_handler)
176        elif output_type == 'file':
177            file_handler = logging.FileHandler("UpdateToolLog.txt")
178            file_handler.setLevel(logging.INFO)
179            file_handler.setFormatter(formatter)
180            ota_logger.addHandler(file_handler)
181        return ota_logger
182
183    def print_log(self, msg, log_type=INFO_LOG):
184        """
185        Print log information.
186        :param msg: log information
187        :param log_type: log type
188        :return:
189        """
190        if log_type == self.LOG_TYPE[0]:
191            self.__logger_obj.info(msg)
192        elif log_type == self.LOG_TYPE[1]:
193            self.__logger_obj.warning(msg)
194        elif log_type == self.LOG_TYPE[2]:
195            self.__logger_obj.error(msg)
196        else:
197            self.__logger_obj.error("Unknown log type! %s", log_type)
198            return False
199        return True
200
201    def print_uncaught_exception_msg(self, msg, exc_info):
202        """
203        Print log when an uncaught exception occurs.
204        :param msg: Uncaught exception
205        :param exc_info: information about the uncaught exception
206        """
207        self.__logger_obj.error(msg, exc_info=exc_info)
208
209
210UPDATE_LOGGER = UpdateToolLogger()
211
212
213def load_public_cert(sign_cert):
214    with open(sign_cert, 'rb') as cert_file:
215        der_bytes = cert_file.read()
216        if pem.detect(der_bytes):
217            type_name, headers, der_bytes = pem.unarmor(der_bytes)
218
219    return x509.Certificate.load(der_bytes)
220
221
222def calculate_package_hash(package_path):
223    """
224    :return: (hash) for path using hashlib.sha256()
225    """
226    hash_sha256 = hashlib.sha256()
227    length = 0
228
229    remain_len = os.path.getsize(package_path) - ZIP_ECOD_LENGTH
230    with open(package_path, 'rb') as package_file:
231        while remain_len > BLCOK_SIZE:
232            hash_sha256.update(package_file.read(BLCOK_SIZE))
233            remain_len -= BLCOK_SIZE
234        if remain_len > 0:
235            hash_sha256.update(package_file.read(remain_len))
236
237    return hash_sha256.digest()
238
239
240def sign_digest_with_pss(digset, private_key_file):
241    # read private key from pem file
242    try:
243        with open(private_key_file, 'rb') as f_r:
244            key_data = f_r.read()
245
246        private_key = serialization.load_pem_private_key(
247            key_data,
248            password=None,
249            backend=default_backend())
250        pad = padding.PSS(
251            mgf=padding.MGF1(hashes.SHA256()),
252            salt_length=padding.PSS.MAX_LENGTH)
253
254        signature = private_key.sign(
255            digset,
256            pad,
257            hashes.SHA256()
258        )
259    except (OSError, ValueError):
260        return False
261    return signature
262
263
264def sign_digest(digset, private_key_file):
265    # read private key from pem file
266    try:
267        with open(private_key_file, 'rb') as f_r:
268            key_data = f_r.read()
269
270        private_key = serialization.load_pem_private_key(
271            key_data,
272            password=None,
273            backend=default_backend())
274
275        signature = private_key.sign(
276            digset,
277            padding.PKCS1v15(),
278            hashes.SHA256()
279        )
280    except (OSError, ValueError):
281        return False
282    return signature
283
284
285def create_encap_content_info(diget):
286    if not diget:
287        UPDATE_LOGGER.print_log("calc package hash failed! file: %s",
288            log_type=UPDATE_LOGGER.ERROR_LOG)
289        return False
290    content_header = struct.pack(CONTENT_INFO_FORMAT, DIGEST_SHA256,
291        SHA256_HASH_LEN, diget)
292    return content_header
293
294
295def write_signed_package(unsigned_package, signature, signed_package):
296    """
297    :Write signature to signed package
298    """
299    signature_size = len(signature)
300    signature_total_size = signature_size + FOOTER_LENGTH
301
302    package_fd = os.open(signed_package, os.O_RDWR | os.O_CREAT, 0o755)
303    f_signed = os.fdopen(package_fd, 'wb')
304
305    remain_len = os.path.getsize(unsigned_package) - 2
306    with open(unsigned_package, 'rb') as f_unsign:
307        while remain_len > BLCOK_SIZE:
308            f_signed.write(f_unsign.read(BLCOK_SIZE))
309            remain_len -= BLCOK_SIZE
310        if remain_len > 0:
311            f_signed.write(f_unsign.read(remain_len))
312
313    zip_comment_len = struct.pack(ZIP_EOCD_COMMENT_LEN_FORMAT,
314            signature_total_size)
315    f_signed.write(zip_comment_len)
316
317    f_signed.write(signature)
318    footter = struct.pack(SIGANTURE_FOOTER_FORMAT, signature_total_size,
319            0xffff, signature_total_size)
320    f_signed.write(footter)
321    f_signed.close()
322
323
324def sign_ota_package(package_path, signed_package, private_key, sign_cert):
325    digest = calculate_package_hash(package_path)
326    data = create_encap_content_info(digest)
327    signature = sign_digest(digest, private_key)
328
329    digest_fd = os.open("digest", os.O_RDWR | os.O_CREAT, 0o755)
330    digest_file = os.fdopen(digest_fd, 'wb')
331    digest_file.write(digest)
332    digest_file.close()
333
334    signatute_fd = os.open("signature", os.O_RDWR | os.O_CREAT, 0o755)
335    signatute_file = os.fdopen(signatute_fd, 'wb')
336    signatute_file.write(signature)
337    signatute_file.close()
338
339    # Creating a SignedData object from cms
340    signed_data = cms.SignedData()
341    signed_data['version'] = 'v1'
342    signed_data['encap_content_info'] = util.OrderedDict([
343        ('content_type', 'data'),
344        ('content', data)])
345
346    signed_data['digest_algorithms'] = [util.OrderedDict([
347        ('algorithm', 'sha256'),
348        ('parameters', None)])]
349
350    cert = load_public_cert(sign_cert)
351
352    # Adding this certificate to SignedData object
353    signed_data['certificates'] = [cert]
354
355    # Setting signer info section
356    signer_info = cms.SignerInfo()
357    signer_info['version'] = 'v1'
358    signer_info['digest_algorithm'] = util.OrderedDict([
359                ('algorithm', 'sha256'),
360                ('parameters', None)])
361    signer_info['signature_algorithm'] = util.OrderedDict([
362                ('algorithm', 'sha256_rsa'),
363                ('parameters', None)])
364
365    issuer = cert.issuer
366    serial_number = cert.serial_number
367    issuer_and_serial = cms.IssuerAndSerialNumber()
368    issuer_and_serial['issuer'] = cert.issuer
369    issuer_and_serial['serial_number'] = cert.serial_number
370
371    key_id = cert.key_identifier_value.native
372    signer_info['sid'] = cms.SignerIdentifier({
373        'issuer_and_serial_number': issuer_and_serial})
374
375    signer_info['signature'] = signature
376    # Adding SignerInfo object to SignedData object
377    signed_data['signer_infos'] = [signer_info]
378
379    # Writing everything into ASN.1 object
380    asn1obj = cms.ContentInfo()
381    asn1obj['content_type'] = 'signed_data'
382    asn1obj['content'] = signed_data
383
384    # This asn1obj can be dumped to a disk using dump() method (DER format)
385    write_signed_package(package_path, asn1obj.dump(), signed_package)
386    return True
387
388
389def build_module_package(package_name, target_package, update_package, private_key, sign_cert):
390    unsigned_package = os.path.join(
391        update_package, '%s_unsigned.zip' % package_name)
392
393    zip_file = zipfile.ZipFile(unsigned_package, 'w')
394    # add module.img to update package
395    img_file_path = os.path.join(target_package, 'module.img')
396    if os.path.exists(img_file_path):
397        zip_file.write(img_file_path, "module.img")
398    # add config.json to update package
399    module_file_path = os.path.join(target_package, 'config.json')
400    zip_file.write(module_file_path, "config.json")
401    # add pub_key.pem to update package
402    pub_key_file_path = os.path.join(target_package, 'pub_key.pem')
403    if os.path.exists(pub_key_file_path):
404        zip_file.write(pub_key_file_path, 'pub_key.pem')
405    zip_file.close()
406
407    # align package
408    align_package = os.path.join(
409        update_package, '%s_align.zip' % package_name)
410    align_cmd = ['java', '-jar', SIGN_TOOL_PATH, unsigned_package, align_package, '4096']
411    subprocess.call(align_cmd, shell=False)
412    if not os.path.exists(align_package):
413        UPDATE_LOGGER.print_log("align package failed", log_type=UPDATE_LOGGER.ERROR_LOG)
414        return False
415
416    # sign package
417    signed_package = os.path.join(
418        update_package, '%s.zip' % package_name)
419    if os.path.exists(signed_package):
420        os.remove(signed_package)
421
422    sign_result = sign_ota_package(
423        align_package,
424        signed_package,
425        private_key,
426        sign_cert)
427
428    if not sign_result:
429        UPDATE_LOGGER.print_log("Sign module package fail", UPDATE_LOGGER.ERROR_LOG)
430        return False
431    if os.path.exists(align_package):
432        os.remove(align_package)
433    if os.path.exists(unsigned_package):
434        os.remove(unsigned_package)
435
436    return True
437
438
439def main(argv):
440    """
441    Entry function.
442    """
443    parser = argparse.ArgumentParser()
444    parser.add_argument("target_package", type=target_package_check,
445                        help="Target package file path.")
446    parser.add_argument("update_package", type=check_update_package,
447                        help="Update package file path.")
448    parser.add_argument("-pk", "--private_key", type=private_key_check,
449                        default=None, help="Private key file path.")
450    parser.add_argument("-sc", "--sign_cert", type=sign_cert_check,
451                        default=None, help="Sign cert file path.")
452    parser.add_argument("-pn", "--package_name", type=package_name_check,
453                        default=None, help="Package name.")
454
455    args = parser.parse_args(argv)
456
457    # Generate the module package.
458    build_re = build_module_package(args.package_name, args.target_package,
459        args.update_package, args.private_key, args.sign_cert)
460
461if __name__ == '__main__':
462    main(sys.argv[1:])
463