diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp index 9a080ee21..faf545fae 100644 --- a/source/comp/markv_codec.cpp +++ b/source/comp/markv_codec.cpp @@ -38,6 +38,7 @@ #include "diagnostic.h" #include "enum_string_mapping.h" #include "extensions.h" +#include "ext_inst.h" #include "instruction.h" #include "opcode.h" #include "operand.h" @@ -161,6 +162,8 @@ size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) { case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: case SPV_OPERAND_TYPE_SELECTION_CONTROL: return 4; + case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: + return 6; default: return 0; } @@ -674,6 +677,8 @@ class MarkvDecoder : public MarkvCodecBase { std::unordered_map id_to_type_id_; // Maps a type ID to its number type description. std::unordered_map type_id_to_number_type_info_; + // Maps an ExtInstImport id to the extended instruction type. + std::unordered_map import_id_to_ext_inst_type_; }; void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction, @@ -1028,16 +1033,35 @@ spv_result_t MarkvDecoder::DecodeOperand( parsed_operand.type = SPV_OPERAND_TYPE_ID; if (opcode == SpvOpExtInst && parsed_operand.offset == 3) { - // TODO(atgoo@github.com) Work in progress. - assert(0 && "Not implemented"); + // The current word is the extended instruction set id. + // Set the extended instruction set type for the current instruction. + auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); + if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { + return vstate_.diag(SPV_ERROR_INVALID_ID) + << "OpExtInst set id " << id + << " does not reference an OpExtInstImport result Id"; + } + inst->ext_inst_type = ext_inst_type_iter->second; } } break; } case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { - // TODO(atgoo@github.com) Work in progress. - assert(0 && "Not implemented"); + uint32_t word = 0; + if (!DecodeOperandWord(type, &word)) + return vstate_.diag(SPV_ERROR_INVALID_BINARY) + << "Failed to read enum"; + + spirv_.push_back(word); + + assert(SpvOpExtInst == opcode); + assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE); + spv_ext_inst_desc ext_inst; + if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst)) + return vstate_.diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction number: " << word; + spvPushOperandTypes(ext_inst->operandTypes, expected_operands); break; } @@ -1130,8 +1154,22 @@ spv_result_t MarkvDecoder::DecodeOperand( std::memcpy(&spirv_[first_word_index], str.data(), str.size()); if (SpvOpExtInstImport == opcode) { - // TODO(atgoo@github.com) Work in progress. - assert(0 && "Not implemented"); + // Record the extended instruction type for the ID for this import. + // There is only one string literal argument to OpExtInstImport, + // so it's sufficient to guard this just on the opcode. + const spv_ext_inst_type_t ext_inst_type = + spvExtInstImportTypeGet(str.data()); + if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { + return vstate_.diag(SPV_ERROR_INVALID_BINARY) + << "Invalid extended instruction import '" << str.data() << "'"; + } + // We must have parsed a valid result ID. It's a condition + // of the grammar, and we only accept non-zero result Ids. + assert(inst->result_id); + const bool inserted = import_id_to_ext_inst_type_.emplace( + inst->result_id, ext_inst_type).second; + (void)inserted; + assert(inserted); } break; } @@ -1327,7 +1365,6 @@ spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) { spirv_[instruction_offset] = spvOpcodeMake(inst->num_words, SpvOp(inst->opcode)); - assert(inst->num_words == std::accumulate( parsed_operands_.begin(), parsed_operands_.end(), 1, [](size_t num_words, const spv_parsed_operand_t& operand) { diff --git a/test/comp/markv_codec_test.cpp b/test/comp/markv_codec_test.cpp index b4e029615..c43cc77ad 100644 --- a/test/comp/markv_codec_test.cpp +++ b/test/comp/markv_codec_test.cpp @@ -410,4 +410,24 @@ OpDecorate %1 Uniform )"); } +TEST(Markv, WithExtInst) { + TestEncodeDecode(R"( +OpCapability Addresses +OpCapability Kernel +OpCapability GenericPointer +OpCapability Linkage +%opencl = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical32 OpenCL +%f32 = OpTypeFloat 32 +%void = OpTypeVoid +%void_func = OpTypeFunction %void +%100 = OpConstant %f32 1.1 +%main = OpFunction %void None %void_func +%entry_main = OpLabel +%200 = OpExtInst %f32 %opencl cos %100 +OpReturn +OpFunctionEnd +)"); +} + } // namespace