1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/constants.h"
16
17 #include <unordered_map>
18 #include <vector>
19
20 #include "source/opt/ir_context.h"
21
22 namespace spvtools {
23 namespace opt {
24 namespace analysis {
25
GetFloat() const26 float Constant::GetFloat() const {
27 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
28
29 if (const FloatConstant* fc = AsFloatConstant()) {
30 return fc->GetFloatValue();
31 } else {
32 assert(AsNullConstant() && "Must be a floating point constant.");
33 return 0.0f;
34 }
35 }
36
GetDouble() const37 double Constant::GetDouble() const {
38 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
39
40 if (const FloatConstant* fc = AsFloatConstant()) {
41 return fc->GetDoubleValue();
42 } else {
43 assert(AsNullConstant() && "Must be a floating point constant.");
44 return 0.0;
45 }
46 }
47
GetValueAsDouble() const48 double Constant::GetValueAsDouble() const {
49 assert(type()->AsFloat() != nullptr);
50 if (type()->AsFloat()->width() == 32) {
51 return GetFloat();
52 } else {
53 assert(type()->AsFloat()->width() == 64);
54 return GetDouble();
55 }
56 }
57
GetU32() const58 uint32_t Constant::GetU32() const {
59 assert(type()->AsInteger() != nullptr);
60 assert(type()->AsInteger()->width() == 32);
61
62 if (const IntConstant* ic = AsIntConstant()) {
63 return ic->GetU32BitValue();
64 } else {
65 assert(AsNullConstant() && "Must be an integer constant.");
66 return 0u;
67 }
68 }
69
GetU64() const70 uint64_t Constant::GetU64() const {
71 assert(type()->AsInteger() != nullptr);
72 assert(type()->AsInteger()->width() == 64);
73
74 if (const IntConstant* ic = AsIntConstant()) {
75 return ic->GetU64BitValue();
76 } else {
77 assert(AsNullConstant() && "Must be an integer constant.");
78 return 0u;
79 }
80 }
81
GetS32() const82 int32_t Constant::GetS32() const {
83 assert(type()->AsInteger() != nullptr);
84 assert(type()->AsInteger()->width() == 32);
85
86 if (const IntConstant* ic = AsIntConstant()) {
87 return ic->GetS32BitValue();
88 } else {
89 assert(AsNullConstant() && "Must be an integer constant.");
90 return 0;
91 }
92 }
93
GetS64() const94 int64_t Constant::GetS64() const {
95 assert(type()->AsInteger() != nullptr);
96 assert(type()->AsInteger()->width() == 64);
97
98 if (const IntConstant* ic = AsIntConstant()) {
99 return ic->GetS64BitValue();
100 } else {
101 assert(AsNullConstant() && "Must be an integer constant.");
102 return 0;
103 }
104 }
105
GetZeroExtendedValue() const106 uint64_t Constant::GetZeroExtendedValue() const {
107 const auto* int_type = type()->AsInteger();
108 assert(int_type != nullptr);
109 const auto width = int_type->width();
110 assert(width <= 64);
111
112 uint64_t value = 0;
113 if (const IntConstant* ic = AsIntConstant()) {
114 if (width <= 32) {
115 value = ic->GetU32BitValue();
116 } else {
117 value = ic->GetU64BitValue();
118 }
119 } else {
120 assert(AsNullConstant() && "Must be an integer constant.");
121 }
122 return value;
123 }
124
GetSignExtendedValue() const125 int64_t Constant::GetSignExtendedValue() const {
126 const auto* int_type = type()->AsInteger();
127 assert(int_type != nullptr);
128 const auto width = int_type->width();
129 assert(width <= 64);
130
131 int64_t value = 0;
132 if (const IntConstant* ic = AsIntConstant()) {
133 if (width <= 32) {
134 // Let the C++ compiler do the sign extension.
135 value = int64_t(ic->GetS32BitValue());
136 } else {
137 value = ic->GetS64BitValue();
138 }
139 } else {
140 assert(AsNullConstant() && "Must be an integer constant.");
141 }
142 return value;
143 }
144
ConstantManager(IRContext * ctx)145 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
146 // Populate the constant table with values from constant declarations in the
147 // module. The values of each OpConstant declaration is the identity
148 // assignment (i.e., each constant is its own value).
149 for (const auto& inst : ctx_->module()->GetConstants()) {
150 MapInst(inst);
151 }
152 }
153
GetType(const Instruction * inst) const154 Type* ConstantManager::GetType(const Instruction* inst) const {
155 return context()->get_type_mgr()->GetType(inst->type_id());
156 }
157
GetOperandConstants(const Instruction * inst) const158 std::vector<const Constant*> ConstantManager::GetOperandConstants(
159 const Instruction* inst) const {
160 std::vector<const Constant*> constants;
161 constants.reserve(inst->NumInOperands());
162 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
163 const Operand* operand = &inst->GetInOperand(i);
164 if (operand->type != SPV_OPERAND_TYPE_ID) {
165 constants.push_back(nullptr);
166 } else {
167 uint32_t id = operand->words[0];
168 const analysis::Constant* constant = FindDeclaredConstant(id);
169 constants.push_back(constant);
170 }
171 }
172 return constants;
173 }
174
FindDeclaredConstant(const Constant * c,uint32_t type_id) const175 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
176 uint32_t type_id) const {
177 c = FindConstant(c);
178 if (c == nullptr) {
179 return 0;
180 }
181
182 for (auto range = const_val_to_id_.equal_range(c);
183 range.first != range.second; ++range.first) {
184 Instruction* const_def =
185 context()->get_def_use_mgr()->GetDef(range.first->second);
186 if (type_id == 0 || const_def->type_id() == type_id) {
187 return range.first->second;
188 }
189 }
190 return 0;
191 }
192
GetConstantsFromIds(const std::vector<uint32_t> & ids) const193 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
194 const std::vector<uint32_t>& ids) const {
195 std::vector<const Constant*> constants;
196 for (uint32_t id : ids) {
197 if (const Constant* c = FindDeclaredConstant(id)) {
198 constants.push_back(c);
199 } else {
200 return {};
201 }
202 }
203 return constants;
204 }
205
BuildInstructionAndAddToModule(const Constant * new_const,Module::inst_iterator * pos,uint32_t type_id)206 Instruction* ConstantManager::BuildInstructionAndAddToModule(
207 const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
208 // TODO(1841): Handle id overflow.
209 uint32_t new_id = context()->TakeNextId();
210 if (new_id == 0) {
211 return nullptr;
212 }
213
214 auto new_inst = CreateInstruction(new_id, new_const, type_id);
215 if (!new_inst) {
216 return nullptr;
217 }
218 auto* new_inst_ptr = new_inst.get();
219 *pos = pos->InsertBefore(std::move(new_inst));
220 ++(*pos);
221 if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse))
222 context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
223 MapConstantToInst(new_const, new_inst_ptr);
224 return new_inst_ptr;
225 }
226
GetDefiningInstruction(const Constant * c,uint32_t type_id,Module::inst_iterator * pos)227 Instruction* ConstantManager::GetDefiningInstruction(
228 const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
229 uint32_t decl_id = FindDeclaredConstant(c, type_id);
230 if (decl_id == 0) {
231 auto iter = context()->types_values_end();
232 if (pos == nullptr) pos = &iter;
233 return BuildInstructionAndAddToModule(c, pos, type_id);
234 } else {
235 auto def = context()->get_def_use_mgr()->GetDef(decl_id);
236 assert(def != nullptr);
237 assert((type_id == 0 || def->type_id() == type_id) &&
238 "This constant already has an instruction with a different type.");
239 return def;
240 }
241 }
242
CreateConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids) const243 std::unique_ptr<Constant> ConstantManager::CreateConstant(
244 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
245 if (literal_words_or_ids.size() == 0) {
246 // Constant declared with OpConstantNull
247 return MakeUnique<NullConstant>(type);
248 } else if (auto* bt = type->AsBool()) {
249 assert(literal_words_or_ids.size() == 1 &&
250 "Bool constant should be declared with one operand");
251 return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
252 } else if (auto* it = type->AsInteger()) {
253 return MakeUnique<IntConstant>(it, literal_words_or_ids);
254 } else if (auto* ft = type->AsFloat()) {
255 return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
256 } else if (auto* vt = type->AsVector()) {
257 auto components = GetConstantsFromIds(literal_words_or_ids);
258 if (components.empty()) return nullptr;
259 // All components of VectorConstant must be of type Bool, Integer or Float.
260 if (!std::all_of(components.begin(), components.end(),
261 [](const Constant* c) {
262 if (c->type()->AsBool() || c->type()->AsInteger() ||
263 c->type()->AsFloat()) {
264 return true;
265 } else {
266 return false;
267 }
268 }))
269 return nullptr;
270 // All components of VectorConstant must be in the same type.
271 const auto* component_type = components.front()->type();
272 if (!std::all_of(components.begin(), components.end(),
273 [&component_type](const Constant* c) {
274 if (c->type() == component_type) return true;
275 return false;
276 }))
277 return nullptr;
278 return MakeUnique<VectorConstant>(vt, components);
279 } else if (auto* mt = type->AsMatrix()) {
280 auto components = GetConstantsFromIds(literal_words_or_ids);
281 if (components.empty()) return nullptr;
282 return MakeUnique<MatrixConstant>(mt, components);
283 } else if (auto* st = type->AsStruct()) {
284 auto components = GetConstantsFromIds(literal_words_or_ids);
285 if (components.empty()) return nullptr;
286 return MakeUnique<StructConstant>(st, components);
287 } else if (auto* at = type->AsArray()) {
288 auto components = GetConstantsFromIds(literal_words_or_ids);
289 if (components.empty()) return nullptr;
290 return MakeUnique<ArrayConstant>(at, components);
291 } else {
292 return nullptr;
293 }
294 }
295
GetConstantFromInst(const Instruction * inst)296 const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
297 std::vector<uint32_t> literal_words_or_ids;
298
299 // Collect the constant defining literals or component ids.
300 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
301 literal_words_or_ids.insert(literal_words_or_ids.end(),
302 inst->GetInOperand(i).words.begin(),
303 inst->GetInOperand(i).words.end());
304 }
305
306 switch (inst->opcode()) {
307 // OpConstant{True|False} have the value embedded in the opcode. So they
308 // are not handled by the for-loop above. Here we add the value explicitly.
309 case SpvOp::SpvOpConstantTrue:
310 literal_words_or_ids.push_back(true);
311 break;
312 case SpvOp::SpvOpConstantFalse:
313 literal_words_or_ids.push_back(false);
314 break;
315 case SpvOp::SpvOpConstantNull:
316 case SpvOp::SpvOpConstant:
317 case SpvOp::SpvOpConstantComposite:
318 case SpvOp::SpvOpSpecConstantComposite:
319 break;
320 default:
321 return nullptr;
322 }
323
324 return GetConstant(GetType(inst), literal_words_or_ids);
325 }
326
CreateInstruction(uint32_t id,const Constant * c,uint32_t type_id) const327 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
328 uint32_t id, const Constant* c, uint32_t type_id) const {
329 uint32_t type =
330 (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
331 if (c->AsNullConstant()) {
332 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type,
333 id, std::initializer_list<Operand>{});
334 } else if (const BoolConstant* bc = c->AsBoolConstant()) {
335 return MakeUnique<Instruction>(
336 context(),
337 bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
338 type, id, std::initializer_list<Operand>{});
339 } else if (const IntConstant* ic = c->AsIntConstant()) {
340 return MakeUnique<Instruction>(
341 context(), SpvOp::SpvOpConstant, type, id,
342 std::initializer_list<Operand>{
343 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
344 ic->words())});
345 } else if (const FloatConstant* fc = c->AsFloatConstant()) {
346 return MakeUnique<Instruction>(
347 context(), SpvOp::SpvOpConstant, type, id,
348 std::initializer_list<Operand>{
349 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
350 fc->words())});
351 } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
352 return CreateCompositeInstruction(id, cc, type_id);
353 } else {
354 return nullptr;
355 }
356 }
357
CreateCompositeInstruction(uint32_t result_id,const CompositeConstant * cc,uint32_t type_id) const358 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
359 uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
360 std::vector<Operand> operands;
361 Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
362 uint32_t component_index = 0;
363 for (const Constant* component_const : cc->GetComponents()) {
364 uint32_t component_type_id = 0;
365 if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
366 component_type_id = type_inst->GetSingleWordInOperand(component_index);
367 } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
368 component_type_id = type_inst->GetSingleWordInOperand(0);
369 }
370 uint32_t id = FindDeclaredConstant(component_const, component_type_id);
371
372 if (id == 0) {
373 // Cannot get the id of the component constant, while all components
374 // should have been added to the module prior to the composite constant.
375 // Cannot create OpConstantComposite instruction in this case.
376 return nullptr;
377 }
378 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
379 std::initializer_list<uint32_t>{id});
380 component_index++;
381 }
382 uint32_t type =
383 (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
384 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type,
385 result_id, std::move(operands));
386 }
387
GetConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids)388 const Constant* ConstantManager::GetConstant(
389 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
390 auto cst = CreateConstant(type, literal_words_or_ids);
391 return cst ? RegisterConstant(std::move(cst)) : nullptr;
392 }
393
GetNumericVectorConstantWithWords(const Vector * type,const std::vector<uint32_t> & literal_words)394 const Constant* ConstantManager::GetNumericVectorConstantWithWords(
395 const Vector* type, const std::vector<uint32_t>& literal_words) {
396 const auto* element_type = type->element_type();
397 uint32_t words_per_element = 0;
398 if (const auto* float_type = element_type->AsFloat())
399 words_per_element = float_type->width() / 32;
400 else if (const auto* int_type = element_type->AsInteger())
401 words_per_element = int_type->width() / 32;
402
403 if (words_per_element != 1 && words_per_element != 2) return nullptr;
404
405 if (words_per_element * type->element_count() !=
406 static_cast<uint32_t>(literal_words.size())) {
407 return nullptr;
408 }
409
410 std::vector<uint32_t> element_ids;
411 for (uint32_t i = 0; i < type->element_count(); ++i) {
412 auto first_word = literal_words.begin() + (words_per_element * i);
413 std::vector<uint32_t> const_data(first_word,
414 first_word + words_per_element);
415 const analysis::Constant* element_constant =
416 GetConstant(element_type, const_data);
417 auto element_id = GetDefiningInstruction(element_constant)->result_id();
418 element_ids.push_back(element_id);
419 }
420
421 return GetConstant(type, element_ids);
422 }
423
GetFloatConstId(float val)424 uint32_t ConstantManager::GetFloatConstId(float val) {
425 const Constant* c = GetFloatConst(val);
426 return GetDefiningInstruction(c)->result_id();
427 }
428
GetFloatConst(float val)429 const Constant* ConstantManager::GetFloatConst(float val) {
430 Type* float_type = context()->get_type_mgr()->GetFloatType();
431 utils::FloatProxy<float> v(val);
432 const Constant* c = GetConstant(float_type, v.GetWords());
433 return c;
434 }
435
GetDoubleConstId(double val)436 uint32_t ConstantManager::GetDoubleConstId(double val) {
437 const Constant* c = GetDoubleConst(val);
438 return GetDefiningInstruction(c)->result_id();
439 }
440
GetDoubleConst(double val)441 const Constant* ConstantManager::GetDoubleConst(double val) {
442 Type* float_type = context()->get_type_mgr()->GetDoubleType();
443 utils::FloatProxy<double> v(val);
444 const Constant* c = GetConstant(float_type, v.GetWords());
445 return c;
446 }
447
GetSIntConst(int32_t val)448 uint32_t ConstantManager::GetSIntConst(int32_t val) {
449 Type* sint_type = context()->get_type_mgr()->GetSIntType();
450 const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
451 return GetDefiningInstruction(c)->result_id();
452 }
453
GetUIntConst(uint32_t val)454 uint32_t ConstantManager::GetUIntConst(uint32_t val) {
455 Type* uint_type = context()->get_type_mgr()->GetUIntType();
456 const Constant* c = GetConstant(uint_type, {val});
457 return GetDefiningInstruction(c)->result_id();
458 }
459
GetVectorComponents(analysis::ConstantManager * const_mgr) const460 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
461 analysis::ConstantManager* const_mgr) const {
462 std::vector<const analysis::Constant*> components;
463 const analysis::VectorConstant* a = this->AsVectorConstant();
464 const analysis::Vector* vector_type = this->type()->AsVector();
465 assert(vector_type != nullptr);
466 if (a != nullptr) {
467 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
468 components.push_back(a->GetComponents()[i]);
469 }
470 } else {
471 const analysis::Type* element_type = vector_type->element_type();
472 const analysis::Constant* element_null_const =
473 const_mgr->GetConstant(element_type, {});
474 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
475 components.push_back(element_null_const);
476 }
477 }
478 return components;
479 }
480
481 } // namespace analysis
482 } // namespace opt
483 } // namespace spvtools
484