From b5db2c480fb4e6fdcedee59c36cdd2e8544d2b8a Mon Sep 17 00:00:00 2001 From: gabor-mezei-arm Date: Wed, 23 Jun 2021 17:33:30 +0200 Subject: [PATCH] Convert iterators to lists to remove late binding Remove late binding of iterators to enable the creation of an object with an actual state of a variable. Signed-off-by: gabor-mezei-arm --- tests/scripts/generate_psa_tests.py | 87 +++++++++++++++-------------- 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py index da759835e..7e04b2ee9 100755 --- a/tests/scripts/generate_psa_tests.py +++ b/tests/scripts/generate_psa_tests.py @@ -313,10 +313,11 @@ class StorageFormat: description=description) return key - def all_keys_for_lifetimes(self) -> Iterator[StorageKey]: + def all_keys_for_lifetimes(self) -> List[StorageKey]: """Generate test keys covering lifetimes.""" lifetimes = sorted(self.constructors.lifetimes) expressions = self.constructors.generate_expressions(lifetimes) + keys = [] #type List[StorageKey] for lifetime in expressions: # Don't attempt to create or load a volatile key in storage if 'VOLATILE' in lifetime: @@ -325,7 +326,8 @@ class StorageFormat: # but do attempt to load one. if 'READ_ONLY' in lifetime and self.forward: continue - yield self.key_for_lifetime(lifetime) + keys.append(self.key_for_lifetime(lifetime)) + return keys def key_for_usage_flags( self, @@ -337,36 +339,38 @@ class StorageFormat: if short is None: short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage) description = 'usage: ' + short - key = StorageKey(version=self.version, - id=1, lifetime=0x00000001, - type='PSA_KEY_TYPE_RAW_DATA', bits=8, - usage=usage, alg=0, alg2=0, - material=b'K', - description=description) - return key + return StorageKey(version=self.version, + id=1, lifetime=0x00000001, + type='PSA_KEY_TYPE_RAW_DATA', bits=8, + usage=usage, alg=0, alg2=0, + material=b'K', + description=description) - def all_keys_for_usage_flags(self) -> Iterator[StorageKey]: + def all_keys_for_usage_flags(self) -> List[StorageKey]: """Generate test keys covering usage flags.""" known_flags = sorted(self.constructors.key_usage_flags) - yield self.key_for_usage_flags(['0']) - for usage_flag in known_flags: - yield self.key_for_usage_flags([usage_flag]) - for flag1, flag2 in zip(known_flags, - known_flags[1:] + [known_flags[0]]): - yield self.key_for_usage_flags([flag1, flag2]) - yield self.key_for_usage_flags(known_flags, short='all known') + keys = [] #type List[StorageKey] + keys.append(self.key_for_usage_flags(['0'])) + keys += [self.key_for_usage_flags([usage_flag]) + for usage_flag in known_flags] + keys += [self.key_for_usage_flags([flag1, flag2]) + for flag1, flag2 in zip(known_flags, + known_flags[1:] + [known_flags[0]])] + keys.append(self.key_for_usage_flags(known_flags, short='all known')) + return keys def keys_for_type( self, key_type: str, params: Optional[Iterable[str]] = None - ) -> Iterator[StorageKey]: + ) -> List[StorageKey]: """Generate test keys for the given key type. For key types that depend on a parameter (e.g. elliptic curve family), `param` is the parameter to pass to the constructor. Only a single parameter is supported. """ + keys = [] #type: List[StorageKey] kt = crypto_knowledge.KeyType(key_type, params) for bits in kt.sizes_to_test(): usage_flags = 'PSA_KEY_USAGE_EXPORT' @@ -377,21 +381,21 @@ class StorageFormat: r'', kt.expression) description = 'type: {} {}-bit'.format(short_expression, bits) - key = StorageKey(version=self.version, - id=1, lifetime=0x00000001, - type=kt.expression, bits=bits, - usage=usage_flags, alg=alg, alg2=alg2, - material=key_material, - description=description) - yield key + keys.append(StorageKey(version=self.version, + id=1, lifetime=0x00000001, + type=kt.expression, bits=bits, + usage=usage_flags, alg=alg, alg2=alg2, + material=key_material, + description=description)) + return keys - def all_keys_for_types(self) -> Iterator[StorageKey]: + def all_keys_for_types(self) -> List[StorageKey]: """Generate test keys covering key types and their representations.""" key_types = sorted(self.constructors.key_types) - for key_type in self.constructors.generate_expressions(key_types): - yield from self.keys_for_type(key_type) + return [key for key_type in self.constructors.generate_expressions(key_types) + for key in self.keys_for_type(key_type)] - def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]: + def keys_for_algorithm(self, alg: str) -> List[StorageKey]: """Generate test keys for the specified algorithm.""" # For now, we don't have information on the compatibility of key # types and algorithms. So we just test the encoding of algorithms, @@ -405,22 +409,21 @@ class StorageFormat: usage=usage, alg=alg, alg2=0, material=b'K', description='alg: ' + descr) - yield key1 key2 = StorageKey(version=self.version, id=1, lifetime=0x00000001, type='PSA_KEY_TYPE_RAW_DATA', bits=8, usage=usage, alg=0, alg2=alg, material=b'L', description='alg2: ' + descr) - yield key2 + return [key1, key2] - def all_keys_for_algorithms(self) -> Iterator[StorageKey]: + def all_keys_for_algorithms(self) -> List[StorageKey]: """Generate test keys covering algorithm encodings.""" algorithms = sorted(self.constructors.algorithms) - for alg in self.constructors.generate_expressions(algorithms): - yield from self.keys_for_algorithm(alg) + return [key for alg in self.constructors.generate_expressions(algorithms) + for key in self.keys_for_algorithm(alg)] - def all_test_cases(self) -> Iterator[test_case.TestCase]: + def all_test_cases(self) -> List[test_case.TestCase]: """Generate all storage format test cases.""" # First build a list of all keys, then construct all the corresponding # test cases. This allows all required information to be obtained in @@ -431,13 +434,13 @@ class StorageFormat: keys += self.all_keys_for_usage_flags() keys += self.all_keys_for_types() keys += self.all_keys_for_algorithms() - for key in keys: - if key.location_value() != 0: - # Skip keys with a non-default location, because they - # require a driver and we currently have no mechanism to - # determine whether a driver is available. - continue - yield self.make_test_case(key) + + # Skip keys with a non-default location, because they + # require a driver and we currently have no mechanism to + # determine whether a driver is available. + keys = filter(lambda key: key.location_value() == 0, keys) + + return [self.make_test_case(key) for key in keys] class TestGenerator: