Implement source extraction logic for spirv-objdump (#5150)

* dump: add ability to extract HLSL from module

Only adds the code to extract the source from the module. The extracted files are written to the given directory.
Android NDK21 has C++17 support, but no std::filesystem support. (NDK22). As for now, the tool is not built on Android.
Might be something to revisit is the need to have this tool on Android arises.
This commit is contained in:
Nathan Gauër 2023-03-22 23:57:18 +01:00 committed by GitHub
parent 5aab2a8fef
commit 5f4e694e10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 479 additions and 38 deletions

View File

@ -318,6 +318,7 @@ cc_binary(
":tools_util",
":spirv_tools_internal",
":spirv_tools_opt_internal",
"@spirv_headers//:spirv_cpp_headers",
],
)

View File

@ -26,4 +26,6 @@ add_spvtools_unittest(
DEFINES TESTING=1)
add_subdirectory(opt)
if(NOT (${CMAKE_SYSTEM_NAME} STREQUAL "Android"))
add_subdirectory(objdump)
endif ()

View File

@ -27,7 +27,7 @@ namespace {
constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
std::pair<bool, std::unordered_map<std::string, std::string>> extractSource(
std::pair<bool, std::unordered_map<std::string, std::string>> ExtractSource(
const std::string& spv_source) {
std::unique_ptr<spvtools::opt::IRContext> ctx = spvtools::BuildModule(
kDefaultEnvironment, spvtools::utils::CLIMessageConsumer, spv_source,
@ -36,7 +36,7 @@ std::pair<bool, std::unordered_map<std::string, std::string>> extractSource(
std::vector<uint32_t> binary;
ctx->module()->ToBinary(&binary, /* skip_nop = */ false);
std::unordered_map<std::string, std::string> output;
bool result = extract_source_from_module(binary, &output);
bool result = ExtractSourceFromModule(binary, &output);
return std::make_pair(result, std::move(output));
}
@ -57,7 +57,209 @@ TEST(ExtractSourceTest, no_debug) {
OpFunctionEnd
)";
auto[success, result] = extractSource(source);
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 0);
}
TEST(ExtractSourceTest, SimpleSource) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2 "[numthreads(1, 1, 1)] void compute_1(){ }"
OpName %1 "compute_1"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 1 41
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] ==
"[numthreads(1, 1, 1)] void compute_1(){ }");
}
TEST(ExtractSourceTest, SourceContinued) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2 "[numthreads(1, 1, 1)] "
OpSourceContinued "void compute_1(){ }"
OpName %1 "compute_1"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 1 41
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] ==
"[numthreads(1, 1, 1)] void compute_1(){ }");
}
TEST(ExtractSourceTest, OnlyFilename) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2
OpName %1 "compute_1"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 1 41
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] == "");
}
TEST(ExtractSourceTest, MultipleFiles) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute1.hlsl"
%3 = OpString "compute2.hlsl"
OpSource HLSL 660 %2 "some instruction"
OpSource HLSL 660 %3 "some other instruction"
OpName %1 "compute_1"
%4 = OpTypeVoid
%5 = OpTypeFunction %4
%1 = OpFunction %4 None %5
%6 = OpLabel
OpLine %2 1 41
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 2);
ASSERT_TRUE(result["compute1.hlsl"] == "some instruction");
ASSERT_TRUE(result["compute2.hlsl"] == "some other instruction");
}
TEST(ExtractSourceTest, MultilineCode) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2 "[numthreads(1, 1, 1)]
void compute_1() {
}
"
OpName %1 "compute_1"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 3 1
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] ==
"[numthreads(1, 1, 1)]\nvoid compute_1() {\n}\n");
}
TEST(ExtractSourceTest, EmptyFilename) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute_1"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString ""
OpSource HLSL 660 %2 "void compute(){}"
OpName %1 "compute_1"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 3 1
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["unnamed-0.hlsl"] == "void compute(){}");
}
TEST(ExtractSourceTest, EscapeEscaped) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2 "// check \" escape removed"
OpName %1 "compute"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 6 1
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] == "// check \" escape removed");
}
TEST(ExtractSourceTest, OpSourceWithNoSource) {
std::string source = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "compute"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "compute.hlsl"
OpSource HLSL 660 %2
OpName %1 "compute"
%3 = OpTypeVoid
%4 = OpTypeFunction %3
%1 = OpFunction %3 None %4
%5 = OpLabel
OpLine %2 6 1
OpReturn
OpFunctionEnd
)";
auto[success, result] = ExtractSource(source);
ASSERT_TRUE(success);
ASSERT_TRUE(result.size() == 1);
ASSERT_TRUE(result["compute.hlsl"] == "");
}

