[wasm-gc] Subtyping support for call_indirect

This CL adds subtyping support to call_indirect: signature comparison
for call_indirect will now succeed if the real signature of the table
element is a canonical subtype of the declared signature. This makes
wasm-gc semantics strictly more permissive, i.e., less programs will
trap.
Drive-by: Since liftoff call_indirect became more complex, we try to
make it a little more readable by renaming registers.

Bug: v8:7748
Change-Id: I42ba94161269e3a4535193d18bf00b3423e946bf
Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3937466
Commit-Queue: Manos Koukoutos <manoskouk@chromium.org>
Reviewed-by: Jakob Kummerow <jkummerow@chromium.org>
Cr-Commit-Position: refs/heads/main@{#84903}
This commit is contained in:
Manos Koukoutos 2022-12-16 13:46:19 +01:00 committed by V8 LUCI CQ
parent 30bc957217
commit fee78cd432
5 changed files with 248 additions and 109 deletions

View File

@ -2857,12 +2857,6 @@ Node* WasmGraphBuilder::BuildIndirectCall(uint32_t table_index,
Node* in_bounds = gasm_->Uint32LessThan(key, ift_size);
TrapIfFalse(wasm::kTrapTableOutOfBounds, in_bounds, position);
// Check that the table entry is not null and that the type of the function is
// **identical with** the function type declared at the call site (no
// subtyping of functions is allowed).
// Note: Since null entries are identified by having ift_sig_id (-1), we only
// need one comparison.
// TODO(9495): Change this if we should do full function subtyping instead.
Node* isorecursive_canonical_types =
LOAD_INSTANCE_FIELD(IsorecursiveCanonicalTypes, MachineType::Pointer());
Node* expected_sig_id =
@ -2875,7 +2869,61 @@ Node* WasmGraphBuilder::BuildIndirectCall(uint32_t table_index,
int32_scaled_key);
Node* sig_match = gasm_->Word32Equal(loaded_sig, expected_sig_id);
TrapIfFalse(wasm::kTrapFuncSigMismatch, sig_match, position);
if (v8_flags.experimental_wasm_gc) {
// Do a full subtyping check.
// TODO(7748): Optimize for non-nullable tables.
// TODO(7748): Optimize if type annotation matches table type.
auto end_label = gasm_->MakeLabel();
gasm_->GotoIf(sig_match, &end_label);
// Trap on null element.
TrapIfTrue(wasm::kTrapFuncSigMismatch,
gasm_->Word32Equal(loaded_sig, Int32Constant(-1)), position);
Node* formal_rtt = RttCanon(sig_index);
int rtt_depth = wasm::GetSubtypingDepth(env_->module, sig_index);
DCHECK_GE(rtt_depth, 0);
// Since we have the canonical index of the real rtt, we have to load it
// from the isolate rtt-array (which is canonically indexed). Since this
// reference is weak, we have to promote it to a strong reference.
// Note: The reference cannot have been cleared: Since the loaded_sig
// corresponds to a function of the same canonical type, that function will
// have kept the type alive.
Node* rtts = LOAD_ROOT(WasmCanonicalRtts, wasm_canonical_rtts);
Node* real_rtt =
gasm_->WordAnd(gasm_->LoadWeakArrayListElement(rtts, loaded_sig),
gasm_->IntPtrConstant(~kWeakHeapObjectMask));
Node* type_info = gasm_->LoadWasmTypeInfo(real_rtt);
// If the depth of the rtt is known to be less than the minimum supertype
// array length, we can access the supertype without bounds-checking the
// supertype array.
if (static_cast<uint32_t>(rtt_depth) >= wasm::kMinimumSupertypeArraySize) {
Node* supertypes_length =
gasm_->BuildChangeSmiToIntPtr(gasm_->LoadImmutableFromObject(
MachineType::TaggedSigned(), type_info,
wasm::ObjectAccess::ToTagged(
WasmTypeInfo::kSupertypesLengthOffset)));
TrapIfFalse(wasm::kTrapFuncSigMismatch,
gasm_->UintLessThan(gasm_->IntPtrConstant(rtt_depth),
supertypes_length),
position);
}
Node* maybe_match = gasm_->LoadImmutableFromObject(
MachineType::TaggedPointer(), type_info,
wasm::ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesOffset +
kTaggedSize * rtt_depth));
TrapIfFalse(wasm::kTrapFuncSigMismatch,
gasm_->TaggedEqual(maybe_match, formal_rtt), position);
gasm_->Goto(&end_label);
gasm_->Bind(&end_label);
} else {
// In absence of subtyping, we just need to check for type equality.
TrapIfFalse(wasm::kTrapFuncSigMismatch, sig_match, position);
}
Node* key_intptr = gasm_->BuildChangeUint32ToUintPtr(key);

View File

@ -232,6 +232,15 @@ Node* WasmGraphAssembler::LoadFixedArrayElement(Node* fixed_array,
return LoadFromObject(type, fixed_array, offset);
}
Node* WasmGraphAssembler::LoadWeakArrayListElement(Node* fixed_array,
Node* index_intptr,
MachineType type) {
Node* offset = IntAdd(
IntMul(index_intptr, IntPtrConstant(kTaggedSize)),
IntPtrConstant(wasm::ObjectAccess::ToTagged(WeakArrayList::kHeaderSize)));
return LoadFromObject(type, fixed_array, offset);
}
Node* WasmGraphAssembler::LoadImmutableFixedArrayElement(Node* fixed_array,
Node* index_intptr,
MachineType type) {

View File

@ -211,6 +211,9 @@ class WasmGraphAssembler : public GraphAssembler {
ObjectAccess(MachineType::AnyTagged(), kFullWriteBarrier));
}
Node* LoadWeakArrayListElement(Node* fixed_array, Node* index_intptr,
MachineType type = MachineType::AnyTagged());
// Functions, SharedFunctionInfos, FunctionData.
Node* LoadSharedFunctionInfo(Node* js_function);

View File

@ -7342,122 +7342,189 @@ class LiftoffCompiler {
LiftoffRegList pinned{index};
// Get all temporary registers unconditionally up front.
Register table = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
Register tmp_const = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
Register scratch = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
// We do not use temporary registers directly; instead we rename them as
// appropriate in each scope they are used.
Register tmp1 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
Register tmp2 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
Register tmp3 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
Register indirect_function_table = no_reg;
if (imm.table_imm.index != 0) {
Register indirect_function_tables =
if (imm.table_imm.index > 0) {
indirect_function_table =
pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp();
LOAD_TAGGED_PTR_INSTANCE_FIELD(indirect_function_tables,
LOAD_TAGGED_PTR_INSTANCE_FIELD(indirect_function_table,
IndirectFunctionTables, pinned);
indirect_function_table = indirect_function_tables;
__ LoadTaggedPointer(
indirect_function_table, indirect_function_tables, no_reg,
indirect_function_table, indirect_function_table, no_reg,
ObjectAccess::ElementOffsetInTaggedFixedArray(imm.table_imm.index));
}
{
CODE_COMMENT("Check index is in-bounds");
Register table_size = tmp1;
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(table_size, IndirectFunctionTableSize, kUInt32Size,
pinned);
} else {
__ Load(LiftoffRegister(table_size), indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kSizeOffset),
LoadType::kI32Load);
}
// Bounds check against the table size.
Label* invalid_func_label =
AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapTableOutOfBounds);
// Compare against table size stored in
// {instance->indirect_function_table_size}.
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(tmp_const, IndirectFunctionTableSize, kUInt32Size,
pinned);
} else {
__ Load(
LiftoffRegister(tmp_const), indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(WasmIndirectFunctionTable::kSizeOffset),
LoadType::kI32Load);
// Bounds check against the table size: Compare against table size stored
// in {instance->indirect_function_table_size}.
Label* out_of_bounds_label =
AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapTableOutOfBounds);
{
FREEZE_STATE(trapping);
__ emit_cond_jump(kUnsignedGreaterEqual, out_of_bounds_label, kI32,
index, table_size, trapping);
}
}
{
FREEZE_STATE(trapping);
__ emit_cond_jump(kUnsignedGreaterEqual, invalid_func_label, kI32, index,
tmp_const, trapping);
}
CODE_COMMENT("Check indirect call signature");
Register real_sig_id = tmp1;
Register formal_sig_id = tmp2;
CODE_COMMENT("Check indirect call signature");
// Load the signature from {instance->ift_sig_ids[key]}
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(table, IndirectFunctionTableSigIds,
// Load the signature from {instance->ift_sig_ids[key]}
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(real_sig_id, IndirectFunctionTableSigIds,
kSystemPointerSize, pinned);
} else {
__ Load(LiftoffRegister(real_sig_id), indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kSigIdsOffset),
kPointerLoadType);
}
static_assert((1 << 2) == kInt32Size);
__ Load(LiftoffRegister(real_sig_id), real_sig_id, index, 0,
LoadType::kI32Load, nullptr, false, false, true);
// Compare against expected signature.
LOAD_INSTANCE_FIELD(formal_sig_id, IsorecursiveCanonicalTypes,
kSystemPointerSize, pinned);
} else {
__ Load(LiftoffRegister(table), indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kSigIdsOffset),
kPointerLoadType);
__ Load(LiftoffRegister(formal_sig_id), formal_sig_id, no_reg,
imm.sig_imm.index * kInt32Size, LoadType::kI32Load);
Label* sig_mismatch_label =
AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapFuncSigMismatch);
__ DropValues(1);
if (v8_flags.experimental_wasm_gc) {
Label success_label;
FREEZE_STATE(frozen);
__ emit_cond_jump(kEqual, &success_label, kI32, real_sig_id,
formal_sig_id, frozen);
__ emit_i32_cond_jumpi(kEqual, sig_mismatch_label, real_sig_id, -1,
frozen);
Register real_rtt = tmp3;
LOAD_INSTANCE_FIELD(real_rtt, IsolateRoot, kSystemPointerSize, pinned);
__ LoadFullPointer(
real_rtt, real_rtt,
IsolateData::root_slot_offset(RootIndex::kWasmCanonicalRtts));
__ LoadTaggedPointer(real_rtt, real_rtt, real_sig_id,
ObjectAccess::ToTagged(WeakArrayList::kHeaderSize),
true);
// Remove the weak reference tag.
if (kSystemPointerSize == 4) {
__ emit_i32_andi(real_rtt, real_rtt,
static_cast<int32_t>(~kWeakHeapObjectMask));
} else {
__ emit_i64_andi(LiftoffRegister(real_rtt), LiftoffRegister(real_rtt),
static_cast<int64_t>(~kWeakHeapObjectMask));
}
// Constant-time subtyping check: load exactly one candidate RTT from
// the supertypes list.
// Step 1: load the WasmTypeInfo.
constexpr int kTypeInfoOffset = wasm::ObjectAccess::ToTagged(
Map::kConstructorOrBackPointerOrNativeContextOffset);
Register type_info = real_rtt;
__ LoadTaggedPointer(type_info, real_rtt, no_reg, kTypeInfoOffset);
// Step 2: check the list's length if needed.
uint32_t rtt_depth =
GetSubtypingDepth(decoder->module_, imm.sig_imm.index);
if (rtt_depth >= kMinimumSupertypeArraySize) {
LiftoffRegister list_length(formal_sig_id);
int offset =
ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesLengthOffset);
__ LoadSmiAsInt32(list_length, type_info, offset);
__ emit_i32_cond_jumpi(kUnsignedLessEqual, sig_mismatch_label,
list_length.gp(), rtt_depth, frozen);
}
// Step 3: load the candidate list slot, and compare it.
Register maybe_match = type_info;
__ LoadTaggedPointer(
maybe_match, type_info, no_reg,
ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesOffset +
rtt_depth * kTaggedSize));
Register formal_rtt = formal_sig_id;
LOAD_TAGGED_PTR_INSTANCE_FIELD(formal_rtt, ManagedObjectMaps, pinned);
__ LoadTaggedPointer(
formal_rtt, formal_rtt, no_reg,
wasm::ObjectAccess::ElementOffsetInTaggedFixedArray(
imm.sig_imm.index));
__ emit_cond_jump(kUnequal, sig_mismatch_label, kRtt, formal_rtt,
maybe_match, frozen);
__ bind(&success_label);
} else {
FREEZE_STATE(trapping);
__ emit_cond_jump(kUnequal, sig_mismatch_label, kI32, real_sig_id,
formal_sig_id, trapping);
}
}
static_assert((1 << 2) == kInt32Size);
__ Load(LiftoffRegister(scratch), table, index, 0, LoadType::kI32Load,
nullptr, false, false, true);
// Compare against expected signature.
LOAD_INSTANCE_FIELD(tmp_const, IsorecursiveCanonicalTypes,
kSystemPointerSize, pinned);
__ Load(LiftoffRegister(tmp_const), tmp_const, no_reg,
imm.sig_imm.index * kInt32Size, LoadType::kI32Load);
Label* sig_mismatch_label =
AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapFuncSigMismatch);
__ DropValues(1);
{
FREEZE_STATE(trapping);
__ emit_cond_jump(kUnequal, sig_mismatch_label, kIntPtrKind, scratch,
tmp_const, trapping);
}
CODE_COMMENT("Execute indirect call");
CODE_COMMENT("Execute indirect call");
// At this point {index} has already been multiplied by kTaggedSize.
Register function_instance = tmp1;
Register function_target = tmp2;
// Load the instance from {instance->ift_instances[key]}
if (imm.table_imm.index == 0) {
LOAD_TAGGED_PTR_INSTANCE_FIELD(table, IndirectFunctionTableRefs, pinned);
} else {
__ LoadTaggedPointer(
table, indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(WasmIndirectFunctionTable::kRefsOffset));
}
__ LoadTaggedPointer(tmp_const, table, index,
ObjectAccess::ElementOffsetInTaggedFixedArray(0),
true);
// Load the instance from {instance->ift_instances[key]}
if (imm.table_imm.index == 0) {
LOAD_TAGGED_PTR_INSTANCE_FIELD(function_instance,
IndirectFunctionTableRefs, pinned);
} else {
__ LoadTaggedPointer(function_instance, indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kRefsOffset));
}
__ LoadTaggedPointer(function_instance, function_instance, index,
ObjectAccess::ElementOffsetInTaggedFixedArray(0),
true);
Register* explicit_instance = &tmp_const;
// Load the target from {instance->ift_targets[key]}
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(function_target, IndirectFunctionTableTargets,
kSystemPointerSize, pinned);
} else {
__ Load(LiftoffRegister(function_target), indirect_function_table,
no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kTargetsOffset),
kPointerLoadType);
}
__ Load(LiftoffRegister(function_target), function_target, index, 0,
kPointerLoadType, nullptr, false, false, true);
// Load the target from {instance->ift_targets[key]}
if (imm.table_imm.index == 0) {
LOAD_INSTANCE_FIELD(table, IndirectFunctionTableTargets,
kSystemPointerSize, pinned);
} else {
__ Load(LiftoffRegister(table), indirect_function_table, no_reg,
wasm::ObjectAccess::ToTagged(
WasmIndirectFunctionTable::kTargetsOffset),
kPointerLoadType);
}
__ Load(LiftoffRegister(scratch), table, index, 0, kPointerLoadType,
nullptr, false, false, true);
auto call_descriptor =
compiler::GetWasmCallDescriptor(compilation_zone_, imm.sig);
call_descriptor =
GetLoweredCallDescriptor(compilation_zone_, call_descriptor);
auto call_descriptor =
compiler::GetWasmCallDescriptor(compilation_zone_, imm.sig);
call_descriptor =
GetLoweredCallDescriptor(compilation_zone_, call_descriptor);
__ PrepareCall(&sig, call_descriptor, &function_target,
&function_instance);
if (tail_call) {
__ PrepareTailCall(
static_cast<int>(call_descriptor->ParameterSlotCount()),
static_cast<int>(
call_descriptor->GetStackParameterDelta(descriptor_)));
__ TailCallIndirect(function_target);
} else {
source_position_table_builder_.AddPosition(
__ pc_offset(), SourcePosition(decoder->position()), true);
__ CallIndirect(&sig, call_descriptor, function_target);
Register target = scratch;
__ PrepareCall(&sig, call_descriptor, &target, explicit_instance);
if (tail_call) {
__ PrepareTailCall(
static_cast<int>(call_descriptor->ParameterSlotCount()),
static_cast<int>(
call_descriptor->GetStackParameterDelta(descriptor_)));
__ TailCallIndirect(target);
} else {
source_position_table_builder_.AddPosition(
__ pc_offset(), SourcePosition(decoder->position()), true);
__ CallIndirect(&sig, call_descriptor, target);
FinishCall(decoder, &sig, call_descriptor);
FinishCall(decoder, &sig, call_descriptor);
}
}
}

