• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/constants.h"
16 
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "source/opt/ir_context.h"
21 
22 namespace spvtools {
23 namespace opt {
24 namespace analysis {
25 
GetFloat() const26 float Constant::GetFloat() const {
27   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
28 
29   if (const FloatConstant* fc = AsFloatConstant()) {
30     return fc->GetFloatValue();
31   } else {
32     assert(AsNullConstant() && "Must be a floating point constant.");
33     return 0.0f;
34   }
35 }
36 
GetDouble() const37 double Constant::GetDouble() const {
38   assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
39 
40   if (const FloatConstant* fc = AsFloatConstant()) {
41     return fc->GetDoubleValue();
42   } else {
43     assert(AsNullConstant() && "Must be a floating point constant.");
44     return 0.0;
45   }
46 }
47 
GetValueAsDouble() const48 double Constant::GetValueAsDouble() const {
49   assert(type()->AsFloat() != nullptr);
50   if (type()->AsFloat()->width() == 32) {
51     return GetFloat();
52   } else {
53     assert(type()->AsFloat()->width() == 64);
54     return GetDouble();
55   }
56 }
57 
GetU32() const58 uint32_t Constant::GetU32() const {
59   assert(type()->AsInteger() != nullptr);
60   assert(type()->AsInteger()->width() == 32);
61 
62   if (const IntConstant* ic = AsIntConstant()) {
63     return ic->GetU32BitValue();
64   } else {
65     assert(AsNullConstant() && "Must be an integer constant.");
66     return 0u;
67   }
68 }
69 
GetU64() const70 uint64_t Constant::GetU64() const {
71   assert(type()->AsInteger() != nullptr);
72   assert(type()->AsInteger()->width() == 64);
73 
74   if (const IntConstant* ic = AsIntConstant()) {
75     return ic->GetU64BitValue();
76   } else {
77     assert(AsNullConstant() && "Must be an integer constant.");
78     return 0u;
79   }
80 }
81 
GetS32() const82 int32_t Constant::GetS32() const {
83   assert(type()->AsInteger() != nullptr);
84   assert(type()->AsInteger()->width() == 32);
85 
86   if (const IntConstant* ic = AsIntConstant()) {
87     return ic->GetS32BitValue();
88   } else {
89     assert(AsNullConstant() && "Must be an integer constant.");
90     return 0;
91   }
92 }
93 
GetS64() const94 int64_t Constant::GetS64() const {
95   assert(type()->AsInteger() != nullptr);
96   assert(type()->AsInteger()->width() == 64);
97 
98   if (const IntConstant* ic = AsIntConstant()) {
99     return ic->GetS64BitValue();
100   } else {
101     assert(AsNullConstant() && "Must be an integer constant.");
102     return 0;
103   }
104 }
105 
GetZeroExtendedValue() const106 uint64_t Constant::GetZeroExtendedValue() const {
107   const auto* int_type = type()->AsInteger();
108   assert(int_type != nullptr);
109   const auto width = int_type->width();
110   assert(width <= 64);
111 
112   uint64_t value = 0;
113   if (const IntConstant* ic = AsIntConstant()) {
114     if (width <= 32) {
115       value = ic->GetU32BitValue();
116     } else {
117       value = ic->GetU64BitValue();
118     }
119   } else {
120     assert(AsNullConstant() && "Must be an integer constant.");
121   }
122   return value;
123 }
124 
GetSignExtendedValue() const125 int64_t Constant::GetSignExtendedValue() const {
126   const auto* int_type = type()->AsInteger();
127   assert(int_type != nullptr);
128   const auto width = int_type->width();
129   assert(width <= 64);
130 
131   int64_t value = 0;
132   if (const IntConstant* ic = AsIntConstant()) {
133     if (width <= 32) {
134       // Let the C++ compiler do the sign extension.
135       value = int64_t(ic->GetS32BitValue());
136     } else {
137       value = ic->GetS64BitValue();
138     }
139   } else {
140     assert(AsNullConstant() && "Must be an integer constant.");
141   }
142   return value;
143 }
144 
ConstantManager(IRContext * ctx)145 ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
146   // Populate the constant table with values from constant declarations in the
147   // module.  The values of each OpConstant declaration is the identity
148   // assignment (i.e., each constant is its own value).
149   for (const auto& inst : ctx_->module()->GetConstants()) {
150     MapInst(inst);
151   }
152 }
153 
GetType(const Instruction * inst) const154 Type* ConstantManager::GetType(const Instruction* inst) const {
155   return context()->get_type_mgr()->GetType(inst->type_id());
156 }
157 
GetOperandConstants(const Instruction * inst) const158 std::vector<const Constant*> ConstantManager::GetOperandConstants(
159     const Instruction* inst) const {
160   std::vector<const Constant*> constants;
161   constants.reserve(inst->NumInOperands());
162   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
163     const Operand* operand = &inst->GetInOperand(i);
164     if (operand->type != SPV_OPERAND_TYPE_ID) {
165       constants.push_back(nullptr);
166     } else {
167       uint32_t id = operand->words[0];
168       const analysis::Constant* constant = FindDeclaredConstant(id);
169       constants.push_back(constant);
170     }
171   }
172   return constants;
173 }
174 
FindDeclaredConstant(const Constant * c,uint32_t type_id) const175 uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
176                                                uint32_t type_id) const {
177   c = FindConstant(c);
178   if (c == nullptr) {
179     return 0;
180   }
181 
182   for (auto range = const_val_to_id_.equal_range(c);
183        range.first != range.second; ++range.first) {
184     Instruction* const_def =
185         context()->get_def_use_mgr()->GetDef(range.first->second);
186     if (type_id == 0 || const_def->type_id() == type_id) {
187       return range.first->second;
188     }
189   }
190   return 0;
191 }
192 
GetConstantsFromIds(const std::vector<uint32_t> & ids) const193 std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
194     const std::vector<uint32_t>& ids) const {
195   std::vector<const Constant*> constants;
196   for (uint32_t id : ids) {
197     if (const Constant* c = FindDeclaredConstant(id)) {
198       constants.push_back(c);
199     } else {
200       return {};
201     }
202   }
203   return constants;
204 }
205 
BuildInstructionAndAddToModule(const Constant * new_const,Module::inst_iterator * pos,uint32_t type_id)206 Instruction* ConstantManager::BuildInstructionAndAddToModule(
207     const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
208   // TODO(1841): Handle id overflow.
209   uint32_t new_id = context()->TakeNextId();
210   if (new_id == 0) {
211     return nullptr;
212   }
213 
214   auto new_inst = CreateInstruction(new_id, new_const, type_id);
215   if (!new_inst) {
216     return nullptr;
217   }
218   auto* new_inst_ptr = new_inst.get();
219   *pos = pos->InsertBefore(std::move(new_inst));
220   ++(*pos);
221   if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse))
222     context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
223   MapConstantToInst(new_const, new_inst_ptr);
224   return new_inst_ptr;
225 }
226 
GetDefiningInstruction(const Constant * c,uint32_t type_id,Module::inst_iterator * pos)227 Instruction* ConstantManager::GetDefiningInstruction(
228     const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
229   uint32_t decl_id = FindDeclaredConstant(c, type_id);
230   if (decl_id == 0) {
231     auto iter = context()->types_values_end();
232     if (pos == nullptr) pos = &iter;
233     return BuildInstructionAndAddToModule(c, pos, type_id);
234   } else {
235     auto def = context()->get_def_use_mgr()->GetDef(decl_id);
236     assert(def != nullptr);
237     assert((type_id == 0 || def->type_id() == type_id) &&
238            "This constant already has an instruction with a different type.");
239     return def;
240   }
241 }
242 
CreateConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids) const243 std::unique_ptr<Constant> ConstantManager::CreateConstant(
244     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
245   if (literal_words_or_ids.size() == 0) {
246     // Constant declared with OpConstantNull
247     return MakeUnique<NullConstant>(type);
248   } else if (auto* bt = type->AsBool()) {
249     assert(literal_words_or_ids.size() == 1 &&
250            "Bool constant should be declared with one operand");
251     return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
252   } else if (auto* it = type->AsInteger()) {
253     return MakeUnique<IntConstant>(it, literal_words_or_ids);
254   } else if (auto* ft = type->AsFloat()) {
255     return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
256   } else if (auto* vt = type->AsVector()) {
257     auto components = GetConstantsFromIds(literal_words_or_ids);
258     if (components.empty()) return nullptr;
259     // All components of VectorConstant must be of type Bool, Integer or Float.
260     if (!std::all_of(components.begin(), components.end(),
261                      [](const Constant* c) {
262                        if (c->type()->AsBool() || c->type()->AsInteger() ||
263                            c->type()->AsFloat()) {
264                          return true;
265                        } else {
266                          return false;
267                        }
268                      }))
269       return nullptr;
270     // All components of VectorConstant must be in the same type.
271     const auto* component_type = components.front()->type();
272     if (!std::all_of(components.begin(), components.end(),
273                      [&component_type](const Constant* c) {
274                        if (c->type() == component_type) return true;
275                        return false;
276                      }))
277       return nullptr;
278     return MakeUnique<VectorConstant>(vt, components);
279   } else if (auto* mt = type->AsMatrix()) {
280     auto components = GetConstantsFromIds(literal_words_or_ids);
281     if (components.empty()) return nullptr;
282     return MakeUnique<MatrixConstant>(mt, components);
283   } else if (auto* st = type->AsStruct()) {
284     auto components = GetConstantsFromIds(literal_words_or_ids);
285     if (components.empty()) return nullptr;
286     return MakeUnique<StructConstant>(st, components);
287   } else if (auto* at = type->AsArray()) {
288     auto components = GetConstantsFromIds(literal_words_or_ids);
289     if (components.empty()) return nullptr;
290     return MakeUnique<ArrayConstant>(at, components);
291   } else {
292     return nullptr;
293   }
294 }
295 
GetConstantFromInst(const Instruction * inst)296 const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
297   std::vector<uint32_t> literal_words_or_ids;
298 
299   // Collect the constant defining literals or component ids.
300   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
301     literal_words_or_ids.insert(literal_words_or_ids.end(),
302                                 inst->GetInOperand(i).words.begin(),
303                                 inst->GetInOperand(i).words.end());
304   }
305 
306   switch (inst->opcode()) {
307     // OpConstant{True|False} have the value embedded in the opcode. So they
308     // are not handled by the for-loop above. Here we add the value explicitly.
309     case SpvOp::SpvOpConstantTrue:
310       literal_words_or_ids.push_back(true);
311       break;
312     case SpvOp::SpvOpConstantFalse:
313       literal_words_or_ids.push_back(false);
314       break;
315     case SpvOp::SpvOpConstantNull:
316     case SpvOp::SpvOpConstant:
317     case SpvOp::SpvOpConstantComposite:
318     case SpvOp::SpvOpSpecConstantComposite:
319       break;
320     default:
321       return nullptr;
322   }
323 
324   return GetConstant(GetType(inst), literal_words_or_ids);
325 }
326 
CreateInstruction(uint32_t id,const Constant * c,uint32_t type_id) const327 std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
328     uint32_t id, const Constant* c, uint32_t type_id) const {
329   uint32_t type =
330       (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
331   if (c->AsNullConstant()) {
332     return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type,
333                                    id, std::initializer_list<Operand>{});
334   } else if (const BoolConstant* bc = c->AsBoolConstant()) {
335     return MakeUnique<Instruction>(
336         context(),
337         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
338         type, id, std::initializer_list<Operand>{});
339   } else if (const IntConstant* ic = c->AsIntConstant()) {
340     return MakeUnique<Instruction>(
341         context(), SpvOp::SpvOpConstant, type, id,
342         std::initializer_list<Operand>{
343             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
344                     ic->words())});
345   } else if (const FloatConstant* fc = c->AsFloatConstant()) {
346     return MakeUnique<Instruction>(
347         context(), SpvOp::SpvOpConstant, type, id,
348         std::initializer_list<Operand>{
349             Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
350                     fc->words())});
351   } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
352     return CreateCompositeInstruction(id, cc, type_id);
353   } else {
354     return nullptr;
355   }
356 }
357 
CreateCompositeInstruction(uint32_t result_id,const CompositeConstant * cc,uint32_t type_id) const358 std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
359     uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
360   std::vector<Operand> operands;
361   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
362   uint32_t component_index = 0;
363   for (const Constant* component_const : cc->GetComponents()) {
364     uint32_t component_type_id = 0;
365     if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
366       component_type_id = type_inst->GetSingleWordInOperand(component_index);
367     } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
368       component_type_id = type_inst->GetSingleWordInOperand(0);
369     }
370     uint32_t id = FindDeclaredConstant(component_const, component_type_id);
371 
372     if (id == 0) {
373       // Cannot get the id of the component constant, while all components
374       // should have been added to the module prior to the composite constant.
375       // Cannot create OpConstantComposite instruction in this case.
376       return nullptr;
377     }
378     operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
379                           std::initializer_list<uint32_t>{id});
380     component_index++;
381   }
382   uint32_t type =
383       (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
384   return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type,
385                                  result_id, std::move(operands));
386 }
387 
GetConstant(const Type * type,const std::vector<uint32_t> & literal_words_or_ids)388 const Constant* ConstantManager::GetConstant(
389     const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
390   auto cst = CreateConstant(type, literal_words_or_ids);
391   return cst ? RegisterConstant(std::move(cst)) : nullptr;
392 }
393 
GetNumericVectorConstantWithWords(const Vector * type,const std::vector<uint32_t> & literal_words)394 const Constant* ConstantManager::GetNumericVectorConstantWithWords(
395     const Vector* type, const std::vector<uint32_t>& literal_words) {
396   const auto* element_type = type->element_type();
397   uint32_t words_per_element = 0;
398   if (const auto* float_type = element_type->AsFloat())
399     words_per_element = float_type->width() / 32;
400   else if (const auto* int_type = element_type->AsInteger())
401     words_per_element = int_type->width() / 32;
402 
403   if (words_per_element != 1 && words_per_element != 2) return nullptr;
404 
405   if (words_per_element * type->element_count() !=
406       static_cast<uint32_t>(literal_words.size())) {
407     return nullptr;
408   }
409 
410   std::vector<uint32_t> element_ids;
411   for (uint32_t i = 0; i < type->element_count(); ++i) {
412     auto first_word = literal_words.begin() + (words_per_element * i);
413     std::vector<uint32_t> const_data(first_word,
414                                      first_word + words_per_element);
415     const analysis::Constant* element_constant =
416         GetConstant(element_type, const_data);
417     auto element_id = GetDefiningInstruction(element_constant)->result_id();
418     element_ids.push_back(element_id);
419   }
420 
421   return GetConstant(type, element_ids);
422 }
423 
GetFloatConstId(float val)424 uint32_t ConstantManager::GetFloatConstId(float val) {
425   const Constant* c = GetFloatConst(val);
426   return GetDefiningInstruction(c)->result_id();
427 }
428 
GetFloatConst(float val)429 const Constant* ConstantManager::GetFloatConst(float val) {
430   Type* float_type = context()->get_type_mgr()->GetFloatType();
431   utils::FloatProxy<float> v(val);
432   const Constant* c = GetConstant(float_type, v.GetWords());
433   return c;
434 }
435 
GetDoubleConstId(double val)436 uint32_t ConstantManager::GetDoubleConstId(double val) {
437   const Constant* c = GetDoubleConst(val);
438   return GetDefiningInstruction(c)->result_id();
439 }
440 
GetDoubleConst(double val)441 const Constant* ConstantManager::GetDoubleConst(double val) {
442   Type* float_type = context()->get_type_mgr()->GetDoubleType();
443   utils::FloatProxy<double> v(val);
444   const Constant* c = GetConstant(float_type, v.GetWords());
445   return c;
446 }
447 
GetSIntConst(int32_t val)448 uint32_t ConstantManager::GetSIntConst(int32_t val) {
449   Type* sint_type = context()->get_type_mgr()->GetSIntType();
450   const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
451   return GetDefiningInstruction(c)->result_id();
452 }
453 
GetUIntConst(uint32_t val)454 uint32_t ConstantManager::GetUIntConst(uint32_t val) {
455   Type* uint_type = context()->get_type_mgr()->GetUIntType();
456   const Constant* c = GetConstant(uint_type, {val});
457   return GetDefiningInstruction(c)->result_id();
458 }
459 
GetVectorComponents(analysis::ConstantManager * const_mgr) const460 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
461     analysis::ConstantManager* const_mgr) const {
462   std::vector<const analysis::Constant*> components;
463   const analysis::VectorConstant* a = this->AsVectorConstant();
464   const analysis::Vector* vector_type = this->type()->AsVector();
465   assert(vector_type != nullptr);
466   if (a != nullptr) {
467     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
468       components.push_back(a->GetComponents()[i]);
469     }
470   } else {
471     const analysis::Type* element_type = vector_type->element_type();
472     const analysis::Constant* element_null_const =
473         const_mgr->GetConstant(element_type, {});
474     for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
475       components.push_back(element_null_const);
476     }
477   }
478   return components;
479 }
480 
481 }  // namespace analysis
482 }  // namespace opt
483 }  // namespace spvtools
484