• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <torch/csrc/jit/tensorexpr/registerizer.h>
2 #include <iostream>
3 
4 namespace torch::jit::tensorexpr {
5 namespace registerizer {
6 
7 // AccessInfo
8 
addStore(const StorePtr & store,const std::shared_ptr<Scope> & scope)9 void AccessInfo::addStore(
10     const StorePtr& store,
11     const std::shared_ptr<Scope>& scope) {
12   block_ =
13       block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
14 
15   // If there is already a usage and it's this store, that means the same
16   // access is present in the RHS.
17   firstUsageOverlapped_ |= first_usage_ == store;
18   first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store;
19   last_usage_ = store;
20 
21   store_cost_ =
22       IRSimplifier::simplify(alloc<Add>(store_cost_, immLike(store_cost_, 1)));
23   stores_.push_back(store);
24 
25   conditionId_ = scope->conditionId();
26   hiddenAccess_.reset();
27 }
28 
addLoad(const LoadPtr & load,const std::shared_ptr<Scope> & scope,const StmtPtr & usage)29 void AccessInfo::addLoad(
30     const LoadPtr& load,
31     const std::shared_ptr<Scope>& scope,
32     const StmtPtr& usage) {
33   block_ =
34       block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
35   first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage;
36   last_usage_ = usage;
37 
38   load_cost_ =
39       IRSimplifier::simplify(alloc<Add>(load_cost_, immLike(load_cost_, 1)));
40   loads_.push_back(load);
41 
42   conditionId_ = scope->conditionId();
43   hiddenAccess_.reset();
44 }
45 
merge(const std::shared_ptr<AccessInfo> & other)46 void AccessInfo::merge(const std::shared_ptr<AccessInfo>& other) {
47   TORCH_INTERNAL_ASSERT(
48       hash_ == other->hash(),
49       buildErrorMessage(
50           "Expected hashes to match in registerizer in the fuser."));
51   TORCH_INTERNAL_ASSERT(
52       indices_.size() == other->indices().size(),
53       buildErrorMessage(
54           "Expected ranks to match in registerizer in the fuser."));
55 
56   last_usage_ = other->last_usage();
57   for (const auto& s : other->stores()) {
58     stores_.push_back(s);
59   }
60   for (const auto& l : other->loads()) {
61     loads_.push_back(l);
62   }
63 
64   store_cost_ =
65       IRSimplifier::simplify(alloc<Add>(store_cost_, other->store_cost()));
66   load_cost_ =
67       IRSimplifier::simplify(alloc<Add>(load_cost_, other->load_cost()));
68 
69   block_ = Block::getSharedParent(block_, other->block());
70   // update first and last usage to be in the parent Block.
71   first_usage_ = block_->getEnclosedRoot(first_usage_);
72   last_usage_ = block_->getEnclosedRoot(last_usage_);
73   hiddenAccess_.reset();
74 }
75 
overlaps(const std::shared_ptr<AccessInfo> & other)76 bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
77   // All accesses to a buf must have the same dimensionality.
78   TORCH_INTERNAL_ASSERT(
79       indices_.size() == other->indices().size(),
80       buildErrorMessage(
81           "Expected ranks to match in registerizer in the fuser."));
82 
83   auto& other_indices = other->indices();
84 
85   // They don't overlap if there is a guaranteed difference in any
86   // dimension.
87   bool overlap = true;
88   for (size_t i = 0; i < indices_.size(); ++i) {
89     ExprPtr diff = alloc<Sub>(indices_[i], other_indices[i]);
90     diff = IRSimplifier::simplify(diff);
91 
92     if (diff->isConstant() && !immediateEquals(diff, 0)) {
93       overlap = false;
94       break;
95     }
96   }
97 
98   return overlap;
99 }
100 
dependsOnVar(const VarPtr & v)101 bool AccessInfo::dependsOnVar(const VarPtr& v) {
102   VarFinder vf;
103   for (const auto& i : indices_) {
104     i->accept(&vf);
105   }
106 
107   return vf.vars().count(v);
108 }
109 
cloneWithHiddenInfo(const std::shared_ptr<AccessInfo> & orig)110 std::shared_ptr<AccessInfo> AccessInfo::cloneWithHiddenInfo(
111     const std::shared_ptr<AccessInfo>& orig) {
112   std::shared_ptr<AccessInfo> newInfo = std::make_shared<AccessInfo>(
113       orig->hash(), orig->buf(), orig->indices(), orig->accessOrder());
114 
115   newInfo->block_ = orig->block_;
116   newInfo->first_usage_ = orig->first_usage_;
117   newInfo->last_usage_ = orig->last_usage_;
118   newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_;
119   newInfo->store_cost_ = orig->store_cost_;
120   newInfo->load_cost_ = orig->load_cost_;
121   for (const auto& s : orig->stores_) {
122     newInfo->stores_.push_back(s);
123   }
124   for (const auto& s : orig->loads_) {
125     newInfo->loads_.push_back(s);
126   }
127 
128   newInfo->conditionId_ = orig->conditionId_;
129   newInfo->hiddenAccess_ = orig;
130   return newInfo;
131 }
132 
print() const133 void AccessInfo::print() const {
134   std::cout << "Access: " << *buf_ << "{";
135   for (const auto& i : indices_) {
136     std::cout << *i << " ";
137   }
138   std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -";
139   std::cout << " loads: " << loads_.size() << " (" << *load_cost_ << ")";
140   if (conditionId_) {
141     std::cout << " cond: " << conditionId_;
142   }
143 
144   std::cout << "\n";
145 }
146 
147 // Scope
148 
closeAccess(const std::shared_ptr<AccessInfo> & info)149 void Scope::closeAccess(const std::shared_ptr<AccessInfo>& info) {
150   closedAccesses_.push_back(info);
151 }
152 
getAccessMapByBuf(const BufPtr & b)153 AccessHashMap& Scope::getAccessMapByBuf(const BufPtr& b) {
154   auto it = openAccesses_.find(b);
155   if (it == openAccesses_.end()) {
156     // create and return
157     return openAccesses_[b];
158   }
159 
160   return it->second;
161 }
162 
filterClosed()163 void Scope::filterClosed() {
164   closedAccesses_.erase(
165       std::remove_if(
166           closedAccesses_.begin(),
167           closedAccesses_.end(),
168           [](auto info) {
169             return info->store_cost()->isConstant() &&
170                 immediateAs<int>(info->store_cost()) <= 1 &&
171                 info->load_cost()->isConstant() &&
172                 immediateAs<int>(info->load_cost()) <= 1;
173           }),
174       closedAccesses_.end());
175 }
176 
177 // RegisterizerAnalysis
178 
closeAccessIntoScope(const std::shared_ptr<AccessInfo> & info,const std::shared_ptr<Scope> & scope)179 void RegisterizerAnalysis::closeAccessIntoScope(
180     const std::shared_ptr<AccessInfo>& info,
181     const std::shared_ptr<Scope>& scope) {
182   if (exprConditionals_.count(info->conditionId()) != 0) {
183     return;
184   }
185 
186   if (info->hiddenAccess()) {
187     closeAccessIntoScope(info->hiddenAccess(), scope);
188     return;
189   }
190   scope->closeAccess(info);
191 }
192 
visit(const ForPtr & v)193 void RegisterizerAnalysis::visit(const ForPtr& v) {
194   if (v->loop_options().is_gpu_block_index() ||
195       v->loop_options().is_gpu_thread_index()) {
196     throw malformed_input(
197         "Registerization must occur after parallelism flattening");
198   }
199 
200   auto parent = currentScope_;
201   currentScope_ = std::make_shared<Scope>(v->body(), parent);
202 
203   currentScope_->addLocalVar(v->var());
204 
205   stmtStack_.push_front(v);
206   v->body()->accept(this);
207   stmtStack_.pop_front();
208 
209   ExprPtr loopExtent =
210       IRSimplifier::simplify(alloc<Sub>(v->stop(), v->start()));
211 
212   // now we need to see which accesses we can hoist out of the for loop, their
213   // costs should be multiplied by the loop extent.
214   for (auto& pair : currentScope_->openAccesses()) {
215     if (pair.second.empty()) {
216       continue;
217     }
218 
219     auto& childAccesses = pair.second;
220 
221     for (auto it = childAccesses.begin(); it != childAccesses.end();) {
222       std::shared_ptr<AccessInfo>& candidate = it->second;
223 
224       // If the access is open, but conditional, then we have a problem. It's
225       // possible that an access at a higher scope could "unhide" the
226       // conditional access, in which case we need to hoist. If there is no
227       // access to this element at a higher scope then we cannot safely hoist.
228       // We cannot know at this level whether that will or wont occur.
229       //
230       // The solution we take here is to split the space-time continuum, and
231       // keep both versions of the access handy. If the hoisted access is not
232       // used above, we'll fall back to using the hidden, conditional
233       // AccessInfo - if it is, we'll delete the copy.
234       if (candidate->conditionId() != 0) {
235         candidate = AccessInfo::cloneWithHiddenInfo(candidate);
236       }
237 
238       bool closed = false;
239       // If this access depends on a locally scoped variable, it cannot be
240       // hosted out of the loop.
241       for (const auto& v : currentScope_->localVars()) {
242         if (candidate->dependsOnVar(v)) {
243           closeAccessIntoScope(candidate, currentScope_);
244           closed = true;
245           break;
246         }
247       }
248       if (closed) {
249         it = childAccesses.erase(it);
250         continue;
251       }
252 
253       // hoist!
254       // By hoisting we pull the reads and writes out of the loop, and so the
255       // benefit of registerizing this access is multiplied by the loop extent.
256       candidate->setEnclosingBlock(parent->block());
257       candidate->hoistCosts(loopExtent);
258 
259       // in the parent block, this loop Stmt is the insertion point for the
260       // initializer and finalizer.
261       candidate->setUsageMarks(v, v);
262 
263       ++it;
264     }
265   }
266 
267   // If an access is closed within a loop then it cannot be merged into an
268   // existing open access, but will still close that existing access. This is
269   // somewhat different from the regular merge so we need to handle closed
270   // accesses first.
271   mergeHiddenScope(true);
272 
273   // having hoisted, now we can merge normally.
274   mergeCurrentScopeIntoParent();
275 };
276 
visit(const CondPtr & v)277 void RegisterizerAnalysis::visit(const CondPtr& v) {
278   ExprPtr condition = v->condition();
279   BlockPtr true_stmt = v->true_stmt();
280   BlockPtr false_stmt = v->false_stmt();
281 
282   stmtStack_.push_front(v);
283 
284   // condition is in the enclosing scope.
285   condition->accept(this);
286 
287   auto prev_scope = currentScope_;
288   auto true_scope =
289       std::make_shared<Scope>(true_stmt, prev_scope, ++conditionId_);
290   auto false_scope =
291       std::make_shared<Scope>(false_stmt, prev_scope, ++conditionId_);
292 
293   if (true_stmt) {
294     currentScope_ = true_scope;
295     true_stmt->accept(this);
296     mergeHiddenScope(true);
297     mergeCurrentScopeIntoParent();
298   }
299   if (false_stmt) {
300     currentScope_ = false_scope;
301     false_stmt->accept(this);
302     mergeHiddenScope(true);
303     mergeCurrentScopeIntoParent();
304   }
305 
306   // TODO: even though both scopes are conditional, we can merge accesses if
307   // they totally overlap in both branches, since we can guarantee one
308   // definition will be hit. We might need a 3-way merge? Not as simple as
309   // merging the true and false scopes together first.
310 
311   stmtStack_.pop_front();
312 }
313 
314 // IfThenElses are just like Conds except they are not Stmts, which means no
315 // registerization can occur internally. However, the first reference to an
316 // access can occur within one if its visible outside the condition.
visit(const IfThenElsePtr & v)317 void RegisterizerAnalysis::visit(const IfThenElsePtr& v) {
318   ExprPtr condition = v->condition();
319   ExprPtr true_value = v->true_value();
320   ExprPtr false_value = v->false_value();
321 
322   // condition is in enclosing scope.
323   condition->accept(this);
324 
325   auto prev_scope = currentScope_;
326   auto true_scope =
327       std::make_shared<Scope>(prev_scope->block(), prev_scope, ++conditionId_);
328   auto false_scope =
329       std::make_shared<Scope>(prev_scope->block(), prev_scope, ++conditionId_);
330 
331   // We store IfThenElse scopes in a global map, which we use to prevent closing
332   // any access that would require inserting statements in the values, which
333   // cannot enclose Stmts.
334   exprConditionals_.insert(true_scope->conditionId());
335   exprConditionals_.insert(false_scope->conditionId());
336 
337   if (true_value) {
338     currentScope_ = true_scope;
339     true_value->accept(this);
340     mergeHiddenScope(false);
341     mergeCurrentScopeIntoParent();
342   }
343 
344   if (false_value) {
345     currentScope_ = false_scope;
346     false_value->accept(this);
347     mergeHiddenScope(false);
348     mergeCurrentScopeIntoParent();
349   }
350 }
351 
visit(const LetPtr & v)352 void RegisterizerAnalysis::visit(const LetPtr& v) {
353   currentScope_->addLocalVar(v->var());
354 
355   stmtStack_.push_front(v);
356   v->value()->accept(this);
357   stmtStack_.pop_front();
358 }
359 
visit(const BlockPtr & v)360 void RegisterizerAnalysis::visit(const BlockPtr& v) {
361   auto prev_scope = currentScope_;
362   if (currentScope_->block() != v) {
363     currentScope_ = std::make_shared<Scope>(v, prev_scope);
364   }
365 
366   stmtStack_.push_front(v);
367 
368   for (const auto& s : *v) {
369     s->accept(this);
370     if (currentScope_->block() != v) {
371       // merge the inner block's accesses into this Block's accesses.
372       mergeCurrentScopeIntoParent();
373     }
374   }
375 
376   stmtStack_.pop_front();
377 
378   if (prev_scope->block() == nullptr) {
379     // close any open candidates.
380     for (auto& p1 : currentScope_->openAccesses()) {
381       for (auto& p2 : p1.second) {
382         closeAccessIntoScope(p2.second, currentScope_);
383       }
384     }
385   }
386 }
387 
visit(const StorePtr & v)388 void RegisterizerAnalysis::visit(const StorePtr& v) {
389   stmtStack_.push_front(v);
390   v->value()->accept(this);
391   stmtStack_.pop_front();
392 
393   if (v->indices().empty()) {
394     // already a scalar.
395     return;
396   }
397 
398   // hash the Store:
399   SimplifierHashType accessHash = hasher_.hash(v->buf());
400   for (const auto& i : v->indices()) {
401     accessHash = hasher_.hash_combine(accessHash, i);
402   }
403 
404   auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf());
405   auto candidateIt = bufAccesses.find(accessHash);
406 
407   // If an identical access already exists, add this Store to it.
408   if (candidateIt != bufAccesses.end()) {
409     candidateIt->second->addStore(v, currentScope_);
410     return;
411   }
412 
413   // Otherwise make a new AccessInfo and add this store.
414   auto info = std::make_shared<AccessInfo>(
415       accessHash, v->buf(), v->indices(), accessOrder_++);
416   info->addStore(v, currentScope_);
417 
418   // This new access may overlap an existing open access, in which case we need
419   // to close the older of the two.
420   bool alreadyOverlapped = false;
421   for (auto it = bufAccesses.begin(); it != bufAccesses.end();) {
422     auto other = it->second;
423     if (info->overlaps(other)) {
424       if (other->last_usage() == v) {
425         // we are already overlapped by an access in the RHS.
426         alreadyOverlapped = true;
427       }
428       closeAccessIntoScope(other, currentScope_);
429       it = bufAccesses.erase(it);
430     } else {
431       ++it;
432     }
433   }
434 
435   if (alreadyOverlapped) {
436     closeAccessIntoScope(info, currentScope_);
437   } else {
438     bufAccesses.emplace(accessHash, info);
439   }
440 }
441 
visit(const LoadPtr & v)442 void RegisterizerAnalysis::visit(const LoadPtr& v) {
443   if (v->indices().empty()) {
444     // already a scalar.
445     return;
446   }
447   // hash the Load:
448   SimplifierHashType accessHash = hasher_.hash(v->buf());
449   for (const auto& i : v->indices()) {
450     accessHash = hasher_.hash_combine(accessHash, i);
451   }
452 
453   auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf());
454   auto candidateIt = bufAccesses.find(accessHash);
455   if (candidateIt != bufAccesses.end()) {
456     // found the right access, can just insert.
457     candidateIt->second->addLoad(v, currentScope_, stmtStack_.front());
458     return;
459   }
460 
461   std::shared_ptr<AccessInfo> info = std::make_shared<AccessInfo>(
462       accessHash, v->buf(), v->indices(), accessOrder_++);
463   info->addLoad(v, currentScope_, stmtStack_.front());
464 
465   bool alreadyOverlapped = false;
466   // This new access may overlap an existing open access, in which case we need
467   // to finalize the older of the two.
468   for (auto it = bufAccesses.begin(); it != bufAccesses.end();) {
469     auto other = it->second;
470     if (info->overlaps(other)) {
471       if (info->last_usage() == other->last_usage()) {
472         // if these two accesses are from the same Stmt, they already overlap
473         // each other.
474         alreadyOverlapped = true;
475       }
476       closeAccessIntoScope(other, currentScope_);
477       it = bufAccesses.erase(it);
478     } else {
479       ++it;
480     }
481   }
482 
483   if (alreadyOverlapped) {
484     closeAccessIntoScope(info, currentScope_);
485   } else {
486     bufAccesses.emplace(accessHash, info);
487   }
488 }
489 
490 // Loop and Conditional scopes are different in that it may or may not be
491 // possible to hoist the initializer of a scalar variable outside the block
492 // depending on if we can tell that the Buffer access is valid outside. This is
493 // tricky because the access that demonstrates this may be later in the tree and
494 // we haven't encountered it yet.
495 // The allowClosed flag indicates whether we want to keep the closed accesses
496 // (For and Cond), or not (IfThenElse).
mergeHiddenScope(bool allowClosed)497 void RegisterizerAnalysis::mergeHiddenScope(bool allowClosed) {
498   // The rule is that if any access is closed within the conditional block, any
499   // accesses which overlap it must also be closed - since their initializer
500   // cannot be hoisted out of the block.
501   std::list<std::shared_ptr<AccessInfo>> newClosed;
502   for (auto& info : currentScope_->closedAccesses()) {
503     auto& candidates = currentScope_->getAccessMapByBuf(info->buf());
504     for (auto it = candidates.begin(); it != candidates.end();) {
505       std::shared_ptr<AccessInfo> candidate = it->second;
506 
507       if (info->hash() == candidate->hash() || info->overlaps(candidate)) {
508         newClosed.push_back(candidate);
509         it = candidates.erase(it);
510       } else {
511         ++it;
512       }
513     }
514   }
515 
516   if (allowClosed) {
517     for (auto& info : newClosed) {
518       closeAccessIntoScope(info, currentScope_);
519     }
520   } else {
521     currentScope_->closedAccesses().clear();
522   }
523 }
524 
525 // Merge currentScope_ into it's parent, and make parent the new currentScope_.
mergeCurrentScopeIntoParent()526 void RegisterizerAnalysis::mergeCurrentScopeIntoParent() {
527   auto parent = currentScope_->parent();
528 
529   // copy across current closed accesses, merging / closing as necessary
530   for (auto& candidate : currentScope_->closedAccesses()) {
531     auto& parentAccesses = parent->getAccessMapByBuf(candidate->buf());
532 
533     auto parentIt = parentAccesses.find(candidate->hash());
534     if (parentIt != parentAccesses.end()) {
535       std::shared_ptr<AccessInfo> pCandidate = parentIt->second;
536 
537       // if the access is closed inside a condition, it can only be merged if
538       // the parent is in the same condition.
539       if (candidate->conditionId() &&
540           pCandidate->conditionId() != candidate->conditionId()) {
541         // the parent's access must be closed.
542         closeAccessIntoScope(pCandidate, parent);
543         parentAccesses.erase(parentIt);
544 
545         // the childs access inserted into the parent scope.
546         closeAccessIntoScope(candidate, parent);
547         continue;
548       }
549 
550       // merge totally overlapping accesses.
551       parentIt->second->merge(candidate);
552       closeAccessIntoScope(parentIt->second, parent);
553       parentAccesses.erase(parentIt);
554       continue;
555     }
556 
557     // we didn't find a perfect match, but we need to check all open accesses of
558     // this buf for partial overlap.
559     for (auto it = parentAccesses.begin(); it != parentAccesses.end();) {
560       std::shared_ptr<AccessInfo> pCandidate = it->second;
561       // Partial overlap of parent access: close parent access.
562       if (candidate->overlaps(pCandidate)) {
563         closeAccessIntoScope(pCandidate, parent);
564         it = parentAccesses.erase(it);
565         continue;
566       }
567       ++it;
568     }
569 
570     // Insert the childs closed access into the parent scope.
571     closeAccessIntoScope(candidate, parent);
572   }
573 
574   // copy across current open accesses, merging as necessary.
575   // for each Buf with an open access:
576   for (auto& pair : currentScope_->openAccesses()) {
577     BufPtr buf = pair.first;
578     if (pair.second.empty()) {
579       continue;
580     }
581 
582     auto& parentAccesses = parent->getAccessMapByBuf(buf);
583 
584     // for each open access in the child scope for this Buf:
585     for (auto& hpair : pair.second) {
586       bool handled{false};
587       std::shared_ptr<AccessInfo> candidate = hpair.second;
588 
589       for (auto it = parentAccesses.begin(); it != parentAccesses.end();) {
590         std::shared_ptr<AccessInfo> pCandidate = it->second;
591 
592         // If it completely overlaps then merge.
593         if (candidate->hash() == pCandidate->hash()) {
594           // if both accesses are found in conditional blocks, they cannot be
595           // merged, but the earlier must be closed.
596           if (pCandidate->conditionId() != parent->conditionId() &&
597               pCandidate->conditionId() != candidate->conditionId()) {
598             closeAccessIntoScope(pCandidate, parent);
599             it = parentAccesses.erase(it);
600             continue;
601           }
602           pCandidate->merge(candidate);
603           handled = true;
604           ++it;
605           continue;
606         }
607 
608         // It can overlap an access in the parent: close the parent access.
609         // The child access may still be open.
610         if (candidate->overlaps(pCandidate)) {
611           closeAccessIntoScope(pCandidate, parent);
612           it = parentAccesses.erase(it);
613           continue;
614         }
615 
616         ++it;
617       }
618 
619       // If this access depends on a locally scoped variable, it cannot be
620       // lifted out of the loop.
621       for (const auto& v : currentScope_->localVars()) {
622         if (candidate->dependsOnVar(v)) {
623           closeAccessIntoScope(candidate, parent);
624           handled = true;
625           break;
626         }
627       }
628 
629       if (!handled) {
630         // If the inner scope was not conditional, but the outer scope is: all
631         // current accesses are now conditional in the parent scope.
632         if (candidate->conditionId() == 0) {
633           candidate->setConditionId(parent->conditionId());
634         }
635         parentAccesses[candidate->hash()] = candidate;
636       }
637     }
638   }
639 
640   currentScope_ = parent;
641 }
642 
getCandidates()643 std::vector<std::shared_ptr<AccessInfo>> RegisterizerAnalysis::getCandidates() {
644   currentScope_->filterClosed();
645   std::sort(
646       currentScope_->closedAccesses().begin(),
647       currentScope_->closedAccesses().end(),
648       [](auto i1, auto i2) { return i1->accessOrder() < i2->accessOrder(); });
649   return currentScope_->closedAccesses();
650 }
651 
652 // RegisterizerReplacer
653 
mutate(const LoadPtr & v)654 ExprPtr RegisterizerReplacer::mutate(const LoadPtr& v) {
655   auto it = loadToAccess_.find(v);
656   if (it == loadToAccess_.end()) {
657     // This access cannot be registerized.
658     return v;
659   }
660 
661   auto& info = it->second;
662 
663   return info->replacement().var;
664 }
665 
mutate(const StorePtr & v)666 StmtPtr RegisterizerReplacer::mutate(const StorePtr& v) {
667   if (eliminatedIntializers_.count(v) != 0) {
668     // This store is the initializer for a scalar var that is already inserted.
669     return nullptr;
670   }
671 
672   auto it = storeToAccess_.find(v);
673   if (it == storeToAccess_.end()) {
674     // This access cannot be registerized.
675     return IRMutator::mutate(v);
676   }
677 
678   auto& info = it->second;
679 
680   ExprPtr new_val = v->value()->accept_mutator(this);
681 
682   v->set_value(new_val);
683   v->set_buf(info->replacement().var_wrapper);
684   v->set_indices({});
685   return v;
686 }
687 
mutate(const BlockPtr & v)688 StmtPtr RegisterizerReplacer::mutate(const BlockPtr& v) {
689   auto& scope = parentToAccesses_[v];
690 
691   std::vector<StmtPtr> stmts;
692   for (const StmtPtr& stmt : v->stmts()) {
693     {
694       // Insert the initializer for any Scalars scoped to this block.
695       auto it = scope.initializerPoints_.find(stmt);
696       if (it != scope.initializerPoints_.end()) {
697         for (auto& info : it->second) {
698           StmtPtr initializer =
699               info->replacement().initializer->accept_mutator(this);
700           stmts.push_back(initializer);
701         }
702         scope.initializerPoints_.erase(it);
703       }
704     }
705 
706     StmtPtr stmt_new = stmt->accept_mutator(this);
707     if (stmt_new) {
708       if (stmt_new->get_parent()) {
709         stmt_new = Stmt::clone(stmt_new);
710       }
711       stmts.push_back(stmt_new);
712     }
713 
714     {
715       // Insert the finalizer for any Scalars scoped to this block.
716       auto it = scope.finalizePoints_.find(stmt);
717       if (it != scope.finalizePoints_.end()) {
718         for (auto& info : it->second) {
719           StorePtr finalizer = alloc<Store>(
720               info->buf(), info->indices(), info->replacement().var);
721           stmts.push_back(finalizer);
722         }
723         scope.finalizePoints_.erase(it);
724       }
725     }
726   }
727 
728   return alloc<Block>(stmts);
729 }
730 
buildReplacements()731 void RegisterizerReplacer::buildReplacements() {
732   // Traverse the list of replacements, creating vars and updating our local
733   // maps.
734   for (auto& info : infoSet_) {
735     VarPtr v = alloc<Var>(
736         info->buf()->name_hint() + "_" +
737             std::to_string(getBufferAccessCount(info->buf())),
738         info->buf()->dtype());
739 
740     info->replacement().var = v;
741 
742     // we need to wrap the Var in a Buf so we can Load or Store it.
743     info->replacement().var_wrapper =
744         alloc<Buf>(v, std::vector<ExprPtr>({}), info->buf()->dtype());
745 
746     bool first = true;
747     for (const auto& s : info->stores()) {
748       if (first && info->first_usage() == s && !info->firstUsageOverlapped()) {
749         info->replacement().initializer = alloc<Let>(v, s->value());
750         eliminatedIntializers_.insert(s);
751       } else {
752         storeToAccess_[s] = info;
753       }
754 
755       first = false;
756     }
757 
758     for (const auto& s : info->loads()) {
759       loadToAccess_[s] = info;
760     }
761 
762     auto& scope = parentToAccesses_[info->block()];
763     scope.initializerPoints_[info->first_usage()].push_back(info);
764 
765     // Only finalize if the scalar is written.
766     if (!info->stores().empty()) {
767       // push front to finalize in reverse order of encounter.
768       scope.finalizePoints_[info->last_usage()].push_front(info);
769     }
770 
771     // create a default initializer by reading the access.
772     if (info->replacement().initializer == nullptr) {
773       info->replacement().initializer = alloc<Let>(
774           v, alloc<Load>(info->buf()->dtype(), info->buf(), info->indices()));
775     }
776   }
777 }
778 
779 } // namespace registerizer
780 
781 // Apply scalar replacement to all accesses in s.
registerize(StmtPtr s)782 StmtPtr registerize(StmtPtr s) {
783   s = IRSimplifier::simplify(s);
784 
785   // The outermost node must be a Block so we have somewhere to put outer scope
786   // scalars.
787   if (!to<Block>(s)) {
788     s = alloc<Block>(std::vector<StmtPtr>({s}));
789   }
790   registerizer::RegisterizerAnalysis analysis;
791   s->accept(&analysis);
792   auto candidates = analysis.getCandidates();
793 
794   registerizer::RegisterizerReplacer replacer(candidates);
795   s = s->accept_mutator(&replacer);
796   return s;
797 }
798 
799 } // namespace torch::jit::tensorexpr
800