diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index 53af0a524..e64040802 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -269,8 +269,10 @@ def remove_file_if_exists(filename): except OSError: pass -def run_c(options, type_word, expressions): +def run_c(type_word, expressions, include_path=None, keep_c=False): """Generate and run a program to print out numerical values for expressions.""" + if include_path is None: + include_path = [] if type_word == 'status': cast_to = 'long' printf_format = '%ld' @@ -304,9 +306,9 @@ int main(void) c_file.close() cc = os.getenv('CC', 'cc') subprocess.check_call([cc] + - ['-I' + dir for dir in options.include] + + ['-I' + dir for dir in include_path] + ['-o', exe_name, c_name]) - if options.keep_c: + if keep_c: sys.stderr.write('List of {} tests kept at {}\n' .format(type_word, c_name)) else: @@ -324,7 +326,7 @@ def normalize(expr): """ return re.sub(NORMALIZE_STRIP_RE, '', expr) -def collect_values(options, inputs, type_word): +def collect_values(inputs, type_word, include_path=None, keep_c=False): """Generate expressions using known macro names and calculate their values. Return a list of pairs of (expr, value) where expr is an expression and @@ -332,7 +334,8 @@ def collect_values(options, inputs, type_word): """ names = inputs.get_names(type_word) expressions = sorted(inputs.generate_expressions(names)) - values = run_c(options, type_word, expressions) + values = run_c(type_word, expressions, + include_path=include_path, keep_c=keep_c) return expressions, values def do_test(options, inputs, type_word): @@ -346,7 +349,9 @@ def do_test(options, inputs, type_word): that have been tested and ``errors`` is the list of errors that were encountered. """ - expressions, values = collect_values(options, inputs, type_word) + expressions, values = collect_values(inputs, type_word, + include_path=options.include, + keep_c=options.keep_c) output = subprocess.check_output([options.program, type_word] + values) outputs = output.decode('ascii').strip().split('\n') errors = [(type_word, expr, value, output)