Add support for SPV_KHR_compute_shader_derivative (#5817)

* Add support for SPV_KHR_compute_shader_derivative

* Update tests for SPV_KHR_compute_shader_derivatives

---------

Co-authored-by: MagicPoncho <magicponcho@gmail.com>
This commit is contained in:
JN Mo 2024-09-25 09:59:33 -04:00 committed by GitHub
parent 362ce7c60d
commit 44936c4a9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 90 additions and 69 deletions

View File

@ -1010,7 +1010,7 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_NV_bindless_texture", "SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives", "SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch" "SPV_KHR_ray_tracing_position_fetch"

View File

@ -428,9 +428,9 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_uniform_group_instructions", "SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model", "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", "SPV_EXT_fragment_shader_interlock",
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix", "SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"}); "SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
} }
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds( bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

View File

@ -291,7 +291,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_NV_bindless_texture", "SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives", "SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"}); "SPV_KHR_ray_tracing_position_fetch"});

View File

@ -141,7 +141,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_NV_bindless_texture", "SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives", "SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"}); "SPV_KHR_ray_tracing_position_fetch"});

View File

@ -74,8 +74,8 @@ class TrimCapabilitiesPass : public Pass {
// contains unsupported instruction, the pass could yield bad results. // contains unsupported instruction, the pass could yield bad results.
static constexpr std::array kSupportedCapabilities{ static constexpr std::array kSupportedCapabilities{
// clang-format off // clang-format off
spv::Capability::ComputeDerivativeGroupLinearNV, spv::Capability::ComputeDerivativeGroupLinearKHR,
spv::Capability::ComputeDerivativeGroupQuadsNV, spv::Capability::ComputeDerivativeGroupQuadsKHR,
spv::Capability::Float16, spv::Capability::Float16,
spv::Capability::Float64, spv::Capability::Float64,
spv::Capability::FragmentShaderPixelInterlockEXT, spv::Capability::FragmentShaderPixelInterlockEXT,

View File

@ -60,12 +60,14 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model, ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) { std::string* message) {
if (model != spv::ExecutionModel::Fragment && if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) { model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) { if (message) {
*message = *message =
std::string( std::string(
"Derivative instructions require Fragment or GLCompute " "Derivative instructions require Fragment, GLCompute, "
"execution model: ") + "MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode); spvOpcodeString(opcode);
} }
return false; return false;
@ -79,19 +81,23 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id()); const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id());
if (models && if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() && (models->find(spv::ExecutionModel::GLCompute) !=
models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes || (!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() && modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) { modes->end()))) {
if (message) { if (message) {
*message = std::string( *message =
"Derivative instructions require " std::string(
"DerivativeGroupQuadsNV " "Derivative instructions require "
"or DerivativeGroupLinearNV execution mode for " "DerivativeGroupQuadsKHR "
"GLCompute execution model: ") + "or DerivativeGroupLinearKHR execution mode for "
spvOpcodeString(opcode); "GLCompute, MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
} }
return false; return false;
} }

View File

@ -2026,11 +2026,13 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
->RegisterExecutionModelLimitation( ->RegisterExecutionModelLimitation(
[&](spv::ExecutionModel model, std::string* message) { [&](spv::ExecutionModel model, std::string* message) {
if (model != spv::ExecutionModel::Fragment && if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) { model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) { if (message) {
*message = std::string( *message = std::string(
"OpImageQueryLod requires Fragment or GLCompute execution " "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or "
"model"); "TaskEXT execution model");
} }
return false; return false;
} }
@ -2042,16 +2044,20 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
std::string* message) { std::string* message) {
const auto* models = state.GetExecutionModels(entry_point->id()); const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id());
if (models->find(spv::ExecutionModel::GLCompute) != models->end() && if (models &&
modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == (models->find(spv::ExecutionModel::GLCompute) != models->end() ||
modes->end() && models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
modes->end()) { (!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) { if (message) {
*message = std::string( *message = std::string(
"OpImageQueryLod requires DerivativeGroupQuadsNV " "OpImageQueryLod requires DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearNV execution mode for GLCompute " "or DerivativeGroupLinearKHR execution mode for GLCompute, "
"execution model"); "MeshEXT or TaskEXT execution model");
} }
return false; return false;
} }
@ -2320,12 +2326,14 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model, ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) { std::string* message) {
if (model != spv::ExecutionModel::Fragment && if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) { model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) { if (message) {
*message = *message =
std::string( std::string(
"ImplicitLod instructions require Fragment or GLCompute " "ImplicitLod instructions require Fragment, GLCompute, "
"execution model: ") + "MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode); spvOpcodeString(opcode);
} }
return false; return false;
@ -2339,19 +2347,22 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id()); const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id()); const auto* modes = state.GetExecutionModes(entry_point->id());
if (models && if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() && (models->find(spv::ExecutionModel::GLCompute) != models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes || (!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) == (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() && modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) == modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) { modes->end()))) {
if (message) { if (message) {
*message = *message = std::string(
std::string( "ImplicitLod instructions require "
"ImplicitLod instructions require DerivativeGroupQuadsNV " "DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearNV execution mode for GLCompute " "or DerivativeGroupLinearKHR execution mode for "
"execution model: ") + "GLCompute, "
spvOpcodeString(opcode); "MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
} }
return false; return false;
} }

View File

@ -156,8 +156,8 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require Fragment or GLCompute " HasSubstr("Derivative instructions require Fragment, GLCompute, "
"execution model: DPdx")); "MeshEXT or TaskEXT execution model: DPdx"));
} }
TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) { TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) {
@ -181,8 +181,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require " HasSubstr("Derivative instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV " "DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute execution model")); "execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
} }
using ValidateHalfDerivatives = spvtest::ValidateBase<std::string>; using ValidateHalfDerivatives = spvtest::ValidateBase<std::string>;

