From fee78cd432f2a1e04f88c41a6c7bacd995a74a85 Mon Sep 17 00:00:00 2001 From: Manos Koukoutos Date: Fri, 16 Dec 2022 13:46:19 +0100 Subject: [PATCH] [wasm-gc] Subtyping support for call_indirect This CL adds subtyping support to call_indirect: signature comparison for call_indirect will now succeed if the real signature of the table element is a canonical subtype of the declared signature. This makes wasm-gc semantics strictly more permissive, i.e., less programs will trap. Drive-by: Since liftoff call_indirect became more complex, we try to make it a little more readable by renaming registers. Bug: v8:7748 Change-Id: I42ba94161269e3a4535193d18bf00b3423e946bf Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3937466 Commit-Queue: Manos Koukoutos Reviewed-by: Jakob Kummerow Cr-Commit-Position: refs/heads/main@{#84903} --- src/compiler/wasm-compiler.cc | 62 +++++- src/compiler/wasm-graph-assembler.cc | 9 + src/compiler/wasm-graph-assembler.h | 3 + src/wasm/baseline/liftoff-compiler.cc | 261 ++++++++++++++++---------- test/cctest/wasm/test-gc.cc | 22 ++- 5 files changed, 248 insertions(+), 109 deletions(-) diff --git a/src/compiler/wasm-compiler.cc b/src/compiler/wasm-compiler.cc index 3b47949040..bc8d0fd674 100644 --- a/src/compiler/wasm-compiler.cc +++ b/src/compiler/wasm-compiler.cc @@ -2857,12 +2857,6 @@ Node* WasmGraphBuilder::BuildIndirectCall(uint32_t table_index, Node* in_bounds = gasm_->Uint32LessThan(key, ift_size); TrapIfFalse(wasm::kTrapTableOutOfBounds, in_bounds, position); - // Check that the table entry is not null and that the type of the function is - // **identical with** the function type declared at the call site (no - // subtyping of functions is allowed). - // Note: Since null entries are identified by having ift_sig_id (-1), we only - // need one comparison. - // TODO(9495): Change this if we should do full function subtyping instead. Node* isorecursive_canonical_types = LOAD_INSTANCE_FIELD(IsorecursiveCanonicalTypes, MachineType::Pointer()); Node* expected_sig_id = @@ -2875,7 +2869,61 @@ Node* WasmGraphBuilder::BuildIndirectCall(uint32_t table_index, int32_scaled_key); Node* sig_match = gasm_->Word32Equal(loaded_sig, expected_sig_id); - TrapIfFalse(wasm::kTrapFuncSigMismatch, sig_match, position); + if (v8_flags.experimental_wasm_gc) { + // Do a full subtyping check. + // TODO(7748): Optimize for non-nullable tables. + // TODO(7748): Optimize if type annotation matches table type. + auto end_label = gasm_->MakeLabel(); + gasm_->GotoIf(sig_match, &end_label); + + // Trap on null element. + TrapIfTrue(wasm::kTrapFuncSigMismatch, + gasm_->Word32Equal(loaded_sig, Int32Constant(-1)), position); + + Node* formal_rtt = RttCanon(sig_index); + int rtt_depth = wasm::GetSubtypingDepth(env_->module, sig_index); + DCHECK_GE(rtt_depth, 0); + + // Since we have the canonical index of the real rtt, we have to load it + // from the isolate rtt-array (which is canonically indexed). Since this + // reference is weak, we have to promote it to a strong reference. + // Note: The reference cannot have been cleared: Since the loaded_sig + // corresponds to a function of the same canonical type, that function will + // have kept the type alive. + Node* rtts = LOAD_ROOT(WasmCanonicalRtts, wasm_canonical_rtts); + Node* real_rtt = + gasm_->WordAnd(gasm_->LoadWeakArrayListElement(rtts, loaded_sig), + gasm_->IntPtrConstant(~kWeakHeapObjectMask)); + Node* type_info = gasm_->LoadWasmTypeInfo(real_rtt); + + // If the depth of the rtt is known to be less than the minimum supertype + // array length, we can access the supertype without bounds-checking the + // supertype array. + if (static_cast(rtt_depth) >= wasm::kMinimumSupertypeArraySize) { + Node* supertypes_length = + gasm_->BuildChangeSmiToIntPtr(gasm_->LoadImmutableFromObject( + MachineType::TaggedSigned(), type_info, + wasm::ObjectAccess::ToTagged( + WasmTypeInfo::kSupertypesLengthOffset))); + TrapIfFalse(wasm::kTrapFuncSigMismatch, + gasm_->UintLessThan(gasm_->IntPtrConstant(rtt_depth), + supertypes_length), + position); + } + + Node* maybe_match = gasm_->LoadImmutableFromObject( + MachineType::TaggedPointer(), type_info, + wasm::ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesOffset + + kTaggedSize * rtt_depth)); + TrapIfFalse(wasm::kTrapFuncSigMismatch, + gasm_->TaggedEqual(maybe_match, formal_rtt), position); + gasm_->Goto(&end_label); + + gasm_->Bind(&end_label); + } else { + // In absence of subtyping, we just need to check for type equality. + TrapIfFalse(wasm::kTrapFuncSigMismatch, sig_match, position); + } Node* key_intptr = gasm_->BuildChangeUint32ToUintPtr(key); diff --git a/src/compiler/wasm-graph-assembler.cc b/src/compiler/wasm-graph-assembler.cc index 8b405e4f6c..5ad8680d24 100644 --- a/src/compiler/wasm-graph-assembler.cc +++ b/src/compiler/wasm-graph-assembler.cc @@ -232,6 +232,15 @@ Node* WasmGraphAssembler::LoadFixedArrayElement(Node* fixed_array, return LoadFromObject(type, fixed_array, offset); } +Node* WasmGraphAssembler::LoadWeakArrayListElement(Node* fixed_array, + Node* index_intptr, + MachineType type) { + Node* offset = IntAdd( + IntMul(index_intptr, IntPtrConstant(kTaggedSize)), + IntPtrConstant(wasm::ObjectAccess::ToTagged(WeakArrayList::kHeaderSize))); + return LoadFromObject(type, fixed_array, offset); +} + Node* WasmGraphAssembler::LoadImmutableFixedArrayElement(Node* fixed_array, Node* index_intptr, MachineType type) { diff --git a/src/compiler/wasm-graph-assembler.h b/src/compiler/wasm-graph-assembler.h index 0cb71ce264..39d3a39597 100644 --- a/src/compiler/wasm-graph-assembler.h +++ b/src/compiler/wasm-graph-assembler.h @@ -211,6 +211,9 @@ class WasmGraphAssembler : public GraphAssembler { ObjectAccess(MachineType::AnyTagged(), kFullWriteBarrier)); } + Node* LoadWeakArrayListElement(Node* fixed_array, Node* index_intptr, + MachineType type = MachineType::AnyTagged()); + // Functions, SharedFunctionInfos, FunctionData. Node* LoadSharedFunctionInfo(Node* js_function); diff --git a/src/wasm/baseline/liftoff-compiler.cc b/src/wasm/baseline/liftoff-compiler.cc index c69f1bcc01..58e2fb689a 100644 --- a/src/wasm/baseline/liftoff-compiler.cc +++ b/src/wasm/baseline/liftoff-compiler.cc @@ -7342,122 +7342,189 @@ class LiftoffCompiler { LiftoffRegList pinned{index}; // Get all temporary registers unconditionally up front. - Register table = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); - Register tmp_const = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); - Register scratch = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); + // We do not use temporary registers directly; instead we rename them as + // appropriate in each scope they are used. + Register tmp1 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); + Register tmp2 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); + Register tmp3 = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); Register indirect_function_table = no_reg; - if (imm.table_imm.index != 0) { - Register indirect_function_tables = + if (imm.table_imm.index > 0) { + indirect_function_table = pinned.set(__ GetUnusedRegister(kGpReg, pinned)).gp(); - LOAD_TAGGED_PTR_INSTANCE_FIELD(indirect_function_tables, + LOAD_TAGGED_PTR_INSTANCE_FIELD(indirect_function_table, IndirectFunctionTables, pinned); - - indirect_function_table = indirect_function_tables; __ LoadTaggedPointer( - indirect_function_table, indirect_function_tables, no_reg, + indirect_function_table, indirect_function_table, no_reg, ObjectAccess::ElementOffsetInTaggedFixedArray(imm.table_imm.index)); } + { + CODE_COMMENT("Check index is in-bounds"); + Register table_size = tmp1; + if (imm.table_imm.index == 0) { + LOAD_INSTANCE_FIELD(table_size, IndirectFunctionTableSize, kUInt32Size, + pinned); + } else { + __ Load(LiftoffRegister(table_size), indirect_function_table, no_reg, + wasm::ObjectAccess::ToTagged( + WasmIndirectFunctionTable::kSizeOffset), + LoadType::kI32Load); + } - // Bounds check against the table size. - Label* invalid_func_label = - AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapTableOutOfBounds); - - // Compare against table size stored in - // {instance->indirect_function_table_size}. - if (imm.table_imm.index == 0) { - LOAD_INSTANCE_FIELD(tmp_const, IndirectFunctionTableSize, kUInt32Size, - pinned); - } else { - __ Load( - LiftoffRegister(tmp_const), indirect_function_table, no_reg, - wasm::ObjectAccess::ToTagged(WasmIndirectFunctionTable::kSizeOffset), - LoadType::kI32Load); + // Bounds check against the table size: Compare against table size stored + // in {instance->indirect_function_table_size}. + Label* out_of_bounds_label = + AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapTableOutOfBounds); + { + FREEZE_STATE(trapping); + __ emit_cond_jump(kUnsignedGreaterEqual, out_of_bounds_label, kI32, + index, table_size, trapping); + } } { - FREEZE_STATE(trapping); - __ emit_cond_jump(kUnsignedGreaterEqual, invalid_func_label, kI32, index, - tmp_const, trapping); - } + CODE_COMMENT("Check indirect call signature"); + Register real_sig_id = tmp1; + Register formal_sig_id = tmp2; - CODE_COMMENT("Check indirect call signature"); - // Load the signature from {instance->ift_sig_ids[key]} - if (imm.table_imm.index == 0) { - LOAD_INSTANCE_FIELD(table, IndirectFunctionTableSigIds, + // Load the signature from {instance->ift_sig_ids[key]} + if (imm.table_imm.index == 0) { + LOAD_INSTANCE_FIELD(real_sig_id, IndirectFunctionTableSigIds, + kSystemPointerSize, pinned); + } else { + __ Load(LiftoffRegister(real_sig_id), indirect_function_table, no_reg, + wasm::ObjectAccess::ToTagged( + WasmIndirectFunctionTable::kSigIdsOffset), + kPointerLoadType); + } + static_assert((1 << 2) == kInt32Size); + __ Load(LiftoffRegister(real_sig_id), real_sig_id, index, 0, + LoadType::kI32Load, nullptr, false, false, true); + + // Compare against expected signature. + LOAD_INSTANCE_FIELD(formal_sig_id, IsorecursiveCanonicalTypes, kSystemPointerSize, pinned); - } else { - __ Load(LiftoffRegister(table), indirect_function_table, no_reg, - wasm::ObjectAccess::ToTagged( - WasmIndirectFunctionTable::kSigIdsOffset), - kPointerLoadType); + __ Load(LiftoffRegister(formal_sig_id), formal_sig_id, no_reg, + imm.sig_imm.index * kInt32Size, LoadType::kI32Load); + + Label* sig_mismatch_label = + AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapFuncSigMismatch); + __ DropValues(1); + + if (v8_flags.experimental_wasm_gc) { + Label success_label; + FREEZE_STATE(frozen); + __ emit_cond_jump(kEqual, &success_label, kI32, real_sig_id, + formal_sig_id, frozen); + __ emit_i32_cond_jumpi(kEqual, sig_mismatch_label, real_sig_id, -1, + frozen); + Register real_rtt = tmp3; + LOAD_INSTANCE_FIELD(real_rtt, IsolateRoot, kSystemPointerSize, pinned); + __ LoadFullPointer( + real_rtt, real_rtt, + IsolateData::root_slot_offset(RootIndex::kWasmCanonicalRtts)); + __ LoadTaggedPointer(real_rtt, real_rtt, real_sig_id, + ObjectAccess::ToTagged(WeakArrayList::kHeaderSize), + true); + // Remove the weak reference tag. + if (kSystemPointerSize == 4) { + __ emit_i32_andi(real_rtt, real_rtt, + static_cast(~kWeakHeapObjectMask)); + } else { + __ emit_i64_andi(LiftoffRegister(real_rtt), LiftoffRegister(real_rtt), + static_cast(~kWeakHeapObjectMask)); + } + // Constant-time subtyping check: load exactly one candidate RTT from + // the supertypes list. + // Step 1: load the WasmTypeInfo. + constexpr int kTypeInfoOffset = wasm::ObjectAccess::ToTagged( + Map::kConstructorOrBackPointerOrNativeContextOffset); + Register type_info = real_rtt; + __ LoadTaggedPointer(type_info, real_rtt, no_reg, kTypeInfoOffset); + // Step 2: check the list's length if needed. + uint32_t rtt_depth = + GetSubtypingDepth(decoder->module_, imm.sig_imm.index); + if (rtt_depth >= kMinimumSupertypeArraySize) { + LiftoffRegister list_length(formal_sig_id); + int offset = + ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesLengthOffset); + __ LoadSmiAsInt32(list_length, type_info, offset); + __ emit_i32_cond_jumpi(kUnsignedLessEqual, sig_mismatch_label, + list_length.gp(), rtt_depth, frozen); + } + // Step 3: load the candidate list slot, and compare it. + Register maybe_match = type_info; + __ LoadTaggedPointer( + maybe_match, type_info, no_reg, + ObjectAccess::ToTagged(WasmTypeInfo::kSupertypesOffset + + rtt_depth * kTaggedSize)); + Register formal_rtt = formal_sig_id; + LOAD_TAGGED_PTR_INSTANCE_FIELD(formal_rtt, ManagedObjectMaps, pinned); + __ LoadTaggedPointer( + formal_rtt, formal_rtt, no_reg, + wasm::ObjectAccess::ElementOffsetInTaggedFixedArray( + imm.sig_imm.index)); + __ emit_cond_jump(kUnequal, sig_mismatch_label, kRtt, formal_rtt, + maybe_match, frozen); + + __ bind(&success_label); + } else { + FREEZE_STATE(trapping); + __ emit_cond_jump(kUnequal, sig_mismatch_label, kI32, real_sig_id, + formal_sig_id, trapping); + } } - static_assert((1 << 2) == kInt32Size); - __ Load(LiftoffRegister(scratch), table, index, 0, LoadType::kI32Load, - nullptr, false, false, true); - - // Compare against expected signature. - LOAD_INSTANCE_FIELD(tmp_const, IsorecursiveCanonicalTypes, - kSystemPointerSize, pinned); - __ Load(LiftoffRegister(tmp_const), tmp_const, no_reg, - imm.sig_imm.index * kInt32Size, LoadType::kI32Load); - - Label* sig_mismatch_label = - AddOutOfLineTrap(decoder, WasmCode::kThrowWasmTrapFuncSigMismatch); - __ DropValues(1); { - FREEZE_STATE(trapping); - __ emit_cond_jump(kUnequal, sig_mismatch_label, kIntPtrKind, scratch, - tmp_const, trapping); - } + CODE_COMMENT("Execute indirect call"); - CODE_COMMENT("Execute indirect call"); - // At this point {index} has already been multiplied by kTaggedSize. + Register function_instance = tmp1; + Register function_target = tmp2; - // Load the instance from {instance->ift_instances[key]} - if (imm.table_imm.index == 0) { - LOAD_TAGGED_PTR_INSTANCE_FIELD(table, IndirectFunctionTableRefs, pinned); - } else { - __ LoadTaggedPointer( - table, indirect_function_table, no_reg, - wasm::ObjectAccess::ToTagged(WasmIndirectFunctionTable::kRefsOffset)); - } - __ LoadTaggedPointer(tmp_const, table, index, - ObjectAccess::ElementOffsetInTaggedFixedArray(0), - true); + // Load the instance from {instance->ift_instances[key]} + if (imm.table_imm.index == 0) { + LOAD_TAGGED_PTR_INSTANCE_FIELD(function_instance, + IndirectFunctionTableRefs, pinned); + } else { + __ LoadTaggedPointer(function_instance, indirect_function_table, no_reg, + wasm::ObjectAccess::ToTagged( + WasmIndirectFunctionTable::kRefsOffset)); + } + __ LoadTaggedPointer(function_instance, function_instance, index, + ObjectAccess::ElementOffsetInTaggedFixedArray(0), + true); - Register* explicit_instance = &tmp_const; + // Load the target from {instance->ift_targets[key]} + if (imm.table_imm.index == 0) { + LOAD_INSTANCE_FIELD(function_target, IndirectFunctionTableTargets, + kSystemPointerSize, pinned); + } else { + __ Load(LiftoffRegister(function_target), indirect_function_table, + no_reg, + wasm::ObjectAccess::ToTagged( + WasmIndirectFunctionTable::kTargetsOffset), + kPointerLoadType); + } + __ Load(LiftoffRegister(function_target), function_target, index, 0, + kPointerLoadType, nullptr, false, false, true); - // Load the target from {instance->ift_targets[key]} - if (imm.table_imm.index == 0) { - LOAD_INSTANCE_FIELD(table, IndirectFunctionTableTargets, - kSystemPointerSize, pinned); - } else { - __ Load(LiftoffRegister(table), indirect_function_table, no_reg, - wasm::ObjectAccess::ToTagged( - WasmIndirectFunctionTable::kTargetsOffset), - kPointerLoadType); - } - __ Load(LiftoffRegister(scratch), table, index, 0, kPointerLoadType, - nullptr, false, false, true); + auto call_descriptor = + compiler::GetWasmCallDescriptor(compilation_zone_, imm.sig); + call_descriptor = + GetLoweredCallDescriptor(compilation_zone_, call_descriptor); - auto call_descriptor = - compiler::GetWasmCallDescriptor(compilation_zone_, imm.sig); - call_descriptor = - GetLoweredCallDescriptor(compilation_zone_, call_descriptor); + __ PrepareCall(&sig, call_descriptor, &function_target, + &function_instance); + if (tail_call) { + __ PrepareTailCall( + static_cast(call_descriptor->ParameterSlotCount()), + static_cast( + call_descriptor->GetStackParameterDelta(descriptor_))); + __ TailCallIndirect(function_target); + } else { + source_position_table_builder_.AddPosition( + __ pc_offset(), SourcePosition(decoder->position()), true); + __ CallIndirect(&sig, call_descriptor, function_target); - Register target = scratch; - __ PrepareCall(&sig, call_descriptor, &target, explicit_instance); - if (tail_call) { - __ PrepareTailCall( - static_cast(call_descriptor->ParameterSlotCount()), - static_cast( - call_descriptor->GetStackParameterDelta(descriptor_))); - __ TailCallIndirect(target); - } else { - source_position_table_builder_.AddPosition( - __ pc_offset(), SourcePosition(decoder->position()), true); - __ CallIndirect(&sig, call_descriptor, target); - - FinishCall(decoder, &sig, call_descriptor); + FinishCall(decoder, &sig, call_descriptor); + } } } diff --git a/test/cctest/wasm/test-gc.cc b/test/cctest/wasm/test-gc.cc index 6bb4bd1a77..864c4dac23 100644 --- a/test/cctest/wasm/test-gc.cc +++ b/test/cctest/wasm/test-gc.cc @@ -97,7 +97,7 @@ class WasmGCTester { } byte DefineSignature(FunctionSig* sig, uint32_t supertype = kNoSuperType) { - return builder_.AddSignature(sig, supertype); + return builder_.ForceAddSignature(sig, supertype); } byte DefineTable(ValueType type, uint32_t min_size, uint32_t max_size) { @@ -1986,6 +1986,7 @@ WASM_COMPILED_EXEC_TEST(GlobalInitReferencingGlobal) { WASM_COMPILED_EXEC_TEST(GCTables) { WasmGCTester tester(execution_tier); + tester.builder()->StartRecursiveTypeGroup(); byte super_struct = tester.DefineStruct({F(kWasmI32, false)}); byte sub_struct = tester.DefineStruct({F(kWasmI32, false), F(kWasmI32, true)}, super_struct); @@ -1995,6 +1996,8 @@ WASM_COMPILED_EXEC_TEST(GCTables) { FunctionSig* sub_sig = FunctionSig::Build(tester.zone(), {kWasmI32}, {refNull(super_struct)}); byte sub_sig_index = tester.DefineSignature(sub_sig, super_sig_index); + byte unrelated_sig_index = tester.DefineSignature(sub_sig, super_sig_index); + tester.builder()->EndRecursiveTypeGroup(); tester.DefineTable(refNull(super_sig_index), 10, 10); @@ -2012,8 +2015,8 @@ WASM_COMPILED_EXEC_TEST(GCTables) { tester.sigs.i_v(), {}, {WASM_TABLE_SET(0, WASM_I32V(0), WASM_REF_NULL(super_sig_index)), WASM_TABLE_SET(0, WASM_I32V(1), WASM_REF_FUNC(super_func)), - WASM_TABLE_SET(0, WASM_I32V(2), WASM_REF_FUNC(sub_func)), WASM_I32V(0), - WASM_END}); + WASM_TABLE_SET(0, WASM_I32V(2), WASM_REF_FUNC(sub_func)), // -- + WASM_I32V(0), WASM_END}); byte super_struct_producer = tester.DefineFunction( FunctionSig::Build(tester.zone(), {ref(super_struct)}, {}), {}, @@ -2045,12 +2048,20 @@ WASM_COMPILED_EXEC_TEST(GCTables) { WASM_CALL_FUNCTION0(super_struct_producer), WASM_I32V(2)), WASM_END}); + // Calling with a signature that is a subtype of the type of the table should + // work, provided the entry has a subtype of the declared signature. + byte call_table_subtype_entry_subtype = tester.DefineFunction( + tester.sigs.i_v(), {}, + {WASM_CALL_INDIRECT(super_sig_index, + WASM_CALL_FUNCTION0(sub_struct_producer), + WASM_I32V(2)), + WASM_END}); // Calling with a signature that is mismatched to that of the entry should // trap. byte call_type_mismatch = tester.DefineFunction( tester.sigs.i_v(), {}, - {WASM_CALL_INDIRECT(super_sig_index, - WASM_CALL_FUNCTION0(sub_struct_producer), + {WASM_CALL_INDIRECT(unrelated_sig_index, + WASM_CALL_FUNCTION0(super_struct_producer), WASM_I32V(2)), WASM_END}); // Getting a table element and then calling it with call_ref should work. @@ -2072,6 +2083,7 @@ WASM_COMPILED_EXEC_TEST(GCTables) { tester.CheckHasThrown(call_null); tester.CheckResult(call_same_type, 18); tester.CheckResult(call_subtype, -5); + tester.CheckResult(call_table_subtype_entry_subtype, 7); tester.CheckHasThrown(call_type_mismatch); tester.CheckResult(table_get_and_call_ref, 7); }