// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "folding_rules.h" #include #include "latest_version_glsl_std_450_header.h" namespace spvtools { namespace opt { namespace { const uint32_t kExtractCompositeIdInIdx = 0; const uint32_t kInsertObjectIdInIdx = 0; const uint32_t kInsertCompositeIdInIdx = 1; const uint32_t kExtInstSetIdInIdx = 0; const uint32_t kExtInstInstructionInIdx = 1; const uint32_t kFMixXIdInIdx = 2; const uint32_t kFMixYIdInIdx = 3; // Returns the element width of |type|. uint32_t ElementWidth(const analysis::Type* type) { if (const analysis::Vector* vec_type = type->AsVector()) { return ElementWidth(vec_type->element_type()); } else if (const analysis::Float* float_type = type->AsFloat()) { return float_type->width(); } else { assert(type->AsInteger()); return type->AsInteger()->width(); } } // Returns true if |type| is Float or a vector of Float. bool HasFloatingPoint(const analysis::Type* type) { if (type->AsFloat()) { return true; } else if (const analysis::Vector* vec_type = type->AsVector()) { return vec_type->element_type()->AsFloat() != nullptr; } return false; } // Returns false if |val| is NaN, infinite or subnormal. template bool IsValidResult(T val) { int classified = std::fpclassify(val); switch (classified) { case FP_NAN: case FP_INFINITE: case FP_SUBNORMAL: return false; default: return true; } } const analysis::Constant* ConstInput( const std::vector& constants) { return constants[0] ? constants[0] : constants[1]; } ir::Instruction* NonConstInput(ir::IRContext* context, const analysis::Constant* c, ir::Instruction* inst) { uint32_t in_op = c ? 1u : 0u; return context->get_def_use_mgr()->GetDef( inst->GetSingleWordInOperand(in_op)); } // Returns the negation of |c|. |c| must be a 32 or 64 bit floating point // constant. uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr, const analysis::Constant* c) { assert(c); assert(c->type()->AsFloat()); uint32_t width = c->type()->AsFloat()->width(); assert(width == 32 || width == 64); std::vector words; if (width == 64) { spvutils::FloatProxy result(c->GetDouble() * -1.0); words = result.GetWords(); } else { spvutils::FloatProxy result(c->GetFloat() * -1.0f); words = result.GetWords(); } const analysis::Constant* negated_const = const_mgr->GetConstant(c->type(), std::move(words)); return const_mgr->GetDefiningInstruction(negated_const)->result_id(); } std::vector ExtractInts(uint64_t val) { std::vector words; words.push_back(static_cast(val)); words.push_back(static_cast(val >> 32)); return words; } // Negates the integer constant |c|. Returns the id of the defining instruction. uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr, const analysis::Constant* c) { assert(c); assert(c->type()->AsInteger()); uint32_t width = c->type()->AsInteger()->width(); assert(width == 32 || width == 64); std::vector words; if (width == 64) { uint64_t uval = static_cast(0 - c->GetU64()); words = ExtractInts(uval); } else { words.push_back(static_cast(0 - c->GetU32())); } const analysis::Constant* negated_const = const_mgr->GetConstant(c->type(), std::move(words)); return const_mgr->GetDefiningInstruction(negated_const)->result_id(); } // Negates the vector constant |c|. Returns the id of the defining instruction. uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr, const analysis::Constant* c) { assert(const_mgr && c); assert(c->type()->AsVector()); if (c->AsNullConstant()) { // 0.0 vs -0.0 shouldn't matter. return const_mgr->GetDefiningInstruction(c)->result_id(); } else { const analysis::Type* component_type = c->AsVectorConstant()->component_type(); std::vector words; for (auto& comp : c->AsVectorConstant()->GetComponents()) { if (component_type->AsFloat()) { words.push_back(NegateFloatingPointConstant(const_mgr, comp)); } else { assert(component_type->AsInteger()); words.push_back(NegateIntegerConstant(const_mgr, comp)); } } const analysis::Constant* negated_const = const_mgr->GetConstant(c->type(), std::move(words)); return const_mgr->GetDefiningInstruction(negated_const)->result_id(); } } // Negates |c|. Returns the id of the defining instruction. uint32_t NegateConstant(analysis::ConstantManager* const_mgr, const analysis::Constant* c) { if (c->type()->AsVector()) { return NegateVectorConstant(const_mgr, c); } else if (c->type()->AsFloat()) { return NegateFloatingPointConstant(const_mgr, c); } else { assert(c->type()->AsInteger()); return NegateIntegerConstant(const_mgr, c); } } // Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float. // Returns 0 if the reciprocal is NaN, infinite or subnormal. uint32_t Reciprocal(analysis::ConstantManager* const_mgr, const analysis::Constant* c) { assert(const_mgr && c); assert(c->type()->AsFloat()); uint32_t width = c->type()->AsFloat()->width(); assert(width == 32 || width == 64); std::vector words; if (width == 64) { spvutils::FloatProxy result(1.0 / c->GetDouble()); if (!IsValidResult(result.getAsFloat())) return 0; words = result.GetWords(); } else { spvutils::FloatProxy result(1.0f / c->GetFloat()); if (!IsValidResult(result.getAsFloat())) return 0; words = result.GetWords(); } const analysis::Constant* negated_const = const_mgr->GetConstant(c->type(), std::move(words)); return const_mgr->GetDefiningInstruction(negated_const)->result_id(); } // Replaces fdiv where second operand is constant with fmul. FoldingRule ReciprocalFDiv() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; if (constants[1] != nullptr) { uint32_t id = 0; if (const analysis::VectorConstant* vector_const = constants[1]->AsVectorConstant()) { std::vector neg_ids; for (auto& comp : vector_const->GetComponents()) { id = Reciprocal(const_mgr, comp); if (id == 0) return false; neg_ids.push_back(id); } const analysis::Constant* negated_const = const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids)); id = const_mgr->GetDefiningInstruction(negated_const)->result_id(); } else if (constants[1]->AsFloatConstant()) { id = Reciprocal(const_mgr, constants[1]); if (id == 0) return false; } else { // Don't fold a null constant. return false; } inst->SetOpcode(SpvOpFMul); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}}, {SPV_OPERAND_TYPE_ID, {id}}}); return true; } return false; }; } // Elides consecutive negate instructions. FoldingRule MergeNegateArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; ir::Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; if (op_inst->opcode() == inst->opcode()) { // Elide negates. inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}}); return true; } return false; }; } // Merges negate into a mul or div operation if that operation contains a // constant operand. // Cases: // -(x * 2) = x * -2 // -(2 * x) = x * -2 // -(x / 2) = x / -2 // -(2 / x) = -2 / x FoldingRule MergeNegateMulDivArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; ir::Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; SpvOp opcode = op_inst->opcode(); if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul || opcode == SpvOpSDiv || opcode == SpvOpUDiv) { std::vector op_constants = const_mgr->GetOperandConstants(op_inst); // Merge negate into mul or div if one operand is constant. if (op_constants[0] || op_constants[1]) { bool zero_is_variable = op_constants[0] == nullptr; const analysis::Constant* c = ConstInput(op_constants); uint32_t neg_id = NegateConstant(const_mgr, c); uint32_t non_const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : op_inst->GetSingleWordInOperand(1u); // Change this instruction to a mul/div. inst->SetOpcode(op_inst->opcode()); if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) { uint32_t op0 = zero_is_variable ? non_const_id : neg_id; uint32_t op1 = zero_is_variable ? neg_id : non_const_id; inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); } else { inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, {SPV_OPERAND_TYPE_ID, {neg_id}}}); } return true; } } return false; }; } // Merges negate into a add or sub operation if that operation contains a // constant operand. // Cases: // -(x + 2) = -2 - x // -(2 + x) = -2 - x // -(x - 2) = 2 - x // -(2 - x) = x - 2 FoldingRule MergeNegateAddSubArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate); (void)constants; ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; ir::Instruction* op_inst = context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub || op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) { std::vector op_constants = const_mgr->GetOperandConstants(op_inst); if (op_constants[0] || op_constants[1]) { bool zero_is_variable = op_constants[0] == nullptr; bool is_add = (op_inst->opcode() == SpvOpFAdd) || (op_inst->opcode() == SpvOpIAdd); bool swap_operands = !is_add || zero_is_variable; bool negate_const = is_add; const analysis::Constant* c = ConstInput(op_constants); uint32_t const_id = 0; if (negate_const) { const_id = NegateConstant(const_mgr, c); } else { const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u) : op_inst->GetSingleWordInOperand(0u); } // Swap operands if necessary and make the instruction a subtraction. uint32_t op0 = zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id; uint32_t op1 = zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u); if (swap_operands) std::swap(op0, op1); inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); return true; } } return false; }; } // Returns true if |c| has a zero element. bool HasZero(const analysis::Constant* c) { if (c->AsNullConstant()) { return true; } if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) { for (auto& comp : vec_const->GetComponents()) if (HasZero(comp)) return true; } else { assert(c->AsScalarConstant()); return c->AsScalarConstant()->IsZero(); } return false; } // Performs |input1| |opcode| |input2| and returns the merged constant result // id. Returns 0 if the result is not a valid value. The input types must be // Float. uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, const analysis::Constant* input1, const analysis::Constant* input2) { const analysis::Type* type = input1->type(); assert(type->AsFloat()); uint32_t width = type->AsFloat()->width(); assert(width == 32 || width == 64); std::vector words; #define FOLD_OP(op) \ if (width == 64) { \ spvutils::FloatProxy val = \ input1->GetDouble() op input2->GetDouble(); \ double dval = val.getAsFloat(); \ if (!IsValidResult(dval)) return 0; \ words = val.GetWords(); \ } else { \ spvutils::FloatProxy val = \ input1->GetFloat() op input2->GetFloat(); \ float fval = val.getAsFloat(); \ if (!IsValidResult(fval)) return 0; \ words = val.GetWords(); \ } switch (opcode) { case SpvOpFMul: FOLD_OP(*); break; case SpvOpFDiv: if (HasZero(input2)) return 0; FOLD_OP(/); break; case SpvOpFAdd: FOLD_OP(+); break; case SpvOpFSub: FOLD_OP(-); break; default: assert(false && "Unexpected operation"); break; } #undef FOLD_OP const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); return const_mgr->GetDefiningInstruction(merged_const)->result_id(); } // Performs |input1| |opcode| |input2| and returns the merged constant result // id. Returns 0 if the result is not a valid value. The input types must be // Integers. uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, const analysis::Constant* input1, const analysis::Constant* input2) { assert(input1->type()->AsInteger()); const analysis::Integer* type = input1->type()->AsInteger(); uint32_t width = type->AsInteger()->width(); assert(width == 32 || width == 64); std::vector words; #define FOLD_OP(op) \ if (width == 64) { \ if (type->IsSigned()) { \ int64_t val = input1->GetS64() op input2->GetS64(); \ words = ExtractInts(static_cast(val)); \ } else { \ uint64_t val = input1->GetU64() op input2->GetU64(); \ words = ExtractInts(val); \ } \ } else { \ if (type->IsSigned()) { \ int32_t val = input1->GetS32() op input2->GetS32(); \ words.push_back(static_cast(val)); \ } else { \ uint32_t val = input1->GetU32() op input2->GetU32(); \ words.push_back(val); \ } \ } switch (opcode) { case SpvOpIMul: FOLD_OP(*); break; case SpvOpSDiv: case SpvOpUDiv: assert(false && "Should not merge integer division"); break; case SpvOpIAdd: FOLD_OP(+); break; case SpvOpISub: FOLD_OP(-); break; default: assert(false && "Unexpected operation"); break; } #undef FOLD_OP const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); return const_mgr->GetDefiningInstruction(merged_const)->result_id(); } // Performs |input1| |opcode| |input2| and returns the merged constant result // id. Returns 0 if the result is not a valid value. The input types must be // Integers, Floats or Vectors of such. uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode, const analysis::Constant* input1, const analysis::Constant* input2) { assert(input1 && input2); assert(input1->type() == input2->type()); const analysis::Type* type = input1->type(); std::vector words; if (const analysis::Vector* vector_type = type->AsVector()) { const analysis::Type* ele_type = vector_type->element_type(); for (uint32_t i = 0; i != vector_type->element_count(); ++i) { uint32_t id = 0; const analysis::Constant* input1_comp = nullptr; if (const analysis::VectorConstant* input1_vector = input1->AsVectorConstant()) { input1_comp = input1_vector->GetComponents()[i]; } else { assert(input1->AsNullConstant()); input1_comp = const_mgr->GetConstant(ele_type, {}); } const analysis::Constant* input2_comp = nullptr; if (const analysis::VectorConstant* input2_vector = input2->AsVectorConstant()) { input2_comp = input2_vector->GetComponents()[i]; } else { assert(input2->AsNullConstant()); input2_comp = const_mgr->GetConstant(ele_type, {}); } if (ele_type->AsFloat()) { id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp, input2_comp); } else { assert(ele_type->AsInteger()); id = PerformIntegerOperation(const_mgr, opcode, input1_comp, input2_comp); } if (id == 0) return 0; words.push_back(id); } const analysis::Constant* merged_const = const_mgr->GetConstant(type, words); return const_mgr->GetDefiningInstruction(merged_const)->result_id(); } else if (type->AsFloat()) { return PerformFloatingPointOperation(const_mgr, opcode, input1, input2); } else { assert(type->AsInteger()); return PerformIntegerOperation(const_mgr, opcode, input1, input2); } } // Merges consecutive multiplies where each contains one constant operand. // Cases: // 2 * (x * 2) = x * 4 // 2 * (2 * x) = x * 4 // (x * 2) * 2 = x * 4 // (2 * x) * 2 = x * 4 FoldingRule MergeMulMulArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; // Determine the constant input and the variable input in |inst|. const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == inst->opcode()) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; bool other_first_is_variable = other_constants[0] == nullptr; uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); if (merged_id == 0) return false; uint32_t non_const_id = other_first_is_variable ? other_inst->GetSingleWordInOperand(0u) : other_inst->GetSingleWordInOperand(1u); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, {SPV_OPERAND_TYPE_ID, {merged_id}}}); return true; } return false; }; } // Merges divides into subsequent multiplies if each instruction contains one // constant operand. Does not support integer operations. // Cases: // 2 * (x / 2) = x * 1 // 2 * (2 / x) = 4 / x // (x / 2) * 2 = x * 1 // (2 / x) * 2 = 4 / x // (y / x) * x = y // x * (y / x) = y FoldingRule MergeMulDivArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; for (uint32_t i = 0; i < 2; i++) { uint32_t op_id = inst->GetSingleWordInOperand(i); ir::Instruction* op_inst = def_use_mgr->GetDef(op_id); if (op_inst->opcode() == SpvOpFDiv) { if (op_inst->GetSingleWordInOperand(1) == inst->GetSingleWordInOperand(1 - i)) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}}); return true; } } } const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFDiv) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; // If the variable value is the second operand of the divide, multiply // the constants together. Otherwise divide the constants. uint32_t merged_id = PerformOperation( const_mgr, other_first_is_variable ? other_inst->opcode() : inst->opcode(), const_input1, const_input2); if (merged_id == 0) return false; uint32_t non_const_id = other_first_is_variable ? other_inst->GetSingleWordInOperand(0u) : other_inst->GetSingleWordInOperand(1u); // If the variable value is on the second operand of the div, then this // operation is a div. Otherwise it should be a multiply. inst->SetOpcode(other_first_is_variable ? inst->opcode() : other_inst->opcode()); if (other_first_is_variable) { inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, {SPV_OPERAND_TYPE_ID, {merged_id}}}); } else { inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}}, {SPV_OPERAND_TYPE_ID, {non_const_id}}}); } return true; } return false; }; } // Merges multiply of constant and negation. // Cases: // (-x) * 2 = x * -2 // 2 * (-x) = x * -2 FoldingRule MergeMulNegateArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFNegate || other_inst->opcode() == SpvOpSNegate) { uint32_t neg_id = NegateConstant(const_mgr, const_input1); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, {SPV_OPERAND_TYPE_ID, {neg_id}}}); return true; } return false; }; } // Merges consecutive divides if each instruction contains one constant operand. // Does not support integer division. // Cases: // 2 / (x / 2) = 4 / x // 4 / (2 / x) = 2 * x // (4 / x) / 2 = 2 / x // (x / 2) / 2 = x / 4 FoldingRule MergeDivDivArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; bool first_is_variable = constants[0] == nullptr; if (other_inst->opcode() == inst->opcode()) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2 || HasZero(const_input2)) return false; bool other_first_is_variable = other_constants[0] == nullptr; SpvOp merge_op = inst->opcode(); if (other_first_is_variable) { // Constants magnify. merge_op = SpvOpFMul; } // This is an x / (*) case. Swap the inputs. Doesn't harm multiply // because it is commutative. if (first_is_variable) std::swap(const_input1, const_input2); uint32_t merged_id = PerformOperation(const_mgr, merge_op, const_input1, const_input2); if (merged_id == 0) return false; uint32_t non_const_id = other_first_is_variable ? other_inst->GetSingleWordInOperand(0u) : other_inst->GetSingleWordInOperand(1u); SpvOp op = inst->opcode(); if (!first_is_variable && !other_first_is_variable) { // Effectively div of 1/x, so change to multiply. op = SpvOpFMul; } uint32_t op1 = merged_id; uint32_t op2 = non_const_id; if (first_is_variable && other_first_is_variable) std::swap(op1, op2); inst->SetOpcode(op); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } // Fold multiplies succeeded by divides where each instruction contains a // constant operand. Does not support integer divide. // Cases: // 4 / (x * 2) = 2 / x // 4 / (2 * x) = 2 / x // (x * 4) / 2 = x * 2 // (4 * x) / 2 = x * 2 // (x * y) / x = y // (y * x) / x = y FoldingRule MergeDivMulArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv); ir::IRContext* context = inst->context(); analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); if (!inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; uint32_t op_id = inst->GetSingleWordInOperand(0); ir::Instruction* op_inst = def_use_mgr->GetDef(op_id); if (op_inst->opcode() == SpvOpFMul) { for (uint32_t i = 0; i < 2; i++) { if (op_inst->GetSingleWordInOperand(i) == inst->GetSingleWordInOperand(1)) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(1 - i)}}}); return true; } } } const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1 || HasZero(const_input1)) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (!other_inst->IsFloatingPointFoldingAllowed()) return false; bool first_is_variable = constants[0] == nullptr; if (other_inst->opcode() == SpvOpFMul) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; bool other_first_is_variable = other_constants[0] == nullptr; // This is an x / (*) case. Swap the inputs. if (first_is_variable) std::swap(const_input1, const_input2); uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); if (merged_id == 0) return false; uint32_t non_const_id = other_first_is_variable ? other_inst->GetSingleWordInOperand(0u) : other_inst->GetSingleWordInOperand(1u); uint32_t op1 = merged_id; uint32_t op2 = non_const_id; if (first_is_variable) std::swap(op1, op2); // Convert to multiply if (first_is_variable) inst->SetOpcode(other_inst->opcode()); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } // Fold divides of a constant and a negation. // Cases: // (-x) / 2 = x / -2 // 2 / (-x) = 2 / -x FoldingRule MergeDivNegateArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv || inst->opcode() == SpvOpUDiv); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; bool first_is_variable = constants[0] == nullptr; if (other_inst->opcode() == SpvOpFNegate || other_inst->opcode() == SpvOpSNegate) { uint32_t neg_id = NegateConstant(const_mgr, const_input1); if (first_is_variable) { inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}, {SPV_OPERAND_TYPE_ID, {neg_id}}}); } else { inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {neg_id}}, {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); } return true; } return false; }; } // Folds addition of a constant and a negation. // Cases: // (-x) + 2 = 2 - x // 2 + (-x) = 2 - x FoldingRule MergeAddNegateArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpSNegate || other_inst->opcode() == SpvOpFNegate) { inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub); uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u) : inst->GetSingleWordInOperand(1u); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {const_id}}, {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}}); return true; } return false; }; } // Folds subtraction of a constant and a negation. // Cases: // (-x) - 2 = -2 - x // 2 - (-x) = x + 2 FoldingRule MergeSubNegateArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpSNegate || other_inst->opcode() == SpvOpFNegate) { uint32_t op1 = 0; uint32_t op2 = 0; SpvOp opcode = inst->opcode(); if (constants[0] != nullptr) { op1 = other_inst->GetSingleWordInOperand(0u); op2 = inst->GetSingleWordInOperand(0u); opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd; } else { op1 = NegateConstant(const_mgr, const_input1); op2 = other_inst->GetSingleWordInOperand(0u); } inst->SetOpcode(opcode); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } // Folds addition of an addition where each operation has a constant operand. // Cases: // (x + 2) + 2 = x + 4 // (2 + x) + 2 = x + 4 // 2 + (x + 2) = x + 4 // 2 + (2 + x) = x + 4 FoldingRule MergeAddAddArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFAdd || other_inst->opcode() == SpvOpIAdd) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; ir::Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); if (merged_id == 0) return false; inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}}, {SPV_OPERAND_TYPE_ID, {merged_id}}}); return true; } return false; }; } // Folds addition of a subtraction where each operation has a constant operand. // Cases: // (x - 2) + 2 = x + 0 // (2 - x) + 2 = 4 - x // 2 + (x - 2) = x + 0 // 2 + (2 - x) = 4 - x FoldingRule MergeAddSubArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd); ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFSub || other_inst->opcode() == SpvOpISub) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; bool first_is_variable = other_constants[0] == nullptr; SpvOp op = inst->opcode(); uint32_t op1 = 0; uint32_t op2 = 0; if (first_is_variable) { // Subtract constants. Non-constant operand is first. op1 = other_inst->GetSingleWordInOperand(0u); op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1, const_input2); } else { // Add constants. Constant operand is first. Change the opcode. op1 = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); op2 = other_inst->GetSingleWordInOperand(1u); op = other_inst->opcode(); } if (op1 == 0 || op2 == 0) return false; inst->SetOpcode(op); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } // Folds subtraction of an addition where each operand has a constant operand. // Cases: // (x + 2) - 2 = x + 0 // (2 + x) - 2 = x + 0 // 2 - (x + 2) = 0 - x // 2 - (2 + x) = 0 - x FoldingRule MergeSubAddArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFAdd || other_inst->opcode() == SpvOpIAdd) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; ir::Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); // If the first operand of the sub is not a constant, swap the constants // so the subtraction has the correct operands. if (constants[0] == nullptr) std::swap(const_input1, const_input2); // Subtract the constants. uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(), const_input1, const_input2); SpvOp op = inst->opcode(); uint32_t op1 = 0; uint32_t op2 = 0; if (constants[0] == nullptr) { // Non-constant operand is first. Change the opcode. op1 = non_const_input->result_id(); op2 = merged_id; op = other_inst->opcode(); } else { // Constant operand is first. op1 = merged_id; op2 = non_const_input->result_id(); } if (op1 == 0 || op2 == 0) return false; inst->SetOpcode(op); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } // Folds subtraction of a subtraction where each operand has a constant operand. // Cases: // (x - 2) - 2 = x - 4 // (2 - x) - 2 = 0 - x // 2 - (x - 2) = 4 - x // 2 - (2 - x) = x + 0 FoldingRule MergeSubSubArithmetic() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub); ir::IRContext* context = inst->context(); const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); bool uses_float = HasFloatingPoint(type); if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false; uint32_t width = ElementWidth(type); if (width != 32 && width != 64) return false; const analysis::Constant* const_input1 = ConstInput(constants); if (!const_input1) return false; ir::Instruction* other_inst = NonConstInput(context, constants[0], inst); if (uses_float && !other_inst->IsFloatingPointFoldingAllowed()) return false; if (other_inst->opcode() == SpvOpFSub || other_inst->opcode() == SpvOpISub) { std::vector other_constants = const_mgr->GetOperandConstants(other_inst); const analysis::Constant* const_input2 = ConstInput(other_constants); if (!const_input2) return false; ir::Instruction* non_const_input = NonConstInput(context, other_constants[0], other_inst); // Merge the constants. uint32_t merged_id = 0; SpvOp merge_op = inst->opcode(); if (other_constants[0] == nullptr) { merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd; } else if (constants[0] == nullptr) { std::swap(const_input1, const_input2); } merged_id = PerformOperation(const_mgr, merge_op, const_input1, const_input2); if (merged_id == 0) return false; SpvOp op = inst->opcode(); if (constants[0] != nullptr && other_constants[0] != nullptr) { // Change the operation. op = uses_float ? SpvOpFAdd : SpvOpIAdd; } uint32_t op1 = 0; uint32_t op2 = 0; if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) { op1 = merged_id; op2 = non_const_input->result_id(); } else { op1 = non_const_input->result_id(); op2 = merged_id; } inst->SetOpcode(op); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}); return true; } return false; }; } FoldingRule IntMultipleBy1() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul."); for (uint32_t i = 0; i < 2; i++) { if (constants[i] == nullptr) { continue; } const analysis::IntConstant* int_constant = constants[i]->AsIntConstant(); if (int_constant) { uint32_t width = ElementWidth(int_constant->type()); if (width != 32 && width != 64) return false; bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u : int_constant->GetU64BitValue() == 1ull; if (is_one) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}}); return true; } } } return false; }; } FoldingRule CompositeConstructFeedingExtract() { return [](ir::Instruction* inst, const std::vector&) { // If the input to an OpCompositeExtract is an OpCompositeConstruct, // then we can simply use the appropriate element in the construction. assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); ir::Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpCompositeConstruct) { return false; } std::vector operands; analysis::Type* composite_type = type_mgr->GetType(cinst->type_id()); if (composite_type->AsVector() == nullptr) { // Get the element being extracted from the OpCompositeConstruct // Since it is not a vector, it is simple to extract the single element. uint32_t element_index = inst->GetSingleWordInOperand(1); uint32_t element_id = cinst->GetSingleWordInOperand(element_index); operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); // Add the remaining indices for extraction. for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}}); } } else { // With vectors we have to handle the case where it is concatenating // vectors. assert(inst->NumInOperands() == 2 && "Expecting a vector of scalar values."); uint32_t element_index = inst->GetSingleWordInOperand(1); for (uint32_t construct_index = 0; construct_index < cinst->NumInOperands(); ++construct_index) { uint32_t element_id = cinst->GetSingleWordInOperand(construct_index); ir::Instruction* element_def = def_use_mgr->GetDef(element_id); analysis::Vector* element_type = type_mgr->GetType(element_def->type_id())->AsVector(); if (element_type) { uint32_t vector_size = element_type->element_count(); if (vector_size < element_index) { // The element we want comes after this vector. element_index -= vector_size; } else { // We want an element of this vector. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); operands.push_back( {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}}); break; } } else { if (element_index == 0) { // This is a scalar, and we this is the element we are extracting. operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}}); break; } else { // Skip over this scalar value. --element_index; } } } } // If there were no extra indices, then we have the final object. No need // to extract even more. if (operands.size() == 1) { inst->SetOpcode(SpvOpCopyObject); } inst->SetInOperands(std::move(operands)); return true; }; } FoldingRule CompositeExtractFeedingConstruct() { // If the OpCompositeConstruct is simply putting back together elements that // where extracted from the same souce, we can simlpy reuse the source. // // This is a common code pattern because of the way that scalar replacement // works. return [](ir::Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeConstruct && "Wrong opcode. Should be OpCompositeConstruct."); analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); uint32_t original_id = 0; // Check each element to make sure they are: // - extractions // - extracting the same position they are inserting // - all extract from the same id. for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { uint32_t element_id = inst->GetSingleWordInOperand(i); ir::Instruction* element_inst = def_use_mgr->GetDef(element_id); if (element_inst->opcode() != SpvOpCompositeExtract) { return false; } if (element_inst->NumInOperands() != 2) { return false; } if (element_inst->GetSingleWordInOperand(1) != i) { return false; } if (i == 0) { original_id = element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); } else if (original_id != element_inst->GetSingleWordInOperand( kExtractCompositeIdInIdx)) { return false; } } // The last check it to see that the object being extracted from is the // correct type. ir::Instruction* original_inst = def_use_mgr->GetDef(original_id); if (original_inst->type_id() != inst->type_id()) { return false; } // Simplify by using the original object. inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); return true; }; } FoldingRule InsertFeedingExtract() { return [](ir::Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); ir::Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpCompositeInsert) { return false; } // Find the first position where the list of insert and extract indicies // differ, if at all. uint32_t i; for (i = 1; i < inst->NumInOperands(); ++i) { if (i + 1 >= cinst->NumInOperands()) { break; } if (inst->GetSingleWordInOperand(i) != cinst->GetSingleWordInOperand(i + 1)) { break; } } // We are extracting the element that was inserted. if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}}); return true; } // Extracting the value that was inserted along with values for the base // composite. Cannot do anything. if (i == inst->NumInOperands()) { return false; } // Extracting an element of the value that was inserted. Extract from // that value directly. if (i + 1 == cinst->NumInOperands()) { std::vector operands; operands.push_back( {SPV_OPERAND_TYPE_ID, {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}); for (; i < inst->NumInOperands(); ++i) { operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}}); } inst->SetInOperands(std::move(operands)); return true; } // Extracting a value that is disjoint from the element being inserted. // Rewrite the extract to use the composite input to the insert. std::vector operands; operands.push_back( {SPV_OPERAND_TYPE_ID, {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}}); for (i = 1; i < inst->NumInOperands(); ++i) { operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {inst->GetSingleWordInOperand(i)}}); } inst->SetInOperands(std::move(operands)); return true; }; } // When a VectorShuffle is feeding an Extract, we can extract from one of the // operands of the VectorShuffle. We just need to adjust the index in the // extract instruction. FoldingRule VectorShuffleFeedingExtract() { return [](ir::Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpCompositeExtract && "Wrong opcode. Should be OpCompositeExtract."); analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx); ir::Instruction* cinst = def_use_mgr->GetDef(cid); if (cinst->opcode() != SpvOpVectorShuffle) { return false; } // Find the size of the first vector operand of the VectorShuffle ir::Instruction* first_input = def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0)); analysis::Type* first_input_type = type_mgr->GetType(first_input->type_id()); assert(first_input_type->AsVector() && "Input to vector shuffle should be vectors."); uint32_t first_input_size = first_input_type->AsVector()->element_count(); // Get index of the element the vector shuffle is placing in the position // being extracted. uint32_t new_index = cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1)); // Get the id of the of the vector the elemtent comes from, and update the // index if needed. uint32_t new_vector = 0; if (new_index < first_input_size) { new_vector = cinst->GetSingleWordInOperand(0); } else { new_vector = cinst->GetSingleWordInOperand(1); new_index -= first_input_size; } // Update the extract instruction. inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector}); inst->SetInOperand(1, {new_index}); return true; }; } FoldingRule RedundantPhi() { // An OpPhi instruction where all values are the same or the result of the phi // itself, can be replaced by the value itself. return [](ir::Instruction* inst, const std::vector&) { assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi."); uint32_t incoming_value = 0; for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) { uint32_t op_id = inst->GetSingleWordInOperand(i); if (op_id == inst->result_id()) { continue; } if (incoming_value == 0) { incoming_value = op_id; } else if (op_id != incoming_value) { // Found two possible value. Can't simplify. return false; } } if (incoming_value == 0) { // Code looks invalid. Don't do anything. return false; } // We have a single incoming value. Simplify using that value. inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}}); return true; }; } FoldingRule RedundantSelect() { // An OpSelect instruction where both values are the same or the condition is // constant can be replaced by one of the values return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpSelect && "Wrong opcode. Should be OpSelect."); assert(inst->NumInOperands() == 3); assert(constants.size() == 3); uint32_t true_id = inst->GetSingleWordInOperand(1); uint32_t false_id = inst->GetSingleWordInOperand(2); if (true_id == false_id) { // Both results are the same, condition doesn't matter inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); return true; } else if (constants[0]) { const analysis::Type* type = constants[0]->type(); if (type->AsBool()) { // Scalar constant value, select the corresponding value. inst->SetOpcode(SpvOpCopyObject); if (constants[0]->AsNullConstant() || !constants[0]->AsBoolConstant()->value()) { inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); } else { inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}}); } return true; } else { assert(type->AsVector()); if (constants[0]->AsNullConstant()) { // All values come from false id. inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}}); return true; } else { // Convert to a vector shuffle. std::vector ops; ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}}); ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}}); const analysis::VectorConstant* vector_const = constants[0]->AsVectorConstant(); uint32_t size = static_cast(vector_const->GetComponents().size()); for (uint32_t i = 0; i != size; ++i) { const analysis::Constant* component = vector_const->GetComponents()[i]; if (component->AsNullConstant() || !component->AsBoolConstant()->value()) { // Selecting from the false vector which is the second input // vector to the shuffle. Offset the index by |size|. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}}); } else { // Selecting from true vector which is the first input vector to // the shuffle. ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); } } inst->SetOpcode(SpvOpVectorShuffle); inst->SetInOperands(std::move(ops)); return true; } } } return false; }; } enum class FloatConstantKind { Unknown, Zero, One }; FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) { if (constant == nullptr) { return FloatConstantKind::Unknown; } assert(HasFloatingPoint(constant->type()) && "Unexpected constant type"); if (constant->AsNullConstant()) { return FloatConstantKind::Zero; } else if (const analysis::VectorConstant* vc = constant->AsVectorConstant()) { const std::vector& components = vc->GetComponents(); assert(!components.empty()); FloatConstantKind kind = getFloatConstantKind(components[0]); for (size_t i = 1; i < components.size(); ++i) { if (getFloatConstantKind(components[i]) != kind) { return FloatConstantKind::Unknown; } } return kind; } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) { if (fc->IsZero()) return FloatConstantKind::Zero; uint32_t width = fc->type()->AsFloat()->width(); if (width != 32 && width != 64) return FloatConstantKind::Unknown; double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue(); if (value == 0.0) { return FloatConstantKind::Zero; } else if (value == 1.0) { return FloatConstantKind::One; } else { return FloatConstantKind::Unknown; } } else { return FloatConstantKind::Unknown; } } FoldingRule RedundantFAdd() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd."); assert(constants.size() == 2); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } FloatConstantKind kind0 = getFloatConstantKind(constants[0]); FloatConstantKind kind1 = getFloatConstantKind(constants[1]); if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand( kind0 == FloatConstantKind::Zero ? 1 : 0)}}}); return true; } return false; }; } FoldingRule RedundantFSub() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub."); assert(constants.size() == 2); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } FloatConstantKind kind0 = getFloatConstantKind(constants[0]); FloatConstantKind kind1 = getFloatConstantKind(constants[1]); if (kind0 == FloatConstantKind::Zero) { inst->SetOpcode(SpvOpFNegate); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}}); return true; } if (kind1 == FloatConstantKind::Zero) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); return true; } return false; }; } FoldingRule RedundantFMul() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul."); assert(constants.size() == 2); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } FloatConstantKind kind0 = getFloatConstantKind(constants[0]); FloatConstantKind kind1 = getFloatConstantKind(constants[1]); if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand( kind0 == FloatConstantKind::Zero ? 0 : 1)}}}); return true; } if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand( kind0 == FloatConstantKind::One ? 1 : 0)}}}); return true; } return false; }; } FoldingRule RedundantFDiv() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv."); assert(constants.size() == 2); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } FloatConstantKind kind0 = getFloatConstantKind(constants[0]); FloatConstantKind kind1 = getFloatConstantKind(constants[1]); if (kind0 == FloatConstantKind::Zero) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); return true; } if (kind1 == FloatConstantKind::One) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}}); return true; } return false; }; } FoldingRule RedundantFMix() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpExtInst && "Wrong opcode. Should be OpExtInst."); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } uint32_t instSetId = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId && inst->GetSingleWordInOperand(kExtInstInstructionInIdx) == GLSLstd450FMix) { assert(constants.size() == 5); FloatConstantKind kind4 = getFloatConstantKind(constants[4]); if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) { inst->SetOpcode(SpvOpCopyObject); inst->SetInOperands( {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero ? kFMixXIdInIdx : kFMixYIdInIdx)}}}); return true; } } return false; }; } // This rule look for a dot with a constant vector containing a single 1 and // the rest 0s. This is the same as doing an extract. FoldingRule DotProductDoingExtract() { return [](ir::Instruction* inst, const std::vector& constants) { assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot."); ir::IRContext* context = inst->context(); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); if (!inst->IsFloatingPointFoldingAllowed()) { return false; } for (int i = 0; i < 2; ++i) { if (!constants[i]) { continue; } const analysis::Vector* vector_type = constants[i]->type()->AsVector(); assert(vector_type && "Inputs to OpDot must be vectors."); const analysis::Float* element_type = vector_type->element_type()->AsFloat(); assert(element_type && "Inputs to OpDot must be vectors of floats."); uint32_t element_width = element_type->width(); if (element_width != 32 && element_width != 64) { return false; } std::vector components; components = constants[i]->GetVectorComponents(const_mgr); const uint32_t kNotFound = std::numeric_limits::max(); uint32_t component_with_one = kNotFound; bool all_others_zero = true; for (uint32_t j = 0; j < components.size(); ++j) { const analysis::Constant* element = components[j]; double value = (element_width == 32 ? element->GetFloat() : element->GetDouble()); if (value == 0.0) { continue; } else if (value == 1.0) { if (component_with_one == kNotFound) { component_with_one = j; } else { component_with_one = kNotFound; break; } } else { all_others_zero = false; break; } } if (!all_others_zero || component_with_one == kNotFound) { continue; } std::vector operands; operands.push_back( {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}}); operands.push_back( {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}}); inst->SetOpcode(SpvOpCompositeExtract); inst->SetInOperands(std::move(operands)); return true; } return false; }; } } // namespace spvtools::opt::FoldingRules::FoldingRules() { // Add all folding rules to the list for the opcodes to which they apply. // Note that the order in which rules are added to the list matters. If a rule // applies to the instruction, the rest of the rules will not be attempted. // Take that into consideration. rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct()); rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract()); rules_[SpvOpDot].push_back(DotProductDoingExtract()); rules_[SpvOpExtInst].push_back(RedundantFMix()); rules_[SpvOpFAdd].push_back(RedundantFAdd()); rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic()); rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic()); rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic()); rules_[SpvOpFDiv].push_back(RedundantFDiv()); rules_[SpvOpFDiv].push_back(ReciprocalFDiv()); rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic()); rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic()); rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic()); rules_[SpvOpFMul].push_back(RedundantFMul()); rules_[SpvOpFMul].push_back(MergeMulMulArithmetic()); rules_[SpvOpFMul].push_back(MergeMulDivArithmetic()); rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic()); rules_[SpvOpFNegate].push_back(MergeNegateArithmetic()); rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic()); rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic()); rules_[SpvOpFSub].push_back(RedundantFSub()); rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic()); rules_[SpvOpFSub].push_back(MergeSubAddArithmetic()); rules_[SpvOpFSub].push_back(MergeSubSubArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic()); rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic()); rules_[SpvOpIMul].push_back(IntMultipleBy1()); rules_[SpvOpIMul].push_back(MergeMulMulArithmetic()); rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic()); rules_[SpvOpISub].push_back(MergeSubNegateArithmetic()); rules_[SpvOpISub].push_back(MergeSubAddArithmetic()); rules_[SpvOpISub].push_back(MergeSubSubArithmetic()); rules_[SpvOpPhi].push_back(RedundantPhi()); rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic()); rules_[SpvOpSNegate].push_back(MergeNegateArithmetic()); rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic()); rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic()); rules_[SpvOpSelect].push_back(RedundantSelect()); rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic()); } } // namespace opt } // namespace spvtools