Fix nits from review.

This commit is contained in:
Hans-Kristian Arntzen 2024-01-16 14:16:37 +01:00
parent 6c24be197f
commit b9abac5024
2 changed files with 13 additions and 11 deletions

View File

@ -34,11 +34,14 @@ struct InOut3
int result;
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
kernel void main0(device InOut& comp [[buffer(0)]], device void* spvBufferAliasSet0Binding1 [[buffer(1)]])
{
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
comp.result = reduce_add(int4(comp.x) * int4(comp.y));
comp2.result = uint(reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y))));
comp3.result = addsat(reduce_add(int4(comp3.x) * int4(comp3.y)), comp3.acc);
comp3.result = int4(addsat(reduce_add(int4(comp3.x) * int4(comp3.y)), comp3.acc));
}

View File

@ -9666,7 +9666,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
auto& input_type1 = expression_type(vec1);
auto &input_type1 = expression_type(vec1);
string vec1input, vec2input;
auto input_size = input_type1.vecsize;
@ -9689,15 +9689,15 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
vec2input = to_expression(vec2);
}
auto& type = get<SPIRType>(result_type);
auto &type = get<SPIRType>(result_type);
auto result_type_cast = join(type_to_glsl(type), input_size);
// When the opcode specifies signed integers, we always cast to the signed integer type, regardless of the output type.
string_view type_cast1(result_type_cast);
string type_cast1 = result_type_cast;
if (type_cast1[0] == 'u' && (opcode == OpSDot || opcode == OpSUDot))
type_cast1 = type_cast1.substr(1);
string_view type_cast2(result_type_cast);
string type_cast2 = result_type_cast;
if (type_cast2[0] == 'u' && opcode == OpSDot)
type_cast2 = type_cast2.substr(1);
@ -9718,8 +9718,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t vec2 = ops[3];
uint32_t acc = ops[4];
auto& input_type1 = expression_type(vec1);
auto& input_type2 = expression_type(vec2);
auto &input_type1 = expression_type(vec1);
string vec1input, vec2input;
auto input_size = input_type1.vecsize;
@ -9742,18 +9741,18 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
vec2input = to_expression(vec2);
}
auto& type = get<SPIRType>(result_type);
auto &type = get<SPIRType>(result_type);
auto result_type_cast = join(type_to_glsl(type), input_size);
string_view type_cast1(result_type_cast);
string type_cast1 = result_type_cast;
if (type_cast1[0] == 'u' && (opcode == OpSDotAccSat || opcode == OpSUDotAccSat))
type_cast1 = type_cast1.substr(1);
string_view type_cast2(result_type_cast);
string type_cast2 = result_type_cast;
if (type_cast2[0] == 'u' && opcode == OpSDotAccSat)
type_cast2 = type_cast2.substr(1);
string exp = join("addsat(reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, ")), ", to_expression(acc), ")");
string exp = join(result_type_cast, "(addsat(reduce_add(", std::string(type_cast1), "(", vec1input, ") * ", std::string(type_cast2), "(", vec2input, ")), ", to_expression(acc), "))");
emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);