MSL: Add missing casts to Op?MulExtended.

It is possible to pass unsigned integers to `OpSMulExtended`. In that
case, we want to do a signed multiply with sign extension, so make sure
the operands are forced to be interpreted as signed.

This was an oversight on my part when I added these instructions.

Fixes the CTS test
`dEQP-VK.spirv_assembly.instruction.compute.signed_op.uint_smulextended`.
This commit is contained in:
Chip Davis 2022-11-08 16:00:06 -08:00
parent edd66a2fc9
commit 51d2dfe02a
4 changed files with 136 additions and 4 deletions

View File

@ -0,0 +1,25 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct _4
{
uint _m0[1];
};
struct _20
{
uint _m0;
uint _m1;
};
kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_20 _28;
_28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
_28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
_7._m0[gl_GlobalInvocationID.x] = _28._m0;
_8._m0[gl_GlobalInvocationID.x] = _28._m1;
}

View File

@ -0,0 +1,25 @@
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct _4
{
uint _m0[1];
};
struct _20
{
uint _m0;
uint _m1;
};
kernel void main0(device _4& _5 [[buffer(0)]], device _4& _6 [[buffer(1)]], device _4& _7 [[buffer(2)]], device _4& _8 [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
_20 _28;
_28._m0 = uint(int(_5._m0[gl_GlobalInvocationID.x]) * int(_6._m0[gl_GlobalInvocationID.x]));
_28._m1 = uint(mulhi(int(_5._m0[gl_GlobalInvocationID.x]), int(_6._m0[gl_GlobalInvocationID.x])));
_7._m0[gl_GlobalInvocationID.x] = _28._m0;
_8._m0[gl_GlobalInvocationID.x] = _28._m1;
}

View File

@ -0,0 +1,61 @@
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %gl_GlobalInvocationId
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %gl_GlobalInvocationId BuiltIn GlobalInvocationId
OpDecorate %ra_uint ArrayStride 4
OpDecorate %struct_uint4 BufferBlock
OpMemberDecorate %struct_uint4 0 Offset 0
OpDecorate %input0 DescriptorSet 0
OpDecorate %input0 Binding 0
OpDecorate %input1 DescriptorSet 0
OpDecorate %input1 Binding 1
OpDecorate %output0 DescriptorSet 0
OpDecorate %output0 Binding 2
OpDecorate %output1 DescriptorSet 0
OpDecorate %output1 Binding 3
%uint = OpTypeInt 32 0
%ptr_uint = OpTypePointer Uniform %uint
%ptr_input_uint = OpTypePointer Input %uint
%uint3 = OpTypeVector %uint 3
%ptr_input_uint3 = OpTypePointer Input %uint3
%void = OpTypeVoid
%voidFn = OpTypeFunction %void
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%ra_uint = OpTypeRuntimeArray %uint
%uint4 = OpTypeVector %uint 4
%struct_uint4 = OpTypeStruct %ra_uint
%ptr_struct_uint4 = OpTypePointer Uniform %struct_uint4
%resulttype = OpTypeStruct %uint %uint
%gl_GlobalInvocationId = OpVariable %ptr_input_uint3 Input
%input0 = OpVariable %ptr_struct_uint4 Uniform
%input1 = OpVariable %ptr_struct_uint4 Uniform
%output0 = OpVariable %ptr_struct_uint4 Uniform
%output1 = OpVariable %ptr_struct_uint4 Uniform
%main = OpFunction %void None %voidFn
%mainStart = OpLabel
%index_ptr = OpAccessChain %ptr_input_uint %gl_GlobalInvocationId %uint_0
%index = OpLoad %uint %index_ptr
%in_ptr0 = OpAccessChain %ptr_uint %input0 %uint_0 %index
%invalue0 = OpLoad %uint %in_ptr0
%in_ptr1 = OpAccessChain %ptr_uint %input1 %uint_0 %index
%invalue1 = OpLoad %uint %in_ptr1
%outvalue = OpSMulExtended %resulttype %invalue0 %invalue1
%outvalue0 = OpCompositeExtract %uint %outvalue 0
%out_ptr0 = OpAccessChain %ptr_uint %output0 %uint_0 %index
OpStore %out_ptr0 %outvalue0
%outvalue1 = OpCompositeExtract %uint %outvalue 1
%out_ptr1 = OpAccessChain %ptr_uint %output1 %uint_0 %index
OpStore %out_ptr1 %outvalue1
OpReturn
OpFunctionEnd

View File

@ -8937,12 +8937,33 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t op0 = ops[2];
uint32_t op1 = ops[3];
auto &type = get<SPIRType>(result_type);
auto input_type = opcode == OpSMulExtended ? int_type : uint_type;
auto &output_type = get_type(result_type);
string cast_op0, cast_op1;
auto expected_type = binary_op_bitcast_helper(cast_op0, cast_op1, input_type, op0, op1, false);
emit_uninitialized_temporary_expression(result_type, result_id);
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(",
to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");");
string mullo_expr, mulhi_expr;
mullo_expr = join(cast_op0, " * ", cast_op1);
mulhi_expr = join("mulhi(", cast_op0, ", ", cast_op1, ")");
auto &low_type = get_type(output_type.member_types[0]);
auto &high_type = get_type(output_type.member_types[1]);
if (low_type.basetype != input_type)
{
expected_type.basetype = input_type;
mullo_expr = join(bitcast_glsl_op(low_type, expected_type), "(", mullo_expr, ")");
}
if (high_type.basetype != input_type)
{
expected_type.basetype = input_type;
mulhi_expr = join(bitcast_glsl_op(high_type, expected_type), "(", mulhi_expr, ")");
}
statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", mullo_expr, ";");
statement(to_expression(result_id), ".", to_member_name(type, 1), " = ", mulhi_expr, ";");
break;
}