mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-23 12:10:06 +00:00
Move OpVectorShuffle check into validate_composites (#1741)
This CL moves the OpVectorShuffle ID check out of validate_id and into validate_composites with the rest of the composite checks.
This commit is contained in:
parent
ee22928bd9
commit
673483d6a7
@ -457,6 +457,69 @@ spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
spv_result_t ValidateVectorShuffle(ValidationState_t& _,
|
||||
const Instruction* inst) {
|
||||
auto resultType = _.FindDef(inst->type_id());
|
||||
if (!resultType || resultType->opcode() != SpvOpTypeVector) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "The Result Type of OpVectorShuffle must be"
|
||||
<< " OpTypeVector. Found Op"
|
||||
<< spvOpcodeString(static_cast<SpvOp>(resultType->opcode())) << ".";
|
||||
}
|
||||
|
||||
// The number of components in Result Type must be the same as the number of
|
||||
// Component operands.
|
||||
auto componentCount = inst->operands().size() - 4;
|
||||
auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
|
||||
if (componentCount != resultVectorDimension) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "OpVectorShuffle component literals count does not match "
|
||||
"Result Type <id> '"
|
||||
<< _.getIdName(resultType->id()) << "'s vector component count.";
|
||||
}
|
||||
|
||||
// Vector 1 and Vector 2 must both have vector types, with the same Component
|
||||
// Type as Result Type.
|
||||
auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
|
||||
auto vector1Type = _.FindDef(vector1Object->type_id());
|
||||
auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
|
||||
auto vector2Type = _.FindDef(vector2Object->type_id());
|
||||
if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "The type of Vector 1 must be OpTypeVector.";
|
||||
}
|
||||
if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "The type of Vector 2 must be OpTypeVector.";
|
||||
}
|
||||
|
||||
auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
|
||||
if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "The Component Type of Vector 1 must be the same as ResultType.";
|
||||
}
|
||||
if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "The Component Type of Vector 2 must be the same as ResultType.";
|
||||
}
|
||||
|
||||
// All Component literals must either be FFFFFFFF or in [0, N - 1].
|
||||
auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
|
||||
auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
|
||||
auto N = vector1ComponentCount + vector2ComponentCount;
|
||||
auto firstLiteralIndex = 4;
|
||||
for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
|
||||
auto literal = inst->GetOperandAs<uint32_t>(i);
|
||||
if (literal != 0xFFFFFFFF && literal >= N) {
|
||||
return _.diag(SPV_ERROR_INVALID_ID, inst->InstructionPosition())
|
||||
<< "Component index " << literal << " is out of bounds for "
|
||||
<< "combined (Vector1 + Vector2) size of " << N << ".";
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Validates correctness of composite instructions.
|
||||
@ -466,11 +529,8 @@ spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
|
||||
return ValidateVectorExtractDynamic(_, inst);
|
||||
case SpvOpVectorInsertDynamic:
|
||||
return ValidateVectorInsertDyanmic(_, inst);
|
||||
case SpvOpVectorShuffle: {
|
||||
// Handled in validate_id.cpp.
|
||||
// TODO(atgoo@github.com) Consider moving it here.
|
||||
break;
|
||||
}
|
||||
case SpvOpVectorShuffle:
|
||||
return ValidateVectorShuffle(_, inst);
|
||||
case SpvOpCompositeConstruct:
|
||||
return ValidateCompositeConstruct(_, inst);
|
||||
case SpvOpCompositeExtract:
|
||||
|
@ -1707,90 +1707,6 @@ bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool idUsage::isValid<SpvOpVectorShuffle>(const spv_instruction_t* inst,
|
||||
const spv_opcode_desc) {
|
||||
auto instr_name = [&inst]() {
|
||||
std::string name =
|
||||
"Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode)));
|
||||
return name;
|
||||
};
|
||||
|
||||
// Result Type must be an OpTypeVector.
|
||||
auto resultTypeIndex = 1;
|
||||
auto resultType = module_.FindDef(inst->words[resultTypeIndex]);
|
||||
if (!resultType || resultType->opcode() != SpvOpTypeVector) {
|
||||
DIAG(resultType) << "The Result Type of " << instr_name()
|
||||
<< " must be OpTypeVector. Found Op"
|
||||
<< spvOpcodeString(
|
||||
static_cast<SpvOp>(resultType->opcode()))
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
|
||||
// The number of components in Result Type must be the same as the number of
|
||||
// Component operands.
|
||||
auto componentCount = inst->words.size() - 5;
|
||||
auto vectorComponentCountIndex = 3;
|
||||
auto resultVectorDimension = resultType->words()[vectorComponentCountIndex];
|
||||
if (componentCount != resultVectorDimension) {
|
||||
DIAG(module_.FindDef(inst->words.back()))
|
||||
<< instr_name()
|
||||
<< " component literals count does not match "
|
||||
"Result Type <id> '"
|
||||
<< module_.getIdName(resultType->id()) << "'s vector component count.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Vector 1 and Vector 2 must both have vector types, with the same Component
|
||||
// Type as Result Type.
|
||||
auto vector1Index = 3;
|
||||
auto vector1Object = module_.FindDef(inst->words[vector1Index]);
|
||||
auto vector1Type = module_.FindDef(vector1Object->type_id());
|
||||
auto vector2Index = 4;
|
||||
auto vector2Object = module_.FindDef(inst->words[vector2Index]);
|
||||
auto vector2Type = module_.FindDef(vector2Object->type_id());
|
||||
if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) {
|
||||
DIAG(vector1Object) << "The type of Vector 1 must be OpTypeVector.";
|
||||
return false;
|
||||
}
|
||||
if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) {
|
||||
DIAG(vector2Object) << "The type of Vector 2 must be OpTypeVector.";
|
||||
return false;
|
||||
}
|
||||
auto vectorComponentTypeIndex = 2;
|
||||
auto resultComponentType = resultType->words()[vectorComponentTypeIndex];
|
||||
auto vector1ComponentType = vector1Type->words()[vectorComponentTypeIndex];
|
||||
if (vector1ComponentType != resultComponentType) {
|
||||
DIAG(vector1Object) << "The Component Type of Vector 1 must be the same "
|
||||
"as ResultType.";
|
||||
return false;
|
||||
}
|
||||
auto vector2ComponentType = vector2Type->words()[vectorComponentTypeIndex];
|
||||
if (vector2ComponentType != resultComponentType) {
|
||||
DIAG(vector2Object) << "The Component Type of Vector 2 must be the same "
|
||||
"as ResultType.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// All Component literals must either be FFFFFFFF or in [0, N - 1].
|
||||
auto vector1ComponentCount = vector1Type->words()[vectorComponentCountIndex];
|
||||
auto vector2ComponentCount = vector2Type->words()[vectorComponentCountIndex];
|
||||
auto N = vector1ComponentCount + vector2ComponentCount;
|
||||
auto firstLiteralIndex = 5;
|
||||
for (size_t i = firstLiteralIndex; i < inst->words.size(); ++i) {
|
||||
auto literal = inst->words[i];
|
||||
if (literal != 0xFFFFFFFF && literal >= N) {
|
||||
DIAG(module_.FindDef(inst->words[2]))
|
||||
<< "Component index " << literal << " is out of range for a result "
|
||||
<< "vector of size " << N << ".";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool idUsage::isValid<SpvOpPhi>(const spv_instruction_t* inst,
|
||||
const spv_opcode_desc /*opcodeEntry*/) {
|
||||
@ -1977,11 +1893,7 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
#define CASE(OpCode) \
|
||||
case Spv##OpCode: \
|
||||
return isValid<Spv##OpCode>(inst, opcodeEntry);
|
||||
#define TODO(OpCode) \
|
||||
case Spv##OpCode: \
|
||||
return true;
|
||||
switch (inst->opcode) {
|
||||
TODO(OpUndef)
|
||||
CASE(OpMemberName)
|
||||
CASE(OpLine)
|
||||
CASE(OpDecorate)
|
||||
@ -1989,7 +1901,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
CASE(OpDecorationGroup)
|
||||
CASE(OpGroupDecorate)
|
||||
CASE(OpGroupMemberDecorate)
|
||||
TODO(OpExtInst)
|
||||
CASE(OpEntryPoint)
|
||||
CASE(OpExecutionMode)
|
||||
CASE(OpTypeVector)
|
||||
@ -2010,7 +1921,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
CASE(OpSpecConstantFalse)
|
||||
CASE(OpSpecConstantComposite)
|
||||
CASE(OpSampledImage)
|
||||
TODO(OpSpecConstantOp)
|
||||
CASE(OpVariable)
|
||||
CASE(OpLoad)
|
||||
CASE(OpStore)
|
||||
@ -2020,87 +1930,19 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
|
||||
CASE(OpInBoundsAccessChain)
|
||||
CASE(OpPtrAccessChain)
|
||||
CASE(OpInBoundsPtrAccessChain)
|
||||
TODO(OpArrayLength)
|
||||
TODO(OpGenericPtrMemSemantics)
|
||||
CASE(OpFunction)
|
||||
CASE(OpFunctionParameter)
|
||||
CASE(OpFunctionCall)
|
||||
// Conversion opcodes are validated in validate_conversion.cpp.
|
||||
CASE(OpVectorShuffle)
|
||||
// Other composite opcodes are validated in validate_composites.cpp.
|
||||
// Arithmetic opcodes are validated in validate_arithmetics.cpp.
|
||||
// Bitwise opcodes are validated in validate_bitwise.cpp.
|
||||
// Logical opcodes are validated in validate_logicals.cpp.
|
||||
// Derivative opcodes are validated in validate_derivatives.cpp.
|
||||
CASE(OpPhi)
|
||||
TODO(OpLoopMerge)
|
||||
TODO(OpSelectionMerge)
|
||||
// OpBranch is validated in validate_cfg.cpp.
|
||||
// See tests in test/val/val_cfg_test.cpp.
|
||||
CASE(OpBranchConditional)
|
||||
TODO(OpSwitch)
|
||||
CASE(OpReturnValue)
|
||||
TODO(OpLifetimeStart)
|
||||
TODO(OpLifetimeStop)
|
||||
TODO(OpAtomicLoad)
|
||||
TODO(OpAtomicStore)
|
||||
TODO(OpAtomicExchange)
|
||||
TODO(OpAtomicCompareExchange)
|
||||
TODO(OpAtomicCompareExchangeWeak)
|
||||
TODO(OpAtomicIIncrement)
|
||||
TODO(OpAtomicIDecrement)
|
||||
TODO(OpAtomicIAdd)
|
||||
TODO(OpAtomicISub)
|
||||
TODO(OpAtomicUMin)
|
||||
TODO(OpAtomicUMax)
|
||||
TODO(OpAtomicAnd)
|
||||
TODO(OpAtomicOr)
|
||||
TODO(OpAtomicSMin)
|
||||
TODO(OpAtomicSMax)
|
||||
TODO(OpEmitStreamVertex)
|
||||
TODO(OpEndStreamPrimitive)
|
||||
TODO(OpGroupAsyncCopy)
|
||||
TODO(OpGroupWaitEvents)
|
||||
TODO(OpGroupAll)
|
||||
TODO(OpGroupAny)
|
||||
TODO(OpGroupBroadcast)
|
||||
TODO(OpGroupIAdd)
|
||||
TODO(OpGroupFAdd)
|
||||
TODO(OpGroupFMin)
|
||||
TODO(OpGroupUMin)
|
||||
TODO(OpGroupSMin)
|
||||
TODO(OpGroupFMax)
|
||||
TODO(OpGroupUMax)
|
||||
TODO(OpGroupSMax)
|
||||
TODO(OpEnqueueMarker)
|
||||
TODO(OpEnqueueKernel)
|
||||
TODO(OpGetKernelNDrangeSubGroupCount)
|
||||
TODO(OpGetKernelNDrangeMaxSubGroupSize)
|
||||
TODO(OpGetKernelWorkGroupSize)
|
||||
TODO(OpGetKernelPreferredWorkGroupSizeMultiple)
|
||||
TODO(OpRetainEvent)
|
||||
TODO(OpReleaseEvent)
|
||||
TODO(OpCreateUserEvent)
|
||||
TODO(OpIsValidEvent)
|
||||
TODO(OpSetUserEventStatus)
|
||||
TODO(OpCaptureEventProfilingInfo)
|
||||
TODO(OpGetDefaultQueue)
|
||||
TODO(OpBuildNDRange)
|
||||
TODO(OpReadPipe)
|
||||
TODO(OpWritePipe)
|
||||
TODO(OpReservedReadPipe)
|
||||
TODO(OpReservedWritePipe)
|
||||
TODO(OpReserveReadPipePackets)
|
||||
TODO(OpReserveWritePipePackets)
|
||||
TODO(OpCommitReadPipe)
|
||||
TODO(OpCommitWritePipe)
|
||||
TODO(OpIsValidReserveId)
|
||||
TODO(OpGetNumPipePackets)
|
||||
TODO(OpGetMaxPipePackets)
|
||||
TODO(OpGroupReserveReadPipePackets)
|
||||
TODO(OpGroupReserveWritePipePackets)
|
||||
TODO(OpGroupCommitReadPipe)
|
||||
TODO(OpGroupCommitWritePipe)
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
@ -3843,7 +3843,7 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) {
|
||||
%var2 = OpVariable %ptr_vec3 Function %2
|
||||
%6 = OpLoad %vec2 %var
|
||||
%7 = OpLoad %vec3 %var2
|
||||
%8 = OpVectorShuffle %vec4 %6 %7 0 5 2 6
|
||||
%8 = OpVectorShuffle %vec4 %6 %7 0 8 2 6
|
||||
OpReturnValue %8
|
||||
OpFunctionEnd)";
|
||||
CompileSuccessfully(spirv.c_str());
|
||||
@ -3851,7 +3851,8 @@ TEST_F(ValidateIdWithMessage, OpVectorShuffleLiterals) {
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Component index 5 is out of range for a result vector of size 5."));
|
||||
"Component index 8 is out of bounds for combined (Vector1 + Vector2) "
|
||||
"size of 5."));
|
||||
}
|
||||
|
||||
// TODO: OpCompositeConstruct
|
||||
@ -4730,7 +4731,8 @@ TEST_F(ValidateIdWithMessage, CorrectErrorForShuffle) {
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Component index 4 is out of range for a result vector of size 4."));
|
||||
"Component index 4 is out of bounds for combined (Vector1 + Vector2) "
|
||||
"size of 4."));
|
||||
EXPECT_EQ(23, getErrorPosition().index);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user