Move cfg opcode validation to another file.

* Moved cfg opcode validation out of idUsage and into validate_cfg.cpp
 * minor style updates
This commit is contained in:
Alan Baker 2018-08-10 12:49:26 -04:00
parent 70de4a35aa
commit 6cd4441c87
4 changed files with 177 additions and 182 deletions

View File

@ -336,6 +336,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
if (auto error = ArithmeticsPass(*vstate, &instruction)) return error;
if (auto error = BitwisePass(*vstate, &instruction)) return error;
if (auto error = LogicalsPass(*vstate, &instruction)) return error;
if (auto error = ControlFlowPass(*vstate, &instruction)) return error;
if (auto error = DerivativesPass(*vstate, &instruction)) return error;
if (auto error = AtomicsPass(*vstate, &instruction)) return error;
if (auto error = PrimitivesPass(*vstate, &instruction)) return error;

View File

@ -116,9 +116,12 @@ void printDominatorList(BasicBlock& block);
/// spec.
spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst);
/// Performs Control Flow Graph validation of a module
/// Performs Control Flow Graph validation and construction.
spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst);
/// Validates Control Flow Graph instructions.
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst);
/// Performs Id and SSA validation of a module
spv_result_t IdPass(ValidationState_t& _, Instruction* inst);

View File

@ -18,6 +18,7 @@
#include <cassert>
#include <functional>
#include <iostream>
#include <iterator>
#include <map>
#include <string>
#include <tuple>
@ -27,6 +28,7 @@
#include <vector>
#include "source/cfa.h"
#include "source/opcode.h"
#include "source/spirv_validator_options.h"
#include "source/val/basic_block.h"
#include "source/val/construct.h"
@ -35,6 +37,158 @@
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) {
SpvOp type_op = _.GetIdOpcode(inst->type_id());
if (!spvOpcodeGeneratesType(type_op)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi's type <id> " << _.getIdName(inst->type_id())
<< " is not a type instruction.";
}
auto block = inst->block();
size_t num_in_ops = inst->words().size() - 3;
if (num_in_ops % 2 != 0) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi does not have an equal number of incoming values and "
"basic blocks.";
}
// Create a uniqued vector of predecessor ids for comparison against
// incoming values. OpBranchConditional %cond %label %label produces two
// predecessors in the CFG.
std::vector<uint32_t> pred_ids;
std::transform(block->predecessors()->begin(), block->predecessors()->end(),
std::back_inserter(pred_ids),
[](const BasicBlock* b) { return b->id(); });
std::sort(pred_ids.begin(), pred_ids.end());
pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end());
size_t num_edges = num_in_ops / 2;
if (num_edges != pred_ids.size()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi's number of incoming blocks (" << num_edges
<< ") does not match block's predecessor count ("
<< block->predecessors()->size() << ").";
}
for (size_t i = 3; i < inst->words().size(); ++i) {
auto inc_id = inst->word(i);
if (i % 2 == 1) {
// Incoming value type must match the phi result type.
auto inc_type_id = _.GetTypeId(inc_id);
if (inst->type_id() != inc_type_id) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi's result type <id> " << _.getIdName(inst->type_id())
<< " does not match incoming value <id> " << _.getIdName(inc_id)
<< " type <id> " << _.getIdName(inc_type_id) << ".";
}
} else {
if (_.GetIdOpcode(inc_id) != SpvOpLabel) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
<< " is not an OpLabel.";
}
// Incoming basic block must be an immediate predecessor of the phi's
// block.
if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
<< " is not a predecessor of <id> " << _.getIdName(block->id())
<< ".";
}
}
}
return SPV_SUCCESS;
}
spv_result_t ValidateBranchConditional(ValidationState_t& _,
const Instruction* inst) {
// num_operands is either 3 or 5 --- if 5, the last two need to be literal
// integers
const auto num_operands = inst->operands().size();
if (num_operands != 3 && num_operands != 5) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpBranchConditional requires either 3 or 5 parameters";
}
// grab the condition operand and check that it is a bool
const auto cond_id = inst->GetOperandAs<uint32_t>(0);
const auto cond_op = _.FindDef(cond_id);
if (!cond_op || !_.IsBoolScalarType(cond_op->type_id())) {
return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for "
"OpBranchConditional must be "
"of boolean type";
}
// target operands must be OpLabel
// note that we don't need to check that the target labels are in the same
// function,
// PerformCfgChecks already checks for that
const auto true_id = inst->GetOperandAs<uint32_t>(1);
const auto true_target = _.FindDef(true_id);
if (!true_target || SpvOpLabel != true_target->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The 'True Label' operand for OpBranchConditional must be the "
"ID of an OpLabel instruction";
}
const auto false_id = inst->GetOperandAs<uint32_t>(2);
const auto false_target = _.FindDef(false_id);
if (!false_target || SpvOpLabel != false_target->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The 'False Label' operand for OpBranchConditional must be the "
"ID of an OpLabel instruction";
}
return SPV_SUCCESS;
}
spv_result_t ValidateReturnValue(ValidationState_t& _,
const Instruction* inst) {
const auto value_id = inst->GetOperandAs<uint32_t>(0);
const auto value = _.FindDef(value_id);
if (!value || !value->type_id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpReturnValue Value <id> '" << _.getIdName(value_id)
<< "' does not represent a value.";
}
auto value_type = _.FindDef(value->type_id());
if (!value_type || SpvOpTypeVoid == value_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpReturnValue value's type <id> '"
<< _.getIdName(value->type_id()) << "' is missing or void.";
}
const bool uses_variable_pointer =
_.features().variable_pointers ||
_.features().variable_pointers_storage_buffer;
if (_.addressing_model() == SpvAddressingModelLogical &&
SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer &&
!_.options()->relax_logical_pointer) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpReturnValue value's type <id> '"
<< _.getIdName(value->type_id())
<< "' is a pointer, which is invalid in the Logical addressing "
"model.";
}
const auto function = inst->function();
const auto return_type = _.FindDef(function->GetResultTypeId());
if (!return_type || return_type->id() != value_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpReturnValue Value <id> '" << _.getIdName(value_id)
<< "'s type does not match OpFunction's return type.";
}
return SPV_SUCCESS;
}
} // namespace
void printDominatorList(const BasicBlock& b) {
std::cout << b.id() << " is dominated by: ";
@ -596,5 +750,23 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
return SPV_SUCCESS;
}
spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpPhi:
if (auto error = ValidatePhi(_, inst)) return error;
break;
case SpvOpBranchConditional:
if (auto error = ValidateBranchConditional(_, inst)) return error;
break;
case SpvOpReturnValue:
if (auto error = ValidateReturnValue(_, inst)) return error;
break;
default:
break;
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools

