• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# coding=utf-8
3
4#
5# Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19import os
20import hashlib
21
22from _core.logger import platform_logger
23from _core.exception import ParamError
24
25__all__ = ["check_pub_key_exist", "do_rsa_encrypt", "do_rsa_decrypt",
26           "generate_key_file", "get_file_summary"]
27
28PUBLIC_KEY_FILE = "config/pub.key"
29PRIVATE_KEY_FILE = "config/pri.key"
30LOG = platform_logger("Encrypt")
31
32
33def check_pub_key_exist():
34    from xdevice import Variables
35    if Variables.report_vars.pub_key_string:
36        return Variables.report_vars.pub_key_string
37
38    if Variables.report_vars.pub_key_file is not None:
39        if Variables.report_vars.pub_key_file == "":
40            return False
41        if not os.path.exists(Variables.report_vars.pub_key_file):
42            Variables.report_vars.pub_key_file = None
43            return False
44        return True
45
46    pub_key_path = os.path.join(Variables.exec_dir, PUBLIC_KEY_FILE)
47    if os.path.exists(pub_key_path):
48        Variables.report_vars.pub_key_file = pub_key_path
49        return True
50
51    pub_key_path = os.path.join(Variables.top_dir, PUBLIC_KEY_FILE)
52    if os.path.exists(pub_key_path):
53        Variables.report_vars.pub_key_file = pub_key_path
54    else:
55        Variables.report_vars.pub_key_file = ""
56    return Variables.report_vars.pub_key_file
57
58
59def do_rsa_encrypt(content):
60    try:
61        if not check_pub_key_exist() or not content:
62            return content
63
64        plain_text = content
65        if not isinstance(plain_text, bytes):
66            plain_text = str(content).encode(encoding='utf-8')
67
68        import rsa
69        from xdevice import Variables
70        if not Variables.report_vars.pub_key_string:
71            with open(Variables.report_vars.pub_key_file,
72                      'rb') as key_content:
73                Variables.report_vars.pub_key_string = key_content.read()
74
75        if isinstance(Variables.report_vars.pub_key_string, str):
76            Variables.report_vars.pub_key_string =\
77                bytes(Variables.report_vars.pub_key_string, "utf-8")
78
79        public_key = rsa.PublicKey.load_pkcs1_openssl_pem(
80            Variables.report_vars.pub_key_string)
81
82        max_encrypt_len = int(public_key.n.bit_length() / 8) - 11
83
84        # encrypt
85        cipher_text = b""
86        for frag in _get_frags(plain_text, max_encrypt_len):
87            cipher_text_frag = rsa.encrypt(frag, public_key)
88            cipher_text += cipher_text_frag
89        return cipher_text
90
91    except (ModuleNotFoundError, ValueError, TypeError, UnicodeError,
92            Exception) as error:
93        error_msg = "rsa encryption error occurs, %s" % error.args[0]
94        raise ParamError(error_msg, error_no="00113")
95
96
97def do_rsa_decrypt(content):
98    try:
99        if not check_pub_key_exist() or not content:
100            return content
101
102        cipher_text = content
103        if not isinstance(cipher_text, bytes):
104            cipher_text = str(content).encode()
105
106        import rsa
107        from xdevice import Variables
108        pri_key_path = os.path.join(Variables.exec_dir, PRIVATE_KEY_FILE)
109        if os.path.exists(pri_key_path):
110            pri_key_file = pri_key_path
111        else:
112            pri_key_file = os.path.join(Variables.top_dir, PRIVATE_KEY_FILE)
113        if not os.path.exists(pri_key_file):
114            return content
115        with open(pri_key_file, "rb") as key_content:
116            # get params
117            pri_key = rsa.PrivateKey.load_pkcs1(key_content.read())
118            max_decrypt_len = int(pri_key.n.bit_length() / 8)
119
120            try:
121                # decrypt
122                plain_text = b""
123                for frag in _get_frags(cipher_text, max_decrypt_len):
124                    plain_text_frag = rsa.decrypt(frag, pri_key)
125                    plain_text += plain_text_frag
126                return plain_text.decode(encoding='utf-8')
127            except rsa.pkcs1.CryptoError as error:
128                error_msg = "rsa decryption error occurs, %s" % error.args[0]
129                LOG.error(error_msg, error_no="00114")
130                return error_msg
131
132    except (ModuleNotFoundError, ValueError, TypeError, UnicodeError) as error:
133        error_msg = "rsa decryption error occurs, %s" % error.args[0]
134        LOG.error(error_msg, error_no="00114")
135        return error_msg
136
137
138def generate_key_file(length=2048):
139    try:
140        from rsa import key
141
142        if int(length) not in [1024, 2048, 3072, 4096]:
143            LOG.error("length should be 1024, 2048, 3072 or 4096")
144            return
145
146        pub_key, pri_key = key.newkeys(int(length))
147        pub_key_pem = pub_key.save_pkcs1().decode()
148        pri_key_pem = pri_key.save_pkcs1().decode()
149
150        file_pri_open = os.open("pri.key", os.O_WRONLY | os.O_CREAT |
151                                os.O_APPEND, 0o755)
152        file_pub_open = os.open("pub.key", os.O_WRONLY | os.O_CREAT |
153                                os.O_APPEND, 0o755)
154        with os.fdopen(file_pri_open, "w") as file_pri, \
155                os.fdopen(file_pub_open, "w") as file_pub:
156            file_pri.write(pri_key_pem)
157            file_pri.flush()
158            file_pub.write(pub_key_pem)
159            file_pub.flush()
160    except ModuleNotFoundError:
161        return
162
163
164def get_file_summary(src_file, algorithm="sha256", buffer_size=100 * 1024):
165    if not os.path.exists(src_file):
166        LOG.error("file '%s' not exists!" % src_file)
167        return ""
168
169    # if the size of file is large, use this function
170    def _read_file(_src_file):
171        while True:
172            _data = _src_file.read(buffer_size)
173            if not _data:
174                break
175            yield _data
176
177    if hasattr(hashlib, algorithm):
178        algorithm_object = hashlib.new(algorithm)
179        try:
180            with open(file=src_file, mode="rb") as _file:
181                for data in _read_file(_file):
182                    algorithm_object.update(data)
183        except ValueError as error:
184            LOG.error("read data from '%s' error: %s " % (
185                src_file, error.args))
186            return ""
187        return algorithm_object.hexdigest()
188    else:
189        LOG.error("the algorithm '%s' not in hashlib!" % algorithm)
190        return ""
191
192
193def _get_frags(text, max_len):
194    _text = text
195    while _text:
196        if len(_text) > max_len:
197            frag, _text = _text[:max_len], _text[max_len:]
198        else:
199            frag, _text = _text, ""
200        yield frag
201