• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_analysis.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <string>
20 #include <utility>
21 
22 #include "source/opt/ir_context.h"
23 
24 // Transforms a given scalar operation instruction into a DAG representation.
25 //
26 // 1. Take an instruction and traverse its operands until we reach a
27 // constant node or an instruction which we do not know how to compute the
28 // value, such as a load.
29 //
30 // 2. Create a new node for each instruction traversed and build the nodes for
31 // the in operands of that instruction as well.
32 //
33 // 3. Add the operand nodes as children of the first and hash the node. Use the
34 // hash to see if the node is already in the cache. We ensure the children are
35 // always in sorted order so that two nodes with the same children but inserted
36 // in a different order have the same hash and so that the overloaded operator==
37 // will return true. If the node is already in the cache return the cached
38 // version instead.
39 //
40 // 4. The created DAG can then be simplified by
41 // ScalarAnalysis::SimplifyExpression, implemented in
42 // scalar_analysis_simplification.cpp. See that file for further information on
43 // the simplification process.
44 //
45 
46 namespace spvtools {
47 namespace opt {
48 
49 uint32_t SENode::NumberOfNodes = 0;
50 
ScalarEvolutionAnalysis(IRContext * context)51 ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
52     : context_(context), pretend_equal_{} {
53   // Create and cached the CantComputeNode.
54   cached_cant_compute_ =
55       GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
56 }
57 
CreateNegation(SENode * operand)58 SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
59   // If operand is can't compute then the whole graph is can't compute.
60   if (operand->IsCantCompute()) return CreateCantComputeNode();
61 
62   if (operand->GetType() == SENode::Constant) {
63     return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
64   }
65   std::unique_ptr<SENode> negation_node{new SENegative(this)};
66   negation_node->AddChild(operand);
67   return GetCachedOrAdd(std::move(negation_node));
68 }
69 
CreateConstant(int64_t integer)70 SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
71   return GetCachedOrAdd(
72       std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
73 }
74 
CreateRecurrentExpression(const Loop * loop,SENode * offset,SENode * coefficient)75 SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
76     const Loop* loop, SENode* offset, SENode* coefficient) {
77   assert(loop && "Recurrent add expressions must have a valid loop.");
78 
79   // If operands are can't compute then the whole graph is can't compute.
80   if (offset->IsCantCompute() || coefficient->IsCantCompute())
81     return CreateCantComputeNode();
82 
83   const Loop* loop_to_use = nullptr;
84   if (pretend_equal_[loop]) {
85     loop_to_use = pretend_equal_[loop];
86   } else {
87     loop_to_use = loop;
88   }
89 
90   std::unique_ptr<SERecurrentNode> phi_node{
91       new SERecurrentNode(this, loop_to_use)};
92   phi_node->AddOffset(offset);
93   phi_node->AddCoefficient(coefficient);
94 
95   return GetCachedOrAdd(std::move(phi_node));
96 }
97 
AnalyzeMultiplyOp(const Instruction * multiply)98 SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
99     const Instruction* multiply) {
100   assert(multiply->opcode() == SpvOp::SpvOpIMul &&
101          "Multiply node did not come from a multiply instruction");
102   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
103 
104   SENode* op1 =
105       AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
106   SENode* op2 =
107       AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
108 
109   return CreateMultiplyNode(op1, op2);
110 }
111 
CreateMultiplyNode(SENode * operand_1,SENode * operand_2)112 SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
113                                                     SENode* operand_2) {
114   // If operands are can't compute then the whole graph is can't compute.
115   if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
116     return CreateCantComputeNode();
117 
118   if (operand_1->GetType() == SENode::Constant &&
119       operand_2->GetType() == SENode::Constant) {
120     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
121                           operand_2->AsSEConstantNode()->FoldToSingleValue());
122   }
123 
124   std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
125 
126   multiply_node->AddChild(operand_1);
127   multiply_node->AddChild(operand_2);
128 
129   return GetCachedOrAdd(std::move(multiply_node));
130 }
131 
CreateSubtraction(SENode * operand_1,SENode * operand_2)132 SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
133                                                    SENode* operand_2) {
134   // Fold if both operands are constant.
135   if (operand_1->GetType() == SENode::Constant &&
136       operand_2->GetType() == SENode::Constant) {
137     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
138                           operand_2->AsSEConstantNode()->FoldToSingleValue());
139   }
140 
141   return CreateAddNode(operand_1, CreateNegation(operand_2));
142 }
143 
CreateAddNode(SENode * operand_1,SENode * operand_2)144 SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
145                                                SENode* operand_2) {
146   // Fold if both operands are constant and the |simplify| flag is true.
147   if (operand_1->GetType() == SENode::Constant &&
148       operand_2->GetType() == SENode::Constant) {
149     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
150                           operand_2->AsSEConstantNode()->FoldToSingleValue());
151   }
152 
153   // If operands are can't compute then the whole graph is can't compute.
154   if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
155     return CreateCantComputeNode();
156 
157   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
158 
159   add_node->AddChild(operand_1);
160   add_node->AddChild(operand_2);
161 
162   return GetCachedOrAdd(std::move(add_node));
163 }
164 
AnalyzeInstruction(const Instruction * inst)165 SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
166   auto itr = recurrent_node_map_.find(inst);
167   if (itr != recurrent_node_map_.end()) return itr->second;
168 
169   SENode* output = nullptr;
170   switch (inst->opcode()) {
171     case SpvOp::SpvOpPhi: {
172       output = AnalyzePhiInstruction(inst);
173       break;
174     }
175     case SpvOp::SpvOpConstant:
176     case SpvOp::SpvOpConstantNull: {
177       output = AnalyzeConstant(inst);
178       break;
179     }
180     case SpvOp::SpvOpISub:
181     case SpvOp::SpvOpIAdd: {
182       output = AnalyzeAddOp(inst);
183       break;
184     }
185     case SpvOp::SpvOpIMul: {
186       output = AnalyzeMultiplyOp(inst);
187       break;
188     }
189     default: {
190       output = CreateValueUnknownNode(inst);
191       break;
192     }
193   }
194 
195   return output;
196 }
197 
AnalyzeConstant(const Instruction * inst)198 SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
199   if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0);
200 
201   assert(inst->opcode() == SpvOp::SpvOpConstant);
202   assert(inst->NumInOperands() == 1);
203   int64_t value = 0;
204 
205   // Look up the instruction in the constant manager.
206   const analysis::Constant* constant =
207       context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
208 
209   if (!constant) return CreateCantComputeNode();
210 
211   const analysis::IntConstant* int_constant = constant->AsIntConstant();
212 
213   // Exit out if it is a 64 bit integer.
214   if (!int_constant || int_constant->words().size() != 1)
215     return CreateCantComputeNode();
216 
217   if (int_constant->type()->AsInteger()->IsSigned()) {
218     value = int_constant->GetS32BitValue();
219   } else {
220     value = int_constant->GetU32BitValue();
221   }
222 
223   return CreateConstant(value);
224 }
225 
226 // Handles both addition and subtraction. If the |sub| flag is set then the
227 // addition will be op1+(-op2) otherwise op1+op2.
AnalyzeAddOp(const Instruction * inst)228 SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
229   assert((inst->opcode() == SpvOp::SpvOpIAdd ||
230           inst->opcode() == SpvOp::SpvOpISub) &&
231          "Add node must be created from a OpIAdd or OpISub instruction");
232 
233   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
234 
235   SENode* op1 =
236       AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
237 
238   SENode* op2 =
239       AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
240 
241   // To handle subtraction we wrap the second operand in a unary negation node.
242   if (inst->opcode() == SpvOp::SpvOpISub) {
243     op2 = CreateNegation(op2);
244   }
245 
246   return CreateAddNode(op1, op2);
247 }
248 
AnalyzePhiInstruction(const Instruction * phi)249 SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
250   // The phi should only have two incoming value pairs.
251   if (phi->NumInOperands() != 4) {
252     return CreateCantComputeNode();
253   }
254 
255   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
256 
257   // Get the basic block this instruction belongs to.
258   BasicBlock* basic_block =
259       context_->get_instr_block(const_cast<Instruction*>(phi));
260 
261   // And then the function that the basic blocks belongs to.
262   Function* function = basic_block->GetParent();
263 
264   // Use the function to get the loop descriptor.
265   LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
266 
267   // We only handle phis in loops at the moment.
268   if (!loop_descriptor) return CreateCantComputeNode();
269 
270   // Get the innermost loop which this block belongs to.
271   Loop* loop = (*loop_descriptor)[basic_block->id()];
272 
273   // If the loop doesn't exist or doesn't have a preheader or latch block, exit
274   // out.
275   if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
276       loop->GetHeaderBlock() != basic_block)
277     return recurrent_node_map_[phi] = CreateCantComputeNode();
278 
279   const Loop* loop_to_use = nullptr;
280   if (pretend_equal_[loop]) {
281     loop_to_use = pretend_equal_[loop];
282   } else {
283     loop_to_use = loop;
284   }
285   std::unique_ptr<SERecurrentNode> phi_node{
286       new SERecurrentNode(this, loop_to_use)};
287 
288   // We add the node to this map to allow it to be returned before the node is
289   // fully built. This is needed as the subsequent call to AnalyzeInstruction
290   // could lead back to this |phi| instruction so we return the pointer
291   // immediately in AnalyzeInstruction to break the recursion.
292   recurrent_node_map_[phi] = phi_node.get();
293 
294   // Traverse the operands of the instruction an create new nodes for each one.
295   for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
296     uint32_t value_id = phi->GetSingleWordInOperand(i);
297     uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
298 
299     Instruction* value_inst = def_use->GetDef(value_id);
300     SENode* value_node = AnalyzeInstruction(value_inst);
301 
302     // If any operand is CantCompute then the whole graph is CantCompute.
303     if (value_node->IsCantCompute())
304       return recurrent_node_map_[phi] = CreateCantComputeNode();
305 
306     // If the value is coming from the preheader block then the value is the
307     // initial value of the phi.
308     if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
309       phi_node->AddOffset(value_node);
310     } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
311       // Assumed to be in the form of step + phi.
312       if (value_node->GetType() != SENode::Add)
313         return recurrent_node_map_[phi] = CreateCantComputeNode();
314 
315       SENode* step_node = nullptr;
316       SENode* phi_operand = nullptr;
317       SENode* operand_1 = value_node->GetChild(0);
318       SENode* operand_2 = value_node->GetChild(1);
319 
320       // Find which node is the step term.
321       if (!operand_1->AsSERecurrentNode())
322         step_node = operand_1;
323       else if (!operand_2->AsSERecurrentNode())
324         step_node = operand_2;
325 
326       // Find which node is the recurrent expression.
327       if (operand_1->AsSERecurrentNode())
328         phi_operand = operand_1;
329       else if (operand_2->AsSERecurrentNode())
330         phi_operand = operand_2;
331 
332       // If it is not in the form step + phi exit out.
333       if (!(step_node && phi_operand))
334         return recurrent_node_map_[phi] = CreateCantComputeNode();
335 
336       // If the phi operand is not the same phi node exit out.
337       if (phi_operand != phi_node.get())
338         return recurrent_node_map_[phi] = CreateCantComputeNode();
339 
340       if (!IsLoopInvariant(loop, step_node))
341         return recurrent_node_map_[phi] = CreateCantComputeNode();
342 
343       phi_node->AddCoefficient(step_node);
344     }
345   }
346 
347   // Once the node is fully built we update the map with the version from the
348   // cache (if it has already been added to the cache).
349   return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
350 }
351 
CreateValueUnknownNode(const Instruction * inst)352 SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
353     const Instruction* inst) {
354   std::unique_ptr<SEValueUnknown> load_node{
355       new SEValueUnknown(this, inst->result_id())};
356   return GetCachedOrAdd(std::move(load_node));
357 }
358 
CreateCantComputeNode()359 SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
360   return cached_cant_compute_;
361 }
362 
363 // Add the created node into the cache of nodes. If it already exists return it.
GetCachedOrAdd(std::unique_ptr<SENode> prospective_node)364 SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
365     std::unique_ptr<SENode> prospective_node) {
366   auto itr = node_cache_.find(prospective_node);
367   if (itr != node_cache_.end()) {
368     return (*itr).get();
369   }
370 
371   SENode* raw_ptr_to_node = prospective_node.get();
372   node_cache_.insert(std::move(prospective_node));
373   return raw_ptr_to_node;
374 }
375 
IsLoopInvariant(const Loop * loop,const SENode * node) const376 bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
377                                               const SENode* node) const {
378   for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
379     if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
380       const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
381 
382       // If the loop which the recurrent expression belongs to is either |loop
383       // or a nested loop inside |loop| then we assume it is variant.
384       if (loop->IsInsideLoop(header)) {
385         return false;
386       }
387     } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
388       // If the instruction is inside the loop we conservatively assume it is
389       // loop variant.
390       if (loop->IsInsideLoop(unknown->ResultId())) return false;
391     }
392   }
393 
394   return true;
395 }
396 
GetCoefficientFromRecurrentTerm(SENode * node,const Loop * loop)397 SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
398     SENode* node, const Loop* loop) {
399   // Traverse the DAG to find the recurrent expression belonging to |loop|.
400   for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
401     SERecurrentNode* rec = itr->AsSERecurrentNode();
402     if (rec && rec->GetLoop() == loop) {
403       return rec->GetCoefficient();
404     }
405   }
406   return CreateConstant(0);
407 }
408 
UpdateChildNode(SENode * parent,SENode * old_child,SENode * new_child)409 SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
410                                                  SENode* old_child,
411                                                  SENode* new_child) {
412   // Only handles add.
413   if (parent->GetType() != SENode::Add) return parent;
414 
415   std::vector<SENode*> new_children;
416   for (SENode* child : *parent) {
417     if (child == old_child) {
418       new_children.push_back(new_child);
419     } else {
420       new_children.push_back(child);
421     }
422   }
423 
424   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
425   for (SENode* child : new_children) {
426     add_node->AddChild(child);
427   }
428 
429   return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
430 }
431 
432 // Rebuild the |node| eliminating, if it exists, the recurrent term which
433 // belongs to the |loop|.
BuildGraphWithoutRecurrentTerm(SENode * node,const Loop * loop)434 SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
435     SENode* node, const Loop* loop) {
436   // If the node is already a recurrent expression belonging to loop then just
437   // return the offset.
438   SERecurrentNode* recurrent = node->AsSERecurrentNode();
439   if (recurrent) {
440     if (recurrent->GetLoop() == loop) {
441       return recurrent->GetOffset();
442     } else {
443       return node;
444     }
445   }
446 
447   std::vector<SENode*> new_children;
448   // Otherwise find the recurrent node in the children of this node.
449   for (auto itr : *node) {
450     recurrent = itr->AsSERecurrentNode();
451     if (recurrent && recurrent->GetLoop() == loop) {
452       new_children.push_back(recurrent->GetOffset());
453     } else {
454       new_children.push_back(itr);
455     }
456   }
457 
458   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
459   for (SENode* child : new_children) {
460     add_node->AddChild(child);
461   }
462 
463   return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
464 }
465 
466 // Return the recurrent term belonging to |loop| if it appears in the graph
467 // starting at |node| or null if it doesn't.
GetRecurrentTerm(SENode * node,const Loop * loop)468 SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
469                                                            const Loop* loop) {
470   for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
471     SERecurrentNode* rec = itr->AsSERecurrentNode();
472     if (rec && rec->GetLoop() == loop) {
473       return rec;
474     }
475   }
476   return nullptr;
477 }
AsString() const478 std::string SENode::AsString() const {
479   switch (GetType()) {
480     case Constant:
481       return "Constant";
482     case RecurrentAddExpr:
483       return "RecurrentAddExpr";
484     case Add:
485       return "Add";
486     case Negative:
487       return "Negative";
488     case Multiply:
489       return "Multiply";
490     case ValueUnknown:
491       return "Value Unknown";
492     case CanNotCompute:
493       return "Can not compute";
494   }
495   return "NULL";
496 }
497 
operator ==(const SENode & other) const498 bool SENode::operator==(const SENode& other) const {
499   if (GetType() != other.GetType()) return false;
500 
501   if (other.GetChildren().size() != children_.size()) return false;
502 
503   const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
504 
505   // Check the children are the same, for SERecurrentNodes we need to check the
506   // offset and coefficient manually as the child vector is sorted by ids so the
507   // offset/coefficient information is lost.
508   if (!this_as_recurrent) {
509     for (size_t index = 0; index < children_.size(); ++index) {
510       if (other.GetChildren()[index] != children_[index]) return false;
511     }
512   } else {
513     const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
514 
515     // We've already checked the types are the same, this should not fail if
516     // this->AsSERecurrentNode() succeeded.
517     assert(other_as_recurrent);
518 
519     if (this_as_recurrent->GetCoefficient() !=
520         other_as_recurrent->GetCoefficient())
521       return false;
522 
523     if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
524       return false;
525 
526     if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
527       return false;
528   }
529 
530   // If we're dealing with a value unknown node check both nodes were created by
531   // the same instruction.
532   if (GetType() == SENode::ValueUnknown) {
533     if (AsSEValueUnknown()->ResultId() !=
534         other.AsSEValueUnknown()->ResultId()) {
535       return false;
536     }
537   }
538 
539   if (AsSEConstantNode()) {
540     if (AsSEConstantNode()->FoldToSingleValue() !=
541         other.AsSEConstantNode()->FoldToSingleValue())
542       return false;
543   }
544 
545   return true;
546 }
547 
operator !=(const SENode & other) const548 bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
549 
550 namespace {
551 // Helper functions to insert 32/64 bit values into the 32 bit hash string. This
552 // allows us to add pointers to the string by reinterpreting the pointers as
553 // uintptr_t. PushToString will deduce the type, call sizeof on it and use
554 // that size to call into the correct PushToStringImpl functor depending on
555 // whether it is 32 or 64 bit.
556 
557 template <typename T, size_t size_of_t>
558 struct PushToStringImpl;
559 
560 template <typename T>
561 struct PushToStringImpl<T, 8> {
operator ()spvtools::opt::__anoncbbc9e310111::PushToStringImpl562   void operator()(T id, std::u32string* str) {
563     str->push_back(static_cast<uint32_t>(id >> 32));
564     str->push_back(static_cast<uint32_t>(id));
565   }
566 };
567 
568 template <typename T>
569 struct PushToStringImpl<T, 4> {
operator ()spvtools::opt::__anoncbbc9e310111::PushToStringImpl570   void operator()(T id, std::u32string* str) {
571     str->push_back(static_cast<uint32_t>(id));
572   }
573 };
574 
575 template <typename T>
PushToString(T id,std::u32string * str)576 static void PushToString(T id, std::u32string* str) {
577   PushToStringImpl<T, sizeof(T)>{}(id, str);
578 }
579 
580 }  // namespace
581 
582 // Implements the hashing of SENodes.
operator ()(const SENode * node) const583 size_t SENodeHash::operator()(const SENode* node) const {
584   // Concatinate the terms into a string which we can hash.
585   std::u32string hash_string{};
586 
587   // Hashing the type as a string is safer than hashing the enum as the enum is
588   // very likely to collide with constants.
589   for (char ch : node->AsString()) {
590     hash_string.push_back(static_cast<char32_t>(ch));
591   }
592 
593   // We just ignore the literal value unless it is a constant.
594   if (node->GetType() == SENode::Constant)
595     PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
596 
597   const SERecurrentNode* recurrent = node->AsSERecurrentNode();
598 
599   // If we're dealing with a recurrent expression hash the loop as well so that
600   // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
601   if (recurrent) {
602     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
603                  &hash_string);
604 
605     // Recurrent expressions can't be hashed using the normal method as the
606     // order of coefficient and offset matters to the hash.
607     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
608                  &hash_string);
609     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
610                  &hash_string);
611 
612     return std::hash<std::u32string>{}(hash_string);
613   }
614 
615   // Hash the result id of the original instruction which created this node if
616   // it is a value unknown node.
617   if (node->GetType() == SENode::ValueUnknown) {
618     PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
619   }
620 
621   // Hash the pointers of the child nodes, each SENode has a unique pointer
622   // associated with it.
623   const std::vector<SENode*>& children = node->GetChildren();
624   for (const SENode* child : children) {
625     PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
626   }
627 
628   return std::hash<std::u32string>{}(hash_string);
629 }
630 
631 // This overload is the actual overload used by the node_cache_ set.
operator ()(const std::unique_ptr<SENode> & node) const632 size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
633   return this->operator()(node.get());
634 }
635 
DumpDot(std::ostream & out,bool recurse) const636 void SENode::DumpDot(std::ostream& out, bool recurse) const {
637   size_t unique_id = std::hash<const SENode*>{}(this);
638   out << unique_id << " [label=\"" << AsString() << " ";
639   if (GetType() == SENode::Constant) {
640     out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
641   }
642   out << "\"]\n";
643   for (const SENode* child : children_) {
644     size_t child_unique_id = std::hash<const SENode*>{}(child);
645     out << unique_id << " -> " << child_unique_id << " \n";
646     if (recurse) child->DumpDot(out, true);
647   }
648 }
649 
650 namespace {
651 class IsGreaterThanZero {
652  public:
IsGreaterThanZero(IRContext * context)653   explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
654 
655   // Determine if the value of |node| is always strictly greater than zero if
656   // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
657   // true. It returns true is the evaluation was able to conclude something, in
658   // which case the result is stored in |result|.
659   // The algorithm work by going through all the nodes and determine the
660   // sign of each of them.
Eval(const SENode * node,bool or_equal_zero,bool * result)661   bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
662     *result = false;
663     switch (Visit(node)) {
664       case Signedness::kPositiveOrNegative: {
665         return false;
666       }
667       case Signedness::kStrictlyNegative: {
668         *result = false;
669         break;
670       }
671       case Signedness::kNegative: {
672         if (!or_equal_zero) {
673           return false;
674         }
675         *result = false;
676         break;
677       }
678       case Signedness::kStrictlyPositive: {
679         *result = true;
680         break;
681       }
682       case Signedness::kPositive: {
683         if (!or_equal_zero) {
684           return false;
685         }
686         *result = true;
687         break;
688       }
689     }
690     return true;
691   }
692 
693  private:
694   enum class Signedness {
695     kPositiveOrNegative,  // Yield a value positive or negative.
696     kStrictlyNegative,    // Yield a value strictly less than 0.
697     kNegative,            // Yield a value less or equal to 0.
698     kStrictlyPositive,    // Yield a value strictly greater than 0.
699     kPositive             // Yield a value greater or equal to 0.
700   };
701 
702   // Combine the signedness according to arithmetic rules of a given operator.
703   using Combiner = std::function<Signedness(Signedness, Signedness)>;
704 
705   // Returns a functor to interpret the signedness of 2 expressions as if they
706   // were added.
GetAddCombiner() const707   Combiner GetAddCombiner() const {
708     return [](Signedness lhs, Signedness rhs) {
709       switch (lhs) {
710         case Signedness::kPositiveOrNegative:
711           break;
712         case Signedness::kStrictlyNegative:
713           if (rhs == Signedness::kStrictlyNegative ||
714               rhs == Signedness::kNegative)
715             return lhs;
716           break;
717         case Signedness::kNegative: {
718           if (rhs == Signedness::kStrictlyNegative)
719             return Signedness::kStrictlyNegative;
720           if (rhs == Signedness::kNegative) return Signedness::kNegative;
721           break;
722         }
723         case Signedness::kStrictlyPositive: {
724           if (rhs == Signedness::kStrictlyPositive ||
725               rhs == Signedness::kPositive) {
726             return Signedness::kStrictlyPositive;
727           }
728           break;
729         }
730         case Signedness::kPositive: {
731           if (rhs == Signedness::kStrictlyPositive)
732             return Signedness::kStrictlyPositive;
733           if (rhs == Signedness::kPositive) return Signedness::kPositive;
734           break;
735         }
736       }
737       return Signedness::kPositiveOrNegative;
738     };
739   }
740 
741   // Returns a functor to interpret the signedness of 2 expressions as if they
742   // were multiplied.
GetMulCombiner() const743   Combiner GetMulCombiner() const {
744     return [](Signedness lhs, Signedness rhs) {
745       switch (lhs) {
746         case Signedness::kPositiveOrNegative:
747           break;
748         case Signedness::kStrictlyNegative: {
749           switch (rhs) {
750             case Signedness::kPositiveOrNegative: {
751               break;
752             }
753             case Signedness::kStrictlyNegative: {
754               return Signedness::kStrictlyPositive;
755             }
756             case Signedness::kNegative: {
757               return Signedness::kPositive;
758             }
759             case Signedness::kStrictlyPositive: {
760               return Signedness::kStrictlyNegative;
761             }
762             case Signedness::kPositive: {
763               return Signedness::kNegative;
764             }
765           }
766           break;
767         }
768         case Signedness::kNegative: {
769           switch (rhs) {
770             case Signedness::kPositiveOrNegative: {
771               break;
772             }
773             case Signedness::kStrictlyNegative:
774             case Signedness::kNegative: {
775               return Signedness::kPositive;
776             }
777             case Signedness::kStrictlyPositive:
778             case Signedness::kPositive: {
779               return Signedness::kNegative;
780             }
781           }
782           break;
783         }
784         case Signedness::kStrictlyPositive: {
785           return rhs;
786         }
787         case Signedness::kPositive: {
788           switch (rhs) {
789             case Signedness::kPositiveOrNegative: {
790               break;
791             }
792             case Signedness::kStrictlyNegative:
793             case Signedness::kNegative: {
794               return Signedness::kNegative;
795             }
796             case Signedness::kStrictlyPositive:
797             case Signedness::kPositive: {
798               return Signedness::kPositive;
799             }
800           }
801           break;
802         }
803       }
804       return Signedness::kPositiveOrNegative;
805     };
806   }
807 
Visit(const SENode * node)808   Signedness Visit(const SENode* node) {
809     switch (node->GetType()) {
810       case SENode::Constant:
811         return Visit(node->AsSEConstantNode());
812         break;
813       case SENode::RecurrentAddExpr:
814         return Visit(node->AsSERecurrentNode());
815         break;
816       case SENode::Negative:
817         return Visit(node->AsSENegative());
818         break;
819       case SENode::CanNotCompute:
820         return Visit(node->AsSECantCompute());
821         break;
822       case SENode::ValueUnknown:
823         return Visit(node->AsSEValueUnknown());
824         break;
825       case SENode::Add:
826         return VisitExpr(node, GetAddCombiner());
827         break;
828       case SENode::Multiply:
829         return VisitExpr(node, GetMulCombiner());
830         break;
831     }
832     return Signedness::kPositiveOrNegative;
833   }
834 
835   // Returns the signedness of a constant |node|.
Visit(const SEConstantNode * node)836   Signedness Visit(const SEConstantNode* node) {
837     if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
838     if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
839     if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
840     return Signedness::kPositiveOrNegative;
841   }
842 
843   // Returns the signedness of an unknown |node| based on its type.
Visit(const SEValueUnknown * node)844   Signedness Visit(const SEValueUnknown* node) {
845     Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
846     analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
847     assert(type && "Can't retrieve a type for the instruction");
848     analysis::Integer* int_type = type->AsInteger();
849     assert(type && "Can't retrieve an integer type for the instruction");
850     return int_type->IsSigned() ? Signedness::kPositiveOrNegative
851                                 : Signedness::kPositive;
852   }
853 
854   // Returns the signedness of a recurring expression.
Visit(const SERecurrentNode * node)855   Signedness Visit(const SERecurrentNode* node) {
856     Signedness coeff_sign = Visit(node->GetCoefficient());
857     // SERecurrentNode represent an affine expression in the range [0,
858     // loop_bound], so the result cannot be strictly positive or negative.
859     switch (coeff_sign) {
860       default:
861         break;
862       case Signedness::kStrictlyNegative:
863         coeff_sign = Signedness::kNegative;
864         break;
865       case Signedness::kStrictlyPositive:
866         coeff_sign = Signedness::kPositive;
867         break;
868     }
869     return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
870   }
871 
872   // Returns the signedness of a negation |node|.
Visit(const SENegative * node)873   Signedness Visit(const SENegative* node) {
874     switch (Visit(*node->begin())) {
875       case Signedness::kPositiveOrNegative: {
876         return Signedness::kPositiveOrNegative;
877       }
878       case Signedness::kStrictlyNegative: {
879         return Signedness::kStrictlyPositive;
880       }
881       case Signedness::kNegative: {
882         return Signedness::kPositive;
883       }
884       case Signedness::kStrictlyPositive: {
885         return Signedness::kStrictlyNegative;
886       }
887       case Signedness::kPositive: {
888         return Signedness::kNegative;
889       }
890     }
891     return Signedness::kPositiveOrNegative;
892   }
893 
Visit(const SECantCompute *)894   Signedness Visit(const SECantCompute*) {
895     return Signedness::kPositiveOrNegative;
896   }
897 
898   // Returns the signedness of a binary expression by using the combiner
899   // |reduce|.
VisitExpr(const SENode * node,std::function<Signedness (Signedness,Signedness)> reduce)900   Signedness VisitExpr(
901       const SENode* node,
902       std::function<Signedness(Signedness, Signedness)> reduce) {
903     Signedness result = Visit(*node->begin());
904     for (const SENode* operand : make_range(++node->begin(), node->end())) {
905       if (result == Signedness::kPositiveOrNegative) {
906         return Signedness::kPositiveOrNegative;
907       }
908       result = reduce(result, Visit(operand));
909     }
910     return result;
911   }
912 
913   IRContext* context_;
914 };
915 }  // namespace
916 
IsAlwaysGreaterThanZero(SENode * node,bool * is_gt_zero) const917 bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
918                                                       bool* is_gt_zero) const {
919   return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
920 }
921 
IsAlwaysGreaterOrEqualToZero(SENode * node,bool * is_ge_zero) const922 bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
923     SENode* node, bool* is_ge_zero) const {
924   return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
925 }
926 
927 namespace {
928 
929 // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
930 // if |node| is not in the chain, returns the original chain.
RemoveOneNodeFromMultiplyChain(SEMultiplyNode * mul,const SENode * node)931 static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
932                                               const SENode* node) {
933   SENode* lhs = mul->GetChildren()[0];
934   SENode* rhs = mul->GetChildren()[1];
935   if (lhs == node) {
936     return rhs;
937   }
938   if (rhs == node) {
939     return lhs;
940   }
941   if (lhs->AsSEMultiplyNode()) {
942     SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
943     if (res != lhs)
944       return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
945   }
946   if (rhs->AsSEMultiplyNode()) {
947     SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
948     if (res != rhs)
949       return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
950   }
951 
952   return mul;
953 }
954 }  // namespace
955 
operator /(SExpression rhs_wrapper) const956 std::pair<SExpression, int64_t> SExpression::operator/(
957     SExpression rhs_wrapper) const {
958   SENode* lhs = node_;
959   SENode* rhs = rhs_wrapper.node_;
960   // Check for division by 0.
961   if (rhs->AsSEConstantNode() &&
962       !rhs->AsSEConstantNode()->FoldToSingleValue()) {
963     return {scev_->CreateCantComputeNode(), 0};
964   }
965 
966   // Trivial case.
967   if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
968     int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
969     int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
970     return {scev_->CreateConstant(lhs_value / rhs_value),
971             lhs_value % rhs_value};
972   }
973 
974   // look for a "c U / U" pattern.
975   if (lhs->AsSEMultiplyNode()) {
976     assert(lhs->GetChildren().size() == 2 &&
977            "More than 2 operand for a multiply node.");
978     SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
979     if (res != lhs) {
980       return {res, 0};
981     }
982   }
983 
984   return {scev_->CreateCantComputeNode(), 0};
985 }
986 
987 }  // namespace opt
988 }  // namespace spvtools
989