SPIRV-Tools/source/assembly_grammar.cpp
alan-baker d35a78db57
Switch SPIRV-Tools to use spirv.hpp11 internally (#4981)
Fixes #4960

* Switches to using enum classes with an underlying type to avoid
  undefined behaviour
2022-11-04 17:27:10 -04:00

265 lines
8.8 KiB
C++

// Copyright (c) 2015-2016 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/assembly_grammar.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include "source/ext_inst.h"
#include "source/opcode.h"
#include "source/operand.h"
#include "source/table.h"
namespace spvtools {
namespace {
/// @brief Parses a mask expression string for the given operand type.
///
/// A mask expression is a sequence of one or more terms separated by '|',
/// where each term a named enum value for the given type. No whitespace
/// is permitted.
///
/// On success, the value is written to pValue.
///
/// @param[in] operandTable operand lookup table
/// @param[in] type of the operand
/// @param[in] textValue word of text to be parsed
/// @param[out] pValue where the resulting value is written
///
/// @return result code
spv_result_t spvTextParseMaskOperand(spv_target_env env,
const spv_operand_table operandTable,
const spv_operand_type_t type,
const char* textValue, uint32_t* pValue) {
if (textValue == nullptr) return SPV_ERROR_INVALID_TEXT;
size_t text_length = strlen(textValue);
if (text_length == 0) return SPV_ERROR_INVALID_TEXT;
const char* text_end = textValue + text_length;
// We only support mask expressions in ASCII, so the separator value is a
// char.
const char separator = '|';
// Accumulate the result by interpreting one word at a time, scanning
// from left to right.
uint32_t value = 0;
const char* begin = textValue; // The left end of the current word.
const char* end = nullptr; // One character past the end of the current word.
do {
end = std::find(begin, text_end, separator);
spv_operand_desc entry = nullptr;
if (auto error = spvOperandTableNameLookup(env, operandTable, type, begin,
end - begin, &entry)) {
return error;
}
value |= entry->value;
// Advance to the next word by skipping over the separator.
begin = end + 1;
} while (end != text_end);
*pValue = value;
return SPV_SUCCESS;
}
// Associates an opcode with its name.
struct SpecConstantOpcodeEntry {
spv::Op opcode;
const char* name;
};
// All the opcodes allowed as the operation for OpSpecConstantOp.
// The name does not have the usual "Op" prefix. For example opcode
// spv::Op::IAdd is associated with the name "IAdd".
//
// clang-format off
#define CASE(NAME) { spv::Op::Op##NAME, #NAME }
const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
// Conversion
CASE(SConvert),
CASE(FConvert),
CASE(ConvertFToS),
CASE(ConvertSToF),
CASE(ConvertFToU),
CASE(ConvertUToF),
CASE(UConvert),
CASE(ConvertPtrToU),
CASE(ConvertUToPtr),
CASE(GenericCastToPtr),
CASE(PtrCastToGeneric),
CASE(Bitcast),
CASE(QuantizeToF16),
// Arithmetic
CASE(SNegate),
CASE(Not),
CASE(IAdd),
CASE(ISub),
CASE(IMul),
CASE(UDiv),
CASE(SDiv),
CASE(UMod),
CASE(SRem),
CASE(SMod),
CASE(ShiftRightLogical),
CASE(ShiftRightArithmetic),
CASE(ShiftLeftLogical),
CASE(BitwiseOr),
CASE(BitwiseAnd),
CASE(BitwiseXor),
CASE(FNegate),
CASE(FAdd),
CASE(FSub),
CASE(FMul),
CASE(FDiv),
CASE(FRem),
CASE(FMod),
// Composite
CASE(VectorShuffle),
CASE(CompositeExtract),
CASE(CompositeInsert),
// Logical
CASE(LogicalOr),
CASE(LogicalAnd),
CASE(LogicalNot),
CASE(LogicalEqual),
CASE(LogicalNotEqual),
CASE(Select),
// Comparison
CASE(IEqual),
CASE(INotEqual),
CASE(ULessThan),
CASE(SLessThan),
CASE(UGreaterThan),
CASE(SGreaterThan),
CASE(ULessThanEqual),
CASE(SLessThanEqual),
CASE(UGreaterThanEqual),
CASE(SGreaterThanEqual),
// Memory
CASE(AccessChain),
CASE(InBoundsAccessChain),
CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain),
CASE(CooperativeMatrixLengthNV)
};
// The 60 is determined by counting the opcodes listed in the spec.
static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete");
#undef CASE
// clang-format on
const size_t kNumOpSpecConstantOpcodes =
sizeof(kOpSpecConstantOpcodes) / sizeof(kOpSpecConstantOpcodes[0]);
} // namespace
bool AssemblyGrammar::isValid() const {
return operandTable_ && opcodeTable_ && extInstTable_;
}
CapabilitySet AssemblyGrammar::filterCapsAgainstTargetEnv(
const spv::Capability* cap_array, uint32_t count) const {
CapabilitySet cap_set;
for (uint32_t i = 0; i < count; ++i) {
spv_operand_desc cap_desc = {};
if (SPV_SUCCESS == lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
static_cast<uint32_t>(cap_array[i]),
&cap_desc)) {
// spvOperandTableValueLookup() filters capabilities internally
// according to the current target environment by itself. So we
// should be safe to add this capability if the lookup succeeds.
cap_set.Add(cap_array[i]);
}
}
return cap_set;
}
spv_result_t AssemblyGrammar::lookupOpcode(const char* name,
spv_opcode_desc* desc) const {
return spvOpcodeTableNameLookup(target_env_, opcodeTable_, name, desc);
}
spv_result_t AssemblyGrammar::lookupOpcode(spv::Op opcode,
spv_opcode_desc* desc) const {
return spvOpcodeTableValueLookup(target_env_, opcodeTable_, opcode, desc);
}
spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type,
const char* name, size_t name_len,
spv_operand_desc* desc) const {
return spvOperandTableNameLookup(target_env_, operandTable_, type, name,
name_len, desc);
}
spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type,
uint32_t operand,
spv_operand_desc* desc) const {
return spvOperandTableValueLookup(target_env_, operandTable_, type, operand,
desc);
}
spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(const char* name,
spv::Op* opcode) const {
const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes;
const auto* found =
std::find_if(kOpSpecConstantOpcodes, last,
[name](const SpecConstantOpcodeEntry& entry) {
return 0 == strcmp(name, entry.name);
});
if (found == last) return SPV_ERROR_INVALID_LOOKUP;
*opcode = found->opcode;
return SPV_SUCCESS;
}
spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(spv::Op opcode) const {
const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes;
const auto* found =
std::find_if(kOpSpecConstantOpcodes, last,
[opcode](const SpecConstantOpcodeEntry& entry) {
return opcode == entry.opcode;
});
if (found == last) return SPV_ERROR_INVALID_LOOKUP;
return SPV_SUCCESS;
}
spv_result_t AssemblyGrammar::parseMaskOperand(const spv_operand_type_t type,
const char* textValue,
uint32_t* pValue) const {
return spvTextParseMaskOperand(target_env_, operandTable_, type, textValue,
pValue);
}
spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type,
const char* textValue,
spv_ext_inst_desc* extInst) const {
return spvExtInstTableNameLookup(extInstTable_, type, textValue, extInst);
}
spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type,
uint32_t firstWord,
spv_ext_inst_desc* extInst) const {
return spvExtInstTableValueLookup(extInstTable_, type, firstWord, extInst);
}
void AssemblyGrammar::pushOperandTypesForMask(
const spv_operand_type_t type, const uint32_t mask,
spv_operand_pattern_t* pattern) const {
spvPushOperandTypesForMask(target_env_, operandTable_, type, mask, pattern);
}
} // namespace spvtools