View File

@ -784,182 +784,6 @@ bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst,
return true;
}
template <>
bool idUsage::isValid<SpvOpPhi>(const spv_instruction_t* inst,
const spv_opcode_desc /*opcodeEntry*/) {
auto thisInst = module_.FindDef(inst->words[2]);
SpvOp typeOp = module_.GetIdOpcode(thisInst->type_id());
if (!spvOpcodeGeneratesType(typeOp)) {
DIAG(thisInst) << "OpPhi's type <id> "
<< module_.getIdName(thisInst->type_id())
<< " is not a type instruction.";
return false;
}
auto block = thisInst->block();
size_t numInOps = inst->words.size() - 3;
if (numInOps % 2 != 0) {
DIAG(thisInst)
<< "OpPhi does not have an equal number of incoming values and "
"basic blocks.";
return false;
}
// Create a uniqued vector of predecessor ids for comparison against
// incoming values. OpBranchConditional %cond %label %label produces two
// predecessors in the CFG.
std::vector<uint32_t> predIds;
std::transform(block->predecessors()->begin(), block->predecessors()->end(),
std::back_inserter(predIds),
[](const BasicBlock* b) { return b->id(); });
std::sort(predIds.begin(), predIds.end());
predIds.erase(std::unique(predIds.begin(), predIds.end()), predIds.end());
size_t numEdges = numInOps / 2;
if (numEdges != predIds.size()) {
DIAG(thisInst) << "OpPhi's number of incoming blocks (" << numEdges
<< ") does not match block's predecessor count ("
<< block->predecessors()->size() << ").";
return false;
}
for (size_t i = 3; i < inst->words.size(); ++i) {
auto incId = inst->words[i];
if (i % 2 == 1) {
// Incoming value type must match the phi result type.
auto incTypeId = module_.GetTypeId(incId);
if (thisInst->type_id() != incTypeId) {
DIAG(thisInst) << "OpPhi's result type <id> "
<< module_.getIdName(thisInst->type_id())
<< " does not match incoming value <id> "
<< module_.getIdName(incId) << " type <id> "
<< module_.getIdName(incTypeId) << ".";
return false;
}
} else {
if (module_.GetIdOpcode(incId) != SpvOpLabel) {
DIAG(thisInst) << "OpPhi's incoming basic block <id> "
<< module_.getIdName(incId) << " is not an OpLabel.";
return false;
}
// Incoming basic block must be an immediate predecessor of the phi's
// block.
if (!std::binary_search(predIds.begin(), predIds.end(), incId)) {
DIAG(thisInst) << "OpPhi's incoming basic block <id> "
<< module_.getIdName(incId)
<< " is not a predecessor of <id> "
<< module_.getIdName(block->id()) << ".";
return false;
}
}
}
return true;
}
template <>
bool idUsage::isValid<SpvOpBranchConditional>(const spv_instruction_t* inst,
const spv_opcode_desc) {
const size_t numOperands = inst->words.size() - 1;
const size_t condOperandIndex = 1;
const size_t targetTrueIndex = 2;
const size_t targetFalseIndex = 3;
// num_operands is either 3 or 5 --- if 5, the last two need to be literal
// integers
if (numOperands != 3 && numOperands != 5) {
Instruction* fake_inst = nullptr;
DIAG(fake_inst) << "OpBranchConditional requires either 3 or 5 parameters";
return false;
}
bool ret = true;
// grab the condition operand and check that it is a bool
const auto condOp = module_.FindDef(inst->words[condOperandIndex]);
if (!condOp || !module_.IsBoolScalarType(condOp->type_id())) {
DIAG(condOp)
<< "Condition operand for OpBranchConditional must be of boolean type";
ret = false;
}
// target operands must be OpLabel
// note that we don't need to check that the target labels are in the same
// function,
// PerformCfgChecks already checks for that
const auto targetOpTrue = module_.FindDef(inst->words[targetTrueIndex]);
if (!targetOpTrue || SpvOpLabel != targetOpTrue->opcode()) {
DIAG(targetOpTrue)
<< "The 'True Label' operand for OpBranchConditional must be the "
"ID of an OpLabel instruction";
ret = false;
}
const auto targetOpFalse = module_.FindDef(inst->words[targetFalseIndex]);
if (!targetOpFalse || SpvOpLabel != targetOpFalse->opcode()) {
DIAG(targetOpFalse)
<< "The 'False Label' operand for OpBranchConditional must be the "
"ID of an OpLabel instruction";
ret = false;
}
return ret;
}
template <>
bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst,
const spv_opcode_desc) {
auto valueIndex = 1;
auto value = module_.FindDef(inst->words[valueIndex]);
if (!value || !value->type_id()) {
DIAG(value) << "OpReturnValue Value <id> '"
<< module_.getIdName(inst->words[valueIndex])
<< "' does not represent a value.";
return false;
}
auto valueType = module_.FindDef(value->type_id());
if (!valueType || SpvOpTypeVoid == valueType->opcode()) {
DIAG(value) << "OpReturnValue value's type <id> '"
<< module_.getIdName(value->type_id())
<< "' is missing or void.";
return false;
}
const bool uses_variable_pointer =
module_.features().variable_pointers ||
module_.features().variable_pointers_storage_buffer;
if (addressingModel == SpvAddressingModelLogical &&
SpvOpTypePointer == valueType->opcode() && !uses_variable_pointer &&
!module_.options()->relax_logical_pointer) {
DIAG(value)
<< "OpReturnValue value's type <id> '"
<< module_.getIdName(value->type_id())
<< "' is a pointer, which is invalid in the Logical addressing model.";
return false;
}
// NOTE: Find OpFunction
const spv_instruction_t* function = inst - 1;
while (firstInst != function) {
if (SpvOpFunction == function->opcode) break;
function--;
}
if (SpvOpFunction != function->opcode) {
DIAG(value) << "OpReturnValue is not in a basic block.";
return false;
}
auto returnType = module_.FindDef(function->words[1]);
if (!returnType || returnType->id() != valueType->id()) {
DIAG(value) << "OpReturnValue Value <id> '"
<< module_.getIdName(inst->words[valueIndex])
<< "'s type does not match OpFunction's return type.";
return false;
}
return true;
}
#undef DIAG
bool idUsage::isValid(const spv_instruction_t* inst) {
@ -988,11 +812,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
// Bitwise opcodes are validated in validate_bitwise.cpp.
// Logical opcodes are validated in validate_logicals.cpp.
// Derivative opcodes are validated in validate_derivatives.cpp.
CASE(OpPhi)
// OpBranch is validated in validate_cfg.cpp.
// See tests in test/val/val_cfg_test.cpp.
CASE(OpBranchConditional)
CASE(OpReturnValue)
default:
return true;
}