View File

@ -16,7 +16,6 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
add_subdirectory(lesspipe)
endif()
add_subdirectory(emacs)
#add_subdirectory(objdump)
# Add a SPIR-V Tools command line tool. Signature:
# add_spvtools_tool(
@ -66,7 +65,10 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
LIBS ${SPIRV_TOOLS_FULL_VISIBILITY})
target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt
spirv-cfg spirv-link spirv-lint)
if(NOT (${CMAKE_SYSTEM_NAME} STREQUAL "Android"))
add_spvtools_tool(TARGET spirv-objdump
SRCS objdump/objdump.cpp
objdump/extract_source.cpp
@ -75,9 +77,9 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
LIBS ${SPIRV_TOOLS_FULL_VISIBILITY})
target_include_directories(spirv-objdump PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-objdump)
endif()
set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt
spirv-cfg spirv-link spirv-lint spirv-objdump)
if(NOT (${CMAKE_SYSTEM_NAME} STREQUAL "iOS"))
set(SPIRV_INSTALL_TARGETS ${SPIRV_INSTALL_TARGETS} spirv-reduce)
endif()

View File

@ -14,43 +14,200 @@
#include "extract_source.h"
#include <cassert>
#include <string>
#include <unordered_map>
#include <vector>
#include "source/opt/log.h"
#include "spirv-tools/libspirv.hpp"
#include "spirv/unified1/spirv.hpp"
#include "tools/util/cli_consumer.h"
namespace {
constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
// Extract a string literal from a given range.
// Copies all the characters from `begin` to the first '\0' it encounters, while
// removing escape patterns.
// Not finding a '\0' before reaching `end` fails the extraction.
//
// Returns `true` if the extraction succeeded.
// `output` value is undefined if false is returned.
spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
const char* end, std::string* output) {
size_t sourceLength = std::distance(begin, end);
std::string escapedString;
escapedString.resize(sourceLength);
size_t writeIndex = 0;
size_t readIndex = 0;
for (; readIndex < sourceLength; writeIndex++, readIndex++) {
const char read = begin[readIndex];
if (read == '\0') {
escapedString.resize(writeIndex);
output->append(escapedString);
return SPV_SUCCESS;
}
if (read == '\\') {
++readIndex;
}
escapedString[writeIndex] = begin[readIndex];
}
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
"Missing NULL terminator for literal string.");
return SPV_ERROR_INVALID_BINARY;
}
spv_result_t extractOpString(const spv_position_t& loc,
const spv_parsed_instruction_t& instruction,
std::string* output) {
assert(output != nullptr);
assert(instruction.opcode == spv::Op::OpString);
if (instruction.num_operands != 2) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
"Missing operands for OpString.");
return SPV_ERROR_INVALID_BINARY;
}
const auto& operand = instruction.operands[1];
const char* stringBegin =
reinterpret_cast<const char*>(instruction.words + operand.offset);
const char* stringEnd = reinterpret_cast<const char*>(
instruction.words + operand.offset + operand.num_words);
return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
}
spv_result_t extractOpSourceContinued(
const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
std::string* output) {
assert(output != nullptr);
assert(instruction.opcode == spv::Op::OpSourceContinued);
if (instruction.num_operands != 1) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
"Missing operands for OpSourceContinued.");
return SPV_ERROR_INVALID_BINARY;
}
const auto& operand = instruction.operands[0];
const char* stringBegin =
reinterpret_cast<const char*>(instruction.words + operand.offset);
const char* stringEnd = reinterpret_cast<const char*>(
instruction.words + operand.offset + operand.num_words);
return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
}
spv_result_t extractOpSource(const spv_position_t& loc,
const spv_parsed_instruction_t& instruction,
spv::Id* filename, std::string* code) {
assert(filename != nullptr && code != nullptr);
assert(instruction.opcode == spv::Op::OpSource);
// OpCode [ Source Language | Version | File (optional) | Source (optional) ]
if (instruction.num_words < 3) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
"Missing operands for OpSource.");
return SPV_ERROR_INVALID_BINARY;
}
*filename = 0;
*code = "";
if (instruction.num_words < 4) {
return SPV_SUCCESS;
}
*filename = instruction.words[3];
if (instruction.num_words < 5) {
return SPV_SUCCESS;
}
const char* stringBegin =
reinterpret_cast<const char*>(instruction.words + 4);
const char* stringEnd =
reinterpret_cast<const char*>(instruction.words + instruction.num_words);
return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
}
} // namespace
bool extract_source_from_module(
bool ExtractSourceFromModule(
const std::vector<uint32_t>& binary,
std::unordered_map<std::string, std::string>* output) {
auto context = spvtools::SpirvTools(kDefaultEnvironment);
context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
spvtools::HeaderParser headerParser =
[](const spv_endianness_t endianess,
const spv_parsed_header_t& instruction) {
(void)endianess;
(void)instruction;
// There is nothing valuable in the header.
spvtools::HeaderParser headerParser = [](const spv_endianness_t,
const spv_parsed_header_t&) {
return SPV_SUCCESS;
};
std::unordered_map<uint32_t, std::string> stringMap;
std::vector<std::pair<spv::Id, std::string>> sources;
spv::Op lastOpcode = spv::Op::OpMax;
size_t instructionIndex = 0;
spvtools::InstructionParser instructionParser =
[](const spv_parsed_instruction_t& instruction) {
(void)instruction;
return SPV_SUCCESS;
[&stringMap, &sources, &lastOpcode,
&instructionIndex](const spv_parsed_instruction_t& instruction) {
const spv_position_t loc = {0, 0, instructionIndex + 1};
spv_result_t result = SPV_SUCCESS;
if (instruction.opcode == spv::Op::OpString) {
std::string content;
result = extractOpString(loc, instruction, &content);
if (result == SPV_SUCCESS) {
stringMap.emplace(instruction.result_id, std::move(content));
}
} else if (instruction.opcode == spv::Op::OpSource) {
spv::Id filenameId;
std::string code;
result = extractOpSource(loc, instruction, &filenameId, &code);
if (result == SPV_SUCCESS) {
sources.emplace_back(std::make_pair(filenameId, std::move(code)));
}
} else if (instruction.opcode == spv::Op::OpSourceContinued) {
if (lastOpcode != spv::Op::OpSource) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
"OpSourceContinued MUST follow an OpSource.");
return SPV_ERROR_INVALID_BINARY;
}
assert(sources.size() > 0);
result = extractOpSourceContinued(loc, instruction,
&sources.back().second);
}
++instructionIndex;
lastOpcode = static_cast<spv::Op>(instruction.opcode);
return result;
};
if (!context.Parse(binary, headerParser, instructionParser)) {
return false;
}
// FIXME
(void)output;
std::string defaultName = "unnamed-";
size_t unnamedCount = 0;
for (auto & [ id, code ] : sources) {
std::string filename;
const auto it = stringMap.find(id);
if (it == stringMap.cend() || it->second.empty()) {
filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
++unnamedCount;
} else {
filename = it->second;
}
if (output->count(filename) != 0) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
"Source file name conflict.");
return false;
}
output->insert({filename, code});
}
return true;
}

