SPIRV-Tools/source/opt/folding_rules.cpp
Steven Perron 00ca4e5bdf
Don't crash when folding construct of empty struct (#3092)
* Don't crash when folding construct of empty struct

An OpCompositeConstruct of an empty struct will be folded to a constant
under normal circumstances.  However, if the id limit has been reached
and the constant cannot be generated, then other folding rules will be
tried.

These rules do not handle the case of an empty struct.  We add allow it
to be handled.

Fixes http://crbug/1030194

* Changes based on the review.
2019-12-10 14:58:30 -05:00

2538 lines
94 KiB
C++

// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/folding_rules.h"
#include <limits>
#include <memory>
#include <utility>
#include "ir_builder.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/opt/ir_context.h"
namespace spvtools {
namespace opt {
namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
const uint32_t kInsertObjectIdInIdx = 0;
const uint32_t kInsertCompositeIdInIdx = 1;
const uint32_t kExtInstSetIdInIdx = 0;
const uint32_t kExtInstInstructionInIdx = 1;
const uint32_t kFMixXIdInIdx = 2;
const uint32_t kFMixYIdInIdx = 3;
const uint32_t kFMixAIdInIdx = 4;
const uint32_t kStoreObjectInIdx = 1;
// Some image instructions may contain an "image operands" argument.
// Returns the operand index for the "image operands".
// Returns -1 if the instruction does not have image operands.
int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
const auto opcode = inst->opcode();
switch (opcode) {
case SpvOpImageSampleImplicitLod:
case SpvOpImageSampleExplicitLod:
case SpvOpImageSampleProjImplicitLod:
case SpvOpImageSampleProjExplicitLod:
case SpvOpImageFetch:
case SpvOpImageRead:
case SpvOpImageSparseSampleImplicitLod:
case SpvOpImageSparseSampleExplicitLod:
case SpvOpImageSparseSampleProjImplicitLod:
case SpvOpImageSparseSampleProjExplicitLod:
case SpvOpImageSparseFetch:
case SpvOpImageSparseRead:
return inst->NumOperands() > 4 ? 2 : -1;
case SpvOpImageSampleDrefImplicitLod:
case SpvOpImageSampleDrefExplicitLod:
case SpvOpImageSampleProjDrefImplicitLod:
case SpvOpImageSampleProjDrefExplicitLod:
case SpvOpImageGather:
case SpvOpImageDrefGather:
case SpvOpImageSparseSampleDrefImplicitLod:
case SpvOpImageSparseSampleDrefExplicitLod:
case SpvOpImageSparseSampleProjDrefImplicitLod:
case SpvOpImageSparseSampleProjDrefExplicitLod:
case SpvOpImageSparseGather:
case SpvOpImageSparseDrefGather:
return inst->NumOperands() > 5 ? 3 : -1;
case SpvOpImageWrite:
return inst->NumOperands() > 3 ? 3 : -1;
default:
return -1;
}
}
// Returns the element width of |type|.
uint32_t ElementWidth(const analysis::Type* type) {
if (const analysis::Vector* vec_type = type->AsVector()) {
return ElementWidth(vec_type->element_type());
} else if (const analysis::Float* float_type = type->AsFloat()) {
return float_type->width();
} else {
assert(type->AsInteger());
return type->AsInteger()->width();
}
}
// Returns true if |type| is Float or a vector of Float.
bool HasFloatingPoint(const analysis::Type* type) {
if (type->AsFloat()) {
return true;
} else if (const analysis::Vector* vec_type = type->AsVector()) {
return vec_type->element_type()->AsFloat() != nullptr;
}
return false;
}
// Returns false if |val| is NaN, infinite or subnormal.
template <typename T>
bool IsValidResult(T val) {
int classified = std::fpclassify(val);
switch (classified) {
case FP_NAN:
case FP_INFINITE:
case FP_SUBNORMAL:
return false;
default:
return true;
}
}
const analysis::Constant* ConstInput(
const std::vector<const analysis::Constant*>& constants) {
return constants[0] ? constants[0] : constants[1];
}
Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
Instruction* inst) {
uint32_t in_op = c ? 1u : 0u;
return context->get_def_use_mgr()->GetDef(
inst->GetSingleWordInOperand(in_op));
}
// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
// constant.
uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
assert(c);
assert(c->type()->AsFloat());
uint32_t width = c->type()->AsFloat()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
if (width == 64) {
utils::FloatProxy<double> result(c->GetDouble() * -1.0);
words = result.GetWords();
} else {
utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
words = result.GetWords();
}
const analysis::Constant* negated_const =
const_mgr->GetConstant(c->type(), std::move(words));
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
}
std::vector<uint32_t> ExtractInts(uint64_t val) {
std::vector<uint32_t> words;
words.push_back(static_cast<uint32_t>(val));
words.push_back(static_cast<uint32_t>(val >> 32));
return words;
}
// Negates the integer constant |c|. Returns the id of the defining instruction.
uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
assert(c);
assert(c->type()->AsInteger());
uint32_t width = c->type()->AsInteger()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
if (width == 64) {
uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
words = ExtractInts(uval);
} else {
words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
}
const analysis::Constant* negated_const =
const_mgr->GetConstant(c->type(), std::move(words));
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
}
// Negates the vector constant |c|. Returns the id of the defining instruction.
uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
assert(const_mgr && c);
assert(c->type()->AsVector());
if (c->AsNullConstant()) {
// 0.0 vs -0.0 shouldn't matter.
return const_mgr->GetDefiningInstruction(c)->result_id();
} else {
const analysis::Type* component_type =
c->AsVectorConstant()->component_type();
std::vector<uint32_t> words;
for (auto& comp : c->AsVectorConstant()->GetComponents()) {
if (component_type->AsFloat()) {
words.push_back(NegateFloatingPointConstant(const_mgr, comp));
} else {
assert(component_type->AsInteger());
words.push_back(NegateIntegerConstant(const_mgr, comp));
}
}
const analysis::Constant* negated_const =
const_mgr->GetConstant(c->type(), std::move(words));
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
}
}
// Negates |c|. Returns the id of the defining instruction.
uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
if (c->type()->AsVector()) {
return NegateVectorConstant(const_mgr, c);
} else if (c->type()->AsFloat()) {
return NegateFloatingPointConstant(const_mgr, c);
} else {
assert(c->type()->AsInteger());
return NegateIntegerConstant(const_mgr, c);
}
}
// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
// Returns 0 if the reciprocal is NaN, infinite or subnormal.
uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
const analysis::Constant* c) {
assert(const_mgr && c);
assert(c->type()->AsFloat());
uint32_t width = c->type()->AsFloat()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
if (width == 64) {
spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
if (!IsValidResult(result.getAsFloat())) return 0;
words = result.GetWords();
} else {
spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
if (!IsValidResult(result.getAsFloat())) return 0;
words = result.GetWords();
}
const analysis::Constant* negated_const =
const_mgr->GetConstant(c->type(), std::move(words));
return const_mgr->GetDefiningInstruction(negated_const)->result_id();
}
// Replaces fdiv where second operand is constant with fmul.
FoldingRule ReciprocalFDiv() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
if (constants[1] != nullptr) {
uint32_t id = 0;
if (const analysis::VectorConstant* vector_const =
constants[1]->AsVectorConstant()) {
std::vector<uint32_t> neg_ids;
for (auto& comp : vector_const->GetComponents()) {
id = Reciprocal(const_mgr, comp);
if (id == 0) return false;
neg_ids.push_back(id);
}
const analysis::Constant* negated_const =
const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
} else if (constants[1]->AsFloatConstant()) {
id = Reciprocal(const_mgr, constants[1]);
if (id == 0) return false;
} else {
// Don't fold a null constant.
return false;
}
inst->SetOpcode(SpvOpFMul);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
{SPV_OPERAND_TYPE_ID, {id}}});
return true;
}
return false;
};
}
// Elides consecutive negate instructions.
FoldingRule MergeNegateArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
Instruction* op_inst =
context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
return false;
if (op_inst->opcode() == inst->opcode()) {
// Elide negates.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
return true;
}
return false;
};
}
// Merges negate into a mul or div operation if that operation contains a
// constant operand.
// Cases:
// -(x * 2) = x * -2
// -(2 * x) = x * -2
// -(x / 2) = x / -2
// -(2 / x) = -2 / x
FoldingRule MergeNegateMulDivArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
Instruction* op_inst =
context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
SpvOp opcode = op_inst->opcode();
if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
// Merge negate into mul or div if one operand is constant.
if (op_constants[0] || op_constants[1]) {
bool zero_is_variable = op_constants[0] == nullptr;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t neg_id = NegateConstant(const_mgr, c);
uint32_t non_const_id = zero_is_variable
? op_inst->GetSingleWordInOperand(0u)
: op_inst->GetSingleWordInOperand(1u);
// Change this instruction to a mul/div.
inst->SetOpcode(op_inst->opcode());
if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
}
return true;
}
}
return false;
};
}
// Merges negate into a add or sub operation if that operation contains a
// constant operand.
// Cases:
// -(x + 2) = -2 - x
// -(2 + x) = -2 - x
// -(x - 2) = 2 - x
// -(2 - x) = x - 2
FoldingRule MergeNegateAddSubArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
(void)constants;
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
Instruction* op_inst =
context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
if (op_constants[0] || op_constants[1]) {
bool zero_is_variable = op_constants[0] == nullptr;
bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
(op_inst->opcode() == SpvOpIAdd);
bool swap_operands = !is_add || zero_is_variable;
bool negate_const = is_add;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t const_id = 0;
if (negate_const) {
const_id = NegateConstant(const_mgr, c);
} else {
const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
: op_inst->GetSingleWordInOperand(0u);
}
// Swap operands if necessary and make the instruction a subtraction.
uint32_t op0 =
zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
uint32_t op1 =
zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
if (swap_operands) std::swap(op0, op1);
inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
return true;
}
}
return false;
};
}
// Returns true if |c| has a zero element.
bool HasZero(const analysis::Constant* c) {
if (c->AsNullConstant()) {
return true;
}
if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
for (auto& comp : vec_const->GetComponents())
if (HasZero(comp)) return true;
} else {
assert(c->AsScalarConstant());
return c->AsScalarConstant()->IsZero();
}
return false;
}
// Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be
// Float.
uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
SpvOp opcode,
const analysis::Constant* input1,
const analysis::Constant* input2) {
const analysis::Type* type = input1->type();
assert(type->AsFloat());
uint32_t width = type->AsFloat()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
#define FOLD_OP(op) \
if (width == 64) { \
utils::FloatProxy<double> val = \
input1->GetDouble() op input2->GetDouble(); \
double dval = val.getAsFloat(); \
if (!IsValidResult(dval)) return 0; \
words = val.GetWords(); \
} else { \
utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
float fval = val.getAsFloat(); \
if (!IsValidResult(fval)) return 0; \
words = val.GetWords(); \
}
switch (opcode) {
case SpvOpFMul:
FOLD_OP(*);
break;
case SpvOpFDiv:
if (HasZero(input2)) return 0;
FOLD_OP(/);
break;
case SpvOpFAdd:
FOLD_OP(+);
break;
case SpvOpFSub:
FOLD_OP(-);
break;
default:
assert(false && "Unexpected operation");
break;
}
#undef FOLD_OP
const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
return const_mgr->GetDefiningInstruction(merged_const)->result_id();
}
// Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be
// Integers.
uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
SpvOp opcode, const analysis::Constant* input1,
const analysis::Constant* input2) {
assert(input1->type()->AsInteger());
const analysis::Integer* type = input1->type()->AsInteger();
uint32_t width = type->AsInteger()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
#define FOLD_OP(op) \
if (width == 64) { \
if (type->IsSigned()) { \
int64_t val = input1->GetS64() op input2->GetS64(); \
words = ExtractInts(static_cast<uint64_t>(val)); \
} else { \
uint64_t val = input1->GetU64() op input2->GetU64(); \
words = ExtractInts(val); \
} \
} else { \
if (type->IsSigned()) { \
int32_t val = input1->GetS32() op input2->GetS32(); \
words.push_back(static_cast<uint32_t>(val)); \
} else { \
uint32_t val = input1->GetU32() op input2->GetU32(); \
words.push_back(val); \
} \
}
switch (opcode) {
case SpvOpIMul:
FOLD_OP(*);
break;
case SpvOpSDiv:
case SpvOpUDiv:
assert(false && "Should not merge integer division");
break;
case SpvOpIAdd:
FOLD_OP(+);
break;
case SpvOpISub:
FOLD_OP(-);
break;
default:
assert(false && "Unexpected operation");
break;
}
#undef FOLD_OP
const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
return const_mgr->GetDefiningInstruction(merged_const)->result_id();
}
// Performs |input1| |opcode| |input2| and returns the merged constant result
// id. Returns 0 if the result is not a valid value. The input types must be
// Integers, Floats or Vectors of such.
uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
const analysis::Constant* input1,
const analysis::Constant* input2) {
assert(input1 && input2);
const analysis::Type* type = input1->type();
std::vector<uint32_t> words;
if (const analysis::Vector* vector_type = type->AsVector()) {
const analysis::Type* ele_type = vector_type->element_type();
for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
uint32_t id = 0;
const analysis::Constant* input1_comp = nullptr;
if (const analysis::VectorConstant* input1_vector =
input1->AsVectorConstant()) {
input1_comp = input1_vector->GetComponents()[i];
} else {
assert(input1->AsNullConstant());
input1_comp = const_mgr->GetConstant(ele_type, {});
}
const analysis::Constant* input2_comp = nullptr;
if (const analysis::VectorConstant* input2_vector =
input2->AsVectorConstant()) {
input2_comp = input2_vector->GetComponents()[i];
} else {
assert(input2->AsNullConstant());
input2_comp = const_mgr->GetConstant(ele_type, {});
}
if (ele_type->AsFloat()) {
id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
input2_comp);
} else {
assert(ele_type->AsInteger());
id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
input2_comp);
}
if (id == 0) return 0;
words.push_back(id);
}
const analysis::Constant* merged_const =
const_mgr->GetConstant(type, words);
return const_mgr->GetDefiningInstruction(merged_const)->result_id();
} else if (type->AsFloat()) {
return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
} else {
assert(type->AsInteger());
return PerformIntegerOperation(const_mgr, opcode, input1, input2);
}
}
// Merges consecutive multiplies where each contains one constant operand.
// Cases:
// 2 * (x * 2) = x * 4
// 2 * (2 * x) = x * 4
// (x * 2) * 2 = x * 4
// (2 * x) * 2 = x * 4
FoldingRule MergeMulMulArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
// Determine the constant input and the variable input in |inst|.
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == inst->opcode()) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
const_input1, const_input2);
if (merged_id == 0) return false;
uint32_t non_const_id = other_first_is_variable
? other_inst->GetSingleWordInOperand(0u)
: other_inst->GetSingleWordInOperand(1u);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {merged_id}}});
return true;
}
return false;
};
}
// Merges divides into subsequent multiplies if each instruction contains one
// constant operand. Does not support integer operations.
// Cases:
// 2 * (x / 2) = x * 1
// 2 * (2 / x) = 4 / x
// (x / 2) * 2 = x * 1
// (2 / x) * 2 = 4 / x
// (y / x) * x = y
// x * (y / x) = y
FoldingRule MergeMulDivArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
for (uint32_t i = 0; i < 2; i++) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
Instruction* op_inst = def_use_mgr->GetDef(op_id);
if (op_inst->opcode() == SpvOpFDiv) {
if (op_inst->GetSingleWordInOperand(1) ==
inst->GetSingleWordInOperand(1 - i)) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
return true;
}
}
}
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
if (other_inst->opcode() == SpvOpFDiv) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
// If the variable value is the second operand of the divide, multiply
// the constants together. Otherwise divide the constants.
uint32_t merged_id = PerformOperation(
const_mgr,
other_first_is_variable ? other_inst->opcode() : inst->opcode(),
const_input1, const_input2);
if (merged_id == 0) return false;
uint32_t non_const_id = other_first_is_variable
? other_inst->GetSingleWordInOperand(0u)
: other_inst->GetSingleWordInOperand(1u);
// If the variable value is on the second operand of the div, then this
// operation is a div. Otherwise it should be a multiply.
inst->SetOpcode(other_first_is_variable ? inst->opcode()
: other_inst->opcode());
if (other_first_is_variable) {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {merged_id}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
{SPV_OPERAND_TYPE_ID, {non_const_id}}});
}
return true;
}
return false;
};
}
// Merges multiply of constant and negation.
// Cases:
// (-x) * 2 = x * -2
// 2 * (-x) = x * -2
FoldingRule MergeMulNegateArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpFNegate ||
other_inst->opcode() == SpvOpSNegate) {
uint32_t neg_id = NegateConstant(const_mgr, const_input1);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
return true;
}
return false;
};
}
// Merges consecutive divides if each instruction contains one constant operand.
// Does not support integer division.
// Cases:
// 2 / (x / 2) = 4 / x
// 4 / (2 / x) = 2 * x
// (4 / x) / 2 = 2 / x
// (x / 2) / 2 = x / 4
FoldingRule MergeDivDivArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1 || HasZero(const_input1)) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
bool first_is_variable = constants[0] == nullptr;
if (other_inst->opcode() == inst->opcode()) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2 || HasZero(const_input2)) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
SpvOp merge_op = inst->opcode();
if (other_first_is_variable) {
// Constants magnify.
merge_op = SpvOpFMul;
}
// This is an x / (*) case. Swap the inputs. Doesn't harm multiply
// because it is commutative.
if (first_is_variable) std::swap(const_input1, const_input2);
uint32_t merged_id =
PerformOperation(const_mgr, merge_op, const_input1, const_input2);
if (merged_id == 0) return false;
uint32_t non_const_id = other_first_is_variable
? other_inst->GetSingleWordInOperand(0u)
: other_inst->GetSingleWordInOperand(1u);
SpvOp op = inst->opcode();
if (!first_is_variable && !other_first_is_variable) {
// Effectively div of 1/x, so change to multiply.
op = SpvOpFMul;
}
uint32_t op1 = merged_id;
uint32_t op2 = non_const_id;
if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
inst->SetOpcode(op);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Fold multiplies succeeded by divides where each instruction contains a
// constant operand. Does not support integer divide.
// Cases:
// 4 / (x * 2) = 2 / x
// 4 / (2 * x) = 2 / x
// (x * 4) / 2 = x * 2
// (4 * x) / 2 = x * 2
// (x * y) / x = y
// (y * x) / x = y
FoldingRule MergeDivMulArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv);
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
if (!inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
uint32_t op_id = inst->GetSingleWordInOperand(0);
Instruction* op_inst = def_use_mgr->GetDef(op_id);
if (op_inst->opcode() == SpvOpFMul) {
for (uint32_t i = 0; i < 2; i++) {
if (op_inst->GetSingleWordInOperand(i) ==
inst->GetSingleWordInOperand(1)) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{op_inst->GetSingleWordInOperand(1 - i)}}});
return true;
}
}
}
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1 || HasZero(const_input1)) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
bool first_is_variable = constants[0] == nullptr;
if (other_inst->opcode() == SpvOpFMul) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
bool other_first_is_variable = other_constants[0] == nullptr;
// This is an x / (*) case. Swap the inputs.
if (first_is_variable) std::swap(const_input1, const_input2);
uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
const_input1, const_input2);
if (merged_id == 0) return false;
uint32_t non_const_id = other_first_is_variable
? other_inst->GetSingleWordInOperand(0u)
: other_inst->GetSingleWordInOperand(1u);
uint32_t op1 = merged_id;
uint32_t op2 = non_const_id;
if (first_is_variable) std::swap(op1, op2);
// Convert to multiply
if (first_is_variable) inst->SetOpcode(other_inst->opcode());
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Fold divides of a constant and a negation.
// Cases:
// (-x) / 2 = x / -2
// 2 / (-x) = 2 / -x
FoldingRule MergeDivNegateArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
inst->opcode() == SpvOpUDiv);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
bool first_is_variable = constants[0] == nullptr;
if (other_inst->opcode() == SpvOpFNegate ||
other_inst->opcode() == SpvOpSNegate) {
uint32_t neg_id = NegateConstant(const_mgr, const_input1);
if (first_is_variable) {
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
} else {
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {neg_id}},
{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
}
return true;
}
return false;
};
}
// Folds addition of a constant and a negation.
// Cases:
// (-x) + 2 = 2 - x
// 2 + (-x) = 2 - x
FoldingRule MergeAddNegateArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpSNegate ||
other_inst->opcode() == SpvOpFNegate) {
inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
: inst->GetSingleWordInOperand(1u);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {const_id}},
{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
return true;
}
return false;
};
}
// Folds subtraction of a constant and a negation.
// Cases:
// (-x) - 2 = -2 - x
// 2 - (-x) = x + 2
FoldingRule MergeSubNegateArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpSNegate ||
other_inst->opcode() == SpvOpFNegate) {
uint32_t op1 = 0;
uint32_t op2 = 0;
SpvOp opcode = inst->opcode();
if (constants[0] != nullptr) {
op1 = other_inst->GetSingleWordInOperand(0u);
op2 = inst->GetSingleWordInOperand(0u);
opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
} else {
op1 = NegateConstant(const_mgr, const_input1);
op2 = other_inst->GetSingleWordInOperand(0u);
}
inst->SetOpcode(opcode);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Folds addition of an addition where each operation has a constant operand.
// Cases:
// (x + 2) + 2 = x + 4
// (2 + x) + 2 = x + 4
// 2 + (x + 2) = x + 4
// 2 + (2 + x) = x + 4
FoldingRule MergeAddAddArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpFAdd ||
other_inst->opcode() == SpvOpIAdd) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
Instruction* non_const_input =
NonConstInput(context, other_constants[0], other_inst);
uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
const_input1, const_input2);
if (merged_id == 0) return false;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
{SPV_OPERAND_TYPE_ID, {merged_id}}});
return true;
}
return false;
};
}
// Folds addition of a subtraction where each operation has a constant operand.
// Cases:
// (x - 2) + 2 = x + 0
// (2 - x) + 2 = 4 - x
// 2 + (x - 2) = x + 0
// 2 + (2 - x) = 4 - x
FoldingRule MergeAddSubArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpFSub ||
other_inst->opcode() == SpvOpISub) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
bool first_is_variable = other_constants[0] == nullptr;
SpvOp op = inst->opcode();
uint32_t op1 = 0;
uint32_t op2 = 0;
if (first_is_variable) {
// Subtract constants. Non-constant operand is first.
op1 = other_inst->GetSingleWordInOperand(0u);
op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
const_input2);
} else {
// Add constants. Constant operand is first. Change the opcode.
op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
const_input2);
op2 = other_inst->GetSingleWordInOperand(1u);
op = other_inst->opcode();
}
if (op1 == 0 || op2 == 0) return false;
inst->SetOpcode(op);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Folds subtraction of an addition where each operand has a constant operand.
// Cases:
// (x + 2) - 2 = x + 0
// (2 + x) - 2 = x + 0
// 2 - (x + 2) = 0 - x
// 2 - (2 + x) = 0 - x
FoldingRule MergeSubAddArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpFAdd ||
other_inst->opcode() == SpvOpIAdd) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
Instruction* non_const_input =
NonConstInput(context, other_constants[0], other_inst);
// If the first operand of the sub is not a constant, swap the constants
// so the subtraction has the correct operands.
if (constants[0] == nullptr) std::swap(const_input1, const_input2);
// Subtract the constants.
uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
const_input1, const_input2);
SpvOp op = inst->opcode();
uint32_t op1 = 0;
uint32_t op2 = 0;
if (constants[0] == nullptr) {
// Non-constant operand is first. Change the opcode.
op1 = non_const_input->result_id();
op2 = merged_id;
op = other_inst->opcode();
} else {
// Constant operand is first.
op1 = merged_id;
op2 = non_const_input->result_id();
}
if (op1 == 0 || op2 == 0) return false;
inst->SetOpcode(op);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Folds subtraction of a subtraction where each operand has a constant operand.
// Cases:
// (x - 2) - 2 = x - 4
// (2 - x) - 2 = 0 - x
// 2 - (x - 2) = 4 - x
// 2 - (2 - x) = x + 0
FoldingRule MergeSubSubArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);
if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
return false;
if (other_inst->opcode() == SpvOpFSub ||
other_inst->opcode() == SpvOpISub) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
const analysis::Constant* const_input2 = ConstInput(other_constants);
if (!const_input2) return false;
Instruction* non_const_input =
NonConstInput(context, other_constants[0], other_inst);
// Merge the constants.
uint32_t merged_id = 0;
SpvOp merge_op = inst->opcode();
if (other_constants[0] == nullptr) {
merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
} else if (constants[0] == nullptr) {
std::swap(const_input1, const_input2);
}
merged_id =
PerformOperation(const_mgr, merge_op, const_input1, const_input2);
if (merged_id == 0) return false;
SpvOp op = inst->opcode();
if (constants[0] != nullptr && other_constants[0] != nullptr) {
// Change the operation.
op = uses_float ? SpvOpFAdd : SpvOpIAdd;
}
uint32_t op1 = 0;
uint32_t op2 = 0;
if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
op1 = merged_id;
op2 = non_const_input->result_id();
} else {
op1 = non_const_input->result_id();
op2 = merged_id;
}
inst->SetOpcode(op);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
return true;
}
return false;
};
}
// Helper function for MergeGenericAddSubArithmetic. If |addend| and
// subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
IRContext* context = inst->context();
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
Instruction* sub_inst = def_use_mgr->GetDef(sub);
if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
return false;
if (sub_inst->opcode() == SpvOpFSub &&
!sub_inst->IsFloatingPointFoldingAllowed())
return false;
if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
context->UpdateDefUse(inst);
return true;
}
// Folds addition of a subtraction where the subtrahend is equal to the
// other addend. Return a copy of the minuend. Accepts generic (const and
// non-const) operands.
// Cases:
// (a - b) + b = a
// b + (a - b) = a
FoldingRule MergeGenericAddSubArithmetic() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
uint32_t width = ElementWidth(type);
if (width != 32 && width != 64) return false;
uint32_t add_op0 = inst->GetSingleWordInOperand(0);
uint32_t add_op1 = inst->GetSingleWordInOperand(1);
if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
return MergeGenericAddendSub(add_op1, add_op0, inst);
};
}
// Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
// generate |factor0_0| * (|factor0_1| + |factor1_1|).
bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
uint32_t factor1_0, uint32_t factor1_1,
Instruction* inst) {
IRContext* context = inst->context();
if (factor0_0 != factor1_0) return false;
InstructionBuilder ir_builder(
context, inst,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
Instruction* new_add_inst = ir_builder.AddBinaryOp(
inst->type_id(), inst->opcode(), factor0_1, factor1_1);
inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
{SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
context->UpdateDefUse(inst);
return true;
}
// Perform the following factoring identity, handling all operand order
// combinations: (a * b) + (a * c) = a * (b + c)
FoldingRule FactorAddMuls() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
bool uses_float = HasFloatingPoint(type);
if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
uint32_t add_op0 = inst->GetSingleWordInOperand(0);
Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
if (add_op0_inst->opcode() != SpvOpFMul &&
add_op0_inst->opcode() != SpvOpIMul)
return false;
uint32_t add_op1 = inst->GetSingleWordInOperand(1);
Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
if (add_op1_inst->opcode() != SpvOpFMul &&
add_op1_inst->opcode() != SpvOpIMul)
return false;
// Only perform this optimization if both of the muls only have one use.
// Otherwise this is a deoptimization in size and performance.
if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
if (add_op0_inst->opcode() == SpvOpFMul &&
(!add_op0_inst->IsFloatingPointFoldingAllowed() ||
!add_op1_inst->IsFloatingPointFoldingAllowed()))
return false;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
// Check if operand i in add_op0_inst matches operand j in add_op1_inst.
if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
add_op0_inst->GetSingleWordInOperand(1 - i),
add_op1_inst->GetSingleWordInOperand(j),
add_op1_inst->GetSingleWordInOperand(1 - j),
inst))
return true;
}
}
return false;
};
}
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
for (uint32_t i = 0; i < 2; i++) {
if (constants[i] == nullptr) {
continue;
}
const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
if (int_constant) {
uint32_t width = ElementWidth(int_constant->type());
if (width != 32 && width != 64) return false;
bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
: int_constant->GetU64BitValue() == 1ull;
if (is_one) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
return true;
}
}
}
return false;
};
}
FoldingRule CompositeConstructFeedingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
// If the input to an OpCompositeExtract is an OpCompositeConstruct,
// then we can simply use the appropriate element in the construction.
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
// If there are no index operands, then this rule cannot do anything.
if (inst->NumInOperands() <= 1) {
return false;
}
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpCompositeConstruct) {
return false;
}
std::vector<Operand> operands;
analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
if (composite_type->AsVector() == nullptr) {
// Get the element being extracted from the OpCompositeConstruct
// Since it is not a vector, it is simple to extract the single element.
uint32_t element_index = inst->GetSingleWordInOperand(1);
uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
// Add the remaining indices for extraction.
for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
} else {
// With vectors we have to handle the case where it is concatenating
// vectors.
assert(inst->NumInOperands() == 2 &&
"Expecting a vector of scalar values.");
uint32_t element_index = inst->GetSingleWordInOperand(1);
for (uint32_t construct_index = 0;
construct_index < cinst->NumInOperands(); ++construct_index) {
uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
Instruction* element_def = def_use_mgr->GetDef(element_id);
analysis::Vector* element_type =
type_mgr->GetType(element_def->type_id())->AsVector();
if (element_type) {
uint32_t vector_size = element_type->element_count();
if (vector_size < element_index) {
// The element we want comes after this vector.
element_index -= vector_size;
} else {
// We want an element of this vector.
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
break;
}
} else {
if (element_index == 0) {
// This is a scalar, and we this is the element we are extracting.
operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
break;
} else {
// Skip over this scalar value.
--element_index;
}
}
}
}
// If there were no extra indices, then we have the final object. No need
// to extract even more.
if (operands.size() == 1) {
inst->SetOpcode(SpvOpCopyObject);
}
inst->SetInOperands(std::move(operands));
return true;
};
}
// If the OpCompositeConstruct is simply putting back together elements that
// where extracted from the same source, we can simply reuse the source.
//
// This is a common code pattern because of the way that scalar replacement
// works.
bool CompositeExtractFeedingConstruct(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeConstruct &&
"Wrong opcode. Should be OpCompositeConstruct.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
uint32_t original_id = 0;
if (inst->NumInOperands() == 0) {
// The struct being constructed has no members.
return false;
}
// Check each element to make sure they are:
// - extractions
// - extracting the same position they are inserting
// - all extract from the same id.
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
const uint32_t element_id = inst->GetSingleWordInOperand(i);
Instruction* element_inst = def_use_mgr->GetDef(element_id);
if (element_inst->opcode() != SpvOpCompositeExtract) {
return false;
}
if (element_inst->NumInOperands() != 2) {
return false;
}
if (element_inst->GetSingleWordInOperand(1) != i) {
return false;
}
if (i == 0) {
original_id =
element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
} else if (original_id !=
element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
return false;
}
}
// The last check it to see that the object being extracted from is the
// correct type.
Instruction* original_inst = def_use_mgr->GetDef(original_id);
if (original_inst->type_id() != inst->type_id()) {
return false;
}
// Simplify by using the original object.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
return true;
}
FoldingRule InsertFeedingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpCompositeInsert) {
return false;
}
// Find the first position where the list of insert and extract indicies
// differ, if at all.
uint32_t i;
for (i = 1; i < inst->NumInOperands(); ++i) {
if (i + 1 >= cinst->NumInOperands()) {
break;
}
if (inst->GetSingleWordInOperand(i) !=
cinst->GetSingleWordInOperand(i + 1)) {
break;
}
}
// We are extracting the element that was inserted.
if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
return true;
}
// Extracting the value that was inserted along with values for the base
// composite. Cannot do anything.
if (i == inst->NumInOperands()) {
return false;
}
// Extracting an element of the value that was inserted. Extract from
// that value directly.
if (i + 1 == cinst->NumInOperands()) {
std::vector<Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
for (; i < inst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
inst->SetInOperands(std::move(operands));
return true;
}
// Extracting a value that is disjoint from the element being inserted.
// Rewrite the extract to use the composite input to the insert.
std::vector<Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID,
{cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
for (i = 1; i < inst->NumInOperands(); ++i) {
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
{inst->GetSingleWordInOperand(i)}});
}
inst->SetInOperands(std::move(operands));
return true;
};
}
// When a VectorShuffle is feeding an Extract, we can extract from one of the
// operands of the VectorShuffle. We just need to adjust the index in the
// extract instruction.
FoldingRule VectorShuffleFeedingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* cinst = def_use_mgr->GetDef(cid);
if (cinst->opcode() != SpvOpVectorShuffle) {
return false;
}
// Find the size of the first vector operand of the VectorShuffle
Instruction* first_input =
def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
analysis::Type* first_input_type =
type_mgr->GetType(first_input->type_id());
assert(first_input_type->AsVector() &&
"Input to vector shuffle should be vectors.");
uint32_t first_input_size = first_input_type->AsVector()->element_count();
// Get index of the element the vector shuffle is placing in the position
// being extracted.
uint32_t new_index =
cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
// Extracting an undefined value so fold this extract into an undef.
const uint32_t undef_literal_value = 0xffffffff;
if (new_index == undef_literal_value) {
inst->SetOpcode(SpvOpUndef);
inst->SetInOperands({});
return true;
}
// Get the id of the of the vector the elemtent comes from, and update the
// index if needed.
uint32_t new_vector = 0;
if (new_index < first_input_size) {
new_vector = cinst->GetSingleWordInOperand(0);
} else {
new_vector = cinst->GetSingleWordInOperand(1);
new_index -= first_input_size;
}
// Update the extract instruction.
inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
inst->SetInOperand(1, {new_index});
return true;
};
}
// When an FMix with is feeding an Extract that extracts an element whose
// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
// operands of the FMix.
FoldingRule FMixFeedingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeExtract &&
"Wrong opcode. Should be OpCompositeExtract.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
uint32_t composite_id =
inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
if (composite_inst->opcode() != SpvOpExtInst) {
return false;
}
uint32_t inst_set_id =
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
inst_set_id ||
composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
GLSLstd450FMix) {
return false;
}
// Get the |a| for the FMix instruction.
uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
std::unique_ptr<Instruction> a(inst->Clone(context));
a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
context->get_instruction_folder().FoldInstruction(a.get());
if (a->opcode() != SpvOpCopyObject) {
return false;
}
const analysis::Constant* a_const =
const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
if (!a_const) {
return false;
}
bool use_x = false;
assert(a_const->type()->AsFloat());
double element_value = a_const->GetValueAsDouble();
if (element_value == 0.0) {
use_x = true;
} else if (element_value == 1.0) {
use_x = false;
} else {
return false;
}
// Get the id of the of the vector the element comes from.
uint32_t new_vector = 0;
if (use_x) {
new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
} else {
new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
}
// Update the extract instruction.
inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
return true;
};
}
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
uint32_t incoming_value = 0;
for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
if (op_id == inst->result_id()) {
continue;
}
if (incoming_value == 0) {
incoming_value = op_id;
} else if (op_id != incoming_value) {
// Found two possible value. Can't simplify.
return false;
}
}
if (incoming_value == 0) {
// Code looks invalid. Don't do anything.
return false;
}
// We have a single incoming value. Simplify using that value.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
return true;
};
}
FoldingRule RedundantSelect() {
// An OpSelect instruction where both values are the same or the condition is
// constant can be replaced by one of the values
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpSelect &&
"Wrong opcode. Should be OpSelect.");
assert(inst->NumInOperands() == 3);
assert(constants.size() == 3);
uint32_t true_id = inst->GetSingleWordInOperand(1);
uint32_t false_id = inst->GetSingleWordInOperand(2);
if (true_id == false_id) {
// Both results are the same, condition doesn't matter
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
return true;
} else if (constants[0]) {
const analysis::Type* type = constants[0]->type();
if (type->AsBool()) {
// Scalar constant value, select the corresponding value.
inst->SetOpcode(SpvOpCopyObject);
if (constants[0]->AsNullConstant() ||
!constants[0]->AsBoolConstant()->value()) {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
}
return true;
} else {
assert(type->AsVector());
if (constants[0]->AsNullConstant()) {
// All values come from false id.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
return true;
} else {
// Convert to a vector shuffle.
std::vector<Operand> ops;
ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
const analysis::VectorConstant* vector_const =
constants[0]->AsVectorConstant();
uint32_t size =
static_cast<uint32_t>(vector_const->GetComponents().size());
for (uint32_t i = 0; i != size; ++i) {
const analysis::Constant* component =
vector_const->GetComponents()[i];
if (component->AsNullConstant() ||
!component->AsBoolConstant()->value()) {
// Selecting from the false vector which is the second input
// vector to the shuffle. Offset the index by |size|.
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
} else {
// Selecting from true vector which is the first input vector to
// the shuffle.
ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
}
}
inst->SetOpcode(SpvOpVectorShuffle);
inst->SetInOperands(std::move(ops));
return true;
}
}
}
return false;
};
}
enum class FloatConstantKind { Unknown, Zero, One };
FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
if (constant == nullptr) {
return FloatConstantKind::Unknown;
}
assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
if (constant->AsNullConstant()) {
return FloatConstantKind::Zero;
} else if (const analysis::VectorConstant* vc =
constant->AsVectorConstant()) {
const std::vector<const analysis::Constant*>& components =
vc->GetComponents();
assert(!components.empty());
FloatConstantKind kind = getFloatConstantKind(components[0]);
for (size_t i = 1; i < components.size(); ++i) {
if (getFloatConstantKind(components[i]) != kind) {
return FloatConstantKind::Unknown;
}
}
return kind;
} else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
if (fc->IsZero()) return FloatConstantKind::Zero;
uint32_t width = fc->type()->AsFloat()->width();
if (width != 32 && width != 64) return FloatConstantKind::Unknown;
double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
if (value == 0.0) {
return FloatConstantKind::Zero;
} else if (value == 1.0) {
return FloatConstantKind::One;
} else {
return FloatConstantKind::Unknown;
}
} else {
return FloatConstantKind::Unknown;
}
}
FoldingRule RedundantFAdd() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFSub() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpFNegate);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
return true;
}
if (kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFMul() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
return true;
}
if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(
kind0 == FloatConstantKind::One ? 1 : 0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFDiv() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
assert(constants.size() == 2);
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
if (kind0 == FloatConstantKind::Zero) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
if (kind1 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}
return false;
};
}
FoldingRule RedundantFMix() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpExtInst &&
"Wrong opcode. Should be OpExtInst.");
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
uint32_t instSetId =
context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
GLSLstd450FMix) {
assert(constants.size() == 5);
FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID,
{inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
? kFMixXIdInIdx
: kFMixYIdInIdx)}}});
return true;
}
}
return false;
};
}
// This rule handles addition of zero for integers.
FoldingRule RedundantIAdd() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
uint32_t operand = std::numeric_limits<uint32_t>::max();
const analysis::Type* operand_type = nullptr;
if (constants[0] && constants[0]->IsZero()) {
operand = inst->GetSingleWordInOperand(1);
operand_type = constants[0]->type();
} else if (constants[1] && constants[1]->IsZero()) {
operand = inst->GetSingleWordInOperand(0);
operand_type = constants[1]->type();
}
if (operand != std::numeric_limits<uint32_t>::max()) {
const analysis::Type* inst_type =
context->get_type_mgr()->GetType(inst->type_id());
if (inst_type->IsSame(operand_type)) {
inst->SetOpcode(SpvOpCopyObject);
} else {
inst->SetOpcode(SpvOpBitcast);
}
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
return true;
}
return false;
};
}
// This rule look for a dot with a constant vector containing a single 1 and
// the rest 0s. This is the same as doing an extract.
FoldingRule DotProductDoingExtract() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}
for (int i = 0; i < 2; ++i) {
if (!constants[i]) {
continue;
}
const analysis::Vector* vector_type = constants[i]->type()->AsVector();
assert(vector_type && "Inputs to OpDot must be vectors.");
const analysis::Float* element_type =
vector_type->element_type()->AsFloat();
assert(element_type && "Inputs to OpDot must be vectors of floats.");
uint32_t element_width = element_type->width();
if (element_width != 32 && element_width != 64) {
return false;
}
std::vector<const analysis::Constant*> components;
components = constants[i]->GetVectorComponents(const_mgr);
const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
uint32_t component_with_one = kNotFound;
bool all_others_zero = true;
for (uint32_t j = 0; j < components.size(); ++j) {
const analysis::Constant* element = components[j];
double value =
(element_width == 32 ? element->GetFloat() : element->GetDouble());
if (value == 0.0) {
continue;
} else if (value == 1.0) {
if (component_with_one == kNotFound) {
component_with_one = j;
} else {
component_with_one = kNotFound;
break;
}
} else {
all_others_zero = false;
break;
}
}
if (!all_others_zero || component_with_one == kNotFound) {
continue;
}
std::vector<Operand> operands;
operands.push_back(
{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
inst->SetOpcode(SpvOpCompositeExtract);
inst->SetInOperands(std::move(operands));
return true;
}
return false;
};
}
// If we are storing an undef, then we can remove the store.
//
// TODO: We can do something similar for OpImageWrite, but checking for volatile
// is complicated. Waiting to see if it is needed.
FoldingRule StoringUndef() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
// If this is a volatile store, the store cannot be removed.
if (inst->NumInOperands() == 3) {
if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
return false;
}
}
uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
Instruction* object_inst = def_use_mgr->GetDef(object_id);
if (object_inst->opcode() == SpvOpUndef) {
inst->ToNop();
return true;
}
return false;
};
}
FoldingRule VectorShuffleFeedingShuffle() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpVectorShuffle &&
"Wrong opcode. Should be OpVectorShuffle.");
analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
Instruction* feeding_shuffle_inst =
def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
analysis::Vector* op0_type =
type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
uint32_t op0_length = op0_type->element_count();
bool feeder_is_op0 = true;
if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
feeding_shuffle_inst =
def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
feeder_is_op0 = false;
}
if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
return false;
}
Instruction* feeder2 =
def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
analysis::Vector* feeder_op0_type =
type_mgr->GetType(feeder2->type_id())->AsVector();
uint32_t feeder_op0_length = feeder_op0_type->element_count();
uint32_t new_feeder_id = 0;
std::vector<Operand> new_operands;
new_operands.resize(
2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
const uint32_t undef_literal = 0xffffffff;
for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
uint32_t component_index = inst->GetSingleWordInOperand(op);
// Do not interpret the undefined value literal as coming from operand 1.
if (component_index != undef_literal &&
feeder_is_op0 == (component_index < op0_length)) {
// This component comes from the feeding_shuffle_inst. Update
// |component_index| to be the index into the operand of the feeder.
// Adjust component_index to get the index into the operands of the
// feeding_shuffle_inst.
if (component_index >= op0_length) {
component_index -= op0_length;
}
component_index =
feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
// Check if we are using a component from the first or second operand of
// the feeding instruction.
if (component_index < feeder_op0_length) {
if (new_feeder_id == 0) {
// First time through, save the id of the operand the element comes
// from.
new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
} else if (new_feeder_id !=
feeding_shuffle_inst->GetSingleWordInOperand(0)) {
// We need both elements of the feeding_shuffle_inst, so we cannot
// fold.
return false;
}
} else {
if (new_feeder_id == 0) {
// First time through, save the id of the operand the element comes
// from.
new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
} else if (new_feeder_id !=
feeding_shuffle_inst->GetSingleWordInOperand(1)) {
// We need both elements of the feeding_shuffle_inst, so we cannot
// fold.
return false;
}
component_index -= feeder_op0_length;
}
if (!feeder_is_op0) {
component_index += op0_length;
}
}
new_operands.push_back(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
}
if (new_feeder_id == 0) {
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Type* type =
type_mgr->GetType(feeding_shuffle_inst->type_id());
const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
new_feeder_id =
const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
}
if (feeder_is_op0) {
// If the size of the first vector operand changed then the indices
// referring to the second operand need to be adjusted.
Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
analysis::Type* new_feeder_type =
type_mgr->GetType(new_feeder_inst->type_id());
uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
int32_t adjustment = op0_length - new_op0_size;
if (adjustment != 0) {
for (uint32_t i = 2; i < new_operands.size(); i++) {
if (inst->GetSingleWordInOperand(i) >= op0_length) {
new_operands[i].words[0] -= adjustment;
}
}
}
new_operands[0].words[0] = new_feeder_id;
new_operands[1] = inst->GetInOperand(1);
} else {
new_operands[1].words[0] = new_feeder_id;
new_operands[0] = inst->GetInOperand(0);
}
inst->SetInOperands(std::move(new_operands));
return true;
};
}
// Removes duplicate ids from the interface list of an OpEntryPoint
// instruction.
FoldingRule RemoveRedundantOperands() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpEntryPoint &&
"Wrong opcode. Should be OpEntryPoint.");
bool has_redundant_operand = false;
std::unordered_set<uint32_t> seen_operands;
std::vector<Operand> new_operands;
new_operands.emplace_back(inst->GetOperand(0));
new_operands.emplace_back(inst->GetOperand(1));
new_operands.emplace_back(inst->GetOperand(2));
for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
new_operands.emplace_back(inst->GetOperand(i));
} else {
has_redundant_operand = true;
}
}
if (!has_redundant_operand) {
return false;
}
inst->SetInOperands(std::move(new_operands));
return true;
};
}
// If an image instruction's operand is a constant, updates the image operand
// flag from Offset to ConstOffset.
FoldingRule UpdateImageOperands() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
const auto opcode = inst->opcode();
(void)opcode;
assert((opcode == SpvOpImageSampleImplicitLod ||
opcode == SpvOpImageSampleExplicitLod ||
opcode == SpvOpImageSampleDrefImplicitLod ||
opcode == SpvOpImageSampleDrefExplicitLod ||
opcode == SpvOpImageSampleProjImplicitLod ||
opcode == SpvOpImageSampleProjExplicitLod ||
opcode == SpvOpImageSampleProjDrefImplicitLod ||
opcode == SpvOpImageSampleProjDrefExplicitLod ||
opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
opcode == SpvOpImageWrite ||
opcode == SpvOpImageSparseSampleImplicitLod ||
opcode == SpvOpImageSparseSampleExplicitLod ||
opcode == SpvOpImageSparseSampleDrefImplicitLod ||
opcode == SpvOpImageSparseSampleDrefExplicitLod ||
opcode == SpvOpImageSparseSampleProjImplicitLod ||
opcode == SpvOpImageSparseSampleProjExplicitLod ||
opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
opcode == SpvOpImageSparseFetch ||
opcode == SpvOpImageSparseGather ||
opcode == SpvOpImageSparseDrefGather ||
opcode == SpvOpImageSparseRead) &&
"Wrong opcode. Should be an image instruction.");
int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
if (operand_index >= 0) {
auto image_operands = inst->GetSingleWordInOperand(operand_index);
if (image_operands & SpvImageOperandsOffsetMask) {
uint32_t offset_operand_index = operand_index + 1;
if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
if (image_operands & SpvImageOperandsGradMask)
offset_operand_index += 2;
assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
"Offset and ConstOffset may not be used together");
if (offset_operand_index < inst->NumOperands()) {
if (constants[offset_operand_index]) {
image_operands = image_operands | SpvImageOperandsConstOffsetMask;
image_operands = image_operands & ~SpvImageOperandsOffsetMask;
inst->SetInOperand(operand_index, {image_operands});
return true;
}
}
}
}
return false;
};
}
} // namespace
void FoldingRules::AddFoldingRules() {
// Add all folding rules to the list for the opcodes to which they apply.
// Note that the order in which rules are added to the list matters. If a rule
// applies to the instruction, the rest of the rules will not be attempted.
// Take that into consideration.
rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
rules_[SpvOpFAdd].push_back(RedundantFAdd());
rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[SpvOpFAdd].push_back(FactorAddMuls());
rules_[SpvOpFDiv].push_back(RedundantFDiv());
rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
rules_[SpvOpFMul].push_back(RedundantFMul());
rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
rules_[SpvOpFSub].push_back(RedundantFSub());
rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
rules_[SpvOpIAdd].push_back(RedundantIAdd());
rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
rules_[SpvOpIAdd].push_back(FactorAddMuls());
rules_[SpvOpIMul].push_back(IntMultipleBy1());
rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
rules_[SpvOpPhi].push_back(RedundantPhi());
rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
rules_[SpvOpSelect].push_back(RedundantSelect());
rules_[SpvOpStore].push_back(StoringUndef());
rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
rules_[SpvOpImageGather].push_back(UpdateImageOperands());
rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
rules_[SpvOpImageRead].push_back(UpdateImageOperands());
rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
UpdateImageOperands());
rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
FeatureManager* feature_manager = context_->get_feature_mgr();
// Add rules for GLSLstd450
uint32_t ext_inst_glslstd450_id =
feature_manager->GetExtInstImportId_GLSLstd450();
if (ext_inst_glslstd450_id != 0) {
ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
RedundantFMix());
}
}
} // namespace opt
} // namespace spvtools