mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-23 04:00:05 +00:00
Add validaton for SPV_KHR_8bit_storage + convert to/from floats. (#1990)
The SPV_KHR_8bit_storage extension does not permit 8-bit integers to be cast directly to floating point types. We are seeing shaders in the wild, being produced by toolchains like glslang, that are generating invalid SPIR-V. This change adds validation to check for the patterns not permitted, and some tests that expose the failure.
This commit is contained in:
parent
715afb0cea
commit
d29a1f98f3
@ -49,6 +49,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Invalid cast to 8-bit integer from a floating-point: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@ -70,6 +75,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Invalid cast to 8-bit integer from a floating-point: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
@ -93,6 +103,11 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
|
||||
<< "Expected input to have the same dimension as Result Type: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
|
||||
return _.diag(SPV_ERROR_INVALID_DATA, inst)
|
||||
<< "Invalid cast to floating-point from an 8-bit integer: "
|
||||
<< spvOpcodeString(opcode);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -365,6 +365,9 @@ void ValidationState_t::RegisterCapability(SpvCapability cap) {
|
||||
features_.group_ops_reduce_and_scans = true;
|
||||
break;
|
||||
case SpvCapabilityInt8:
|
||||
features_.use_int8_type = true;
|
||||
features_.declare_int8_type = true;
|
||||
break;
|
||||
case SpvCapabilityStorageBuffer8BitAccess:
|
||||
case SpvCapabilityUniformAndStorageBuffer8BitAccess:
|
||||
case SpvCapabilityStoragePushConstant8:
|
||||
|
@ -86,6 +86,10 @@ class ValidationState_t {
|
||||
// Target environment uses relaxed block layout.
|
||||
// This is true for Vulkan 1.1 or later.
|
||||
bool env_relaxed_block_layout = false;
|
||||
|
||||
// Allow an OpTypeInt with 8 bit width to be used in more than just int
|
||||
// conversion opcodes
|
||||
bool use_int8_type = false;
|
||||
};
|
||||
|
||||
ValidationState_t(const spv_const_context context,
|
||||
|
@ -31,18 +31,23 @@ using ValidateConversion = spvtest::ValidateBase<bool>;
|
||||
|
||||
std::string GenerateShaderCode(
|
||||
const std::string& body,
|
||||
const std::string& capabilities_and_extensions = "") {
|
||||
const std::string& capabilities_and_extensions = "",
|
||||
const std::string& decorations = "", const std::string& types = "",
|
||||
const std::string& variables = "") {
|
||||
const std::string capabilities =
|
||||
R"(
|
||||
OpCapability Shader
|
||||
OpCapability Int64
|
||||
OpCapability Float64)";
|
||||
|
||||
const std::string after_extension_before_body =
|
||||
const std::string after_extension_before_decorations =
|
||||
R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint Fragment %main "main"
|
||||
OpExecutionMode %main OriginUpperLeft
|
||||
OpExecutionMode %main OriginUpperLeft)";
|
||||
|
||||
const std::string after_decorations_before_types =
|
||||
R"(
|
||||
%void = OpTypeVoid
|
||||
%func = OpTypeFunction %void
|
||||
%bool = OpTypeBool
|
||||
@ -140,8 +145,10 @@ OpExecutionMode %main OriginUpperLeft
|
||||
%true = OpConstantTrue %bool
|
||||
%false = OpConstantFalse %bool
|
||||
|
||||
%f32ptr_func = OpTypePointer Function %f32
|
||||
%f32ptr_func = OpTypePointer Function %f32)";
|
||||
|
||||
const std::string after_variables_before_body =
|
||||
R"(
|
||||
%main = OpFunction %void None %func
|
||||
%main_entry = OpLabel)";
|
||||
|
||||
@ -151,7 +158,9 @@ OpReturn
|
||||
OpFunctionEnd)";
|
||||
|
||||
return capabilities + capabilities_and_extensions +
|
||||
after_extension_before_body + body + after_body;
|
||||
after_extension_before_decorations + decorations +
|
||||
after_decorations_before_types + types + variables +
|
||||
after_variables_before_body + body + after_body;
|
||||
}
|
||||
|
||||
std::string GenerateKernelCode(
|
||||
@ -630,6 +639,170 @@ TEST_F(ValidateConversion, QuantizeToF16WrongInputType) {
|
||||
"Expected input type to be equal to Result Type: QuantizeToF16"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, ConvertFToS8BitStorage) {
|
||||
const std::string capabilities_and_extensions = R"(
|
||||
OpCapability StorageBuffer8BitAccess
|
||||
OpExtension "SPV_KHR_8bit_storage"
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
)";
|
||||
|
||||
const std::string decorations = R"(
|
||||
OpDecorate %ssbo Block
|
||||
OpDecorate %ssbo Binding 0
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpMemberDecorate %ssbo 0 Offset 0
|
||||
)";
|
||||
|
||||
const std::string types = R"(
|
||||
%i8 = OpTypeInt 8 1
|
||||
%i8ptr = OpTypePointer StorageBuffer %i8
|
||||
%ssbo = OpTypeStruct %i8
|
||||
%ssboptr = OpTypePointer StorageBuffer %ssbo
|
||||
)";
|
||||
|
||||
const std::string variables = R"(
|
||||
%var = OpVariable %ssboptr StorageBuffer
|
||||
)";
|
||||
|
||||
const std::string body = R"(
|
||||
%val = OpConvertFToS %i8 %f32_2
|
||||
%accesschain = OpAccessChain %i8ptr %var %u32_0
|
||||
OpStore %accesschain %val
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
|
||||
decorations, types, variables)
|
||||
.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Invalid cast to 8-bit integer from a floating-point: ConvertFToS"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, ConvertFToU8BitStorage) {
|
||||
const std::string capabilities_and_extensions = R"(
|
||||
OpCapability StorageBuffer8BitAccess
|
||||
OpExtension "SPV_KHR_8bit_storage"
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
)";
|
||||
|
||||
const std::string decorations = R"(
|
||||
OpDecorate %ssbo Block
|
||||
OpDecorate %ssbo Binding 0
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpMemberDecorate %ssbo 0 Offset 0
|
||||
)";
|
||||
|
||||
const std::string types = R"(
|
||||
%u8 = OpTypeInt 8 0
|
||||
%u8ptr = OpTypePointer StorageBuffer %u8
|
||||
%ssbo = OpTypeStruct %u8
|
||||
%ssboptr = OpTypePointer StorageBuffer %ssbo
|
||||
)";
|
||||
|
||||
const std::string variables = R"(
|
||||
%var = OpVariable %ssboptr StorageBuffer
|
||||
)";
|
||||
|
||||
const std::string body = R"(
|
||||
%val = OpConvertFToU %u8 %f32_2
|
||||
%accesschain = OpAccessChain %u8ptr %var %u32_0
|
||||
OpStore %accesschain %val
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
|
||||
decorations, types, variables)
|
||||
.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Invalid cast to 8-bit integer from a floating-point: ConvertFToU"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, ConvertSToF8BitStorage) {
|
||||
const std::string capabilities_and_extensions = R"(
|
||||
OpCapability StorageBuffer8BitAccess
|
||||
OpExtension "SPV_KHR_8bit_storage"
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
)";
|
||||
|
||||
const std::string decorations = R"(
|
||||
OpDecorate %ssbo Block
|
||||
OpDecorate %ssbo Binding 0
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpMemberDecorate %ssbo 0 Offset 0
|
||||
)";
|
||||
|
||||
const std::string types = R"(
|
||||
%i8 = OpTypeInt 8 1
|
||||
%i8ptr = OpTypePointer StorageBuffer %i8
|
||||
%ssbo = OpTypeStruct %i8
|
||||
%ssboptr = OpTypePointer StorageBuffer %ssbo
|
||||
)";
|
||||
|
||||
const std::string variables = R"(
|
||||
%var = OpVariable %ssboptr StorageBuffer
|
||||
)";
|
||||
|
||||
const std::string body = R"(
|
||||
%accesschain = OpAccessChain %i8ptr %var %u32_0
|
||||
%load = OpLoad %i8 %accesschain
|
||||
%val = OpConvertSToF %f32 %load
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
|
||||
decorations, types, variables)
|
||||
.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Invalid cast to floating-point from an 8-bit integer: ConvertSToF"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, ConvertUToF8BitStorage) {
|
||||
const std::string capabilities_and_extensions = R"(
|
||||
OpCapability StorageBuffer8BitAccess
|
||||
OpExtension "SPV_KHR_8bit_storage"
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
)";
|
||||
|
||||
const std::string decorations = R"(
|
||||
OpDecorate %ssbo Block
|
||||
OpDecorate %ssbo Binding 0
|
||||
OpDecorate %ssbo DescriptorSet 0
|
||||
OpMemberDecorate %ssbo 0 Offset 0
|
||||
)";
|
||||
|
||||
const std::string types = R"(
|
||||
%u8 = OpTypeInt 8 0
|
||||
%u8ptr = OpTypePointer StorageBuffer %u8
|
||||
%ssbo = OpTypeStruct %u8
|
||||
%ssboptr = OpTypePointer StorageBuffer %ssbo
|
||||
)";
|
||||
|
||||
const std::string variables = R"(
|
||||
%var = OpVariable %ssboptr StorageBuffer
|
||||
)";
|
||||
|
||||
const std::string body = R"(
|
||||
%accesschain = OpAccessChain %u8ptr %var %u32_0
|
||||
%load = OpLoad %u8 %accesschain
|
||||
%val = OpConvertUToF %f32 %load
|
||||
)";
|
||||
|
||||
CompileSuccessfully(GenerateShaderCode(body, capabilities_and_extensions,
|
||||
decorations, types, variables)
|
||||
.c_str());
|
||||
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
|
||||
EXPECT_THAT(
|
||||
getDiagnosticString(),
|
||||
HasSubstr(
|
||||
"Invalid cast to floating-point from an 8-bit integer: ConvertUToF"));
|
||||
}
|
||||
|
||||
TEST_F(ValidateConversion, ConvertPtrToUSuccess) {
|
||||
const std::string body = R"(
|
||||
%ptr = OpVariable %f32ptr_func Function
|
||||
|
Loading…
Reference in New Issue
Block a user