From e99a5c033ea2c933d0403b97db590212fb5a2542 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Wed, 24 Jul 2024 14:38:19 +0200 Subject: [PATCH] spirv-link: allow linking functions with different pointer arguments (#5534) * linker: run dedup earlier Otherwise `linkings_to_do` might end up with stale IDs. * linker: allow linking functions with different pointer arguments Since llvm-17 there are no typed pointers and hte SPIRV-LLVM-Translator doesn't know the function signature of imported functions. I'm investigating different ways of solving this problem and adding an option to work around it inside `spirv-link` is one of those. The code is almost complete, just I'm having troubles constructing the bitcast to cast the pointer parameters to the final type. Closes: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/2153 * test/linker: add tests to test the AllowPtrTypeMismatch feature --- include/spirv-tools/linker.hpp | 7 +- source/link/linker.cpp | 123 +++++- .../link/matching_imports_to_exports_test.cpp | 393 +++++++++++++++++- tools/link/linker.cpp | 24 +- 4 files changed, 498 insertions(+), 49 deletions(-) diff --git a/include/spirv-tools/linker.hpp b/include/spirv-tools/linker.hpp index 6ba6e9654..9037b9488 100644 --- a/include/spirv-tools/linker.hpp +++ b/include/spirv-tools/linker.hpp @@ -16,7 +16,6 @@ #define INCLUDE_SPIRV_TOOLS_LINKER_HPP_ #include - #include #include @@ -63,11 +62,17 @@ class SPIRV_TOOLS_EXPORT LinkerOptions { use_highest_version_ = use_highest_vers; } + bool GetAllowPtrTypeMismatch() const { return allow_ptr_type_mismatch_; } + void SetAllowPtrTypeMismatch(bool allow_ptr_type_mismatch) { + allow_ptr_type_mismatch_ = allow_ptr_type_mismatch; + } + private: bool create_library_{false}; bool verify_ids_{false}; bool allow_partial_linkage_{false}; bool use_highest_version_{false}; + bool allow_ptr_type_mismatch_{false}; }; // Links one or more SPIR-V modules into a new SPIR-V module. That is, combine diff --git a/source/link/linker.cpp b/source/link/linker.cpp index 58930e452..e6aa72e32 100644 --- a/source/link/linker.cpp +++ b/source/link/linker.cpp @@ -31,6 +31,7 @@ #include "source/opt/build_module.h" #include "source/opt/compact_ids_pass.h" #include "source/opt/decoration_manager.h" +#include "source/opt/ir_builder.h" #include "source/opt/ir_loader.h" #include "source/opt/pass_manager.h" #include "source/opt/remove_duplicates_pass.h" @@ -46,12 +47,14 @@ namespace spvtools { namespace { using opt::Instruction; +using opt::InstructionBuilder; using opt::IRContext; using opt::Module; using opt::PassManager; using opt::RemoveDuplicatesPass; using opt::analysis::DecorationManager; using opt::analysis::DefUseManager; +using opt::analysis::Function; using opt::analysis::Type; using opt::analysis::TypeManager; @@ -126,6 +129,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer, // checked. spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, const LinkageTable& linkings_to_do, + bool allow_ptr_type_mismatch, opt::IRContext* context); // Remove linkage specific instructions, such as prototypes of imported @@ -502,6 +506,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer, spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, const LinkageTable& linkings_to_do, + bool allow_ptr_type_mismatch, opt::IRContext* context) { spv_position_t position = {}; @@ -513,7 +518,34 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, type_manager.GetType(linking_entry.imported_symbol.type_id); Type* exported_symbol_type = type_manager.GetType(linking_entry.exported_symbol.type_id); - if (!(*imported_symbol_type == *exported_symbol_type)) + if (!(*imported_symbol_type == *exported_symbol_type)) { + Function* imported_symbol_type_func = imported_symbol_type->AsFunction(); + Function* exported_symbol_type_func = exported_symbol_type->AsFunction(); + + if (imported_symbol_type_func && exported_symbol_type_func) { + const auto& imported_params = imported_symbol_type_func->param_types(); + const auto& exported_params = exported_symbol_type_func->param_types(); + // allow_ptr_type_mismatch allows linking functions where the pointer + // type of arguments doesn't match. Everything else still needs to be + // equal. This is to workaround LLVM-17+ not having typed pointers and + // generated SPIR-Vs not knowing the actual pointer types in some cases. + if (allow_ptr_type_mismatch && + imported_params.size() == exported_params.size()) { + bool correct = true; + for (size_t i = 0; i < imported_params.size(); i++) { + const auto& imported_param = imported_params[i]; + const auto& exported_param = exported_params[i]; + + if (!imported_param->IsSame(exported_param) && + (imported_param->kind() != Type::kPointer || + exported_param->kind() != Type::kPointer)) { + correct = false; + break; + } + } + if (correct) continue; + } + } return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY) << "Type mismatch on symbol \"" << linking_entry.imported_symbol.name @@ -521,6 +553,7 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer, << linking_entry.imported_symbol.id << " and exported variable/function %" << linking_entry.exported_symbol.id << "."; + } } // Ensure the import and export decorations are similar @@ -696,6 +729,57 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer, return SPV_SUCCESS; } +spv_result_t FixFunctionCallTypes(opt::IRContext& context, + const LinkageTable& linkings) { + auto mod = context.module(); + const auto type_manager = context.get_type_mgr(); + const auto def_use_mgr = context.get_def_use_mgr(); + + for (auto& func : *mod) { + func.ForEachInst([&](Instruction* inst) { + if (inst->opcode() != spv::Op::OpFunctionCall) return; + opt::Operand& target = inst->GetInOperand(0); + + // only fix calls to imported functions + auto linking = std::find_if( + linkings.begin(), linkings.end(), [&](const auto& entry) { + return entry.exported_symbol.id == target.AsId(); + }); + if (linking == linkings.end()) return; + + auto builder = InstructionBuilder(&context, inst); + for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { + auto exported_func_param = + def_use_mgr->GetDef(linking->exported_symbol.parameter_ids[i - 1]); + const Type* target_type = + type_manager->GetType(exported_func_param->type_id()); + if (target_type->kind() != Type::kPointer) continue; + + opt::Operand& arg = inst->GetInOperand(i); + const Type* param_type = + type_manager->GetType(def_use_mgr->GetDef(arg.AsId())->type_id()); + + // No need to cast if it already matches + if (*param_type == *target_type) continue; + + auto new_id = context.TakeNextId(); + + // cast to the expected pointer type + builder.AddInstruction(MakeUnique( + &context, spv::Op::OpBitcast, exported_func_param->type_id(), + new_id, + opt::Instruction::OperandList( + {{SPV_OPERAND_TYPE_ID, {arg.AsId()}}}))); + + inst->SetInOperand(i, {new_id}); + } + }); + } + context.InvalidateAnalyses(opt::IRContext::kAnalysisDefUse | + opt::IRContext::kAnalysisInstrToBlockMapping); + return SPV_SUCCESS; +} + } // namespace spv_result_t Link(const Context& context, @@ -773,7 +857,14 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries, if (res != SPV_SUCCESS) return res; } - // Phase 4: Find the import/export pairs + // Phase 4: Remove duplicates + PassManager manager; + manager.SetMessageConsumer(consumer); + manager.AddPass(); + opt::Pass::Status pass_res = manager.Run(&linked_context); + if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; + + // Phase 5: Find the import/export pairs LinkageTable linkings_to_do; res = GetImportExportPairs(consumer, linked_context, *linked_context.get_def_use_mgr(), @@ -781,18 +872,12 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries, options.GetAllowPartialLinkage(), &linkings_to_do); if (res != SPV_SUCCESS) return res; - // Phase 5: Ensure the import and export have the same types and decorations. - res = - CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context); + // Phase 6: Ensure the import and export have the same types and decorations. + res = CheckImportExportCompatibility(consumer, linkings_to_do, + options.GetAllowPtrTypeMismatch(), + &linked_context); if (res != SPV_SUCCESS) return res; - // Phase 6: Remove duplicates - PassManager manager; - manager.SetMessageConsumer(consumer); - manager.AddPass(); - opt::Pass::Status pass_res = manager.Run(&linked_context); - if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; - // Phase 7: Remove all names and decorations of import variables/functions for (const auto& linking_entry : linkings_to_do) { linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id); @@ -815,21 +900,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries, &linked_context); if (res != SPV_SUCCESS) return res; - // Phase 10: Compact the IDs used in the module + // Phase 10: Optionally fix function call types + if (options.GetAllowPtrTypeMismatch()) { + res = FixFunctionCallTypes(linked_context, linkings_to_do); + if (res != SPV_SUCCESS) return res; + } + + // Phase 11: Compact the IDs used in the module manager.AddPass(); pass_res = manager.Run(&linked_context); if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; - // Phase 11: Recompute EntryPoint variables + // Phase 12: Recompute EntryPoint variables manager.AddPass(); pass_res = manager.Run(&linked_context); if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA; - // Phase 12: Warn if SPIR-V limits were exceeded + // Phase 13: Warn if SPIR-V limits were exceeded res = VerifyLimits(consumer, linked_context); if (res != SPV_SUCCESS) return res; - // Phase 13: Output the module + // Phase 14: Output the module linked_context.module()->ToBinary(linked_binary, true); return SPV_SUCCESS; diff --git a/test/link/matching_imports_to_exports_test.cpp b/test/link/matching_imports_to_exports_test.cpp index 6b02fc46d..c7c962fa2 100644 --- a/test/link/matching_imports_to_exports_test.cpp +++ b/test/link/matching_imports_to_exports_test.cpp @@ -174,14 +174,18 @@ OpDecorate %1 LinkageAttributes "foo" Export %1 = OpVariable %2 Uniform %3 )"; - spvtest::Binary linked_binary; - EXPECT_EQ(SPV_ERROR_INVALID_BINARY, - AssembleAndLink({body1, body2}, &linked_binary)) - << GetErrorMessage(); - EXPECT_THAT( - GetErrorMessage(), - HasSubstr("Type mismatch on symbol \"foo\" between imported " - "variable/function %1 and exported variable/function %4")); + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %4")); + } } TEST_F(MatchingImportsToExports, MultipleDefinitions) { @@ -216,13 +220,17 @@ OpDecorate %1 LinkageAttributes "foo" Export %1 = OpVariable %2 Uniform %3 )"; - spvtest::Binary linked_binary; - EXPECT_EQ(SPV_ERROR_INVALID_BINARY, - AssembleAndLink({body1, body2, body3}, &linked_binary)) - << GetErrorMessage(); - EXPECT_THAT(GetErrorMessage(), - HasSubstr("Too many external references, 2, were found " - "for \"foo\".")); + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2, body3}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT(GetErrorMessage(), + HasSubstr("Too many external references, 2, were found " + "for \"foo\".")); + } } TEST_F(MatchingImportsToExports, SameNameDifferentTypes) { @@ -289,14 +297,18 @@ OpDecorate %1 LinkageAttributes "foo" Export %1 = OpVariable %2 Uniform %3 )"; - spvtest::Binary linked_binary; - EXPECT_EQ(SPV_ERROR_INVALID_BINARY, - AssembleAndLink({body1, body2}, &linked_binary)) - << GetErrorMessage(); - EXPECT_THAT( - GetErrorMessage(), - HasSubstr("Type mismatch on symbol \"foo\" between imported " - "variable/function %1 and exported variable/function %4")); + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %4")); + } } TEST_F(MatchingImportsToExports, @@ -557,5 +569,340 @@ OpFunctionEnd EXPECT_EQ(expected_res, res_body); } +TEST_F(MatchingImportsToExports, FunctionCall) { + const std::string body1 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %3 "param" +OpDecorate %1 LinkageAttributes "foo" Import + %5 = OpTypeVoid + %6 = OpTypeInt 32 0 + %9 = OpTypePointer Function %6 + %7 = OpTypeFunction %5 %9 + %1 = OpFunction %5 None %7 + %3 = OpFunctionParameter %9 +OpFunctionEnd + %8 = OpFunction %5 None %7 + %4 = OpFunctionParameter %9 +%10 = OpLabel +%11 = OpFunctionCall %5 %1 %4 +OpReturn +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpDecorate %1 LinkageAttributes "foo" Export +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%7 = OpTypePointer Function %4 +%5 = OpTypeFunction %3 %7 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %7 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + ASSERT_EQ(SPV_SUCCESS, + AssembleAndLink({body1, body2}, &linked_binary, options)) + << GetErrorMessage(); + + const std::string expected_res = R"(OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpModuleProcessed "Linked by SPIR-V Tools Linker" +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypePointer Function %4 +%6 = OpTypeFunction %3 %5 +%7 = OpFunction %3 None %6 +%8 = OpFunctionParameter %5 +%9 = OpLabel +%10 = OpFunctionCall %3 %1 %8 +OpReturn +OpFunctionEnd +%1 = OpFunction %3 None %6 +%2 = OpFunctionParameter %5 +%11 = OpLabel +OpReturn +OpFunctionEnd +)"; + + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + ASSERT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); + } +} + +TEST_F(MatchingImportsToExports, FunctionSignatureMismatchPointer) { + const std::string body1 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %3 "param" +OpDecorate %1 LinkageAttributes "foo" Import + %5 = OpTypeVoid + %6 = OpTypeInt 8 0 + %9 = OpTypePointer Function %6 + %7 = OpTypeFunction %5 %9 + %1 = OpFunction %5 None %7 + %3 = OpFunctionParameter %9 +OpFunctionEnd + %8 = OpFunction %5 None %7 + %4 = OpFunctionParameter %9 +%10 = OpLabel +%11 = OpFunctionCall %5 %1 %4 +OpReturn +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpDecorate %1 LinkageAttributes "foo" Export +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%7 = OpTypePointer Function %4 +%5 = OpTypeFunction %3 %7 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %7 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + spvtest::Binary linked_binary; + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %11")); + + LinkerOptions options; + options.SetAllowPtrTypeMismatch(true); + ASSERT_EQ(SPV_SUCCESS, + AssembleAndLink({body1, body2}, &linked_binary, options)) + << GetErrorMessage(); + + const std::string expected_res = R"(OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpModuleProcessed "Linked by SPIR-V Tools Linker" +%3 = OpTypeVoid +%4 = OpTypeInt 8 0 +%5 = OpTypePointer Function %4 +%6 = OpTypeFunction %3 %5 +%7 = OpTypeInt 32 0 +%8 = OpTypePointer Function %7 +%9 = OpTypeFunction %3 %8 +%10 = OpFunction %3 None %6 +%11 = OpFunctionParameter %5 +%12 = OpLabel +%13 = OpBitcast %8 %11 +%14 = OpFunctionCall %3 %1 %13 +OpReturn +OpFunctionEnd +%1 = OpFunction %3 None %9 +%2 = OpFunctionParameter %8 +%15 = OpLabel +OpReturn +OpFunctionEnd +)"; + std::string res_body; + SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER); + ASSERT_EQ(SPV_SUCCESS, Disassemble(linked_binary, &res_body)) + << GetErrorMessage(); + EXPECT_EQ(expected_res, res_body); +} + +TEST_F(MatchingImportsToExports, FunctionSignatureMismatchValue) { + const std::string body1 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %3 "param" +OpDecorate %1 LinkageAttributes "foo" Import + %5 = OpTypeVoid + %6 = OpTypeInt 8 0 + %7 = OpTypeFunction %5 %6 + %1 = OpFunction %5 None %7 + %3 = OpFunctionParameter %6 +OpFunctionEnd + %8 = OpFunction %5 None %7 + %4 = OpFunctionParameter %6 +%10 = OpLabel +%11 = OpFunctionCall %5 %1 %4 +OpReturn +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpDecorate %1 LinkageAttributes "foo" Export +%3 = OpTypeVoid +%4 = OpTypeInt 32 0 +%5 = OpTypeFunction %3 %4 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %4 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %10")); + } +} + +TEST_F(MatchingImportsToExports, FunctionSignatureMismatchTypePointerInt) { + const std::string body1 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %3 "param" +OpDecorate %1 LinkageAttributes "foo" Import + %5 = OpTypeVoid + %6 = OpTypeInt 64 0 + %7 = OpTypeFunction %5 %6 + %1 = OpFunction %5 None %7 + %3 = OpFunctionParameter %6 +OpFunctionEnd + %8 = OpFunction %5 None %7 + %4 = OpFunctionParameter %6 +%10 = OpLabel +%11 = OpFunctionCall %5 %1 %4 +OpReturn +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpDecorate %1 LinkageAttributes "foo" Export +%3 = OpTypeVoid +%4 = OpTypeInt 64 0 +%7 = OpTypePointer Function %4 +%5 = OpTypeFunction %3 %7 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %7 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %10")); + } +} + +TEST_F(MatchingImportsToExports, FunctionSignatureMismatchTypeIntPointer) { + const std::string body1 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %3 "param" +OpDecorate %1 LinkageAttributes "foo" Import + %5 = OpTypeVoid + %6 = OpTypeInt 64 0 + %9 = OpTypePointer Function %6 + %7 = OpTypeFunction %5 %9 + %1 = OpFunction %5 None %7 + %3 = OpFunctionParameter %9 +OpFunctionEnd + %8 = OpFunction %5 None %7 + %4 = OpFunctionParameter %9 +%10 = OpLabel +%11 = OpFunctionCall %5 %1 %4 +OpReturn +OpFunctionEnd +)"; + const std::string body2 = R"( +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpMemoryModel Physical64 OpenCL +OpName %1 "foo" +OpName %2 "param" +OpDecorate %1 LinkageAttributes "foo" Export +%3 = OpTypeVoid +%4 = OpTypeInt 64 0 +%5 = OpTypeFunction %3 %4 +%1 = OpFunction %3 None %5 +%2 = OpFunctionParameter %4 +%6 = OpLabel +OpReturn +OpFunctionEnd +)"; + + LinkerOptions options; + for (int i = 0; i < 2; i++) { + spvtest::Binary linked_binary; + options.SetAllowPtrTypeMismatch(i == 1); + ASSERT_EQ(SPV_ERROR_INVALID_BINARY, + AssembleAndLink({body1, body2}, &linked_binary)) + << GetErrorMessage(); + EXPECT_THAT( + GetErrorMessage(), + HasSubstr("Type mismatch on symbol \"foo\" between imported " + "variable/function %1 and exported variable/function %11")); + } +} + } // namespace } // namespace spvtools diff --git a/tools/link/linker.cpp b/tools/link/linker.cpp index f3898aab0..7dbd8596f 100644 --- a/tools/link/linker.cpp +++ b/tools/link/linker.cpp @@ -48,6 +48,10 @@ Options (in lexicographical order): --allow-partial-linkage Allow partial linkage by accepting imported symbols to be unresolved. + --allow-pointer-mismatch + Allow pointer function parameters to mismatch the target link + target. This is useful to workaround lost correct parameter type + information due to LLVM's opaque pointers. --create-library Link the binaries into a library, keeping all exported symbols. -h, --help @@ -77,15 +81,16 @@ Options (in lexicographical order): } // namespace // clang-format off -FLAG_SHORT_bool( h, /* default_value= */ false, /* required= */ false); -FLAG_LONG_bool( help, /* default_value= */ false, /* required= */false); -FLAG_LONG_bool( version, /* default_value= */ false, /* required= */ false); -FLAG_LONG_bool( verify_ids, /* default_value= */ false, /* required= */ false); -FLAG_LONG_bool( create_library, /* default_value= */ false, /* required= */ false); -FLAG_LONG_bool( allow_partial_linkage, /* default_value= */ false, /* required= */ false); -FLAG_SHORT_string(o, /* default_value= */ "", /* required= */ false); -FLAG_LONG_string( target_env, /* default_value= */ kDefaultEnvironment, /* required= */ false); -FLAG_LONG_bool( use_highest_version, /* default_value= */ false, /* required= */ false); +FLAG_SHORT_bool( h, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( help, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( version, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( verify_ids, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( create_library, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( allow_partial_linkage, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( allow_pointer_mismatch, /* default_value= */ false, /* required= */ false); +FLAG_SHORT_string(o, /* default_value= */ "", /* required= */ false); +FLAG_LONG_string( target_env, /* default_value= */ kDefaultEnvironment, /* required= */ false); +FLAG_LONG_bool( use_highest_version, /* default_value= */ false, /* required= */ false); // clang-format on int main(int, const char* argv[]) { @@ -126,6 +131,7 @@ int main(int, const char* argv[]) { spvtools::LinkerOptions options; options.SetAllowPartialLinkage(flags::allow_partial_linkage.value()); + options.SetAllowPtrTypeMismatch(flags::allow_pointer_mismatch.value()); options.SetCreateLibrary(flags::create_library.value()); options.SetVerifyIds(flags::verify_ids.value()); options.SetUseHighestVersion(flags::use_highest_version.value());