View File

@ -4780,7 +4780,8 @@ TEST_F(ValidateImage, QueryLodWrongExecutionModel) {
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr( HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model")); "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
} }
TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) { TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) {
@ -4801,7 +4802,8 @@ OpFunctionEnd
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr( HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model")); "OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
} }
TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) { TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
@ -4813,12 +4815,12 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
)"; )";
const std::string extra = R"( const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_NV_compute_shader_derivatives" OpExtension "SPV_KHR_compute_shader_derivatives"
)"; )";
const std::string mode = R"( const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1 OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV OpExecutionMode %main DerivativeGroupLinearKHR
)"; )";
CompileSuccessfully( CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@ -4930,8 +4932,8 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivativesMissingMode) {
)"; )";
const std::string extra = R"( const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_NV_compute_shader_derivatives" OpExtension "SPV_KHR_compute_shader_derivatives"
)"; )";
const std::string mode = R"( const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1 OpExecutionMode %main LocalSize 8 8 1
@ -4940,9 +4942,9 @@ OpExecutionMode %main LocalSize 8 8 1
GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsNV or " HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearNV execution mode for GLCompute " "DerivativeGroupLinearKHR execution mode for "
"execution model")); "GLCompute, MeshEXT or TaskEXT execution model"));
} }
TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) { TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
@ -4956,8 +4958,8 @@ TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str()); CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require Fragment or " HasSubstr("ImplicitLod instructions require Fragment, "
"GLCompute execution model")); "GLCompute, MeshEXT or TaskEXT execution model"));
} }
TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) { TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
@ -4969,12 +4971,12 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
)"; )";
const std::string extra = R"( const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_NV_compute_shader_derivatives" OpExtension "SPV_KHR_compute_shader_derivatives"
)"; )";
const std::string mode = R"( const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1 OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV OpExecutionMode %main DerivativeGroupLinearKHR
)"; )";
CompileSuccessfully( CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str()); GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@ -4990,8 +4992,8 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivativesMissingMode) {
)"; )";
const std::string extra = R"( const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_NV_compute_shader_derivatives" OpExtension "SPV_KHR_compute_shader_derivatives"
)"; )";
const std::string mode = R"( const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1 OpExecutionMode %main LocalSize 8 8 1
@ -5001,9 +5003,9 @@ OpExecutionMode %main LocalSize 8 8 1
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT( EXPECT_THAT(
getDiagnosticString(), getDiagnosticString(),
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsNV or " HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearNV execution mode for GLCompute " "DerivativeGroupLinearKHR execution mode for GLCompute, "
"execution model")); "MeshEXT or TaskEXT execution model"));
} }
TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) { TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) {
@ -6505,8 +6507,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(), EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require " HasSubstr("ImplicitLod instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV " "DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute execution model")); "execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
} }
TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) { TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) {