• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Performs validation of arithmetic instructions.
16 
17 #include "source/val/validate.h"
18 
19 #include <vector>
20 
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25 
26 namespace spvtools {
27 namespace val {
28 
29 // Validates correctness of arithmetic instructions.
ArithmeticsPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
31   const SpvOp opcode = inst->opcode();
32   const uint32_t result_type = inst->type_id();
33 
34   switch (opcode) {
35     case SpvOpFAdd:
36     case SpvOpFSub:
37     case SpvOpFMul:
38     case SpvOpFDiv:
39     case SpvOpFRem:
40     case SpvOpFMod:
41     case SpvOpFNegate: {
42       if (!_.IsFloatScalarType(result_type) &&
43           !_.IsFloatVectorType(result_type))
44         return _.diag(SPV_ERROR_INVALID_DATA, inst)
45                << "Expected floating scalar or vector type as Result Type: "
46                << spvOpcodeString(opcode);
47 
48       for (size_t operand_index = 2; operand_index < inst->operands().size();
49            ++operand_index) {
50         if (_.GetOperandTypeId(inst, operand_index) != result_type)
51           return _.diag(SPV_ERROR_INVALID_DATA, inst)
52                  << "Expected arithmetic operands to be of Result Type: "
53                  << spvOpcodeString(opcode) << " operand index "
54                  << operand_index;
55       }
56       break;
57     }
58 
59     case SpvOpUDiv:
60     case SpvOpUMod: {
61       if (!_.IsUnsignedIntScalarType(result_type) &&
62           !_.IsUnsignedIntVectorType(result_type))
63         return _.diag(SPV_ERROR_INVALID_DATA, inst)
64                << "Expected unsigned int scalar or vector type as Result Type: "
65                << spvOpcodeString(opcode);
66 
67       for (size_t operand_index = 2; operand_index < inst->operands().size();
68            ++operand_index) {
69         if (_.GetOperandTypeId(inst, operand_index) != result_type)
70           return _.diag(SPV_ERROR_INVALID_DATA, inst)
71                  << "Expected arithmetic operands to be of Result Type: "
72                  << spvOpcodeString(opcode) << " operand index "
73                  << operand_index;
74       }
75       break;
76     }
77 
78     case SpvOpISub:
79     case SpvOpIAdd:
80     case SpvOpIMul:
81     case SpvOpSDiv:
82     case SpvOpSMod:
83     case SpvOpSRem:
84     case SpvOpSNegate: {
85       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
86         return _.diag(SPV_ERROR_INVALID_DATA, inst)
87                << "Expected int scalar or vector type as Result Type: "
88                << spvOpcodeString(opcode);
89 
90       const uint32_t dimension = _.GetDimension(result_type);
91       const uint32_t bit_width = _.GetBitWidth(result_type);
92 
93       for (size_t operand_index = 2; operand_index < inst->operands().size();
94            ++operand_index) {
95         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
96         if (!type_id ||
97             (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
98           return _.diag(SPV_ERROR_INVALID_DATA, inst)
99                  << "Expected int scalar or vector type as operand: "
100                  << spvOpcodeString(opcode) << " operand index "
101                  << operand_index;
102 
103         if (_.GetDimension(type_id) != dimension)
104           return _.diag(SPV_ERROR_INVALID_DATA, inst)
105                  << "Expected arithmetic operands to have the same dimension "
106                  << "as Result Type: " << spvOpcodeString(opcode)
107                  << " operand index " << operand_index;
108 
109         if (_.GetBitWidth(type_id) != bit_width)
110           return _.diag(SPV_ERROR_INVALID_DATA, inst)
111                  << "Expected arithmetic operands to have the same bit width "
112                  << "as Result Type: " << spvOpcodeString(opcode)
113                  << " operand index " << operand_index;
114       }
115       break;
116     }
117 
118     case SpvOpDot: {
119       if (!_.IsFloatScalarType(result_type))
120         return _.diag(SPV_ERROR_INVALID_DATA, inst)
121                << "Expected float scalar type as Result Type: "
122                << spvOpcodeString(opcode);
123 
124       uint32_t first_vector_num_components = 0;
125 
126       for (size_t operand_index = 2; operand_index < inst->operands().size();
127            ++operand_index) {
128         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
129 
130         if (!type_id || !_.IsFloatVectorType(type_id))
131           return _.diag(SPV_ERROR_INVALID_DATA, inst)
132                  << "Expected float vector as operand: "
133                  << spvOpcodeString(opcode) << " operand index "
134                  << operand_index;
135 
136         const uint32_t component_type = _.GetComponentType(type_id);
137         if (component_type != result_type)
138           return _.diag(SPV_ERROR_INVALID_DATA, inst)
139                  << "Expected component type to be equal to Result Type: "
140                  << spvOpcodeString(opcode) << " operand index "
141                  << operand_index;
142 
143         const uint32_t num_components = _.GetDimension(type_id);
144         if (operand_index == 2) {
145           first_vector_num_components = num_components;
146         } else if (num_components != first_vector_num_components) {
147           return _.diag(SPV_ERROR_INVALID_DATA, inst)
148                  << "Expected operands to have the same number of componenets: "
149                  << spvOpcodeString(opcode);
150         }
151       }
152       break;
153     }
154 
155     case SpvOpVectorTimesScalar: {
156       if (!_.IsFloatVectorType(result_type))
157         return _.diag(SPV_ERROR_INVALID_DATA, inst)
158                << "Expected float vector type as Result Type: "
159                << spvOpcodeString(opcode);
160 
161       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
162       if (result_type != vector_type_id)
163         return _.diag(SPV_ERROR_INVALID_DATA, inst)
164                << "Expected vector operand type to be equal to Result Type: "
165                << spvOpcodeString(opcode);
166 
167       const uint32_t component_type = _.GetComponentType(vector_type_id);
168 
169       const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
170       if (component_type != scalar_type_id)
171         return _.diag(SPV_ERROR_INVALID_DATA, inst)
172                << "Expected scalar operand type to be equal to the component "
173                << "type of the vector operand: " << spvOpcodeString(opcode);
174 
175       break;
176     }
177 
178     case SpvOpMatrixTimesScalar: {
179       if (!_.IsFloatMatrixType(result_type))
180         return _.diag(SPV_ERROR_INVALID_DATA, inst)
181                << "Expected float matrix type as Result Type: "
182                << spvOpcodeString(opcode);
183 
184       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
185       if (result_type != matrix_type_id)
186         return _.diag(SPV_ERROR_INVALID_DATA, inst)
187                << "Expected matrix operand type to be equal to Result Type: "
188                << spvOpcodeString(opcode);
189 
190       const uint32_t component_type = _.GetComponentType(matrix_type_id);
191 
192       const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
193       if (component_type != scalar_type_id)
194         return _.diag(SPV_ERROR_INVALID_DATA, inst)
195                << "Expected scalar operand type to be equal to the component "
196                << "type of the matrix operand: " << spvOpcodeString(opcode);
197 
198       break;
199     }
200 
201     case SpvOpVectorTimesMatrix: {
202       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
203       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
204 
205       if (!_.IsFloatVectorType(result_type))
206         return _.diag(SPV_ERROR_INVALID_DATA, inst)
207                << "Expected float vector type as Result Type: "
208                << spvOpcodeString(opcode);
209 
210       const uint32_t res_component_type = _.GetComponentType(result_type);
211 
212       if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
213         return _.diag(SPV_ERROR_INVALID_DATA, inst)
214                << "Expected float vector type as left operand: "
215                << spvOpcodeString(opcode);
216 
217       if (res_component_type != _.GetComponentType(vector_type_id))
218         return _.diag(SPV_ERROR_INVALID_DATA, inst)
219                << "Expected component types of Result Type and vector to be "
220                << "equal: " << spvOpcodeString(opcode);
221 
222       uint32_t matrix_num_rows = 0;
223       uint32_t matrix_num_cols = 0;
224       uint32_t matrix_col_type = 0;
225       uint32_t matrix_component_type = 0;
226       if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
227                                &matrix_num_cols, &matrix_col_type,
228                                &matrix_component_type))
229         return _.diag(SPV_ERROR_INVALID_DATA, inst)
230                << "Expected float matrix type as right operand: "
231                << spvOpcodeString(opcode);
232 
233       if (res_component_type != matrix_component_type)
234         return _.diag(SPV_ERROR_INVALID_DATA, inst)
235                << "Expected component types of Result Type and matrix to be "
236                << "equal: " << spvOpcodeString(opcode);
237 
238       if (matrix_num_cols != _.GetDimension(result_type))
239         return _.diag(SPV_ERROR_INVALID_DATA, inst)
240                << "Expected number of columns of the matrix to be equal to "
241                << "Result Type vector size: " << spvOpcodeString(opcode);
242 
243       if (matrix_num_rows != _.GetDimension(vector_type_id))
244         return _.diag(SPV_ERROR_INVALID_DATA, inst)
245                << "Expected number of rows of the matrix to be equal to the "
246                << "vector operand size: " << spvOpcodeString(opcode);
247 
248       break;
249     }
250 
251     case SpvOpMatrixTimesVector: {
252       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
253       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
254 
255       if (!_.IsFloatVectorType(result_type))
256         return _.diag(SPV_ERROR_INVALID_DATA, inst)
257                << "Expected float vector type as Result Type: "
258                << spvOpcodeString(opcode);
259 
260       uint32_t matrix_num_rows = 0;
261       uint32_t matrix_num_cols = 0;
262       uint32_t matrix_col_type = 0;
263       uint32_t matrix_component_type = 0;
264       if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
265                                &matrix_num_cols, &matrix_col_type,
266                                &matrix_component_type))
267         return _.diag(SPV_ERROR_INVALID_DATA, inst)
268                << "Expected float matrix type as left operand: "
269                << spvOpcodeString(opcode);
270 
271       if (result_type != matrix_col_type)
272         return _.diag(SPV_ERROR_INVALID_DATA, inst)
273                << "Expected column type of the matrix to be equal to Result "
274                   "Type: "
275                << spvOpcodeString(opcode);
276 
277       if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
278         return _.diag(SPV_ERROR_INVALID_DATA, inst)
279                << "Expected float vector type as right operand: "
280                << spvOpcodeString(opcode);
281 
282       if (matrix_component_type != _.GetComponentType(vector_type_id))
283         return _.diag(SPV_ERROR_INVALID_DATA, inst)
284                << "Expected component types of the operands to be equal: "
285                << spvOpcodeString(opcode);
286 
287       if (matrix_num_cols != _.GetDimension(vector_type_id))
288         return _.diag(SPV_ERROR_INVALID_DATA, inst)
289                << "Expected number of columns of the matrix to be equal to the "
290                << "vector size: " << spvOpcodeString(opcode);
291 
292       break;
293     }
294 
295     case SpvOpMatrixTimesMatrix: {
296       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
297       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
298 
299       uint32_t res_num_rows = 0;
300       uint32_t res_num_cols = 0;
301       uint32_t res_col_type = 0;
302       uint32_t res_component_type = 0;
303       if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
304                                &res_col_type, &res_component_type))
305         return _.diag(SPV_ERROR_INVALID_DATA, inst)
306                << "Expected float matrix type as Result Type: "
307                << spvOpcodeString(opcode);
308 
309       uint32_t left_num_rows = 0;
310       uint32_t left_num_cols = 0;
311       uint32_t left_col_type = 0;
312       uint32_t left_component_type = 0;
313       if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
314                                &left_col_type, &left_component_type))
315         return _.diag(SPV_ERROR_INVALID_DATA, inst)
316                << "Expected float matrix type as left operand: "
317                << spvOpcodeString(opcode);
318 
319       uint32_t right_num_rows = 0;
320       uint32_t right_num_cols = 0;
321       uint32_t right_col_type = 0;
322       uint32_t right_component_type = 0;
323       if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
324                                &right_col_type, &right_component_type))
325         return _.diag(SPV_ERROR_INVALID_DATA, inst)
326                << "Expected float matrix type as right operand: "
327                << spvOpcodeString(opcode);
328 
329       if (!_.IsFloatScalarType(res_component_type))
330         return _.diag(SPV_ERROR_INVALID_DATA, inst)
331                << "Expected float matrix type as Result Type: "
332                << spvOpcodeString(opcode);
333 
334       if (res_col_type != left_col_type)
335         return _.diag(SPV_ERROR_INVALID_DATA, inst)
336                << "Expected column types of Result Type and left matrix to be "
337                << "equal: " << spvOpcodeString(opcode);
338 
339       if (res_component_type != right_component_type)
340         return _.diag(SPV_ERROR_INVALID_DATA, inst)
341                << "Expected component types of Result Type and right matrix to "
342                   "be "
343                << "equal: " << spvOpcodeString(opcode);
344 
345       if (res_num_cols != right_num_cols)
346         return _.diag(SPV_ERROR_INVALID_DATA, inst)
347                << "Expected number of columns of Result Type and right matrix "
348                   "to "
349                << "be equal: " << spvOpcodeString(opcode);
350 
351       if (left_num_cols != right_num_rows)
352         return _.diag(SPV_ERROR_INVALID_DATA, inst)
353                << "Expected number of columns of left matrix and number of "
354                   "rows "
355                << "of right matrix to be equal: " << spvOpcodeString(opcode);
356 
357       assert(left_num_rows == res_num_rows);
358       break;
359     }
360 
361     case SpvOpOuterProduct: {
362       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
363       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
364 
365       uint32_t res_num_rows = 0;
366       uint32_t res_num_cols = 0;
367       uint32_t res_col_type = 0;
368       uint32_t res_component_type = 0;
369       if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
370                                &res_col_type, &res_component_type))
371         return _.diag(SPV_ERROR_INVALID_DATA, inst)
372                << "Expected float matrix type as Result Type: "
373                << spvOpcodeString(opcode);
374 
375       if (left_type_id != res_col_type)
376         return _.diag(SPV_ERROR_INVALID_DATA, inst)
377                << "Expected column type of Result Type to be equal to the type "
378                << "of the left operand: " << spvOpcodeString(opcode);
379 
380       if (!right_type_id || !_.IsFloatVectorType(right_type_id))
381         return _.diag(SPV_ERROR_INVALID_DATA, inst)
382                << "Expected float vector type as right operand: "
383                << spvOpcodeString(opcode);
384 
385       if (res_component_type != _.GetComponentType(right_type_id))
386         return _.diag(SPV_ERROR_INVALID_DATA, inst)
387                << "Expected component types of the operands to be equal: "
388                << spvOpcodeString(opcode);
389 
390       if (res_num_cols != _.GetDimension(right_type_id))
391         return _.diag(SPV_ERROR_INVALID_DATA, inst)
392                << "Expected number of columns of the matrix to be equal to the "
393                << "vector size of the right operand: "
394                << spvOpcodeString(opcode);
395 
396       break;
397     }
398 
399     case SpvOpIAddCarry:
400     case SpvOpISubBorrow:
401     case SpvOpUMulExtended:
402     case SpvOpSMulExtended: {
403       std::vector<uint32_t> result_types;
404       if (!_.GetStructMemberTypes(result_type, &result_types))
405         return _.diag(SPV_ERROR_INVALID_DATA, inst)
406                << "Expected a struct as Result Type: "
407                << spvOpcodeString(opcode);
408 
409       if (result_types.size() != 2)
410         return _.diag(SPV_ERROR_INVALID_DATA, inst)
411                << "Expected Result Type struct to have two members: "
412                << spvOpcodeString(opcode);
413 
414       if (opcode == SpvOpSMulExtended) {
415         if (!_.IsIntScalarType(result_types[0]) &&
416             !_.IsIntVectorType(result_types[0]))
417           return _.diag(SPV_ERROR_INVALID_DATA, inst)
418                  << "Expected Result Type struct member types to be integer "
419                     "scalar "
420                  << "or vector: " << spvOpcodeString(opcode);
421       } else {
422         if (!_.IsUnsignedIntScalarType(result_types[0]) &&
423             !_.IsUnsignedIntVectorType(result_types[0]))
424           return _.diag(SPV_ERROR_INVALID_DATA, inst)
425                  << "Expected Result Type struct member types to be unsigned "
426                  << "integer scalar or vector: " << spvOpcodeString(opcode);
427       }
428 
429       if (result_types[0] != result_types[1])
430         return _.diag(SPV_ERROR_INVALID_DATA, inst)
431                << "Expected Result Type struct member types to be identical: "
432                << spvOpcodeString(opcode);
433 
434       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
435       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
436 
437       if (left_type_id != result_types[0] || right_type_id != result_types[0])
438         return _.diag(SPV_ERROR_INVALID_DATA, inst)
439                << "Expected both operands to be of Result Type member type: "
440                << spvOpcodeString(opcode);
441 
442       break;
443     }
444 
445     default:
446       break;
447   }
448 
449   return SPV_SUCCESS;
450 }
451 
452 }  // namespace val
453 }  // namespace spvtools
454