• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 
17 #include "source/enum_string_mapping.h"
18 #include "source/opcode.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 // Returns true if |a| and |b| are instructions defining pointers that point to
28 // types logically match and the decorations that apply to |b| are a subset
29 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)30 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31                               ValidationState_t& _) {
32   if (a->opcode() != spv::Op::OpTypePointer ||
33       b->opcode() != spv::Op::OpTypePointer) {
34     return false;
35   }
36 
37   const auto& dec_a = _.id_decorations(a->id());
38   const auto& dec_b = _.id_decorations(b->id());
39   for (const auto& dec : dec_b) {
40     if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41       return false;
42     }
43   }
44 
45   uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46   uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47 
48   if (a_type == b_type) {
49     return true;
50   }
51 
52   Instruction* a_type_inst = _.FindDef(a_type);
53   Instruction* b_type_inst = _.FindDef(b_type);
54 
55   return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56 }
57 
ValidateFunction(ValidationState_t & _,const Instruction * inst)58 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60   const auto function_type = _.FindDef(function_type_id);
61   if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62     return _.diag(SPV_ERROR_INVALID_ID, inst)
63            << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64            << " is not a function type.";
65   }
66 
67   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68   if (return_id != inst->type_id()) {
69     return _.diag(SPV_ERROR_INVALID_ID, inst)
70            << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71            << " does not match the Function Type's return type <id> "
72            << _.getIdName(return_id) << ".";
73   }
74 
75   const std::vector<spv::Op> acceptable = {
76       spv::Op::OpGroupDecorate,
77       spv::Op::OpDecorate,
78       spv::Op::OpEnqueueKernel,
79       spv::Op::OpEntryPoint,
80       spv::Op::OpExecutionMode,
81       spv::Op::OpExecutionModeId,
82       spv::Op::OpFunctionCall,
83       spv::Op::OpGetKernelNDrangeSubGroupCount,
84       spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85       spv::Op::OpGetKernelWorkGroupSize,
86       spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87       spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88       spv::Op::OpGetKernelMaxNumSubgroups,
89       spv::Op::OpName};
90   for (auto& pair : inst->uses()) {
91     const auto* use = pair.first;
92     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
93             acceptable.end() &&
94         !use->IsNonSemantic() && !use->IsDebugInfo()) {
95       return _.diag(SPV_ERROR_INVALID_ID, use)
96              << "Invalid use of function result id " << _.getIdName(inst->id())
97              << ".";
98     }
99   }
100 
101   return SPV_SUCCESS;
102 }
103 
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)104 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
105                                        const Instruction* inst) {
106   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
107   size_t param_index = 0;
108   size_t inst_num = inst->LineNum() - 1;
109   if (inst_num == 0) {
110     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
111            << "Function parameter cannot be the first instruction.";
112   }
113 
114   auto func_inst = &_.ordered_instructions()[inst_num];
115   while (--inst_num) {
116     func_inst = &_.ordered_instructions()[inst_num];
117     if (func_inst->opcode() == spv::Op::OpFunction) {
118       break;
119     } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
120       ++param_index;
121     }
122   }
123 
124   if (func_inst->opcode() != spv::Op::OpFunction) {
125     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
126            << "Function parameter must be preceded by a function.";
127   }
128 
129   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
130   const auto function_type = _.FindDef(function_type_id);
131   if (!function_type) {
132     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
133            << "Missing function type definition.";
134   }
135   if (param_index >= function_type->words().size() - 3) {
136     return _.diag(SPV_ERROR_INVALID_ID, inst)
137            << "Too many OpFunctionParameters for " << func_inst->id()
138            << ": expected " << function_type->words().size() - 3
139            << " based on the function's type";
140   }
141 
142   const auto param_type =
143       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
144   if (!param_type || inst->type_id() != param_type->id()) {
145     return _.diag(SPV_ERROR_INVALID_ID, inst)
146            << "OpFunctionParameter Result Type <id> "
147            << _.getIdName(inst->type_id())
148            << " does not match the OpTypeFunction parameter "
149               "type of the same index.";
150   }
151 
152   // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
153   // RestrictPointer, or AliasedPointer.
154   auto param_nonarray_type_id = param_type->id();
155   while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
156     param_nonarray_type_id =
157         _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
158   }
159   if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) {
160     auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
161     if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
162         spv::StorageClass::PhysicalStorageBuffer) {
163       // check for Aliased or Restrict
164       const auto& decorations = _.id_decorations(inst->id());
165 
166       bool foundAliased = std::any_of(
167           decorations.begin(), decorations.end(), [](const Decoration& d) {
168             return spv::Decoration::Aliased == d.dec_type();
169           });
170 
171       bool foundRestrict = std::any_of(
172           decorations.begin(), decorations.end(), [](const Decoration& d) {
173             return spv::Decoration::Restrict == d.dec_type();
174           });
175 
176       if (!foundAliased && !foundRestrict) {
177         return _.diag(SPV_ERROR_INVALID_ID, inst)
178                << "OpFunctionParameter " << inst->id()
179                << ": expected Aliased or Restrict for PhysicalStorageBuffer "
180                   "pointer.";
181       }
182       if (foundAliased && foundRestrict) {
183         return _.diag(SPV_ERROR_INVALID_ID, inst)
184                << "OpFunctionParameter " << inst->id()
185                << ": can't specify both Aliased and Restrict for "
186                   "PhysicalStorageBuffer pointer.";
187       }
188     } else {
189       const auto pointee_type_id =
190           param_nonarray_type->GetOperandAs<uint32_t>(2);
191       const auto pointee_type = _.FindDef(pointee_type_id);
192       if (spv::Op::OpTypePointer == pointee_type->opcode() &&
193           pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
194               spv::StorageClass::PhysicalStorageBuffer) {
195         // check for AliasedPointer/RestrictPointer
196         const auto& decorations = _.id_decorations(inst->id());
197 
198         bool foundAliased = std::any_of(
199             decorations.begin(), decorations.end(), [](const Decoration& d) {
200               return spv::Decoration::AliasedPointer == d.dec_type();
201             });
202 
203         bool foundRestrict = std::any_of(
204             decorations.begin(), decorations.end(), [](const Decoration& d) {
205               return spv::Decoration::RestrictPointer == d.dec_type();
206             });
207 
208         if (!foundAliased && !foundRestrict) {
209           return _.diag(SPV_ERROR_INVALID_ID, inst)
210                  << "OpFunctionParameter " << inst->id()
211                  << ": expected AliasedPointer or RestrictPointer for "
212                     "PhysicalStorageBuffer pointer.";
213         }
214         if (foundAliased && foundRestrict) {
215           return _.diag(SPV_ERROR_INVALID_ID, inst)
216                  << "OpFunctionParameter " << inst->id()
217                  << ": can't specify both AliasedPointer and "
218                     "RestrictPointer for PhysicalStorageBuffer pointer.";
219         }
220       }
221     }
222   }
223 
224   return SPV_SUCCESS;
225 }
226 
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)227 spv_result_t ValidateFunctionCall(ValidationState_t& _,
228                                   const Instruction* inst) {
229   const auto function_id = inst->GetOperandAs<uint32_t>(2);
230   const auto function = _.FindDef(function_id);
231   if (!function || spv::Op::OpFunction != function->opcode()) {
232     return _.diag(SPV_ERROR_INVALID_ID, inst)
233            << "OpFunctionCall Function <id> " << _.getIdName(function_id)
234            << " is not a function.";
235   }
236 
237   auto return_type = _.FindDef(function->type_id());
238   if (!return_type || return_type->id() != inst->type_id()) {
239     return _.diag(SPV_ERROR_INVALID_ID, inst)
240            << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
241            << "s type does not match Function <id> "
242            << _.getIdName(return_type->id()) << "s return type.";
243   }
244 
245   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
246   const auto function_type = _.FindDef(function_type_id);
247   if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
248     return _.diag(SPV_ERROR_INVALID_ID, inst)
249            << "Missing function type definition.";
250   }
251 
252   const auto function_call_arg_count = inst->words().size() - 4;
253   const auto function_param_count = function_type->words().size() - 3;
254   if (function_param_count != function_call_arg_count) {
255     return _.diag(SPV_ERROR_INVALID_ID, inst)
256            << "OpFunctionCall Function <id>'s parameter count does not match "
257               "the argument count.";
258   }
259 
260   for (size_t argument_index = 3, param_index = 2;
261        argument_index < inst->operands().size();
262        argument_index++, param_index++) {
263     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
264     const auto argument = _.FindDef(argument_id);
265     if (!argument) {
266       return _.diag(SPV_ERROR_INVALID_ID, inst)
267              << "Missing argument " << argument_index - 3 << " definition.";
268     }
269 
270     const auto argument_type = _.FindDef(argument->type_id());
271     if (!argument_type) {
272       return _.diag(SPV_ERROR_INVALID_ID, inst)
273              << "Missing argument " << argument_index - 3
274              << " type definition.";
275     }
276 
277     const auto parameter_type_id =
278         function_type->GetOperandAs<uint32_t>(param_index);
279     const auto parameter_type = _.FindDef(parameter_type_id);
280     if (!parameter_type || argument_type->id() != parameter_type->id()) {
281       if (!_.options()->before_hlsl_legalization ||
282           !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
283         return _.diag(SPV_ERROR_INVALID_ID, inst)
284                << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
285                << "s type does not match Function <id> "
286                << _.getIdName(parameter_type_id) << "s parameter type.";
287       }
288     }
289 
290     if (_.addressing_model() == spv::AddressingModel::Logical) {
291       if (parameter_type->opcode() == spv::Op::OpTypePointer &&
292           !_.options()->relax_logical_pointer) {
293         spv::StorageClass sc =
294             parameter_type->GetOperandAs<spv::StorageClass>(1u);
295         // Validate which storage classes can be pointer operands.
296         switch (sc) {
297           case spv::StorageClass::UniformConstant:
298           case spv::StorageClass::Function:
299           case spv::StorageClass::Private:
300           case spv::StorageClass::Workgroup:
301           case spv::StorageClass::AtomicCounter:
302             // These are always allowed.
303             break;
304           case spv::StorageClass::StorageBuffer:
305             if (!_.features().variable_pointers) {
306               return _.diag(SPV_ERROR_INVALID_ID, inst)
307                      << "StorageBuffer pointer operand "
308                      << _.getIdName(argument_id)
309                      << " requires a variable pointers capability";
310             }
311             break;
312           default:
313             return _.diag(SPV_ERROR_INVALID_ID, inst)
314                    << "Invalid storage class for pointer operand "
315                    << _.getIdName(argument_id);
316         }
317 
318         // Validate memory object declaration requirements.
319         if (argument->opcode() != spv::Op::OpVariable &&
320             argument->opcode() != spv::Op::OpFunctionParameter) {
321           const bool ssbo_vptr = _.features().variable_pointers &&
322                                  sc == spv::StorageClass::StorageBuffer;
323           const bool wg_vptr =
324               _.HasCapability(spv::Capability::VariablePointers) &&
325               sc == spv::StorageClass::Workgroup;
326           const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
327           if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
328             return _.diag(SPV_ERROR_INVALID_ID, inst)
329                    << "Pointer operand " << _.getIdName(argument_id)
330                    << " must be a memory object declaration";
331           }
332         }
333       }
334     }
335   }
336   return SPV_SUCCESS;
337 }
338 
339 }  // namespace
340 
FunctionPass(ValidationState_t & _,const Instruction * inst)341 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
342   switch (inst->opcode()) {
343     case spv::Op::OpFunction:
344       if (auto error = ValidateFunction(_, inst)) return error;
345       break;
346     case spv::Op::OpFunctionParameter:
347       if (auto error = ValidateFunctionParameter(_, inst)) return error;
348       break;
349     case spv::Op::OpFunctionCall:
350       if (auto error = ValidateFunctionCall(_, inst)) return error;
351       break;
352     default:
353       break;
354   }
355 
356   return SPV_SUCCESS;
357 }
358 
359 }  // namespace val
360 }  // namespace spvtools
361