Add knowledge of algorithms

Determine the category of operations supported by an algorithm based
on its name.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
Gilles Peskine 2021-04-29 20:38:01 +02:00
parent 8b4a38176a
commit ee7554e606

View File

@ -18,11 +18,21 @@ This module is entirely based on the PSA API.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import re
from typing import Dict, Iterable, Optional, Pattern, Tuple
from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
BLOCK_MAC_MODES = frozenset(['CBC_MAC', 'CMAC'])
BLOCK_CIPHER_MODES = frozenset([
'CTR', 'CFB', 'OFB', 'XTS', 'CCM_STAR_NO_TAG',
'ECB_NO_PADDING', 'CBC_NO_PADDING', 'CBC_PKCS7',
])
BLOCK_AEAD_MODES = frozenset(['CCM', 'GCM'])
class KeyType:
"""Knowledge about a PSA key type."""
@ -153,3 +163,144 @@ class KeyType:
"""
# This is just temporaly solution for the implicit usage flags.
return re.match(self.KEY_TYPE_FOR_SIGNATURE[usage], self.name) is not None
class AlgorithmCategory(enum.Enum):
"""PSA algorithm categories."""
# The numbers are aligned with the category bits in numerical values of
# algorithms.
HASH = 2
MAC = 3
CIPHER = 4
AEAD = 5
SIGN = 6
ASYMMETRIC_ENCRYPTION = 7
KEY_DERIVATION = 8
KEY_AGREEMENT = 9
PAKE = 10
def requires_key(self) -> bool:
return self not in {self.HASH, self.KEY_DERIVATION}
class AlgorithmNotRecognized(Exception):
def __init__(self, expr: str) -> None:
super().__init__('Algorithm not recognized: ' + expr)
self.expr = expr
class Algorithm:
"""Knowledge about a PSA algorithm."""
@staticmethod
def determine_base(expr: str) -> str:
"""Return an expression for the "base" of the algorithm.
This strips off variants of algorithms such as MAC truncation.
This function does not attempt to detect invalid inputs.
"""
m = re.match(r'PSA_ALG_(?:'
r'(?:TRUNCATED|AT_LEAST_THIS_LENGTH)_MAC|'
r'AEAD_WITH_(?:SHORTENED|AT_LEAST_THIS_LENGTH)_TAG'
r')\((.*),[^,]+\)\Z', expr)
if m:
expr = m.group(1)
return expr
@staticmethod
def determine_head(expr: str) -> str:
"""Return the head of an algorithm expression.
The head is the first (outermost) constructor, without its PSA_ALG_
prefix, and with some normalization of similar algorithms.
"""
m = re.match(r'PSA_ALG_(?:DETERMINISTIC_)?(\w+)', expr)
if not m:
raise AlgorithmNotRecognized(expr)
head = m.group(1)
if head == 'KEY_AGREEMENT':
m = re.match(r'PSA_ALG_KEY_AGREEMENT\s*\(\s*PSA_ALG_(\w+)', expr)
if not m:
raise AlgorithmNotRecognized(expr)
head = m.group(1)
head = re.sub(r'_ANY\Z', r'', head)
if re.match(r'ED[0-9]+PH\Z', head):
head = 'EDDSA_PREHASH'
return head
CATEGORY_FROM_HEAD = {
'SHA': AlgorithmCategory.HASH,
'SHAKE256_512': AlgorithmCategory.HASH,
'MD': AlgorithmCategory.HASH,
'RIPEMD': AlgorithmCategory.HASH,
'ANY_HASH': AlgorithmCategory.HASH,
'HMAC': AlgorithmCategory.MAC,
'STREAM_CIPHER': AlgorithmCategory.CIPHER,
'CHACHA20_POLY1305': AlgorithmCategory.AEAD,
'DSA': AlgorithmCategory.SIGN,
'ECDSA': AlgorithmCategory.SIGN,
'EDDSA': AlgorithmCategory.SIGN,
'PURE_EDDSA': AlgorithmCategory.SIGN,
'RSA_PSS': AlgorithmCategory.SIGN,
'RSA_PKCS1V15_SIGN': AlgorithmCategory.SIGN,
'RSA_PKCS1V15_CRYPT': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
'RSA_OAEP': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
'HKDF': AlgorithmCategory.KEY_DERIVATION,
'TLS12_PRF': AlgorithmCategory.KEY_DERIVATION,
'TLS12_PSK_TO_MS': AlgorithmCategory.KEY_DERIVATION,
'PBKDF': AlgorithmCategory.KEY_DERIVATION,
'ECDH': AlgorithmCategory.KEY_AGREEMENT,
'FFDH': AlgorithmCategory.KEY_AGREEMENT,
# KEY_AGREEMENT(...) is a key derivation with a key agreement component
'KEY_AGREEMENT': AlgorithmCategory.KEY_DERIVATION,
'JPAKE': AlgorithmCategory.PAKE,
}
for x in BLOCK_MAC_MODES:
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.MAC
for x in BLOCK_CIPHER_MODES:
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.CIPHER
for x in BLOCK_AEAD_MODES:
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.AEAD
def determine_category(self, expr: str, head: str) -> AlgorithmCategory:
"""Return the category of the given algorithm expression.
This function does not attempt to detect invalid inputs.
"""
prefix = head
while prefix:
if prefix in self.CATEGORY_FROM_HEAD:
return self.CATEGORY_FROM_HEAD[prefix]
if re.match(r'.*[0-9]\Z', prefix):
prefix = re.sub(r'_*[0-9]+\Z', r'', prefix)
else:
prefix = re.sub(r'_*[^_]*\Z', r'', prefix)
raise AlgorithmNotRecognized(expr)
@staticmethod
def determine_wildcard(expr) -> bool:
"""Whether the given algorithm expression is a wildcard.
This function does not attempt to detect invalid inputs.
"""
if re.search(r'\bPSA_ALG_ANY_HASH\b', expr):
return True
if re.search(r'_AT_LEAST_', expr):
return True
return False
def __init__(self, expr: str) -> None:
"""Analyze an algorithm value.
The algorithm must be expressed as a C expression containing only
calls to PSA algorithm constructor macros and numeric literals.
This class is only programmed to handle valid expressions. Invalid
expressions may result in exceptions or in nonsensical results.
"""
self.expression = re.sub(r'\s+', r'', expr)
self.base_expression = self.determine_base(self.expression)
self.head = self.determine_head(self.base_expression)
self.category = self.determine_category(self.base_expression, self.head)
self.is_wildcard = self.determine_wildcard(self.expression)