diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp index 49ba236fc..e26df2880 100644 --- a/source/val/validate_extensions.cpp +++ b/source/val/validate_extensions.cpp @@ -2701,8 +2701,9 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { "Generic, CrossWorkgroup, Workgroup or Function"; } - if (!_.IsFloatScalarType(p_data_type) || - _.GetBitWidth(p_data_type) != 16) { + if ((!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) && + !_.ContainsUntypedPointer(p_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; @@ -2763,8 +2764,9 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { "Generic, CrossWorkgroup, Workgroup or Function"; } - if (!_.IsFloatScalarType(p_data_type) || - _.GetBitWidth(p_data_type) != 16) { + if ((!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) && + !_.ContainsUntypedPointer(p_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; @@ -2855,8 +2857,9 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { "CrossWorkgroup, Workgroup or Function"; } - if (!_.IsFloatScalarType(p_data_type) || - _.GetBitWidth(p_data_type) != 16) { + if ((!_.IsFloatScalarType(p_data_type) || + _.GetBitWidth(p_data_type) != 16) && + !_.ContainsUntypedPointer(p_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected operand P data type to be 16-bit float scalar"; @@ -2994,8 +2997,9 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) { if (_.IsIntArrayType(format_data_type)) format_data_type = _.GetComponentType(format_data_type); - if (!_.IsIntScalarType(format_data_type) || - _.GetBitWidth(format_data_type) != 8) { + if ((!_.IsIntScalarType(format_data_type) || + _.GetBitWidth(format_data_type) != 8) && + !_.ContainsUntypedPointer(format_type)) { return _.diag(SPV_ERROR_INVALID_DATA, inst) << ext_inst_name() << ": " << "expected Format data type to be 8-bit int";