• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (c) 2024, Arm Limited and Contributors. All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5#
6
7import sys
8import re
9from pydevicetree.ast import CellArray, LabelReference
10from pydevicetree import Devicetree, Property, Node
11from pathlib import Path
12from typing import List, Optional
13
14class COT:
15    def __init__(self, inputfile: str, outputfile=None):
16        try:
17            self.tree = Devicetree.parseFile(inputfile)
18        except:
19            print("not a valid CoT DT file")
20            exit(1)
21
22        self.output = outputfile
23        self.input = inputfile
24        self.has_root = False
25
26        # edge cases
27        certs = self.get_all_certificates()
28        for c in certs:
29            if self.if_root(c):
30                if not c.get_fields("signing-key"):
31                    c.properties.append(Property("signing-key", CellArray([LabelReference("subject_pk")])))
32
33    def print_cert_info(self, node:Node):
34        img_id = node.get_field("image-id").values[0].replace('"', "")
35        sign_key = self.get_sign_key(node)
36        nv = self.get_nv_ctr(node)
37
38        info = "<b>name:</b> {}<br><b>image-id:</b> {}<br>{}{}{}"\
39                .format(node.name, img_id, "<b>root-certificate</b><br>" if self.if_root(node) else "", \
40                     "<b>signing-key:</b> " + self.extract_label(sign_key) + "<br>" if sign_key else "", \
41                     "<b>nv counter:</b> " + self.extract_label(nv) + "<br>" if nv else "")
42        return info
43
44    def print_data_info(self, node:Node):
45        oid = node.get_field("oid")
46        info = "<b>name:</b> {}<br><b>oid:</b> {}<br>" \
47                .format(node.name, oid)
48
49        return info
50
51    def print_img_info(self, node:Node):
52        hash = self.extract_label(node.get_fields("hash"))
53        img_id = node.get_field("image-id").values[0].replace('"', "")
54        info = "<b>name:</b> {}<br><b>image-id:</b> {}<br><b>hash:</b> {}"\
55                .format(node.name, img_id, hash)
56
57        return info
58
59    def tree_width(self, parent_set, root):
60        ans = 1
61        stack = [root]
62
63        while stack:
64            tmp_stack = []
65            while stack:
66                cur_node = stack.pop()
67                child = parent_set[cur_node]
68                for c in child:
69                    tmp_stack.append(c)
70
71            stack = tmp_stack.copy()
72            ans = max(ans, len(tmp_stack))
73
74        return ans
75
76    def resolve_lay(self, parent_set, lay, name_idx, root, bounds, break_name):
77        child = parent_set[root]
78
79        if len(child) == 0:
80            return
81
82        width = []
83        total_width = 0
84        for c in child:
85            w = self.tree_width(parent_set, c)
86            width.append(w)
87            total_width += w
88
89        allow_width = bounds[1] - bounds[0]
90        interval = allow_width / total_width
91        start = bounds[0]
92        for i, c in enumerate(child):
93            end = start + interval * width[i]
94            new_bounds = [start, end]
95            lay[name_idx[c]][0] = start + (end - start) / 2
96            if end - start < 0.28:
97                break_name.add(c)
98            start = end
99            self.resolve_lay(parent_set, lay, name_idx, c, new_bounds, break_name)
100
101    def tree_visualization(self):
102        import igraph
103        from igraph import Graph, EdgeSeq
104        import collections
105
106        cert = self.get_certificates()
107        pk = self.get_rot_keys()
108        nv = self.get_nv_counters()
109        image = self.get_images()
110
111        certs = cert.children
112        if pk:
113            pks = pk.children
114        else:
115            pks = []
116        nvs = nv.children
117        images = image.children
118
119        root_name = "CoT"
120
121        G = Graph()
122        detail = []
123        lay = []
124        name_idx = {}
125        parent_set = collections.defaultdict(list)
126
127        G.add_vertex(root_name)
128        detail.append("CoT Root")
129        name_idx[root_name] = len(lay)
130        lay.append([0,0])
131
132        G.add_vertex(cert.name)
133        G.add_edge(root_name, cert.name)
134        detail.append("All Certificates")
135        name_idx[cert.name] = len(lay)
136        lay.append([0, 1])
137        parent_set[root_name].append(cert.name)
138
139        if pk:
140            G.add_vertex(pk.name)
141            detail.append("All Public Trusted Key")
142            G.add_edge(root_name, pk.name)
143            name_idx[pk.name] = len(lay)
144            lay.append([-2.0, 1])
145            parent_set[root_name].append(pk.name)
146
147        G.add_vertex(nv.name)
148        detail.append("All NV Counters")
149        G.add_edge(root_name, nv.name)
150        name_idx[nv.name] = len(lay)
151        lay.append([2.0, 1])
152        parent_set[root_name].append(nv.name)
153
154        if pks:
155            for i, p in enumerate(pks):
156                G.add_vertex(p.name)
157                detail.append(self.print_data_info(p))
158                G.add_edge(pk.name, p.name)
159                name_idx[p.name] = len(lay)
160                parent_set[pk.name].append(p.name)
161                lay.append([0, lay[name_idx[pk.name]][1] + 1])
162
163        for c in certs:
164            G.add_vertex(c.name)
165            detail.append(self.print_cert_info(c))
166            name_idx[c.name] = len(lay)
167            if self.if_root(c):
168                G.add_edge(cert.name, c.name)
169                parent_set[cert.name].append(c.name)
170                lay.append([0, 2])
171            else:
172                parent = self.extract_label(c.get_fields("parent"))
173                G.add_edge(parent, c.name)
174                parent_set[parent].append(c.name)
175                lay.append([0, lay[name_idx[parent]][1] + 1])
176
177        for idx, i in enumerate(images):
178            G.add_vertex(i.name)
179            detail.append(self.print_img_info(i))
180            parent = self.extract_label(i.get_fields("parent"))
181            G.add_edge(parent, i.name)
182            parent_set[parent].append(i.name)
183            name_idx[i.name] = len(lay)
184            lay.append([0, lay[name_idx[parent]][1] + 1])
185
186        for i, n in enumerate(nvs):
187            G.add_vertex(n.name)
188            detail.append(self.print_data_info(n))
189            G.add_edge(nv.name, n.name)
190            name_idx[n.name] = len(lay)
191            parent_set[nv.name].append(n.name)
192            lay.append([0, lay[name_idx[nv.name]][1] + 1])
193
194        break_name = set()
195        self.resolve_lay(parent_set, lay, name_idx, root_name, [-3, 3], break_name)
196        #lay = G.layout('rt')
197
198        numVertex = len(G.get_vertex_dataframe())
199        vertices = G.get_vertex_dataframe()
200        v_label = []
201
202        for i in vertices['name']:
203            if i in break_name and len(i) > 10:
204                middle = len(i) // 2
205                v_label.append(i[:middle] + "<br>" + i[middle:])
206            else:
207                v_label.append(i)
208
209        position = {k: lay[k] for k in range(numVertex)}
210        Y = [lay[k][1] for k in range(numVertex)]
211        M = max(Y)
212
213        es = EdgeSeq(G) # sequence of edges
214        E = [e.tuple for e in G.es] # list of edges
215
216        L = len(position)
217        Xn = [position[k][0] for k in range(L)]
218        Yn = [2*M-position[k][1] for k in range(L)]
219        Xe = []
220        Ye = []
221        for edge in E:
222            Xe += [position[edge[0]][0], position[edge[1]][0], None]
223            Ye += [2*M-position[edge[0]][1], 2*M-position[edge[1]][1], None]
224
225        labels = v_label
226
227        import plotly.graph_objects as go
228        fig = go.Figure()
229        fig.add_trace(go.Scatter(x = Xe,
230                        y = Ye,
231                        mode = 'lines',
232                        line = dict(color='rgb(210,210,210)', width=2),
233                        hoverinfo = 'none'
234                        ))
235        fig.add_trace(go.Scatter(x = Xn,
236                        y = Yn,
237                        mode = 'markers',
238                        name = 'detail',
239                        marker = dict(symbol = 'circle-dot',
240                                        size = 50,
241                                        color = 'rgba(135, 206, 250, 0.8)',    #'#DB4551',
242                                        line = dict(color='MediumPurple', width=3)
243                                        ),
244                        text=detail,
245                        hoverinfo='text',
246                        hovertemplate =
247                            '<b>Detail</b><br>'
248                            '%{text}',
249                        opacity=0.8
250                        ))
251
252        def make_annotations(pos, text, font_size=10, font_color='rgb(0,0,0)'):
253            L = len(pos)
254            if len(text) != L:
255                raise ValueError('The lists pos and text must have the same len')
256            annotations = []
257            for k in range(L):
258                annotations.append(
259                    dict(
260                        text = labels[k],
261                        x = pos[k][0], y = 2*M-position[k][1],
262                        xref = 'x1', yref = 'y1',
263                        font = dict(color = font_color, size = font_size),
264                        showarrow = False)
265                )
266            return annotations
267
268        axis = dict(showline=False, # hide axis line, grid, ticklabels and  title
269            zeroline=False,
270            showgrid=False,
271            showticklabels=False,
272            )
273
274        fig.update_layout(title= 'CoT Device Tree',
275                    annotations=make_annotations(position, v_label),
276                    font_size=12,
277                    showlegend=False,
278                    xaxis=axis,
279                    yaxis=axis,
280                    margin=dict(l=40, r=40, b=85, t=100),
281                    hovermode='closest',
282                    plot_bgcolor='rgb(248,248,248)'
283                    )
284
285        fig.show()
286
287        return
288
289    def if_root(self, node:Node) -> bool:
290        for p in node.properties:
291            if p.name == "root-certificate":
292                return True
293        return False
294
295    def get_sign_key(self, node:Node):
296        for p in node.properties:
297            if p.name == "signing-key":
298                return p.values
299
300        return None
301
302    def get_nv_ctr(self, node:Node):
303        for nv in node.properties:
304            if nv.name == "antirollback-counter":
305                return nv.values
306
307        return None
308
309    def extract_label(self, label) -> str:
310        if not label:
311            return label
312        return label[0].label.name
313
314    def get_auth_data(self, node:Node):
315        return node.children
316
317    def format_auth_data_val(self, node:Node, cert:Node):
318        type_desc = node.name
319        ptr = type_desc + "_buf"
320        len = "HASH_DER_LEN"
321        if re.search("_pk$", type_desc):
322            len = "PK_DER_LEN"
323
324        # edge case
325        if not self.if_root(cert) and "key_cert" in cert.name:
326            if "content_pk" in ptr:
327                ptr = "content_pk_buf"
328
329        return type_desc, ptr, len
330
331    def get_node(self, nodes: List[Node], name: str) -> Node:
332        for i in nodes:
333            if i.name == name:
334                return i
335
336    def get_certificates(self) -> Node:
337        children = self.tree.children
338        for i in children:
339            if i.name == "cot":
340                return self.get_node(i.children, "manifests")
341
342    def get_images(self)-> Node:
343        children = self.tree.children
344        for i in children:
345            if i.name == "cot":
346                return self.get_node(i.children, "images")
347
348    def get_nv_counters(self) -> Node:
349        children = self.tree.children
350        return self.get_node(children, "non_volatile_counters")
351
352    def get_rot_keys(self) -> Node:
353        children = self.tree.children
354        return self.get_node(children, "rot_keys")
355
356    def get_all_certificates(self) -> Node:
357        cert = self.get_certificates()
358        return cert.children
359
360    def get_all_images(self) -> Node:
361        image = self.get_images()
362        return image.children
363
364    def get_all_nv_counters(self) -> Node:
365        nv = self.get_nv_counters()
366        return nv.children
367
368    def get_all_pks(self) -> Node:
369        pk = self.get_rot_keys()
370        if not pk:
371            return []
372        return pk.children
373
374    def validate_cert(self, node:Node) -> bool:
375        valid = True
376        if not node.has_field("image-id"):
377            print("{} missing mandatory attribute image-id".format(node.name))
378            valid = False
379
380        if not node.has_field("root-certificate"):
381            if not node.has_field("parent"):
382                print("{} missing mandatory attribute parent".format(node.name))
383                valid = False
384            else:
385                # check if refer to non existing parent
386                certs = self.get_all_certificates()
387                found = False
388                for c in certs:
389                    if c.name == self.extract_label(node.get_fields("parent")):
390                        found = True
391
392                if not found:
393                    print("{} refer to non existing parent".format(node.name))
394                    valid = False
395
396        else:
397            self.has_root = True
398
399        child = node.children
400        if child:
401            for c in child:
402                if not c.has_field("oid"):
403                    print("{} missing mandatory attribute oid".format(c.name))
404                    valid = False
405
406        return valid
407
408    def validate_img(self, node:Node) -> bool:
409        valid = True
410        if not node.has_field("image-id"):
411            print("{} missing mandatory attribute image-id".format(node.name))
412            valid = False
413
414        if not node.has_field("parent"):
415            print("{} missing mandatory attribute parent".format(node.name))
416            valid = False
417
418        if not node.has_field("hash"):
419            print("{} missing mandatory attribute hash".format(node.name))
420            valid = False
421
422        # check if refer to non existing parent
423        certs = self.get_all_certificates()
424        found = False
425        for c in certs:
426            if c.name == self.extract_label(node.get_fields("parent")):
427                found = True
428
429        if not found:
430            print("{} refer to non existing parent".format(node.name))
431            valid = False
432
433        return valid
434
435    def validate_nodes(self) -> bool:
436        valid = True
437
438        certs = self.get_all_certificates()
439        images = self.get_all_images()
440
441        for n in certs:
442            node_valid = self.validate_cert(n)
443            valid = valid and node_valid
444
445        for i in images:
446            node_valid = self.validate_img(i)
447            valid = valid and node_valid
448
449        if not self.has_root:
450            print("missing root certificate")
451
452        return valid
453
454    def include_to_c(self, f):
455        f.write("#include <stddef.h>\n")
456        f.write("#include <mbedtls/version.h>\n")
457        f.write("#include <common/tbbr/cot_def.h>\n")
458        f.write("#include <drivers/auth/auth_mod.h>\n")
459        f.write("#include <platform_def.h>\n\n")
460        return
461
462    def generate_header(self, output):
463        self.include_to_c(output)
464
465    def all_cert_to_c(self, f):
466        certs = self.get_all_certificates()
467        for c in certs:
468            self.cert_to_c(c, f)
469
470        f.write("\n")
471
472    def cert_to_c(self, node: Node, f):
473        node_image_id: int = node.get_field("image-id")
474
475        f.write(f"static const auth_img_desc_t {node.name} = {{\n")
476        f.write(f"\t.img_id = {node_image_id},\n")
477        f.write("\t.img_type = IMG_CERT,\n")
478
479        if not self.if_root(node):
480            node_parent: Node = node.get_field("parent")
481
482            f.write(f"\t.parent = &{node_parent.label.name},\n")
483        else:
484            f.write("\t.parent = NULL,\n")
485
486        sign = self.get_sign_key(node)
487        nv_ctr = self.get_nv_ctr(node)
488
489        if sign or nv_ctr:
490            f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n")
491
492        if sign:
493            f.write("\t\t[0] = {\n")
494            f.write("\t\t\t.type = AUTH_METHOD_SIG,\n")
495            f.write("\t\t\t.param.sig = {\n")
496
497            f.write("\t\t\t\t.pk = &{},\n".format(self.extract_label(sign)))
498            f.write("\t\t\t\t.sig = &sig,\n")
499            f.write("\t\t\t\t.alg = &sig_alg,\n")
500            f.write("\t\t\t\t.data = &raw_data\n")
501            f.write("\t\t\t}\n")
502            f.write("\t\t}}{}\n".format("," if nv_ctr else ""))
503
504        if nv_ctr:
505            f.write("\t\t[1] = {\n")
506            f.write("\t\t\t.type = AUTH_METHOD_NV_CTR,\n")
507            f.write("\t\t\t.param.nv_ctr = {\n")
508
509            f.write("\t\t\t\t.cert_nv_ctr = &{},\n".format(self.extract_label(nv_ctr)))
510            f.write("\t\t\t\t.plat_nv_ctr = &{}\n".format(self.extract_label(nv_ctr)))
511
512            f.write("\t\t\t}\n")
513            f.write("\t\t}\n")
514
515        f.write("\t},\n")
516
517        auth_data = self.get_auth_data(node)
518        if auth_data:
519            f.write("\t.authenticated_data = (const auth_param_desc_t[COT_MAX_VERIFIED_PARAMS]) {\n")
520
521            for i, d in enumerate(auth_data):
522                type_desc, ptr, data_len = self.format_auth_data_val(d, node)
523
524                f.write("\t\t[{}] = {{\n".format(i))
525                f.write("\t\t\t.type_desc = &{},\n".format(type_desc))
526                f.write("\t\t\t.data = {\n")
527
528                f.write("\t\t\t\t.ptr = (void *){},\n".format(ptr))
529
530                f.write("\t\t\t\t.len = (unsigned int){}\n".format(data_len))
531                f.write("\t\t\t}\n")
532
533                f.write("\t\t}}{}\n".format("," if i != len(auth_data) - 1 else ""))
534
535            f.write("\t}\n")
536
537        f.write("};\n\n")
538
539        return
540
541
542    def img_to_c(self, node:Node, f):
543        node_image_id: int = node.get_field("image-id")
544        node_parent: Node = node.get_field("parent")
545        node_hash: Node = node.get_field("hash")
546
547        f.write(f"static const auth_img_desc_t {node.name} = {{\n")
548        f.write(f"\t.img_id = {node_image_id},\n")
549        f.write("\t.img_type = IMG_RAW,\n")
550        f.write(f"\t.parent = &{node_parent.label.name},\n")
551        f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n")
552
553        f.write("\t\t[0] = {\n")
554        f.write("\t\t\t.type = AUTH_METHOD_HASH,\n")
555        f.write("\t\t\t.param.hash = {\n")
556        f.write("\t\t\t\t.data = &raw_data,\n")
557        f.write(f"\t\t\t\t.hash = &{node_hash.label.name}\n")
558        f.write("\t\t\t}\n")
559
560        f.write("\t\t}\n")
561        f.write("\t}\n")
562        f.write("};\n\n")
563
564        return
565
566    def all_img_to_c(self, f):
567        images = self.get_all_images()
568        for i in images:
569            self.img_to_c(i, f)
570
571        f.write("\n")
572
573    def nv_to_c(self, f):
574        nv_ctr = self.get_all_nv_counters()
575
576        for nv in nv_ctr:
577            nv_oid: str = nv.get_field("oid")
578
579            f.write(f"static auth_param_type_desc_t {nv.name} = "\
580                    f"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_NV_CTR, \"{nv_oid}\");\n")
581
582        f.write("\n")
583
584        return
585
586    def pk_to_c(self, f):
587        pks = self.get_all_pks()
588
589        for p in pks:
590            pk_oid: str = p.get_field("oid")
591
592            f.write(f"static auth_param_type_desc_t {p.name} = "\
593                    f"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, \"{pk_oid}\");\n")
594
595        f.write("\n")
596        return
597
598    def buf_to_c(self, f):
599        certs = self.get_all_certificates()
600
601        buffers = set()
602
603        for c in certs:
604            auth_data = self.get_auth_data(c)
605
606            for a in auth_data:
607                type_desc, ptr, data_len = self.format_auth_data_val(a, c)
608
609                if not ptr in buffers:
610                    f.write(f"static unsigned char {ptr}[{data_len}];\n")
611                    buffers.add(ptr)
612
613        f.write("\n")
614
615    def param_to_c(self, f):
616        f.write("static auth_param_type_desc_t subject_pk = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, 0);\n")
617        f.write("static auth_param_type_desc_t sig = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG, 0);\n")
618        f.write("static auth_param_type_desc_t sig_alg = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG_ALG, 0);\n")
619        f.write("static auth_param_type_desc_t raw_data = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_RAW_DATA, 0);\n")
620        f.write("\n")
621
622        certs = self.get_all_certificates()
623        for c in certs:
624            hash = c.children
625            for h in hash:
626                name = h.name
627                oid = h.get_field("oid")
628
629                if re.search("_pk$", name):
630                    ty = "AUTH_PARAM_PUB_KEY"
631                elif re.search("_hash$", name):
632                    ty = "AUTH_PARAM_HASH"
633
634                f.write(f"static auth_param_type_desc_t {name} = "\
635                    f"AUTH_PARAM_TYPE_DESC({ty}, \"{oid}\");\n")
636
637        f.write("\n")
638
639    def cot_to_c(self, f):
640        certs = self.get_all_certificates()
641        images = self.get_all_images()
642
643        f.write("static const auth_img_desc_t * const cot_desc[] = {\n")
644
645        for i, c in enumerate(certs):
646            c_image_id: int = c.get_field("image-id")
647
648            f.write(f"\t[{c_image_id}]	=	&{c.name},\n")
649
650        for i, c in enumerate(images):
651            c_image_id: int = c.get_field("image-id")
652
653            f.write(f"\t[{c_image_id}]	=	&{c.name},\n")
654
655        f.write("};\n\n")
656        f.write("REGISTER_COT(cot_desc);\n")
657        return
658
659    def generate_c_file(self):
660        filename = Path(self.output)
661        filename.parent.mkdir(exist_ok=True, parents=True)
662
663        with open(self.output, 'w+') as output:
664            self.generate_header(output)
665            self.buf_to_c(output)
666            self.param_to_c(output)
667            self.nv_to_c(output)
668            self.pk_to_c(output)
669            self.all_cert_to_c(output)
670            self.all_img_to_c(output)
671            self.cot_to_c(output)
672
673        return
674