diff --git a/src/sksl/SkSLDSLParser.cpp b/src/sksl/SkSLDSLParser.cpp index 1dbdb8b55c..8c44c8493d 100644 --- a/src/sksl/SkSLDSLParser.cpp +++ b/src/sksl/SkSLDSLParser.cpp @@ -420,26 +420,32 @@ bool DSLParser::functionDeclarationEnd(const DSLModifiers& modifiers, return true; } -SKSL_INT DSLParser::arraySize() { +bool DSLParser::arraySize(SKSL_INT* outResult) { DSLExpression sizeExpr = this->expression(); - if (!sizeExpr.isValid()) { - return 1; + if (!sizeExpr.hasValue()) { + return false; } - std::unique_ptr sizeLiteral = sizeExpr.release(); - SKSL_INT size; - if (!ConstantFolder::GetConstantInt(*sizeLiteral, &size)) { - this->error(sizeLiteral->fLine, "array size must be an integer"); - return 1; + // Start out with a safe value that won't generate any errors downstream + *outResult = 1; + if (sizeExpr.isValid()) { + std::unique_ptr sizeLiteral = sizeExpr.release(); + SKSL_INT size; + if (!ConstantFolder::GetConstantInt(*sizeLiteral, &size)) { + this->error(sizeLiteral->fLine, "array size must be an integer"); + return true; + } + if (size > INT32_MAX) { + this->error(sizeLiteral->fLine, "array size out of bounds"); + return true; + } + if (size <= 0) { + this->error(sizeLiteral->fLine, "array size must be positive"); + return true; + } + // Now that we've validated it, output the real value + *outResult = size; } - if (size > INT32_MAX) { - this->error(sizeLiteral->fLine, "array size out of bounds"); - return 1; - } - if (size <= 0) { - this->error(sizeLiteral->fLine, "array size must be positive"); - return 1; - } - return size; + return true; } bool DSLParser::parseArrayDimensions(int line, DSLType* type) { @@ -447,7 +453,11 @@ bool DSLParser::parseArrayDimensions(int line, DSLType* type) { if (this->checkNext(Token::Kind::TK_RBRACKET)) { this->error(line, "expected array dimension"); } else { - *type = Array(*type, this->arraySize(), this->position(line)); + SKSL_INT size; + if (!this->arraySize(&size)) { + return false; + } + *type = Array(*type, size, this->position(line)); if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) { return false; } @@ -633,7 +643,11 @@ std::optional DSLParser::structDeclaration() { } while (this->checkNext(Token::Kind::TK_LBRACKET)) { - actualType = dsl::Array(actualType, this->arraySize(), this->position(memberName)); + SKSL_INT size; + if (!this->arraySize(&size)) { + return std::nullopt; + } + actualType = dsl::Array(actualType, size, this->position(memberName)); if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) { return std::nullopt; } @@ -874,7 +888,11 @@ std::optional DSLParser::type(DSLModifiers* modifiers) { DSLType result(this->text(type), modifiers, this->position(type)); while (this->checkNext(Token::Kind::TK_LBRACKET)) { if (this->peek().fKind != Token::Kind::TK_RBRACKET) { - result = Array(result, this->arraySize(), this->position(type)); + SKSL_INT size; + if (!this->arraySize(&size)) { + return std::nullopt; + } + result = Array(result, size, this->position(type)); } else { this->error(this->peek(), "expected array dimension"); } @@ -885,7 +903,7 @@ std::optional DSLParser::type(DSLModifiers* modifiers) { /* IDENTIFIER LBRACE varDeclaration+ - RBRACE (IDENTIFIER (LBRACKET expression? RBRACKET)*)? SEMICOLON */ + RBRACE (IDENTIFIER (LBRACKET expression RBRACKET)*)? SEMICOLON */ bool DSLParser::interfaceBlock(const dsl::DSLModifiers& modifiers) { Token typeName; if (!this->expectIdentifier(&typeName)) { @@ -916,8 +934,11 @@ bool DSLParser::interfaceBlock(const dsl::DSLModifiers& modifiers) { if (this->checkNext(Token::Kind::TK_LBRACKET)) { Token sizeToken = this->peek(); if (sizeToken.fKind != Token::Kind::TK_RBRACKET) { - actualType = Array(std::move(actualType), this->arraySize(), - this->position(typeName)); + SKSL_INT size; + if (!this->arraySize(&size)) { + return false; + } + actualType = Array(std::move(actualType), size, this->position(typeName)); } else { this->error(sizeToken, "unsized arrays are not permitted"); } @@ -945,11 +966,13 @@ bool DSLParser::interfaceBlock(const dsl::DSLModifiers& modifiers) { } std::string_view instanceName; Token instanceNameToken; - SKSL_INT arraySize = 0; + SKSL_INT size = 0; if (this->checkIdentifier(&instanceNameToken)) { instanceName = this->text(instanceNameToken); if (this->checkNext(Token::Kind::TK_LBRACKET)) { - arraySize = this->arraySize(); + if (!this->arraySize(&size)) { + return false; + } this->expect(Token::Kind::TK_RBRACKET, "']'"); } } @@ -959,7 +982,7 @@ bool DSLParser::interfaceBlock(const dsl::DSLModifiers& modifiers) { "' must contain at least one member"); } else { dsl::InterfaceBlock(modifiers, this->text(typeName), std::move(fields), instanceName, - arraySize, this->position(typeName)); + size, this->position(typeName)); } return true; } diff --git a/src/sksl/SkSLDSLParser.h b/src/sksl/SkSLDSLParser.h index b3ca41ed4c..f8af04c438 100644 --- a/src/sksl/SkSLDSLParser.h +++ b/src/sksl/SkSLDSLParser.h @@ -125,7 +125,14 @@ private: void declarations(); - SKSL_INT arraySize(); + /** + * Parses an expression representing an array size. Reports errors if the array size is not + * valid (out of bounds, not a literal integer). Returns true if an expression was + * successfully parsed, even if that array size is not actually valid. In the event of a true + * return, outResult always contains a valid array size (even if the parsed array size was not + * actually valid; invalid array sizes result in a 1 to avoid additional errors downstream). + */ + bool arraySize(SKSL_INT* outResult); void directive(); diff --git a/tests/sksl/errors/Ossfuzz36850.asm.frag b/tests/sksl/errors/Ossfuzz36850.asm.frag index 64374ca894..acc1b65f02 100644 --- a/tests/sksl/errors/Ossfuzz36850.asm.frag +++ b/tests/sksl/errors/Ossfuzz36850.asm.frag @@ -1,6 +1,4 @@ ### Compilation failed: error: 1: expected expression, but found ']' -error: 1: expected ']', but found ';' -error: 2: expected ';', but found 'void' -3 errors +1 error diff --git a/tests/sksl/errors/Ossfuzz37469.asm.frag b/tests/sksl/errors/Ossfuzz37469.asm.frag index 4dc1ef8387..acc1b65f02 100644 --- a/tests/sksl/errors/Ossfuzz37469.asm.frag +++ b/tests/sksl/errors/Ossfuzz37469.asm.frag @@ -1,6 +1,4 @@ ### Compilation failed: error: 1: expected expression, but found ']' -error: 1: expected ']', but found ';' -error: 1: expected ';', but found 'void' -3 errors +1 error