Fold Min, Max, and Clamp instructions. (#2836)

Fixes #2830.
This commit is contained in:
Steven Perron 2019-09-05 13:30:03 -04:00 committed by GitHub
parent a41520eaa4
commit b218ad1994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 508 additions and 3 deletions

View File

@ -353,6 +353,10 @@ ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) {
if (!inst->IsFloatingPointFoldingAllowed()) {
return nullptr;
}
if (inst->opcode() == SpvOpExtInst) {
return FoldFPBinaryOp(scalar_rule, inst->type_id(),
{constants[1], constants[2]}, context);
}
return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context);
};
}
@ -893,6 +897,179 @@ ConstantFoldingRule FoldFMix() {
};
}
template <class IntType>
IntType FoldIClamp(IntType x, IntType min_val, IntType max_val) {
if (x < min_val) {
x = min_val;
}
if (x > max_val) {
x = max_val;
}
return x;
}
const analysis::Constant* FoldMin(const analysis::Type* result_type,
const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager*) {
if (const analysis::Integer* int_type = result_type->AsInteger()) {
if (int_type->width() == 32) {
if (int_type->IsSigned()) {
int32_t va = a->GetS32();
int32_t vb = b->GetS32();
return (va < vb ? a : b);
} else {
uint32_t va = a->GetU32();
uint32_t vb = b->GetU32();
return (va < vb ? a : b);
}
} else if (int_type->width() == 64) {
if (int_type->IsSigned()) {
int64_t va = a->GetS64();
int64_t vb = b->GetS64();
return (va < vb ? a : b);
} else {
uint64_t va = a->GetU64();
uint64_t vb = b->GetU64();
return (va < vb ? a : b);
}
}
} else if (const analysis::Float* float_type = result_type->AsFloat()) {
if (float_type->width() == 32) {
float va = a->GetFloat();
float vb = b->GetFloat();
return (va < vb ? a : b);
} else if (float_type->width() == 64) {
double va = a->GetDouble();
double vb = b->GetDouble();
return (va < vb ? a : b);
}
}
return nullptr;
}
const analysis::Constant* FoldMax(const analysis::Type* result_type,
const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager*) {
if (const analysis::Integer* int_type = result_type->AsInteger()) {
if (int_type->width() == 32) {
if (int_type->IsSigned()) {
int32_t va = a->GetS32();
int32_t vb = b->GetS32();
return (va > vb ? a : b);
} else {
uint32_t va = a->GetU32();
uint32_t vb = b->GetU32();
return (va > vb ? a : b);
}
} else if (int_type->width() == 64) {
if (int_type->IsSigned()) {
int64_t va = a->GetS64();
int64_t vb = b->GetS64();
return (va > vb ? a : b);
} else {
uint64_t va = a->GetU64();
uint64_t vb = b->GetU64();
return (va > vb ? a : b);
}
}
} else if (const analysis::Float* float_type = result_type->AsFloat()) {
if (float_type->width() == 32) {
float va = a->GetFloat();
float vb = b->GetFloat();
return (va > vb ? a : b);
} else if (float_type->width() == 64) {
double va = a->GetDouble();
double vb = b->GetDouble();
return (va > vb ? a : b);
}
}
return nullptr;
}
// Fold an clamp instruction when all three operands are constant.
const analysis::Constant* FoldClamp1(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
"Expecting a GLSLstd450 extended instruction.");
// Make sure all Clamp operands are constants.
for (uint32_t i = 1; i < 3; i++) {
if (constants[i] == nullptr) {
return nullptr;
}
}
const analysis::Constant* temp = FoldFPBinaryOp(
FoldMax, inst->type_id(), {constants[1], constants[2]}, context);
if (temp == nullptr) {
return nullptr;
}
return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]},
context);
}
// Fold a clamp instruction when |x >= min_val|.
const analysis::Constant* FoldClamp2(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
"Expecting a GLSLstd450 extended instruction.");
const analysis::Constant* x = constants[1];
const analysis::Constant* min_val = constants[2];
if (x == nullptr || min_val == nullptr) {
return nullptr;
}
const analysis::Constant* temp =
FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context);
if (temp == min_val) {
// We can assume that |min_val| is less than |max_val|. Therefore, if the
// result of the max operation is |min_val|, we know the result of the min
// operation, even if |max_val| is not a constant.
return min_val;
}
return nullptr;
}
// Fold a clamp instruction when |x >= max_val|.
const analysis::Constant* FoldClamp3(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Expecting an extended instruction.");
assert(inst->GetSingleWordInOperand(0) ==
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
"Expecting a GLSLstd450 extended instruction.");
const analysis::Constant* x = constants[1];
const analysis::Constant* max_val = constants[3];
if (x == nullptr || max_val == nullptr) {
return nullptr;
}
const analysis::Constant* temp =
FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context);
if (temp == max_val) {
// We can assume that |min_val| is less than |max_val|. Therefore, if the
// result of the max operation is |min_val|, we know the result of the min
// operation, even if |max_val| is not a constant.
return max_val;
}
return nullptr;
}
} // namespace
void ConstantFoldingRules::AddFoldingRules() {
@ -968,6 +1145,36 @@ void ConstantFoldingRules::AddFoldingRules() {
feature_manager->GetExtInstImportId_GLSLstd450();
if (ext_inst_glslstd450_id != 0) {
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix());
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back(
FoldFPBinaryOp(FoldMin));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back(
FoldFPBinaryOp(FoldMin));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back(
FoldFPBinaryOp(FoldMin));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back(
FoldFPBinaryOp(FoldMax));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back(
FoldFPBinaryOp(FoldMax));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back(
FoldFPBinaryOp(FoldMax));
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
FoldClamp1);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
FoldClamp2);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back(
FoldClamp3);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
FoldClamp1);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
FoldClamp2);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back(
FoldClamp3);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
FoldClamp1);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
FoldClamp2);
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back(
FoldClamp3);
}
}
} // namespace opt

