mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-12-02 07:40:08 +00:00
WebGPU: Array size at most max signed int + 1 (#3077)
This makes it easier to clamp indices for robust-buffer-access behaviour. See https://github.com/gpuweb/spirv-execution-env/issues/47
This commit is contained in:
parent
0a5d99d02c
commit
e82a428605
@ -19,29 +19,40 @@
|
|||||||
#include "source/val/instruction.h"
|
#include "source/val/instruction.h"
|
||||||
#include "source/val/validate.h"
|
#include "source/val/validate.h"
|
||||||
#include "source/val/validation_state.h"
|
#include "source/val/validation_state.h"
|
||||||
|
#include "spirv/unified1/spirv.h"
|
||||||
|
|
||||||
namespace spvtools {
|
namespace spvtools {
|
||||||
namespace val {
|
namespace val {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// True if the integer constant is > 0. |const_words| are words of the
|
// Returns, as an int64_t, the literal value from an OpConstant or the
|
||||||
// constant-defining instruction (either OpConstant or
|
// default value of an OpSpecConstant, assuming it is an integral type.
|
||||||
// OpSpecConstant). typeWords are the words of the constant's-type-defining
|
// For signed integers, relies the rule that literal value is sign extended
|
||||||
// OpTypeInt.
|
// to fill out to word granularity. Assumes that the constant value
|
||||||
bool AboveZero(const std::vector<uint32_t>& const_words,
|
// has
|
||||||
const std::vector<uint32_t>& type_words) {
|
int64_t ConstantLiteralAsInt64(uint32_t width,
|
||||||
const uint32_t width = type_words[2];
|
const std::vector<uint32_t>& const_words) {
|
||||||
const bool is_signed = type_words[3] > 0;
|
|
||||||
const uint32_t lo_word = const_words[3];
|
const uint32_t lo_word = const_words[3];
|
||||||
if (width > 32) {
|
if (width <= 32) return int32_t(lo_word);
|
||||||
// The spec currently doesn't allow integers wider than 64 bits.
|
assert(width <= 64);
|
||||||
const uint32_t hi_word = const_words[4]; // Must exist, per spec.
|
assert(const_words.size() > 4);
|
||||||
if (is_signed && (hi_word >> 31)) return false;
|
const uint32_t hi_word = const_words[4]; // Must exist, per spec.
|
||||||
return (lo_word | hi_word) > 0;
|
return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
|
||||||
} else {
|
}
|
||||||
if (is_signed && (lo_word >> 31)) return false;
|
|
||||||
return lo_word > 0;
|
// Returns, as an uint64_t, the literal value from an OpConstant or the
|
||||||
}
|
// default value of an OpSpecConstant, assuming it is an integral type.
|
||||||
|
// For signed integers, relies the rule that literal value is sign extended
|
||||||
|
// to fill out to word granularity. Assumes that the constant value
|
||||||
|
// has
|
||||||
|
int64_t ConstantLiteralAsUint64(uint32_t width,
|
||||||
|
const std::vector<uint32_t>& const_words) {
|
||||||
|
const uint32_t lo_word = const_words[3];
|
||||||
|
if (width <= 32) return lo_word;
|
||||||
|
assert(width <= 64);
|
||||||
|
assert(const_words.size() > 4);
|
||||||
|
const uint32_t hi_word = const_words[4]; // Must exist, per spec.
|
||||||
|
return (uint64_t(lo_word) | uint64_t(hi_word) << 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validates that type declarations are unique, unless multiple declarations
|
// Validates that type declarations are unique, unless multiple declarations
|
||||||
@ -258,14 +269,33 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
|
|||||||
|
|
||||||
switch (length->opcode()) {
|
switch (length->opcode()) {
|
||||||
case SpvOpSpecConstant:
|
case SpvOpSpecConstant:
|
||||||
case SpvOpConstant:
|
case SpvOpConstant: {
|
||||||
if (AboveZero(length->words(), const_result_type->words())) break;
|
auto& type_words = const_result_type->words();
|
||||||
// Else fall through!
|
const bool is_signed = type_words[3] > 0;
|
||||||
case SpvOpConstantNull: {
|
const uint32_t width = type_words[2];
|
||||||
|
const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
|
||||||
|
if (ivalue == 0 || (ivalue < 0 && is_signed)) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "OpTypeArray Length <id> '" << _.getIdName(length_id)
|
||||||
|
<< "' default value must be at least 1: found " << ivalue;
|
||||||
|
}
|
||||||
|
if (spvIsWebGPUEnv(_.context()->target_env)) {
|
||||||
|
// WebGPU has maximum integer width of 32 bits, and max array size
|
||||||
|
// is one more than the max signed integer representation.
|
||||||
|
const uint64_t max_permitted = (uint64_t(1) << 31);
|
||||||
|
const uint64_t uvalue = ConstantLiteralAsUint64(width, length->words());
|
||||||
|
if (uvalue > max_permitted) {
|
||||||
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
|
<< "OpTypeArray Length <id> '" << _.getIdName(length_id)
|
||||||
|
<< "' size exceeds max value " << max_permitted
|
||||||
|
<< " permitted by WebGPU: got " << uvalue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
|
case SpvOpConstantNull:
|
||||||
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
||||||
<< "OpTypeArray Length <id> '" << _.getIdName(length_id)
|
<< "OpTypeArray Length <id> '" << _.getIdName(length_id)
|
||||||
<< "' default value must be at least 1.";
|
<< "' default value must be at least 1.";
|
||||||
}
|
|
||||||
case SpvOpSpecConstantOp:
|
case SpvOpSpecConstantOp:
|
||||||
// Assume it's OK, rather than try to evaluate the operation.
|
// Assume it's OK, rather than try to evaluate the operation.
|
||||||
break;
|
break;
|
||||||
|
@ -749,20 +749,40 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayElementTypeBad) {
|
|||||||
// Signed or unsigned.
|
// Signed or unsigned.
|
||||||
enum Signed { kSigned, kUnsigned };
|
enum Signed { kSigned, kUnsigned };
|
||||||
|
|
||||||
// Creates an assembly snippet declaring OpTypeArray with the given length.
|
// Creates an assembly module declaring OpTypeArray with the given length.
|
||||||
std::string MakeArrayLength(const std::string& len, Signed isSigned,
|
std::string MakeArrayLength(const std::string& len, Signed isSigned, int width,
|
||||||
int width) {
|
int max_int_width = 64,
|
||||||
|
bool use_vulkan_memory_model = false) {
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
ss << R"(
|
ss << R"(
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
OpCapability Linkage
|
|
||||||
OpCapability Int16
|
|
||||||
OpCapability Int64
|
|
||||||
)";
|
)";
|
||||||
ss << "OpMemoryModel Logical GLSL450\n";
|
if (use_vulkan_memory_model) {
|
||||||
|
ss << " OpCapability VulkanMemoryModel\n";
|
||||||
|
}
|
||||||
|
if (width == 16) {
|
||||||
|
ss << " OpCapability Int16\n";
|
||||||
|
}
|
||||||
|
if (max_int_width > 32) {
|
||||||
|
ss << "\n OpCapability Int64\n";
|
||||||
|
}
|
||||||
|
if (use_vulkan_memory_model) {
|
||||||
|
ss << " OpExtension \"SPV_KHR_vulkan_memory_model\"\n";
|
||||||
|
ss << "OpMemoryModel Logical Vulkan\n";
|
||||||
|
} else {
|
||||||
|
ss << "OpMemoryModel Logical GLSL450\n";
|
||||||
|
}
|
||||||
|
ss << "OpEntryPoint GLCompute %main \"main\"\n";
|
||||||
|
ss << "OpExecutionMode %main LocalSize 1 1 1\n";
|
||||||
ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0");
|
ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0");
|
||||||
ss << " %l = OpConstant %t " << len;
|
ss << " %l = OpConstant %t " << len;
|
||||||
ss << " %a = OpTypeArray %t %l";
|
ss << " %a = OpTypeArray %t %l";
|
||||||
|
ss << " %void = OpTypeVoid \n"
|
||||||
|
" %voidfn = OpTypeFunction %void \n"
|
||||||
|
" %main = OpFunction %void None %voidfn \n"
|
||||||
|
" %entry = OpLabel\n"
|
||||||
|
" OpReturn\n"
|
||||||
|
" OpFunctionEnd\n";
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -772,7 +792,8 @@ class OpTypeArrayLengthTest
|
|||||||
: public spvtest::TextToBinaryTestBase<::testing::TestWithParam<int>> {
|
: public spvtest::TextToBinaryTestBase<::testing::TestWithParam<int>> {
|
||||||
protected:
|
protected:
|
||||||
OpTypeArrayLengthTest()
|
OpTypeArrayLengthTest()
|
||||||
: position_(spv_position_t{0, 0, 0}),
|
: env_(SPV_ENV_UNIVERSAL_1_0),
|
||||||
|
position_(spv_position_t{0, 0, 0}),
|
||||||
diagnostic_(spvDiagnosticCreate(&position_, "")) {}
|
diagnostic_(spvDiagnosticCreate(&position_, "")) {}
|
||||||
|
|
||||||
~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); }
|
~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); }
|
||||||
@ -783,7 +804,7 @@ class OpTypeArrayLengthTest
|
|||||||
spvDiagnosticDestroy(diagnostic_);
|
spvDiagnosticDestroy(diagnostic_);
|
||||||
diagnostic_ = nullptr;
|
diagnostic_ = nullptr;
|
||||||
const auto status =
|
const auto status =
|
||||||
spvValidate(ScopedContext().context, &cbinary, &diagnostic_);
|
spvValidate(ScopedContext(env_).context, &cbinary, &diagnostic_);
|
||||||
if (status != SPV_SUCCESS) {
|
if (status != SPV_SUCCESS) {
|
||||||
spvDiagnosticPrint(diagnostic_);
|
spvDiagnosticPrint(diagnostic_);
|
||||||
EXPECT_THAT(std::string(diagnostic_->error),
|
EXPECT_THAT(std::string(diagnostic_->error),
|
||||||
@ -792,12 +813,15 @@ class OpTypeArrayLengthTest
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
spv_target_env env_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
spv_position_t position_; // For creating diagnostic_.
|
spv_position_t position_; // For creating diagnostic_.
|
||||||
spv_diagnostic diagnostic_;
|
spv_diagnostic diagnostic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(OpTypeArrayLengthTest, LengthPositive) {
|
TEST_P(OpTypeArrayLengthTest, LengthPositiveSmall) {
|
||||||
const int width = GetParam();
|
const int width = GetParam();
|
||||||
EXPECT_EQ(SPV_SUCCESS,
|
EXPECT_EQ(SPV_SUCCESS,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width))));
|
Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width))));
|
||||||
@ -814,20 +838,19 @@ TEST_P(OpTypeArrayLengthTest, LengthPositive) {
|
|||||||
const std::string fpad(width / 4 - 1, 'F');
|
const std::string fpad(width / 4 - 1, 'F');
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
SPV_SUCCESS,
|
SPV_SUCCESS,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width))));
|
Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width))))
|
||||||
EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully(
|
<< MakeArrayLength("0x7" + fpad, kSigned, width);
|
||||||
MakeArrayLength("0xF" + fpad, kUnsigned, width))));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(OpTypeArrayLengthTest, LengthZero) {
|
TEST_P(OpTypeArrayLengthTest, LengthZero) {
|
||||||
const int width = GetParam();
|
const int width = GetParam();
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -835,23 +858,88 @@ TEST_P(OpTypeArrayLengthTest, LengthNegative) {
|
|||||||
const int width = GetParam();
|
const int width = GetParam();
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0');
|
const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0');
|
||||||
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)),
|
Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)),
|
||||||
"OpTypeArray Length <id> '2\\[%.*\\]' default value must be at "
|
"OpTypeArray Length <id> '3\\[%.*\\]' default value must be at "
|
||||||
"least 1."));
|
"least 1."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the string form of an integer of the form 0x80....0 of the
|
||||||
|
// given bit width.
|
||||||
|
std::string big_num_ending_0(int bit_width) {
|
||||||
|
return "0x8" + std::string(bit_width / 4 - 1, '0');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the string form of an integer of the form 0x80..001 of the
|
||||||
|
// given bit width.
|
||||||
|
std::string big_num_ending_1(int bit_width) {
|
||||||
|
return "0x8" + std::string(bit_width / 4 - 2, '0') + "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InVulkan) {
|
||||||
|
env_ = SPV_ENV_VULKAN_1_0;
|
||||||
|
const int width = GetParam();
|
||||||
|
for (int max_int_width : {32, 64}) {
|
||||||
|
if (width > max_int_width) {
|
||||||
|
// Not valid to even make the OpConstant in this case.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto module = CompileSuccessfully(MakeArrayLength(
|
||||||
|
big_num_ending_0(width), kUnsigned, width, max_int_width));
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, Val(module));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InVulkan) {
|
||||||
|
env_ = SPV_ENV_VULKAN_1_0;
|
||||||
|
const int width = GetParam();
|
||||||
|
for (int max_int_width : {32, 64}) {
|
||||||
|
if (width > max_int_width) {
|
||||||
|
// Not valid to even make the OpConstant in this case.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto module = CompileSuccessfully(MakeArrayLength(
|
||||||
|
big_num_ending_1(width), kUnsigned, width, max_int_width));
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, Val(module));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InWebGPU) {
|
||||||
|
env_ = SPV_ENV_WEBGPU_0;
|
||||||
|
const int width = GetParam();
|
||||||
|
// WebGPU only has 32 bit integers.
|
||||||
|
if (width != 32) return;
|
||||||
|
const int max_int_width = 32;
|
||||||
|
const auto module = CompileSuccessfully(MakeArrayLength(
|
||||||
|
big_num_ending_0(width), kUnsigned, width, max_int_width, true));
|
||||||
|
EXPECT_EQ(SPV_SUCCESS, Val(module));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InWebGPU) {
|
||||||
|
env_ = SPV_ENV_WEBGPU_0;
|
||||||
|
const int width = GetParam();
|
||||||
|
// WebGPU only has 32 bit integers.
|
||||||
|
if (width != 32) return;
|
||||||
|
const int max_int_width = 32;
|
||||||
|
const auto module = CompileSuccessfully(MakeArrayLength(
|
||||||
|
big_num_ending_1(width), kUnsigned, width, max_int_width, true));
|
||||||
|
EXPECT_EQ(SPV_ERROR_INVALID_ID,
|
||||||
|
Val(module,
|
||||||
|
"OpTypeArray Length <id> '3\\[%.*\\]' size exceeds max value "
|
||||||
|
"2147483648 permitted by WebGPU: got 2147483649"));
|
||||||
|
}
|
||||||
|
|
||||||
// The only valid widths for integers are 8, 16, 32, and 64.
|
// The only valid widths for integers are 8, 16, 32, and 64.
|
||||||
// Since the Int8 capability requires the Kernel capability, and the Kernel
|
// Since the Int8 capability requires the Kernel capability, and the Kernel
|
||||||
// capability prohibits usage of signed integers, we can skip 8-bit integers
|
// capability prohibits usage of signed integers, we can skip 8-bit integers
|
||||||
|
Loading…
Reference in New Issue
Block a user