1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of bitwise instructions.
16
17 #include "source/diagnostic.h"
18 #include "source/opcode.h"
19 #include "source/spirv_target_env.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validation_state.h"
23
24 namespace spvtools {
25 namespace val {
26
27 // Validates when base and result need to be the same type
ValidateBaseType(ValidationState_t & _,const Instruction * inst,const uint32_t base_type)28 spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
29 const uint32_t base_type) {
30 const SpvOp opcode = inst->opcode();
31
32 if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
33 return _.diag(SPV_ERROR_INVALID_DATA, inst)
34 << _.VkErrorID(4781)
35 << "Expected int scalar or vector type for Base operand: "
36 << spvOpcodeString(opcode);
37 }
38
39 // Vulkan has a restriction to 32 bit for base
40 if (spvIsVulkanEnv(_.context()->target_env)) {
41 if (_.GetBitWidth(base_type) != 32) {
42 return _.diag(SPV_ERROR_INVALID_DATA, inst)
43 << _.VkErrorID(4781)
44 << "Expected 32-bit int type for Base operand: "
45 << spvOpcodeString(opcode);
46 }
47 }
48
49 // OpBitCount just needs same number of components
50 if (base_type != inst->type_id() && opcode != SpvOpBitCount) {
51 return _.diag(SPV_ERROR_INVALID_DATA, inst)
52 << "Expected Base Type to be equal to Result Type: "
53 << spvOpcodeString(opcode);
54 }
55
56 return SPV_SUCCESS;
57 }
58
59 // Validates correctness of bitwise instructions.
BitwisePass(ValidationState_t & _,const Instruction * inst)60 spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
61 const SpvOp opcode = inst->opcode();
62 const uint32_t result_type = inst->type_id();
63
64 switch (opcode) {
65 case SpvOpShiftRightLogical:
66 case SpvOpShiftRightArithmetic:
67 case SpvOpShiftLeftLogical: {
68 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
69 return _.diag(SPV_ERROR_INVALID_DATA, inst)
70 << "Expected int scalar or vector type as Result Type: "
71 << spvOpcodeString(opcode);
72
73 const uint32_t result_dimension = _.GetDimension(result_type);
74 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
75 const uint32_t shift_type = _.GetOperandTypeId(inst, 3);
76
77 if (!base_type ||
78 (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
79 return _.diag(SPV_ERROR_INVALID_DATA, inst)
80 << "Expected Base to be int scalar or vector: "
81 << spvOpcodeString(opcode);
82
83 if (_.GetDimension(base_type) != result_dimension)
84 return _.diag(SPV_ERROR_INVALID_DATA, inst)
85 << "Expected Base to have the same dimension "
86 << "as Result Type: " << spvOpcodeString(opcode);
87
88 if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type))
89 return _.diag(SPV_ERROR_INVALID_DATA, inst)
90 << "Expected Base to have the same bit width "
91 << "as Result Type: " << spvOpcodeString(opcode);
92
93 if (!shift_type ||
94 (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type)))
95 return _.diag(SPV_ERROR_INVALID_DATA, inst)
96 << "Expected Shift to be int scalar or vector: "
97 << spvOpcodeString(opcode);
98
99 if (_.GetDimension(shift_type) != result_dimension)
100 return _.diag(SPV_ERROR_INVALID_DATA, inst)
101 << "Expected Shift to have the same dimension "
102 << "as Result Type: " << spvOpcodeString(opcode);
103 break;
104 }
105
106 case SpvOpBitwiseOr:
107 case SpvOpBitwiseXor:
108 case SpvOpBitwiseAnd:
109 case SpvOpNot: {
110 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
111 return _.diag(SPV_ERROR_INVALID_DATA, inst)
112 << "Expected int scalar or vector type as Result Type: "
113 << spvOpcodeString(opcode);
114
115 const uint32_t result_dimension = _.GetDimension(result_type);
116 const uint32_t result_bit_width = _.GetBitWidth(result_type);
117
118 for (size_t operand_index = 2; operand_index < inst->operands().size();
119 ++operand_index) {
120 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
121 if (!type_id ||
122 (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
123 return _.diag(SPV_ERROR_INVALID_DATA, inst)
124 << "Expected int scalar or vector as operand: "
125 << spvOpcodeString(opcode) << " operand index "
126 << operand_index;
127
128 if (_.GetDimension(type_id) != result_dimension)
129 return _.diag(SPV_ERROR_INVALID_DATA, inst)
130 << "Expected operands to have the same dimension "
131 << "as Result Type: " << spvOpcodeString(opcode)
132 << " operand index " << operand_index;
133
134 if (_.GetBitWidth(type_id) != result_bit_width)
135 return _.diag(SPV_ERROR_INVALID_DATA, inst)
136 << "Expected operands to have the same bit width "
137 << "as Result Type: " << spvOpcodeString(opcode)
138 << " operand index " << operand_index;
139 }
140 break;
141 }
142
143 case SpvOpBitFieldInsert: {
144 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
145 const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
146 const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
147 const uint32_t count_type = _.GetOperandTypeId(inst, 5);
148
149 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
150 return error;
151 }
152
153 if (insert_type != result_type)
154 return _.diag(SPV_ERROR_INVALID_DATA, inst)
155 << "Expected Insert Type to be equal to Result Type: "
156 << spvOpcodeString(opcode);
157
158 if (!offset_type || !_.IsIntScalarType(offset_type))
159 return _.diag(SPV_ERROR_INVALID_DATA, inst)
160 << "Expected Offset Type to be int scalar: "
161 << spvOpcodeString(opcode);
162
163 if (!count_type || !_.IsIntScalarType(count_type))
164 return _.diag(SPV_ERROR_INVALID_DATA, inst)
165 << "Expected Count Type to be int scalar: "
166 << spvOpcodeString(opcode);
167 break;
168 }
169
170 case SpvOpBitFieldSExtract:
171 case SpvOpBitFieldUExtract: {
172 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
173 const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
174 const uint32_t count_type = _.GetOperandTypeId(inst, 4);
175
176 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
177 return error;
178 }
179
180 if (!offset_type || !_.IsIntScalarType(offset_type))
181 return _.diag(SPV_ERROR_INVALID_DATA, inst)
182 << "Expected Offset Type to be int scalar: "
183 << spvOpcodeString(opcode);
184
185 if (!count_type || !_.IsIntScalarType(count_type))
186 return _.diag(SPV_ERROR_INVALID_DATA, inst)
187 << "Expected Count Type to be int scalar: "
188 << spvOpcodeString(opcode);
189 break;
190 }
191
192 case SpvOpBitReverse: {
193 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
194
195 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
196 return error;
197 }
198
199 break;
200 }
201
202 case SpvOpBitCount: {
203 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
204 return _.diag(SPV_ERROR_INVALID_DATA, inst)
205 << "Expected int scalar or vector type as Result Type: "
206 << spvOpcodeString(opcode);
207
208 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
209 const uint32_t base_dimension = _.GetDimension(base_type);
210 const uint32_t result_dimension = _.GetDimension(result_type);
211
212 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
213 return error;
214 }
215
216 if (base_dimension != result_dimension)
217 return _.diag(SPV_ERROR_INVALID_DATA, inst)
218 << "Expected Base dimension to be equal to Result Type "
219 "dimension: "
220 << spvOpcodeString(opcode);
221 break;
222 }
223
224 default:
225 break;
226 }
227
228 return SPV_SUCCESS;
229 }
230
231 } // namespace val
232 } // namespace spvtools
233