1 #pragma once 2 #include <c10/core/ScalarType.h> 3 #include <c10/util/irange.h> 4 #include <torch/csrc/Export.h> 5 6 #include <torch/csrc/jit/tensorexpr/hash_provider.h> 7 #include <torch/csrc/jit/tensorexpr/ir_mutator.h> 8 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> 9 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 10 11 #include <utility> 12 #include <vector> 13 14 namespace torch::jit::tensorexpr { 15 namespace registerizer { 16 17 /* The Registerizer performs scalar replacement by looking for common Stores and 18 Loads to a single item in a buffer and replacing them with a local temporary 19 scalar which is cheaper to write. 20 21 For example it can replace: 22 23 { 24 A[0] = 0; 25 for(const auto x : c10::irange(10)) { 26 A[0] = (A[0]) + x; 27 } 28 } 29 30 with: 31 32 { 33 int A_ = 0; 34 for(const auto x : c10::irange(10)) { 35 A_ = x + A_; 36 } 37 A[0] = A_; 38 } 39 40 This is particularly useful on GPUs when parallelizing, since after replacing 41 loops with metavars we have a lot of accesses like this. */ 42 43 class Scope; 44 45 /* Holds analysis information about accesses to a specific range of a 46 buffer, including the number of loads and stores and the lowest common parent 47 Block. 48 */ 49 class AccessInfo { 50 public: 51 AccessInfo() = default; AccessInfo(SimplifierHashType h,BufPtr b,std::vector<ExprPtr> i,size_t accessOrder)52 AccessInfo( 53 SimplifierHashType h, 54 BufPtr b, 55 std::vector<ExprPtr> i, 56 size_t accessOrder) 57 : hash_(h), 58 buf_(std::move(b)), 59 indices_(std::move(i)), 60 store_cost_(alloc<IntImm>(0)), 61 load_cost_(alloc<IntImm>(0)), 62 accessOrder_(accessOrder) {} 63 64 // Adds a Store to this access, which is in the provided scope. 65 void addStore(const StorePtr& store, const std::shared_ptr<Scope>& scope); 66 67 // Adds a Load to this access, which occurs in the usage Stmt in the provided 68 // scope. 69 void addLoad( 70 const LoadPtr& load, 71 const std::shared_ptr<Scope>& scope, 72 const StmtPtr& usage); 73 74 // Merge another AccessInfo into this one. 75 void merge(const std::shared_ptr<AccessInfo>& other); 76 77 // Returns true if the other AccessInfo's bounds may overlap this one. 78 bool overlaps(const std::shared_ptr<AccessInfo>& other); 79 80 // Returns true if the indices of this access depend on the provided Var. 81 bool dependsOnVar(const VarPtr& v); 82 83 // Clone this AccessInfo, and set this as the new accesses' hiddenAccess. 84 static std::shared_ptr<AccessInfo> cloneWithHiddenInfo( 85 const std::shared_ptr<AccessInfo>& orig); 86 87 // print for debugging. 88 void print() const; 89 hash()90 SimplifierHashType hash() const { 91 return hash_; 92 } 93 buf()94 BufPtr buf() const { 95 return buf_; 96 } 97 indices()98 const std::vector<ExprPtr>& indices() const { 99 return indices_; 100 } 101 block()102 BlockPtr block() const { 103 return block_; 104 } 105 setEnclosingBlock(BlockPtr b)106 void setEnclosingBlock(BlockPtr b) { 107 block_ = std::move(b); 108 } 109 first_usage()110 StmtPtr first_usage() const { 111 return first_usage_; 112 } last_usage()113 StmtPtr last_usage() const { 114 return last_usage_; 115 } 116 setUsageMarks(StmtPtr first,StmtPtr last)117 void setUsageMarks(StmtPtr first, StmtPtr last) { 118 first_usage_ = std::move(first); 119 last_usage_ = std::move(last); 120 } 121 firstUsageOverlapped()122 bool firstUsageOverlapped() const { 123 return firstUsageOverlapped_; 124 } 125 store_cost()126 ExprPtr store_cost() const { 127 return store_cost_; 128 } 129 load_cost()130 ExprPtr load_cost() const { 131 return load_cost_; 132 } 133 stores()134 const std::vector<StorePtr>& stores() const { 135 return stores_; 136 } 137 loads()138 const std::vector<LoadPtr>& loads() const { 139 return loads_; 140 } 141 hoistCosts(const ExprPtr & extent)142 void hoistCosts(const ExprPtr& extent) { 143 store_cost_ = IRSimplifier::simplify(alloc<Mul>(store_cost_, extent)); 144 load_cost_ = IRSimplifier::simplify(alloc<Mul>(load_cost_, extent)); 145 } 146 conditionId()147 size_t conditionId() const { 148 return conditionId_; 149 } 150 setConditionId(size_t c)151 void setConditionId(size_t c) { 152 conditionId_ = c; 153 } 154 accessOrder()155 size_t accessOrder() const { 156 return accessOrder_; 157 } 158 hiddenAccess()159 std::shared_ptr<AccessInfo> hiddenAccess() const { 160 return hiddenAccess_; 161 } 162 163 // Holds state relating to the scalar variable we will insert to replace some 164 // number of loads and stores. 165 struct ScalarReplacement { 166 VarPtr var{nullptr}; 167 BufPtr var_wrapper{nullptr}; 168 LetPtr initializer{nullptr}; 169 }; 170 replacement()171 ScalarReplacement& replacement() { 172 return replacement_; 173 } 174 175 private: 176 SimplifierHashType hash_; 177 BufPtr buf_; 178 std::vector<ExprPtr> indices_; 179 BlockPtr block_{nullptr}; 180 181 StmtPtr first_usage_{nullptr}; 182 StmtPtr last_usage_{nullptr}; 183 184 // Whether or not this access is overlapped in the first Stmt it appears. This 185 // means we cannot use it's first Store as the initializer. 186 bool firstUsageOverlapped_{false}; 187 188 // The cost in real ops that this access represents, to enable 189 // filtering accesses that wont save any loads or stores. 190 ExprPtr store_cost_; 191 ExprPtr load_cost_; 192 193 // The actual Stores and Loads which represent this access. 194 // Be careful with these, any mutator will invalidate these pointers. 195 std::vector<StorePtr> stores_; 196 std::vector<LoadPtr> loads_; 197 198 // An identifier representing the conditional block, if any, this access 199 // depends on. 200 size_t conditionId_{0}; 201 202 // An identifier representing the order this access was first encountered, for 203 // sorting returned results. 204 size_t accessOrder_{0}; 205 206 // Sometimes when traversing the tree we need to record what would happen if 207 // we hoisted an access, but sometimes it doesn't work out. This lets us 208 // "undo" some mutation and return to the internal hidden AccessInfo. 209 // It will be removed after any further additions to this AccessInfo. 210 std::shared_ptr<AccessInfo> hiddenAccess_; 211 212 ScalarReplacement replacement_; 213 }; 214 215 using AccessHashMap = 216 std::unordered_map<SimplifierHashType, std::shared_ptr<AccessInfo>>; 217 218 // Represents a scope block and holds all accesses contained within it. 219 class Scope { 220 public: 221 Scope(BlockPtr b, std::shared_ptr<Scope> parent, size_t conditionId = 0) block_(std::move (b))222 : block_(std::move(b)), 223 parent_(std::move(parent)), 224 conditionId_(conditionId) {} 225 226 AccessHashMap& getAccessMapByBuf(const BufPtr& b); 227 openAccesses()228 std::unordered_map<BufPtr, AccessHashMap>& openAccesses() { 229 return openAccesses_; 230 } 231 closedAccesses()232 std::vector<std::shared_ptr<AccessInfo>>& closedAccesses() { 233 return closedAccesses_; 234 } 235 block()236 BlockPtr block() const { 237 return block_; 238 } 239 parent()240 std::shared_ptr<Scope> parent() const { 241 return parent_; 242 } 243 conditionId()244 size_t conditionId() const { 245 return conditionId_; 246 } 247 localVars()248 const std::unordered_set<VarPtr>& localVars() const { 249 return localVars_; 250 } addLocalVar(VarPtr v)251 void addLocalVar(VarPtr v) { 252 localVars_.insert(std::move(v)); 253 } 254 255 void closeAccess(const std::shared_ptr<AccessInfo>& info); 256 257 void filterClosed(); 258 259 private: 260 // Map of map to access, narrowing by Buf then by hash(Buf+Indices). 261 // This allows us to find a candidate access easily, and also check for 262 // overlap with other accesses to the same buf. Buf -> 263 // Hash -> 264 // Access 265 std::unordered_map<BufPtr, AccessHashMap> openAccesses_; 266 std::vector<std::shared_ptr<AccessInfo>> closedAccesses_; 267 268 // The Block object this scope represents. 269 BlockPtr block_; 270 271 // The enclosing scope object. 272 std::shared_ptr<Scope> parent_; 273 274 // An identifier representing the condition block this scope depends on. 275 size_t conditionId_; 276 277 // A set of variables local to this scope (e.g. loop vars). 278 std::unordered_set<VarPtr> localVars_; 279 }; 280 281 /* Analyzes the graph and collects accesses to the same symbolic tensor element 282 * which can be replaced by a single local scalar. 283 * 284 * This works by recursively walking the tree in postfix order, building sets of 285 * accesses to the same symbolic element by scope and then merging lower scopes 286 * into their enclosing scope. 287 * 288 * It is safe to move two accesses of the same Tensor element to a local scalar 289 * Var if between all usages of the element there are no other Loads or Stores 290 * that may refer to it. In the comments I refer to this as overlapping the 291 * access, or "cutting" the existing AccessInfo. In the case where a candidate 292 * for registerization is cut, it may be possible to finalize the access early 293 * by writing it back to the Tensor and then create a new scalar variable after 294 * the overlapping access is complete. We will attempt to do this when it saves 295 * memory accesses. 296 * 297 * There are a few cases that make this more challenging: 298 * 299 * - For: Loops change the number of real usages of a buffer by the loop 300 * extent, but only if we can pull the definition and finalization of the scalar 301 * variable out of the loop block. 302 * 303 * - Cond: Conditions complicate lifting scalars out of internal scopes. 304 * Generally we cannot lift an access outside of a conditional scope unless 305 * there is already a reference to that same access at the higher scope, since 306 * we don't know if the condition was guarding an array access not safe at the 307 * higher scope. In the comments I refer to this as the condition "hiding" the 308 * access, and the outer access "unhiding" it. 309 * 310 * - IfThenElse: Same situation as Cond, except since IfThenElse is an Expr 311 * rather than a Stmt we cannot insert the scalar definition or finalizer 312 * within the conditional scope. Accesses inside an IfThenElse can be safely 313 * combined with external accesses but cannot exist completely within. 314 * 315 * - Let: Accesses dependent on local variables via Let Stmts, or loop vars, 316 * cannot be raised outside of the scope of the dependent var. 317 */ 318 class TORCH_API RegisterizerAnalysis : public IRVisitor { 319 public: RegisterizerAnalysis()320 RegisterizerAnalysis() 321 : currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {} 322 ~RegisterizerAnalysis() override = default; 323 324 void visit(const ForPtr& v) override; 325 326 void visit(const CondPtr& v) override; 327 328 void visit(const BlockPtr& v) override; 329 330 void visit(const StorePtr& v) override; 331 332 void visit(const LoadPtr& v) override; 333 334 void visit(const IfThenElsePtr& v) override; 335 336 void visit(const LetPtr& v) override; 337 338 #define STMT_ON_STACK(Op) \ 339 void visit(const Op##Ptr& v) override { \ 340 stmtStack_.push_front(v); \ 341 IRVisitor::visit(v); \ 342 stmtStack_.pop_front(); \ 343 } 344 345 STMT_ON_STACK(AtomicAdd); 346 STMT_ON_STACK(Allocate); 347 STMT_ON_STACK(Free); 348 349 #undef STMT_ON_STACK 350 351 std::vector<std::shared_ptr<AccessInfo>> getCandidates(); 352 353 private: 354 void mergeCurrentScopeIntoParent(); 355 void mergeHiddenScope(bool allowClosed); 356 void closeAccessIntoScope( 357 const std::shared_ptr<AccessInfo>& info, 358 const std::shared_ptr<Scope>& scope); 359 360 std::unordered_set<size_t> exprConditionals_; 361 362 // A stack of enclosing Stmts for tracking the usage Stmt of Loads. 363 std::deque<StmtPtr> stmtStack_; 364 365 // The current scope being analyzed. 366 std::shared_ptr<Scope> currentScope_; 367 368 HashProvider hasher_; 369 370 size_t conditionId_{0}; 371 size_t accessOrder_{0}; 372 }; 373 374 /* Replaces each registerizable access with a Scalar variable, including 375 * definition, initializer and finalizer. 376 */ 377 class TORCH_API RegisterizerReplacer : public IRMutator { 378 public: RegisterizerReplacer(std::vector<std::shared_ptr<AccessInfo>> & vec)379 RegisterizerReplacer(std::vector<std::shared_ptr<AccessInfo>>& vec) 380 : infoSet_(vec) { 381 buildReplacements(); 382 } 383 384 ExprPtr mutate(const LoadPtr& v) override; 385 386 StmtPtr mutate(const StorePtr& v) override; 387 388 StmtPtr mutate(const BlockPtr& v) override; 389 390 private: 391 struct ReplacerScope { 392 std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>> 393 initializerPoints_; 394 std::unordered_map<StmtPtr, std::deque<std::shared_ptr<AccessInfo>>> 395 finalizePoints_; 396 }; 397 398 // Creates the various ReplacerScope objects and builds internal maps. 399 void buildReplacements(); 400 401 // State relating to the accesses yet to be replaced. 402 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 403 std::vector<std::shared_ptr<AccessInfo>>& infoSet_; 404 std::unordered_map<StorePtr, std::shared_ptr<AccessInfo>> storeToAccess_; 405 std::unordered_map<LoadPtr, std::shared_ptr<AccessInfo>> loadToAccess_; 406 std::unordered_map<BlockPtr, ReplacerScope> parentToAccesses_; 407 408 // Holds the set of Stores that should be pulled into an initializer, so they 409 // can be eliminated. 410 std::set<StorePtr> eliminatedIntializers_; 411 412 // Tracks the number of times we've seen each buffer, so we can name the 413 // scalar Vars appropriately. 414 std::unordered_map<BufPtr, unsigned int> bufferAccessCounts_; getBufferAccessCount(const BufPtr & b)415 unsigned int getBufferAccessCount(const BufPtr& b) { 416 return ++bufferAccessCounts_[b]; 417 } 418 }; 419 } // namespace registerizer 420 421 // Apply scalar replacement to all accesses in s. 422 // To produce safe code, this must occur after handling parallelized axes and 423 // atomics. 424 TORCH_API StmtPtr registerize(StmtPtr s); 425 426 } // namespace torch::jit::tensorexpr 427