1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/constants.h"
16
17 #include <vector>
18
19 #include "source/opt/ir_context.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace analysis {
24
GetFloat() const25 float Constant::GetFloat() const {
26 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
27
28 if (const FloatConstant* fc = AsFloatConstant()) {
29 return fc->GetFloatValue();
30 } else {
31 assert(AsNullConstant() && "Must be a floating point constant.");
32 return 0.0f;
33 }
34 }
35
GetDouble() const36 double Constant::GetDouble() const {
37 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
38
39 if (const FloatConstant* fc = AsFloatConstant()) {
40 return fc->GetDoubleValue();
41 } else {
42 assert(AsNullConstant() && "Must be a floating point constant.");
43 return 0.0;
44 }
45 }
46
GetValueAsDouble() const47 double Constant::GetValueAsDouble() const {
48 assert(type()->AsFloat() != nullptr);
49 if (type()->AsFloat()->width() == 32) {
50 return GetFloat();
51 } else {
52 assert(type()->AsFloat()->width() == 64);
53 return GetDouble();
54 }
55 }
56
GetU32() const57 uint32_t Constant::GetU32() const {
58 assert(type()->AsInteger() != nullptr);
59 assert(type()->AsInteger()->width() == 32);
60
61 if (const IntConstant* ic = AsIntConstant()) {
62 return ic->GetU32BitValue();
63 } else {
64 assert(AsNullConstant() && "Must be an integer constant.");
65 return 0u;
66 }
67 }
68
GetU64() const69 uint64_t Constant::GetU64() const {
70 assert(type()->AsInteger() != nullptr);
71 assert(type()->AsInteger()->width() == 64);
72
73 if (const IntConstant* ic = AsIntConstant()) {
74 return ic->GetU64BitValue();
75 } else {
76 assert(AsNullConstant() && "Must be an integer constant.");
77 return 0u;
78 }
79 }
80
GetS32() const81 int32_t Constant::GetS32() const {
82 assert(type()->AsInteger() != nullptr);
83 assert(type()->AsInteger()->width() == 32);
84
85 if (const IntConstant* ic = AsIntConstant()) {
86 return ic->GetS32BitValue();
87 } else {
88 assert(AsNullConstant() && "Must be an integer constant.");
89 return 0;
90 }
91 }
92
GetS64() const93 int64_t Constant::GetS64() const {
94 assert(type()->AsInteger() != nullptr);
95 assert(type()->AsInteger()->width() == 64);
96
97 if (const IntConstant* ic = AsIntConstant()) {
98 return ic->GetS64BitValue();
99 } else {
100 assert(AsNullConstant() && "Must be an integer constant.");
101 return 0;
102 }
103 }
104
GetZeroExtendedValue() const105 uint64_t Constant::GetZeroExtendedValue() const {
106 const auto* int_type = type()->AsInteger();
107 assert(int_type != nullptr);
108 const auto width = int_type->width();
109 assert(width <= 64);
110
111 uint64_t value = 0;
112 if (const IntConstant* ic = AsIntConstant()) {
113 if (width <= 32) {
114 value = ic->GetU32BitValue();
115 } else {
116 value = ic->GetU64BitValue();
117 }
118 } else {
119 assert(AsNullConstant() && "Must be an integer constant.");
120 }
121 return value;
122 }
123
GetSignExtendedValue() const124 int64_t Constant::GetSignExtendedValue() const {
125 const auto* int_type = type()->AsInteger();
126 assert(int_type != nullptr);
127 const auto width = int_type->width();
128 assert(width <= 64);
129
130 int64_t value = 0;
131 if (const IntConstant* ic = AsIntConstant()) {
132 if (width <= 32) {
133 // Let the C++ compiler do the sign extension.
134 value = int64_t(ic->GetS32BitValue());
135 } else {
136 value = ic->GetS64BitValue();
137 }
138 } else {
139 assert(AsNullConstant() && "Must be an integer constant.");
140 }
141 return value;
142 }
143
ConstantManager(IRContext * ctx)144 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
145 // Populate the constant table with values from constant declarations in the
146 // module. The values of each OpConstant declaration is the identity
147 // assignment (i.e., each constant is its own value).
148 for (const auto& inst : ctx_->module()->GetConstants()) {
149 MapInst(inst);
150 }
151 }
152
GetType(const Instruction * inst) const153 Type* ConstantManager::GetType(const Instruction* inst) const {
154 return context()->get_type_mgr()->GetType(inst->type_id());
155 }
156
GetOperandConstants(const Instruction * inst) const157 std::vector<const Constant*> ConstantManager::GetOperandConstants(
158 const Instruction* inst) const {
159 std::vector<const Constant*> constants;
160 constants.reserve(inst->NumInOperands());
161 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
162 const Operand* operand = &inst->GetInOperand(i);
163 if (operand->type != SPV_OPERAND_TYPE_ID) {
164 constants.push_back(nullptr);
165 } else {
166 uint32_t id = operand->words[0];
167 const analysis::Constant* constant = FindDeclaredConstant(id);
168 constants.push_back(constant);
169 }
170 }
171 return constants;
172 }
173
FindDeclaredConstant(const Constant * c,uint32_t type_id) const174 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
175 uint32_t type_id) const {
176 c = FindConstant(c);
177 if (c == nullptr) {
178 return 0;
179 }
180
181 for (auto range = const_val_to_id_.equal_range(c);
182 range.first != range.second; ++range.first) {
183 Instruction* const_def =
184 context()->get_def_use_mgr()->GetDef(range.first->second);
185 if (type_id == 0 || const_def->type_id() == type_id) {
186 return range.first->second;
187 }
188 }
189 return 0;
190 }
191
GetConstantsFromIds(const std::vector<uint32_t> & ids) const192 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
193 const std::vector<uint32_t>& ids) const {
194 std::vector<const Constant*> constants;
195 for (uint32_t id : ids) {
196 if (const Constant* c = FindDeclaredConstant(id)) {
197 constants.push_back(c);
198 } else {
199 return {};
200 }
201 }
202 return constants;
203 }
204
BuildInstructionAndAddToModule(const Constant * new_const,Module::inst_iterator * pos,uint32_t type_id)205 Instruction* ConstantManager::BuildInstructionAndAddToModule(
206 const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
207 // TODO(1841): Handle id overflow.
208 uint32_t new_id = context()->TakeNextId();
209 if (new_id == 0) {
210 return nullptr;
211 }
212
213 auto new_inst = CreateInstruction(new_id, new_const, type_id);
214 if (!new_inst) {
215 return nullptr;
216 }
217 auto* new_inst_ptr = new_inst.get();
218 *pos = pos->InsertBefore(std::move(new_inst));
219 ++(*pos);
220 if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse))
221 context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
222 MapConstantToInst(new_const, new_inst_ptr);
223 return new_inst_ptr;
224 }
225
GetDefiningInstruction(const Constant * c,uint32_t type_id,Module::inst_iterator * pos)226 Instruction* ConstantManager::GetDefiningInstruction(
227 const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
228 uint32_t decl_id = FindDeclaredConstant(c, type_id);
229 if (decl_id == 0) {
230 auto iter = context()->types_values_end();
231 if (pos == nullptr) pos = &iter;
232 return BuildInstructionAndAddToModule(c, pos, type_id);
233 } else {
234 auto def = context()->get_def_use_mgr()->GetDef(decl_id);
235 assert(def != nullptr);
236 assert((type_id == 0 || def->type_id() == type_id) &&
237 "This constant already has an instruction with a different type.");
238 return def;
239 }
240 }
241
CreateConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids) const242 std::unique_ptr<Constant> ConstantManager::CreateConstant(
243 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
244 if (literal_words_or_ids.size() == 0) {
245 // Constant declared with OpConstantNull
246 return MakeUnique<NullConstant>(type);
247 } else if (auto* bt = type->AsBool()) {
248 assert(literal_words_or_ids.size() == 1 &&
249 "Bool constant should be declared with one operand");
250 return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
251 } else if (auto* it = type->AsInteger()) {
252 return MakeUnique<IntConstant>(it, literal_words_or_ids);
253 } else if (auto* ft = type->AsFloat()) {
254 return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
255 } else if (auto* vt = type->AsVector()) {
256 auto components = GetConstantsFromIds(literal_words_or_ids);
257 if (components.empty()) return nullptr;
258 // All components of VectorConstant must be of type Bool, Integer or Float.
259 if (!std::all_of(components.begin(), components.end(),
260 [](const Constant* c) {
261 if (c->type()->AsBool() || c->type()->AsInteger() ||
262 c->type()->AsFloat()) {
263 return true;
264 } else {
265 return false;
266 }
267 }))
268 return nullptr;
269 // All components of VectorConstant must be in the same type.
270 const auto* component_type = components.front()->type();
271 if (!std::all_of(components.begin(), components.end(),
272 [&component_type](const Constant* c) {
273 if (c->type() == component_type) return true;
274 return false;
275 }))
276 return nullptr;
277 return MakeUnique<VectorConstant>(vt, components);
278 } else if (auto* mt = type->AsMatrix()) {
279 auto components = GetConstantsFromIds(literal_words_or_ids);
280 if (components.empty()) return nullptr;
281 return MakeUnique<MatrixConstant>(mt, components);
282 } else if (auto* st = type->AsStruct()) {
283 auto components = GetConstantsFromIds(literal_words_or_ids);
284 if (components.empty()) return nullptr;
285 return MakeUnique<StructConstant>(st, components);
286 } else if (auto* at = type->AsArray()) {
287 auto components = GetConstantsFromIds(literal_words_or_ids);
288 if (components.empty()) return nullptr;
289 return MakeUnique<ArrayConstant>(at, components);
290 } else {
291 return nullptr;
292 }
293 }
294
GetConstantFromInst(const Instruction * inst)295 const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
296 std::vector<uint32_t> literal_words_or_ids;
297
298 // Collect the constant defining literals or component ids.
299 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
300 literal_words_or_ids.insert(literal_words_or_ids.end(),
301 inst->GetInOperand(i).words.begin(),
302 inst->GetInOperand(i).words.end());
303 }
304
305 switch (inst->opcode()) {
306 // OpConstant{True|False} have the value embedded in the opcode. So they
307 // are not handled by the for-loop above. Here we add the value explicitly.
308 case spv::Op::OpConstantTrue:
309 literal_words_or_ids.push_back(true);
310 break;
311 case spv::Op::OpConstantFalse:
312 literal_words_or_ids.push_back(false);
313 break;
314 case spv::Op::OpConstantNull:
315 case spv::Op::OpConstant:
316 case spv::Op::OpConstantComposite:
317 case spv::Op::OpSpecConstantComposite:
318 break;
319 default:
320 return nullptr;
321 }
322
323 return GetConstant(GetType(inst), literal_words_or_ids);
324 }
325
CreateInstruction(uint32_t id,const Constant * c,uint32_t type_id) const326 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
327 uint32_t id, const Constant* c, uint32_t type_id) const {
328 uint32_t type =
329 (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
330 if (c->AsNullConstant()) {
331 return MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, type, id,
332 std::initializer_list<Operand>{});
333 } else if (const BoolConstant* bc = c->AsBoolConstant()) {
334 return MakeUnique<Instruction>(
335 context(),
336 bc->value() ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse, type,
337 id, std::initializer_list<Operand>{});
338 } else if (const IntConstant* ic = c->AsIntConstant()) {
339 return MakeUnique<Instruction>(
340 context(), spv::Op::OpConstant, type, id,
341 std::initializer_list<Operand>{
342 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
343 ic->words())});
344 } else if (const FloatConstant* fc = c->AsFloatConstant()) {
345 return MakeUnique<Instruction>(
346 context(), spv::Op::OpConstant, type, id,
347 std::initializer_list<Operand>{
348 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
349 fc->words())});
350 } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
351 return CreateCompositeInstruction(id, cc, type_id);
352 } else {
353 return nullptr;
354 }
355 }
356
CreateCompositeInstruction(uint32_t result_id,const CompositeConstant * cc,uint32_t type_id) const357 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
358 uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
359 std::vector<Operand> operands;
360 Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
361 uint32_t component_index = 0;
362 for (const Constant* component_const : cc->GetComponents()) {
363 uint32_t component_type_id = 0;
364 if (type_inst && type_inst->opcode() == spv::Op::OpTypeStruct) {
365 component_type_id = type_inst->GetSingleWordInOperand(component_index);
366 } else if (type_inst && type_inst->opcode() == spv::Op::OpTypeArray) {
367 component_type_id = type_inst->GetSingleWordInOperand(0);
368 }
369 uint32_t id = FindDeclaredConstant(component_const, component_type_id);
370
371 if (id == 0) {
372 // Cannot get the id of the component constant, while all components
373 // should have been added to the module prior to the composite constant.
374 // Cannot create OpConstantComposite instruction in this case.
375 return nullptr;
376 }
377 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
378 std::initializer_list<uint32_t>{id});
379 component_index++;
380 }
381 uint32_t type =
382 (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
383 return MakeUnique<Instruction>(context(), spv::Op::OpConstantComposite, type,
384 result_id, std::move(operands));
385 }
386
GetConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids)387 const Constant* ConstantManager::GetConstant(
388 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
389 auto cst = CreateConstant(type, literal_words_or_ids);
390 return cst ? RegisterConstant(std::move(cst)) : nullptr;
391 }
392
GetNullCompositeConstant(const Type * type)393 const Constant* ConstantManager::GetNullCompositeConstant(const Type* type) {
394 std::vector<uint32_t> literal_words_or_id;
395
396 if (type->AsVector()) {
397 const Type* element_type = type->AsVector()->element_type();
398 const uint32_t null_id = GetNullConstId(element_type);
399 const uint32_t element_count = type->AsVector()->element_count();
400 for (uint32_t i = 0; i < element_count; i++) {
401 literal_words_or_id.push_back(null_id);
402 }
403 } else if (type->AsMatrix()) {
404 const Type* element_type = type->AsMatrix()->element_type();
405 const uint32_t null_id = GetNullConstId(element_type);
406 const uint32_t element_count = type->AsMatrix()->element_count();
407 for (uint32_t i = 0; i < element_count; i++) {
408 literal_words_or_id.push_back(null_id);
409 }
410 } else if (type->AsStruct()) {
411 // TODO (sfricke-lunarg) add proper struct support
412 return nullptr;
413 } else if (type->AsArray()) {
414 const Type* element_type = type->AsArray()->element_type();
415 const uint32_t null_id = GetNullConstId(element_type);
416 assert(type->AsArray()->length_info().words[0] ==
417 analysis::Array::LengthInfo::kConstant &&
418 "unexpected array length");
419 const uint32_t element_count = type->AsArray()->length_info().words[0];
420 for (uint32_t i = 0; i < element_count; i++) {
421 literal_words_or_id.push_back(null_id);
422 }
423 } else {
424 return nullptr;
425 }
426
427 return GetConstant(type, literal_words_or_id);
428 }
429
GetNumericVectorConstantWithWords(const Vector * type,const std::vector<uint32_t> & literal_words)430 const Constant* ConstantManager::GetNumericVectorConstantWithWords(
431 const Vector* type, const std::vector<uint32_t>& literal_words) {
432 const auto* element_type = type->element_type();
433 uint32_t words_per_element = 0;
434 if (const auto* float_type = element_type->AsFloat())
435 words_per_element = float_type->width() / 32;
436 else if (const auto* int_type = element_type->AsInteger())
437 words_per_element = int_type->width() / 32;
438 else if (element_type->AsBool() != nullptr)
439 words_per_element = 1;
440
441 if (words_per_element != 1 && words_per_element != 2) return nullptr;
442
443 if (words_per_element * type->element_count() !=
444 static_cast<uint32_t>(literal_words.size())) {
445 return nullptr;
446 }
447
448 std::vector<uint32_t> element_ids;
449 for (uint32_t i = 0; i < type->element_count(); ++i) {
450 auto first_word = literal_words.begin() + (words_per_element * i);
451 std::vector<uint32_t> const_data(first_word,
452 first_word + words_per_element);
453 const analysis::Constant* element_constant =
454 GetConstant(element_type, const_data);
455 auto element_id = GetDefiningInstruction(element_constant)->result_id();
456 element_ids.push_back(element_id);
457 }
458
459 return GetConstant(type, element_ids);
460 }
461
GetFloatConstId(float val)462 uint32_t ConstantManager::GetFloatConstId(float val) {
463 const Constant* c = GetFloatConst(val);
464 return GetDefiningInstruction(c)->result_id();
465 }
466
GetFloatConst(float val)467 const Constant* ConstantManager::GetFloatConst(float val) {
468 Type* float_type = context()->get_type_mgr()->GetFloatType();
469 utils::FloatProxy<float> v(val);
470 const Constant* c = GetConstant(float_type, v.GetWords());
471 return c;
472 }
473
GetDoubleConstId(double val)474 uint32_t ConstantManager::GetDoubleConstId(double val) {
475 const Constant* c = GetDoubleConst(val);
476 return GetDefiningInstruction(c)->result_id();
477 }
478
GetDoubleConst(double val)479 const Constant* ConstantManager::GetDoubleConst(double val) {
480 Type* float_type = context()->get_type_mgr()->GetDoubleType();
481 utils::FloatProxy<double> v(val);
482 const Constant* c = GetConstant(float_type, v.GetWords());
483 return c;
484 }
485
GetSIntConstId(int32_t val)486 uint32_t ConstantManager::GetSIntConstId(int32_t val) {
487 Type* sint_type = context()->get_type_mgr()->GetSIntType();
488 const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
489 return GetDefiningInstruction(c)->result_id();
490 }
491
GetIntConst(uint64_t val,int32_t bitWidth,bool isSigned)492 const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth,
493 bool isSigned) {
494 Type* int_type = context()->get_type_mgr()->GetIntType(bitWidth, isSigned);
495
496 if (isSigned) {
497 // Sign extend the value.
498 int32_t num_of_bit_to_ignore = 64 - bitWidth;
499 val = static_cast<int64_t>(val << num_of_bit_to_ignore) >>
500 num_of_bit_to_ignore;
501 } else {
502 // Clear the upper bit that are not used.
503 uint64_t mask = ((1ull << bitWidth) - 1);
504 val &= mask;
505 }
506
507 if (bitWidth <= 32) {
508 return GetConstant(int_type, {static_cast<uint32_t>(val)});
509 }
510
511 // If the value is more than 32-bit, we need to split the operands into two
512 // 32-bit integers.
513 return GetConstant(
514 int_type, {static_cast<uint32_t>(val >> 32), static_cast<uint32_t>(val)});
515 }
516
GetUIntConstId(uint32_t val)517 uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
518 Type* uint_type = context()->get_type_mgr()->GetUIntType();
519 const Constant* c = GetConstant(uint_type, {val});
520 return GetDefiningInstruction(c)->result_id();
521 }
522
GetNullConstId(const Type * type)523 uint32_t ConstantManager::GetNullConstId(const Type* type) {
524 const Constant* c = GetConstant(type, {});
525 return GetDefiningInstruction(c)->result_id();
526 }
527
GetVectorComponents(analysis::ConstantManager * const_mgr) const528 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
529 analysis::ConstantManager* const_mgr) const {
530 std::vector<const analysis::Constant*> components;
531 const analysis::VectorConstant* a = this->AsVectorConstant();
532 const analysis::Vector* vector_type = this->type()->AsVector();
533 assert(vector_type != nullptr);
534 if (a != nullptr) {
535 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
536 components.push_back(a->GetComponents()[i]);
537 }
538 } else {
539 const analysis::Type* element_type = vector_type->element_type();
540 const analysis::Constant* element_null_const =
541 const_mgr->GetConstant(element_type, {});
542 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
543 components.push_back(element_null_const);
544 }
545 }
546 return components;
547 }
548
549 } // namespace analysis
550 } // namespace opt
551 } // namespace spvtools
552