mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-28 06:01:04 +00:00
Move cfg opcode validation to another file.
* Moved cfg opcode validation out of idUsage and into validate_cfg.cpp * minor style updates
This commit is contained in:
parent
70de4a35aa
commit
6cd4441c87
@ -336,6 +336,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
|
||||
if (auto error = ArithmeticsPass(*vstate, &instruction)) return error;
|
||||
if (auto error = BitwisePass(*vstate, &instruction)) return error;
|
||||
if (auto error = LogicalsPass(*vstate, &instruction)) return error;
|
||||
if (auto error = ControlFlowPass(*vstate, &instruction)) return error;
|
||||
if (auto error = DerivativesPass(*vstate, &instruction)) return error;
|
||||
if (auto error = AtomicsPass(*vstate, &instruction)) return error;
|
||||
if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
|
||||
|
@ -116,9 +116,12 @@ void printDominatorList(BasicBlock& block);
|
||||
/// spec.
|
||||
spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Performs Control Flow Graph validation of a module
|
||||
/// Performs Control Flow Graph validation and construction.
|
||||
spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Validates Control Flow Graph instructions.
|
||||
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst);
|
||||
|
||||
/// Performs Id and SSA validation of a module
|
||||
spv_result_t IdPass(ValidationState_t& _, Instruction* inst);
|
||||
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
@ -27,6 +28,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "source/cfa.h"
|
||||
#include "source/opcode.h"
|
||||
#include "source/spirv_validator_options.h"
|
||||
#include "source/val/basic_block.h"
|
||||
#include "source/val/construct.h"
|
||||
@ -35,6 +37,158 @@
|
||||
|
||||
namespace spvtools {
|
||||
namespace val {
|
||||
namespace {
|
||||
|
||||
spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) {
|
||||
SpvOp type_op = _.GetIdOpcode(inst->type_id());
|
||||
if (!spvOpcodeGeneratesType(type_op)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi's type <id> " << _.getIdName(inst->type_id())
|
||||
<< " is not a type instruction.";
|
||||
}
|
||||
|
||||
auto block = inst->block();
|
||||
size_t num_in_ops = inst->words().size() - 3;
|
||||
if (num_in_ops % 2 != 0) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi does not have an equal number of incoming values and "
|
||||
"basic blocks.";
|
||||
}
|
||||
|
||||
// Create a uniqued vector of predecessor ids for comparison against
|
||||
// incoming values. OpBranchConditional %cond %label %label produces two
|
||||
// predecessors in the CFG.
|
||||
std::vector<uint32_t> pred_ids;
|
||||
std::transform(block->predecessors()->begin(), block->predecessors()->end(),
|
||||
std::back_inserter(pred_ids),
|
||||
[](const BasicBlock* b) { return b->id(); });
|
||||
std::sort(pred_ids.begin(), pred_ids.end());
|
||||
pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end());
|
||||
|
||||
size_t num_edges = num_in_ops / 2;
|
||||
if (num_edges != pred_ids.size()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi's number of incoming blocks (" << num_edges
|
||||
<< ") does not match block's predecessor count ("
|
||||
<< block->predecessors()->size() << ").";
|
||||
}
|
||||
|
||||
for (size_t i = 3; i < inst->words().size(); ++i) {
|
||||
auto inc_id = inst->word(i);
|
||||
if (i % 2 == 1) {
|
||||
// Incoming value type must match the phi result type.
|
||||
auto inc_type_id = _.GetTypeId(inc_id);
|
||||
if (inst->type_id() != inc_type_id) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi's result type <id> " << _.getIdName(inst->type_id())
|
||||
<< " does not match incoming value <id> " << _.getIdName(inc_id)
|
||||
<< " type <id> " << _.getIdName(inc_type_id) << ".";
|
||||
}
|
||||
} else {
|
||||
if (_.GetIdOpcode(inc_id) != SpvOpLabel) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
|
||||
<< " is not an OpLabel.";
|
||||
}
|
||||
|
||||
// Incoming basic block must be an immediate predecessor of the phi's
|
||||
// block.
|
||||
if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
|
||||
<< " is not a predecessor of <id> " << _.getIdName(block->id())
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateBranchConditional(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
// num_operands is either 3 or 5 --- if 5, the last two need to be literal
|
||||
// integers
|
||||
const auto num_operands = inst->operands().size();
|
||||
if (num_operands != 3 && num_operands != 5) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpBranchConditional requires either 3 or 5 parameters";
|
||||
}
|
||||
|
||||
// grab the condition operand and check that it is a bool
|
||||
const auto cond_id = inst->GetOperandAs<uint32_t>(0);
|
||||
const auto cond_op = _.FindDef(cond_id);
|
||||
if (!cond_op || !_.IsBoolScalarType(cond_op->type_id())) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for "
|
||||
"OpBranchConditional must be "
|
||||
"of boolean type";
|
||||
}
|
||||
|
||||
// target operands must be OpLabel
|
||||
// note that we don't need to check that the target labels are in the same
|
||||
// function,
|
||||
// PerformCfgChecks already checks for that
|
||||
const auto true_id = inst->GetOperandAs<uint32_t>(1);
|
||||
const auto true_target = _.FindDef(true_id);
|
||||
if (!true_target || SpvOpLabel != true_target->opcode()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "The 'True Label' operand for OpBranchConditional must be the "
|
||||
"ID of an OpLabel instruction";
|
||||
}
|
||||
|
||||
const auto false_id = inst->GetOperandAs<uint32_t>(2);
|
||||
const auto false_target = _.FindDef(false_id);
|
||||
if (!false_target || SpvOpLabel != false_target->opcode()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "The 'False Label' operand for OpBranchConditional must be the "
|
||||
"ID of an OpLabel instruction";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateReturnValue(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
const auto value_id = inst->GetOperandAs<uint32_t>(0);
|
||||
const auto value = _.FindDef(value_id);
|
||||
if (!value || !value->type_id()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpReturnValue Value <id> '" << _.getIdName(value_id)
|
||||
<< "' does not represent a value.";
|
||||
}
|
||||
auto value_type = _.FindDef(value->type_id());
|
||||
if (!value_type || SpvOpTypeVoid == value_type->opcode()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpReturnValue value's type <id> '"
|
||||
<< _.getIdName(value->type_id()) << "' is missing or void.";
|
||||
}
|
||||
|
||||
const bool uses_variable_pointer =
|
||||
_.features().variable_pointers ||
|
||||
_.features().variable_pointers_storage_buffer;
|
||||
|
||||
if (_.addressing_model() == SpvAddressingModelLogical &&
|
||||
SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer &&
|
||||
!_.options()->relax_logical_pointer) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpReturnValue value's type <id> '"
|
||||
<< _.getIdName(value->type_id())
|
||||
<< "' is a pointer, which is invalid in the Logical addressing "
|
||||
"model.";
|
||||
}
|
||||
|
||||
const auto function = inst->function();
|
||||
const auto return_type = _.FindDef(function->GetResultTypeId());
|
||||
if (!return_type || return_type->id() != value_type->id()) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||
<< "OpReturnValue Value <id> '" << _.getIdName(value_id)
|
||||
<< "'s type does not match OpFunction's return type.";
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void printDominatorList(const BasicBlock& b) {
|
||||
std::cout << b.id() << " is dominated by: ";
|
||||
@ -596,5 +750,23 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
|
||||
switch (inst->opcode()) {
|
||||
case SpvOpPhi:
|
||||
if (auto error = ValidatePhi(_, inst)) return error;
|
||||
break;
|
||||
case SpvOpBranchConditional:
|
||||
if (auto error = ValidateBranchConditional(_, inst)) return error;
|
||||
break;
|
||||
case SpvOpReturnValue:
|
||||
if (auto error = ValidateReturnValue(_, inst)) return error;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // namespace val
|
||||
} // namespace spvtools
|
||||
|
@ -784,182 +784,6 @@ bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool idUsage::isValid<SpvOpPhi>(const spv_instruction_t* inst,
|
||||
const spv_opcode_desc /*opcodeEntry*/) {
|
||||
auto thisInst = module_.FindDef(inst->words[2]);
|
||||
SpvOp typeOp = module_.GetIdOpcode(thisInst->type_id());
|
||||
if (!spvOpcodeGeneratesType(typeOp)) {
|
||||
DIAG(thisInst) << "OpPhi's type <id> "
|
||||
<< module_.getIdName(thisInst->type_id())
|
||||
<< " is not a type instruction.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto block = thisInst->block();
|
||||
size_t numInOps = inst->words.size() - 3;
|
||||
if (numInOps % 2 != 0) {
|
||||
DIAG(thisInst)
|
||||
<< "OpPhi does not have an equal number of incoming values and "
|
||||
"basic blocks.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create a uniqued vector of predecessor ids for comparison against
|
||||
// incoming values. OpBranchConditional %cond %label %label produces two
|
||||
// predecessors in the CFG.
|
||||
std::vector<uint32_t> predIds;
|
||||
std::transform(block->predecessors()->begin(), block->predecessors()->end(),
|
||||
std::back_inserter(predIds),
|
||||
[](const BasicBlock* b) { return b->id(); });
|
||||
std::sort(predIds.begin(), predIds.end());
|
||||
predIds.erase(std::unique(predIds.begin(), predIds.end()), predIds.end());
|
||||
|
||||
size_t numEdges = numInOps / 2;
|
||||
if (numEdges != predIds.size()) {
|
||||
DIAG(thisInst) << "OpPhi's number of incoming blocks (" << numEdges
|
||||
<< ") does not match block's predecessor count ("
|
||||
<< block->predecessors()->size() << ").";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 3; i < inst->words.size(); ++i) {
|
||||
auto incId = inst->words[i];
|
||||
if (i % 2 == 1) {
|
||||
// Incoming value type must match the phi result type.
|
||||
auto incTypeId = module_.GetTypeId(incId);
|
||||
if (thisInst->type_id() != incTypeId) {
|
||||
DIAG(thisInst) << "OpPhi's result type <id> "
|
||||
<< module_.getIdName(thisInst->type_id())
|
||||
<< " does not match incoming value <id> "
|
||||
<< module_.getIdName(incId) << " type <id> "
|
||||
<< module_.getIdName(incTypeId) << ".";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (module_.GetIdOpcode(incId) != SpvOpLabel) {
|
||||
DIAG(thisInst) << "OpPhi's incoming basic block <id> "
|
||||
<< module_.getIdName(incId) << " is not an OpLabel.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Incoming basic block must be an immediate predecessor of the phi's
|
||||
// block.
|
||||
if (!std::binary_search(predIds.begin(), predIds.end(), incId)) {
|
||||
DIAG(thisInst) << "OpPhi's incoming basic block <id> "
|
||||
<< module_.getIdName(incId)
|
||||
<< " is not a predecessor of <id> "
|
||||
<< module_.getIdName(block->id()) << ".";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool idUsage::isValid<SpvOpBranchConditional>(const spv_instruction_t* inst,
|
||||
const spv_opcode_desc) {
|
||||
const size_t numOperands = inst->words.size() - 1;
|
||||
const size_t condOperandIndex = 1;
|
||||
const size_t targetTrueIndex = 2;
|
||||
const size_t targetFalseIndex = 3;
|
||||
|
||||
// num_operands is either 3 or 5 --- if 5, the last two need to be literal
|
||||
// integers
|
||||
if (numOperands != 3 && numOperands != 5) {
|
||||
Instruction* fake_inst = nullptr;
|
||||
DIAG(fake_inst) << "OpBranchConditional requires either 3 or 5 parameters";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ret = true;
|
||||
|
||||
// grab the condition operand and check that it is a bool
|
||||
const auto condOp = module_.FindDef(inst->words[condOperandIndex]);
|
||||
if (!condOp || !module_.IsBoolScalarType(condOp->type_id())) {
|
||||
DIAG(condOp)
|
||||
<< "Condition operand for OpBranchConditional must be of boolean type";
|
||||
ret = false;
|
||||
}
|
||||
|
||||
// target operands must be OpLabel
|
||||
// note that we don't need to check that the target labels are in the same
|
||||
// function,
|
||||
// PerformCfgChecks already checks for that
|
||||
const auto targetOpTrue = module_.FindDef(inst->words[targetTrueIndex]);
|
||||
if (!targetOpTrue || SpvOpLabel != targetOpTrue->opcode()) {
|
||||
DIAG(targetOpTrue)
|
||||
<< "The 'True Label' operand for OpBranchConditional must be the "
|
||||
"ID of an OpLabel instruction";
|
||||
ret = false;
|
||||
}
|
||||
|
||||
const auto targetOpFalse = module_.FindDef(inst->words[targetFalseIndex]);
|
||||
if (!targetOpFalse || SpvOpLabel != targetOpFalse->opcode()) {
|
||||
DIAG(targetOpFalse)
|
||||
<< "The 'False Label' operand for OpBranchConditional must be the "
|
||||
"ID of an OpLabel instruction";
|
||||
ret = false;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst,
|
||||
const spv_opcode_desc) {
|
||||
auto valueIndex = 1;
|
||||
auto value = module_.FindDef(inst->words[valueIndex]);
|
||||
if (!value || !value->type_id()) {
|
||||
DIAG(value) << "OpReturnValue Value <id> '"
|
||||
<< module_.getIdName(inst->words[valueIndex])
|
||||
<< "' does not represent a value.";
|
||||
return false;
|
||||
}
|
||||
auto valueType = module_.FindDef(value->type_id());
|
||||
if (!valueType || SpvOpTypeVoid == valueType->opcode()) {
|
||||
DIAG(value) << "OpReturnValue value's type <id> '"
|
||||
<< module_.getIdName(value->type_id())
|
||||
<< "' is missing or void.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool uses_variable_pointer =
|
||||
module_.features().variable_pointers ||
|
||||
module_.features().variable_pointers_storage_buffer;
|
||||
|
||||
if (addressingModel == SpvAddressingModelLogical &&
|
||||
SpvOpTypePointer == valueType->opcode() && !uses_variable_pointer &&
|
||||
!module_.options()->relax_logical_pointer) {
|
||||
DIAG(value)
|
||||
<< "OpReturnValue value's type <id> '"
|
||||
<< module_.getIdName(value->type_id())
|
||||
<< "' is a pointer, which is invalid in the Logical addressing model.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// NOTE: Find OpFunction
|
||||
const spv_instruction_t* function = inst - 1;
|
||||
while (firstInst != function) {
|
||||
if (SpvOpFunction == function->opcode) break;
|
||||
function--;
|
||||
}
|
||||
if (SpvOpFunction != function->opcode) {
|
||||
DIAG(value) << "OpReturnValue is not in a basic block.";
|
||||
return false;
|
||||
}
|
||||
auto returnType = module_.FindDef(function->words[1]);
|
||||
if (!returnType || returnType->id() != valueType->id()) {
|
||||
DIAG(value) << "OpReturnValue Value <id> '"
|
||||
<< module_.getIdName(inst->words[valueIndex])
|
||||
<< "'s type does not match OpFunction's return type.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#undef DIAG
|
||||
|
||||
bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
@ -988,11 +812,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
// Bitwise opcodes are validated in validate_bitwise.cpp.
|
||||
// Logical opcodes are validated in validate_logicals.cpp.
|
||||
// Derivative opcodes are validated in validate_derivatives.cpp.
|
||||
CASE(OpPhi)
|
||||
// OpBranch is validated in validate_cfg.cpp.
|
||||
// See tests in test/val/val_cfg_test.cpp.
|
||||
CASE(OpBranchConditional)
|
||||
CASE(OpReturnValue)
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user