blob: c1d53e21114fc9a5c5eaf0ed2a388f9bf12dfc39 [file] [log] [blame]
#
# Copyright (c) 2024, Arm Limited and Contributors. All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
#
import sys
import re
from cot_dt2c.pydevicetree.source.parser import ifdef_stack
from cot_dt2c.pydevicetree.ast import CellArray, LabelReference
from cot_dt2c.pydevicetree import *
from pathlib import Path
def extractNumber(s):
for i in s:
if i.isdigit():
return (int)(i)
return -1
def removeNumber(s):
result = ''.join([i for i in s if not i.isdigit()])
return result
class COT:
def __init__(self, inputfile: str, outputfile=None):
with open(inputfile, 'r') as f:
contents = f.read()
pos = contents.find("cot")
if pos == -1:
print("not a valid CoT DT file")
exit(1)
contents = contents[pos:]
try:
self.tree = Devicetree.parseStr(contents)
except:
print("not a valid CoT DT file")
exit(1)
self.output = outputfile
self.input = inputfile
self.has_root = False
# edge cases
certs = self.get_all_certificates()
for c in certs:
if self.if_root(c):
if not c.get_fields("signing-key"):
c.properties.append(Property("signing-key", CellArray([LabelReference("subject_pk")])))
def print_cert_info(self, node:Node):
img_id = node.get_field("image-id").values[0].replace('"', "")
sign_key = self.get_sign_key(node)
nv = self.get_nv_ctr(node)
info = "<b>name:</b> {}<br><b>image-id:</b> {}<br>{}{}{}"\
.format(node.name, img_id, "<b>root-certificate</b><br>" if self.if_root(node) else "", \
"<b>signing-key:</b> " + self.extract_label(sign_key) + "<br>" if sign_key else "", \
"<b>nv counter:</b> " + self.extract_label(nv) + "<br>" if nv else "")
return info
def print_data_info(self, node:Node):
oid = node.get_field("oid")
info = "<b>name:</b> {}<br><b>oid:</b> {}<br>" \
.format(node.name, oid)
return info
def print_img_info(self, node:Node):
hash = self.extract_label(node.get_fields("hash"))
img_id = node.get_field("image-id").values[0].replace('"', "")
info = "<b>name:</b> {}<br><b>image-id:</b> {}<br><b>hash:</b> {}"\
.format(node.name, img_id, hash)
return info
def tree_width(self, parent_set, root):
ans = 1
stack = [root]
while stack:
tmp_stack = []
while stack:
cur_node = stack.pop()
child = parent_set[cur_node]
for c in child:
tmp_stack.append(c)
stack = tmp_stack.copy()
ans = max(ans, len(tmp_stack))
return ans
def resolve_lay(self, parent_set, lay, name_idx, root, bounds, break_name):
child = parent_set[root]
if len(child) == 0:
return
width = []
total_width = 0
for c in child:
w = self.tree_width(parent_set, c)
width.append(w)
total_width += w
allow_width = bounds[1] - bounds[0]
interval = allow_width / total_width
start = bounds[0]
for i, c in enumerate(child):
end = start + interval * width[i]
new_bounds = [start, end]
lay[name_idx[c]][0] = start + (end - start) / 2
if end - start < 0.28:
break_name.add(c)
start = end
self.resolve_lay(parent_set, lay, name_idx, c, new_bounds, break_name)
def tree_visualization(self):
import igraph
from igraph import Graph, EdgeSeq
import collections
cert = self.get_certificates()
pk = self.get_rot_keys()
nv = self.get_nv_counters()
image = self.get_images()
certs = cert.children
if pk:
pks = pk.children
else:
pks = []
nvs = nv.children
images = image.children
root_name = "CoT"
G = Graph()
detail = []
lay = []
name_idx = {}
parent_set = collections.defaultdict(list)
G.add_vertex(root_name)
detail.append("CoT Root")
name_idx[root_name] = len(lay)
lay.append([0,0])
G.add_vertex(cert.name)
G.add_edge(root_name, cert.name)
detail.append("All Certificates")
name_idx[cert.name] = len(lay)
lay.append([0, 1])
parent_set[root_name].append(cert.name)
if pk:
G.add_vertex(pk.name)
detail.append("All Public Trusted Key")
G.add_edge(root_name, pk.name)
name_idx[pk.name] = len(lay)
lay.append([-2.0, 1])
parent_set[root_name].append(pk.name)
G.add_vertex(nv.name)
detail.append("All NV Counters")
G.add_edge(root_name, nv.name)
name_idx[nv.name] = len(lay)
lay.append([2.0, 1])
parent_set[root_name].append(nv.name)
if pks:
for i, p in enumerate(pks):
G.add_vertex(p.name)
detail.append(self.print_data_info(p))
G.add_edge(pk.name, p.name)
name_idx[p.name] = len(lay)
parent_set[pk.name].append(p.name)
lay.append([0, lay[name_idx[pk.name]][1] + 1])
for c in certs:
G.add_vertex(c.name)
detail.append(self.print_cert_info(c))
name_idx[c.name] = len(lay)
if self.if_root(c):
G.add_edge(cert.name, c.name)
parent_set[cert.name].append(c.name)
lay.append([0, 2])
else:
parent = self.extract_label(c.get_fields("parent"))
G.add_edge(parent, c.name)
parent_set[parent].append(c.name)
lay.append([0, lay[name_idx[parent]][1] + 1])
for idx, i in enumerate(images):
G.add_vertex(i.name)
detail.append(self.print_img_info(i))
parent = self.extract_label(i.get_fields("parent"))
G.add_edge(parent, i.name)
parent_set[parent].append(i.name)
name_idx[i.name] = len(lay)
lay.append([0, lay[name_idx[parent]][1] + 1])
for i, n in enumerate(nvs):
G.add_vertex(n.name)
detail.append(self.print_data_info(n))
G.add_edge(nv.name, n.name)
name_idx[n.name] = len(lay)
parent_set[nv.name].append(n.name)
lay.append([0, lay[name_idx[nv.name]][1] + 1])
break_name = set()
self.resolve_lay(parent_set, lay, name_idx, root_name, [-3, 3], break_name)
#lay = G.layout('rt')
numVertex = len(G.get_vertex_dataframe())
vertices = G.get_vertex_dataframe()
v_label = []
for i in vertices['name']:
if i in break_name and len(i) > 10:
middle = len(i) // 2
v_label.append(i[:middle] + "<br>" + i[middle:])
else:
v_label.append(i)
position = {k: lay[k] for k in range(numVertex)}
Y = [lay[k][1] for k in range(numVertex)]
M = max(Y)
es = EdgeSeq(G) # sequence of edges
E = [e.tuple for e in G.es] # list of edges
L = len(position)
Xn = [position[k][0] for k in range(L)]
Yn = [2*M-position[k][1] for k in range(L)]
Xe = []
Ye = []
for edge in E:
Xe += [position[edge[0]][0], position[edge[1]][0], None]
Ye += [2*M-position[edge[0]][1], 2*M-position[edge[1]][1], None]
labels = v_label
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x = Xe,
y = Ye,
mode = 'lines',
line = dict(color='rgb(210,210,210)', width=2),
hoverinfo = 'none'
))
fig.add_trace(go.Scatter(x = Xn,
y = Yn,
mode = 'markers',
name = 'detail',
marker = dict(symbol = 'circle-dot',
size = 50,
color = 'rgba(135, 206, 250, 0.8)', #'#DB4551',
line = dict(color='MediumPurple', width=3)
),
text=detail,
hoverinfo='text',
hovertemplate =
'<b>Detail</b><br>'
'%{text}',
opacity=0.8
))
def make_annotations(pos, text, font_size=10, font_color='rgb(0,0,0)'):
L = len(pos)
if len(text) != L:
raise ValueError('The lists pos and text must have the same len')
annotations = []
for k in range(L):
annotations.append(
dict(
text = labels[k],
x = pos[k][0], y = 2*M-position[k][1],
xref = 'x1', yref = 'y1',
font = dict(color = font_color, size = font_size),
showarrow = False)
)
return annotations
axis = dict(showline=False, # hide axis line, grid, ticklabels and title
zeroline=False,
showgrid=False,
showticklabels=False,
)
fig.update_layout(title= 'CoT Device Tree',
annotations=make_annotations(position, v_label),
font_size=12,
showlegend=False,
xaxis=axis,
yaxis=axis,
margin=dict(l=40, r=40, b=85, t=100),
hovermode='closest',
plot_bgcolor='rgb(248,248,248)'
)
fig.show()
return
def if_root(self, node:Node) -> bool:
for p in node.properties:
if p.name == "root-certificate":
return True
return False
def get_sign_key(self, node:Node):
for p in node.properties:
if p.name == "signing-key":
return p.values
return None
def get_nv_ctr(self, node:Node):
for nv in node.properties:
if nv.name == "antirollback-counter":
return nv.values
return None
def extract_label(self, label) -> str:
if not label:
return label
return label[0].label.name
def get_auth_data(self, node:Node):
return node.children
def format_auth_data_val(self, node:Node, cert:Node):
type_desc = node.name
if "sp_pkg" in type_desc:
ptr = removeNumber(type_desc) + "_buf"
else:
ptr = type_desc + "_buf"
len = "(unsigned int)HASH_DER_LEN"
if "pk" in type_desc:
len = "(unsigned int)PK_DER_LEN"
# edge case
if not self.if_root(cert) and "key_cert" in cert.name:
if "content_pk" in ptr:
ptr = "content_pk_buf"
return type_desc, ptr, len
def get_node(self, nodes: list[Node], name: str) -> Node:
for i in nodes:
if i.name == name:
return i
def get_certificates(self) -> Node:
children = self.tree.children
for i in children:
if i.name == "cot":
return self.get_node(i.children, "manifests")
def get_images(self)-> Node:
children = self.tree.children
for i in children:
if i.name == "cot":
return self.get_node(i.children, "images")
def get_nv_counters(self) -> Node:
children = self.tree.children
return self.get_node(children, "non_volatile_counters")
def get_rot_keys(self) -> Node:
children = self.tree.children
return self.get_node(children, "rot_keys")
def get_all_certificates(self) -> Node:
cert = self.get_certificates()
return cert.children
def get_all_images(self) -> Node:
image = self.get_images()
return image.children
def get_all_nv_counters(self) -> Node:
nv = self.get_nv_counters()
return nv.children
def get_all_pks(self) -> Node:
pk = self.get_rot_keys()
if not pk:
return []
return pk.children
def validate_cert(self, node:Node) -> bool:
valid = True
if not node.has_field("image-id"):
print("{} missing mandatory attribute image-id".format(node.name))
valid = False
if not node.has_field("root-certificate"):
if not node.has_field("parent"):
print("{} missing mandatory attribute parent".format(node.name))
valid = False
else:
# check if refer to non existing parent
certs = self.get_all_certificates()
found = False
for c in certs:
if c.name == self.extract_label(node.get_fields("parent")):
found = True
if not found:
print("{} refer to non existing parent".format(node.name))
valid = False
else:
self.has_root = True
child = node.children
if child:
for c in child:
if not c.has_field("oid"):
print("{} missing mandatory attribute oid".format(c.name))
valid = False
return valid
def validate_img(self, node:Node) -> bool:
valid = True
if not node.has_field("image-id"):
print("{} missing mandatory attribute image-id".format(node.name))
valid = False
if not node.has_field("parent"):
print("{} missing mandatory attribute parent".format(node.name))
valid = False
if not node.has_field("hash"):
print("{} missing mandatory attribute hash".format(node.name))
valid = False
# check if refer to non existing parent
certs = self.get_all_certificates()
found = False
for c in certs:
if c.name == self.extract_label(node.get_fields("parent")):
found = True
if not found:
print("{} refer to non existing parent".format(node.name))
valid = False
return valid
def validate_nodes(self) -> bool:
valid = True
if ifdef_stack:
print("invalid ifdef macro")
valid = False
certs = self.get_all_certificates()
images = self.get_all_images()
for n in certs:
node_valid = self.validate_cert(n)
valid = valid and node_valid
for i in images:
node_valid = self.validate_img(i)
valid = valid and node_valid
if not self.has_root:
print("missing root certificate")
return valid
def extract_licence(self, f):
licence = []
licencereg = re.compile(r'/\*')
licenceendReg = re.compile(r'\*/')
licencePre = False
for line in f:
match = licencereg.search(line)
if match != None:
licence.append(line)
licencePre = True
continue
match = licenceendReg.search(line)
if match != None:
licence.append(line)
licencePre = False
return licence
if licencePre:
licence.append(line)
else:
return licence
return licence
def licence_to_c(self, licence, f):
if len(licence) != 0:
for i in licence:
f.write(i)
f.write("\n")
return
def extract_include(self, f):
include = []
for line in f:
if "cot" in line:
return include
if line != "" and "common" not in line and line != "\n":
include.append(line)
return include
def include_to_c(self, include, f):
f.write("#include <stddef.h>\n")
f.write("#include <mbedtls/version.h>\n")
f.write("#include <common/tbbr/cot_def.h>\n")
f.write("#include <drivers/auth/auth_mod.h>\n")
f.write("\n")
for i in include:
f.write(i)
f.write("\n")
f.write("#include <platform_def.h>\n\n")
return
def generate_header(self, input, output):
licence = self.extract_licence(input)
include = self.extract_include(input)
self.licence_to_c(licence, output)
self.include_to_c(include, output)
def all_cert_to_c(self, f):
certs = self.get_all_certificates()
for c in certs:
self.cert_to_c(c, f)
f.write("\n")
def cert_to_c(self, node: Node, f):
ifdef = node.get_fields("ifdef")
if ifdef:
for i in ifdef:
f.write("{}\n".format(i))
f.write("static const auth_img_desc_t {} = {{\n".format(node.name))
f.write("\t.img_id = {},\n".format(node.get_field("image-id").values[0].replace('"', "")))
f.write("\t.img_type = IMG_CERT,\n")
if not self.if_root(node):
f.write("\t.parent = &{},\n".format(node.get_field("parent").label.name))
else:
f.write("\t.parent = NULL,\n")
sign = self.get_sign_key(node)
nv_ctr = self.get_nv_ctr(node)
if sign or nv_ctr:
f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n")
if sign:
f.write("\t\t[0] = {\n")
f.write("\t\t\t.type = AUTH_METHOD_SIG,\n")
f.write("\t\t\t.param.sig = {\n")
f.write("\t\t\t\t.pk = &{},\n".format(self.extract_label(sign)))
f.write("\t\t\t\t.sig = &sig,\n")
f.write("\t\t\t\t.alg = &sig_alg,\n")
f.write("\t\t\t\t.data = &raw_data\n")
f.write("\t\t\t}\n")
f.write("\t\t}}{}\n".format("," if nv_ctr else ""))
if nv_ctr:
f.write("\t\t[1] = {\n")
f.write("\t\t\t.type = AUTH_METHOD_NV_CTR,\n")
f.write("\t\t\t.param.nv_ctr = {\n")
f.write("\t\t\t\t.cert_nv_ctr = &{},\n".format(self.extract_label(nv_ctr)))
f.write("\t\t\t\t.plat_nv_ctr = &{}\n".format(self.extract_label(nv_ctr)))
f.write("\t\t\t}\n")
f.write("\t\t}\n")
f.write("\t},\n")
auth_data = self.get_auth_data(node)
if auth_data:
f.write("\t.authenticated_data = (const auth_param_desc_t[COT_MAX_VERIFIED_PARAMS]) {\n")
for i, d in enumerate(auth_data):
type_desc, ptr, data_len = self.format_auth_data_val(d, node)
f.write("\t\t[{}] = {{\n".format(i))
f.write("\t\t\t.type_desc = &{},\n".format(type_desc))
f.write("\t\t\t.data = {\n")
n = extractNumber(type_desc)
if "pkg" not in type_desc or n == -1:
f.write("\t\t\t\t.ptr = (void *){},\n".format(ptr))
else:
f.write("\t\t\t\t.ptr = (void *){}[{}],\n".format(ptr, n-1))
f.write("\t\t\t\t.len = {}\n".format(data_len))
f.write("\t\t\t}\n")
f.write("\t\t}}{}\n".format("," if i != len(auth_data) - 1 else ""))
f.write("\t}\n")
f.write("};\n\n")
if ifdef:
for i in ifdef:
f.write("#endif\n")
f.write("\n")
return
def img_to_c(self, node:Node, f):
ifdef = node.get_fields("ifdef")
if ifdef:
for i in ifdef:
f.write("{}\n".format(i))
f.write("static const auth_img_desc_t {} = {{\n".format(node.name))
f.write("\t.img_id = {},\n".format(node.get_field("image-id").values[0].replace('"', "")))
f.write("\t.img_type = IMG_RAW,\n")
f.write("\t.parent = &{},\n".format(node.get_field("parent").label.name))
f.write("\t.img_auth_methods = (const auth_method_desc_t[AUTH_METHOD_NUM]) {\n")
f.write("\t\t[0] = {\n")
f.write("\t\t\t.type = AUTH_METHOD_HASH,\n")
f.write("\t\t\t.param.hash = {\n")
f.write("\t\t\t\t.data = &raw_data,\n")
f.write("\t\t\t\t.hash = &{}\n".format(node.get_field("hash").label.name))
f.write("\t\t\t}\n")
f.write("\t\t}\n")
f.write("\t}\n")
f.write("};\n\n")
if ifdef:
for i in ifdef:
f.write("#endif\n")
f.write("\n")
return
def all_img_to_c(self, f):
images = self.get_all_images()
for i in images:
self.img_to_c(i, f)
f.write("\n")
def nv_to_c(self, f):
nv_ctr = self.get_all_nv_counters()
for nv in nv_ctr:
f.write("static auth_param_type_desc_t {} = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_NV_CTR, {});\n".format(nv.name, nv.get_field("oid")))
f.write("\n")
return
def pk_to_c(self, f):
pks = self.get_all_pks()
for p in pks:
f.write("static auth_param_type_desc_t {} = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, {});\n".format(p.name, p.get_field("oid")))
f.write("\n")
return
def buf_to_c(self, f):
certs = self.get_all_certificates()
buffers = {}
for c in certs:
auth_data = self.get_auth_data(c)
for a in auth_data:
type_desc, ptr, data_len = self.format_auth_data_val(a, c)
if ptr not in buffers:
buffers[ptr] = c.get_fields("ifdef")
for key, values in buffers.items():
if values:
for i in values:
f.write("{}\n".format(i))
if "sp_pkg_hash_buf" in key:
f.write("static unsigned char {}[MAX_SP_IDS][HASH_DER_LEN];\n".format(key))
elif "pk" in key:
f.write("static unsigned char {}[PK_DER_LEN];\n".format(key))
else:
f.write("static unsigned char {}[HASH_DER_LEN];\n".format(key))
if values:
for i in values:
f.write("#endif\n")
f.write("\n")
def param_to_c(self, f):
f.write("static auth_param_type_desc_t subject_pk = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, 0);\n")
f.write("static auth_param_type_desc_t sig = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG, 0);\n")
f.write("static auth_param_type_desc_t sig_alg = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_SIG_ALG, 0);\n")
f.write("static auth_param_type_desc_t raw_data = AUTH_PARAM_TYPE_DESC(AUTH_PARAM_RAW_DATA, 0);\n")
f.write("\n")
certs = self.get_all_certificates()
for c in certs:
ifdef = c.get_fields("ifdef")
if ifdef:
for i in ifdef:
f.write("{}\n".format(i))
hash = c.children
for h in hash:
name = h.name
oid = h.get_field("oid")
if "pk" in name and "pkg" not in name:
f.write("static auth_param_type_desc_t {} = "\
"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_PUB_KEY, {});\n".format(name, oid))
elif "hash" in name:
f.write("static auth_param_type_desc_t {} = "\
"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_HASH, {});\n".format(name, oid))
elif "ctr" in name:
f.write("static auth_param_type_desc_t {} = "\
"AUTH_PARAM_TYPE_DESC(AUTH_PARAM_NV_CTR, {});\n".format(name, oid))
if ifdef:
for i in ifdef:
f.write("#endif\n")
f.write("\n")
def cot_to_c(self, f):
certs = self.get_all_certificates()
images = self.get_all_images()
f.write("static const auth_img_desc_t * const cot_desc[] = {\n")
for i, c in enumerate(certs):
ifdef = c.get_fields("ifdef")
if ifdef:
for i in ifdef:
f.write("{}\n".format(i))
f.write("\t[{}] = &{}{}\n".format(c.get_field("image-id").values[0], c.name, ","))
if ifdef:
for i in ifdef:
f.write("#endif\n")
for i, c in enumerate(images):
ifdef = c.get_fields("ifdef")
if ifdef:
for i in ifdef:
f.write("{}\n".format(i))
f.write("\t[{}] = &{}{}\n".format(c.get_field("image-id").values[0], c.name, "," if i != len(images) - 1 else ""))
if ifdef:
for i in ifdef:
f.write("#endif\n")
f.write("};\n\n")
f.write("REGISTER_COT(cot_desc);\n")
return
def generate_c_file(self):
filename = Path(self.output)
filename.parent.mkdir(exist_ok=True, parents=True)
output = open(self.output, 'w+')
input = open(self.input, "r")
self.generate_header(input, output)
self.buf_to_c(output)
self.param_to_c(output)
self.nv_to_c(output)
self.pk_to_c(output)
self.all_cert_to_c(output)
self.all_img_to_c(output)
self.cot_to_c(output)
return