1# 2# Copyright (c) 2024, Arm Limited and Contributors. All rights reserved. 3# 4# SPDX-License-Identifier: BSD-3-Clause 5# 6 7import sys 8import re 9from pydevicetree.ast import CellArray, LabelReference 10from pydevicetree import Devicetree, Property, Node 11from pathlib import Path 12from typing import List, Optional 13 14class COT: 15 def __init__(self, inputfile: str, outputfile=None): 16 try: 17 self.tree = Devicetree.parseFile(inputfile) 18 except: 19 print("not a valid CoT DT file") 20 exit(1) 21 22 self.output = outputfile 23 self.input = inputfile 24 self.has_root = False 25 26 # edge cases 27 certs = self.get_all_certificates() 28 for c in certs: 29 if self.if_root(c): 30 if not c.get_fields("signing-key"): 31 c.properties.append(Property("signing-key", CellArray([LabelReference("subject_pk")]))) 32 33 def print_cert_info(self, node:Node): 34 img_id = node.get_field("image-id").values[0].replace('"', "") 35 sign_key = self.get_sign_key(node) 36 nv = self.get_nv_ctr(node) 37 38 info = "<b>name:</b> {}<br><b>image-id:</b> {}<br>{}{}{}"\ 39 .format(node.name, img_id, "<b>root-certificate</b><br>" if self.if_root(node) else "", \ 40 "<b>signing-key:</b> " + self.extract_label(sign_key) + "<br>" if sign_key else "", \ 41 "<b>nv counter:</b> " + self.extract_label(nv) + "<br>" if nv else "") 42 return info 43 44 def print_data_info(self, node:Node): 45 oid = node.get_field("oid") 46 info = "<b>name:</b> {}<br><b>oid:</b> {}<br>" \ 47 .format(node.name, oid) 48 49 return info 50 51 def print_img_info(self, node:Node): 52 hash = self.extract_label(node.get_fields("hash")) 53 img_id = node.get_field("image-id").values[0].replace('"', "") 54 info = "<b>name:</b> {}<br><b>image-id:</b> {}<br><b>hash:</b> {}"\ 55 .format(node.name, img_id, hash) 56 57 return info 58 59 def tree_width(self, parent_set, root): 60 ans = 1 61 stack = [root] 62 63 while stack: 64 tmp_stack = [] 65 while stack: 66 cur_node = stack.pop() 67 child = parent_set[cur_node] 68 for c in child: 69 tmp_stack.append(c) 70 71 stack = tmp_stack.copy() 72 ans = max(ans, len(tmp_stack)) 73 74 return ans 75 76 def resolve_lay(self, parent_set, lay, name_idx, root, bounds, break_name): 77 child = parent_set[root] 78 79 if len(child) == 0: 80 return 81 82 width = [] 83 total_width = 0 84 for c in child: 85 w = self.tree_width(parent_set, c) 86 width.append(w) 87 total_width += w 88 89 allow_width = bounds[1] - bounds[0] 90 interval = allow_width / total_width 91 start = bounds[0] 92 for i, c in enumerate(child): 93 end = start + interval * width[i] 94 new_bounds = [start, end] 95 lay[name_idx[c]][0] = start + (end - start) / 2 96 if end - start < 0.28: 97 break_name.add(c) 98 start = end 99 self.resolve_lay(parent_set, lay, name_idx, c, new_bounds, break_name) 100 101 def tree_visualization(self): 102 import igraph 103 from igraph import Graph, EdgeSeq 104 import collections 105 106 cert = self.get_certificates() 107 pk = self.get_rot_keys() 108 nv = self.get_nv_counters() 109 image = self.get_images() 110 111 certs = cert.children 112 if pk: 113 pks = pk.children 114 else: 115 pks = [] 116 nvs = nv.children 117 images = image.children 118 119 root_name = "CoT" 120 121 G = Graph() 122 detail = [] 123 lay = [] 124 name_idx = {} 125 parent_set = collections.defaultdict(list) 126 127 G.add_vertex(root_name) 128 detail.append("CoT Root") 129 name_idx[root_name] = len(lay) 130 lay.append([0,0]) 131 132 G.add_vertex(cert.name) 133 G.add_edge(root_name, cert.name) 134 detail.append("All Certificates") 135 name_idx[cert.name] = len(lay) 136 lay.append([0, 1]) 137 parent_set[root_name].append(cert.name) 138 139 if pk: 140 G.add_vertex(pk.name) 141 detail.append("All Public Trusted Key") 142 G.add_edge(root_name, pk.name) 143 name_idx[pk.name] = len(lay) 144 lay.append([-2.0, 1]) 145 parent_set[root_name].append(pk.name) 146 147 G.add_vertex(nv.name) 148 detail.append("All NV Counters") 149 G.add_edge(root_name, nv.name) 150 name_idx[nv.name] = len(lay) 151 lay.append([2.0, 1]) 152 parent_set[root_name].append(nv.name) 153 154 if pks: 155 for i, p in enumerate(pks): 156 G.add_vertex(p.name) 157 detail.append(self.print_data_info(p)) 158 G.add_edge(pk.name, p.name) 159 name_idx[p.name] = len(lay) 160 parent_set[pk.name].append(p.name) 161 lay.append([0, lay[name_idx[pk.name]][1] + 1]) 162 163 for c in certs: 164 G.add_vertex(c.name) 165 detail.append(self.print_cert_info(c)) 166 name_idx[c.name] = len(lay) 167 if self.if_root(c): 168 G.add_edge(cert.name, c.name) 169 parent_set[cert.name].append(c.name) 170 lay.append([0, 2]) 171 else: 172 parent = self.extract_label(c.get_fields("parent")) 173 G.add_edge(parent, c.name) 174 parent_set[parent].append(c.name) 175 lay.append([0, lay[name_idx[parent]][1] + 1]) 176 177 for idx, i in enumerate(images): 178 G.add_vertex(i.name) 179 detail.append(self.print_img_info(i)) 180 parent = self.extract_label(i.get_fields("parent")) 181 G.add_edge(parent, i.name) 182 parent_set[parent].append(i.name) 183 name_idx[i.name] = len(lay) 184 lay.append([0, lay[name_idx[parent]][1] + 1]) 185 186 for i, n in enumerate(nvs): 187 G.add_vertex(n.name) 188 detail.append(self.print_data_info(n)) 189 G.add_edge(nv.name, n.name) 190 name_idx[n.name] = len(lay) 191 parent_set[nv.name].append(n.name) 192 lay.append([0, lay[name_idx[nv.name]][1] + 1]) 193 194 break_name = set() 195 self.resolve_lay(parent_set, lay, name_idx, root_name, [-3, 3], break_name) 196 #lay = G.layout('rt') 197 198 numVertex = len(G.get_vertex_dataframe()) 199 vertices = G.get_vertex_dataframe() 200 v_label = [] 201 202 for i in vertices['name']: 203 if i in break_name and len(i) > 10: 204 middle = len(i) // 2 205 v_label.append(i[:middle] + "<br>" + i[middle:]) 206 else: 207 v_label.append(i) 208 209 position = {k: lay[k] for k in range(numVertex)} 210 Y = [lay[k][1] for k in range(numVertex)] 211 M = max(Y) 212 213 es = EdgeSeq(G) # sequence of edges 214 E = [e.tuple for e in G.es] # list of edges 215 216 L = len(position) 217 Xn = [position[k][0] for k in range(L)] 218 Yn = [2*M-position[k][1] for k in range(L)] 219 Xe = [] 220 Ye = [] 221 for edge in E: 222 Xe += [position[edge[0]][0], position[edge[1]][0], None] 223 Ye += [2*M-position[edge[0]][1], 2*M-position[edge[1]][1], None] 224 225 labels = v_label 226 227 import plotly.graph_objects as go 228 fig = go.Figure() 229 fig.add_trace(go.Scatter(x = Xe, 230 y = Ye, 231 mode = 'lines', 232 line = dict(color='rgb(210,210,210)', width=2), 233 hoverinfo = 'none' 234 )) 235 fig.add_trace(go.Scatter(x = Xn, 236 y = Yn, 237 mode = 'markers', 238 name = 'detail', 239 marker = dict(symbol = 'circle-dot', 240 size = 50, 241 color = 'rgba(135, 206, 250, 0.8)', #'#DB4551', 242 line = dict(color='MediumPurple', width=3) 243 ), 244 text=detail, 245 hoverinfo='text', 246 hovertemplate = 247 '<b>Detail</b><br>' 248 '%{text}', 249 opacity=0.8 250 )) 251 252 def make_annotations(pos, text, font_size=10, font_color='rgb(0,0,0)'): 253 L = len(pos) 254 if len(text) != L: 255 raise ValueError('The lists pos and text must have the same len') 256 annotations = [] 257 for k in range(L): 258 annotations.append( 259 dict( 260 text = labels[k], 261 x = pos[k][0], y = 2*M-position[k][1], 262 xref = 'x1', yref = 'y1', 263 font = dict(color = font_color, size = font_size), 264 showarrow = False) 265 ) 266 return annotations 267 268 axis = dict(showline=False, # hide axis line, grid, ticklabels and title 269 zeroline=False, 270 showgrid=False, 271 showticklabels=False, 272 ) 273 274 fig.update_layout(title= 'CoT Device Tree', 275 annotations=make_annotations(position, v_label), 276 font_size=12, 277 showlegend=False, 278 xaxis=axis, 279 yaxis=axis, 280 margin=dict(l=40, r=40, b=85, t=100), 281 hovermode='closest', 282 plot_bgcolor='rgb(248,248,248)' 283 ) 284 285 fig.show() 286 287 return 288 289 def if_root(self, node:Node) -> bool: 290 for p in node.properties: 291 if p.name == "root-certificate": 292 return True 293 return False 294 295 def get_sign_key(self, node:Node): 296 for p in node.properties: 297 if p.name == "signing-key": 298 return p.values 299 300 return None 301 302 def get_nv_ctr(self, node:Node): 303 for nv in node.properties: 304 if nv.name == "antirollback-counter": 305 return nv.values 306 307 return None 308 309 def extract_label(self, label) -> str: 310 if not label: 311 return label 312 return label[0].label.name 313 314 def get_auth_data(self, node:Node): 315 return node.children 316 317 def format_auth_data_val(self, node:Node, cert:Node): 318 type_desc = node.name 319 ptr = type_desc + "_buf" 320 len = "HASH_DER_LEN" 321 if re.search("_pk$", type_desc): 322 len = "PK_DER_LEN" 323 324 # edge case 325 if not self.if_root(cert) and "key_cert" in cert.name: 326 if "content_pk" in ptr: 327 ptr = "content_pk_buf" 328 329 return type_desc, ptr, len 330 331 def get_node(self, nodes: List[Node], name: str) -> Node: 332 for i in nodes: 333 if i.name == name: 334 return i 335 336 def get_certificates(self) -> Node: 337 children = self.tree.children 338 for i in children: 339 if i.name == "cot": 340 return self.get_node(i.children, "manifests") 341 342 def get_images(self)-> Node: 343 children = self.tree.children 344 for i in children: 345 if i.name == "cot": 346 return self.get_node(i.children, "images") 347 348 def get_nv_counters(self) -> Node: 349 children = self.tree.children 350 return self.get_node(children, "non_volatile_counters") 351 352 def get_rot_keys(self) -> Node: 353 children = self.tree.children 354 return self.get_node(children, "rot_keys") 355 356 def get_all_certificates(self) -> Node: 357 cert = self.get_certificates() 358 return cert.children 359 360 def get_all_images(self) -> Node: 361 image = self.get_images() 362 return image.children 363 364 def get_all_nv_counters(self) -> Node: 365 nv = self.get_nv_counters() 366 return nv.children 367 368 def get_all_pks(self) -> Node: 369 pk = self.get_rot_keys() 370 if not pk: 371 return [] 372 return pk.children 373 374 def validate_cert(self, node:Node) -> bool: 375 valid = True 376 if not node.has_field("image-id"): 377 print("{} missing mandatory attribute image-id".format(node.name)) 378 valid = False 379 380 if not node.has_field("root-certificate"): 381 if not node.has_field("parent"): 382 print("{} missing mandatory attribute parent".format(node.name)) 383 valid = False 384 else: 385 # check if refer to non existing parent 386 certs = self.get_all_certificates() 387 found = False 388 for c in certs: 389 if c.name == self.extract_label(node.get_fields("parent")): 390 found = True 391 392 if not found: 393 print("{} refer to non existing parent".format(node.name)) 394 valid = False 395 396 else: 397 self.has_root = True 398 399 child = node.children 400 if child: 401 for c in child: 402 if not c.has_field("oid"): 403 print("{} missing mandatory attribute oid".format(c.name)) 404 valid = False 405 406 return valid 407 408 def validate_img(self, node:Node) -> bool: 409 valid = True 410 if not node.has_field("image-id"): 411 print("{} missing mandatory attribute image-id".format(node.name)) 412 valid = False 413 414 if not node.has_field("parent"): 415 print("{} missing mandatory attribute parent".format(node.name)) 416 valid = False 417 418 if not node.has_field("hash"): 419 print("{} missing mandatory attribute hash".format(node.name)) 420 valid = False 421 422 # check if refer to non existing parent 423 certs = self.get_all_certificates() 424 found = False 425 for c in certs: 426 if c.name == self.extract_label(node.get_fields("parent")): 427 found = True 428 429 if not found: 430 print("{} refer to non existing parent".format(node.name)) 431 valid = False 432 433 return valid 434 435 def validate_nodes(self) -> bool: 436 valid = True 437 438 certs = self.get_all_certificates() 439 images = self.get_all_images() 440 441 for n in certs: 442 node_valid = self.validate_cert(n) 443 valid = valid and node_valid 444 445 for i in images: 446 node_valid = self.validate_img(i) 447 valid = valid and node_valid 448 449 if not self.has_root: 450 print("missing root certificate") 451 452 return valid 453 454 def include_to_c(self, f): 455 f.write("#include <stddef.h>\n") 456 f.write("#include <mbedtls/version.h>\n") 457 f.write("#include <common/tbbr/cot_def.h>\n") 458 f.write("#include <drivers/auth/auth_mod.h>\n") 459 f.write("#include <platform_def.h>\n\n") 460 return 461 462 def generate_header(self, output): 463 self.include_to_c(output) 464 465 def all_cert_to_c(self, f): 466 certs = self.get_all_certificates() 467 for c in certs: 468 self.cert_to_c(c, f) 469 470 f.write("\n") 471 472 def cert_to_c(self, node: Node, f): 473 node_image_id: int = node.get_field("image-id") 474 475 f.write(f"static const auth_img_desc_t {node.name} = {{\n") 476 f.write(f"\t.img_id = {node_image_id},\n") 477 f.write("\t.img_type = IMG_CERT,\n") 478 479 if not self.if_root(node): 480 node_parent: Node = node.get_field("parent") 481 482 f.write(f"\t.parent = &{node_parent.label.name},\n") 483 else: 484 f.write("\t.parent = NULL,\n") 485 486 sign = self.get_sign_key(node) 487 nv_ctr = self.get_nv_ctr(node) 488 489 if sign or nv_ctr: 490 f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n") 491 492 if sign: 493 f.write("\t\t[0] = {\n") 494 f.write("\t\t\t.type = AUTH_METHOD_SIG,\n") 495 f.write("\t\t\t.param.sig = {\n") 496 497 f.write("\t\t\t\t.pk = &{},\n".format(self.extract_label(sign))) 498 f.write("\t\t\t\t.sig = &sig,\n") 499 f.write("\t\t\t\t.alg = &sig_alg,\n") 500 f.write("\t\t\t\t.data = &raw_data\n") 501 f.write("\t\t\t}\n") 502 f.write("\t\t}}{}\n".format("," if nv_ctr else "")) 503 504 if nv_ctr: 505 f.write("\t\t[1] = {\n") 506 f.write("\t\t\t.type = AUTH_METHOD_NV_CTR,\n") 507 f.write("\t\t\t.param.nv_ctr = {\n") 508 509 f.write("\t\t\t\t.cert_nv_ctr = &{},\n".format(self.extract_label(nv_ctr))) 510 f.write("\t\t\t\t.plat_nv_ctr = &{}\n".format(self.extract_label(nv_ctr))) 511 512 f.write("\t\t\t}\n") 513 f.write("\t\t}\n") 514 515 f.write("\t},\n") 516 517 auth_data = self.get_auth_data(node) 518 if auth_data: 519 f.write("\t.authenticated_data = (const auth_param_desc_t[COT_MAX_VERIFIED_PARAMS]) {\n") 520 521 for i, d in enumerate(auth_data): 522 type_desc, ptr, data_len = self.format_auth_data_val(d, node) 523 524 f.write("\t\t[{}] = {{\n".format(i)) 525 f.write("\t\t\t.type_desc = &{},\n".format(type_desc)) 526 f.write("\t\t\t.data = {\n") 527 528 f.write("\t\t\t\t.ptr = (void *){},\n".format(ptr)) 529 530 f.write("\t\t\t\t.len = (unsigned int){}\n".format(data_len)) 531 f.write("\t\t\t}\n") 532 533 f.write("\t\t}}{}\n".format("," if i != len(auth_data) - 1 else "")) 534 535 f.write("\t}\n") 536 537 f.write("};\n\n") 538 539 return 540 541 542 def img_to_c(self, node:Node, f): 543 node_image_id: int = node.get_field("image-id") 544 node_parent: Node = node.get_field("parent") 545 node_hash: Node = node.get_field("hash") 546 547 f.write(f"static const auth_img_desc_t {node.name} = {{\n") 548 f.write(f"\t.img_id = {node_image_id},\n") 549 f.write("\t.img_type = IMG_RAW,\n") 550 f.write(f"\t.parent = &{node_parent.label.name},\n") 551 f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n") 552 553 f.write("\t\t[0] = {\n") 554 f.write("\t\t\t.type = AUTH_METHOD_HASH,\n") 555 f.write("\t\t\t.param.hash = {\n") 556 f.write("\t\t\t\t.data = &raw_data,\n") 557 f.write(f"\t\t\t\t.hash = &{node_hash.label.name}\n") 558 f.write("\t\t\t}\n") 559 560 f.write("\t\t}\n") 561 f.write("\t}\n") 562 f.write("};\n\n") 563 564 return 565 566 def all_img_to_c(self, f): 567 images = self.get_all_images() 568 for i in images: 569 self.img_to_c(i, f) 570 571 f.write("\n") 572 573 def nv_to_c(self, f): 574 nv_ctr = self.get_all_nv_counters() 575 576 for nv in nv_ctr: 577 nv_oid: str = nv.get_field("oid") 578 579 f.write(f"static auth_param_type_desc_t {nv.name} = "\ 580 f"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_NV_CTR, \"{nv_oid}\");\n") 581 582 f.write("\n") 583 584 return 585 586 def pk_to_c(self, f): 587 pks = self.get_all_pks() 588 589 for p in pks: 590 pk_oid: str = p.get_field("oid") 591 592 f.write(f"static auth_param_type_desc_t {p.name} = "\ 593 f"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, \"{pk_oid}\");\n") 594 595 f.write("\n") 596 return 597 598 def buf_to_c(self, f): 599 certs = self.get_all_certificates() 600 601 buffers = set() 602 603 for c in certs: 604 auth_data = self.get_auth_data(c) 605 606 for a in auth_data: 607 type_desc, ptr, data_len = self.format_auth_data_val(a, c) 608 609 if not ptr in buffers: 610 f.write(f"static unsigned char {ptr}[{data_len}];\n") 611 buffers.add(ptr) 612 613 f.write("\n") 614 615 def param_to_c(self, f): 616 f.write("static auth_param_type_desc_t subject_pk = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, 0);\n") 617 f.write("static auth_param_type_desc_t sig = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG, 0);\n") 618 f.write("static auth_param_type_desc_t sig_alg = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG_ALG, 0);\n") 619 f.write("static auth_param_type_desc_t raw_data = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_RAW_DATA, 0);\n") 620 f.write("\n") 621 622 certs = self.get_all_certificates() 623 for c in certs: 624 hash = c.children 625 for h in hash: 626 name = h.name 627 oid = h.get_field("oid") 628 629 if re.search("_pk$", name): 630 ty = "AUTH_PARAM_PUB_KEY" 631 elif re.search("_hash$", name): 632 ty = "AUTH_PARAM_HASH" 633 634 f.write(f"static auth_param_type_desc_t {name} = "\ 635 f"AUTH_PARAM_TYPE_DESC({ty}, \"{oid}\");\n") 636 637 f.write("\n") 638 639 def cot_to_c(self, f): 640 certs = self.get_all_certificates() 641 images = self.get_all_images() 642 643 f.write("static const auth_img_desc_t * const cot_desc[] = {\n") 644 645 for i, c in enumerate(certs): 646 c_image_id: int = c.get_field("image-id") 647 648 f.write(f"\t[{c_image_id}] = &{c.name},\n") 649 650 for i, c in enumerate(images): 651 c_image_id: int = c.get_field("image-id") 652 653 f.write(f"\t[{c_image_id}] = &{c.name},\n") 654 655 f.write("};\n\n") 656 f.write("REGISTER_COT(cot_desc);\n") 657 return 658 659 def generate_c_file(self): 660 filename = Path(self.output) 661 filename.parent.mkdir(exist_ok=True, parents=True) 662 663 with open(self.output, 'w+') as output: 664 self.generate_header(output) 665 self.buf_to_c(output) 666 self.param_to_c(output) 667 self.nv_to_c(output) 668 self.pk_to_c(output) 669 self.all_cert_to_c(output) 670 self.all_img_to_c(output) 671 self.cot_to_c(output) 672 673 return 674