mirror of
https://github.com/KhronosGroup/SPIRV-Tools
synced 2024-11-22 11:40:05 +00:00
185 lines
6.2 KiB
C++
185 lines
6.2 KiB
C++
// Copyright (c) 2024 NVIDIA Corporation
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Validate instructions that manipulate tensor layout and view objects
|
|
|
|
#include "source/opcode.h"
|
|
#include "source/spirv_target_env.h"
|
|
#include "source/val/instruction.h"
|
|
#include "source/val/validate.h"
|
|
#include "source/val/validation_state.h"
|
|
|
|
namespace spvtools {
|
|
namespace val {
|
|
namespace {
|
|
|
|
spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _,
|
|
const Instruction* inst) {
|
|
const auto result_type_index = 0;
|
|
const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
|
|
const auto result_type = _.FindDef(result_type_id);
|
|
|
|
if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) {
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
|
|
<< _.getIdName(result_type_id) << " is not a tensor layout type.";
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _,
|
|
const Instruction* inst) {
|
|
const auto result_type_index = 0;
|
|
const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
|
|
const auto result_type = _.FindDef(result_type_id);
|
|
|
|
if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) {
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
|
|
<< _.getIdName(result_type_id) << " is not a tensor view type.";
|
|
}
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _,
|
|
const Instruction* inst) {
|
|
if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _,
|
|
const Instruction* inst) {
|
|
if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
enum ExpectedNumValues {
|
|
DIM,
|
|
DIMx2,
|
|
ONE,
|
|
FOUR,
|
|
};
|
|
|
|
spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
|
|
const Instruction* inst,
|
|
ExpectedNumValues expected,
|
|
bool is_view) {
|
|
std::string type_str;
|
|
if (is_view) {
|
|
if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
|
|
type_str = "TensorView";
|
|
} else {
|
|
if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
|
|
type_str = "TensorLayout";
|
|
}
|
|
|
|
const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
|
|
const auto tensor_id = inst->GetOperandAs<uint32_t>(2);
|
|
const auto tensor = _.FindDef(tensor_id);
|
|
if (!tensor || result_type_id != tensor->type_id()) {
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
<< spvOpcodeString(inst->opcode()) << " Result Type <id> "
|
|
<< _.getIdName(result_type_id) << " does not match " << type_str
|
|
<< " type.";
|
|
}
|
|
|
|
const auto num_values = inst->operands().size() - 3;
|
|
|
|
const auto result_type = _.FindDef(result_type_id);
|
|
const auto dim_index = 1;
|
|
const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index);
|
|
uint64_t dim_value;
|
|
if (_.EvalConstantValUint64(dim_id, &dim_value)) {
|
|
uint64_t expected_num_values = 0;
|
|
switch (expected) {
|
|
case DIM:
|
|
expected_num_values = dim_value;
|
|
break;
|
|
case DIMx2:
|
|
expected_num_values = dim_value * 2;
|
|
break;
|
|
case ONE:
|
|
expected_num_values = 1;
|
|
break;
|
|
case FOUR:
|
|
expected_num_values = 4;
|
|
break;
|
|
}
|
|
|
|
if (num_values != expected_num_values) {
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
<< spvOpcodeString(inst->opcode())
|
|
<< " unexpected number of operands.";
|
|
}
|
|
}
|
|
|
|
for (uint32_t i = 0; i < num_values; ++i) {
|
|
const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
|
|
const auto val = _.FindDef(val_id);
|
|
if (!val || !_.IsIntScalarType(val->type_id()) ||
|
|
_.GetBitWidth(val->type_id()) != 32) {
|
|
return _.diag(SPV_ERROR_INVALID_ID, inst)
|
|
<< spvOpcodeString(inst->opcode()) << " operand <id> "
|
|
<< _.getIdName(val_id) << " is not a 32-bit integer.";
|
|
}
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) {
|
|
switch (inst->opcode()) {
|
|
case spv::Op::OpCreateTensorLayoutNV:
|
|
if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error;
|
|
break;
|
|
case spv::Op::OpCreateTensorViewNV:
|
|
if (auto error = ValidateCreateTensorViewNV(_, inst)) return error;
|
|
break;
|
|
case spv::Op::OpTensorLayoutSetBlockSizeNV:
|
|
case spv::Op::OpTensorLayoutSetDimensionNV:
|
|
case spv::Op::OpTensorLayoutSetStrideNV:
|
|
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false))
|
|
return error;
|
|
break;
|
|
case spv::Op::OpTensorLayoutSliceNV:
|
|
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false))
|
|
return error;
|
|
break;
|
|
case spv::Op::OpTensorLayoutSetClampValueNV:
|
|
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false))
|
|
return error;
|
|
break;
|
|
case spv::Op::OpTensorViewSetDimensionNV:
|
|
case spv::Op::OpTensorViewSetStrideNV:
|
|
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true))
|
|
return error;
|
|
break;
|
|
case spv::Op::OpTensorViewSetClipNV:
|
|
if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true))
|
|
return error;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return SPV_SUCCESS;
|
|
}
|
|
|
|
} // namespace val
|
|
} // namespace spvtools
|