1#!/usr/bin/env python3 2# encoding=utf-8 3# ============================================================================ 4# @brief riscv ROM Patch File 5# Copyright (c) 2020 HiSilicon (Shanghai) Technologies CO., LIMITED. 6# Licensed under the Apache License, Version 2.0 (the "License"); 7# you may not use this file except in compliance with the License. 8# You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, software 13# distributed under the License is distributed on an "AS IS" BASIS, 14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15# See the License for the specific language governing permissions and 16# limitations under the License. 17# ============================================================================ 18import struct 19import ctypes 20import sys 21import os 22import shutil 23import traceback 24 25dir_name = os.path.dirname(os.path.realpath(__file__)) 26info_item = [ 27 "Device_Code_Version", 28 "Patch_Cpu_Core", 29 "Patch_File_Address", 30 "Patch_TBL_Address", 31 "Patch_TBL_Run_Address", 32 "Table_Max_Size", 33 "Table_Reg_Size", 34 "ROM_Address", 35 "ROM_Size", 36 "CMP_Bin_File", 37 "TBL_Bin_File", 38 "RW_Bin_File", 39 "TABLE_REG_CONUT", 40 ] 41 42# The default value is 4. 43# The value will be set based on long jump or short jump for linx131. 44CMP_HEAD_LEN = 3 45g_cmp_total_len = 131 # 128个比较表 + 3个头部信息 46CMP_REG_SIZE = 4 47PATCH_COUNT_REG_INDEX = 2 48DATA_PATCH_COUNT = 2 49pid = str(os.getpid()) 50 51 52# 目录转换 53def get_dir(dir_in): 54 return os.path.join(dir_name, dir_in) 55 56def remove_bin_file(): 57 os.remove(get_dir(pid + "cmp.bin")) if os.path.exists(get_dir(pid + "cmp.bin")) else None 58 os.remove(get_dir(pid + "tbl.bin")) if os.path.exists(get_dir(pid + "tbl.bin")) else None 59 os.remove(get_dir(pid + "rw.bin")) if os.path.exists(get_dir(pid + "rw.bin")) else None 60 61# 转换成bin文件 62def copy_bin_file(str_dst, str_src): 63 try: 64 with open(str_src, "rb")as file_src: 65 try: 66 with open(str_dst, "wb+")as file_dst: 67 byte = file_src.read(1) 68 while byte: 69 file_dst.write(byte) 70 byte = file_src.read(1) 71 file_dst.close() 72 except Exception as e: 73 print("Error: %s Can't Open!" % file_dst) 74 remove_bin_file() 75 sys.exit(1) 76 file_src.close() 77 except Exception as e: 78 print("Error: %s Can't Open!" % str_src) 79 remove_bin_file() 80 sys.exit(1) 81 82 83# 生成bin文件 84def merge_output_file(files): 85 try: 86 reg_size = int(files['Table_Reg_Size']) 87 max_size = int(files['Table_Max_Size']) 88 with open(get_dir(files['RW_Bin_File']), "rb+") as file_rw: 89 try: 90 with open(get_dir(files['TBL_Bin_File']), "rb+")as file_table: 91 data_num = int(files['TABLE_REG_CONUT']) * reg_size 92 data_table = [] 93 for num in range(data_num): 94 data_table.append(0) 95 buff = file_table.read(1) 96 j = 0 97 while buff: 98 data_table[j] = struct.unpack('<B', buff)[0] 99 buff = file_table.read(1) 100 j += 1 101 file_table.close() 102 offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\ 103 DATA_PATCH_COUNT * reg_size 104 file_rw.seek(offset_addr, 0) 105 i = 0 106 while i < len(data_table): 107 byte1 = struct.pack('B', data_table[i]) 108 file_rw.write(byte1) 109 i += 1 110 except Exception as e: 111 print("Error: %s Can't Open!" % file_table) 112 remove_bin_file() 113 sys.exit(1) 114 try: 115 with open(get_dir(files['CMP_Bin_File']), "rb+")as file_cmp: 116 data_num = g_cmp_total_len * CMP_REG_SIZE 117 data_cmp = [] 118 for num1 in range(data_num): 119 data_cmp.append(0) 120 l = 0 121 buff1 = file_cmp.read(1) 122 while buff1: 123 data_cmp[l] = struct.unpack('<B', buff1)[0] 124 buff1 = file_cmp.read(1) 125 l += 1 126 file_cmp.close() 127 offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\ 128 int(files['TABLE_REG_CONUT']) * max_size 129 file_rw.seek(offset_addr, 0) 130 i = 0 131 while i < len(data_cmp): 132 byte2 = struct.pack('B', data_cmp[i]) 133 file_rw.write(byte2) 134 i += 1 135 file_rw.close() 136 except Exception as e: 137 print("Error: %s Can't Open!" % file_cmp) 138 remove_bin_file() 139 sys.exit(1) 140 except Exception as e: 141 print("Error: %s Can't Open!" % file_rw) 142 remove_bin_file() 143 sys.exit(1) 144 145 146# Table or Cmp 147def creat_bin_file(file_name, contents, type_in): 148 cmp_len = len(contents) 149 try: 150 with open(file_name, "wb")as file_o: 151 for i in range(cmp_len): 152 byte = struct.pack('I', contents[i]) 153 file_o.write(byte) 154 file_o.close() 155 except Exception as e: 156 print("Error: %s Can't Open!" % file_name) 157 remove_bin_file() 158 sys.exit(1) 159 160 161# 获取函数对应的二进制文件 162def get_func_code_in_file(files, func_addr, bt_rom_file_in): 163 rom_pack = ['', 0xffffffff] 164 temp_rom_begin = int(files['ROM_Address'], 16) 165 temp_rom_size = int(files['ROM_Size'], 16) 166 if (temp_rom_begin <= func_addr) and ( 167 func_addr < (temp_rom_begin + temp_rom_size)): 168 rom_pack[0] = temp_rom_begin 169 rom_pack[1] = bt_rom_file_in 170 return rom_pack 171 172 173# 获取函数对应的二进制 174def get_func_code_in_bin(rom_addr, file_name, func_addr, align): 175 bin_content = [0, 0, 0, 0] 176 func_addr_offset = func_addr - rom_addr - align 177 try: 178 with open(file_name, "rb") as file_o: 179 file_o.seek(func_addr_offset, 0) 180 for i in range(4): 181 readonce = file_o.read(4) 182 bin_content[i] = struct.unpack('<I', readonce)[0] 183 file_o.close() 184 except Exception as e: 185 print("Error: %s Can't Open!" % file_name) 186 remove_bin_file() 187 sys.exit(1) 188 return bin_content 189 190 191# 获取patch相关信息 192def get_patch_info(patchinfo): 193 all_name = {} 194 try: 195 with open(str(patchinfo), "r", encoding="utf-8") as file_o: 196 source_in_lines = file_o.readlines() 197 for line in source_in_lines: 198 if '[Function]' in line: 199 break 200 if '=' in line: 201 config_key= line.split('=')[0].strip() 202 config_value= line.split('=')[1].strip() 203 if config_key in info_item: 204 all_name[config_key] = config_value 205 # 在bin文件名上加上进程id前缀,防止多进程时出错 206 all_name["CMP_Bin_File"] = pid + all_name["CMP_Bin_File"] 207 all_name["TBL_Bin_File"] = pid + all_name["TBL_Bin_File"] 208 all_name["RW_Bin_File"] = pid + all_name["RW_Bin_File"] 209 return all_name 210 except Exception: 211 print("Error: %s Can't Open!" % patchinfo) 212 remove_bin_file() 213 sys.exit(1) 214 215# 获取patch函数名 216def get_func_name(func_file_name, index): 217 func_names = [] 218 try: 219 with open(func_file_name, "r", encoding='utf-8') as file_o: 220 is_func = False 221 for line in file_o: 222 line = line.strip() # 去掉每行头尾空白 223 if '[Function]' in line: 224 is_func = True 225 continue 226 if not len(line) or line.startswith('#') or not is_func: 227 continue 228 temp = line.split() 229 if len(temp) != 2: 230 print( 231 "Error format file_o:%s,line:%s," % 232 (func_file_name, temp)) 233 func_names.append(temp[index]) 234 except UnicodeDecodeError as e: 235 print("get_func_name catch UnicodeDecodeError: %s" % func_file_name) 236 remove_bin_file() 237 sys.exit(1) 238 except Exception as e: 239 print(traceback.format_exc()) 240 print("Error: %s Can't Open!" % func_file_name) 241 remove_bin_file() 242 sys.exit(1) 243 return func_names 244 245 246# gcc编译器nm文件中所有行有效 247def get_nm_content(file_name): 248 nm_ontent = [] 249 try: 250 with open(file_name, "r", encoding='utf-8') as file_o: 251 nm_lines = file_o.readlines() 252 for nm_line in nm_lines: 253 if nm_line.find(" T ") != -1 or nm_line.find(" t ") != -1 or nm_line.find(" A "): 254 nm_line = nm_line.strip('\n') 255 nm_ontent.append(nm_line) 256 except Exception as e: 257 print("Error: %s Can't Open!" % file_name) 258 remove_bin_file() 259 sys.exit(1) 260 return nm_ontent 261 262def get_func_addr(func_names, m_contents, compiler_name): 263 index = 0 264 func_addrs = [] 265 while index < len(func_names): 266 index_temp = 0 267 for m_content in m_contents: 268 temp = m_content.split('|') 269 if compiler_name == "GCC" and func_names[index] == temp[0].strip(): 270 func_addrs.append(temp[1].strip()) 271 if int(temp[1], 16) < 4: 272 print( 273 "Error: %s Length is %s , it can not be less than 4 " 274 "bytes!" % (temp[3], temp[1])) 275 sys.exit(1) 276 elif int(temp[1], 16) == 4: 277 print( 278 "Warning: %s Length is %s bytes, it may be unsafe!" % 279 (temp[3], temp[1])) 280 break 281 index_temp += 1 282 if index_temp == len(m_contents): 283 print( 284 "Error: %s Function Can't Find in Map File !" % 285 func_names[index]) 286 sys.exit(1) 287 index += 1 288 return func_addrs 289 290 291# Ctrl寄存器 + Remap寄存器 + CMP数量 + CMP表 292def get_cmp_content(func_addrs, patch_tbl_addr, version): 293 cmp_content = [] 294 if version == "Version1": 295 cmp_content.append(0) 296 cmp_content.append(int(patch_tbl_addr, 16)) 297 cmp_content.append(0) 298 index = 0 299 300 while index < len(func_addrs): 301 func_addr = int(func_addrs[index], 16) 302 func_addr = func_addr & (~0x01) # 地址最后一位为指令标记位,清除 303 if version == "Version1": 304 func_addr |= 0x1 305 cmp_content.append(func_addr) 306 index += 1 307 cmp_count = len(cmp_content) - CMP_HEAD_LEN 308 if version == "Version1" or version == "Version2": 309 if version == "Version1": 310 cmp_content[PATCH_COUNT_REG_INDEX] = cmp_count 311 if len(cmp_content) > g_cmp_total_len: # 128个比较表 + 头部信息 312 print("Error: CMP Packet is larger than CMP Reg Capacitance") 313 while cmp_count < g_cmp_total_len - 3: 314 cmp_content.append(0) 315 cmp_count += 1 316 317 return cmp_content 318 319 320def get_table_content_for_short_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in): 321 table_content = [] 322 bit0_to6 = 0x6F 323 bit7_to11 = 0x5 << 7 324 325 func_num = len(func_addrs) 326 index = 0 327 while index < func_num: 328 func_addr = int(func_addrs[index], 16) 329 func_patch_addr = int(func_patch_addrs[index], 16) 330 331 off_addr = func_patch_addr - func_addr 332 333 off_bit1_to10 = (off_addr & 0x7fe) >> 1 334 off_bit12_to19 = (off_addr & 0xff000) >> 12 335 off_bit11 = (off_addr & 0x800) >> 11 336 off_bit20 = (off_addr & 0x100000) >> 20 337 338 bit_code = bit0_to6 + bit7_to11 + (off_bit12_to19 << 12) + (off_bit11 << 20) + (off_bit1_to10 << 21) + (off_bit20 << 31) 339 table_content.append(bit_code) 340 index += 1 341 342 table_count = len(table_content) 343 if table_count > int(files['TABLE_REG_CONUT']): # 128个比较表 344 print("Error: TABLE Packet is larger than CMP Reg Capacitance") 345 sys.exit(1) 346 while table_count < int(files['TABLE_REG_CONUT']): 347 table_content.append(0) 348 table_count += 1 349 350 return table_content 351 352def get_table_content_for_long_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in): 353 table_content = [] 354 auipc_opt_bits = 0x17 355 jalr_opt_bits = 0x67 356 base_addr_bits = 0x6 # x6 357 jalr_bit12_to14 = 0x0 << 12 358 jalr_bit7_to11 = 0x0 << 7 359 360 func_num = len(func_addrs) 361 index = 0 362 while index < func_num: 363 func_addr = int(func_addrs[index], 16) 364 func_patch_addr = int(func_patch_addrs[index], 16) 365 366 off_addr = func_patch_addr - func_addr 367 off_bit12_to31 = off_addr & 0xfffff000 368 off_bit0_to11 = off_addr & 0xfff 369 if off_bit0_to11 > 0x7FF: 370 off_bit12_to31 = off_bit12_to31 + 0x1000 371 off_bit0_to11 = 0x1000 - off_bit0_to11 372 off_bit0_to11 = (~off_bit0_to11 + 1) & 0xfff 373 374 auipc_bit_code = auipc_opt_bits + (base_addr_bits << 7) + off_bit12_to31 375 table_content.append(auipc_bit_code) 376 377 jalr_bit_code = jalr_opt_bits + jalr_bit7_to11 + jalr_bit12_to14 + (base_addr_bits << 15) + (off_bit0_to11 << 20) 378 table_content.append(jalr_bit_code) 379 380 index += 1 381 382 table_count = len(table_content) 383 if table_count > 2 * int(files['TABLE_REG_CONUT']): # 128个比较表 384 print("Error: TABLE Packet is larger than CMP Reg Capacitance") 385 sys.exit(1) 386 while table_count < 2 * int(files['TABLE_REG_CONUT']): 387 table_content.append(0) 388 table_content.append(0) 389 table_count += 2 390 391 return table_content 392 393def create_patch(patch_info, nm_file_in, rom_bin_file): 394 global g_cmp_total_len 395 file_all = get_patch_info(patch_info) 396 g_cmp_total_len = int(file_all['TABLE_REG_CONUT']) + 3 397 core = file_all['Patch_Cpu_Core'] 398 funs = get_func_name(patch_info, 0) 399 funs_patch = get_func_name(patch_info, 1) 400 nm_contents = get_nm_content(nm_file_in) 401 func_addrs = get_func_addr(funs, nm_contents, "GCC") 402 func_patch_addrs = get_func_addr(funs_patch, nm_contents, "GCC") 403 cmp_contents = get_cmp_content(func_addrs, file_all['Patch_TBL_Run_Address'], file_all['Device_Code_Version']) 404 reg_size = int(file_all['Table_Reg_Size']) 405 if reg_size == 4: 406 table_contents = get_table_content_for_short_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file) 407 elif reg_size == 8: 408 table_contents = get_table_content_for_long_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file) 409 else: 410 print("ErrorCore %s for rom patch" % core) 411 creat_bin_file(get_dir(file_all['CMP_Bin_File']), cmp_contents, file_all['CMP_Bin_File']) 412 creat_bin_file(get_dir(file_all['TBL_Bin_File']), table_contents, file_all['TBL_Bin_File']) 413 414def output_bin_file(file_all, output_dir_in, ram_bin_file): 415 if os.path.exists(os.path.join(output_dir_in, ram_bin_file)): 416 shutil.move(os.path.join(output_dir_in, ram_bin_file), os.path.join(output_dir_in, "unpatch.bin")) 417 if os.path.exists(os.path.join(output_dir_in, "rw.bin")): 418 os.remove(os.path.join(output_dir_in, "rw.bin")) 419 shutil.move(get_dir(file_all['RW_Bin_File']), os.path.join(output_dir_in, ram_bin_file)) 420 os.remove(get_dir(file_all['CMP_Bin_File'])) 421 os.remove(get_dir(file_all['TBL_Bin_File'])) 422 print("Generating %s..." % ram_bin_file) 423 424 425def get_patch_addr(patch_info, ram_bin_file, output_dir_in): 426 file_all = get_patch_info(patch_info) 427 funs = get_func_name(patch_info, 0) 428 funs_patch = get_func_name(patch_info, 1) 429 copy_bin_file(get_dir(file_all['RW_Bin_File']), ram_bin_file) 430 merge_output_file(file_all) 431 output_bin_file(file_all, output_dir_in, ram_bin_file) 432 433if __name__ == "__main__": 434 if(len(sys.argv) == 8): 435 ram_bin_file = sys.argv[1] 436 rom_bin_file = sys.argv[2] 437 nm_file = sys.argv[3] 438 partch_config_dir = sys.argv[4] 439 core = sys.argv[5] 440 target_name = sys.argv[6] 441 output_dir = sys.argv[7] 442 443 if os.path.exists(os.path.join(partch_config_dir, f'{target_name}.cfg')): 444 patch_info = os.path.join(partch_config_dir, f'{target_name}.cfg') 445 else: 446 patch_info = os.path.join(partch_config_dir, f'{core}.cfg') 447 448 create_patch(patch_info, nm_file, rom_bin_file) 449 get_patch_addr(patch_info, ram_bin_file, output_dir) 450 451 else: 452 print( 453 "Usage: %s <xx.bin> <xx_rom.bin> <xx.nm> <patch_confi_dir> <core> <target_name>" 454 "<output_dir>" % os.path.basename(sys.argv[0])) 455 sys.exit(1) 456