Add support for SPV_KHR_compute_shader_derivative (#5817)

* Add support for SPV_KHR_compute_shader_derivative

* Update tests for SPV_KHR_compute_shader_derivatives

---------

Co-authored-by: MagicPoncho <magicponcho@gmail.com>
This commit is contained in:
JN Mo 2024-09-25 09:59:33 -04:00 committed by GitHub
parent 362ce7c60d
commit 44936c4a9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 90 additions and 69 deletions

View File

@ -1010,7 +1010,7 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"

View File

@ -428,9 +428,9 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_uniform_group_instructions",
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
"SPV_EXT_fragment_shader_interlock",
"SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
}
bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(

View File

@ -291,7 +291,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});

View File

@ -141,7 +141,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_NV_bindless_texture",
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});

View File

@ -74,8 +74,8 @@ class TrimCapabilitiesPass : public Pass {
// contains unsupported instruction, the pass could yield bad results.
static constexpr std::array kSupportedCapabilities{
// clang-format off
spv::Capability::ComputeDerivativeGroupLinearNV,
spv::Capability::ComputeDerivativeGroupQuadsNV,
spv::Capability::ComputeDerivativeGroupLinearKHR,
spv::Capability::ComputeDerivativeGroupQuadsKHR,
spv::Capability::Float16,
spv::Capability::Float64,
spv::Capability::FragmentShaderPixelInterlockEXT,

View File

@ -60,12 +60,14 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message =
std::string(
"Derivative instructions require Fragment or GLCompute "
"execution model: ") +
"Derivative instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
@ -79,19 +81,23 @@ spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
(models->find(spv::ExecutionModel::GLCompute) !=
models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message = std::string(
"Derivative instructions require "
"DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for "
"GLCompute execution model: ") +
spvOpcodeString(opcode);
*message =
std::string(
"Derivative instructions require "
"DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for "
"GLCompute, MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
}

View File

@ -2026,11 +2026,13 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
->RegisterExecutionModelLimitation(
[&](spv::ExecutionModel model, std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message = std::string(
"OpImageQueryLod requires Fragment or GLCompute execution "
"model");
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or "
"TaskEXT execution model");
}
return false;
}
@ -2042,16 +2044,20 @@ spv_result_t ValidateImageQueryLod(ValidationState_t& _,
std::string* message) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models->find(spv::ExecutionModel::GLCompute) != models->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->end()) {
if (models &&
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message = std::string(
"OpImageQueryLod requires DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for GLCompute "
"execution model");
"OpImageQueryLod requires DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for GLCompute, "
"MeshEXT or TaskEXT execution model");
}
return false;
}
@ -2320,12 +2326,14 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
std::string* message) {
if (model != spv::ExecutionModel::Fragment &&
model != spv::ExecutionModel::GLCompute) {
model != spv::ExecutionModel::GLCompute &&
model != spv::ExecutionModel::MeshEXT &&
model != spv::ExecutionModel::TaskEXT) {
if (message) {
*message =
std::string(
"ImplicitLod instructions require Fragment or GLCompute "
"execution model: ") +
"ImplicitLod instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
@ -2339,19 +2347,22 @@ spv_result_t ImagePass(ValidationState_t& _, const Instruction* inst) {
const auto* models = state.GetExecutionModels(entry_point->id());
const auto* modes = state.GetExecutionModes(entry_point->id());
if (models &&
models->find(spv::ExecutionModel::GLCompute) != models->end() &&
(models->find(spv::ExecutionModel::GLCompute) != models->end() ||
models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
(!modes ||
(modes->find(spv::ExecutionMode::DerivativeGroupLinearNV) ==
(modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
modes->end() &&
modes->find(spv::ExecutionMode::DerivativeGroupQuadsNV) ==
modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
modes->end()))) {
if (message) {
*message =
std::string(
"ImplicitLod instructions require DerivativeGroupQuadsNV "
"or DerivativeGroupLinearNV execution mode for GLCompute "
"execution model: ") +
spvOpcodeString(opcode);
*message = std::string(
"ImplicitLod instructions require "
"DerivativeGroupQuadsKHR "
"or DerivativeGroupLinearKHR execution mode for "
"GLCompute, "
"MeshEXT or TaskEXT execution model: ") +
spvOpcodeString(opcode);
}
return false;
}

View File

@ -156,8 +156,8 @@ TEST_F(ValidateDerivatives, OpDPdxWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require Fragment or GLCompute "
"execution model: DPdx"));
HasSubstr("Derivative instructions require Fragment, GLCompute, "
"MeshEXT or TaskEXT execution model: DPdx"));
}
TEST_F(ValidateDerivatives, NoExecutionModeGLCompute) {
@ -181,8 +181,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Derivative instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
"execution mode for GLCompute execution model"));
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
}
using ValidateHalfDerivatives = spvtest::ValidateBase<std::string>;

View File

@ -4780,7 +4780,8 @@ TEST_F(ValidateImage, QueryLodWrongExecutionModel) {
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model"));
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
}
TEST_F(ValidateImage, QueryLodWrongExecutionModelWithFunc) {
@ -4801,7 +4802,8 @@ OpFunctionEnd
EXPECT_THAT(
getDiagnosticString(),
HasSubstr(
"OpImageQueryLod requires Fragment or GLCompute execution model"));
"OpImageQueryLod requires Fragment, GLCompute, MeshEXT or TaskEXT "
"execution model"));
}
TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
@ -4813,12 +4815,12 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivatives) {
)";
const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV
OpExecutionMode %main DerivativeGroupLinearKHR
)";
CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@ -4930,8 +4932,8 @@ TEST_F(ValidateImage, QueryLodComputeShaderDerivativesMissingMode) {
)";
const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
@ -4940,9 +4942,9 @@ OpExecutionMode %main LocalSize 8 8 1
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsNV or "
"DerivativeGroupLinearNV execution mode for GLCompute "
"execution model"));
HasSubstr("OpImageQueryLod requires DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearKHR execution mode for "
"GLCompute, MeshEXT or TaskEXT execution model"));
}
TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
@ -4956,8 +4958,8 @@ TEST_F(ValidateImage, ImplicitLodWrongExecutionModel) {
CompileSuccessfully(GenerateShaderCode(body, "", "Vertex").c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require Fragment or "
"GLCompute execution model"));
HasSubstr("ImplicitLod instructions require Fragment, "
"GLCompute, MeshEXT or TaskEXT execution model"));
}
TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
@ -4969,12 +4971,12 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivatives) {
)";
const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
OpExecutionMode %main DerivativeGroupLinearNV
OpExecutionMode %main DerivativeGroupLinearKHR
)";
CompileSuccessfully(
GenerateShaderCode(body, extra, "GLCompute", mode).c_str());
@ -4990,8 +4992,8 @@ TEST_F(ValidateImage, ImplicitLodComputeShaderDerivativesMissingMode) {
)";
const std::string extra = R"(
OpCapability ComputeDerivativeGroupLinearNV
OpExtension "SPV_NV_compute_shader_derivatives"
OpCapability ComputeDerivativeGroupLinearKHR
OpExtension "SPV_KHR_compute_shader_derivatives"
)";
const std::string mode = R"(
OpExecutionMode %main LocalSize 8 8 1
@ -5001,9 +5003,9 @@ OpExecutionMode %main LocalSize 8 8 1
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsNV or "
"DerivativeGroupLinearNV execution mode for GLCompute "
"execution model"));
HasSubstr("ImplicitLod instructions require DerivativeGroupQuadsKHR or "
"DerivativeGroupLinearKHR execution mode for GLCompute, "
"MeshEXT or TaskEXT execution model"));
}
TEST_F(ValidateImage, ReadSubpassDataWrongExecutionModel) {
@ -6505,8 +6507,9 @@ OpFunctionEnd
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("ImplicitLod instructions require "
"DerivativeGroupQuadsNV or DerivativeGroupLinearNV "
"execution mode for GLCompute execution model"));
"DerivativeGroupQuadsKHR or DerivativeGroupLinearKHR "
"execution mode for GLCompute, MeshEXT or TaskEXT "
"execution model"));
}
TEST_F(ValidateImage, TypeSampledImageNotBufferPost1p6) {