InsertExtractElim: Optimize through VectorShuffle, Mix

This improves Extract replacement to continue through VectorShuffle.
It will also handle Mix with 0.0 or 1.0 in the a-value of the desired
component.

To facilitate optimization of VectorShuffle, the algorithm was refactored
to pass around the indices of the extract in a vector rather than pass the
extract instruction itself. This allows the indices to be modified as the
algorithm progresses.
This commit is contained in:
Greg Fischer 2017-12-29 16:44:43 -07:00 committed by Steven Perron
parent 1ebd860daa
commit 5eafc00ad5
3 changed files with 426 additions and 70 deletions

View File

@ -18,6 +18,9 @@
#include "ir_context.h"
#include "iterator.h"
#include "spirv/1.2/GLSL.std.450.h"
#include <vector>
namespace spvtools {
namespace opt {
@ -27,33 +30,42 @@ namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
const uint32_t kInsertObjectIdInIdx = 0;
const uint32_t kInsertCompositeIdInIdx = 1;
const uint32_t kConstantValueInIdx = 0;
const uint32_t kVectorShuffleVec1IdInIdx = 0;
const uint32_t kVectorShuffleVec2IdInIdx = 1;
const uint32_t kVectorShuffleCompsInIdx = 2;
const uint32_t kTypeVectorCompTypeIdInIdx = 0;
const uint32_t kTypeVectorLengthInIdx = 1;
const uint32_t kTypeFloatWidthInIdx = 0;
const uint32_t kExtInstSetIdInIdx = 0;
const uint32_t kExtInstInstructionInIdx = 1;
const uint32_t kFMixXIdInIdx = 2;
const uint32_t kFMixYIdInIdx = 3;
const uint32_t kFMixAIdInIdx = 4;
} // anonymous namespace
bool InsertExtractElimPass::ExtInsMatch(const ir::Instruction* extInst,
bool InsertExtractElimPass::ExtInsMatch(const std::vector<uint32_t>& extIndices,
const ir::Instruction* insInst,
const uint32_t extOffset) const {
if (extInst->NumInOperands() - extOffset != insInst->NumInOperands() - 1)
return false;
uint32_t numIdx = extInst->NumInOperands() - 1 - extOffset;
for (uint32_t i = 0; i < numIdx; ++i)
if (extInst->GetSingleWordInOperand(i + 1 + extOffset) !=
insInst->GetSingleWordInOperand(i + 2))
uint32_t numIndices = static_cast<uint32_t>(extIndices.size()) - extOffset;
if (numIndices != insInst->NumInOperands() - 2) return false;
for (uint32_t i = 0; i < numIndices; ++i)
if (extIndices[i + extOffset] != insInst->GetSingleWordInOperand(i + 2))
return false;
return true;
}
bool InsertExtractElimPass::ExtInsConflict(const ir::Instruction* extInst,
const ir::Instruction* insInst,
const uint32_t extOffset) const {
if (extInst->NumInOperands() - extOffset == insInst->NumInOperands() - 1)
bool InsertExtractElimPass::ExtInsConflict(
const std::vector<uint32_t>& extIndices, const ir::Instruction* insInst,
const uint32_t extOffset) const {
if (extIndices.size() - extOffset == insInst->NumInOperands() - 2)
return false;
uint32_t extNumIdx = extInst->NumInOperands() - 1 - extOffset;
uint32_t insNumIdx = insInst->NumInOperands() - 2;
uint32_t numIdx = std::min(extNumIdx, insNumIdx);
for (uint32_t i = 0; i < numIdx; ++i)
if (extInst->GetSingleWordInOperand(i + 1 + extOffset) !=
insInst->GetSingleWordInOperand(i + 2))
uint32_t extNumIndices = static_cast<uint32_t>(extIndices.size()) - extOffset;
uint32_t insNumIndices = insInst->NumInOperands() - 2;
uint32_t numIndices = std::min(extNumIndices, insNumIndices);
for (uint32_t i = 0; i < numIndices; ++i)
if (extIndices[i + extOffset] != insInst->GetSingleWordInOperand(i + 2))
return false;
return true;
}
@ -63,6 +75,120 @@ bool InsertExtractElimPass::IsVectorType(uint32_t typeId) {
return typeInst->opcode() == SpvOpTypeVector;
}
uint32_t InsertExtractElimPass::DoExtract(ir::Instruction* compInst,
std::vector<uint32_t>* pExtIndices,
uint32_t extOffset) {
ir::Instruction* cinst = compInst;
uint32_t cid = 0;
uint32_t replId = 0;
while (true) {
if (cinst->opcode() == SpvOpCompositeInsert) {
if (ExtInsMatch(*pExtIndices, cinst, extOffset)) {
// Match! Use inserted value as replacement
replId = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
break;
} else if (ExtInsConflict(*pExtIndices, cinst, extOffset)) {
// If extract has fewer indices than the insert, stop searching.
// Otherwise increment offset of extract indices considered and
// continue searching through the inserted value
if (pExtIndices->size() - extOffset < cinst->NumInOperands() - 2) {
break;
} else {
extOffset += cinst->NumInOperands() - 2;
cid = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
}
} else {
// Consider next composite in insert chain
cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
}
} else if (cinst->opcode() == SpvOpVectorShuffle) {
// Get length of vector1
uint32_t v1_id = cinst->GetSingleWordInOperand(kVectorShuffleVec1IdInIdx);
ir::Instruction* v1_inst = get_def_use_mgr()->GetDef(v1_id);
uint32_t v1_type_id = v1_inst->type_id();
ir::Instruction* v1_type_inst = get_def_use_mgr()->GetDef(v1_type_id);
uint32_t v1_len =
v1_type_inst->GetSingleWordInOperand(kTypeVectorLengthInIdx);
// Get shuffle idx
uint32_t comp_idx = (*pExtIndices)[extOffset];
uint32_t shuffle_idx =
cinst->GetSingleWordInOperand(kVectorShuffleCompsInIdx + comp_idx);
// If undefined, give up
// TODO(greg-lunarg): Return OpUndef
if (shuffle_idx == 0xFFFFFFFF) break;
if (shuffle_idx < v1_len) {
cid = v1_id;
(*pExtIndices)[extOffset] = shuffle_idx;
} else {
cid = cinst->GetSingleWordInOperand(kVectorShuffleVec2IdInIdx);
(*pExtIndices)[extOffset] = shuffle_idx - v1_len;
}
} else if (cinst->opcode() == SpvOpExtInst &&
cinst->GetSingleWordInOperand(kExtInstSetIdInIdx) ==
get_module()->GetExtInstImportId("GLSL.std.450") &&
cinst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
GLSLstd450FMix) {
// If mixing value component is 0 or 1 we just match with x or y.
// Otherwise give up.
uint32_t comp_idx = (*pExtIndices)[extOffset];
std::vector<uint32_t> aIndices = {comp_idx};
uint32_t a_id = cinst->GetSingleWordInOperand(kFMixAIdInIdx);
ir::Instruction* a_inst = get_def_use_mgr()->GetDef(a_id);
uint32_t a_comp_id = DoExtract(a_inst, &aIndices, 0);
if (a_comp_id == 0) break;
ir::Instruction* a_comp_inst = get_def_use_mgr()->GetDef(a_comp_id);
if (a_comp_inst->opcode() != SpvOpConstant) break;
// If a value is not 32-bit, give up
uint32_t a_comp_type_id = a_comp_inst->type_id();
ir::Instruction* a_comp_type = get_def_use_mgr()->GetDef(a_comp_type_id);
if (a_comp_type->GetSingleWordInOperand(kTypeFloatWidthInIdx) != 32)
break;
uint32_t u = a_comp_inst->GetSingleWordInOperand(kConstantValueInIdx);
float* fp = reinterpret_cast<float*>(&u);
if (*fp == 0.0)
cid = cinst->GetSingleWordInOperand(kFMixXIdInIdx);
else if (*fp == 1.0)
cid = cinst->GetSingleWordInOperand(kFMixYIdInIdx);
else
break;
} else {
break;
}
cinst = get_def_use_mgr()->GetDef(cid);
}
// If search ended with CompositeConstruct or ConstantComposite
// and the extract has one index, return the appropriate component.
// TODO(greg-lunarg): Handle multiple-indices, ConstantNull, special
// vector composition, and additional CompositeInsert.
if (replId == 0 &&
(cinst->opcode() == SpvOpCompositeConstruct ||
cinst->opcode() == SpvOpConstantComposite) &&
(*pExtIndices).size() - extOffset == 1) {
uint32_t compIdx = (*pExtIndices)[extOffset];
// If a vector CompositeConstruct we make sure all preceding
// components are of component type (not vector composition).
uint32_t ctype_id = cinst->type_id();
ir::Instruction* ctype_inst = get_def_use_mgr()->GetDef(ctype_id);
if (ctype_inst->opcode() == SpvOpTypeVector &&
cinst->opcode() == SpvOpConstantComposite) {
uint32_t vec_comp_type_id =
ctype_inst->GetSingleWordInOperand(kTypeVectorCompTypeIdInIdx);
if (compIdx < cinst->NumInOperands()) {
uint32_t i = 0;
for (; i <= compIdx; i++) {
uint32_t compId = cinst->GetSingleWordInOperand(i);
ir::Instruction* componentInst = get_def_use_mgr()->GetDef(compId);
if (componentInst->type_id() != vec_comp_type_id) break;
}
if (i > compIdx) replId = cinst->GetSingleWordInOperand(compIdx);
}
} else {
replId = cinst->GetSingleWordInOperand(compIdx);
}
}
return replId;
}
bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) {
bool modified = false;
for (auto bi = func->begin(); bi != func->end(); ++bi) {
@ -72,57 +198,16 @@ bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) {
case SpvOpCompositeExtract: {
uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
ir::Instruction* cinst = get_def_use_mgr()->GetDef(cid);
uint32_t replId = 0;
// Capture extract indices
std::vector<uint32_t> extIndices;
uint32_t icnt = 0;
inst->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
if (icnt > 0) extIndices.push_back(*idp);
++icnt;
});
// Offset of extract indices being compared to insert indices.
// Offset increases as indices are matched.
uint32_t extOffset = 0;
while (cinst->opcode() == SpvOpCompositeInsert) {
if (ExtInsMatch(inst, cinst, extOffset)) {
// Match! Use inserted value as replacement
replId = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
break;
} else if (ExtInsConflict(inst, cinst, extOffset)) {
// If extract has fewer indices than the insert, stop searching.
// Otherwise increment offset of extract indices considered and
// continue searching through the inserted value
if (inst->NumInOperands() - extOffset <
cinst->NumInOperands() - 1) {
break;
} else {
extOffset += cinst->NumInOperands() - 2;
cid = cinst->GetSingleWordInOperand(kInsertObjectIdInIdx);
}
} else {
// Consider next composite in insert chain
cid = cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
}
cinst = get_def_use_mgr()->GetDef(cid);
}
// If search ended with CompositeConstruct or ConstantComposite
// and the extract has one index, return the appropriate component.
// If a vector CompositeConstruct we make sure all preceding
// components are of component type (not vector composition).
// TODO(greg-lunarg): Handle multiple-indices, ConstantNull, special
// vector composition, and additional CompositeInsert.
if ((cinst->opcode() == SpvOpCompositeConstruct ||
cinst->opcode() == SpvOpConstantComposite) &&
inst->NumInOperands() - extOffset == 2) {
uint32_t compIdx = inst->GetSingleWordInOperand(extOffset + 1);
if (IsVectorType(cinst->type_id())) {
if (compIdx < cinst->NumInOperands()) {
uint32_t i = 0;
for (; i <= compIdx; i++) {
uint32_t compId = cinst->GetSingleWordInOperand(i);
ir::Instruction* compInst = get_def_use_mgr()->GetDef(compId);
if (compInst->type_id() != inst->type_id()) break;
}
if (i > compIdx)
replId = cinst->GetSingleWordInOperand(compIdx);
}
} else {
replId = cinst->GetSingleWordInOperand(compIdx);
}
}
uint32_t replId = DoExtract(cinst, &extIndices, 0);
if (replId != 0) {
const uint32_t extId = inst->result_id();
(void)context()->ReplaceAllUsesWith(extId, replId);

View File

@ -40,24 +40,29 @@ class InsertExtractElimPass : public Pass {
Status Process(ir::IRContext*) override;
private:
// Return true if indices of extract |extInst| starting at |extOffset|
// Return true if the extract indices in |extIndices| starting at |extOffset|
// match indices of insert |insInst|.
bool ExtInsMatch(const ir::Instruction* extInst,
bool ExtInsMatch(const std::vector<uint32_t>& extIndices,
const ir::Instruction* insInst,
const uint32_t extOffset) const;
// Return true if indices of extract |extInst| starting at |extOffset| and
// Return true if indices in |extIndices| starting at |extOffset| and
// indices of insert |insInst| conflict, specifically, if the insert
// changes bits specified by the extract, but changes either more bits
// or less bits than the extract specifies, meaning the exact value being
// inserted cannot be used to replace the extract.
bool ExtInsConflict(const ir::Instruction* extInst,
bool ExtInsConflict(const std::vector<uint32_t>& extIndices,
const ir::Instruction* insInst,
const uint32_t extOffset) const;
// Return true if |typeId| is a vector type
bool IsVectorType(uint32_t typeId);
// Return id of component of |cinst| specified by |extIndices| starting with
// index at |extOffset|. Return 0 if indices cannot be matched exactly.
uint32_t DoExtract(ir::Instruction* cinst, std::vector<uint32_t>* extIndices,
uint32_t extOffset);
// Look for OpExtract on sequence of OpInserts in |func|. If there is a
// reaching insert which corresponds to the indices of the extract, replace
// the extract with the value that is inserted. Also resolve extracts from

View File

@ -539,6 +539,272 @@ OpFunctionEnd
true);
}
TEST_F(InsertExtractElimTest, MixWithConstants) {
// Extract component of FMix with 0.0 or 1.0 as the a-value.
//
// Note: The SPIR-V assembly has had store/load elimination
// performed to allow the inserts and extracts to directly
// reference each other.
//
// #version 450
//
// layout (location=0) in float bc;
// layout (location=1) in float bc2;
// layout (location=2) in float m;
// layout (location=3) in float m2;
// layout (location=0) out vec4 OutColor;
//
// void main()
// {
// vec4 bcv = vec4(bc, bc2, 0.0, 1.0);
// vec4 bcv2 = vec4(bc2, bc, 1.0, 0.0);
// vec4 v = mix(bcv, bcv2, vec4(0.0,1.0,m,m2));
// OutColor = vec4(v.y);
// }
const std::string predefs =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %bc %bc2 %m %m2 %OutColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %bc "bc"
OpName %bc2 "bc2"
OpName %m "m"
OpName %m2 "m2"
OpName %OutColor "OutColor"
OpDecorate %bc Location 0
OpDecorate %bc2 Location 1
OpDecorate %m Location 2
OpDecorate %m2 Location 3
OpDecorate %OutColor Location 0
%void = OpTypeVoid
%9 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Function_v4float = OpTypePointer Function %v4float
%_ptr_Input_float = OpTypePointer Input %float
%bc = OpVariable %_ptr_Input_float Input
%bc2 = OpVariable %_ptr_Input_float Input
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%m = OpVariable %_ptr_Input_float Input
%m2 = OpVariable %_ptr_Input_float Input
%_ptr_Output_v4float = OpTypePointer Output %v4float
%OutColor = OpVariable %_ptr_Output_v4float Output
%uint = OpTypeInt 32 0
%_ptr_Function_float = OpTypePointer Function %float
)";
const std::string before =
R"(%main = OpFunction %void None %9
%19 = OpLabel
%20 = OpLoad %float %bc
%21 = OpLoad %float %bc2
%22 = OpCompositeConstruct %v4float %20 %21 %float_0 %float_1
%23 = OpLoad %float %bc2
%24 = OpLoad %float %bc
%25 = OpCompositeConstruct %v4float %23 %24 %float_1 %float_0
%26 = OpLoad %float %m
%27 = OpLoad %float %m2
%28 = OpCompositeConstruct %v4float %float_0 %float_1 %26 %27
%29 = OpExtInst %v4float %1 FMix %22 %25 %28
%30 = OpCompositeExtract %float %29 1
%31 = OpCompositeConstruct %v4float %30 %30 %30 %30
OpStore %OutColor %31
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %9
%19 = OpLabel
%20 = OpLoad %float %bc
%21 = OpLoad %float %bc2
%22 = OpCompositeConstruct %v4float %20 %21 %float_0 %float_1
%23 = OpLoad %float %bc2
%24 = OpLoad %float %bc
%25 = OpCompositeConstruct %v4float %23 %24 %float_1 %float_0
%26 = OpLoad %float %m
%27 = OpLoad %float %m2
%28 = OpCompositeConstruct %v4float %float_0 %float_1 %26 %27
%29 = OpExtInst %v4float %1 FMix %22 %25 %28
%31 = OpCompositeConstruct %v4float %24 %24 %24 %24
OpStore %OutColor %31
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::InsertExtractElimPass>(
predefs + before, predefs + after, true, true);
}
TEST_F(InsertExtractElimTest, VectorShuffle1) {
// Extract component from first vector in VectorShuffle
//
// Note: The SPIR-V assembly has had store/load elimination
// performed to allow the inserts and extracts to directly
// reference each other.
//
// #version 450
//
// layout (location=0) in float bc;
// layout (location=1) in float bc2;
// layout (location=0) out vec4 OutColor;
//
// void main()
// {
// vec4 bcv = vec4(bc, bc2, 0.0, 1.0);
// vec4 v = bcv.zwxy;
// OutColor = vec4(v.y);
// }
const std::string predefs =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %bc "bc"
OpName %bc2 "bc2"
OpName %OutColor "OutColor"
OpDecorate %bc Location 0
OpDecorate %bc2 Location 1
OpDecorate %OutColor Location 0
%void = OpTypeVoid
%7 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Function_v4float = OpTypePointer Function %v4float
%_ptr_Input_float = OpTypePointer Input %float
%bc = OpVariable %_ptr_Input_float Input
%bc2 = OpVariable %_ptr_Input_float Input
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%_ptr_Output_v4float = OpTypePointer Output %v4float
%OutColor = OpVariable %_ptr_Output_v4float Output
%uint = OpTypeInt 32 0
%_ptr_Function_float = OpTypePointer Function %float
)";
const std::string before =
R"(%main = OpFunction %void None %7
%17 = OpLabel
%18 = OpLoad %float %bc
%19 = OpLoad %float %bc2
%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1
%21 = OpVectorShuffle %v4float %20 %20 2 3 0 1
%22 = OpCompositeExtract %float %21 1
%23 = OpCompositeConstruct %v4float %22 %22 %22 %22
OpStore %OutColor %23
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %7
%17 = OpLabel
%18 = OpLoad %float %bc
%19 = OpLoad %float %bc2
%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1
%21 = OpVectorShuffle %v4float %20 %20 2 3 0 1
%23 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1
OpStore %OutColor %23
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::InsertExtractElimPass>(
predefs + before, predefs + after, true, true);
}
TEST_F(InsertExtractElimTest, VectorShuffle2) {
// Extract component from second vector in VectorShuffle
// Identical to test VectorShuffle1 except for the vector
// shuffle index of 7.
//
// Note: The SPIR-V assembly has had store/load elimination
// performed to allow the inserts and extracts to directly
// reference each other.
//
// #version 450
//
// layout (location=0) in float bc;
// layout (location=1) in float bc2;
// layout (location=0) out vec4 OutColor;
//
// void main()
// {
// vec4 bcv = vec4(bc, bc2, 0.0, 1.0);
// vec4 v = bcv.zwxy;
// OutColor = vec4(v.y);
// }
const std::string predefs =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main" %bc %bc2 %OutColor
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %bc "bc"
OpName %bc2 "bc2"
OpName %OutColor "OutColor"
OpDecorate %bc Location 0
OpDecorate %bc2 Location 1
OpDecorate %OutColor Location 0
%void = OpTypeVoid
%7 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_ptr_Function_v4float = OpTypePointer Function %v4float
%_ptr_Input_float = OpTypePointer Input %float
%bc = OpVariable %_ptr_Input_float Input
%bc2 = OpVariable %_ptr_Input_float Input
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%_ptr_Output_v4float = OpTypePointer Output %v4float
%OutColor = OpVariable %_ptr_Output_v4float Output
%uint = OpTypeInt 32 0
%_ptr_Function_float = OpTypePointer Function %float
)";
const std::string before =
R"(%main = OpFunction %void None %7
%17 = OpLabel
%18 = OpLoad %float %bc
%19 = OpLoad %float %bc2
%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1
%21 = OpVectorShuffle %v4float %20 %20 2 7 0 1
%22 = OpCompositeExtract %float %21 1
%23 = OpCompositeConstruct %v4float %22 %22 %22 %22
OpStore %OutColor %23
OpReturn
OpFunctionEnd
)";
const std::string after =
R"(%main = OpFunction %void None %7
%17 = OpLabel
%18 = OpLoad %float %bc
%19 = OpLoad %float %bc2
%20 = OpCompositeConstruct %v4float %18 %19 %float_0 %float_1
%21 = OpVectorShuffle %v4float %20 %20 2 7 0 1
%23 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1
OpStore %OutColor %23
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<opt::InsertExtractElimPass>(
predefs + before, predefs + after, true, true);
}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//