• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# coding=utf-8
3
4#
5# Copyright (c) 2020-2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import os
20import hashlib
21
22from _core.logger import platform_logger
23from _core.exception import ParamError
24from _core.constants import FilePermission
25
26__all__ = ["check_pub_key_exist", "do_rsa_encrypt", "do_rsa_decrypt",
27           "generate_key_file", "get_file_summary"]
28
29PUBLIC_KEY_FILE = "config/pub.key"
30PRIVATE_KEY_FILE = "config/pri.key"
31LOG = platform_logger("Encrypt")
32
33
34def check_pub_key_exist():
35    from xdevice import Variables
36    if Variables.report_vars.pub_key_string:
37        return Variables.report_vars.pub_key_string
38
39    if Variables.report_vars.pub_key_file is not None:
40        if Variables.report_vars.pub_key_file == "":
41            return False
42        if not os.path.exists(Variables.report_vars.pub_key_file):
43            Variables.report_vars.pub_key_file = None
44            return False
45        return True
46
47    pub_key_path = os.path.join(Variables.exec_dir, PUBLIC_KEY_FILE)
48    if os.path.exists(pub_key_path):
49        Variables.report_vars.pub_key_file = pub_key_path
50        return True
51
52    pub_key_path = os.path.join(Variables.top_dir, PUBLIC_KEY_FILE)
53    if os.path.exists(pub_key_path):
54        Variables.report_vars.pub_key_file = pub_key_path
55    else:
56        Variables.report_vars.pub_key_file = ""
57    return Variables.report_vars.pub_key_file
58
59
60def do_rsa_encrypt(content):
61    try:
62        if not check_pub_key_exist() or not content:
63            return content
64
65        plain_text = content
66        if not isinstance(plain_text, bytes):
67            plain_text = str(content).encode(encoding='utf-8')
68
69        import rsa
70        from xdevice import Variables
71        if not Variables.report_vars.pub_key_string:
72            with open(Variables.report_vars.pub_key_file,
73                      'rb') as key_content:
74                Variables.report_vars.pub_key_string = key_content.read()
75
76        if isinstance(Variables.report_vars.pub_key_string, str):
77            Variables.report_vars.pub_key_string =\
78                bytes(Variables.report_vars.pub_key_string, "utf-8")
79
80        public_key = rsa.PublicKey.load_pkcs1_openssl_pem(
81            Variables.report_vars.pub_key_string)
82
83        max_encrypt_len = int(public_key.n.bit_length() / 8) - 11
84
85        # encrypt
86        cipher_text = b""
87        for frag in _get_frags(plain_text, max_encrypt_len):
88            cipher_text_frag = rsa.encrypt(frag, public_key)
89            cipher_text += cipher_text_frag
90        return cipher_text
91
92    except (ModuleNotFoundError, ValueError, TypeError, UnicodeError,
93            Exception) as error:
94        error_msg = "rsa encryption error occurs, %s" % error.args[0]
95        raise ParamError(error_msg, error_no="00113")
96
97
98def do_rsa_decrypt(content):
99    try:
100        if not check_pub_key_exist() or not content:
101            return content
102
103        cipher_text = content
104        if not isinstance(cipher_text, bytes):
105            cipher_text = str(content).encode()
106
107        import rsa
108        from xdevice import Variables
109        pri_key_path = os.path.join(Variables.exec_dir, PRIVATE_KEY_FILE)
110        if os.path.exists(pri_key_path):
111            pri_key_file = pri_key_path
112        else:
113            pri_key_file = os.path.join(Variables.top_dir, PRIVATE_KEY_FILE)
114        if not os.path.exists(pri_key_file):
115            return content
116        with open(pri_key_file, "rb") as key_content:
117            # get params
118            pri_key = rsa.PrivateKey.load_pkcs1(key_content.read())
119            max_decrypt_len = int(pri_key.n.bit_length() / 8)
120
121            try:
122                # decrypt
123                plain_text = b""
124                for frag in _get_frags(cipher_text, max_decrypt_len):
125                    plain_text_frag = rsa.decrypt(frag, pri_key)
126                    plain_text += plain_text_frag
127                return plain_text.decode(encoding='utf-8')
128            except rsa.pkcs1.CryptoError as error:
129                error_msg = "rsa decryption error occurs, %s" % error.args[0]
130                LOG.error(error_msg, error_no="00114")
131                return error_msg
132
133    except (ModuleNotFoundError, ValueError, TypeError, UnicodeError) as error:
134        error_msg = "rsa decryption error occurs, %s" % error.args[0]
135        LOG.error(error_msg, error_no="00114")
136        return error_msg
137
138
139def generate_key_file(length=2048):
140    try:
141        from rsa import key
142
143        if int(length) not in [1024, 2048, 3072, 4096]:
144            LOG.error("Length should be 1024, 2048, 3072 or 4096")
145            return
146
147        pub_key, pri_key = key.newkeys(int(length))
148        pub_key_pem = pub_key.save_pkcs1().decode()
149        pri_key_pem = pri_key.save_pkcs1().decode()
150
151        file_pri_open = os.open("pri.key", os.O_WRONLY | os.O_CREAT |
152                                os.O_APPEND, FilePermission.mode_755)
153        file_pub_open = os.open("pub.key", os.O_WRONLY | os.O_CREAT |
154                                os.O_APPEND, FilePermission.mode_755)
155        with os.fdopen(file_pri_open, "w") as file_pri, \
156                os.fdopen(file_pub_open, "w") as file_pub:
157            file_pri.write(pri_key_pem)
158            file_pri.flush()
159            file_pub.write(pub_key_pem)
160            file_pub.flush()
161    except ModuleNotFoundError as _:
162        return
163
164
165def get_file_summary(src_file, algorithm="sha256", buffer_size=100 * 1024):
166    if not os.path.exists(src_file):
167        LOG.error("File '%s' not exists!" % src_file)
168        return ""
169
170    # if the size of file is large, use this function
171    def _read_file(_src_file):
172        while True:
173            _data = _src_file.read(buffer_size)
174            if not _data:
175                break
176            yield _data
177
178    if hasattr(hashlib, algorithm):
179        algorithm_object = hashlib.new(algorithm)
180        try:
181            with open(file=src_file, mode="rb") as _file:
182                for data in _read_file(_file):
183                    algorithm_object.update(data)
184        except ValueError as error:
185            LOG.error("Read data from '%s' error: %s " % (
186                src_file, error.args))
187            return ""
188        return algorithm_object.hexdigest()
189    else:
190        LOG.error("The algorithm '%s' not in hashlib!" % algorithm)
191        return ""
192
193
194def _get_frags(text, max_len):
195    _text = text
196    while _text:
197        if len(_text) > max_len:
198            frag, _text = _text[:max_len], _text[max_len:]
199        else:
200            frag, _text = _text, ""
201        yield frag
202