• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/irange.h>
4 #include <torch/csrc/Export.h>
5 
6 #include <torch/csrc/jit/tensorexpr/hash_provider.h>
7 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
8 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
9 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
10 
11 #include <utility>
12 #include <vector>
13 
14 namespace torch::jit::tensorexpr {
15 namespace registerizer {
16 
17 /* The Registerizer performs scalar replacement by looking for common Stores and
18 Loads to a single item in a buffer and replacing them with a local temporary
19 scalar which is cheaper to write.
20 
21 For example it can replace:
22 
23 {
24   A[0] = 0;
25   for(const auto x : c10::irange(10)) {
26     A[0] = (A[0]) + x;
27   }
28 }
29 
30 with:
31 
32 {
33   int A_ = 0;
34   for(const auto x : c10::irange(10)) {
35     A_ = x + A_;
36   }
37   A[0] = A_;
38 }
39 
40 This is particularly useful on GPUs when parallelizing, since after replacing
41 loops with metavars we have a lot of accesses like this. */
42 
43 class Scope;
44 
45 /*  Holds analysis information about accesses to a specific range of a
46  buffer, including the number of loads and stores and the lowest common parent
47  Block.
48  */
49 class AccessInfo {
50  public:
51   AccessInfo() = default;
AccessInfo(SimplifierHashType h,BufPtr b,std::vector<ExprPtr> i,size_t accessOrder)52   AccessInfo(
53       SimplifierHashType h,
54       BufPtr b,
55       std::vector<ExprPtr> i,
56       size_t accessOrder)
57       : hash_(h),
58         buf_(std::move(b)),
59         indices_(std::move(i)),
60         store_cost_(alloc<IntImm>(0)),
61         load_cost_(alloc<IntImm>(0)),
62         accessOrder_(accessOrder) {}
63 
64   // Adds a Store to this access, which is in the provided scope.
65   void addStore(const StorePtr& store, const std::shared_ptr<Scope>& scope);
66 
67   // Adds a Load to this access, which occurs in the usage Stmt in the provided
68   // scope.
69   void addLoad(
70       const LoadPtr& load,
71       const std::shared_ptr<Scope>& scope,
72       const StmtPtr& usage);
73 
74   // Merge another AccessInfo into this one.
75   void merge(const std::shared_ptr<AccessInfo>& other);
76 
77   // Returns true if the other AccessInfo's bounds may overlap this one.
78   bool overlaps(const std::shared_ptr<AccessInfo>& other);
79 
80   // Returns true if the indices of this access depend on the provided Var.
81   bool dependsOnVar(const VarPtr& v);
82 
83   // Clone this AccessInfo, and set this as the new accesses' hiddenAccess.
84   static std::shared_ptr<AccessInfo> cloneWithHiddenInfo(
85       const std::shared_ptr<AccessInfo>& orig);
86 
87   // print for debugging.
88   void print() const;
89 
hash()90   SimplifierHashType hash() const {
91     return hash_;
92   }
93 
buf()94   BufPtr buf() const {
95     return buf_;
96   }
97 
indices()98   const std::vector<ExprPtr>& indices() const {
99     return indices_;
100   }
101 
block()102   BlockPtr block() const {
103     return block_;
104   }
105 
setEnclosingBlock(BlockPtr b)106   void setEnclosingBlock(BlockPtr b) {
107     block_ = std::move(b);
108   }
109 
first_usage()110   StmtPtr first_usage() const {
111     return first_usage_;
112   }
last_usage()113   StmtPtr last_usage() const {
114     return last_usage_;
115   }
116 
setUsageMarks(StmtPtr first,StmtPtr last)117   void setUsageMarks(StmtPtr first, StmtPtr last) {
118     first_usage_ = std::move(first);
119     last_usage_ = std::move(last);
120   }
121 
firstUsageOverlapped()122   bool firstUsageOverlapped() const {
123     return firstUsageOverlapped_;
124   }
125 
store_cost()126   ExprPtr store_cost() const {
127     return store_cost_;
128   }
129 
load_cost()130   ExprPtr load_cost() const {
131     return load_cost_;
132   }
133 
stores()134   const std::vector<StorePtr>& stores() const {
135     return stores_;
136   }
137 
loads()138   const std::vector<LoadPtr>& loads() const {
139     return loads_;
140   }
141 
hoistCosts(const ExprPtr & extent)142   void hoistCosts(const ExprPtr& extent) {
143     store_cost_ = IRSimplifier::simplify(alloc<Mul>(store_cost_, extent));
144     load_cost_ = IRSimplifier::simplify(alloc<Mul>(load_cost_, extent));
145   }
146 
conditionId()147   size_t conditionId() const {
148     return conditionId_;
149   }
150 
setConditionId(size_t c)151   void setConditionId(size_t c) {
152     conditionId_ = c;
153   }
154 
accessOrder()155   size_t accessOrder() const {
156     return accessOrder_;
157   }
158 
hiddenAccess()159   std::shared_ptr<AccessInfo> hiddenAccess() const {
160     return hiddenAccess_;
161   }
162 
163   // Holds state relating to the scalar variable we will insert to replace some
164   // number of loads and stores.
165   struct ScalarReplacement {
166     VarPtr var{nullptr};
167     BufPtr var_wrapper{nullptr};
168     LetPtr initializer{nullptr};
169   };
170 
replacement()171   ScalarReplacement& replacement() {
172     return replacement_;
173   }
174 
175  private:
176   SimplifierHashType hash_;
177   BufPtr buf_;
178   std::vector<ExprPtr> indices_;
179   BlockPtr block_{nullptr};
180 
181   StmtPtr first_usage_{nullptr};
182   StmtPtr last_usage_{nullptr};
183 
184   // Whether or not this access is overlapped in the first Stmt it appears. This
185   // means we cannot use it's first Store as the initializer.
186   bool firstUsageOverlapped_{false};
187 
188   // The cost in real ops that this access represents, to enable
189   // filtering accesses that wont save any loads or stores.
190   ExprPtr store_cost_;
191   ExprPtr load_cost_;
192 
193   // The actual Stores and Loads which represent this access.
194   // Be careful with these, any mutator will invalidate these pointers.
195   std::vector<StorePtr> stores_;
196   std::vector<LoadPtr> loads_;
197 
198   // An identifier representing the conditional block, if any, this access
199   // depends on.
200   size_t conditionId_{0};
201 
202   // An identifier representing the order this access was first encountered, for
203   // sorting returned results.
204   size_t accessOrder_{0};
205 
206   // Sometimes when traversing the tree we need to record what would happen if
207   // we hoisted an access, but sometimes it doesn't work out. This lets us
208   // "undo" some mutation and return to the internal hidden AccessInfo.
209   // It will be removed after any further additions to this AccessInfo.
210   std::shared_ptr<AccessInfo> hiddenAccess_;
211 
212   ScalarReplacement replacement_;
213 };
214 
215 using AccessHashMap =
216     std::unordered_map<SimplifierHashType, std::shared_ptr<AccessInfo>>;
217 
218 // Represents a scope block and holds all accesses contained within it.
219 class Scope {
220  public:
221   Scope(BlockPtr b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
block_(std::move (b))222       : block_(std::move(b)),
223         parent_(std::move(parent)),
224         conditionId_(conditionId) {}
225 
226   AccessHashMap& getAccessMapByBuf(const BufPtr& b);
227 
openAccesses()228   std::unordered_map<BufPtr, AccessHashMap>& openAccesses() {
229     return openAccesses_;
230   }
231 
closedAccesses()232   std::vector<std::shared_ptr<AccessInfo>>& closedAccesses() {
233     return closedAccesses_;
234   }
235 
block()236   BlockPtr block() const {
237     return block_;
238   }
239 
parent()240   std::shared_ptr<Scope> parent() const {
241     return parent_;
242   }
243 
conditionId()244   size_t conditionId() const {
245     return conditionId_;
246   }
247 
localVars()248   const std::unordered_set<VarPtr>& localVars() const {
249     return localVars_;
250   }
addLocalVar(VarPtr v)251   void addLocalVar(VarPtr v) {
252     localVars_.insert(std::move(v));
253   }
254 
255   void closeAccess(const std::shared_ptr<AccessInfo>& info);
256 
257   void filterClosed();
258 
259  private:
260   // Map of map to access, narrowing by Buf then by hash(Buf+Indices).
261   // This allows us to find a candidate access easily, and also check for
262   // overlap with other accesses to the same buf. Buf ->
263   //    Hash ->
264   //        Access
265   std::unordered_map<BufPtr, AccessHashMap> openAccesses_;
266   std::vector<std::shared_ptr<AccessInfo>> closedAccesses_;
267 
268   // The Block object this scope represents.
269   BlockPtr block_;
270 
271   // The enclosing scope object.
272   std::shared_ptr<Scope> parent_;
273 
274   // An identifier representing the condition block this scope depends on.
275   size_t conditionId_;
276 
277   // A set of variables local to this scope (e.g. loop vars).
278   std::unordered_set<VarPtr> localVars_;
279 };
280 
281 /* Analyzes the graph and collects accesses to the same symbolic tensor element
282  * which can be replaced by a single local scalar.
283  *
284  * This works by recursively walking the tree in postfix order, building sets of
285  * accesses to the same symbolic element by scope and then merging lower scopes
286  * into their enclosing scope.
287  *
288  * It is safe to move two accesses of the same Tensor element to a local scalar
289  * Var if between all usages of the element there are no other Loads or Stores
290  * that may refer to it. In the comments I refer to this as overlapping the
291  * access, or "cutting" the existing AccessInfo. In the case where a candidate
292  * for registerization is cut, it may be possible to finalize the access early
293  * by writing it back to the Tensor and then create a new scalar variable after
294  * the overlapping access is complete. We will attempt to do this when it saves
295  * memory accesses.
296  *
297  * There are a few cases that make this more challenging:
298  *
299  *  - For: Loops change the number of real usages of a buffer by the loop
300  * extent, but only if we can pull the definition and finalization of the scalar
301  * variable out of the loop block.
302  *
303  * - Cond: Conditions complicate lifting scalars out of internal scopes.
304  * Generally we cannot lift an access outside of a conditional scope unless
305  * there is already a reference to that same access at the higher scope, since
306  * we don't know if the condition was guarding an array access not safe at the
307  * higher scope. In the comments I refer to this as the condition "hiding" the
308  * access, and the outer access "unhiding" it.
309  *
310  * - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr
311  * rather than a Stmt we cannot insert the scalar definition or finalizer
312  * within the conditional scope. Accesses inside an IfThenElse can be safely
313  * combined with external accesses but cannot exist completely within.
314  *
315  * - Let: Accesses dependent on local variables via Let Stmts, or loop vars,
316  * cannot be raised outside of the scope of the dependent var.
317  */
318 class TORCH_API RegisterizerAnalysis : public IRVisitor {
319  public:
RegisterizerAnalysis()320   RegisterizerAnalysis()
321       : currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {}
322   ~RegisterizerAnalysis() override = default;
323 
324   void visit(const ForPtr& v) override;
325 
326   void visit(const CondPtr& v) override;
327 
328   void visit(const BlockPtr& v) override;
329 
330   void visit(const StorePtr& v) override;
331 
332   void visit(const LoadPtr& v) override;
333 
334   void visit(const IfThenElsePtr& v) override;
335 
336   void visit(const LetPtr& v) override;
337 
338 #define STMT_ON_STACK(Op)                 \
339   void visit(const Op##Ptr& v) override { \
340     stmtStack_.push_front(v);             \
341     IRVisitor::visit(v);                  \
342     stmtStack_.pop_front();               \
343   }
344 
345   STMT_ON_STACK(AtomicAdd);
346   STMT_ON_STACK(Allocate);
347   STMT_ON_STACK(Free);
348 
349 #undef STMT_ON_STACK
350 
351   std::vector<std::shared_ptr<AccessInfo>> getCandidates();
352 
353  private:
354   void mergeCurrentScopeIntoParent();
355   void mergeHiddenScope(bool allowClosed);
356   void closeAccessIntoScope(
357       const std::shared_ptr<AccessInfo>& info,
358       const std::shared_ptr<Scope>& scope);
359 
360   std::unordered_set<size_t> exprConditionals_;
361 
362   // A stack of enclosing Stmts for tracking the usage Stmt of Loads.
363   std::deque<StmtPtr> stmtStack_;
364 
365   // The current scope being analyzed.
366   std::shared_ptr<Scope> currentScope_;
367 
368   HashProvider hasher_;
369 
370   size_t conditionId_{0};
371   size_t accessOrder_{0};
372 };
373 
374 /* Replaces each registerizable access with a Scalar variable, including
375  * definition, initializer and finalizer.
376  */
377 class TORCH_API RegisterizerReplacer : public IRMutator {
378  public:
RegisterizerReplacer(std::vector<std::shared_ptr<AccessInfo>> & vec)379   RegisterizerReplacer(std::vector<std::shared_ptr<AccessInfo>>& vec)
380       : infoSet_(vec) {
381     buildReplacements();
382   }
383 
384   ExprPtr mutate(const LoadPtr& v) override;
385 
386   StmtPtr mutate(const StorePtr& v) override;
387 
388   StmtPtr mutate(const BlockPtr& v) override;
389 
390  private:
391   struct ReplacerScope {
392     std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>>
393         initializerPoints_;
394     std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>>
395         finalizePoints_;
396   };
397 
398   // Creates the various ReplacerScope objects and builds internal maps.
399   void buildReplacements();
400 
401   // State relating to the accesses yet to be replaced.
402   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
403   std::vector<std::shared_ptr<AccessInfo>>& infoSet_;
404   std::unordered_map<StorePtr, std::shared_ptr<AccessInfo>> storeToAccess_;
405   std::unordered_map<LoadPtr, std::shared_ptr<AccessInfo>> loadToAccess_;
406   std::unordered_map<BlockPtr, ReplacerScope> parentToAccesses_;
407 
408   // Holds the set of Stores that should be pulled into an initializer, so they
409   // can be eliminated.
410   std::set<StorePtr> eliminatedIntializers_;
411 
412   // Tracks the number of times we've seen each buffer, so we can name the
413   // scalar Vars appropriately.
414   std::unordered_map<BufPtr, unsigned int> bufferAccessCounts_;
getBufferAccessCount(const BufPtr & b)415   unsigned int getBufferAccessCount(const BufPtr& b) {
416     return ++bufferAccessCounts_[b];
417   }
418 };
419 } // namespace registerizer
420 
421 // Apply scalar replacement to all accesses in s.
422 // To produce safe code, this must occur after handling parallelized axes and
423 // atomics.
424 TORCH_API StmtPtr registerize(StmtPtr s);
425 
426 } // namespace torch::jit::tensorexpr
427