1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates ray query instructions from SPV_KHR_ray_query
16
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21
22 namespace spvtools {
23 namespace val {
24
MeshShadingPass(ValidationState_t & _,const Instruction * inst)25 spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
26 const spv::Op opcode = inst->opcode();
27 switch (opcode) {
28 case spv::Op::OpEmitMeshTasksEXT: {
29 _.function(inst->function()->id())
30 ->RegisterExecutionModelLimitation(
31 [](spv::ExecutionModel model, std::string* message) {
32 if (model != spv::ExecutionModel::TaskEXT) {
33 if (message) {
34 *message =
35 "OpEmitMeshTasksEXT requires TaskEXT execution model";
36 }
37 return false;
38 }
39 return true;
40 });
41
42 const uint32_t group_count_x = _.GetOperandTypeId(inst, 0);
43 if (!_.IsUnsignedIntScalarType(group_count_x) ||
44 _.GetBitWidth(group_count_x) != 32) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Group Count X must be a 32-bit unsigned int scalar";
47 }
48
49 const uint32_t group_count_y = _.GetOperandTypeId(inst, 1);
50 if (!_.IsUnsignedIntScalarType(group_count_y) ||
51 _.GetBitWidth(group_count_y) != 32) {
52 return _.diag(SPV_ERROR_INVALID_DATA, inst)
53 << "Group Count Y must be a 32-bit unsigned int scalar";
54 }
55
56 const uint32_t group_count_z = _.GetOperandTypeId(inst, 2);
57 if (!_.IsUnsignedIntScalarType(group_count_z) ||
58 _.GetBitWidth(group_count_z) != 32) {
59 return _.diag(SPV_ERROR_INVALID_DATA, inst)
60 << "Group Count Z must be a 32-bit unsigned int scalar";
61 }
62
63 if (inst->operands().size() == 4) {
64 const auto payload = _.FindDef(inst->GetOperandAs<uint32_t>(3));
65 if (payload->opcode() != spv::Op::OpVariable) {
66 return _.diag(SPV_ERROR_INVALID_DATA, inst)
67 << "Payload must be the result of a OpVariable";
68 }
69 if (payload->GetOperandAs<spv::StorageClass>(2) !=
70 spv::StorageClass::TaskPayloadWorkgroupEXT) {
71 return _.diag(SPV_ERROR_INVALID_DATA, inst)
72 << "Payload OpVariable must have a storage class of "
73 "TaskPayloadWorkgroupEXT";
74 }
75 }
76 break;
77 }
78
79 case spv::Op::OpSetMeshOutputsEXT: {
80 _.function(inst->function()->id())
81 ->RegisterExecutionModelLimitation(
82 [](spv::ExecutionModel model, std::string* message) {
83 if (model != spv::ExecutionModel::MeshEXT) {
84 if (message) {
85 *message =
86 "OpSetMeshOutputsEXT requires MeshEXT execution model";
87 }
88 return false;
89 }
90 return true;
91 });
92
93 const uint32_t vertex_count = _.GetOperandTypeId(inst, 0);
94 if (!_.IsUnsignedIntScalarType(vertex_count) ||
95 _.GetBitWidth(vertex_count) != 32) {
96 return _.diag(SPV_ERROR_INVALID_DATA, inst)
97 << "Vertex Count must be a 32-bit unsigned int scalar";
98 }
99
100 const uint32_t primitive_count = _.GetOperandTypeId(inst, 1);
101 if (!_.IsUnsignedIntScalarType(primitive_count) ||
102 _.GetBitWidth(primitive_count) != 32) {
103 return _.diag(SPV_ERROR_INVALID_DATA, inst)
104 << "Primitive Count must be a 32-bit unsigned int scalar";
105 }
106
107 break;
108 }
109
110 case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
111 // No validation rules (for the moment).
112 break;
113 }
114
115 default:
116 break;
117 }
118
119 return SPV_SUCCESS;
120 }
121
122 } // namespace val
123 } // namespace spvtools
124