• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# encoding=utf-8
3# ============================================================================
4# @brief    riscv ROM Patch File
5# Copyright (c) 2020 HiSilicon (Shanghai) Technologies CO., LIMITED.
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17# ============================================================================
18import struct
19import ctypes
20import sys
21import os
22import shutil
23import traceback
24
25dir_name = os.path.dirname(os.path.realpath(__file__))
26info_item = [
27    "Device_Code_Version",
28    "Patch_Cpu_Core",
29    "Patch_File_Address",
30    "Patch_TBL_Address",
31    "Patch_TBL_Run_Address",
32    "Table_Max_Size",
33    "Table_Reg_Size",
34    "ROM_Address",
35    "ROM_Size",
36    "CMP_Bin_File",
37    "TBL_Bin_File",
38    "RW_Bin_File",
39    "TABLE_REG_CONUT",
40    ]
41
42# The default value is 4.
43# The value will be set based on long jump or short jump for linx131.
44CMP_HEAD_LEN = 3
45g_cmp_total_len = 131 # 128个比较表 + 3个头部信息
46CMP_REG_SIZE = 4
47PATCH_COUNT_REG_INDEX = 2
48DATA_PATCH_COUNT = 2
49pid = str(os.getpid())
50
51
52# 目录转换
53def get_dir(dir_in):
54    return os.path.join(dir_name, dir_in)
55
56def remove_bin_file():
57    os.remove(get_dir(pid + "cmp.bin")) if os.path.exists(get_dir(pid + "cmp.bin")) else None
58    os.remove(get_dir(pid + "tbl.bin")) if os.path.exists(get_dir(pid + "tbl.bin")) else None
59    os.remove(get_dir(pid + "rw.bin")) if os.path.exists(get_dir(pid + "rw.bin")) else None
60
61# 转换成bin文件
62def copy_bin_file(str_dst, str_src):
63    try:
64        with open(str_src, "rb")as file_src:
65            try:
66                with open(str_dst, "wb+")as file_dst:
67                    byte = file_src.read(1)
68                    while byte:
69                        file_dst.write(byte)
70                        byte = file_src.read(1)
71                    file_dst.close()
72            except Exception as e:
73                print("Error: %s Can't Open!" % file_dst)
74                remove_bin_file()
75                sys.exit(1)
76            file_src.close()
77    except Exception as e:
78        print("Error: %s Can't Open!" % str_src)
79        remove_bin_file()
80        sys.exit(1)
81
82
83# 生成bin文件
84def merge_output_file(files):
85    try:
86        reg_size = int(files['Table_Reg_Size'])
87        max_size = int(files['Table_Max_Size'])
88        with open(get_dir(files['RW_Bin_File']), "rb+") as file_rw:
89            try:
90                with open(get_dir(files['TBL_Bin_File']), "rb+")as file_table:
91                    data_num = int(files['TABLE_REG_CONUT']) * reg_size
92                    data_table = []
93                    for num in range(data_num):
94                        data_table.append(0)
95                    buff = file_table.read(1)
96                    j = 0
97                    while buff:
98                        data_table[j] = struct.unpack('<B', buff)[0]
99                        buff = file_table.read(1)
100                        j += 1
101                    file_table.close()
102                    offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\
103                                      DATA_PATCH_COUNT * reg_size
104                    file_rw.seek(offset_addr, 0)
105                    i = 0
106                    while i < len(data_table):
107                        byte1 = struct.pack('B', data_table[i])
108                        file_rw.write(byte1)
109                        i += 1
110            except Exception as e:
111                print("Error: %s Can't Open!" % file_table)
112                remove_bin_file()
113                sys.exit(1)
114            try:
115                with open(get_dir(files['CMP_Bin_File']), "rb+")as file_cmp:
116                    data_num = g_cmp_total_len * CMP_REG_SIZE
117                    data_cmp = []
118                    for num1 in range(data_num):
119                        data_cmp.append(0)
120                    l = 0
121                    buff1 = file_cmp.read(1)
122                    while buff1:
123                        data_cmp[l] = struct.unpack('<B', buff1)[0]
124                        buff1 = file_cmp.read(1)
125                        l += 1
126                    file_cmp.close()
127                    offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\
128                                      int(files['TABLE_REG_CONUT']) * max_size
129                    file_rw.seek(offset_addr, 0)
130                    i = 0
131                    while i < len(data_cmp):
132                        byte2 = struct.pack('B', data_cmp[i])
133                        file_rw.write(byte2)
134                        i += 1
135                    file_rw.close()
136            except Exception as e:
137                print("Error: %s Can't Open!" % file_cmp)
138                remove_bin_file()
139                sys.exit(1)
140    except Exception as e:
141        print("Error: %s Can't Open!" % file_rw)
142        remove_bin_file()
143        sys.exit(1)
144
145
146# Table or Cmp
147def creat_bin_file(file_name, contents, type_in):
148    cmp_len = len(contents)
149    try:
150        with open(file_name, "wb")as file_o:
151            for i in range(cmp_len):
152                byte = struct.pack('I', contents[i])
153                file_o.write(byte)
154        file_o.close()
155    except Exception as e:
156        print("Error: %s Can't Open!" % file_name)
157        remove_bin_file()
158        sys.exit(1)
159
160
161# 获取函数对应的二进制文件
162def get_func_code_in_file(files, func_addr, bt_rom_file_in):
163    rom_pack = ['', 0xffffffff]
164    temp_rom_begin = int(files['ROM_Address'], 16)
165    temp_rom_size = int(files['ROM_Size'], 16)
166    if (temp_rom_begin <= func_addr) and (
167            func_addr < (temp_rom_begin + temp_rom_size)):
168        rom_pack[0] = temp_rom_begin
169        rom_pack[1] = bt_rom_file_in
170    return rom_pack
171
172
173# 获取函数对应的二进制
174def get_func_code_in_bin(rom_addr, file_name, func_addr, align):
175    bin_content = [0, 0, 0, 0]
176    func_addr_offset = func_addr - rom_addr - align
177    try:
178        with open(file_name, "rb") as file_o:
179            file_o.seek(func_addr_offset, 0)
180            for i in range(4):
181                readonce = file_o.read(4)
182                bin_content[i] = struct.unpack('<I', readonce)[0]
183            file_o.close()
184    except Exception as e:
185        print("Error: %s Can't Open!" % file_name)
186        remove_bin_file()
187        sys.exit(1)
188    return bin_content
189
190
191# 获取patch相关信息
192def get_patch_info(patchinfo):
193    all_name = {}
194    try:
195        with open(str(patchinfo), "r", encoding="utf-8") as file_o:
196            source_in_lines = file_o.readlines()
197            for line in source_in_lines:
198                if '[Function]' in line:
199                    break
200                if '=' in line:
201                    config_key= line.split('=')[0].strip()
202                    config_value= line.split('=')[1].strip()
203                    if config_key in info_item:
204                        all_name[config_key] = config_value
205        # 在bin文件名上加上进程id前缀,防止多进程时出错
206        all_name["CMP_Bin_File"] = pid + all_name["CMP_Bin_File"]
207        all_name["TBL_Bin_File"] = pid + all_name["TBL_Bin_File"]
208        all_name["RW_Bin_File"]  = pid + all_name["RW_Bin_File"]
209        return all_name
210    except Exception:
211        print("Error: %s Can't Open!" % patchinfo)
212        remove_bin_file()
213        sys.exit(1)
214
215# 获取patch函数名
216def get_func_name(func_file_name, index):
217    func_names = []
218    try:
219        with open(func_file_name, "r", encoding='utf-8') as file_o:
220            is_func = False
221            for line in file_o:
222                line = line.strip()   # 去掉每行头尾空白
223                if '[Function]' in line:
224                    is_func = True
225                    continue
226                if not len(line) or line.startswith('#') or not is_func:
227                    continue
228                temp = line.split()
229                if len(temp) != 2:
230                    print(
231                        "Error format file_o:%s,line:%s," %
232                        (func_file_name, temp))
233                func_names.append(temp[index])
234    except UnicodeDecodeError as e:
235        print("get_func_name catch UnicodeDecodeError: %s" % func_file_name)
236        remove_bin_file()
237        sys.exit(1)
238    except Exception as e:
239        print(traceback.format_exc())
240        print("Error: %s Can't Open!" % func_file_name)
241        remove_bin_file()
242        sys.exit(1)
243    return func_names
244
245
246# gcc编译器nm文件中所有行有效
247def get_nm_content(file_name):
248    nm_ontent = []
249    try:
250        with open(file_name, "r", encoding='utf-8') as file_o:
251            nm_lines = file_o.readlines()
252            for nm_line in nm_lines:
253                if nm_line.find(" T ") != -1 or nm_line.find(" t ") != -1 or nm_line.find(" A "):
254                    nm_line = nm_line.strip('\n')
255                    nm_ontent.append(nm_line)
256    except Exception as e:
257        print("Error: %s Can't Open!" % file_name)
258        remove_bin_file()
259        sys.exit(1)
260    return nm_ontent
261
262def get_func_addr(func_names, m_contents, compiler_name):
263    index = 0
264    func_addrs = []
265    while index < len(func_names):
266        index_temp = 0
267        for m_content in m_contents:
268            temp = m_content.split('|')
269            if compiler_name == "GCC" and func_names[index] == temp[0].strip():
270                func_addrs.append(temp[1].strip())
271                if int(temp[1], 16) < 4:
272                    print(
273                        "Error: %s Length is %s , it can not be less than 4 "
274                        "bytes!" % (temp[3], temp[1]))
275                    sys.exit(1)
276                elif int(temp[1], 16) == 4:
277                    print(
278                        "Warning: %s Length is %s bytes, it may be unsafe!" %
279                        (temp[3], temp[1]))
280                break
281            index_temp += 1
282            if index_temp == len(m_contents):
283                print(
284                    "Error: %s Function Can't Find in Map File !" %
285                    func_names[index])
286                sys.exit(1)
287        index += 1
288    return func_addrs
289
290
291# Ctrl寄存器 +  Remap寄存器 + CMP数量 + CMP表
292def get_cmp_content(func_addrs, patch_tbl_addr, version):
293    cmp_content = []
294    if version == "Version1":
295        cmp_content.append(0)
296        cmp_content.append(int(patch_tbl_addr, 16))
297        cmp_content.append(0)
298    index = 0
299
300    while index < len(func_addrs):
301        func_addr = int(func_addrs[index], 16)
302        func_addr = func_addr & (~0x01)  # 地址最后一位为指令标记位,清除
303        if version == "Version1":
304            func_addr |= 0x1
305            cmp_content.append(func_addr)
306            index += 1
307    cmp_count = len(cmp_content) - CMP_HEAD_LEN
308    if version == "Version1" or version == "Version2":
309        if version == "Version1":
310            cmp_content[PATCH_COUNT_REG_INDEX] = cmp_count
311        if len(cmp_content) > g_cmp_total_len:  # 128个比较表 + 头部信息
312            print("Error: CMP Packet is larger than CMP Reg Capacitance")
313        while cmp_count < g_cmp_total_len - 3:
314            cmp_content.append(0)
315            cmp_count += 1
316
317    return cmp_content
318
319
320def get_table_content_for_short_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in):
321    table_content = []
322    bit0_to6 = 0x6F
323    bit7_to11 = 0x5 << 7
324
325    func_num = len(func_addrs)
326    index = 0
327    while index < func_num:
328        func_addr = int(func_addrs[index], 16)
329        func_patch_addr = int(func_patch_addrs[index], 16)
330
331        off_addr = func_patch_addr - func_addr
332
333        off_bit1_to10 = (off_addr & 0x7fe) >> 1
334        off_bit12_to19 = (off_addr & 0xff000) >> 12
335        off_bit11 = (off_addr & 0x800) >> 11
336        off_bit20 = (off_addr & 0x100000) >> 20
337
338        bit_code = bit0_to6 + bit7_to11 + (off_bit12_to19 << 12) + (off_bit11 << 20) + (off_bit1_to10 << 21) + (off_bit20 << 31)
339        table_content.append(bit_code)
340        index += 1
341
342    table_count = len(table_content)
343    if table_count > int(files['TABLE_REG_CONUT']):  # 128个比较表
344        print("Error: TABLE Packet is larger than CMP Reg Capacitance")
345        sys.exit(1)
346    while table_count < int(files['TABLE_REG_CONUT']):
347        table_content.append(0)
348        table_count += 1
349
350    return table_content
351
352def get_table_content_for_long_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in):
353    table_content = []
354    auipc_opt_bits = 0x17
355    jalr_opt_bits = 0x67
356    base_addr_bits = 0x6 # x6
357    jalr_bit12_to14 = 0x0 << 12
358    jalr_bit7_to11 = 0x0 << 7
359
360    func_num = len(func_addrs)
361    index = 0
362    while index < func_num:
363        func_addr = int(func_addrs[index], 16)
364        func_patch_addr = int(func_patch_addrs[index], 16)
365
366        off_addr = func_patch_addr - func_addr
367        off_bit12_to31 = off_addr & 0xfffff000
368        off_bit0_to11 = off_addr & 0xfff
369        if off_bit0_to11 > 0x7FF:
370            off_bit12_to31 = off_bit12_to31 + 0x1000
371            off_bit0_to11 = 0x1000 - off_bit0_to11
372            off_bit0_to11 = (~off_bit0_to11 + 1) & 0xfff
373
374        auipc_bit_code = auipc_opt_bits + (base_addr_bits << 7) + off_bit12_to31
375        table_content.append(auipc_bit_code)
376
377        jalr_bit_code = jalr_opt_bits + jalr_bit7_to11 + jalr_bit12_to14 + (base_addr_bits << 15) + (off_bit0_to11 << 20)
378        table_content.append(jalr_bit_code)
379
380        index += 1
381
382    table_count = len(table_content)
383    if table_count > 2 * int(files['TABLE_REG_CONUT']):  # 128个比较表
384        print("Error: TABLE Packet is larger than CMP Reg Capacitance")
385        sys.exit(1)
386    while table_count < 2 * int(files['TABLE_REG_CONUT']):
387        table_content.append(0)
388        table_content.append(0)
389        table_count += 2
390
391    return table_content
392
393def create_patch(patch_info, nm_file_in, rom_bin_file):
394    global g_cmp_total_len
395    file_all = get_patch_info(patch_info)
396    g_cmp_total_len = int(file_all['TABLE_REG_CONUT']) + 3
397    core = file_all['Patch_Cpu_Core']
398    funs = get_func_name(patch_info, 0)
399    funs_patch = get_func_name(patch_info, 1)
400    nm_contents = get_nm_content(nm_file_in)
401    func_addrs = get_func_addr(funs, nm_contents, "GCC")
402    func_patch_addrs = get_func_addr(funs_patch, nm_contents, "GCC")
403    cmp_contents = get_cmp_content(func_addrs, file_all['Patch_TBL_Run_Address'], file_all['Device_Code_Version'])
404    reg_size = int(file_all['Table_Reg_Size'])
405    if reg_size == 4:
406        table_contents = get_table_content_for_short_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file)
407    elif reg_size == 8:
408        table_contents = get_table_content_for_long_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file)
409    else:
410        print("ErrorCore %s for rom patch" % core)
411    creat_bin_file(get_dir(file_all['CMP_Bin_File']), cmp_contents, file_all['CMP_Bin_File'])
412    creat_bin_file(get_dir(file_all['TBL_Bin_File']), table_contents, file_all['TBL_Bin_File'])
413
414def output_bin_file(file_all, output_dir_in, ram_bin_file):
415    if os.path.exists(os.path.join(output_dir_in, ram_bin_file)):
416        shutil.move(os.path.join(output_dir_in, ram_bin_file), os.path.join(output_dir_in, "unpatch.bin"))
417    if os.path.exists(os.path.join(output_dir_in, "rw.bin")):
418        os.remove(os.path.join(output_dir_in, "rw.bin"))
419    shutil.move(get_dir(file_all['RW_Bin_File']), os.path.join(output_dir_in, ram_bin_file))
420    os.remove(get_dir(file_all['CMP_Bin_File']))
421    os.remove(get_dir(file_all['TBL_Bin_File']))
422    print("Generating %s..." % ram_bin_file)
423
424
425def get_patch_addr(patch_info, ram_bin_file, output_dir_in):
426    file_all = get_patch_info(patch_info)
427    funs = get_func_name(patch_info, 0)
428    funs_patch = get_func_name(patch_info, 1)
429    copy_bin_file(get_dir(file_all['RW_Bin_File']), ram_bin_file)
430    merge_output_file(file_all)
431    output_bin_file(file_all, output_dir_in, ram_bin_file)
432
433if __name__ == "__main__":
434    if(len(sys.argv) == 8):
435        ram_bin_file = sys.argv[1]
436        rom_bin_file = sys.argv[2]
437        nm_file = sys.argv[3]
438        partch_config_dir = sys.argv[4]
439        core = sys.argv[5]
440        target_name = sys.argv[6]
441        output_dir = sys.argv[7]
442
443        if os.path.exists(os.path.join(partch_config_dir, f'{target_name}.cfg')):
444            patch_info = os.path.join(partch_config_dir, f'{target_name}.cfg')
445        else:
446            patch_info = os.path.join(partch_config_dir, f'{core}.cfg')
447
448        create_patch(patch_info, nm_file, rom_bin_file)
449        get_patch_addr(patch_info, ram_bin_file, output_dir)
450
451    else:
452        print(
453            "Usage: %s <xx.bin> <xx_rom.bin> <xx.nm> <patch_confi_dir> <core> <target_name>"
454            "<output_dir>" % os.path.basename(sys.argv[0]))
455        sys.exit(1)
456