• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# coding=utf-8
3
4#
5# Copyright (c) 2020-2022 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import os
20import hashlib
21
22from _core.constants import FilePermission
23from _core.error import ErrorMessage
24from _core.exception import ParamError
25from _core.logger import platform_logger
26
27__all__ = ["check_pub_key_exist", "do_rsa_encrypt", "do_rsa_decrypt",
28           "generate_key_file", "get_file_summary"]
29
30PUBLIC_KEY_FILE = "config/pub.key"
31PRIVATE_KEY_FILE = "config/pri.key"
32LOG = platform_logger("Encrypt")
33
34
35def check_pub_key_exist():
36    from xdevice import Variables
37    if Variables.report_vars.pub_key_string:
38        return Variables.report_vars.pub_key_string
39
40    if Variables.report_vars.pub_key_file is not None:
41        if Variables.report_vars.pub_key_file == "":
42            return False
43        if not os.path.exists(Variables.report_vars.pub_key_file):
44            Variables.report_vars.pub_key_file = None
45            return False
46        return True
47
48    pub_key_path = os.path.join(Variables.exec_dir, PUBLIC_KEY_FILE)
49    if os.path.exists(pub_key_path):
50        Variables.report_vars.pub_key_file = pub_key_path
51        return True
52
53    pub_key_path = os.path.join(Variables.top_dir, PUBLIC_KEY_FILE)
54    if os.path.exists(pub_key_path):
55        Variables.report_vars.pub_key_file = pub_key_path
56    else:
57        Variables.report_vars.pub_key_file = ""
58    return Variables.report_vars.pub_key_file
59
60
61def do_rsa_encrypt(content):
62    try:
63        if not check_pub_key_exist() or not content:
64            return content
65
66        plain_text = content
67        if not isinstance(plain_text, bytes):
68            plain_text = str(content).encode(encoding='utf-8')
69
70        import rsa
71        from xdevice import Variables
72        if not Variables.report_vars.pub_key_string:
73            with open(Variables.report_vars.pub_key_file,
74                      'rb') as key_content:
75                Variables.report_vars.pub_key_string = key_content.read()
76
77        if isinstance(Variables.report_vars.pub_key_string, str):
78            Variables.report_vars.pub_key_string =\
79                bytes(Variables.report_vars.pub_key_string, "utf-8")
80
81        public_key = rsa.PublicKey.load_pkcs1_openssl_pem(
82            Variables.report_vars.pub_key_string)
83
84        max_encrypt_len = int(public_key.n.bit_length() / 8) - 11
85
86        # encrypt
87        cipher_text = b""
88        for frag in _get_frags(plain_text, max_encrypt_len):
89            cipher_text_frag = rsa.encrypt(frag, public_key)
90            cipher_text += cipher_text_frag
91        return cipher_text
92
93    except (ModuleNotFoundError, ValueError, TypeError) as error:
94        raise ParamError(ErrorMessage.Common.Code_0101025.format(error.args[0])) from error
95
96
97def do_rsa_decrypt(content):
98    try:
99        if not check_pub_key_exist() or not content:
100            return content
101
102        cipher_text = content
103        if not isinstance(cipher_text, bytes):
104            cipher_text = str(content).encode()
105
106        import rsa
107        from xdevice import Variables
108        pri_key_path = os.path.join(Variables.exec_dir, PRIVATE_KEY_FILE)
109        if os.path.exists(pri_key_path):
110            pri_key_file = pri_key_path
111        else:
112            pri_key_file = os.path.join(Variables.top_dir, PRIVATE_KEY_FILE)
113        if not os.path.exists(pri_key_file):
114            return content
115        with open(pri_key_file, "rb") as key_content:
116            # get params
117            pri_key = rsa.PrivateKey.load_pkcs1(key_content.read())
118            max_decrypt_len = int(pri_key.n.bit_length() / 8)
119
120            try:
121                # decrypt
122                plain_text = b""
123                for frag in _get_frags(cipher_text, max_decrypt_len):
124                    plain_text_frag = rsa.decrypt(frag, pri_key)
125                    plain_text += plain_text_frag
126                return plain_text.decode(encoding='utf-8')
127            except rsa.pkcs1.CryptoError as error:
128                error_msg = ErrorMessage.Common.Code_0101026.format(error.args[0])
129                LOG.error(error_msg)
130                return error_msg
131
132    except (ModuleNotFoundError, ValueError, TypeError) as error:
133        error_msg = ErrorMessage.Common.Code_0101026.format(error.args[0])
134        LOG.error(error_msg)
135        return error_msg
136
137
138def generate_key_file(length=2048):
139    try:
140        from rsa import key
141
142        if int(length) not in [1024, 2048, 3072, 4096]:
143            LOG.error("Length should be 1024, 2048, 3072 or 4096")
144            return
145
146        pub_key, pri_key = key.newkeys(int(length))
147        pub_key_pem = pub_key.save_pkcs1().decode()
148        pri_key_pem = pri_key.save_pkcs1().decode()
149
150        file_pri_open = os.open("pri.key", os.O_WRONLY | os.O_CREAT |
151                                os.O_APPEND, FilePermission.mode_755)
152        file_pub_open = os.open("pub.key", os.O_WRONLY | os.O_CREAT |
153                                os.O_APPEND, FilePermission.mode_755)
154        with os.fdopen(file_pri_open, "w") as file_pri, \
155                os.fdopen(file_pub_open, "w") as file_pub:
156            file_pri.write(pri_key_pem)
157            file_pri.flush()
158            file_pub.write(pub_key_pem)
159            file_pub.flush()
160    except ModuleNotFoundError as _:
161        return
162
163
164def get_file_summary(src_file, algorithm="sha256", buffer_size=100 * 1024):
165    if not os.path.exists(src_file):
166        LOG.error("File '%s' not exists!" % src_file)
167        return ""
168
169    # if the size of file is large, use this function
170    def _read_file(_src_file):
171        while True:
172            _data = _src_file.read(buffer_size)
173            if not _data:
174                break
175            yield _data
176
177    if hasattr(hashlib, algorithm):
178        algorithm_object = hashlib.new(algorithm)
179        try:
180            with open(file=src_file, mode="rb") as _file:
181                for data in _read_file(_file):
182                    algorithm_object.update(data)
183        except ValueError as error:
184            LOG.error("Read data from '%s' error: %s " % (
185                src_file, error.args))
186            return ""
187        return algorithm_object.hexdigest()
188    else:
189        LOG.error("The algorithm '%s' not in hashlib!" % algorithm)
190        return ""
191
192
193def _get_frags(text, max_len):
194    _text = text
195    while _text:
196        if len(_text) > max_len:
197            frag, _text = _text[:max_len], _text[max_len:]
198        else:
199            frag, _text = _text, ""
200        yield frag
201