View File

@ -32,7 +32,7 @@
//
// Returns `true` if the extraction succeeded, `false` otherwise.
// `output` value is undefined if `false` is returned.
bool extract_source_from_module(
bool ExtractSourceFromModule(
const std::vector<uint32_t>& binary,
std::unordered_map<std::string, std::string>* output);

View File

@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <filesystem>
#include <iostream>
#include "extract_source.h"
#include "source/opt/log.h"
#include "tools/io.h"
@ -42,6 +45,57 @@ Source dump options:
File written to stdout if '-' is given. Default is `-`.
)";
// Removes trailing '/' from `input`.
// A behavior difference has been observed between libc++ implementations.
// Fixing path to prevent this edge case to be reached.
// (https://github.com/llvm/llvm-project/issues/60634)
std::string fixPathForLLVM(std::string input) {
while (!input.empty() && input.back() == '/') input.resize(input.size() - 1);
return input;
}
// Write each HLSL file described in `sources` in a file in `outdirPath`.
// Doesn't ovewrite existing files, unless `overwrite` is set to true. The
// created HLSL file's filename is the path's filename obtained from `sources`.
// Returns true if all files could be written. False otherwise.
bool OutputSourceFiles(
const std::unordered_map<std::string, std::string>& sources,
const std::string& outdirPath, bool overwrite) {
std::filesystem::path outdir(fixPathForLLVM(outdirPath));
if (!std::filesystem::is_directory(outdir)) {
if (!std::filesystem::create_directories(outdir)) {
std::cerr << "error: could not create output directory " << outdir
<< std::endl;
return false;
}
}
for (const auto & [ filepath, code ] : sources) {
if (code.empty()) {
std::cout << "Ignoring source for " << filepath
<< ": no code source in debug infos." << std::endl;
continue;
}
std::filesystem::path old_path(filepath);
std::filesystem::path new_path = outdir / old_path.filename();
if (!overwrite && std::filesystem::exists(new_path)) {
std::cerr << "file " << filepath
<< " already exists, aborting (use --overwrite to allow it)."
<< std::endl;
return false;
}
std::cout << "Exporting " << new_path << std::endl;
if (!WriteFile<char>(new_path.string().c_str(), "w", code.c_str(),
code.size())) {
return false;
}
}
return true;
}
} // namespace
// clang-format off
@ -71,14 +125,11 @@ int main(int, const char** argv) {
}
if (flags::positional_arguments.size() != 1) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, nullptr, {},
"expected exactly one input file.");
std::cerr << "Expected exactly one input file." << std::endl;
return 1;
}
if (flags::source.value() || flags::entrypoint.value() ||
flags::compiler_cmd.value()) {
spvtools::Error(spvtools::utils::CLIMessageConsumer, nullptr, {},
"not implemented yet.");
if (flags::entrypoint.value() || flags::compiler_cmd.value()) {
std::cerr << "Unimplemented flags." << std::endl;
return 1;
}
@ -88,8 +139,34 @@ int main(int, const char** argv) {
}
if (flags::source.value()) {
std::unordered_map<std::string, std::string> output;
return extract_source_from_module(binary, &output) ? 0 : 1;
std::unordered_map<std::string, std::string> sourceCode;
if (!ExtractSourceFromModule(binary, &sourceCode)) {
return 1;
}
if (flags::list.value()) {
for (const auto & [ filename, source ] : sourceCode) {
printf("%s\n", filename.c_str());
}
return 0;
}
const bool outputToConsole = flags::outdir.value() == "-";
if (outputToConsole) {
for (const auto & [ filename, source ] : sourceCode) {
std::cout << filename << ":" << std::endl
<< source << std::endl
<< std::endl;
}
return 0;
}
const std::filesystem::path outdirPath(flags::outdir.value());
if (!OutputSourceFiles(sourceCode, outdirPath.string(),
flags::force.value())) {
return 1;
}
}
// FIXME: implement logic.