diff --git a/BUILD.bazel b/BUILD.bazel index 0b44b559f..759f043a8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -285,6 +285,7 @@ cc_binary( deps = [ ":spirv_tools_internal", ":tools_io", + ":tools_util", ], ) @@ -298,6 +299,7 @@ cc_binary( deps = [ ":spirv_tools", ":tools_io", + ":tools_util", ], ) @@ -357,6 +359,7 @@ cc_binary( ":spirv_tools_internal", ":spirv_tools_link", ":tools_io", + ":tools_util", ], ) @@ -387,6 +390,7 @@ cc_binary( deps = [ ":spirv_tools_internal", ":tools_io", + ":tools_util", ], ) @@ -416,7 +420,7 @@ cc_library( name = "base_{testcase}_test".format(testcase = f[len("test/"):-len("_test.cpp")]), size = "small", srcs = [f], - copts = TEST_COPTS, + copts = TEST_COPTS + ['-DTESTING'], linkstatic = 1, target_compatible_with = { "test/timer_test.cpp": incompatible_with(["@bazel_tools//src/conditions:windows"]), @@ -424,11 +428,12 @@ cc_library( deps = [ ":spirv_tools_internal", ":test_lib", + "tools_util", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], ) for f in glob( - ["test/*_test.cpp"], + ["test/*_test.cpp", "test/tools/*_test.cpp"], exclude = [ "test/cpp_interface_test.cpp", "test/log_test.cpp", diff --git a/BUILD.gn b/BUILD.gn index ee3743b85..b576be15c 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -1430,15 +1430,6 @@ if (spirv_tools_standalone) { } } -source_set("spvtools_util_cli_consumer") { - sources = [ - "tools/util/cli_consumer.cpp", - "tools/util/cli_consumer.h", - ] - deps = [ ":spvtools_headers" ] - configs += [ ":spvtools_internal_config" ] -} - source_set("spvtools_software_version") { sources = [ "source/software_version.cpp" ] deps = [ @@ -1448,12 +1439,23 @@ source_set("spvtools_software_version") { configs += [ ":spvtools_internal_config" ] } +source_set("spvtools_tools_util") { + sources = [ + "tools/util/flags.cpp" + "tools/util/cli_consumer.cpp", + "tools/util/cli_consumer.h", + ] + deps = [ ":spvtools_headers" ] + configs += [ ":spvtools_internal_config" ] +} + if (spvtools_build_executables) { executable("spirv-as") { sources = [ "tools/as/as.cpp" ] deps = [ ":spvtools", ":spvtools_software_version", + ":spvtools_tools_util", ] configs += [ ":spvtools_internal_config" ] } @@ -1463,6 +1465,7 @@ if (spvtools_build_executables) { deps = [ ":spvtools", ":spvtools_software_version", + ":spvtools_tools_util", ] configs += [ ":spvtools_internal_config" ] } @@ -1472,7 +1475,7 @@ if (spvtools_build_executables) { deps = [ ":spvtools", ":spvtools_software_version", - ":spvtools_util_cli_consumer", + ":spvtools_tools_util", ":spvtools_val", ] configs += [ ":spvtools_internal_config" ] @@ -1487,6 +1490,7 @@ if (spvtools_build_executables) { deps = [ ":spvtools", ":spvtools_software_version", + ":spvtools_tools_util", ] configs += [ ":spvtools_internal_config" ] } @@ -1497,7 +1501,7 @@ if (spvtools_build_executables) { ":spvtools", ":spvtools_opt", ":spvtools_software_version", - ":spvtools_util_cli_consumer", + ":spvtools_tools_util", ":spvtools_val", ] configs += [ ":spvtools_internal_config" ] @@ -1510,6 +1514,7 @@ if (spvtools_build_executables) { ":spvtools_link", ":spvtools_opt", ":spvtools_software_version", + ":spvtools_tools_util", ":spvtools_val", ] configs += [ ":spvtools_internal_config" ] @@ -1529,7 +1534,7 @@ if (!is_ios && !spirv_is_winuwp && build_with_chromium && spvtools_build_executa ":spvtools_opt", ":spvtools_reduce", ":spvtools_software_version", - ":spvtools_util_cli_consumer", + ":spvtools_tools_util", ":spvtools_val", "//third_party/protobuf:protobuf_full", ] @@ -1548,7 +1553,7 @@ if (!is_ios && !spirv_is_winuwp && spvtools_build_executables) { ":spvtools_opt", ":spvtools_reduce", ":spvtools_software_version", - ":spvtools_util_cli_consumer", + ":spvtools_tools_util", ":spvtools_val", ] configs += [ ":spvtools_internal_config" ] diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4ca8ef8fb..37c5e1d51 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -31,7 +31,7 @@ endif() function(add_spvtools_unittest) if (NOT "${SPIRV_SKIP_TESTS}" AND TARGET gmock_main) set(one_value_args TARGET PCH_FILE) - set(multi_value_args SRCS LIBS ENVIRONMENT) + set(multi_value_args SRCS LIBS ENVIRONMENT DEFINES) cmake_parse_arguments( ARG "" "${one_value_args}" "${multi_value_args}" ${ARGN}) set(target test_${ARG_TARGET}) @@ -40,6 +40,7 @@ function(add_spvtools_unittest) spvtools_pch(SRC_COPY ${ARG_PCH_FILE}) endif() add_executable(${target} ${SRC_COPY}) + target_compile_definitions(${target} PUBLIC ${ARG_DEFINES}) spvtools_default_compile_options(${target}) if(${COMPILER_IS_LIKE_GNU}) target_compile_options(${target} PRIVATE -Wno-undef) diff --git a/test/tools/CMakeLists.txt b/test/tools/CMakeLists.txt index 99f9780c5..0520bd751 100644 --- a/test/tools/CMakeLists.txt +++ b/test/tools/CMakeLists.txt @@ -18,4 +18,11 @@ add_test(NAME spirv-tools_expect_unittests add_test(NAME spirv-tools_spirv_test_framework_unittests COMMAND ${PYTHON_EXECUTABLE} -m unittest spirv_test_framework_unittest.py WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +add_spvtools_unittest( + TARGET spirv_unit_test_tools_util + SRCS flags_test.cpp ${spirv-tools_SOURCE_DIR}/tools/util/flags.cpp + LIBS ${SPIRV_TOOLS_FULL_VISIBILITY} + DEFINES TESTING=1) + add_subdirectory(opt) diff --git a/test/tools/flags_test.cpp b/test/tools/flags_test.cpp new file mode 100644 index 000000000..d92f00102 --- /dev/null +++ b/test/tools/flags_test.cpp @@ -0,0 +1,286 @@ +// Copyright (c) 2023 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/util/flags.h" + +#include "gmock/gmock.h" + +#ifdef UTIL_FLAGS_FLAG +#undef UTIL_FLAGS_FLAG +#define UTIL_FLAGS_FLAG(Type, Prefix, Name, Default, Required, IsShort) \ + flags::Flag Name(Default); \ + flags::FlagRegistration Name##_registration(Name, Prefix #Name, Required, \ + IsShort) +#else +#error \ + "UTIL_FLAGS_FLAG is not defined. Either flags.h is not included of the flag name changed." +#endif + +class FlagTest : public ::testing::Test { + protected: + void SetUp() override { flags::FlagList::reset(); } +}; + +TEST_F(FlagTest, NoFlags) { + const char* argv[] = {"binary", nullptr}; + EXPECT_TRUE(flags::Parse(argv)); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, DashIsPositional) { + const char* argv[] = {"binary", "-", nullptr}; + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(flags::positional_arguments.size(), 1); + EXPECT_EQ(flags::positional_arguments[0], "-"); +} + +TEST_F(FlagTest, Positional) { + const char* argv[] = {"binary", "A", "BCD", nullptr}; + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(flags::positional_arguments.size(), 2); + EXPECT_EQ(flags::positional_arguments[0], "A"); + EXPECT_EQ(flags::positional_arguments[1], "BCD"); +} + +TEST_F(FlagTest, MissingRequired) { + FLAG_SHORT_bool(g, false, true); + + const char* argv[] = {"binary", nullptr}; + EXPECT_FALSE(flags::Parse(argv)); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, BooleanShortValue) { + FLAG_SHORT_bool(g, false, false); + const char* argv[] = {"binary", "-g", nullptr}; + EXPECT_FALSE(g.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(g.value()); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, BooleanShortDefaultValue) { + FLAG_SHORT_bool(g, true, false); + const char* argv[] = {"binary", nullptr}; + EXPECT_TRUE(g.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(g.value()); +} + +TEST_F(FlagTest, BooleanLongValueNotParsed) { + FLAG_SHORT_bool(g, false, false); + const char* argv[] = {"binary", "-g", "false", nullptr}; + EXPECT_FALSE(g.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(g.value()); + EXPECT_EQ(flags::positional_arguments.size(), 1); + EXPECT_EQ(flags::positional_arguments[0], "false"); +} + +TEST_F(FlagTest, BooleanLongSplitNotParsed) { + FLAG_LONG_bool(foo, false, false); + const char* argv[] = {"binary", "--foo", "true", nullptr}; + EXPECT_FALSE(foo.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(foo.value()); + EXPECT_EQ(flags::positional_arguments.size(), 1); + EXPECT_EQ(flags::positional_arguments[0], "true"); +} + +TEST_F(FlagTest, BooleanLongExplicitTrue) { + FLAG_LONG_bool(foo, false, false); + const char* argv[] = {"binary", "--foo=true", nullptr}; + EXPECT_FALSE(foo.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(foo.value()); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, BooleanLongExplicitFalse) { + FLAG_LONG_bool(foo, false, false); + const char* argv[] = {"binary", "--foo=false", nullptr}; + EXPECT_FALSE(foo.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_FALSE(foo.value()); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, BooleanLongDefaultValue) { + FLAG_LONG_bool(foo, true, false); + const char* argv[] = {"binary", nullptr}; + EXPECT_TRUE(foo.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_TRUE(foo.value()); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, BooleanLongDefaultValueCancelled) { + FLAG_LONG_bool(foo, true, false); + const char* argv[] = {"binary", "--foo=false", nullptr}; + EXPECT_TRUE(foo.value()); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_FALSE(foo.value()); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringFlagDefaultValue) { + FLAG_SHORT_string(f, "default", false); + const char* argv[] = {"binary", nullptr}; + EXPECT_EQ(f.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + EXPECT_EQ(f.value(), "default"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringFlagShortMissingString) { + FLAG_SHORT_string(f, "default", false); + const char* argv[] = {"binary", "-f", nullptr}; + EXPECT_EQ(f.value(), "default"); + + EXPECT_FALSE(flags::Parse(argv)); +} + +TEST_F(FlagTest, StringFlagDefault) { + FLAG_SHORT_string(f, "default", false); + const char* argv[] = {"binary", nullptr}; + EXPECT_EQ(f.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(f.value(), "default"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringFlagSet) { + FLAG_SHORT_string(f, "default", false); + const char* argv[] = {"binary", "-f", "toto", nullptr}; + EXPECT_EQ(f.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(f.value(), "toto"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringLongFlagSetSplit) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--foo", "toto", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(foo.value(), "toto"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringLongFlagSetUnified) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--foo=toto", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(foo.value(), "toto"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, StringLongFlagSetEmpty) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--foo=", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(foo.value(), ""); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, AllPositionalAfterDoubleDash) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--", "--foo=toto", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(foo.value(), "default"); + EXPECT_EQ(flags::positional_arguments.size(), 1); + EXPECT_EQ(flags::positional_arguments[0], "--foo=toto"); +} + +TEST_F(FlagTest, NothingAfterDoubleDash) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_TRUE(flags::Parse(argv)); + + EXPECT_EQ(foo.value(), "default"); + EXPECT_EQ(flags::positional_arguments.size(), 0); +} + +TEST_F(FlagTest, FlagDoubleSetNotAllowed) { + FLAG_LONG_string(foo, "default", false); + const char* argv[] = {"binary", "--foo=abc", "--foo=def", nullptr}; + EXPECT_EQ(foo.value(), "default"); + + EXPECT_FALSE(flags::Parse(argv)); +} + +TEST_F(FlagTest, MultipleFlags) { + FLAG_LONG_string(foo, "default foo", false); + FLAG_LONG_string(bar, "default_bar", false); + const char* argv[] = {"binary", "--foo", "abc", "--bar=def", nullptr}; + EXPECT_EQ(foo.value(), "default foo"); + EXPECT_EQ(bar.value(), "default_bar"); + + EXPECT_TRUE(flags::Parse(argv)); + EXPECT_EQ(foo.value(), "abc"); + EXPECT_EQ(bar.value(), "def"); +} + +TEST_F(FlagTest, MixedStringAndBool) { + FLAG_LONG_string(foo, "default foo", false); + FLAG_LONG_string(bar, "default_bar", false); + FLAG_SHORT_bool(g, false, false); + const char* argv[] = {"binary", "--foo", "abc", "-g", "--bar=def", nullptr}; + EXPECT_EQ(foo.value(), "default foo"); + EXPECT_EQ(bar.value(), "default_bar"); + EXPECT_FALSE(g.value()); + + EXPECT_TRUE(flags::Parse(argv)); + EXPECT_EQ(foo.value(), "abc"); + EXPECT_EQ(bar.value(), "def"); + EXPECT_TRUE(g.value()); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index d272b08e7..6bf7a1190 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -39,21 +39,29 @@ function(add_spvtools_tool) set_property(TARGET ${ARG_TARGET} PROPERTY FOLDER "SPIRV-Tools executables") endfunction() +set(COMMON_TOOLS_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/util/flags.cpp") + if (NOT ${SPIRV_SKIP_EXECUTABLES}) - add_spvtools_tool(TARGET spirv-as SRCS as/as.cpp LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) - add_spvtools_tool(TARGET spirv-diff SRCS diff/diff.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-diff SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) - add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) - add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp util/cli_consumer.cpp LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) - add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-diff SRCS ${COMMON_TOOLS_SRCS} diff/diff.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-diff SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-dis SRCS ${COMMON_TOOLS_SRCS} dis/dis.cpp LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-val SRCS ${COMMON_TOOLS_SRCS} val/val.cpp util/cli_consumer.cpp LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-opt SRCS ${COMMON_TOOLS_SRCS} opt/opt.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) if(NOT (${CMAKE_SYSTEM_NAME} STREQUAL "iOS")) # iOS does not allow std::system calls which spirv-reduce requires - add_spvtools_tool(TARGET spirv-reduce SRCS reduce/reduce.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-reduce ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-reduce SRCS ${COMMON_TOOLS_SRCS} reduce/reduce.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-reduce ${SPIRV_TOOLS_FULL_VISIBILITY}) endif() - add_spvtools_tool(TARGET spirv-link SRCS link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS_FULL_VISIBILITY}) - add_spvtools_tool(TARGET spirv-lint SRCS lint/lint.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-lint SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-link SRCS ${COMMON_TOOLS_SRCS} link/linker.cpp LIBS SPIRV-Tools-link ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-lint SRCS ${COMMON_TOOLS_SRCS} lint/lint.cpp util/cli_consumer.cpp LIBS SPIRV-Tools-lint SPIRV-Tools-opt ${SPIRV_TOOLS_FULL_VISIBILITY}) + add_spvtools_tool(TARGET spirv-as + SRCS as/as.cpp + ${COMMON_TOOLS_SRCS} + LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) + target_include_directories(spirv-as PRIVATE ${spirv-tools_SOURCE_DIR} + ${SPIRV_HEADER_INCLUDE_DIR}) add_spvtools_tool(TARGET spirv-cfg SRCS cfg/cfg.cpp cfg/bin_to_dot.h cfg/bin_to_dot.cpp + ${COMMON_TOOLS_SRCS} LIBS ${SPIRV_TOOLS_FULL_VISIBILITY}) target_include_directories(spirv-cfg PRIVATE ${spirv-tools_SOURCE_DIR} ${SPIRV_HEADER_INCLUDE_DIR}) diff --git a/tools/as/as.cpp b/tools/as/as.cpp index 506b05856..2a000cf09 100644 --- a/tools/as/as.cpp +++ b/tools/as/as.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -19,11 +20,11 @@ #include "source/spirv_target_env.h" #include "spirv-tools/libspirv.h" #include "tools/io.h" +#include "tools/util/flags.h" -void print_usage(char* argv0) { - std::string target_env_list = spvTargetEnvList(19, 80); - printf( - R"(%s - Create a SPIR-V binary module from SPIR-V assembly text +static const auto kDefaultEnvironment = "spv1.6"; +static const std::string kHelpText = + R"(%s - Create a SPIR-V binary module from SPIR-V assembly text Usage: %s [options] [] @@ -42,94 +43,70 @@ Options: Numeric IDs in the binary will have the same values as in the source. Non-numeric IDs are allocated by filling in the gaps, starting with 1 and going up. - --target-env {%s} + --target-env %s Use specified environment. -)", - argv0, argv0, target_env_list.c_str()); -} +)"; -static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; +// clang-format off +FLAG_SHORT_bool( h, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( help, /* default_value= */ false, /* required= */false); +FLAG_LONG_bool( version, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( preserve_numeric_ids, /* default_value= */ false, /* required= */ false); +FLAG_SHORT_string(o, /* default_value= */ "", /* required= */ false); +FLAG_LONG_string( target_env, /* default_value= */ kDefaultEnvironment, /* required= */ false); +// clang-format on -int main(int argc, char** argv) { - const char* inFile = nullptr; - const char* outFile = nullptr; - uint32_t options = 0; - spv_target_env target_env = kDefaultEnvironment; - for (int argi = 1; argi < argc; ++argi) { - if ('-' == argv[argi][0]) { - switch (argv[argi][1]) { - case 'h': { - print_usage(argv[0]); - return 0; - } - case 'o': { - if (!outFile && argi + 1 < argc) { - outFile = argv[++argi]; - } else { - print_usage(argv[0]); - return 1; - } - } break; - case 0: { - // Setting a filename of "-" to indicate stdin. - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } break; - case '-': { - // Long options - if (0 == strcmp(argv[argi], "--version")) { - printf("%s\n", spvSoftwareVersionDetailsString()); - printf("Target: %s\n", - spvTargetEnvDescription(kDefaultEnvironment)); - return 0; - } else if (0 == strcmp(argv[argi], "--help")) { - print_usage(argv[0]); - return 0; - } else if (0 == strcmp(argv[argi], "--preserve-numeric-ids")) { - options |= SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS; - } else if (0 == strcmp(argv[argi], "--target-env")) { - if (argi + 1 < argc) { - const auto env_str = argv[++argi]; - if (!spvParseTargetEnv(env_str, &target_env)) { - fprintf(stderr, "error: Unrecognized target env: %s\n", - env_str); - return 1; - } - } else { - fprintf(stderr, "error: Missing argument to --target-env\n"); - return 1; - } - } else { - fprintf(stderr, "error: Unrecognized option: %s\n\n", argv[argi]); - print_usage(argv[0]); - return 1; - } - } break; - default: - fprintf(stderr, "error: Unrecognized option: %s\n\n", argv[argi]); - print_usage(argv[0]); - return 1; - } - } else { - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } +int main(int, const char** argv) { + if (!flags::Parse(argv)) { + return 1; } - if (!outFile) { + if (flags::h.value() || flags::help.value()) { + const std::string target_env_list = spvTargetEnvList(19, 80); + printf(kHelpText.c_str(), argv[0], argv[0], target_env_list.c_str()); + return 0; + } + + if (flags::version.value()) { + spv_target_env target_env; + bool success = spvParseTargetEnv(kDefaultEnvironment, &target_env); + assert(success && "Default environment should always parse."); + if (!success) { + fprintf(stderr, + "error: invalid default target environment. Please report this " + "issue."); + return 1; + } + printf("%s\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", spvTargetEnvDescription(target_env)); + return 0; + } + + std::string outFile = flags::o.value(); + if (outFile.empty()) { outFile = "out.spv"; } + uint32_t options = 0; + if (flags::preserve_numeric_ids.value()) { + options |= SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS; + } + + spv_target_env target_env; + if (!spvParseTargetEnv(flags::target_env.value().c_str(), &target_env)) { + fprintf(stderr, "error: Unrecognized target env: %s\n", + flags::target_env.value().c_str()); + return 1; + } + + if (flags::positional_arguments.size() != 1) { + fprintf(stderr, "error: exactly one input file must be specified.\n"); + return 1; + } + std::string inFile = flags::positional_arguments[0]; + std::vector contents; - if (!ReadTextFile(inFile, &contents)) return 1; + if (!ReadTextFile(inFile.c_str(), &contents)) return 1; spv_binary binary; spv_diagnostic diagnostic = nullptr; @@ -143,7 +120,8 @@ int main(int argc, char** argv) { return error; } - if (!WriteFile(outFile, "wb", binary->code, binary->wordCount)) { + if (!WriteFile(outFile.c_str(), "wb", binary->code, + binary->wordCount)) { spvBinaryDestroy(binary); return 1; } diff --git a/tools/cfg/cfg.cpp b/tools/cfg/cfg.cpp index 5380c21ec..2d11e6fb0 100644 --- a/tools/cfg/cfg.cpp +++ b/tools/cfg/cfg.cpp @@ -21,11 +21,11 @@ #include "spirv-tools/libspirv.h" #include "tools/cfg/bin_to_dot.h" #include "tools/io.h" +#include "tools/util/flags.h" -// Prints a program usage message to stdout. -static void print_usage(const char* argv0) { - printf( - R"(%s - Show the control flow graph in GraphiViz "dot" form. EXPERIMENTAL +static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; +static const std::string kHelpText = + R"(%s - Show the control flow graph in GraphiViz "dot" form. EXPERIMENTAL Usage: %s [options] [] @@ -40,71 +40,42 @@ Options: -o Set the output filename. Output goes to standard output if this option is not specified, or if the filename is "-". -)", - argv0, argv0); -} +)"; -static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; +// clang-format off +FLAG_SHORT_bool( h, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool( help, /* default_value= */ false, /* required= */false); +FLAG_LONG_bool( version, /* default_value= */ false, /* required= */ false); +FLAG_SHORT_string(o, /* default_value= */ "", /* required= */ false); +// clang-format on -int main(int argc, char** argv) { - const char* inFile = nullptr; - const char* outFile = nullptr; // Stays nullptr if printing to stdout. - - for (int argi = 1; argi < argc; ++argi) { - if ('-' == argv[argi][0]) { - switch (argv[argi][1]) { - case 'h': - print_usage(argv[0]); - return 0; - case 'o': { - if (!outFile && argi + 1 < argc) { - outFile = argv[++argi]; - } else { - print_usage(argv[0]); - return 1; - } - } break; - case '-': { - // Long options - if (0 == strcmp(argv[argi], "--help")) { - print_usage(argv[0]); - return 0; - } - if (0 == strcmp(argv[argi], "--version")) { - printf("%s EXPERIMENTAL\n", spvSoftwareVersionDetailsString()); - printf("Target: %s\n", - spvTargetEnvDescription(kDefaultEnvironment)); - return 0; - } - print_usage(argv[0]); - return 1; - } - case 0: { - // Setting a filename of "-" to indicate stdin. - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } break; - default: - print_usage(argv[0]); - return 1; - } - } else { - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } +int main(int, const char** argv) { + if (!flags::Parse(argv)) { + return 1; } + if (flags::h.value() || flags::help.value()) { + printf(kHelpText.c_str(), argv[0], argv[0]); + return 0; + } + + if (flags::version.value()) { + printf("%s EXPERIMENTAL\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", spvTargetEnvDescription(kDefaultEnvironment)); + return 0; + } + + if (flags::positional_arguments.size() != 1) { + fprintf(stderr, "error: exactly one input file must be specified.\n"); + return 1; + } + + std::string inFile = flags::positional_arguments[0]; + std::string outFile = flags::o.value(); + // Read the input binary. std::vector contents; - if (!ReadBinaryFile(inFile, &contents)) return 1; + if (!ReadBinaryFile(inFile.c_str(), &contents)) return 1; spv_context context = spvContextCreate(kDefaultEnvironment); spv_diagnostic diagnostic = nullptr; @@ -118,7 +89,8 @@ int main(int argc, char** argv) { return error; } std::string str = ss.str(); - WriteFile(outFile, "w", str.data(), str.size()); + WriteFile(outFile.empty() ? nullptr : outFile.c_str(), "w", str.data(), + str.size()); spvDiagnosticDestroy(diagnostic); spvContextDestroy(context); diff --git a/tools/dis/dis.cpp b/tools/dis/dis.cpp index 64380db06..aacd37f07 100644 --- a/tools/dis/dis.cpp +++ b/tools/dis/dis.cpp @@ -24,10 +24,9 @@ #include "spirv-tools/libspirv.h" #include "tools/io.h" +#include "tools/util/flags.h" -static void print_usage(char* argv0) { - printf( - R"(%s - Disassemble a SPIR-V binary module +static const std::string kHelpText = R"(%s - Disassemble a SPIR-V binary module Usage: %s [options] [] @@ -58,15 +57,49 @@ Options: --offsets Show byte offsets for each instruction. --comment Add comments to make reading easier -)", - argv0, argv0); -} +)"; + +// clang-format off +FLAG_SHORT_bool (h, /* default_value= */ false, /* required= */ false); +FLAG_SHORT_string(o, /* default_value= */ "-", /* required= */ false); +FLAG_LONG_bool (help, /* default_value= */ false, /* required= */false); +FLAG_LONG_bool (version, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (color, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (no_color, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (no_indent, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (no_header, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (raw_id, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (offsets, /* default_value= */ false, /* required= */ false); +FLAG_LONG_bool (comment, /* default_value= */ false, /* required= */ false); +// clang-format on static const auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_5; -int main(int argc, char** argv) { - const char* inFile = nullptr; - const char* outFile = nullptr; +int main(int, const char** argv) { + if (!flags::Parse(argv)) { + return 1; + } + + if (flags::h.value() || flags::help.value()) { + printf(kHelpText.c_str(), argv[0], argv[0]); + return 0; + } + + if (flags::version.value()) { + printf("%s\n", spvSoftwareVersionDetailsString()); + printf("Target: %s\n", spvTargetEnvDescription(kDefaultEnvironment)); + return 0; + } + + if (flags::positional_arguments.size() > 1) { + fprintf(stderr, "error: more than one input file specified.\n"); + return 1; + } + + const std::string inFile = flags::positional_arguments.size() == 0 + ? "-" + : flags::positional_arguments[0]; + const std::string outFile = flags::o.value(); bool color_is_possible = #if SPIRV_COLOR_TERMINAL @@ -74,105 +107,30 @@ int main(int argc, char** argv) { #else false; #endif - bool force_color = false; - bool force_no_color = false; - - bool allow_indent = true; - bool show_byte_offsets = false; - bool no_header = false; - bool friendly_names = true; - bool comments = false; - - for (int argi = 1; argi < argc; ++argi) { - if ('-' == argv[argi][0]) { - switch (argv[argi][1]) { - case 'h': - print_usage(argv[0]); - return 0; - case 'o': { - if (!outFile && argi + 1 < argc) { - outFile = argv[++argi]; - } else { - print_usage(argv[0]); - return 1; - } - } break; - case '-': { - // Long options - if (0 == strcmp(argv[argi], "--no-color")) { - force_no_color = true; - force_color = false; - } else if (0 == strcmp(argv[argi], "--color")) { - force_no_color = false; - force_color = true; - } else if (0 == strcmp(argv[argi], "--comment")) { - comments = true; - } else if (0 == strcmp(argv[argi], "--no-indent")) { - allow_indent = false; - } else if (0 == strcmp(argv[argi], "--offsets")) { - show_byte_offsets = true; - } else if (0 == strcmp(argv[argi], "--no-header")) { - no_header = true; - } else if (0 == strcmp(argv[argi], "--raw-id")) { - friendly_names = false; - } else if (0 == strcmp(argv[argi], "--help")) { - print_usage(argv[0]); - return 0; - } else if (0 == strcmp(argv[argi], "--version")) { - printf("%s\n", spvSoftwareVersionDetailsString()); - printf("Target: %s\n", - spvTargetEnvDescription(kDefaultEnvironment)); - return 0; - } else { - print_usage(argv[0]); - return 1; - } - } break; - case 0: { - // Setting a filename of "-" to indicate stdin. - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } break; - default: - print_usage(argv[0]); - return 1; - } - } else { - if (!inFile) { - inFile = argv[argi]; - } else { - fprintf(stderr, "error: More than one input file specified\n"); - return 1; - } - } - } uint32_t options = SPV_BINARY_TO_TEXT_OPTION_NONE; - if (allow_indent) options |= SPV_BINARY_TO_TEXT_OPTION_INDENT; + if (!flags::no_indent.value()) options |= SPV_BINARY_TO_TEXT_OPTION_INDENT; - if (show_byte_offsets) options |= SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET; + if (flags::offsets.value()) + options |= SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET; - if (no_header) options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER; + if (flags::no_header.value()) options |= SPV_BINARY_TO_TEXT_OPTION_NO_HEADER; - if (friendly_names) options |= SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES; + if (!flags::raw_id.value()) + options |= SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES; - if (comments) options |= SPV_BINARY_TO_TEXT_OPTION_COMMENT; + if (flags::comment.value()) options |= SPV_BINARY_TO_TEXT_OPTION_COMMENT; - if (!outFile || (0 == strcmp("-", outFile))) { + if (flags::o.value() == "-") { // Print to standard output. options |= SPV_BINARY_TO_TEXT_OPTION_PRINT; - - if (color_is_possible && !force_no_color) { + if (color_is_possible && !flags::no_color.value()) { bool output_is_tty = true; #if defined(_POSIX_VERSION) output_is_tty = isatty(fileno(stdout)); #endif - if (output_is_tty || force_color) { + if (output_is_tty || flags::color.value()) { options |= SPV_BINARY_TO_TEXT_OPTION_COLOR; } } @@ -180,7 +138,7 @@ int main(int argc, char** argv) { // Read the input binary. std::vector contents; - if (!ReadBinaryFile(inFile, &contents)) return 1; + if (!ReadBinaryFile(inFile.c_str(), &contents)) return 1; // If printing to standard output, then spvBinaryToText should // do the printing. In particular, colour printing on Windows is @@ -205,7 +163,7 @@ int main(int argc, char** argv) { } if (!print_to_stdout) { - if (!WriteFile(outFile, "w", text->str, text->length)) { + if (!WriteFile(outFile.c_str(), "w", text->str, text->length)) { spvTextDestroy(text); return 1; } diff --git a/tools/util/flags.cpp b/tools/util/flags.cpp new file mode 100644 index 000000000..c773347b9 --- /dev/null +++ b/tools/util/flags.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2023 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flags.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace flags { + +std::vector positional_arguments; + +namespace { + +using token_t = const char*; +using token_iterator_t = token_t*; + +// Extracts the flag name from a potential token. +// This function only looks for a '=', to split the flag name from the value for +// long-form flags. Returns the name of the flag, prefixed with the hyphen(s). +inline std::string get_flag_name(const std::string& flag, bool is_short_flag) { + if (is_short_flag) { + return flag; + } + + size_t equal_index = flag.find('='); + if (equal_index == std::string::npos) { + return flag; + } + return flag.substr(0, equal_index); +} + +// Parse a boolean flag. Returns `true` if the parsing succeeded, `false` +// otherwise. +bool parse_flag(Flag& flag, bool is_short_flag, + const std::string& token) { + if (is_short_flag) { + flag.value() = true; + return true; + } + + const std::string raw_flag(token); + size_t equal_index = raw_flag.find('='); + if (equal_index == std::string::npos) { + flag.value() = true; + return true; + } + + const std::string value = raw_flag.substr(equal_index + 1); + if (value == "true") { + flag.value() = true; + return true; + } + + if (value == "false") { + flag.value() = false; + return true; + } + + return false; +} + +// Parse a string flag. Moved the iterator to the last flag's token if it's a +// multi-token flag. Returns `true` if the parsing succeeded. +// The iterator is moved to the last parsed token. +bool parse_flag(Flag& flag, bool is_short_flag, + token_iterator_t* iterator) { + const std::string raw_flag(**iterator); + const size_t equal_index = raw_flag.find('='); + if (is_short_flag || equal_index == std::string::npos) { + if ((*iterator)[1] == nullptr) { + return false; + } + + // This is a bi-token flag. Moving iterator to the last parsed token. + flag.value() = (*iterator)[1]; + *iterator += 1; + return true; + } + + // This is a mono-token flag, no need to move the iterator. + const std::string value = raw_flag.substr(equal_index + 1); + flag.value() = value; + return true; +} +} // namespace + +// This is the function to expand if you want to support a new type. +bool FlagList::parse_flag_info(FlagInfo& info, token_iterator_t* iterator) { + bool success = false; + + std::visit( + [&](auto&& item) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + success = parse_flag(item.get(), info.is_short, **iterator); + } else if constexpr (std::is_same_v>) { + success = parse_flag(item.get(), info.is_short, iterator); + } else { + static_assert(always_false_v, "Unsupported flag type."); + } + }, + info.flag); + + return success; +} + +bool FlagList::parse(token_t* argv) { + flags::positional_arguments.clear(); + std::unordered_set parsed_flags; + + bool ignore_flags = false; + for (const char** it = argv + 1; *it != nullptr; it++) { + if (ignore_flags) { + flags::positional_arguments.emplace_back(*it); + continue; + } + + // '--' alone is used to mark the end of the flags. + if (std::strcmp(*it, "--") == 0) { + ignore_flags = true; + continue; + } + + // '-' alone is not a flag, but often used to say 'stdin'. + if (std::strcmp(*it, "-") == 0) { + flags::positional_arguments.emplace_back(*it); + continue; + } + + const std::string raw_flag(*it); + if (raw_flag.size() == 0) { + continue; + } + + if (raw_flag[0] != '-') { + flags::positional_arguments.emplace_back(*it); + continue; + } + + // Only case left: flags (long and shorts). + if (raw_flag.size() < 2) { + std::cerr << "Unknown flag " << raw_flag << std::endl; + return false; + } + const bool is_short_flag = std::strncmp(*it, "--", 2) != 0; + const std::string flag_name = get_flag_name(raw_flag, is_short_flag); + + auto needle = std::find_if( + get_flags().begin(), get_flags().end(), + [&flag_name](const auto& item) { return item.name == flag_name; }); + if (needle == get_flags().end()) { + std::cerr << "Unknown flag " << flag_name << std::endl; + return false; + } + + if (parsed_flags.count(&*needle) != 0) { + std::cerr << "The flag " << flag_name << " was specified multiple times." + << std::endl; + return false; + } + parsed_flags.insert(&*needle); + + if (!parse_flag_info(*needle, &it)) { + std::cerr << "Invalid usage for flag " << flag_name << std::endl; + return false; + } + } + + // Check that we parsed all required flags. + for (const auto& flag : get_flags()) { + if (!flag.required) { + continue; + } + + if (parsed_flags.count(&flag) == 0) { + std::cerr << "Missing required flag " << flag.name << std::endl; + return false; + } + } + + return true; +} + +// Just the public wrapper around the parse function. +bool Parse(const char** argv) { return FlagList::parse(argv); } + +} // namespace flags diff --git a/tools/util/flags.h b/tools/util/flags.h new file mode 100644 index 000000000..e48982cd5 --- /dev/null +++ b/tools/util/flags.h @@ -0,0 +1,251 @@ +// Copyright (c) 2023 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef INCLUDE_SPIRV_TOOLS_UTIL_FLAGS_HPP_ +#define INCLUDE_SPIRV_TOOLS_UTIL_FLAGS_HPP_ + +#include +#include +#include +#include + +// This file provides some utils to define a command-line interface with +// required and optional flags. +// - Flag order is not checked. +// - Currently supported flag types: BOOLEAN, STRING +// - As with most nix tools, using '--' in the command-line means all following +// tokens will be considered positional +// arguments. +// Example: binary -g -- -g --some-other-flag +// - the first `-g` is a flag. +// - the second `-g` is not a flag. +// - `--some-other-flag` is not a flag. +// - Both long-form and short-form flags are supported, but boolean flags don't +// support split boolean literals (short and long form). +// Example: +// -g : allowed, sets g to true. +// --my-flag : allowed, sets --my-flag to true. +// --my-flag=true : allowed, sets --my-flag to true. +// --my-flag true : NOT allowed. +// -g true : NOT allowed. +// --my-flag=TRUE : NOT allowed. +// +// - This implementation also supports string flags: +// -o myfile.spv : allowed, sets -o to `myfile.spv`. +// --output=myfile.spv : allowed, sets --output to `myfile.spv`. +// --output myfile.spv : allowd, sets --output to `myfile.spv`. +// +// Note: then second token is NOT checked for hyphens. +// --output -file.spv +// flag name: `output` +// flag value: `-file.spv` +// +// - This implementation generates flag at compile time. Meaning flag names +// must be valid C++ identifiers. +// However, flags are usually using hyphens for word separation. Hence +// renaming is done behind the scenes. Example: +// // Declaring a long-form flag. +// FLAG_LONG_bool(my_flag, [...]) +// +// -> in the code: flags::my_flag.value() +// -> command-line: --my-flag +// +// - The only additional lexing done is around '='. Otherwise token list is +// processed as received in the Parse() +// function. +// Lexing the '=' sign: +// - This is only done when parsing a long-form flag name. +// - the first '=' found is considered a marker for long-form, splitting +// the token into 2. +// Example: --option=value=abc -> [--option, value=abc] +// +// In most cases, you want to define some flags, parse them, and query them. +// Here is a small code sample: +// +// ```c +// // Defines a '-h' boolean flag for help printing, optional. +// FLAG_SHORT_bool(h, /*default=*/ false, "Print the help.", false); +// // Defines a '--my-flag' string flag, required. +// FLAG_LONG_string(my_flag, /*default=*/ "", "A magic flag!", true); +// +// int main(int argc, const char** argv) { +// if (!flags::Parse(argv)) { +// return -1; +// } +// +// if (flags::h.value()) { +// printf("usage: my-bin --my-flag=\n"); +// return 0; +// } +// +// printf("flag value: %s\n", flags::my_flag.value().c_str()); +// for (const std::string& arg : flags::positional_arguments) { +// printf("arg: %s\n", arg.c_str()); +// } +// return 0; +// } +// ```c + +// Those macros can be used to define flags. +// - They should be used in the global scope. +// - Underscores in the flag variable name are replaced with hyphens ('-'). +// +// Example: +// FLAG_SHORT_bool(my_flag, false, "some help", false); +// - in the code: flags::my_flag +// - command line: --my-flag=true +// +#define FLAG_LONG_string(Name, Default, Required) \ + UTIL_FLAGS_FLAG_LONG(std::string, Name, Default, Required) +#define FLAG_LONG_bool(Name, Default, Required) \ + UTIL_FLAGS_FLAG_LONG(bool, Name, Default, Required) + +#define FLAG_SHORT_string(Name, Default, Required) \ + UTIL_FLAGS_FLAG_SHORT(std::string, Name, Default, Required) +#define FLAG_SHORT_bool(Name, Default, Required) \ + UTIL_FLAGS_FLAG_SHORT(bool, Name, Default, Required) + +namespace flags { + +// Parse the command-line arguments, checking flags, and separating positional +// arguments from flags. +// +// * argv: the argv array received in the main function. This utility expects +// the last pointer to +// be NULL, as it should if coming from the main() function. +// +// Returns `true` if the parsing succeeds, `false` otherwise. +bool Parse(const char** argv); + +} // namespace flags + +// ===================== BEGIN NON-PUBLIC SECTION ============================= +// All the code below belongs to the implementation, and there is no guaranteed +// around the API stability. Please do not use it directly. + +// Defines the static variable holding the flag, allowing access like +// flags::my_flag. +// By creating the FlagRegistration object, the flag can be added to +// the global list. +// The final `extern` definition is ONLY useful for clang-format: +// - if the macro doesn't ends with a semicolon, clang-format goes wild. +// - cannot disable clang-format for those macros on clang < 16. +// (https://github.com/llvm/llvm-project/issues/54522) +// - cannot allow trailing semi (-Wextra-semi). +#define UTIL_FLAGS_FLAG(Type, Prefix, Name, Default, Required, IsShort) \ + namespace flags { \ + Flag Name(Default); \ + namespace { \ + static FlagRegistration Name##_registration(Name, Prefix #Name, Required, \ + IsShort); \ + } \ + } \ + extern flags::Flag flags::Name + +#define UTIL_FLAGS_FLAG_LONG(Type, Name, Default, Required) \ + UTIL_FLAGS_FLAG(Type, "--", Name, Default, Required, false) +#define UTIL_FLAGS_FLAG_SHORT(Type, Name, Default, Required) \ + UTIL_FLAGS_FLAG(Type, "-", Name, Default, Required, true) + +namespace flags { + +// Just a wrapper around the flag value. +template +struct Flag { + public: + Flag(T&& default_value) : value_(default_value) {} + Flag(Flag&& other) = delete; + Flag(const Flag& other) = delete; + + const T& value() const { return value_; } + T& value() { return value_; } + + private: + T value_; +}; + +// To add support for new flag-types, this needs to be extended, and the visitor +// below. +using FlagType = std::variant>, + std::reference_wrapper>>; + +template +inline constexpr bool always_false_v = false; + +extern std::vector positional_arguments; + +// Static class keeping track of the flags/arguments values. +class FlagList { + struct FlagInfo { + FlagInfo(FlagType&& flag_, std::string&& name_, bool required_, + bool is_short_) + : flag(std::move(flag_)), + name(std::move(name_)), + required(required_), + is_short(is_short_) {} + + FlagType flag; + std::string name; + bool required; + bool is_short; + }; + + public: + template + static void register_flag(Flag& flag, std::string&& name, bool required, + bool is_short) { + get_flags().emplace_back(flag, std::move(name), required, is_short); + } + + static bool parse(const char** argv); + +#ifdef TESTING + // Flags are supposed to be constant for the whole app execution, hence the + // static storage. Gtest doesn't fork before running a test, meaning we have + // to manually clear the context at teardown. + static void reset() { + get_flags().clear(); + positional_arguments.clear(); + } +#endif + + private: + static std::vector& get_flags() { + static std::vector flags; + return flags; + } + + static bool parse_flag_info(FlagInfo& info, const char*** iterator); + static void print_usage(const char* binary_name, + const std::string& usage_format); +}; + +template +struct FlagRegistration { + FlagRegistration(Flag& flag, std::string&& name, bool required, + bool is_short) { + std::string fixed_name = name; + for (auto& c : fixed_name) { + if (c == '_') { + c = '-'; + } + } + + FlagList::register_flag(flag, std::move(fixed_name), required, is_short); + } +}; + +} // namespace flags + +#endif // INCLUDE_SPIRV_TOOLS_UTIL_FLAGS_HPP_