test: vboot: Add a test for a forged configuration

Add a check to make sure that it is not possible to add a new
configuration and use the hashed nodes and hash of another configuration.

Signed-off-by: Simon Glass <sjg@chromium.org>
diff --git a/test/py/tests/vboot_forge.py b/test/py/tests/vboot_forge.py
new file mode 100644
index 0000000..0fb7ef4
--- /dev/null
+++ b/test/py/tests/vboot_forge.py
@@ -0,0 +1,423 @@
+#!/usr/bin/python3
+# SPDX-License-Identifier: GPL-2.0
+# Copyright (c) 2020, F-Secure Corporation, https://foundry.f-secure.com
+#
+# pylint: disable=E1101,W0201,C0103
+
+"""
+Verified boot image forgery tools and utilities
+
+This module provides services to both take apart and regenerate FIT images
+in a way that preserves all existing verified boot signatures, unless you
+manipulate nodes in the process.
+"""
+
+import struct
+import binascii
+from io import BytesIO
+
+#
+# struct parsing helpers
+#
+
+class BetterStructMeta(type):
+    """
+    Preprocesses field definitions and creates a struct.Struct instance from them
+    """
+    def __new__(cls, clsname, superclasses, attributedict):
+        if clsname != 'BetterStruct':
+            fields = attributedict['__fields__']
+            field_types = [_[0] for _ in fields]
+            field_names = [_[1] for _ in fields if _[1] is not None]
+            attributedict['__names__'] = field_names
+            s = struct.Struct(attributedict.get('__endian__', '') + ''.join(field_types))
+            attributedict['__struct__'] = s
+            attributedict['size'] = s.size
+        return type.__new__(cls, clsname, superclasses, attributedict)
+
+class BetterStruct(metaclass=BetterStructMeta):
+    """
+    Base class for better structures
+    """
+    def __init__(self):
+        for t, n in self.__fields__:
+            if 's' in t:
+                setattr(self, n, '')
+            elif t in ('Q', 'I', 'H', 'B'):
+                setattr(self, n, 0)
+
+    @classmethod
+    def unpack_from(cls, buffer, offset=0):
+        """
+        Unpack structure instance from a buffer
+        """
+        fields = cls.__struct__.unpack_from(buffer, offset)
+        instance = cls()
+        for n, v in zip(cls.__names__, fields):
+            setattr(instance, n, v)
+        return instance
+
+    def pack(self):
+        """
+        Pack structure instance into bytes
+        """
+        return self.__struct__.pack(*[getattr(self, n) for n in self.__names__])
+
+    def __str__(self):
+        items = ["'%s': %s" % (n, repr(getattr(self, n))) for n in self.__names__ if n is not None]
+        return '(' + ', '.join(items) + ')'
+
+#
+# some defs for flat DT data
+#
+
+class HeaderV17(BetterStruct):
+    __endian__ = '>'
+    __fields__ = [
+        ('I', 'magic'),
+        ('I', 'totalsize'),
+        ('I', 'off_dt_struct'),
+        ('I', 'off_dt_strings'),
+        ('I', 'off_mem_rsvmap'),
+        ('I', 'version'),
+        ('I', 'last_comp_version'),
+        ('I', 'boot_cpuid_phys'),
+        ('I', 'size_dt_strings'),
+        ('I', 'size_dt_struct'),
+    ]
+
+class RRHeader(BetterStruct):
+    __endian__ = '>'
+    __fields__ = [
+        ('Q', 'address'),
+        ('Q', 'size'),
+    ]
+
+class PropHeader(BetterStruct):
+    __endian__ = '>'
+    __fields__ = [
+        ('I', 'value_size'),
+        ('I', 'name_offset'),
+    ]
+
+# magical constants for DTB format
+OF_DT_HEADER = 0xd00dfeed
+OF_DT_BEGIN_NODE = 1
+OF_DT_END_NODE = 2
+OF_DT_PROP = 3
+OF_DT_END = 9
+
+class StringsBlock:
+    """
+    Represents a parsed device tree string block
+    """
+    def __init__(self, values=None):
+        if values is None:
+            self.values = []
+        else:
+            self.values = values
+
+    def __getitem__(self, at):
+        if isinstance(at, str):
+            offset = 0
+            for value in self.values:
+                if value == at:
+                    break
+                offset += len(value) + 1
+            else:
+                self.values.append(at)
+            return offset
+
+        if isinstance(at, int):
+            offset = 0
+            for value in self.values:
+                if offset == at:
+                    return value
+                offset += len(value) + 1
+            raise IndexError('no string found corresponding to the given offset')
+
+        raise TypeError('only strings and integers are accepted')
+
+class Prop:
+    """
+    Represents a parsed device tree property
+    """
+    def __init__(self, name=None, value=None):
+        self.name = name
+        self.value = value
+
+    def clone(self):
+        return Prop(self.name, self.value)
+
+    def __repr__(self):
+        return "<Prop(name='%s', value=%s>" % (self.name, repr(self.value))
+
+class Node:
+    """
+    Represents a parsed device tree node
+    """
+    def __init__(self, name=None):
+        self.name = name
+        self.props = []
+        self.children = []
+
+    def clone(self):
+        o = Node(self.name)
+        o.props = [x.clone() for x in self.props]
+        o.children = [x.clone() for x in self.children]
+        return o
+
+    def __getitem__(self, index):
+        return self.children[index]
+
+    def __repr__(self):
+        return "<Node('%s'), %s, %s>" % (self.name, repr(self.props), repr(self.children))
+
+#
+# flat DT to memory
+#
+
+def parse_strings(strings):
+    """
+    Converts the bytes into a StringsBlock instance so it is convenient to work with
+    """
+    strings = strings.split(b'\x00')
+    return StringsBlock(strings)
+
+def parse_struct(stream):
+    """
+    Parses DTB structure(s) into a Node or Prop instance
+    """
+    tag = bytearray(stream.read(4))[3]
+    if tag == OF_DT_BEGIN_NODE:
+        name = b''
+        while b'\x00' not in name:
+            name += stream.read(4)
+        name = name.rstrip(b'\x00')
+        node = Node(name)
+
+        item = parse_struct(stream)
+        while item is not None:
+            if isinstance(item, Node):
+                node.children.append(item)
+            elif isinstance(item, Prop):
+                node.props.append(item)
+            item = parse_struct(stream)
+
+        return node
+
+    if tag == OF_DT_PROP:
+        h = PropHeader.unpack_from(stream.read(PropHeader.size))
+        length = (h.value_size + 3) & (~3)
+        value = stream.read(length)[:h.value_size]
+        prop = Prop(h.name_offset, value)
+        return prop
+
+    if tag in (OF_DT_END_NODE, OF_DT_END):
+        return None
+
+    raise ValueError('unexpected tag value')
+
+def read_fdt(fp):
+    """
+    Reads and parses the flattened device tree (or derivatives like FIT)
+    """
+    header = HeaderV17.unpack_from(fp.read(HeaderV17.size))
+    if header.magic != OF_DT_HEADER:
+        raise ValueError('invalid magic value %08x; expected %08x' % (header.magic, OF_DT_HEADER))
+    # TODO: read/parse reserved regions
+    fp.seek(header.off_dt_struct)
+    structs = fp.read(header.size_dt_struct)
+    fp.seek(header.off_dt_strings)
+    strings = fp.read(header.size_dt_strings)
+    strblock = parse_strings(strings)
+    root = parse_struct(BytesIO(structs))
+
+    return root, strblock
+
+#
+# memory to flat DT
+#
+
+def compose_structs_r(item):
+    """
+    Recursive part of composing Nodes and Props into a bytearray
+    """
+    t = bytearray()
+
+    if isinstance(item, Node):
+        t.extend(struct.pack('>I', OF_DT_BEGIN_NODE))
+        if isinstance(item.name, str):
+            item.name = bytes(item.name, 'utf-8')
+        name = item.name + b'\x00'
+        if len(name) & 3:
+            name += b'\x00' * (4 - (len(name) & 3))
+        t.extend(name)
+        for p in item.props:
+            t.extend(compose_structs_r(p))
+        for c in item.children:
+            t.extend(compose_structs_r(c))
+        t.extend(struct.pack('>I', OF_DT_END_NODE))
+
+    elif isinstance(item, Prop):
+        t.extend(struct.pack('>I', OF_DT_PROP))
+        value = item.value
+        h = PropHeader()
+        h.name_offset = item.name
+        if value:
+            h.value_size = len(value)
+            t.extend(h.pack())
+            if len(value) & 3:
+                value += b'\x00' * (4 - (len(value) & 3))
+            t.extend(value)
+        else:
+            h.value_size = 0
+            t.extend(h.pack())
+
+    return t
+
+def compose_structs(root):
+    """
+    Composes the parsed Nodes into a flat bytearray instance
+    """
+    t = compose_structs_r(root)
+    t.extend(struct.pack('>I', OF_DT_END))
+    return t
+
+def compose_strings(strblock):
+    """
+    Composes the StringsBlock instance back into a bytearray instance
+    """
+    b = bytearray()
+    for s in strblock.values:
+        b.extend(s)
+        b.append(0)
+    return bytes(b)
+
+def write_fdt(root, strblock, fp):
+    """
+    Writes out a complete flattened device tree (or FIT)
+    """
+    header = HeaderV17()
+    header.magic = OF_DT_HEADER
+    header.version = 17
+    header.last_comp_version = 16
+    fp.write(header.pack())
+
+    header.off_mem_rsvmap = fp.tell()
+    fp.write(RRHeader().pack())
+
+    structs = compose_structs(root)
+    header.off_dt_struct = fp.tell()
+    header.size_dt_struct = len(structs)
+    fp.write(structs)
+
+    strings = compose_strings(strblock)
+    header.off_dt_strings = fp.tell()
+    header.size_dt_strings = len(strings)
+    fp.write(strings)
+
+    header.totalsize = fp.tell()
+
+    fp.seek(0)
+    fp.write(header.pack())
+
+#
+# pretty printing / converting to DT source
+#
+
+def as_bytes(value):
+    return ' '.join(["%02X" % x for x in value])
+
+def prety_print_value(value):
+    """
+    Formats a property value as appropriate depending on the guessed data type
+    """
+    if not value:
+        return '""'
+    if value[-1] == b'\x00':
+        printable = True
+        for x in value[:-1]:
+            x = ord(x)
+            if x != 0 and (x < 0x20 or x > 0x7F):
+                printable = False
+                break
+        if printable:
+            value = value[:-1]
+            return ', '.join('"' + x + '"' for x in value.split(b'\x00'))
+    if len(value) > 0x80:
+        return '[' + as_bytes(value[:0x80]) + ' ... ]'
+    return '[' + as_bytes(value) + ']'
+
+def pretty_print_r(node, strblock, indent=0):
+    """
+    Prints out a single node, recursing further for each of its children
+    """
+    spaces = '  ' * indent
+    print((spaces + '%s {' % (node.name.decode('utf-8') if node.name else '/')))
+    for p in node.props:
+        print((spaces + '  %s = %s;' % (strblock[p.name].decode('utf-8'), prety_print_value(p.value))))
+    for c in node.children:
+        pretty_print_r(c, strblock, indent+1)
+    print((spaces + '};'))
+
+def pretty_print(node, strblock):
+    """
+    Generates an almost-DTS formatted printout of the parsed device tree
+    """
+    print('/dts-v1/;')
+    pretty_print_r(node, strblock, 0)
+
+#
+# manipulating the DT structure
+#
+
+def manipulate(root, strblock):
+    """
+    Maliciously manipulates the structure to create a crafted FIT file
+    """
+    # locate /images/kernel@1 (frankly, it just expects it to be the first one)
+    kernel_node = root[0][0]
+    # clone it to save time filling all the properties
+    fake_kernel = kernel_node.clone()
+    # rename the node
+    fake_kernel.name = b'kernel@2'
+    # get rid of signatures/hashes
+    fake_kernel.children = []
+    # NOTE: this simply replaces the first prop... either description or data
+    # should be good for testing purposes
+    fake_kernel.props[0].value = b'Super 1337 kernel\x00'
+    # insert the new kernel node under /images
+    root[0].children.append(fake_kernel)
+
+    # modify the default configuration
+    root[1].props[0].value = b'conf@2\x00'
+    # clone the first (only?) configuration
+    fake_conf = root[1][0].clone()
+    # rename and change kernel and fdt properties to select the crafted kernel
+    fake_conf.name = b'conf@2'
+    fake_conf.props[0].value = b'kernel@2\x00'
+    fake_conf.props[1].value = b'fdt@1\x00'
+    # insert the new configuration under /configurations
+    root[1].children.append(fake_conf)
+
+    return root, strblock
+
+def main(argv):
+    with open(argv[1], 'rb') as fp:
+        root, strblock = read_fdt(fp)
+
+    print("Before:")
+    pretty_print(root, strblock)
+
+    root, strblock = manipulate(root, strblock)
+    print("After:")
+    pretty_print(root, strblock)
+
+    with open('blah', 'w+b') as fp:
+        write_fdt(root, strblock, fp)
+
+if __name__ == '__main__':
+    import sys
+    main(sys.argv)
+# EOF