View File

@ -97,7 +97,7 @@ class WasmGCTester {
}
byte DefineSignature(FunctionSig* sig, uint32_t supertype = kNoSuperType) {
return builder_.AddSignature(sig, supertype);
return builder_.ForceAddSignature(sig, supertype);
}
byte DefineTable(ValueType type, uint32_t min_size, uint32_t max_size) {
@ -1986,6 +1986,7 @@ WASM_COMPILED_EXEC_TEST(GlobalInitReferencingGlobal) {
WASM_COMPILED_EXEC_TEST(GCTables) {
WasmGCTester tester(execution_tier);
tester.builder()->StartRecursiveTypeGroup();
byte super_struct = tester.DefineStruct({F(kWasmI32, false)});
byte sub_struct = tester.DefineStruct({F(kWasmI32, false), F(kWasmI32, true)},
super_struct);
@ -1995,6 +1996,8 @@ WASM_COMPILED_EXEC_TEST(GCTables) {
FunctionSig* sub_sig =
FunctionSig::Build(tester.zone(), {kWasmI32}, {refNull(super_struct)});
byte sub_sig_index = tester.DefineSignature(sub_sig, super_sig_index);
byte unrelated_sig_index = tester.DefineSignature(sub_sig, super_sig_index);
tester.builder()->EndRecursiveTypeGroup();
tester.DefineTable(refNull(super_sig_index), 10, 10);
@ -2012,8 +2015,8 @@ WASM_COMPILED_EXEC_TEST(GCTables) {
tester.sigs.i_v(), {},
{WASM_TABLE_SET(0, WASM_I32V(0), WASM_REF_NULL(super_sig_index)),
WASM_TABLE_SET(0, WASM_I32V(1), WASM_REF_FUNC(super_func)),
WASM_TABLE_SET(0, WASM_I32V(2), WASM_REF_FUNC(sub_func)), WASM_I32V(0),
WASM_END});
WASM_TABLE_SET(0, WASM_I32V(2), WASM_REF_FUNC(sub_func)), // --
WASM_I32V(0), WASM_END});
byte super_struct_producer = tester.DefineFunction(
FunctionSig::Build(tester.zone(), {ref(super_struct)}, {}), {},
@ -2045,12 +2048,20 @@ WASM_COMPILED_EXEC_TEST(GCTables) {
WASM_CALL_FUNCTION0(super_struct_producer),
WASM_I32V(2)),
WASM_END});
// Calling with a signature that is a subtype of the type of the table should
// work, provided the entry has a subtype of the declared signature.
byte call_table_subtype_entry_subtype = tester.DefineFunction(
tester.sigs.i_v(), {},
{WASM_CALL_INDIRECT(super_sig_index,
WASM_CALL_FUNCTION0(sub_struct_producer),
WASM_I32V(2)),
WASM_END});
// Calling with a signature that is mismatched to that of the entry should
// trap.
byte call_type_mismatch = tester.DefineFunction(
tester.sigs.i_v(), {},
{WASM_CALL_INDIRECT(super_sig_index,
WASM_CALL_FUNCTION0(sub_struct_producer),
{WASM_CALL_INDIRECT(unrelated_sig_index,
WASM_CALL_FUNCTION0(super_struct_producer),
WASM_I32V(2)),
WASM_END});
// Getting a table element and then calling it with call_ref should work.
@ -2072,6 +2083,7 @@ WASM_COMPILED_EXEC_TEST(GCTables) {
tester.CheckHasThrown(call_null);
tester.CheckResult(call_same_type, 18);
tester.CheckResult(call_subtype, -5);
tester.CheckResult(call_table_subtype_entry_subtype, 7);
tester.CheckHasThrown(call_type_mismatch);
tester.CheckResult(table_get_and_call_ref, 7);
}