Add validaton for SPV_KHR_8bit_storage + convert to/from floats. (#1990)

The SPV_KHR_8bit_storage extension does not permit 8-bit integers to be
cast directly to floating point types. We are seeing shaders in the
wild, being produced by toolchains like glslang, that are generating
invalid SPIR-V.

This change adds validation to check for the patterns not permitted, and
some tests that expose the failure.
This commit is contained in:
Neil Henning 2018-10-19 18:45:26 +01:00 committed by Steven Perron
parent 715afb0cea
commit d29a1f98f3
4 changed files with 200 additions and 5 deletions

View File

@ -49,6 +49,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Invalid cast to 8-bit integer from a floating-point: "
<< spvOpcodeString(opcode);
break;
}
@ -70,6 +75,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Invalid cast to 8-bit integer from a floating-point: "
<< spvOpcodeString(opcode);
break;
}
@ -93,6 +103,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
<< "Expected input to have the same dimension as Result Type: "
<< spvOpcodeString(opcode);
if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Invalid cast to floating-point from an 8-bit integer: "
<< spvOpcodeString(opcode);
break;
}

View File

@ -365,6 +365,9 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) {
features_.group_ops_reduce_and_scans = true;
break;
case SpvCapabilityInt8:
features_.use_int8_type = true;
features_.declare_int8_type = true;
break;
case SpvCapabilityStorageBuffer8BitAccess:
case SpvCapabilityUniformAndStorageBuffer8BitAccess:
case SpvCapabilityStoragePushConstant8:

View File

@ -86,6 +86,10 @@ class ValidationState_t {
// Target environment uses relaxed block layout.
// This is true for Vulkan 1.1 or later.
bool env_relaxed_block_layout = false;
// Allow an OpTypeInt with 8 bit width to be used in more than just int
// conversion opcodes
bool use_int8_type = false;
};
ValidationState_t(const spv_const_context context,

View File

@ -31,18 +31,23 @@ using ValidateConversion = spvtest::ValidateBase<bool>;
std::string GenerateShaderCode(
const std::string& body,
const std::string& capabilities_and_extensions = "") {
const std::string& capabilities_and_extensions = "",
const std::string& decorations = "", const std::string& types = "",
const std::string& variables = "") {
const std::string capabilities =
R"(
OpCapability Shader
OpCapability Int64
OpCapability Float64)";
const std::string after_extension_before_body =
const std::string after_extension_before_decorations =
R"(
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpExecutionMode %main OriginUpperLeft)";
const std::string after_decorations_before_types =
R"(
%void = OpTypeVoid
%func = OpTypeFunction %void
%bool = OpTypeBool
@ -140,8 +145,10 @@ OpExecutionMode %main OriginUpperLeft
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%f32ptr_func = OpTypePointer Function %f32
%f32ptr_func = OpTypePointer Function %f32)";
const std::string after_variables_before_body =
R"(
%main = OpFunction %void None %func
%main_entry = OpLabel)";
@ -151,7 +158,9 @@ OpReturn
OpFunctionEnd)";
return capabilities + capabilities_and_extensions +
after_extension_before_body + body + after_body;
after_extension_before_decorations + decorations +
after_decorations_before_types + types + variables +
after_variables_before_body + body + after_body;
}
std::string GenerateKernelCode(
@ -630,6 +639,170 @@ TEST_F(ValidateConversion, QuantizeToF16WrongInputType) {
"Expected input type to be equal to Result Type: QuantizeToF16"));
}
TEST_F(ValidateConversion, ConvertFToS8BitStorage) {
const std::string capabilities_and_extensions = R"(
OpCapability StorageBuffer8BitAccess
OpExtension "SPV_KHR_8bit_storage"
OpExtension "SPV_KHR_storage_buffer_storage_class"
)";
const std::string decorations = R"(
OpDecorate %ssbo Block
OpDecorate %ssbo Binding 0
OpDecorate %ssbo DescriptorSet 0
OpMemberDecorate %ssbo 0 Offset 0
)";
const std::string types = R"(
%i8 = OpTypeInt 8 1
%i8ptr = OpTypePointer StorageBuffer %i8
%ssbo = OpTypeStruct %i8
%ssboptr = OpTypePointer StorageBuffer %ssbo
)";
const std::string variables = R"(
%var = OpVariable %ssboptr StorageBuffer
)";
const std::string body = R"(
%val = OpConvertFToS %i8 %f32_2
%accesschain = OpAccessChain %i8ptr %var %u32_0
OpStore %accesschain %val
)";
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
decorations, types, variables)
.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Invalid cast to 8-bit integer from a floating-point: ConvertFToS"));
}
TEST_F(ValidateConversion, ConvertFToU8BitStorage) {
const std::string capabilities_and_extensions = R"(
OpCapability StorageBuffer8BitAccess
OpExtension "SPV_KHR_8bit_storage"
OpExtension "SPV_KHR_storage_buffer_storage_class"
)";
const std::string decorations = R"(
OpDecorate %ssbo Block
OpDecorate %ssbo Binding 0
OpDecorate %ssbo DescriptorSet 0
OpMemberDecorate %ssbo 0 Offset 0
)";
const std::string types = R"(
%u8 = OpTypeInt 8 0
%u8ptr = OpTypePointer StorageBuffer %u8
%ssbo = OpTypeStruct %u8
%ssboptr = OpTypePointer StorageBuffer %ssbo
)";
const std::string variables = R"(
%var = OpVariable %ssboptr StorageBuffer
)";
const std::string body = R"(
%val = OpConvertFToU %u8 %f32_2
%accesschain = OpAccessChain %u8ptr %var %u32_0
OpStore %accesschain %val
)";
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
decorations, types, variables)
.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Invalid cast to 8-bit integer from a floating-point: ConvertFToU"));
}
TEST_F(ValidateConversion, ConvertSToF8BitStorage) {
const std::string capabilities_and_extensions = R"(
OpCapability StorageBuffer8BitAccess
OpExtension "SPV_KHR_8bit_storage"
OpExtension "SPV_KHR_storage_buffer_storage_class"
)";
const std::string decorations = R"(
OpDecorate %ssbo Block
OpDecorate %ssbo Binding 0
OpDecorate %ssbo DescriptorSet 0
OpMemberDecorate %ssbo 0 Offset 0
)";
const std::string types = R"(
%i8 = OpTypeInt 8 1
%i8ptr = OpTypePointer StorageBuffer %i8
%ssbo = OpTypeStruct %i8
%ssboptr = OpTypePointer StorageBuffer %ssbo
)";
const std::string variables = R"(
%var = OpVariable %ssboptr StorageBuffer
)";
const std::string body = R"(
%accesschain = OpAccessChain %i8ptr %var %u32_0
%load = OpLoad %i8 %accesschain
%val = OpConvertSToF %f32 %load
)";
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
decorations, types, variables)
.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Invalid cast to floating-point from an 8-bit integer: ConvertSToF"));
}
TEST_F(ValidateConversion, ConvertUToF8BitStorage) {
const std::string capabilities_and_extensions = R"(
OpCapability StorageBuffer8BitAccess
OpExtension "SPV_KHR_8bit_storage"
OpExtension "SPV_KHR_storage_buffer_storage_class"
)";
const std::string decorations = R"(
OpDecorate %ssbo Block
OpDecorate %ssbo Binding 0
OpDecorate %ssbo DescriptorSet 0
OpMemberDecorate %ssbo 0 Offset 0
)";
const std::string types = R"(
%u8 = OpTypeInt 8 0
%u8ptr = OpTypePointer StorageBuffer %u8
%ssbo = OpTypeStruct %u8
%ssboptr = OpTypePointer StorageBuffer %ssbo
)";
const std::string variables = R"(
%var = OpVariable %ssboptr StorageBuffer
)";
const std::string body = R"(
%accesschain = OpAccessChain %u8ptr %var %u32_0
%load = OpLoad %u8 %accesschain
%val = OpConvertUToF %f32 %load
)";
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
decorations, types, variables)
.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"Invalid cast to floating-point from an 8-bit integer: ConvertUToF"));
}
TEST_F(ValidateConversion, ConvertPtrToUSuccess) {
const std::string body = R"(
%ptr = OpVariable %f32ptr_func Function