1 #include <torch/csrc/jit/tensorexpr/registerizer.h>
2 #include <iostream>
3
4 namespace torch::jit::tensorexpr {
5 namespace registerizer {
6
7 // AccessInfo
8
addStore(const StorePtr & store,const std::shared_ptr<Scope> & scope)9 void AccessInfo::addStore(
10 const StorePtr& store,
11 const std::shared_ptr<Scope>& scope) {
12 block_ =
13 block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
14
15 // If there is already a usage and it's this store, that means the same
16 // access is present in the RHS.
17 firstUsageOverlapped_ |= first_usage_ == store;
18 first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : store;
19 last_usage_ = store;
20
21 store_cost_ =
22 IRSimplifier::simplify(alloc<Add>(store_cost_, immLike(store_cost_, 1)));
23 stores_.push_back(store);
24
25 conditionId_ = scope->conditionId();
26 hiddenAccess_.reset();
27 }
28
addLoad(const LoadPtr & load,const std::shared_ptr<Scope> & scope,const StmtPtr & usage)29 void AccessInfo::addLoad(
30 const LoadPtr& load,
31 const std::shared_ptr<Scope>& scope,
32 const StmtPtr& usage) {
33 block_ =
34 block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
35 first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage;
36 last_usage_ = usage;
37
38 load_cost_ =
39 IRSimplifier::simplify(alloc<Add>(load_cost_, immLike(load_cost_, 1)));
40 loads_.push_back(load);
41
42 conditionId_ = scope->conditionId();
43 hiddenAccess_.reset();
44 }
45
merge(const std::shared_ptr<AccessInfo> & other)46 void AccessInfo::merge(const std::shared_ptr<AccessInfo>& other) {
47 TORCH_INTERNAL_ASSERT(
48 hash_ == other->hash(),
49 buildErrorMessage(
50 "Expected hashes to match in registerizer in the fuser."));
51 TORCH_INTERNAL_ASSERT(
52 indices_.size() == other->indices().size(),
53 buildErrorMessage(
54 "Expected ranks to match in registerizer in the fuser."));
55
56 last_usage_ = other->last_usage();
57 for (const auto& s : other->stores()) {
58 stores_.push_back(s);
59 }
60 for (const auto& l : other->loads()) {
61 loads_.push_back(l);
62 }
63
64 store_cost_ =
65 IRSimplifier::simplify(alloc<Add>(store_cost_, other->store_cost()));
66 load_cost_ =
67 IRSimplifier::simplify(alloc<Add>(load_cost_, other->load_cost()));
68
69 block_ = Block::getSharedParent(block_, other->block());
70 // update first and last usage to be in the parent Block.
71 first_usage_ = block_->getEnclosedRoot(first_usage_);
72 last_usage_ = block_->getEnclosedRoot(last_usage_);
73 hiddenAccess_.reset();
74 }
75
overlaps(const std::shared_ptr<AccessInfo> & other)76 bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
77 // All accesses to a buf must have the same dimensionality.
78 TORCH_INTERNAL_ASSERT(
79 indices_.size() == other->indices().size(),
80 buildErrorMessage(
81 "Expected ranks to match in registerizer in the fuser."));
82
83 auto& other_indices = other->indices();
84
85 // They don't overlap if there is a guaranteed difference in any
86 // dimension.
87 bool overlap = true;
88 for (size_t i = 0; i < indices_.size(); ++i) {
89 ExprPtr diff = alloc<Sub>(indices_[i], other_indices[i]);
90 diff = IRSimplifier::simplify(diff);
91
92 if (diff->isConstant() && !immediateEquals(diff, 0)) {
93 overlap = false;
94 break;
95 }
96 }
97
98 return overlap;
99 }
100
dependsOnVar(const VarPtr & v)101 bool AccessInfo::dependsOnVar(const VarPtr& v) {
102 VarFinder vf;
103 for (const auto& i : indices_) {
104 i->accept(&vf);
105 }
106
107 return vf.vars().count(v);
108 }
109
cloneWithHiddenInfo(const std::shared_ptr<AccessInfo> & orig)110 std::shared_ptr<AccessInfo> AccessInfo::cloneWithHiddenInfo(
111 const std::shared_ptr<AccessInfo>& orig) {
112 std::shared_ptr<AccessInfo> newInfo = std::make_shared<AccessInfo>(
113 orig->hash(), orig->buf(), orig->indices(), orig->accessOrder());
114
115 newInfo->block_ = orig->block_;
116 newInfo->first_usage_ = orig->first_usage_;
117 newInfo->last_usage_ = orig->last_usage_;
118 newInfo->firstUsageOverlapped_ = orig->firstUsageOverlapped_;
119 newInfo->store_cost_ = orig->store_cost_;
120 newInfo->load_cost_ = orig->load_cost_;
121 for (const auto& s : orig->stores_) {
122 newInfo->stores_.push_back(s);
123 }
124 for (const auto& s : orig->loads_) {
125 newInfo->loads_.push_back(s);
126 }
127
128 newInfo->conditionId_ = orig->conditionId_;
129 newInfo->hiddenAccess_ = orig;
130 return newInfo;
131 }
132
print() const133 void AccessInfo::print() const {
134 std::cout << "Access: " << *buf_ << "{";
135 for (const auto& i : indices_) {
136 std::cout << *i << " ";
137 }
138 std::cout << "} stores: " << stores_.size() << " (" << *store_cost_ << ") -";
139 std::cout << " loads: " << loads_.size() << " (" << *load_cost_ << ")";
140 if (conditionId_) {
141 std::cout << " cond: " << conditionId_;
142 }
143
144 std::cout << "\n";
145 }
146
147 // Scope
148
closeAccess(const std::shared_ptr<AccessInfo> & info)149 void Scope::closeAccess(const std::shared_ptr<AccessInfo>& info) {
150 closedAccesses_.push_back(info);
151 }
152
getAccessMapByBuf(const BufPtr & b)153 AccessHashMap& Scope::getAccessMapByBuf(const BufPtr& b) {
154 auto it = openAccesses_.find(b);
155 if (it == openAccesses_.end()) {
156 // create and return
157 return openAccesses_[b];
158 }
159
160 return it->second;
161 }
162
filterClosed()163 void Scope::filterClosed() {
164 closedAccesses_.erase(
165 std::remove_if(
166 closedAccesses_.begin(),
167 closedAccesses_.end(),
168 [](auto info) {
169 return info->store_cost()->isConstant() &&
170 immediateAs<int>(info->store_cost()) <= 1 &&
171 info->load_cost()->isConstant() &&
172 immediateAs<int>(info->load_cost()) <= 1;
173 }),
174 closedAccesses_.end());
175 }
176
177 // RegisterizerAnalysis
178
closeAccessIntoScope(const std::shared_ptr<AccessInfo> & info,const std::shared_ptr<Scope> & scope)179 void RegisterizerAnalysis::closeAccessIntoScope(
180 const std::shared_ptr<AccessInfo>& info,
181 const std::shared_ptr<Scope>& scope) {
182 if (exprConditionals_.count(info->conditionId()) != 0) {
183 return;
184 }
185
186 if (info->hiddenAccess()) {
187 closeAccessIntoScope(info->hiddenAccess(), scope);
188 return;
189 }
190 scope->closeAccess(info);
191 }
192
visit(const ForPtr & v)193 void RegisterizerAnalysis::visit(const ForPtr& v) {
194 if (v->loop_options().is_gpu_block_index() ||
195 v->loop_options().is_gpu_thread_index()) {
196 throw malformed_input(
197 "Registerization must occur after parallelism flattening");
198 }
199
200 auto parent = currentScope_;
201 currentScope_ = std::make_shared<Scope>(v->body(), parent);
202
203 currentScope_->addLocalVar(v->var());
204
205 stmtStack_.push_front(v);
206 v->body()->accept(this);
207 stmtStack_.pop_front();
208
209 ExprPtr loopExtent =
210 IRSimplifier::simplify(alloc<Sub>(v->stop(), v->start()));
211
212 // now we need to see which accesses we can hoist out of the for loop, their
213 // costs should be multiplied by the loop extent.
214 for (auto& pair : currentScope_->openAccesses()) {
215 if (pair.second.empty()) {
216 continue;
217 }
218
219 auto& childAccesses = pair.second;
220
221 for (auto it = childAccesses.begin(); it != childAccesses.end();) {
222 std::shared_ptr<AccessInfo>& candidate = it->second;
223
224 // If the access is open, but conditional, then we have a problem. It's
225 // possible that an access at a higher scope could "unhide" the
226 // conditional access, in which case we need to hoist. If there is no
227 // access to this element at a higher scope then we cannot safely hoist.
228 // We cannot know at this level whether that will or wont occur.
229 //
230 // The solution we take here is to split the space-time continuum, and
231 // keep both versions of the access handy. If the hoisted access is not
232 // used above, we'll fall back to using the hidden, conditional
233 // AccessInfo - if it is, we'll delete the copy.
234 if (candidate->conditionId() != 0) {
235 candidate = AccessInfo::cloneWithHiddenInfo(candidate);
236 }
237
238 bool closed = false;
239 // If this access depends on a locally scoped variable, it cannot be
240 // hosted out of the loop.
241 for (const auto& v : currentScope_->localVars()) {
242 if (candidate->dependsOnVar(v)) {
243 closeAccessIntoScope(candidate, currentScope_);
244 closed = true;
245 break;
246 }
247 }
248 if (closed) {
249 it = childAccesses.erase(it);
250 continue;
251 }
252
253 // hoist!
254 // By hoisting we pull the reads and writes out of the loop, and so the
255 // benefit of registerizing this access is multiplied by the loop extent.
256 candidate->setEnclosingBlock(parent->block());
257 candidate->hoistCosts(loopExtent);
258
259 // in the parent block, this loop Stmt is the insertion point for the
260 // initializer and finalizer.
261 candidate->setUsageMarks(v, v);
262
263 ++it;
264 }
265 }
266
267 // If an access is closed within a loop then it cannot be merged into an
268 // existing open access, but will still close that existing access. This is
269 // somewhat different from the regular merge so we need to handle closed
270 // accesses first.
271 mergeHiddenScope(true);
272
273 // having hoisted, now we can merge normally.
274 mergeCurrentScopeIntoParent();
275 };
276
visit(const CondPtr & v)277 void RegisterizerAnalysis::visit(const CondPtr& v) {
278 ExprPtr condition = v->condition();
279 BlockPtr true_stmt = v->true_stmt();
280 BlockPtr false_stmt = v->false_stmt();
281
282 stmtStack_.push_front(v);
283
284 // condition is in the enclosing scope.
285 condition->accept(this);
286
287 auto prev_scope = currentScope_;
288 auto true_scope =
289 std::make_shared<Scope>(true_stmt, prev_scope, ++conditionId_);
290 auto false_scope =
291 std::make_shared<Scope>(false_stmt, prev_scope, ++conditionId_);
292
293 if (true_stmt) {
294 currentScope_ = true_scope;
295 true_stmt->accept(this);
296 mergeHiddenScope(true);
297 mergeCurrentScopeIntoParent();
298 }
299 if (false_stmt) {
300 currentScope_ = false_scope;
301 false_stmt->accept(this);
302 mergeHiddenScope(true);
303 mergeCurrentScopeIntoParent();
304 }
305
306 // TODO: even though both scopes are conditional, we can merge accesses if
307 // they totally overlap in both branches, since we can guarantee one
308 // definition will be hit. We might need a 3-way merge? Not as simple as
309 // merging the true and false scopes together first.
310
311 stmtStack_.pop_front();
312 }
313
314 // IfThenElses are just like Conds except they are not Stmts, which means no
315 // registerization can occur internally. However, the first reference to an
316 // access can occur within one if its visible outside the condition.
visit(const IfThenElsePtr & v)317 void RegisterizerAnalysis::visit(const IfThenElsePtr& v) {
318 ExprPtr condition = v->condition();
319 ExprPtr true_value = v->true_value();
320 ExprPtr false_value = v->false_value();
321
322 // condition is in enclosing scope.
323 condition->accept(this);
324
325 auto prev_scope = currentScope_;
326 auto true_scope =
327 std::make_shared<Scope>(prev_scope->block(), prev_scope, ++conditionId_);
328 auto false_scope =
329 std::make_shared<Scope>(prev_scope->block(), prev_scope, ++conditionId_);
330
331 // We store IfThenElse scopes in a global map, which we use to prevent closing
332 // any access that would require inserting statements in the values, which
333 // cannot enclose Stmts.
334 exprConditionals_.insert(true_scope->conditionId());
335 exprConditionals_.insert(false_scope->conditionId());
336
337 if (true_value) {
338 currentScope_ = true_scope;
339 true_value->accept(this);
340 mergeHiddenScope(false);
341 mergeCurrentScopeIntoParent();
342 }
343
344 if (false_value) {
345 currentScope_ = false_scope;
346 false_value->accept(this);
347 mergeHiddenScope(false);
348 mergeCurrentScopeIntoParent();
349 }
350 }
351
visit(const LetPtr & v)352 void RegisterizerAnalysis::visit(const LetPtr& v) {
353 currentScope_->addLocalVar(v->var());
354
355 stmtStack_.push_front(v);
356 v->value()->accept(this);
357 stmtStack_.pop_front();
358 }
359
visit(const BlockPtr & v)360 void RegisterizerAnalysis::visit(const BlockPtr& v) {
361 auto prev_scope = currentScope_;
362 if (currentScope_->block() != v) {
363 currentScope_ = std::make_shared<Scope>(v, prev_scope);
364 }
365
366 stmtStack_.push_front(v);
367
368 for (const auto& s : *v) {
369 s->accept(this);
370 if (currentScope_->block() != v) {
371 // merge the inner block's accesses into this Block's accesses.
372 mergeCurrentScopeIntoParent();
373 }
374 }
375
376 stmtStack_.pop_front();
377
378 if (prev_scope->block() == nullptr) {
379 // close any open candidates.
380 for (auto& p1 : currentScope_->openAccesses()) {
381 for (auto& p2 : p1.second) {
382 closeAccessIntoScope(p2.second, currentScope_);
383 }
384 }
385 }
386 }
387
visit(const StorePtr & v)388 void RegisterizerAnalysis::visit(const StorePtr& v) {
389 stmtStack_.push_front(v);
390 v->value()->accept(this);
391 stmtStack_.pop_front();
392
393 if (v->indices().empty()) {
394 // already a scalar.
395 return;
396 }
397
398 // hash the Store:
399 SimplifierHashType accessHash = hasher_.hash(v->buf());
400 for (const auto& i : v->indices()) {
401 accessHash = hasher_.hash_combine(accessHash, i);
402 }
403
404 auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf());
405 auto candidateIt = bufAccesses.find(accessHash);
406
407 // If an identical access already exists, add this Store to it.
408 if (candidateIt != bufAccesses.end()) {
409 candidateIt->second->addStore(v, currentScope_);
410 return;
411 }
412
413 // Otherwise make a new AccessInfo and add this store.
414 auto info = std::make_shared<AccessInfo>(
415 accessHash, v->buf(), v->indices(), accessOrder_++);
416 info->addStore(v, currentScope_);
417
418 // This new access may overlap an existing open access, in which case we need
419 // to close the older of the two.
420 bool alreadyOverlapped = false;
421 for (auto it = bufAccesses.begin(); it != bufAccesses.end();) {
422 auto other = it->second;
423 if (info->overlaps(other)) {
424 if (other->last_usage() == v) {
425 // we are already overlapped by an access in the RHS.
426 alreadyOverlapped = true;
427 }
428 closeAccessIntoScope(other, currentScope_);
429 it = bufAccesses.erase(it);
430 } else {
431 ++it;
432 }
433 }
434
435 if (alreadyOverlapped) {
436 closeAccessIntoScope(info, currentScope_);
437 } else {
438 bufAccesses.emplace(accessHash, info);
439 }
440 }
441
visit(const LoadPtr & v)442 void RegisterizerAnalysis::visit(const LoadPtr& v) {
443 if (v->indices().empty()) {
444 // already a scalar.
445 return;
446 }
447 // hash the Load:
448 SimplifierHashType accessHash = hasher_.hash(v->buf());
449 for (const auto& i : v->indices()) {
450 accessHash = hasher_.hash_combine(accessHash, i);
451 }
452
453 auto& bufAccesses = currentScope_->getAccessMapByBuf(v->buf());
454 auto candidateIt = bufAccesses.find(accessHash);
455 if (candidateIt != bufAccesses.end()) {
456 // found the right access, can just insert.
457 candidateIt->second->addLoad(v, currentScope_, stmtStack_.front());
458 return;
459 }
460
461 std::shared_ptr<AccessInfo> info = std::make_shared<AccessInfo>(
462 accessHash, v->buf(), v->indices(), accessOrder_++);
463 info->addLoad(v, currentScope_, stmtStack_.front());
464
465 bool alreadyOverlapped = false;
466 // This new access may overlap an existing open access, in which case we need
467 // to finalize the older of the two.
468 for (auto it = bufAccesses.begin(); it != bufAccesses.end();) {
469 auto other = it->second;
470 if (info->overlaps(other)) {
471 if (info->last_usage() == other->last_usage()) {
472 // if these two accesses are from the same Stmt, they already overlap
473 // each other.
474 alreadyOverlapped = true;
475 }
476 closeAccessIntoScope(other, currentScope_);
477 it = bufAccesses.erase(it);
478 } else {
479 ++it;
480 }
481 }
482
483 if (alreadyOverlapped) {
484 closeAccessIntoScope(info, currentScope_);
485 } else {
486 bufAccesses.emplace(accessHash, info);
487 }
488 }
489
490 // Loop and Conditional scopes are different in that it may or may not be
491 // possible to hoist the initializer of a scalar variable outside the block
492 // depending on if we can tell that the Buffer access is valid outside. This is
493 // tricky because the access that demonstrates this may be later in the tree and
494 // we haven't encountered it yet.
495 // The allowClosed flag indicates whether we want to keep the closed accesses
496 // (For and Cond), or not (IfThenElse).
mergeHiddenScope(bool allowClosed)497 void RegisterizerAnalysis::mergeHiddenScope(bool allowClosed) {
498 // The rule is that if any access is closed within the conditional block, any
499 // accesses which overlap it must also be closed - since their initializer
500 // cannot be hoisted out of the block.
501 std::list<std::shared_ptr<AccessInfo>> newClosed;
502 for (auto& info : currentScope_->closedAccesses()) {
503 auto& candidates = currentScope_->getAccessMapByBuf(info->buf());
504 for (auto it = candidates.begin(); it != candidates.end();) {
505 std::shared_ptr<AccessInfo> candidate = it->second;
506
507 if (info->hash() == candidate->hash() || info->overlaps(candidate)) {
508 newClosed.push_back(candidate);
509 it = candidates.erase(it);
510 } else {
511 ++it;
512 }
513 }
514 }
515
516 if (allowClosed) {
517 for (auto& info : newClosed) {
518 closeAccessIntoScope(info, currentScope_);
519 }
520 } else {
521 currentScope_->closedAccesses().clear();
522 }
523 }
524
525 // Merge currentScope_ into it's parent, and make parent the new currentScope_.
mergeCurrentScopeIntoParent()526 void RegisterizerAnalysis::mergeCurrentScopeIntoParent() {
527 auto parent = currentScope_->parent();
528
529 // copy across current closed accesses, merging / closing as necessary
530 for (auto& candidate : currentScope_->closedAccesses()) {
531 auto& parentAccesses = parent->getAccessMapByBuf(candidate->buf());
532
533 auto parentIt = parentAccesses.find(candidate->hash());
534 if (parentIt != parentAccesses.end()) {
535 std::shared_ptr<AccessInfo> pCandidate = parentIt->second;
536
537 // if the access is closed inside a condition, it can only be merged if
538 // the parent is in the same condition.
539 if (candidate->conditionId() &&
540 pCandidate->conditionId() != candidate->conditionId()) {
541 // the parent's access must be closed.
542 closeAccessIntoScope(pCandidate, parent);
543 parentAccesses.erase(parentIt);
544
545 // the childs access inserted into the parent scope.
546 closeAccessIntoScope(candidate, parent);
547 continue;
548 }
549
550 // merge totally overlapping accesses.
551 parentIt->second->merge(candidate);
552 closeAccessIntoScope(parentIt->second, parent);
553 parentAccesses.erase(parentIt);
554 continue;
555 }
556
557 // we didn't find a perfect match, but we need to check all open accesses of
558 // this buf for partial overlap.
559 for (auto it = parentAccesses.begin(); it != parentAccesses.end();) {
560 std::shared_ptr<AccessInfo> pCandidate = it->second;
561 // Partial overlap of parent access: close parent access.
562 if (candidate->overlaps(pCandidate)) {
563 closeAccessIntoScope(pCandidate, parent);
564 it = parentAccesses.erase(it);
565 continue;
566 }
567 ++it;
568 }
569
570 // Insert the childs closed access into the parent scope.
571 closeAccessIntoScope(candidate, parent);
572 }
573
574 // copy across current open accesses, merging as necessary.
575 // for each Buf with an open access:
576 for (auto& pair : currentScope_->openAccesses()) {
577 BufPtr buf = pair.first;
578 if (pair.second.empty()) {
579 continue;
580 }
581
582 auto& parentAccesses = parent->getAccessMapByBuf(buf);
583
584 // for each open access in the child scope for this Buf:
585 for (auto& hpair : pair.second) {
586 bool handled{false};
587 std::shared_ptr<AccessInfo> candidate = hpair.second;
588
589 for (auto it = parentAccesses.begin(); it != parentAccesses.end();) {
590 std::shared_ptr<AccessInfo> pCandidate = it->second;
591
592 // If it completely overlaps then merge.
593 if (candidate->hash() == pCandidate->hash()) {
594 // if both accesses are found in conditional blocks, they cannot be
595 // merged, but the earlier must be closed.
596 if (pCandidate->conditionId() != parent->conditionId() &&
597 pCandidate->conditionId() != candidate->conditionId()) {
598 closeAccessIntoScope(pCandidate, parent);
599 it = parentAccesses.erase(it);
600 continue;
601 }
602 pCandidate->merge(candidate);
603 handled = true;
604 ++it;
605 continue;
606 }
607
608 // It can overlap an access in the parent: close the parent access.
609 // The child access may still be open.
610 if (candidate->overlaps(pCandidate)) {
611 closeAccessIntoScope(pCandidate, parent);
612 it = parentAccesses.erase(it);
613 continue;
614 }
615
616 ++it;
617 }
618
619 // If this access depends on a locally scoped variable, it cannot be
620 // lifted out of the loop.
621 for (const auto& v : currentScope_->localVars()) {
622 if (candidate->dependsOnVar(v)) {
623 closeAccessIntoScope(candidate, parent);
624 handled = true;
625 break;
626 }
627 }
628
629 if (!handled) {
630 // If the inner scope was not conditional, but the outer scope is: all
631 // current accesses are now conditional in the parent scope.
632 if (candidate->conditionId() == 0) {
633 candidate->setConditionId(parent->conditionId());
634 }
635 parentAccesses[candidate->hash()] = candidate;
636 }
637 }
638 }
639
640 currentScope_ = parent;
641 }
642
getCandidates()643 std::vector<std::shared_ptr<AccessInfo>> RegisterizerAnalysis::getCandidates() {
644 currentScope_->filterClosed();
645 std::sort(
646 currentScope_->closedAccesses().begin(),
647 currentScope_->closedAccesses().end(),
648 [](auto i1, auto i2) { return i1->accessOrder() < i2->accessOrder(); });
649 return currentScope_->closedAccesses();
650 }
651
652 // RegisterizerReplacer
653
mutate(const LoadPtr & v)654 ExprPtr RegisterizerReplacer::mutate(const LoadPtr& v) {
655 auto it = loadToAccess_.find(v);
656 if (it == loadToAccess_.end()) {
657 // This access cannot be registerized.
658 return v;
659 }
660
661 auto& info = it->second;
662
663 return info->replacement().var;
664 }
665
mutate(const StorePtr & v)666 StmtPtr RegisterizerReplacer::mutate(const StorePtr& v) {
667 if (eliminatedIntializers_.count(v) != 0) {
668 // This store is the initializer for a scalar var that is already inserted.
669 return nullptr;
670 }
671
672 auto it = storeToAccess_.find(v);
673 if (it == storeToAccess_.end()) {
674 // This access cannot be registerized.
675 return IRMutator::mutate(v);
676 }
677
678 auto& info = it->second;
679
680 ExprPtr new_val = v->value()->accept_mutator(this);
681
682 v->set_value(new_val);
683 v->set_buf(info->replacement().var_wrapper);
684 v->set_indices({});
685 return v;
686 }
687
mutate(const BlockPtr & v)688 StmtPtr RegisterizerReplacer::mutate(const BlockPtr& v) {
689 auto& scope = parentToAccesses_[v];
690
691 std::vector<StmtPtr> stmts;
692 for (const StmtPtr& stmt : v->stmts()) {
693 {
694 // Insert the initializer for any Scalars scoped to this block.
695 auto it = scope.initializerPoints_.find(stmt);
696 if (it != scope.initializerPoints_.end()) {
697 for (auto& info : it->second) {
698 StmtPtr initializer =
699 info->replacement().initializer->accept_mutator(this);
700 stmts.push_back(initializer);
701 }
702 scope.initializerPoints_.erase(it);
703 }
704 }
705
706 StmtPtr stmt_new = stmt->accept_mutator(this);
707 if (stmt_new) {
708 if (stmt_new->get_parent()) {
709 stmt_new = Stmt::clone(stmt_new);
710 }
711 stmts.push_back(stmt_new);
712 }
713
714 {
715 // Insert the finalizer for any Scalars scoped to this block.
716 auto it = scope.finalizePoints_.find(stmt);
717 if (it != scope.finalizePoints_.end()) {
718 for (auto& info : it->second) {
719 StorePtr finalizer = alloc<Store>(
720 info->buf(), info->indices(), info->replacement().var);
721 stmts.push_back(finalizer);
722 }
723 scope.finalizePoints_.erase(it);
724 }
725 }
726 }
727
728 return alloc<Block>(stmts);
729 }
730
buildReplacements()731 void RegisterizerReplacer::buildReplacements() {
732 // Traverse the list of replacements, creating vars and updating our local
733 // maps.
734 for (auto& info : infoSet_) {
735 VarPtr v = alloc<Var>(
736 info->buf()->name_hint() + "_" +
737 std::to_string(getBufferAccessCount(info->buf())),
738 info->buf()->dtype());
739
740 info->replacement().var = v;
741
742 // we need to wrap the Var in a Buf so we can Load or Store it.
743 info->replacement().var_wrapper =
744 alloc<Buf>(v, std::vector<ExprPtr>({}), info->buf()->dtype());
745
746 bool first = true;
747 for (const auto& s : info->stores()) {
748 if (first && info->first_usage() == s && !info->firstUsageOverlapped()) {
749 info->replacement().initializer = alloc<Let>(v, s->value());
750 eliminatedIntializers_.insert(s);
751 } else {
752 storeToAccess_[s] = info;
753 }
754
755 first = false;
756 }
757
758 for (const auto& s : info->loads()) {
759 loadToAccess_[s] = info;
760 }
761
762 auto& scope = parentToAccesses_[info->block()];
763 scope.initializerPoints_[info->first_usage()].push_back(info);
764
765 // Only finalize if the scalar is written.
766 if (!info->stores().empty()) {
767 // push front to finalize in reverse order of encounter.
768 scope.finalizePoints_[info->last_usage()].push_front(info);
769 }
770
771 // create a default initializer by reading the access.
772 if (info->replacement().initializer == nullptr) {
773 info->replacement().initializer = alloc<Let>(
774 v, alloc<Load>(info->buf()->dtype(), info->buf(), info->indices()));
775 }
776 }
777 }
778
779 } // namespace registerizer
780
781 // Apply scalar replacement to all accesses in s.
registerize(StmtPtr s)782 StmtPtr registerize(StmtPtr s) {
783 s = IRSimplifier::simplify(s);
784
785 // The outermost node must be a Block so we have somewhere to put outer scope
786 // scalars.
787 if (!to<Block>(s)) {
788 s = alloc<Block>(std::vector<StmtPtr>({s}));
789 }
790 registerizer::RegisterizerAnalysis analysis;
791 s->accept(&analysis);
792 auto candidates = analysis.getCandidates();
793
794 registerizer::RegisterizerReplacer replacer(candidates);
795 s = s->accept_mutator(&replacer);
796 return s;
797 }
798
799 } // namespace torch::jit::tensorexpr
800