• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray query instructions from SPV_KHR_ray_query
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 
MeshShadingPass(ValidationState_t & _,const Instruction * inst)25 spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst) {
26   const spv::Op opcode = inst->opcode();
27   switch (opcode) {
28     case spv::Op::OpEmitMeshTasksEXT: {
29       _.function(inst->function()->id())
30           ->RegisterExecutionModelLimitation(
31               [](spv::ExecutionModel model, std::string* message) {
32                 if (model != spv::ExecutionModel::TaskEXT) {
33                   if (message) {
34                     *message =
35                         "OpEmitMeshTasksEXT requires TaskEXT execution model";
36                   }
37                   return false;
38                 }
39                 return true;
40               });
41 
42       const uint32_t group_count_x = _.GetOperandTypeId(inst, 0);
43       if (!_.IsUnsignedIntScalarType(group_count_x) ||
44           _.GetBitWidth(group_count_x) != 32) {
45         return _.diag(SPV_ERROR_INVALID_DATA, inst)
46                << "Group Count X must be a 32-bit unsigned int scalar";
47       }
48 
49       const uint32_t group_count_y = _.GetOperandTypeId(inst, 1);
50       if (!_.IsUnsignedIntScalarType(group_count_y) ||
51           _.GetBitWidth(group_count_y) != 32) {
52         return _.diag(SPV_ERROR_INVALID_DATA, inst)
53                << "Group Count Y must be a 32-bit unsigned int scalar";
54       }
55 
56       const uint32_t group_count_z = _.GetOperandTypeId(inst, 2);
57       if (!_.IsUnsignedIntScalarType(group_count_z) ||
58           _.GetBitWidth(group_count_z) != 32) {
59         return _.diag(SPV_ERROR_INVALID_DATA, inst)
60                << "Group Count Z must be a 32-bit unsigned int scalar";
61       }
62 
63       if (inst->operands().size() == 4) {
64         const auto payload = _.FindDef(inst->GetOperandAs<uint32_t>(3));
65         if (payload->opcode() != spv::Op::OpVariable) {
66           return _.diag(SPV_ERROR_INVALID_DATA, inst)
67                  << "Payload must be the result of a OpVariable";
68         }
69         if (payload->GetOperandAs<spv::StorageClass>(2) !=
70             spv::StorageClass::TaskPayloadWorkgroupEXT) {
71           return _.diag(SPV_ERROR_INVALID_DATA, inst)
72                  << "Payload OpVariable must have a storage class of "
73                     "TaskPayloadWorkgroupEXT";
74         }
75       }
76       break;
77     }
78 
79     case spv::Op::OpSetMeshOutputsEXT: {
80       _.function(inst->function()->id())
81           ->RegisterExecutionModelLimitation(
82               [](spv::ExecutionModel model, std::string* message) {
83                 if (model != spv::ExecutionModel::MeshEXT) {
84                   if (message) {
85                     *message =
86                         "OpSetMeshOutputsEXT requires MeshEXT execution model";
87                   }
88                   return false;
89                 }
90                 return true;
91               });
92 
93       const uint32_t vertex_count = _.GetOperandTypeId(inst, 0);
94       if (!_.IsUnsignedIntScalarType(vertex_count) ||
95           _.GetBitWidth(vertex_count) != 32) {
96         return _.diag(SPV_ERROR_INVALID_DATA, inst)
97                << "Vertex Count must be a 32-bit unsigned int scalar";
98       }
99 
100       const uint32_t primitive_count = _.GetOperandTypeId(inst, 1);
101       if (!_.IsUnsignedIntScalarType(primitive_count) ||
102           _.GetBitWidth(primitive_count) != 32) {
103         return _.diag(SPV_ERROR_INVALID_DATA, inst)
104                << "Primitive Count must be a 32-bit unsigned int scalar";
105       }
106 
107       break;
108     }
109 
110     case spv::Op::OpWritePackedPrimitiveIndices4x8NV: {
111       // No validation rules (for the moment).
112       break;
113     }
114 
115     default:
116       break;
117   }
118 
119   return SPV_SUCCESS;
120 }
121 
122 }  // namespace val
123 }  // namespace spvtools
124