diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index 2ee1dc80..a4ef0c04 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -752,6 +752,32 @@ void CompilerHLSL::emit_binary_func_op_transpose_first(uint32_t result_type, uin } } +void CompilerHLSL::emit_binary_func_op_transpose_second(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1); + emit_op(result_type, result_id, join(op, "(", to_expression(op0), ", transpose(", to_expression(op1), "))"), forward, false); + + if (forward && forced_temporaries.find(result_id) == end(forced_temporaries)) + { + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + } +} + +void CompilerHLSL::emit_binary_func_op_transpose_all(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op) +{ + bool forward = should_forward(op0) && should_forward(op1); + emit_op(result_type, result_id, join("transpose(", op, "(transpose(", to_expression(op0), "), transpose(", to_expression(op1), ")))"), forward, false); + + if (forward && forced_temporaries.find(result_id) == end(forced_temporaries)) + { + inherit_expression_dependencies(result_id, op0); + inherit_expression_dependencies(result_id, op1); + } +} + void CompilerHLSL::emit_instruction(const Instruction &instruction) { auto ops = stream(instruction); @@ -775,12 +801,16 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction) emit_binary_func_op_transpose_first(ops[0], ops[1], ops[2], ops[3], "mul"); break; } - - // TODO: Take care of remaining matrix binops - //case OpMatrixTimesScalar: - //case OpVectorTimesMatrix: - //case OpMatrixTimesMatrix: - + case OpVectorTimesMatrix: + { + emit_binary_func_op_transpose_second(ops[0], ops[1], ops[2], ops[3], "mul"); + break; + } + case OpMatrixTimesMatrix: + { + emit_binary_func_op_transpose_all(ops[0], ops[1], ops[2], ops[3], "mul"); + break; + } default: CompilerGLSL::emit_instruction(instruction); break; diff --git a/spirv_hlsl.hpp b/spirv_hlsl.hpp index 2b8af614..30563087 100644 --- a/spirv_hlsl.hpp +++ b/spirv_hlsl.hpp @@ -60,6 +60,10 @@ private: void emit_instruction(const Instruction &instruction) override; void emit_binary_func_op_transpose_first(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, const char *op); + void emit_binary_func_op_transpose_second(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op); + void emit_binary_func_op_transpose_all(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1, + const char *op); Options options; };