From 9c006b724ee2a094b57ffcdf2cd0da060af1f821 Mon Sep 17 00:00:00 2001 From: titzer Date: Thu, 4 Feb 2016 02:16:20 -0800 Subject: [PATCH] [wasm] Refactor handling of operands to bytecodes. This cleans up and simplifyies handling the bytes followin an opcode with little helper structs that will be useful in the interpreter and already have been in keeping OpcodeArity and OpcodeLength up to date with the decoder. R=bradnelson@chromium.org, ahaas@chromium.org BUG= Review URL: https://codereview.chromium.org/1664883002 Cr-Commit-Position: refs/heads/master@{#33723} --- src/wasm/ast-decoder.cc | 681 ++++++++++---------- src/wasm/ast-decoder.h | 156 ++++- src/wasm/decoder.h | 68 ++ test/unittests/wasm/ast-decoder-unittest.cc | 30 +- 4 files changed, 561 insertions(+), 374 deletions(-) diff --git a/src/wasm/ast-decoder.cc b/src/wasm/ast-decoder.cc index 59ca446493..93129f6d2f 100644 --- a/src/wasm/ast-decoder.cc +++ b/src/wasm/ast-decoder.cc @@ -41,7 +41,6 @@ struct Tree { WasmOpcode opcode() const { return static_cast(*pc); } }; - // A production represents an incomplete decoded tree in the LR decoder. struct Production { Tree* tree; // the root of the syntax tree. @@ -103,8 +102,8 @@ struct IfEnv { class WasmDecoder : public Decoder { public: WasmDecoder() : Decoder(nullptr, nullptr), function_env_(nullptr) {} - - protected: + WasmDecoder(FunctionEnv* env, const byte* start, const byte* end) + : Decoder(start, end), function_env_(env) {} FunctionEnv* function_env_; void Reset(FunctionEnv* function_env, const byte* start, const byte* end) { @@ -136,62 +135,209 @@ class WasmDecoder : public Decoder { return read_u64(pc + 1); } - LocalType LocalOperand(const byte* pc, uint32_t* index, int* length) { - *index = UnsignedLEB128Operand(pc, length); - if (function_env_->IsValidLocal(*index)) { - return function_env_->GetLocalType(*index); + inline bool Validate(const byte* pc, LocalIndexOperand& operand) { + if (operand.index < function_env_->total_locals) { + operand.type = function_env_->GetLocalType(operand.index); + return true; } - error(pc, "invalid local variable index"); - return kAstStmt; + error(pc, pc + 1, "invalid local index"); + return false; } - LocalType GlobalOperand(const byte* pc, uint32_t* index, int* length) { - *index = UnsignedLEB128Operand(pc, length); - if (function_env_->module->IsValidGlobal(*index)) { - return WasmOpcodes::LocalTypeFor( - function_env_->module->GetGlobalType(*index)); + inline bool Validate(const byte* pc, GlobalIndexOperand& operand) { + ModuleEnv* m = function_env_->module; + if (m && m->module && operand.index < m->module->globals->size()) { + operand.machine_type = m->module->globals->at(operand.index).type; + operand.type = WasmOpcodes::LocalTypeFor(operand.machine_type); + return true; } - error(pc, "invalid global variable index"); - return kAstStmt; + error(pc, pc + 1, "invalid global index"); + return false; } - FunctionSig* FunctionSigOperand(const byte* pc, uint32_t* index, - int* length) { - *index = UnsignedLEB128Operand(pc, length); - if (function_env_->module->IsValidFunction(*index)) { - return function_env_->module->GetFunctionSignature(*index); + inline bool Validate(const byte* pc, FunctionIndexOperand& operand) { + ModuleEnv* m = function_env_->module; + if (m && m->module && operand.index < m->module->functions->size()) { + operand.sig = m->module->functions->at(operand.index).sig; + return true; } - error(pc, "invalid function index"); - return nullptr; + error(pc, pc + 1, "invalid function index"); + return false; } - FunctionSig* SigOperand(const byte* pc, uint32_t* index, int* length) { - *index = UnsignedLEB128Operand(pc, length); - if (function_env_->module->IsValidSignature(*index)) { - return function_env_->module->GetSignature(*index); + inline bool Validate(const byte* pc, SignatureIndexOperand& operand) { + ModuleEnv* m = function_env_->module; + if (m && m->module && operand.index < m->module->signatures->size()) { + operand.sig = m->module->signatures->at(operand.index); + return true; } - error(pc, "invalid signature index"); - return nullptr; + error(pc, pc + 1, "invalid signature index"); + return false; } - uint32_t UnsignedLEB128Operand(const byte* pc, int* length) { - uint32_t result = 0; - ReadUnsignedLEB128ErrorCode error_code = - ReadUnsignedLEB128Operand(pc + 1, limit_, length, &result); - if (error_code == kInvalidLEB128) error(pc, "invalid LEB128 varint"); - if (error_code == kMissingLEB128) error(pc, "expected LEB128 varint"); - (*length)++; - return result; + inline bool Validate(const byte* pc, BreakDepthOperand& operand, + ZoneVector& blocks) { + if (operand.depth < blocks.size()) { + operand.target = &blocks[blocks.size() - operand.depth - 1]; + return true; + } + error(pc, pc + 1, "invalid break depth"); + return false; } - void MemoryAccessOperand(const byte* pc, int* length, uint32_t* offset) { - byte bitfield = ByteOperand(pc, "missing memory access operand"); - if (MemoryAccess::OffsetField::decode(bitfield)) { - *offset = UnsignedLEB128Operand(pc + 1, length); - (*length)++; // to account for the memory access byte - } else { - *offset = 0; - *length = 2; + bool Validate(const byte* pc, TableSwitchOperand& operand, + size_t block_depth) { + if (operand.table_count == 0) { + error(pc, "tableswitch with 0 entries"); + return false; + } + // Verify table. + for (uint32_t i = 0; i < operand.table_count; i++) { + uint16_t target = operand.read_entry(this, i); + if (target >= 0x8000) { + size_t depth = target - 0x8000; + if (depth > block_depth) { + error(operand.table + i * 2, "improper branch in tableswitch"); + return false; + } + } else { + if (target >= operand.case_count) { + error(operand.table + i * 2, "invalid case target in tableswitch"); + return false; + } + } + } + return true; + } + + int OpcodeArity(const byte* pc) { +#define DECLARE_ARITY(name, ...) \ + static const LocalType kTypes_##name[] = {__VA_ARGS__}; \ + static const int kArity_##name = \ + static_cast(arraysize(kTypes_##name) - 1); + + FOREACH_SIGNATURE(DECLARE_ARITY); +#undef DECLARE_ARITY + + switch (static_cast(*pc)) { + case kExprI8Const: + case kExprI32Const: + case kExprI64Const: + case kExprF64Const: + case kExprF32Const: + case kExprGetLocal: + case kExprLoadGlobal: + case kExprNop: + case kExprUnreachable: + return 0; + + case kExprBr: + case kExprStoreGlobal: + case kExprSetLocal: + return 1; + + case kExprIf: + case kExprBrIf: + return 2; + case kExprIfElse: + case kExprSelect: + return 3; + + case kExprBlock: + case kExprLoop: { + BlockCountOperand operand(this, pc); + return operand.count; + } + + case kExprCallFunction: { + FunctionIndexOperand operand(this, pc); + return static_cast( + function_env_->module->GetFunctionSignature(operand.index) + ->parameter_count()); + } + case kExprCallIndirect: { + SignatureIndexOperand operand(this, pc); + return 1 + static_cast( + function_env_->module->GetSignature(operand.index) + ->parameter_count()); + } + case kExprReturn: { + return static_cast(function_env_->sig->return_count()); + } + case kExprTableSwitch: { + TableSwitchOperand operand(this, pc); + return 1 + operand.case_count; + } + +#define DECLARE_OPCODE_CASE(name, opcode, sig) \ + case kExpr##name: \ + return kArity_##sig; + + FOREACH_LOAD_MEM_OPCODE(DECLARE_OPCODE_CASE) + FOREACH_STORE_MEM_OPCODE(DECLARE_OPCODE_CASE) + FOREACH_MISC_MEM_OPCODE(DECLARE_OPCODE_CASE) + FOREACH_SIMPLE_OPCODE(DECLARE_OPCODE_CASE) +#undef DECLARE_OPCODE_CASE + } + UNREACHABLE(); + return 0; + } + + int OpcodeLength(const byte* pc) { + switch (static_cast(*pc)) { +#define DECLARE_OPCODE_CASE(name, opcode, sig) case kExpr##name: + FOREACH_LOAD_MEM_OPCODE(DECLARE_OPCODE_CASE) + FOREACH_STORE_MEM_OPCODE(DECLARE_OPCODE_CASE) +#undef DECLARE_OPCODE_CASE + { + MemoryAccessOperand operand(this, pc); + return 1 + operand.length; + } + case kExprBlock: + case kExprLoop: { + BlockCountOperand operand(this, pc); + return 1 + operand.length; + } + case kExprBr: + case kExprBrIf: { + BreakDepthOperand operand(this, pc); + return 1 + operand.length; + } + case kExprStoreGlobal: + case kExprLoadGlobal: { + GlobalIndexOperand operand(this, pc); + return 1 + operand.length; + } + + case kExprCallFunction: { + FunctionIndexOperand operand(this, pc); + return 1 + operand.length; + } + case kExprCallIndirect: { + SignatureIndexOperand operand(this, pc); + return 1 + operand.length; + } + + case kExprSetLocal: + case kExprGetLocal: { + LocalIndexOperand operand(this, pc); + return 1 + operand.length; + } + case kExprTableSwitch: { + TableSwitchOperand operand(this, pc); + return 1 + operand.length; + } + case kExprI8Const: + return 2; + case kExprI32Const: + case kExprF32Const: + return 5; + case kExprI64Const: + case kExprF64Const: + return 9; + + default: + return 1; } } }; @@ -431,25 +577,25 @@ class LR_WasmDecoder : public WasmDecoder { Leaf(kAstStmt); break; case kExprBlock: { - int length = ByteOperand(pc_); - if (length < 1) { + BlockCountOperand operand(this, pc_); + if (operand.count < 1) { Leaf(kAstStmt); } else { - Shift(kAstEnd, length); + Shift(kAstEnd, operand.count); // The break environment is the outer environment. SsaEnv* break_env = ssa_env_; PushBlock(break_env); SetEnv("block:start", Steal(break_env)); } - len = 2; + len = 1 + operand.length; break; } case kExprLoop: { - int length = ByteOperand(pc_); - if (length < 1) { + BlockCountOperand operand(this, pc_); + if (operand.count < 1) { Leaf(kAstStmt); } else { - Shift(kAstEnd, length); + Shift(kAstEnd, operand.count); // The break environment is the outer environment. SsaEnv* break_env = ssa_env_; PushBlock(break_env); @@ -461,7 +607,7 @@ class LR_WasmDecoder : public WasmDecoder { PushBlock(cont_env); blocks_.back().stack_depth = -1; // no production for inner block. } - len = 2; + len = 1 + operand.length; break; } case kExprIf: @@ -474,58 +620,27 @@ class LR_WasmDecoder : public WasmDecoder { Shift(kAstStmt, 3); // Result type is typeof(x) in {c ? x : y}. break; case kExprBr: { - uint32_t depth = ByteOperand(pc_); - Shift(kAstEnd, 1); - if (depth >= blocks_.size()) { - error("improperly nested branch"); + BreakDepthOperand operand(this, pc_); + if (Validate(pc_, operand, blocks_)) { + Shift(kAstEnd, 1); } - len = 2; + len = 1 + operand.length; break; } case kExprBrIf: { - uint32_t depth = ByteOperand(pc_); - Shift(kAstStmt, 2); - if (depth >= blocks_.size()) { - error("improperly nested conditional branch"); + BreakDepthOperand operand(this, pc_); + if (Validate(pc_, operand, blocks_)) { + Shift(kAstStmt, 2); } - len = 2; + len = 1 + operand.length; break; } case kExprTableSwitch: { - if (!checkAvailable(5)) { - error("expected #tableswitch , fell off end"); - break; - } - uint16_t case_count = read_u16(pc_ + 1); - uint16_t table_count = read_u16(pc_ + 3); - len = 5 + table_count * 2; - - if (table_count == 0) { - error("tableswitch with 0 entries"); - break; - } - - if (!checkAvailable(len)) { - error("expected #tableswitch
, fell off end"); - break; - } - - Shift(kAstEnd, 1 + case_count); - - // Verify table. - for (int i = 0; i < table_count; i++) { - uint16_t target = read_u16(pc_ + 5 + i * 2); - if (target >= 0x8000) { - size_t depth = target - 0x8000; - if (depth > blocks_.size()) { - error(pc_ + 5 + i * 2, "improper branch in tableswitch"); - } - } else { - if (target >= case_count) { - error(pc_ + 5 + i * 2, "invalid case target in tableswitch"); - } - } + TableSwitchOperand operand(this, pc_); + if (Validate(pc_, operand, blocks_.size())) { + Shift(kAstEnd, 1 + operand.case_count); } + len = 1 + operand.length; break; } case kExprReturn: { @@ -546,59 +661,66 @@ class LR_WasmDecoder : public WasmDecoder { break; } case kExprI8Const: { - int32_t value = bit_cast(ByteOperand(pc_)); - Leaf(kAstI32, BUILD(Int32Constant, value)); - len = 2; + ImmI8Operand operand(this, pc_); + Leaf(kAstI32, BUILD(Int32Constant, operand.value)); + len = 1 + operand.length; break; } case kExprI32Const: { - uint32_t value = Uint32Operand(pc_); - Leaf(kAstI32, BUILD(Int32Constant, value)); - len = 5; + ImmI32Operand operand(this, pc_); + Leaf(kAstI32, BUILD(Int32Constant, operand.value)); + len = 1 + operand.length; break; } case kExprI64Const: { - uint64_t value = Uint64Operand(pc_); - Leaf(kAstI64, BUILD(Int64Constant, value)); - len = 9; + ImmI64Operand operand(this, pc_); + Leaf(kAstI64, BUILD(Int64Constant, operand.value)); + len = 1 + operand.length; break; } case kExprF32Const: { - float value = bit_cast(Uint32Operand(pc_)); - Leaf(kAstF32, BUILD(Float32Constant, value)); - len = 5; + ImmF32Operand operand(this, pc_); + Leaf(kAstF32, BUILD(Float32Constant, operand.value)); + len = 1 + operand.length; break; } case kExprF64Const: { - double value = bit_cast(Uint64Operand(pc_)); - Leaf(kAstF64, BUILD(Float64Constant, value)); - len = 9; + ImmF64Operand operand(this, pc_); + Leaf(kAstF64, BUILD(Float64Constant, operand.value)); + len = 1 + operand.length; break; } case kExprGetLocal: { - uint32_t index; - LocalType type = LocalOperand(pc_, &index, &len); - TFNode* val = - build() && type != kAstStmt ? ssa_env_->locals[index] : nullptr; - Leaf(type, val); + LocalIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + TFNode* val = build() ? ssa_env_->locals[operand.index] : nullptr; + Leaf(operand.type, val); + } + len = 1 + operand.length; break; } case kExprSetLocal: { - uint32_t index; - LocalType type = LocalOperand(pc_, &index, &len); - Shift(type, 1); + LocalIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + Shift(operand.type, 1); + } + len = 1 + operand.length; break; } case kExprLoadGlobal: { - uint32_t index; - LocalType type = GlobalOperand(pc_, &index, &len); - Leaf(type, BUILD(LoadGlobal, index)); + GlobalIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + Leaf(operand.type, BUILD(LoadGlobal, operand.index)); + } + len = 1 + operand.length; break; } case kExprStoreGlobal: { - uint32_t index; - LocalType type = GlobalOperand(pc_, &index, &len); - Shift(type, 1); + GlobalIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + Shift(operand.type, 1); + } + len = 1 + operand.length; break; } case kExprI32LoadMem8S: @@ -647,27 +769,25 @@ class LR_WasmDecoder : public WasmDecoder { Shift(kAstI32, 1); break; case kExprCallFunction: { - uint32_t unused; - FunctionSig* sig = FunctionSigOperand(pc_, &unused, &len); - if (sig) { - LocalType type = - sig->return_count() == 0 ? kAstStmt : sig->GetReturn(); - Shift(type, static_cast(sig->parameter_count())); - } else { - Leaf(kAstI32); // error + FunctionIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + LocalType type = operand.sig->return_count() == 0 + ? kAstStmt + : operand.sig->GetReturn(); + Shift(type, static_cast(operand.sig->parameter_count())); } + len = 1 + operand.length; break; } case kExprCallIndirect: { - uint32_t unused; - FunctionSig* sig = SigOperand(pc_, &unused, &len); - if (sig) { - LocalType type = - sig->return_count() == 0 ? kAstStmt : sig->GetReturn(); - Shift(type, static_cast(1 + sig->parameter_count())); - } else { - Leaf(kAstI32); // error + SignatureIndexOperand operand(this, pc_); + if (Validate(pc_, operand)) { + LocalType type = operand.sig->return_count() == 0 + ? kAstStmt + : operand.sig->GetReturn(); + Shift(type, static_cast(1 + operand.sig->parameter_count())); } + len = 1 + operand.length; break; } default: @@ -690,19 +810,15 @@ class LR_WasmDecoder : public WasmDecoder { } int DecodeLoadMem(const byte* pc, LocalType type) { - int length = 2; - uint32_t offset; - MemoryAccessOperand(pc, &length, &offset); + MemoryAccessOperand operand(this, pc); Shift(type, 1); - return length; + return 1 + operand.length; } int DecodeStoreMem(const byte* pc, LocalType type) { - int length = 2; - uint32_t offset; - MemoryAccessOperand(pc, &length, &offset); + MemoryAccessOperand operand(this, pc); Shift(type, 2); - return length; + return 1 + operand.length; } void AddImplicitReturnAtEnd() { @@ -876,31 +992,23 @@ class LR_WasmDecoder : public WasmDecoder { break; } case kExprBr: { - uint32_t depth = ByteOperand(p->pc()); - if (depth >= blocks_.size()) { - error("improperly nested branch"); - break; - } - Block* block = &blocks_[blocks_.size() - depth - 1]; - ReduceBreakToExprBlock(p, block); + BreakDepthOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand, blocks_)); + ReduceBreakToExprBlock(p, operand.target); break; } case kExprBrIf: { if (p->index == 1) { TypeCheckLast(p, kAstI32); } else if (p->done()) { - uint32_t depth = ByteOperand(p->pc()); - if (depth >= blocks_.size()) { - error("improperly nested branch"); - break; - } - Block* block = &blocks_[blocks_.size() - depth - 1]; + BreakDepthOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand, blocks_)); SsaEnv* fenv = ssa_env_; SsaEnv* tenv = Split(fenv); BUILD(Branch, p->tree->children[0]->node, &tenv->control, &fenv->control); ssa_env_ = tenv; - ReduceBreakToExprBlock(p, block); + ReduceBreakToExprBlock(p, operand.target); ssa_env_ = fenv; } break; @@ -909,18 +1017,22 @@ class LR_WasmDecoder : public WasmDecoder { if (p->index == 1) { // Switch key finished. TypeCheckLast(p, kAstI32); + if (failed()) break; - uint16_t table_count = read_u16(p->pc() + 3); + TableSwitchOperand operand(this, p->pc()); + DCHECK(Validate(p->pc(), operand, blocks_.size())); // Build the switch only if it has more than just a default target. - bool build_switch = table_count > 1; + bool build_switch = operand.table_count > 1; TFNode* sw = nullptr; - if (build_switch) sw = BUILD(Switch, table_count, p->last()->node); + if (build_switch) + sw = BUILD(Switch, operand.table_count, p->last()->node); // Allocate environments for each case. - uint16_t case_count = read_u16(p->pc() + 1); - SsaEnv** case_envs = zone_->NewArray(case_count); - for (int i = 0; i < case_count; i++) case_envs[i] = UnreachableEnv(); + SsaEnv** case_envs = zone_->NewArray(operand.case_count); + for (uint32_t i = 0; i < operand.case_count; i++) { + case_envs[i] = UnreachableEnv(); + } ifs_.push_back({nullptr, nullptr, case_envs}); SsaEnv* break_env = ssa_env_; @@ -929,13 +1041,14 @@ class LR_WasmDecoder : public WasmDecoder { ssa_env_ = copy; // Build the environments for each case based on the table. - for (int i = 0; i < table_count; i++) { - uint16_t target = read_u16(p->pc() + 5 + i * 2); + for (uint32_t i = 0; i < operand.table_count; i++) { + uint16_t target = operand.read_entry(this, i); SsaEnv* env = copy; if (build_switch) { env = Split(env); - env->control = (i == table_count - 1) ? BUILD(IfDefault, sw) - : BUILD(IfValue, i, sw); + env->control = (i == operand.table_count - 1) + ? BUILD(IfDefault, sw) + : BUILD(IfValue, i, sw); } if (target >= 0x8000) { // Targets an outer block. @@ -981,12 +1094,11 @@ class LR_WasmDecoder : public WasmDecoder { break; } case kExprSetLocal: { - int unused = 0; - uint32_t index; - LocalType type = LocalOperand(p->pc(), &index, &unused); + LocalIndexOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand)); Tree* val = p->last(); - if (type == val->type) { - if (build()) ssa_env_->locals[index] = val->node; + if (operand.type == val->type) { + if (build()) ssa_env_->locals[operand.index] = val->node; p->tree->node = val->node; } else { error(p->pc(), val->pc, "Typecheck failed in SetLocal"); @@ -994,12 +1106,11 @@ class LR_WasmDecoder : public WasmDecoder { break; } case kExprStoreGlobal: { - int unused = 0; - uint32_t index; - LocalType type = GlobalOperand(p->pc(), &index, &unused); + GlobalIndexOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand)); Tree* val = p->last(); - if (type == val->type) { - BUILD(StoreGlobal, index, val->node); + if (operand.type == val->type) { + BUILD(StoreGlobal, operand.index, val->node); p->tree->node = val->node; } else { error(p->pc(), val->pc, "Typecheck failed in StoreGlobal"); @@ -1068,34 +1179,29 @@ class LR_WasmDecoder : public WasmDecoder { return; case kExprCallFunction: { - int len; - uint32_t index; - FunctionSig* sig = FunctionSigOperand(p->pc(), &index, &len); - if (!sig) break; + FunctionIndexOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand)); if (p->index > 0) { - TypeCheckLast(p, sig->GetParam(p->index - 1)); + TypeCheckLast(p, operand.sig->GetParam(p->index - 1)); } if (p->done() && build()) { uint32_t count = p->tree->count + 1; TFNode** buffer = builder_->Buffer(count); - FunctionSig* sig = FunctionSigOperand(p->pc(), &index, &len); - USE(sig); buffer[0] = nullptr; // reserved for code object. for (uint32_t i = 1; i < count; i++) { buffer[i] = p->tree->children[i - 1]->node; } - p->tree->node = builder_->CallDirect(index, buffer); + p->tree->node = builder_->CallDirect(operand.index, buffer); } break; } case kExprCallIndirect: { - int len; - uint32_t index; - FunctionSig* sig = SigOperand(p->pc(), &index, &len); + SignatureIndexOperand operand(this, p->pc()); + CHECK(Validate(p->pc(), operand)); if (p->index == 1) { TypeCheckLast(p, kAstI32); } else { - TypeCheckLast(p, sig->GetParam(p->index - 2)); + TypeCheckLast(p, operand.sig->GetParam(p->index - 2)); } if (p->done() && build()) { uint32_t count = p->tree->count; @@ -1103,7 +1209,7 @@ class LR_WasmDecoder : public WasmDecoder { for (uint32_t i = 0; i < count; i++) { buffer[i] = p->tree->children[i]->node; } - p->tree->node = builder_->CallIndirect(index, buffer); + p->tree->node = builder_->CallIndirect(operand.index, buffer); } break; } @@ -1152,11 +1258,9 @@ class LR_WasmDecoder : public WasmDecoder { DCHECK_EQ(1, p->index); TypeCheckLast(p, kAstI32); // index if (build()) { - int length = 0; - uint32_t offset = 0; - MemoryAccessOperand(p->pc(), &length, &offset); + MemoryAccessOperand operand(this, p->pc()); p->tree->node = - builder_->LoadMem(type, mem_type, p->last()->node, offset); + builder_->LoadMem(type, mem_type, p->last()->node, operand.offset); } } @@ -1167,11 +1271,10 @@ class LR_WasmDecoder : public WasmDecoder { DCHECK_EQ(2, p->index); TypeCheckLast(p, type); if (build()) { - int length = 0; - uint32_t offset = 0; - MemoryAccessOperand(p->pc(), &length, &offset); + MemoryAccessOperand operand(this, p->pc()); TFNode* val = p->tree->children[1]->node; - builder_->StoreMem(mem_type, p->tree->children[0]->node, offset, val); + builder_->StoreMem(mem_type, p->tree->children[0]->node, operand.offset, + val); p->tree->node = val; } } @@ -1194,7 +1297,7 @@ class LR_WasmDecoder : public WasmDecoder { void SetEnv(const char* reason, SsaEnv* env) { TRACE(" env = %p, block depth = %d, reason = %s", static_cast(env), static_cast(blocks_.size()), reason); - if (env->control != nullptr && FLAG_trace_wasm_decoder) { + if (FLAG_trace_wasm_decoder && env && env->control) { TRACE(", control = "); compiler::WasmGraphBuilder::PrintDebugName(env->control); } @@ -1447,158 +1550,29 @@ ReadUnsignedLEB128ErrorCode ReadUnsignedLEB128Operand(const byte* pc, const byte* limit, int* length, uint32_t* result) { - *result = 0; - const byte* ptr = pc; - const byte* end = pc + 5; // maximum 5 bytes. - if (end > limit) end = limit; - int shift = 0; - byte b = 0; - while (ptr < end) { - b = *ptr++; - *result = *result | ((b & 0x7F) << shift); - if ((b & 0x80) == 0) break; - shift += 7; - } - DCHECK_LE(ptr - pc, 5); - *length = static_cast(ptr - pc); - if (ptr == end && (b & 0x80)) { - return kInvalidLEB128; - } else if (*length == 0) { - return kMissingLEB128; - } else { - return kNoError; - } + Decoder decoder(pc, limit); + *result = decoder.checked_read_u32v(pc, 0, length); + if (decoder.ok()) return kNoError; + return (limit - pc) > 1 ? kInvalidLEB128 : kMissingLEB128; } - -// TODO(titzer): move this into WasmDecoder and bounds check accesses. -int OpcodeLength(const byte* pc) { - switch (static_cast(*pc)) { -#define DECLARE_OPCODE_CASE(name, opcode, sig) case kExpr##name: - FOREACH_LOAD_MEM_OPCODE(DECLARE_OPCODE_CASE) - FOREACH_STORE_MEM_OPCODE(DECLARE_OPCODE_CASE) -#undef DECLARE_OPCODE_CASE - { - // Loads and stores have an optional offset. - byte bitfield = pc[1]; - if (MemoryAccess::OffsetField::decode(bitfield)) { - int length; - uint32_t result = 0; - ReadUnsignedLEB128Operand(pc + 2, pc + 7, &length, &result); - return 2 + length; - } - return 2; - } - case kExprI8Const: - case kExprBlock: - case kExprLoop: - case kExprBr: - case kExprBrIf: - return 2; - case kExprI32Const: - case kExprF32Const: - return 5; - case kExprI64Const: - case kExprF64Const: - return 9; - case kExprStoreGlobal: - case kExprSetLocal: - case kExprLoadGlobal: - case kExprCallFunction: - case kExprCallIndirect: - case kExprGetLocal: { - int length; - uint32_t result = 0; - ReadUnsignedLEB128Operand(pc + 1, pc + 6, &length, &result); - return 1 + length; - } - case kExprTableSwitch: { - uint16_t table_count = *reinterpret_cast(pc + 3); - return 5 + table_count * 2; - } - - default: - return 1; - } +int OpcodeLength(const byte* pc, const byte* end) { + WasmDecoder decoder(nullptr, pc, end); + return decoder.OpcodeLength(pc); } - -// TODO(titzer): move this into WasmDecoder and bounds check accesses. -int OpcodeArity(FunctionEnv* env, const byte* pc) { -#define DECLARE_ARITY(name, ...) \ - static const LocalType kTypes_##name[] = {__VA_ARGS__}; \ - static const int kArity_##name = \ - static_cast(arraysize(kTypes_##name) - 1); - - FOREACH_SIGNATURE(DECLARE_ARITY); -#undef DECLARE_ARITY - - switch (static_cast(*pc)) { - case kExprI8Const: - case kExprI32Const: - case kExprI64Const: - case kExprF64Const: - case kExprF32Const: - case kExprGetLocal: - case kExprLoadGlobal: - case kExprNop: - case kExprUnreachable: - return 0; - - case kExprBr: - case kExprStoreGlobal: - case kExprSetLocal: - return 1; - - case kExprIf: - case kExprBrIf: - return 2; - case kExprIfElse: - case kExprSelect: - return 3; - case kExprBlock: - case kExprLoop: - return *(pc + 1); - - case kExprCallFunction: { - int index = *(pc + 1); - return static_cast( - env->module->GetFunctionSignature(index)->parameter_count()); - } - case kExprCallIndirect: { - int index = *(pc + 1); - return 1 + static_cast( - env->module->GetSignature(index)->parameter_count()); - } - case kExprReturn: - return static_cast(env->sig->return_count()); - case kExprTableSwitch: { - uint16_t case_count = *reinterpret_cast(pc + 1); - return 1 + case_count; - } - -#define DECLARE_OPCODE_CASE(name, opcode, sig) \ - case kExpr##name: \ - return kArity_##sig; - - FOREACH_LOAD_MEM_OPCODE(DECLARE_OPCODE_CASE) - FOREACH_STORE_MEM_OPCODE(DECLARE_OPCODE_CASE) - FOREACH_MISC_MEM_OPCODE(DECLARE_OPCODE_CASE) - FOREACH_SIMPLE_OPCODE(DECLARE_OPCODE_CASE) -#undef DECLARE_OPCODE_CASE - } - UNREACHABLE(); - return 0; +int OpcodeArity(FunctionEnv* env, const byte* pc, const byte* end) { + WasmDecoder decoder(env, pc, end); + return decoder.OpcodeArity(pc); } - - void PrintAst(FunctionEnv* env, const byte* start, const byte* end) { + WasmDecoder decoder(env, start, end); const byte* pc = start; std::vector arity_stack; while (pc < end) { - int arity = OpcodeArity(env, pc); - size_t length = OpcodeLength(pc); + int arity = decoder.OpcodeArity(pc); + size_t length = decoder.OpcodeLength(pc); for (auto arity : arity_stack) { printf(" "); @@ -1623,7 +1597,6 @@ void PrintAst(FunctionEnv* env, const byte* start, const byte* end) { } } - // Analyzes loop bodies for static assignments to locals, which helps in // reducing the number of phis introduced at loop headers. class LoopAssignmentAnalyzer : public WasmDecoder { @@ -1641,7 +1614,7 @@ class LoopAssignmentAnalyzer : public WasmDecoder { new (zone_) BitVector(function_env_->total_locals, zone_); // Keep a stack to model the nesting of expressions. std::vector arity_stack; - arity_stack.push_back(OpcodeArity(function_env_, pc_)); + arity_stack.push_back(OpcodeArity(pc_)); pc_ += OpcodeLength(pc_); // Iteratively process all AST nodes nested inside the loop. @@ -1650,16 +1623,16 @@ class LoopAssignmentAnalyzer : public WasmDecoder { int arity = 0; int length = 1; if (opcode == kExprSetLocal) { - uint32_t index; - LocalOperand(pc_, &index, &length); + LocalIndexOperand operand(this, pc_); if (assigned->length() > 0 && - static_cast(index) < assigned->length()) { + static_cast(operand.index) < assigned->length()) { // Unverified code might have an out-of-bounds index. - assigned->Add(index); + assigned->Add(operand.index); } arity = 1; + length = 1 + operand.length; } else { - arity = OpcodeArity(function_env_, pc_); + arity = OpcodeArity(pc_); length = OpcodeLength(pc_); } diff --git a/src/wasm/ast-decoder.h b/src/wasm/ast-decoder.h index 742c844121..f07c8800f8 100644 --- a/src/wasm/ast-decoder.h +++ b/src/wasm/ast-decoder.h @@ -6,6 +6,7 @@ #define V8_WASM_AST_DECODER_H_ #include "src/signature.h" +#include "src/wasm/decoder.h" #include "src/wasm/wasm-opcodes.h" #include "src/wasm/wasm-result.h" @@ -20,6 +21,156 @@ class WasmGraphBuilder; namespace wasm { +// Helpers for decoding different kinds of operands which follow bytecodes. +struct LocalIndexOperand { + uint32_t index; + LocalType type; + int length; + + inline LocalIndexOperand(Decoder* decoder, const byte* pc) { + index = decoder->checked_read_u32v(pc, 1, &length, "local index"); + type = kAstStmt; + } +}; + +struct ImmI8Operand { + int8_t value; + int length; + inline ImmI8Operand(Decoder* decoder, const byte* pc) { + value = bit_cast(decoder->checked_read_u8(pc, 1, "immi8")); + length = 1; + } +}; + +struct ImmI32Operand { + int32_t value; + int length; + inline ImmI32Operand(Decoder* decoder, const byte* pc) { + value = bit_cast(decoder->checked_read_u32(pc, 1, "immi32")); + length = 4; + } +}; + +struct ImmI64Operand { + int64_t value; + int length; + inline ImmI64Operand(Decoder* decoder, const byte* pc) { + value = bit_cast(decoder->checked_read_u64(pc, 1, "immi64")); + length = 8; + } +}; + +struct ImmF32Operand { + float value; + int length; + inline ImmF32Operand(Decoder* decoder, const byte* pc) { + value = bit_cast(decoder->checked_read_u32(pc, 1, "immf32")); + length = 4; + } +}; + +struct ImmF64Operand { + double value; + int length; + inline ImmF64Operand(Decoder* decoder, const byte* pc) { + value = bit_cast(decoder->checked_read_u64(pc, 1, "immf64")); + length = 8; + } +}; + +struct GlobalIndexOperand { + uint32_t index; + LocalType type; + MachineType machine_type; + int length; + + inline GlobalIndexOperand(Decoder* decoder, const byte* pc) { + index = decoder->checked_read_u32v(pc, 1, &length, "global index"); + type = kAstStmt; + machine_type = MachineType::None(); + } +}; + +struct Block; +struct BreakDepthOperand { + uint32_t depth; + Block* target; + int length; + inline BreakDepthOperand(Decoder* decoder, const byte* pc) { + depth = decoder->checked_read_u8(pc, 1, "break depth"); + length = 1; + target = nullptr; + } +}; + +struct BlockCountOperand { + uint32_t count; + int length; + inline BlockCountOperand(Decoder* decoder, const byte* pc) { + count = decoder->checked_read_u8(pc, 1, "block count"); + length = 1; + } +}; + +struct SignatureIndexOperand { + uint32_t index; + FunctionSig* sig; + int length; + inline SignatureIndexOperand(Decoder* decoder, const byte* pc) { + index = decoder->checked_read_u32v(pc, 1, &length, "signature index"); + sig = nullptr; + } +}; + +struct FunctionIndexOperand { + uint32_t index; + FunctionSig* sig; + int length; + inline FunctionIndexOperand(Decoder* decoder, const byte* pc) { + index = decoder->checked_read_u32v(pc, 1, &length, "function index"); + sig = nullptr; + } +}; + +struct TableSwitchOperand { + uint32_t case_count; + uint32_t table_count; + const byte* table; + int length; + inline TableSwitchOperand(Decoder* decoder, const byte* pc) { + case_count = decoder->checked_read_u16(pc, 1, "expected #cases"); + table_count = decoder->checked_read_u16(pc, 3, "expected #entries"); + length = 4 + table_count * 2; + + if (decoder->check(pc, 5, table_count * 2, "expected
")) { + table = pc + 5; + } else { + table = nullptr; + } + } + inline uint16_t read_entry(Decoder* decoder, int i) { + DCHECK(i >= 0 && static_cast(i) < table_count); + return table ? decoder->read_u16(table + i * sizeof(uint16_t)) : 0; + } +}; + +struct MemoryAccessOperand { + bool aligned; + uint32_t offset; + int length; + inline MemoryAccessOperand(Decoder* decoder, const byte* pc) { + byte bitfield = decoder->checked_read_u8(pc, 1, "memory access byte"); + aligned = MemoryAccess::AlignmentField::decode(bitfield); + if (MemoryAccess::OffsetField::decode(bitfield)) { + offset = decoder->checked_read_u32v(pc, 2, &length, "memory offset"); + length++; + } else { + offset = 0; + length = 1; + } + } +}; + typedef compiler::WasmGraphBuilder TFBuilder; struct ModuleEnv; // forward declaration of module interface. @@ -34,7 +185,6 @@ struct FunctionEnv { uint32_t local_f64_count; // number of float64 locals uint32_t total_locals; // sum of parameters and all locals - bool IsValidLocal(uint32_t index) { return index < total_locals; } uint32_t GetLocalCount() { return total_locals; } LocalType GetLocalType(uint32_t index) { if (index < static_cast(sig->parameter_count())) { @@ -112,10 +262,10 @@ BitVector* AnalyzeLoopAssignmentForTesting(Zone* zone, FunctionEnv* env, const byte* start, const byte* end); // Computes the length of the opcode at the given address. -int OpcodeLength(const byte* pc); +int OpcodeLength(const byte* pc, const byte* end); // Computes the arity (number of sub-nodes) of the opcode at the given address. -int OpcodeArity(FunctionEnv* env, const byte* pc); +int OpcodeArity(FunctionEnv* env, const byte* pc, const byte* end); } // namespace wasm } // namespace internal } // namespace v8 diff --git a/src/wasm/decoder.h b/src/wasm/decoder.h index 433e80a186..0e88eda022 100644 --- a/src/wasm/decoder.h +++ b/src/wasm/decoder.h @@ -44,6 +44,68 @@ class Decoder { virtual ~Decoder() {} + inline bool check(const byte* base, int offset, int length, const char* msg) { + DCHECK_GE(base, start_); + if ((base + offset + length) > limit_) { + error(base, base + offset, msg); + return false; + } + return true; + } + + // Reads a single 8-bit byte, reporting an error if out of bounds. + inline uint8_t checked_read_u8(const byte* base, int offset, + const char* msg = "expected 1 byte") { + return check(base, offset, 1, msg) ? base[offset] : 0; + } + + // Reads 16-bit word, reporting an error if out of bounds. + inline uint16_t checked_read_u16(const byte* base, int offset, + const char* msg = "expected 2 bytes") { + return check(base, offset, 2, msg) ? read_u16(base + offset) : 0; + } + + // Reads 32-bit word, reporting an error if out of bounds. + inline uint32_t checked_read_u32(const byte* base, int offset, + const char* msg = "expected 4 bytes") { + return check(base, offset, 4, msg) ? read_u32(base + offset) : 0; + } + + // Reads 64-bit word, reporting an error if out of bounds. + inline uint64_t checked_read_u64(const byte* base, int offset, + const char* msg = "expected 8 bytes") { + return check(base, offset, 8, msg) ? read_u64(base + offset) : 0; + } + + uint32_t checked_read_u32v(const byte* base, int offset, int* length, + const char* msg = "expected LEB128") { + if (!check(base, offset, 1, msg)) { + *length = 0; + return 0; + } + + const ptrdiff_t kMaxDiff = 5; // maximum 5 bytes. + const byte* ptr = base + offset; + const byte* end = ptr + kMaxDiff; + if (end > limit_) end = limit_; + int shift = 0; + byte b = 0; + uint32_t result = 0; + while (ptr < end) { + b = *ptr++; + result = result | ((b & 0x7F) << shift); + if ((b & 0x80) == 0) break; + shift += 7; + } + DCHECK_LE(ptr - (base + offset), kMaxDiff); + *length = static_cast(ptr - (base + offset)); + if (ptr == end && (b & 0x80)) { + error(base, ptr, msg); + return 0; + } + return result; + } + // Reads a single 16-bit unsigned integer (little endian). inline uint16_t read_u16(const byte* ptr) { DCHECK(ptr >= start_ && (ptr + 2) <= end_); @@ -170,6 +232,12 @@ class Decoder { } } + bool RangeOk(const byte* pc, int length) { + if (pc < start_ || pc_ >= limit_) return false; + if ((pc + length) >= limit_) return false; + return true; + } + void error(const char* msg) { error(pc_, nullptr, msg); } void error(const byte* pc, const char* msg) { error(pc, nullptr, msg); } diff --git a/test/unittests/wasm/ast-decoder-unittest.cc b/test/unittests/wasm/ast-decoder-unittest.cc index 412e719e6a..d0c467053c 100644 --- a/test/unittests/wasm/ast-decoder-unittest.cc +++ b/test/unittests/wasm/ast-decoder-unittest.cc @@ -1913,14 +1913,12 @@ class WasmOpcodeLengthTest : public TestWithZone { WasmOpcodeLengthTest() : TestWithZone() {} }; - -#define EXPECT_LENGTH(expected, opcode) \ - { \ - static const byte code[] = {opcode, 0, 0, 0, 0, 0, 0, 0, 0}; \ - EXPECT_EQ(expected, OpcodeLength(code)); \ +#define EXPECT_LENGTH(expected, opcode) \ + { \ + static const byte code[] = {opcode, 0, 0, 0, 0, 0, 0, 0, 0}; \ + EXPECT_EQ(expected, OpcodeLength(code, code + sizeof(code))); \ } - TEST_F(WasmOpcodeLengthTest, Statements) { EXPECT_LENGTH(1, kExprNop); EXPECT_LENGTH(2, kExprBlock); @@ -1961,11 +1959,11 @@ TEST_F(WasmOpcodeLengthTest, VariableLength) { byte size5[] = {kExprLoadGlobal, 1 | 0x80, 2 | 0x80, 3 | 0x80, 4}; byte size6[] = {kExprLoadGlobal, 1 | 0x80, 2 | 0x80, 3 | 0x80, 4 | 0x80, 5}; - EXPECT_EQ(2, OpcodeLength(size2)); - EXPECT_EQ(3, OpcodeLength(size3)); - EXPECT_EQ(4, OpcodeLength(size4)); - EXPECT_EQ(5, OpcodeLength(size5)); - EXPECT_EQ(6, OpcodeLength(size6)); + EXPECT_EQ(2, OpcodeLength(size2, size2 + sizeof(size2))); + EXPECT_EQ(3, OpcodeLength(size3, size3 + sizeof(size3))); + EXPECT_EQ(4, OpcodeLength(size4, size4 + sizeof(size4))); + EXPECT_EQ(5, OpcodeLength(size5, size5 + sizeof(size5))); + EXPECT_EQ(6, OpcodeLength(size6, size6 + sizeof(size6))); } @@ -2130,14 +2128,12 @@ class WasmOpcodeArityTest : public TestWithZone { WasmOpcodeArityTest() : TestWithZone() {} }; - -#define EXPECT_ARITY(expected, ...) \ - { \ - static const byte code[] = {__VA_ARGS__}; \ - EXPECT_EQ(expected, OpcodeArity(&env, code)); \ +#define EXPECT_ARITY(expected, ...) \ + { \ + static const byte code[] = {__VA_ARGS__}; \ + EXPECT_EQ(expected, OpcodeArity(&env, code, code + sizeof(code))); \ } - TEST_F(WasmOpcodeArityTest, Control) { FunctionEnv env; EXPECT_ARITY(0, kExprNop);