diff --git a/spirv_cross.cpp b/spirv_cross.cpp index c2f03c6d..e43b2dcb 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -2526,15 +2526,7 @@ vector Compiler::get_entry_points() const { vector entries; for (auto &entry : entry_points) - entries.push_back(entry.second.name); - return entries; -} - -unordered_map Compiler::get_entry_point_name_map() const -{ - unordered_map entries; - for (auto &entry : entry_points) - entries[entry.second.orig_name] = entry.second.name; + entries.push_back(entry.second.orig_name); return entries; } @@ -2547,8 +2539,9 @@ void Compiler::set_entry_point(const std::string &name) SPIREntryPoint &Compiler::get_entry_point(const std::string &name) { auto itr = - find_if(begin(entry_points), end(entry_points), - [&](const std::pair &entry) -> bool { return entry.second.name == name; }); + find_if(begin(entry_points), end(entry_points), [&](const std::pair &entry) -> bool { + return entry.second.orig_name == name; + }); if (itr == end(entry_points)) SPIRV_CROSS_THROW("Entry point does not exist."); @@ -2559,8 +2552,9 @@ SPIREntryPoint &Compiler::get_entry_point(const std::string &name) const SPIREntryPoint &Compiler::get_entry_point(const std::string &name) const { auto itr = - find_if(begin(entry_points), end(entry_points), - [&](const std::pair &entry) -> bool { return entry.second.name == name; }); + find_if(begin(entry_points), end(entry_points), [&](const std::pair &entry) -> bool { + return entry.second.orig_name == name; + }); if (itr == end(entry_points)) SPIRV_CROSS_THROW("Entry point does not exist."); @@ -2568,6 +2562,11 @@ const SPIREntryPoint &Compiler::get_entry_point(const std::string &name) const return itr->second; } +const string &Compiler::get_cleansed_entry_point_name(const std::string &name) const +{ + return get_entry_point(name).name; +} + const SPIREntryPoint &Compiler::get_entry_point() const { return entry_points.find(entry_point)->second; diff --git a/spirv_cross.hpp b/spirv_cross.hpp index 7d2fb9cb..e9f75987 100644 --- a/spirv_cross.hpp +++ b/spirv_cross.hpp @@ -264,17 +264,20 @@ public: std::vector get_entry_points() const; void set_entry_point(const std::string &name); - // Returns a mapping between the original entry point name in the SPIR-V and a modified - // name defined by the backend. Some backends (eg. MSL) restrict the legal names allowed - // for entry point names (eg. "main" is illegal in MSL). Renaming occurs during compile(). - // Calling this function after before compiling will return a map of the original names - // to those same original names. - std::unordered_map get_entry_point_name_map() const; - // Returns the internal data structure for entry points to allow poking around. const SPIREntryPoint &get_entry_point(const std::string &name) const; SPIREntryPoint &get_entry_point(const std::string &name); + // Some shader languages restrict the names that can be given to entry points, and the + // corresponding backend will automatically rename an entry point name, during the call + // to compile() if it is illegal. For example, the common entry point name main() is + // illegal in MSL, and is renamed to an alternate name by the MSL backend. + // Given the original entry point name contained in the SPIR-V, this function returns + // the name, as updated by the backend during the call to compile(). If the name is not + // illegal, and has not been renamed, or if this function is called before compile(), + // this function will simply return the same name. + const std::string &get_cleansed_entry_point_name(const std::string &name) const; + // Query and modify OpExecutionMode. uint64_t get_execution_mode_mask() const; void unset_execution_mode(spv::ExecutionMode mode); diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 78af8d25..99bbc821 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -74,23 +74,26 @@ static const uint32_t kPushConstBinding = 0; class CompilerMSL : public CompilerGLSL { public: -#define MAKE_MSL_VERSION(major, minor, patch) (((major)*10000) + ((minor)*100) + (patch)) - // Options for compiling to Metal Shading Language struct Options { - uint32_t msl_version = MAKE_MSL_VERSION(1, 2, 0); + uint32_t msl_version = make_msl_version(1, 2); bool enable_point_size_builtin = true; bool resolve_specialized_array_lengths = true; void set_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) { - msl_version = MAKE_MSL_VERSION(major, minor, patch); + msl_version = make_msl_version(major, minor, patch); } bool supports_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) { - return msl_version >= MAKE_MSL_VERSION(major, minor, patch); + return msl_version >= make_msl_version(major, minor, patch); + } + + static uint32_t make_msl_version(uint32_t major, uint32_t minor = 0, uint32_t patch = 0) + { + return (major * 10000) + (minor * 100) + patch; } };