• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/constants.h"
16 
17 #include <vector>
18 
19 #include "source/opt/ir_context.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace analysis {
24 
GetFloat() const25 float Constant::GetFloat() const {
26   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
27 
28   if (const FloatConstant* fc = AsFloatConstant()) {
29     return fc->GetFloatValue();
30   } else {
31     assert(AsNullConstant() && "Must be a floating point constant.");
32     return 0.0f;
33   }
34 }
35 
GetDouble() const36 double Constant::GetDouble() const {
37   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
38 
39   if (const FloatConstant* fc = AsFloatConstant()) {
40     return fc->GetDoubleValue();
41   } else {
42     assert(AsNullConstant() && "Must be a floating point constant.");
43     return 0.0;
44   }
45 }
46 
GetValueAsDouble() const47 double Constant::GetValueAsDouble() const {
48   assert(type()->AsFloat() != nullptr);
49   if (type()->AsFloat()->width() == 32) {
50     return GetFloat();
51   } else {
52     assert(type()->AsFloat()->width() == 64);
53     return GetDouble();
54   }
55 }
56 
GetU32() const57 uint32_t Constant::GetU32() const {
58   assert(type()->AsInteger() != nullptr);
59   assert(type()->AsInteger()->width() == 32);
60 
61   if (const IntConstant* ic = AsIntConstant()) {
62     return ic->GetU32BitValue();
63   } else {
64     assert(AsNullConstant() && "Must be an integer constant.");
65     return 0u;
66   }
67 }
68 
GetU64() const69 uint64_t Constant::GetU64() const {
70   assert(type()->AsInteger() != nullptr);
71   assert(type()->AsInteger()->width() == 64);
72 
73   if (const IntConstant* ic = AsIntConstant()) {
74     return ic->GetU64BitValue();
75   } else {
76     assert(AsNullConstant() && "Must be an integer constant.");
77     return 0u;
78   }
79 }
80 
GetS32() const81 int32_t Constant::GetS32() const {
82   assert(type()->AsInteger() != nullptr);
83   assert(type()->AsInteger()->width() == 32);
84 
85   if (const IntConstant* ic = AsIntConstant()) {
86     return ic->GetS32BitValue();
87   } else {
88     assert(AsNullConstant() && "Must be an integer constant.");
89     return 0;
90   }
91 }
92 
GetS64() const93 int64_t Constant::GetS64() const {
94   assert(type()->AsInteger() != nullptr);
95   assert(type()->AsInteger()->width() == 64);
96 
97   if (const IntConstant* ic = AsIntConstant()) {
98     return ic->GetS64BitValue();
99   } else {
100     assert(AsNullConstant() && "Must be an integer constant.");
101     return 0;
102   }
103 }
104 
GetZeroExtendedValue() const105 uint64_t Constant::GetZeroExtendedValue() const {
106   const auto* int_type = type()->AsInteger();
107   assert(int_type != nullptr);
108   const auto width = int_type->width();
109   assert(width <= 64);
110 
111   uint64_t value = 0;
112   if (const IntConstant* ic = AsIntConstant()) {
113     if (width <= 32) {
114       value = ic->GetU32BitValue();
115     } else {
116       value = ic->GetU64BitValue();
117     }
118   } else {
119     assert(AsNullConstant() && "Must be an integer constant.");
120   }
121   return value;
122 }
123 
GetSignExtendedValue() const124 int64_t Constant::GetSignExtendedValue() const {
125   const auto* int_type = type()->AsInteger();
126   assert(int_type != nullptr);
127   const auto width = int_type->width();
128   assert(width <= 64);
129 
130   int64_t value = 0;
131   if (const IntConstant* ic = AsIntConstant()) {
132     if (width <= 32) {
133       // Let the C++ compiler do the sign extension.
134       value = int64_t(ic->GetS32BitValue());
135     } else {
136       value = ic->GetS64BitValue();
137     }
138   } else {
139     assert(AsNullConstant() && "Must be an integer constant.");
140   }
141   return value;
142 }
143 
ConstantManager(IRContext * ctx)144 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
145   // Populate the constant table with values from constant declarations in the
146   // module.  The values of each OpConstant declaration is the identity
147   // assignment (i.e., each constant is its own value).
148   for (const auto& inst : ctx_->module()->GetConstants()) {
149     MapInst(inst);
150   }
151 }
152 
GetType(const Instruction * inst) const153 Type* ConstantManager::GetType(const Instruction* inst) const {
154   return context()->get_type_mgr()->GetType(inst->type_id());
155 }
156 
GetOperandConstants(const Instruction * inst) const157 std::vector<const Constant*> ConstantManager::GetOperandConstants(
158     const Instruction* inst) const {
159   std::vector<const Constant*> constants;
160   constants.reserve(inst->NumInOperands());
161   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
162     const Operand* operand = &inst->GetInOperand(i);
163     if (operand->type != SPV_OPERAND_TYPE_ID) {
164       constants.push_back(nullptr);
165     } else {
166       uint32_t id = operand->words[0];
167       const analysis::Constant* constant = FindDeclaredConstant(id);
168       constants.push_back(constant);
169     }
170   }
171   return constants;
172 }
173 
FindDeclaredConstant(const Constant * c,uint32_t type_id) const174 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
175                                                uint32_t type_id) const {
176   c = FindConstant(c);
177   if (c == nullptr) {
178     return 0;
179   }
180 
181   for (auto range = const_val_to_id_.equal_range(c);
182        range.first != range.second; ++range.first) {
183     Instruction* const_def =
184         context()->get_def_use_mgr()->GetDef(range.first->second);
185     if (type_id == 0 || const_def->type_id() == type_id) {
186       return range.first->second;
187     }
188   }
189   return 0;
190 }
191 
GetConstantsFromIds(const std::vector<uint32_t> & ids) const192 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
193     const std::vector<uint32_t>& ids) const {
194   std::vector<const Constant*> constants;
195   for (uint32_t id : ids) {
196     if (const Constant* c = FindDeclaredConstant(id)) {
197       constants.push_back(c);
198     } else {
199       return {};
200     }
201   }
202   return constants;
203 }
204 
BuildInstructionAndAddToModule(const Constant * new_const,Module::inst_iterator * pos,uint32_t type_id)205 Instruction* ConstantManager::BuildInstructionAndAddToModule(
206     const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
207   // TODO(1841): Handle id overflow.
208   uint32_t new_id = context()->TakeNextId();
209   if (new_id == 0) {
210     return nullptr;
211   }
212 
213   auto new_inst = CreateInstruction(new_id, new_const, type_id);
214   if (!new_inst) {
215     return nullptr;
216   }
217   auto* new_inst_ptr = new_inst.get();
218   *pos = pos->InsertBefore(std::move(new_inst));
219   ++(*pos);
220   if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse))
221     context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
222   MapConstantToInst(new_const, new_inst_ptr);
223   return new_inst_ptr;
224 }
225 
GetDefiningInstruction(const Constant * c,uint32_t type_id,Module::inst_iterator * pos)226 Instruction* ConstantManager::GetDefiningInstruction(
227     const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
228   uint32_t decl_id = FindDeclaredConstant(c, type_id);
229   if (decl_id == 0) {
230     auto iter = context()->types_values_end();
231     if (pos == nullptr) pos = &iter;
232     return BuildInstructionAndAddToModule(c, pos, type_id);
233   } else {
234     auto def = context()->get_def_use_mgr()->GetDef(decl_id);
235     assert(def != nullptr);
236     assert((type_id == 0 || def->type_id() == type_id) &&
237            "This constant already has an instruction with a different type.");
238     return def;
239   }
240 }
241 
CreateConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids) const242 std::unique_ptr<Constant> ConstantManager::CreateConstant(
243     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
244   if (literal_words_or_ids.size() == 0) {
245     // Constant declared with OpConstantNull
246     return MakeUnique<NullConstant>(type);
247   } else if (auto* bt = type->AsBool()) {
248     assert(literal_words_or_ids.size() == 1 &&
249            "Bool constant should be declared with one operand");
250     return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
251   } else if (auto* it = type->AsInteger()) {
252     return MakeUnique<IntConstant>(it, literal_words_or_ids);
253   } else if (auto* ft = type->AsFloat()) {
254     return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
255   } else if (auto* vt = type->AsVector()) {
256     auto components = GetConstantsFromIds(literal_words_or_ids);
257     if (components.empty()) return nullptr;
258     // All components of VectorConstant must be of type Bool, Integer or Float.
259     if (!std::all_of(components.begin(), components.end(),
260                      [](const Constant* c) {
261                        if (c->type()->AsBool() || c->type()->AsInteger() ||
262                            c->type()->AsFloat()) {
263                          return true;
264                        } else {
265                          return false;
266                        }
267                      }))
268       return nullptr;
269     // All components of VectorConstant must be in the same type.
270     const auto* component_type = components.front()->type();
271     if (!std::all_of(components.begin(), components.end(),
272                      [&component_type](const Constant* c) {
273                        if (c->type() == component_type) return true;
274                        return false;
275                      }))
276       return nullptr;
277     return MakeUnique<VectorConstant>(vt, components);
278   } else if (auto* mt = type->AsMatrix()) {
279     auto components = GetConstantsFromIds(literal_words_or_ids);
280     if (components.empty()) return nullptr;
281     return MakeUnique<MatrixConstant>(mt, components);
282   } else if (auto* st = type->AsStruct()) {
283     auto components = GetConstantsFromIds(literal_words_or_ids);
284     if (components.empty()) return nullptr;
285     return MakeUnique<StructConstant>(st, components);
286   } else if (auto* at = type->AsArray()) {
287     auto components = GetConstantsFromIds(literal_words_or_ids);
288     if (components.empty()) return nullptr;
289     return MakeUnique<ArrayConstant>(at, components);
290   } else {
291     return nullptr;
292   }
293 }
294 
GetConstantFromInst(const Instruction * inst)295 const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
296   std::vector<uint32_t> literal_words_or_ids;
297 
298   // Collect the constant defining literals or component ids.
299   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
300     literal_words_or_ids.insert(literal_words_or_ids.end(),
301                                 inst->GetInOperand(i).words.begin(),
302                                 inst->GetInOperand(i).words.end());
303   }
304 
305   switch (inst->opcode()) {
306     // OpConstant{True|False} have the value embedded in the opcode. So they
307     // are not handled by the for-loop above. Here we add the value explicitly.
308     case spv::Op::OpConstantTrue:
309       literal_words_or_ids.push_back(true);
310       break;
311     case spv::Op::OpConstantFalse:
312       literal_words_or_ids.push_back(false);
313       break;
314     case spv::Op::OpConstantNull:
315     case spv::Op::OpConstant:
316     case spv::Op::OpConstantComposite:
317     case spv::Op::OpSpecConstantComposite:
318       break;
319     default:
320       return nullptr;
321   }
322 
323   return GetConstant(GetType(inst), literal_words_or_ids);
324 }
325 
CreateInstruction(uint32_t id,const Constant * c,uint32_t type_id) const326 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
327     uint32_t id, const Constant* c, uint32_t type_id) const {
328   uint32_t type =
329       (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
330   if (c->AsNullConstant()) {
331     return MakeUnique<Instruction>(context(), spv::Op::OpConstantNull, type, id,
332                                    std::initializer_list<Operand>{});
333   } else if (const BoolConstant* bc = c->AsBoolConstant()) {
334     return MakeUnique<Instruction>(
335         context(),
336         bc->value() ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse, type,
337         id, std::initializer_list<Operand>{});
338   } else if (const IntConstant* ic = c->AsIntConstant()) {
339     return MakeUnique<Instruction>(
340         context(), spv::Op::OpConstant, type, id,
341         std::initializer_list<Operand>{
342             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
343                     ic->words())});
344   } else if (const FloatConstant* fc = c->AsFloatConstant()) {
345     return MakeUnique<Instruction>(
346         context(), spv::Op::OpConstant, type, id,
347         std::initializer_list<Operand>{
348             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
349                     fc->words())});
350   } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
351     return CreateCompositeInstruction(id, cc, type_id);
352   } else {
353     return nullptr;
354   }
355 }
356 
CreateCompositeInstruction(uint32_t result_id,const CompositeConstant * cc,uint32_t type_id) const357 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
358     uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
359   std::vector<Operand> operands;
360   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
361   uint32_t component_index = 0;
362   for (const Constant* component_const : cc->GetComponents()) {
363     uint32_t component_type_id = 0;
364     if (type_inst && type_inst->opcode() == spv::Op::OpTypeStruct) {
365       component_type_id = type_inst->GetSingleWordInOperand(component_index);
366     } else if (type_inst && type_inst->opcode() == spv::Op::OpTypeArray) {
367       component_type_id = type_inst->GetSingleWordInOperand(0);
368     }
369     uint32_t id = FindDeclaredConstant(component_const, component_type_id);
370 
371     if (id == 0) {
372       // Cannot get the id of the component constant, while all components
373       // should have been added to the module prior to the composite constant.
374       // Cannot create OpConstantComposite instruction in this case.
375       return nullptr;
376     }
377     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
378                           std::initializer_list<uint32_t>{id});
379     component_index++;
380   }
381   uint32_t type =
382       (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
383   return MakeUnique<Instruction>(context(), spv::Op::OpConstantComposite, type,
384                                  result_id, std::move(operands));
385 }
386 
GetConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids)387 const Constant* ConstantManager::GetConstant(
388     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
389   auto cst = CreateConstant(type, literal_words_or_ids);
390   return cst ? RegisterConstant(std::move(cst)) : nullptr;
391 }
392 
GetNullCompositeConstant(const Type * type)393 const Constant* ConstantManager::GetNullCompositeConstant(const Type* type) {
394   std::vector<uint32_t> literal_words_or_id;
395 
396   if (type->AsVector()) {
397     const Type* element_type = type->AsVector()->element_type();
398     const uint32_t null_id = GetNullConstId(element_type);
399     const uint32_t element_count = type->AsVector()->element_count();
400     for (uint32_t i = 0; i < element_count; i++) {
401       literal_words_or_id.push_back(null_id);
402     }
403   } else if (type->AsMatrix()) {
404     const Type* element_type = type->AsMatrix()->element_type();
405     const uint32_t null_id = GetNullConstId(element_type);
406     const uint32_t element_count = type->AsMatrix()->element_count();
407     for (uint32_t i = 0; i < element_count; i++) {
408       literal_words_or_id.push_back(null_id);
409     }
410   } else if (type->AsStruct()) {
411     // TODO (sfricke-lunarg) add proper struct support
412     return nullptr;
413   } else if (type->AsArray()) {
414     const Type* element_type = type->AsArray()->element_type();
415     const uint32_t null_id = GetNullConstId(element_type);
416     assert(type->AsArray()->length_info().words[0] ==
417                analysis::Array::LengthInfo::kConstant &&
418            "unexpected array length");
419     const uint32_t element_count = type->AsArray()->length_info().words[0];
420     for (uint32_t i = 0; i < element_count; i++) {
421       literal_words_or_id.push_back(null_id);
422     }
423   } else {
424     return nullptr;
425   }
426 
427   return GetConstant(type, literal_words_or_id);
428 }
429 
GetNumericVectorConstantWithWords(const Vector * type,const std::vector<uint32_t> & literal_words)430 const Constant* ConstantManager::GetNumericVectorConstantWithWords(
431     const Vector* type, const std::vector<uint32_t>& literal_words) {
432   const auto* element_type = type->element_type();
433   uint32_t words_per_element = 0;
434   if (const auto* float_type = element_type->AsFloat())
435     words_per_element = float_type->width() / 32;
436   else if (const auto* int_type = element_type->AsInteger())
437     words_per_element = int_type->width() / 32;
438   else if (element_type->AsBool() != nullptr)
439     words_per_element = 1;
440 
441   if (words_per_element != 1 && words_per_element != 2) return nullptr;
442 
443   if (words_per_element * type->element_count() !=
444       static_cast<uint32_t>(literal_words.size())) {
445     return nullptr;
446   }
447 
448   std::vector<uint32_t> element_ids;
449   for (uint32_t i = 0; i < type->element_count(); ++i) {
450     auto first_word = literal_words.begin() + (words_per_element * i);
451     std::vector<uint32_t> const_data(first_word,
452                                      first_word + words_per_element);
453     const analysis::Constant* element_constant =
454         GetConstant(element_type, const_data);
455     auto element_id = GetDefiningInstruction(element_constant)->result_id();
456     element_ids.push_back(element_id);
457   }
458 
459   return GetConstant(type, element_ids);
460 }
461 
GetFloatConstId(float val)462 uint32_t ConstantManager::GetFloatConstId(float val) {
463   const Constant* c = GetFloatConst(val);
464   return GetDefiningInstruction(c)->result_id();
465 }
466 
GetFloatConst(float val)467 const Constant* ConstantManager::GetFloatConst(float val) {
468   Type* float_type = context()->get_type_mgr()->GetFloatType();
469   utils::FloatProxy<float> v(val);
470   const Constant* c = GetConstant(float_type, v.GetWords());
471   return c;
472 }
473 
GetDoubleConstId(double val)474 uint32_t ConstantManager::GetDoubleConstId(double val) {
475   const Constant* c = GetDoubleConst(val);
476   return GetDefiningInstruction(c)->result_id();
477 }
478 
GetDoubleConst(double val)479 const Constant* ConstantManager::GetDoubleConst(double val) {
480   Type* float_type = context()->get_type_mgr()->GetDoubleType();
481   utils::FloatProxy<double> v(val);
482   const Constant* c = GetConstant(float_type, v.GetWords());
483   return c;
484 }
485 
GetSIntConstId(int32_t val)486 uint32_t ConstantManager::GetSIntConstId(int32_t val) {
487   Type* sint_type = context()->get_type_mgr()->GetSIntType();
488   const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
489   return GetDefiningInstruction(c)->result_id();
490 }
491 
GetIntConst(uint64_t val,int32_t bitWidth,bool isSigned)492 const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth,
493                                              bool isSigned) {
494   Type* int_type = context()->get_type_mgr()->GetIntType(bitWidth, isSigned);
495 
496   if (isSigned) {
497     // Sign extend the value.
498     int32_t num_of_bit_to_ignore = 64 - bitWidth;
499     val = static_cast<int64_t>(val << num_of_bit_to_ignore) >>
500           num_of_bit_to_ignore;
501   } else {
502     // Clear the upper bit that are not used.
503     uint64_t mask = ((1ull << bitWidth) - 1);
504     val &= mask;
505   }
506 
507   if (bitWidth <= 32) {
508     return GetConstant(int_type, {static_cast<uint32_t>(val)});
509   }
510 
511   // If the value is more than 32-bit, we need to split the operands into two
512   // 32-bit integers.
513   return GetConstant(
514       int_type, {static_cast<uint32_t>(val >> 32), static_cast<uint32_t>(val)});
515 }
516 
GetUIntConstId(uint32_t val)517 uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
518   Type* uint_type = context()->get_type_mgr()->GetUIntType();
519   const Constant* c = GetConstant(uint_type, {val});
520   return GetDefiningInstruction(c)->result_id();
521 }
522 
GetNullConstId(const Type * type)523 uint32_t ConstantManager::GetNullConstId(const Type* type) {
524   const Constant* c = GetConstant(type, {});
525   return GetDefiningInstruction(c)->result_id();
526 }
527 
GetVectorComponents(analysis::ConstantManager * const_mgr) const528 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
529     analysis::ConstantManager* const_mgr) const {
530   std::vector<const analysis::Constant*> components;
531   const analysis::VectorConstant* a = this->AsVectorConstant();
532   const analysis::Vector* vector_type = this->type()->AsVector();
533   assert(vector_type != nullptr);
534   if (a != nullptr) {
535     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
536       components.push_back(a->GetComponents()[i]);
537     }
538   } else {
539     const analysis::Type* element_type = vector_type->element_type();
540     const analysis::Constant* element_null_const =
541         const_mgr->GetConstant(element_type, {});
542     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
543       components.push_back(element_null_const);
544     }
545   }
546   return components;
547 }
548 
549 }  // namespace analysis
550 }  // namespace opt
551 }  // namespace spvtools
552