1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Performs validation of arithmetic instructions.
16
17 #include "source/val/validate.h"
18
19 #include <vector>
20
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25
26 namespace spvtools {
27 namespace val {
28
29 // Validates correctness of arithmetic instructions.
ArithmeticsPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
31 const SpvOp opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case SpvOpFAdd:
36 case SpvOpFSub:
37 case SpvOpFMul:
38 case SpvOpFDiv:
39 case SpvOpFRem:
40 case SpvOpFMod:
41 case SpvOpFNegate: {
42 if (!_.IsFloatScalarType(result_type) &&
43 !_.IsFloatVectorType(result_type))
44 return _.diag(SPV_ERROR_INVALID_DATA, inst)
45 << "Expected floating scalar or vector type as Result Type: "
46 << spvOpcodeString(opcode);
47
48 for (size_t operand_index = 2; operand_index < inst->operands().size();
49 ++operand_index) {
50 if (_.GetOperandTypeId(inst, operand_index) != result_type)
51 return _.diag(SPV_ERROR_INVALID_DATA, inst)
52 << "Expected arithmetic operands to be of Result Type: "
53 << spvOpcodeString(opcode) << " operand index "
54 << operand_index;
55 }
56 break;
57 }
58
59 case SpvOpUDiv:
60 case SpvOpUMod: {
61 if (!_.IsUnsignedIntScalarType(result_type) &&
62 !_.IsUnsignedIntVectorType(result_type))
63 return _.diag(SPV_ERROR_INVALID_DATA, inst)
64 << "Expected unsigned int scalar or vector type as Result Type: "
65 << spvOpcodeString(opcode);
66
67 for (size_t operand_index = 2; operand_index < inst->operands().size();
68 ++operand_index) {
69 if (_.GetOperandTypeId(inst, operand_index) != result_type)
70 return _.diag(SPV_ERROR_INVALID_DATA, inst)
71 << "Expected arithmetic operands to be of Result Type: "
72 << spvOpcodeString(opcode) << " operand index "
73 << operand_index;
74 }
75 break;
76 }
77
78 case SpvOpISub:
79 case SpvOpIAdd:
80 case SpvOpIMul:
81 case SpvOpSDiv:
82 case SpvOpSMod:
83 case SpvOpSRem:
84 case SpvOpSNegate: {
85 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
86 return _.diag(SPV_ERROR_INVALID_DATA, inst)
87 << "Expected int scalar or vector type as Result Type: "
88 << spvOpcodeString(opcode);
89
90 const uint32_t dimension = _.GetDimension(result_type);
91 const uint32_t bit_width = _.GetBitWidth(result_type);
92
93 for (size_t operand_index = 2; operand_index < inst->operands().size();
94 ++operand_index) {
95 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
96 if (!type_id ||
97 (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
98 return _.diag(SPV_ERROR_INVALID_DATA, inst)
99 << "Expected int scalar or vector type as operand: "
100 << spvOpcodeString(opcode) << " operand index "
101 << operand_index;
102
103 if (_.GetDimension(type_id) != dimension)
104 return _.diag(SPV_ERROR_INVALID_DATA, inst)
105 << "Expected arithmetic operands to have the same dimension "
106 << "as Result Type: " << spvOpcodeString(opcode)
107 << " operand index " << operand_index;
108
109 if (_.GetBitWidth(type_id) != bit_width)
110 return _.diag(SPV_ERROR_INVALID_DATA, inst)
111 << "Expected arithmetic operands to have the same bit width "
112 << "as Result Type: " << spvOpcodeString(opcode)
113 << " operand index " << operand_index;
114 }
115 break;
116 }
117
118 case SpvOpDot: {
119 if (!_.IsFloatScalarType(result_type))
120 return _.diag(SPV_ERROR_INVALID_DATA, inst)
121 << "Expected float scalar type as Result Type: "
122 << spvOpcodeString(opcode);
123
124 uint32_t first_vector_num_components = 0;
125
126 for (size_t operand_index = 2; operand_index < inst->operands().size();
127 ++operand_index) {
128 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
129
130 if (!type_id || !_.IsFloatVectorType(type_id))
131 return _.diag(SPV_ERROR_INVALID_DATA, inst)
132 << "Expected float vector as operand: "
133 << spvOpcodeString(opcode) << " operand index "
134 << operand_index;
135
136 const uint32_t component_type = _.GetComponentType(type_id);
137 if (component_type != result_type)
138 return _.diag(SPV_ERROR_INVALID_DATA, inst)
139 << "Expected component type to be equal to Result Type: "
140 << spvOpcodeString(opcode) << " operand index "
141 << operand_index;
142
143 const uint32_t num_components = _.GetDimension(type_id);
144 if (operand_index == 2) {
145 first_vector_num_components = num_components;
146 } else if (num_components != first_vector_num_components) {
147 return _.diag(SPV_ERROR_INVALID_DATA, inst)
148 << "Expected operands to have the same number of componenets: "
149 << spvOpcodeString(opcode);
150 }
151 }
152 break;
153 }
154
155 case SpvOpVectorTimesScalar: {
156 if (!_.IsFloatVectorType(result_type))
157 return _.diag(SPV_ERROR_INVALID_DATA, inst)
158 << "Expected float vector type as Result Type: "
159 << spvOpcodeString(opcode);
160
161 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
162 if (result_type != vector_type_id)
163 return _.diag(SPV_ERROR_INVALID_DATA, inst)
164 << "Expected vector operand type to be equal to Result Type: "
165 << spvOpcodeString(opcode);
166
167 const uint32_t component_type = _.GetComponentType(vector_type_id);
168
169 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
170 if (component_type != scalar_type_id)
171 return _.diag(SPV_ERROR_INVALID_DATA, inst)
172 << "Expected scalar operand type to be equal to the component "
173 << "type of the vector operand: " << spvOpcodeString(opcode);
174
175 break;
176 }
177
178 case SpvOpMatrixTimesScalar: {
179 if (!_.IsFloatMatrixType(result_type))
180 return _.diag(SPV_ERROR_INVALID_DATA, inst)
181 << "Expected float matrix type as Result Type: "
182 << spvOpcodeString(opcode);
183
184 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
185 if (result_type != matrix_type_id)
186 return _.diag(SPV_ERROR_INVALID_DATA, inst)
187 << "Expected matrix operand type to be equal to Result Type: "
188 << spvOpcodeString(opcode);
189
190 const uint32_t component_type = _.GetComponentType(matrix_type_id);
191
192 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
193 if (component_type != scalar_type_id)
194 return _.diag(SPV_ERROR_INVALID_DATA, inst)
195 << "Expected scalar operand type to be equal to the component "
196 << "type of the matrix operand: " << spvOpcodeString(opcode);
197
198 break;
199 }
200
201 case SpvOpVectorTimesMatrix: {
202 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
203 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
204
205 if (!_.IsFloatVectorType(result_type))
206 return _.diag(SPV_ERROR_INVALID_DATA, inst)
207 << "Expected float vector type as Result Type: "
208 << spvOpcodeString(opcode);
209
210 const uint32_t res_component_type = _.GetComponentType(result_type);
211
212 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
213 return _.diag(SPV_ERROR_INVALID_DATA, inst)
214 << "Expected float vector type as left operand: "
215 << spvOpcodeString(opcode);
216
217 if (res_component_type != _.GetComponentType(vector_type_id))
218 return _.diag(SPV_ERROR_INVALID_DATA, inst)
219 << "Expected component types of Result Type and vector to be "
220 << "equal: " << spvOpcodeString(opcode);
221
222 uint32_t matrix_num_rows = 0;
223 uint32_t matrix_num_cols = 0;
224 uint32_t matrix_col_type = 0;
225 uint32_t matrix_component_type = 0;
226 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
227 &matrix_num_cols, &matrix_col_type,
228 &matrix_component_type))
229 return _.diag(SPV_ERROR_INVALID_DATA, inst)
230 << "Expected float matrix type as right operand: "
231 << spvOpcodeString(opcode);
232
233 if (res_component_type != matrix_component_type)
234 return _.diag(SPV_ERROR_INVALID_DATA, inst)
235 << "Expected component types of Result Type and matrix to be "
236 << "equal: " << spvOpcodeString(opcode);
237
238 if (matrix_num_cols != _.GetDimension(result_type))
239 return _.diag(SPV_ERROR_INVALID_DATA, inst)
240 << "Expected number of columns of the matrix to be equal to "
241 << "Result Type vector size: " << spvOpcodeString(opcode);
242
243 if (matrix_num_rows != _.GetDimension(vector_type_id))
244 return _.diag(SPV_ERROR_INVALID_DATA, inst)
245 << "Expected number of rows of the matrix to be equal to the "
246 << "vector operand size: " << spvOpcodeString(opcode);
247
248 break;
249 }
250
251 case SpvOpMatrixTimesVector: {
252 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
253 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
254
255 if (!_.IsFloatVectorType(result_type))
256 return _.diag(SPV_ERROR_INVALID_DATA, inst)
257 << "Expected float vector type as Result Type: "
258 << spvOpcodeString(opcode);
259
260 uint32_t matrix_num_rows = 0;
261 uint32_t matrix_num_cols = 0;
262 uint32_t matrix_col_type = 0;
263 uint32_t matrix_component_type = 0;
264 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
265 &matrix_num_cols, &matrix_col_type,
266 &matrix_component_type))
267 return _.diag(SPV_ERROR_INVALID_DATA, inst)
268 << "Expected float matrix type as left operand: "
269 << spvOpcodeString(opcode);
270
271 if (result_type != matrix_col_type)
272 return _.diag(SPV_ERROR_INVALID_DATA, inst)
273 << "Expected column type of the matrix to be equal to Result "
274 "Type: "
275 << spvOpcodeString(opcode);
276
277 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
278 return _.diag(SPV_ERROR_INVALID_DATA, inst)
279 << "Expected float vector type as right operand: "
280 << spvOpcodeString(opcode);
281
282 if (matrix_component_type != _.GetComponentType(vector_type_id))
283 return _.diag(SPV_ERROR_INVALID_DATA, inst)
284 << "Expected component types of the operands to be equal: "
285 << spvOpcodeString(opcode);
286
287 if (matrix_num_cols != _.GetDimension(vector_type_id))
288 return _.diag(SPV_ERROR_INVALID_DATA, inst)
289 << "Expected number of columns of the matrix to be equal to the "
290 << "vector size: " << spvOpcodeString(opcode);
291
292 break;
293 }
294
295 case SpvOpMatrixTimesMatrix: {
296 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
297 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
298
299 uint32_t res_num_rows = 0;
300 uint32_t res_num_cols = 0;
301 uint32_t res_col_type = 0;
302 uint32_t res_component_type = 0;
303 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
304 &res_col_type, &res_component_type))
305 return _.diag(SPV_ERROR_INVALID_DATA, inst)
306 << "Expected float matrix type as Result Type: "
307 << spvOpcodeString(opcode);
308
309 uint32_t left_num_rows = 0;
310 uint32_t left_num_cols = 0;
311 uint32_t left_col_type = 0;
312 uint32_t left_component_type = 0;
313 if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
314 &left_col_type, &left_component_type))
315 return _.diag(SPV_ERROR_INVALID_DATA, inst)
316 << "Expected float matrix type as left operand: "
317 << spvOpcodeString(opcode);
318
319 uint32_t right_num_rows = 0;
320 uint32_t right_num_cols = 0;
321 uint32_t right_col_type = 0;
322 uint32_t right_component_type = 0;
323 if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
324 &right_col_type, &right_component_type))
325 return _.diag(SPV_ERROR_INVALID_DATA, inst)
326 << "Expected float matrix type as right operand: "
327 << spvOpcodeString(opcode);
328
329 if (!_.IsFloatScalarType(res_component_type))
330 return _.diag(SPV_ERROR_INVALID_DATA, inst)
331 << "Expected float matrix type as Result Type: "
332 << spvOpcodeString(opcode);
333
334 if (res_col_type != left_col_type)
335 return _.diag(SPV_ERROR_INVALID_DATA, inst)
336 << "Expected column types of Result Type and left matrix to be "
337 << "equal: " << spvOpcodeString(opcode);
338
339 if (res_component_type != right_component_type)
340 return _.diag(SPV_ERROR_INVALID_DATA, inst)
341 << "Expected component types of Result Type and right matrix to "
342 "be "
343 << "equal: " << spvOpcodeString(opcode);
344
345 if (res_num_cols != right_num_cols)
346 return _.diag(SPV_ERROR_INVALID_DATA, inst)
347 << "Expected number of columns of Result Type and right matrix "
348 "to "
349 << "be equal: " << spvOpcodeString(opcode);
350
351 if (left_num_cols != right_num_rows)
352 return _.diag(SPV_ERROR_INVALID_DATA, inst)
353 << "Expected number of columns of left matrix and number of "
354 "rows "
355 << "of right matrix to be equal: " << spvOpcodeString(opcode);
356
357 assert(left_num_rows == res_num_rows);
358 break;
359 }
360
361 case SpvOpOuterProduct: {
362 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
363 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
364
365 uint32_t res_num_rows = 0;
366 uint32_t res_num_cols = 0;
367 uint32_t res_col_type = 0;
368 uint32_t res_component_type = 0;
369 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
370 &res_col_type, &res_component_type))
371 return _.diag(SPV_ERROR_INVALID_DATA, inst)
372 << "Expected float matrix type as Result Type: "
373 << spvOpcodeString(opcode);
374
375 if (left_type_id != res_col_type)
376 return _.diag(SPV_ERROR_INVALID_DATA, inst)
377 << "Expected column type of Result Type to be equal to the type "
378 << "of the left operand: " << spvOpcodeString(opcode);
379
380 if (!right_type_id || !_.IsFloatVectorType(right_type_id))
381 return _.diag(SPV_ERROR_INVALID_DATA, inst)
382 << "Expected float vector type as right operand: "
383 << spvOpcodeString(opcode);
384
385 if (res_component_type != _.GetComponentType(right_type_id))
386 return _.diag(SPV_ERROR_INVALID_DATA, inst)
387 << "Expected component types of the operands to be equal: "
388 << spvOpcodeString(opcode);
389
390 if (res_num_cols != _.GetDimension(right_type_id))
391 return _.diag(SPV_ERROR_INVALID_DATA, inst)
392 << "Expected number of columns of the matrix to be equal to the "
393 << "vector size of the right operand: "
394 << spvOpcodeString(opcode);
395
396 break;
397 }
398
399 case SpvOpIAddCarry:
400 case SpvOpISubBorrow:
401 case SpvOpUMulExtended:
402 case SpvOpSMulExtended: {
403 std::vector<uint32_t> result_types;
404 if (!_.GetStructMemberTypes(result_type, &result_types))
405 return _.diag(SPV_ERROR_INVALID_DATA, inst)
406 << "Expected a struct as Result Type: "
407 << spvOpcodeString(opcode);
408
409 if (result_types.size() != 2)
410 return _.diag(SPV_ERROR_INVALID_DATA, inst)
411 << "Expected Result Type struct to have two members: "
412 << spvOpcodeString(opcode);
413
414 if (opcode == SpvOpSMulExtended) {
415 if (!_.IsIntScalarType(result_types[0]) &&
416 !_.IsIntVectorType(result_types[0]))
417 return _.diag(SPV_ERROR_INVALID_DATA, inst)
418 << "Expected Result Type struct member types to be integer "
419 "scalar "
420 << "or vector: " << spvOpcodeString(opcode);
421 } else {
422 if (!_.IsUnsignedIntScalarType(result_types[0]) &&
423 !_.IsUnsignedIntVectorType(result_types[0]))
424 return _.diag(SPV_ERROR_INVALID_DATA, inst)
425 << "Expected Result Type struct member types to be unsigned "
426 << "integer scalar or vector: " << spvOpcodeString(opcode);
427 }
428
429 if (result_types[0] != result_types[1])
430 return _.diag(SPV_ERROR_INVALID_DATA, inst)
431 << "Expected Result Type struct member types to be identical: "
432 << spvOpcodeString(opcode);
433
434 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
435 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
436
437 if (left_type_id != result_types[0] || right_type_id != result_types[0])
438 return _.diag(SPV_ERROR_INVALID_DATA, inst)
439 << "Expected both operands to be of Result Type member type: "
440 << spvOpcodeString(opcode);
441
442 break;
443 }
444
445 default:
446 break;
447 }
448
449 return SPV_SUCCESS;
450 }
451
452 } // namespace val
453 } // namespace spvtools
454