blob: b1fc37710428bef4a55a6914a4b99759c69b1c56 [file] [log] [blame]
Tom Rini0344c602024-10-08 13:56:50 -06001"""Knowledge about the PSA key store as implemented in Mbed TLS.
2
3Note that if you need to make a change that affects how keys are
4stored, this may indicate that the key store is changing in a
5backward-incompatible way! Think carefully about backward compatibility
6before changing how test data is constructed or validated.
7"""
8
9# Copyright The Mbed TLS Contributors
10# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
11#
12
13import re
14import struct
15from typing import Dict, List, Optional, Set, Union
16import unittest
17
18from . import c_build_helper
19from . import build_tree
20
21
22class Expr:
23 """Representation of a C expression with a known or knowable numerical value."""
24
25 def __init__(self, content: Union[int, str]):
26 if isinstance(content, int):
27 digits = 8 if content > 0xffff else 4
28 self.string = '{0:#0{1}x}'.format(content, digits + 2)
29 self.value_if_known = content #type: Optional[int]
30 else:
31 self.string = content
32 self.unknown_values.add(self.normalize(content))
33 self.value_if_known = None
34
35 value_cache = {} #type: Dict[str, int]
36 """Cache of known values of expressions."""
37
38 unknown_values = set() #type: Set[str]
39 """Expressions whose values are not present in `value_cache` yet."""
40
41 def update_cache(self) -> None:
42 """Update `value_cache` for expressions registered in `unknown_values`."""
43 expressions = sorted(self.unknown_values)
44 includes = ['include']
45 if build_tree.looks_like_tf_psa_crypto_root('.'):
46 includes.append('drivers/builtin/include')
47 values = c_build_helper.get_c_expression_values(
48 'unsigned long', '%lu',
49 expressions,
50 header="""
51 #include <psa/crypto.h>
52 """,
53 include_path=includes) #type: List[str]
54 for e, v in zip(expressions, values):
55 self.value_cache[e] = int(v, 0)
56 self.unknown_values.clear()
57
58 @staticmethod
59 def normalize(string: str) -> str:
60 """Put the given C expression in a canonical form.
61
62 This function is only intended to give correct results for the
63 relatively simple kind of C expression typically used with this
64 module.
65 """
66 return re.sub(r'\s+', r'', string)
67
68 def value(self) -> int:
69 """Return the numerical value of the expression."""
70 if self.value_if_known is None:
71 if re.match(r'([0-9]+|0x[0-9a-f]+)\Z', self.string, re.I):
72 return int(self.string, 0)
73 normalized = self.normalize(self.string)
74 if normalized not in self.value_cache:
75 self.update_cache()
76 self.value_if_known = self.value_cache[normalized]
77 return self.value_if_known
78
79Exprable = Union[str, int, Expr]
80"""Something that can be converted to a C expression with a known numerical value."""
81
82def as_expr(thing: Exprable) -> Expr:
83 """Return an `Expr` object for `thing`.
84
85 If `thing` is already an `Expr` object, return it. Otherwise build a new
86 `Expr` object from `thing`. `thing` can be an integer or a string that
87 contains a C expression.
88 """
89 if isinstance(thing, Expr):
90 return thing
91 else:
92 return Expr(thing)
93
94
95class Key:
96 """Representation of a PSA crypto key object and its storage encoding.
97 """
98
99 LATEST_VERSION = 0
100 """The latest version of the storage format."""
101
102 def __init__(self, *,
103 version: Optional[int] = None,
104 id: Optional[int] = None, #pylint: disable=redefined-builtin
105 lifetime: Exprable = 'PSA_KEY_LIFETIME_PERSISTENT',
106 type: Exprable, #pylint: disable=redefined-builtin
107 bits: int,
108 usage: Exprable, alg: Exprable, alg2: Exprable,
109 material: bytes #pylint: disable=used-before-assignment
110 ) -> None:
111 self.version = self.LATEST_VERSION if version is None else version
112 self.id = id #pylint: disable=invalid-name #type: Optional[int]
113 self.lifetime = as_expr(lifetime) #type: Expr
114 self.type = as_expr(type) #type: Expr
115 self.bits = bits #type: int
116 self.usage = as_expr(usage) #type: Expr
117 self.alg = as_expr(alg) #type: Expr
118 self.alg2 = as_expr(alg2) #type: Expr
119 self.material = material #type: bytes
120
121 MAGIC = b'PSA\000KEY\000'
122
123 @staticmethod
124 def pack(
125 fmt: str,
126 *args: Union[int, Expr]
127 ) -> bytes: #pylint: disable=used-before-assignment
128 """Pack the given arguments into a byte string according to the given format.
129
130 This function is similar to `struct.pack`, but with the following differences:
131 * All integer values are encoded with standard sizes and in
132 little-endian representation. `fmt` must not include an endianness
133 prefix.
134 * Arguments can be `Expr` objects instead of integers.
135 * Only integer-valued elements are supported.
136 """
137 return struct.pack('<' + fmt, # little-endian, standard sizes
138 *[arg.value() if isinstance(arg, Expr) else arg
139 for arg in args])
140
141 def bytes(self) -> bytes:
142 """Return the representation of the key in storage as a byte array.
143
144 This is the content of the PSA storage file. When PSA storage is
145 implemented over stdio files, this does not include any wrapping made
146 by the PSA-storage-over-stdio-file implementation.
147
148 Note that if you need to make a change in this function,
149 this may indicate that the key store is changing in a
150 backward-incompatible way! Think carefully about backward
151 compatibility before making any change here.
152 """
153 header = self.MAGIC + self.pack('L', self.version)
154 if self.version == 0:
155 attributes = self.pack('LHHLLL',
156 self.lifetime, self.type, self.bits,
157 self.usage, self.alg, self.alg2)
158 material = self.pack('L', len(self.material)) + self.material
159 else:
160 raise NotImplementedError
161 return header + attributes + material
162
163 def hex(self) -> str:
164 """Return the representation of the key as a hexadecimal string.
165
166 This is the hexadecimal representation of `self.bytes`.
167 """
168 return self.bytes().hex()
169
170 def location_value(self) -> int:
171 """The numerical value of the location encoded in the key's lifetime."""
172 return self.lifetime.value() >> 8
173
174
175class TestKey(unittest.TestCase):
176 # pylint: disable=line-too-long
177 """A few smoke tests for the functionality of the `Key` class."""
178
179 def test_numerical(self):
180 key = Key(version=0,
181 id=1, lifetime=0x00000001,
182 type=0x2400, bits=128,
183 usage=0x00000300, alg=0x05500200, alg2=0x04c01000,
184 material=b'@ABCDEFGHIJKLMNO')
185 expected_hex = '505341004b45590000000000010000000024800000030000000250050010c00410000000404142434445464748494a4b4c4d4e4f'
186 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
187 self.assertEqual(key.hex(), expected_hex)
188
189 def test_names(self):
190 length = 0xfff8 // 8 # PSA_MAX_KEY_BITS in bytes
191 key = Key(version=0,
192 id=1, lifetime='PSA_KEY_LIFETIME_PERSISTENT',
193 type='PSA_KEY_TYPE_RAW_DATA', bits=length*8,
194 usage=0, alg=0, alg2=0,
195 material=b'\x00' * length)
196 expected_hex = '505341004b45590000000000010000000110f8ff000000000000000000000000ff1f0000' + '00' * length
197 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
198 self.assertEqual(key.hex(), expected_hex)
199
200 def test_defaults(self):
201 key = Key(type=0x1001, bits=8,
202 usage=0, alg=0, alg2=0,
203 material=b'\x2a')
204 expected_hex = '505341004b455900000000000100000001100800000000000000000000000000010000002a'
205 self.assertEqual(key.bytes(), bytes.fromhex(expected_hex))
206 self.assertEqual(key.hex(), expected_hex)