View File

@ -232,7 +232,9 @@ OpName %main "main"
%double_2 = OpConstant %double 2
%double_3 = OpConstant %double 3
%double_4 = OpConstant %double 4
%double_5 = OpConstant %double 5
%double_0p5 = OpConstant %double 0.5
%double_0p2 = OpConstant %double 0.2
%v2double_0_0 = OpConstantComposite %v2double %double_0 %double_0
%v2double_2_2 = OpConstantComposite %v2double %double_2 %double_2
%v2double_2_3 = OpConstantComposite %v2double %double_2 %double_3
@ -558,7 +560,155 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
"%2 = OpSNegate %int %int_min\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, std::numeric_limits<int32_t>::min())
2, std::numeric_limits<int32_t>::min()),
// Test case 30: fold UMin 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UMin %uint_3 %uint_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 31: fold UMin 4 2
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UMin %uint_4 %uint_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 32: fold SMin 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 UMin %int_3 %int_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 33: fold SMin 4 2
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 SMin %int_4 %int_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 34: fold UMax 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UMax %uint_3 %uint_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4),
// Test case 35: fold UMax 3 2
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UMax %uint_3 %uint_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 36: fold SMax 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 UMax %int_3 %int_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4),
// Test case 37: fold SMax 3 2
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 SMax %int_3 %int_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 38: fold UClamp 2 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UClamp %uint_2 %uint_3 %uint_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 39: fold UClamp 2 0 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UClamp %uint_2 %uint_0 %uint_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 40: fold UClamp 2 0 1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %uint %1 UClamp %uint_2 %uint_0 %uint_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1),
// Test case 41: fold SClamp 2 3 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 SClamp %int_2 %int_3 %int_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 42: fold SClamp 2 0 4
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 SClamp %int_2 %int_0 %int_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 43: fold SClamp 2 0 1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %int %1 SClamp %int_2 %int_0 %int_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1),
// Test case 44: SClamp 1 2 x
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %int\n" +
"%2 = OpExtInst %int %1 SClamp %int_1 %int_2 %undef\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 45: SClamp 2 x 1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %int\n" +
"%2 = OpExtInst %int %1 SClamp %int_2 %undef %int_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1),
// Test case 44: UClamp 1 2 x
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %uint\n" +
"%2 = OpExtInst %uint %1 UClamp %uint_1 %uint_2 %undef\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2),
// Test case 45: UClamp 2 x 1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %uint\n" +
"%2 = OpExtInst %uint %1 UClamp %uint_2 %undef %uint_1\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1)
));
// clang-format on
@ -1526,7 +1676,81 @@ INSTANTIATE_TEST_SUITE_P(FloatConstantFoldingTest, FloatInstructionFoldingTest,
"%2 = OpExtInst %float %1 FMix %float_1 %float_4 %float_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.6f)
2, 1.6f),
// Test case 21: FMin 1.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FMin %float_1 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
// Test case 22: FMin 4.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FMin %float_4 %float_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.2f),
// Test case 21: FMax 1.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FMax %float_1 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
// Test case 22: FMax 1.0 0.2
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FMax %float_1 %float_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
// Test case 23: FClamp 1.0 0.2 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FClamp %float_1 %float_0p2 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0f),
// Test case 24: FClamp 0.2 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FClamp %float_0p2 %float_2 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0f),
// Test case 25: FClamp 2049.0 2.0 4.0
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %float %1 FClamp %float_2049 %float_2 %float_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0f),
// Test case 26: FClamp 1.0 2.0 x
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %float\n" +
"%2 = OpExtInst %float %1 FClamp %float_1 %float_2 %undef\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 27: FClamp 1.0 x 0.5
InstructionFoldingCase<float>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %float\n" +
"%2 = OpExtInst %float %1 FClamp %float_1 %undef %float_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.5)
));
// clang-format on
@ -1669,7 +1893,81 @@ INSTANTIATE_TEST_SUITE_P(DoubleConstantFoldingTest, DoubleInstructionFoldingTest
"%2 = OpFNegate %double %double_2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, -2)
2, -2),
// Test case 12: FMin 1.0 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FMin %double_1 %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 13: FMin 4.0 0.2
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FMin %double_4 %double_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.2),
// Test case 14: FMax 1.0 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FMax %double_1 %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0),
// Test case 15: FMax 1.0 0.2
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FMax %double_1 %double_0p2\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 16: FClamp 1.0 0.2 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FClamp %double_1 %double_0p2 %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 1.0),
// Test case 17: FClamp 0.2 2.0 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FClamp %double_0p2 %double_2 %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 18: FClamp 5.0 2.0 4.0
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpExtInst %double %1 FClamp %double_5 %double_2 %double_4\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 4.0),
// Test case 19: FClamp 1.0 2.0 x
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %double\n" +
"%2 = OpExtInst %double %1 FClamp %double_1 %double_2 %undef\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 2.0),
// Test case 20: FClamp 1.0 x 0.5
InstructionFoldingCase<double>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%undef = OpUndef %double\n" +
"%2 = OpExtInst %double %1 FClamp %double_1 %undef %double_0p5\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0.5)
));
// clang-format on