1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the AffineExpr visitor class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H 14 #define MLIR_IR_AFFINE_EXPR_VISITOR_H 15 16 #include "mlir/IR/AffineExpr.h" 17 18 namespace mlir { 19 20 /// Base class for AffineExpr visitors/walkers. 21 /// 22 /// AffineExpr visitors are used when you want to perform different actions 23 /// for different kinds of AffineExprs without having to use lots of casts 24 /// and a big switch instruction. 25 /// 26 /// To define your own visitor, inherit from this class, specifying your 27 /// new type for the 'SubClass' template parameter, and "override" visitXXX 28 /// functions in your class. This class is defined in terms of statically 29 /// resolved overloading, not virtual functions. 30 /// 31 /// For example, here is a visitor that counts the number of for AffineDimExprs 32 /// in an AffineExpr. 33 /// 34 /// /// Declare the class. Note that we derive from AffineExprVisitor 35 /// /// instantiated with our new subclasses_ type. 36 /// 37 /// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> { 38 /// unsigned numDimExprs; 39 /// DimExprCounter() : numDimExprs(0) {} 40 /// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; } 41 /// }; 42 /// 43 /// And this class would be used like this: 44 /// DimExprCounter dec; 45 /// dec.visit(affineExpr); 46 /// numDimExprs = dec.numDimExprs; 47 /// 48 /// AffineExprVisitor provides visit methods for the following binary affine 49 /// op expressions: 50 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr, 51 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr, 52 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these 53 /// methods will call the general AffineBinaryOpExpr method. 54 /// 55 /// In addition, visit methods are provided for the following affine 56 // expressions: AffineConstantExpr, AffineDimExpr, and 57 // AffineSymbolExpr. 58 /// 59 /// Note that if you don't implement visitXXX for some affine expression type, 60 /// the visitXXX method for Instruction superclass will be invoked. 61 /// 62 /// Note that this class is specifically designed as a template to avoid 63 /// virtual function call overhead. Defining and using a AffineExprVisitor is 64 /// just as efficient as having your own switch instruction over the instruction 65 /// opcode. 66 67 template <typename SubClass, typename RetTy = void> class AffineExprVisitor { 68 //===--------------------------------------------------------------------===// 69 // Interface code - This is the public interface of the AffineExprVisitor 70 // that you use to visit affine expressions... 71 public: 72 // Function to walk an AffineExpr (in post order). walkPostOrder(AffineExpr expr)73 RetTy walkPostOrder(AffineExpr expr) { 74 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, 75 "Must instantiate with a derived type of AffineExprVisitor"); 76 switch (expr.getKind()) { 77 case AffineExprKind::Add: { 78 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 79 walkOperandsPostOrder(binOpExpr); 80 return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); 81 } 82 case AffineExprKind::Mul: { 83 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 84 walkOperandsPostOrder(binOpExpr); 85 return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); 86 } 87 case AffineExprKind::Mod: { 88 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 89 walkOperandsPostOrder(binOpExpr); 90 return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); 91 } 92 case AffineExprKind::FloorDiv: { 93 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 94 walkOperandsPostOrder(binOpExpr); 95 return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); 96 } 97 case AffineExprKind::CeilDiv: { 98 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 99 walkOperandsPostOrder(binOpExpr); 100 return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); 101 } 102 case AffineExprKind::Constant: 103 return static_cast<SubClass *>(this)->visitConstantExpr( 104 expr.cast<AffineConstantExpr>()); 105 case AffineExprKind::DimId: 106 return static_cast<SubClass *>(this)->visitDimExpr( 107 expr.cast<AffineDimExpr>()); 108 case AffineExprKind::SymbolId: 109 return static_cast<SubClass *>(this)->visitSymbolExpr( 110 expr.cast<AffineSymbolExpr>()); 111 } 112 } 113 114 // Function to visit an AffineExpr. visit(AffineExpr expr)115 RetTy visit(AffineExpr expr) { 116 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, 117 "Must instantiate with a derived type of AffineExprVisitor"); 118 switch (expr.getKind()) { 119 case AffineExprKind::Add: { 120 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 121 return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); 122 } 123 case AffineExprKind::Mul: { 124 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 125 return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); 126 } 127 case AffineExprKind::Mod: { 128 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 129 return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); 130 } 131 case AffineExprKind::FloorDiv: { 132 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 133 return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); 134 } 135 case AffineExprKind::CeilDiv: { 136 auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); 137 return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); 138 } 139 case AffineExprKind::Constant: 140 return static_cast<SubClass *>(this)->visitConstantExpr( 141 expr.cast<AffineConstantExpr>()); 142 case AffineExprKind::DimId: 143 return static_cast<SubClass *>(this)->visitDimExpr( 144 expr.cast<AffineDimExpr>()); 145 case AffineExprKind::SymbolId: 146 return static_cast<SubClass *>(this)->visitSymbolExpr( 147 expr.cast<AffineSymbolExpr>()); 148 } 149 llvm_unreachable("Unknown AffineExpr"); 150 } 151 152 //===--------------------------------------------------------------------===// 153 // Visitation functions... these functions provide default fallbacks in case 154 // the user does not specify what to do for a particular instruction type. 155 // The default behavior is to generalize the instruction type to its subtype 156 // and try visiting the subtype. All of this should be inlined perfectly, 157 // because there are no virtual functions to get in the way. 158 // 159 160 // Default visit methods. Note that the default op-specific binary op visit 161 // methods call the general visitAffineBinaryOpExpr visit method. visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)162 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} visitAddExpr(AffineBinaryOpExpr expr)163 void visitAddExpr(AffineBinaryOpExpr expr) { 164 static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 165 } visitMulExpr(AffineBinaryOpExpr expr)166 void visitMulExpr(AffineBinaryOpExpr expr) { 167 static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 168 } visitModExpr(AffineBinaryOpExpr expr)169 void visitModExpr(AffineBinaryOpExpr expr) { 170 static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 171 } visitFloorDivExpr(AffineBinaryOpExpr expr)172 void visitFloorDivExpr(AffineBinaryOpExpr expr) { 173 static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 174 } visitCeilDivExpr(AffineBinaryOpExpr expr)175 void visitCeilDivExpr(AffineBinaryOpExpr expr) { 176 static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); 177 } visitConstantExpr(AffineConstantExpr expr)178 void visitConstantExpr(AffineConstantExpr expr) {} visitDimExpr(AffineDimExpr expr)179 void visitDimExpr(AffineDimExpr expr) {} visitSymbolExpr(AffineSymbolExpr expr)180 void visitSymbolExpr(AffineSymbolExpr expr) {} 181 182 private: 183 // Walk the operands - each operand is itself walked in post order. walkOperandsPostOrder(AffineBinaryOpExpr expr)184 void walkOperandsPostOrder(AffineBinaryOpExpr expr) { 185 walkPostOrder(expr.getLHS()); 186 walkPostOrder(expr.getRHS()); 187 } 188 }; 189 190 // This class is used to flatten a pure affine expression (AffineExpr, 191 // which is in a tree form) into a sum of products (w.r.t constants) when 192 // possible, and in that process simplifying the expression. For a modulo, 193 // floordiv, or a ceildiv expression, an additional identifier, called a local 194 // identifier, is introduced to rewrite the expression as a sum of product 195 // affine expression. Each local identifier is always and by construction a 196 // floordiv of a pure add/mul affine function of dimensional, symbolic, and 197 // other local identifiers, in a non-mutually recursive way. Hence, every local 198 // identifier can ultimately always be recovered as an affine function of 199 // dimensional and symbolic identifiers (involving floordiv's); note however 200 // that by AffineExpr construction, some floordiv combinations are converted to 201 // mod's. The result of the flattening is a flattened expression and a set of 202 // constraints involving just the local variables. 203 // 204 // d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local 205 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3. 206 // 207 // The simplification performed includes the accumulation of contributions for 208 // each dimensional and symbolic identifier together, the simplification of 209 // floordiv/ceildiv/mod expressions and other simplifications that in turn 210 // happen as a result. A simplification that this flattening naturally performs 211 // is of simplifying the numerator and denominator of floordiv/ceildiv, and 212 // folding a modulo expression to a zero, if possible. Three examples are below: 213 // 214 // (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1 215 // (d0 - d0 mod 4 + 4) mod 4 simplified to 0 216 // (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1 217 // 218 // The way the flattening works for the second example is as follows: d0 % 4 is 219 // replaced by d0 - 4*q with q being introduced: the expression then simplifies 220 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to 221 // zero. Note that an affine expression may not always be expressible purely as 222 // a sum of products involving just the original dimensional and symbolic 223 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that 224 // may not be eliminated after simplification; in such cases, the final 225 // expression can be reconstructed by replacing the local identifiers with their 226 // corresponding explicit form stored in 'localExprs' (note that each of the 227 // explicit forms itself would have been simplified). 228 // 229 // The expression walk method here performs a linear time post order walk that 230 // performs the above simplifications through visit methods, with partial 231 // results being stored in 'operandExprStack'. When a parent expr is visited, 232 // the flattened expressions corresponding to its two operands would already be 233 // on the stack - the parent expression looks at the two flattened expressions 234 // and combines the two. It pops off the operand expressions and pushes the 235 // combined result (although this is done in-place on its LHS operand expr). 236 // When the walk is completed, the flattened form of the top-level expression 237 // would be left on the stack. 238 // 239 // A flattener can be repeatedly used for multiple affine expressions that bind 240 // to the same operands, for example, for all result expressions of an 241 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions 242 // is more efficient than creating a new flattener for each expression since 243 // common identical div and mod expressions appearing across different 244 // expressions are mapped to the same local identifier (same column position in 245 // 'localVarCst'). 246 class SimpleAffineExprFlattener 247 : public AffineExprVisitor<SimpleAffineExprFlattener> { 248 public: 249 // Flattend expression layout: [dims, symbols, locals, constant] 250 // Stack that holds the LHS and RHS operands while visiting a binary op expr. 251 // In future, consider adding a prepass to determine how big the SmallVector's 252 // will be, and linearize this to std::vector<int64_t> to prevent 253 // SmallVector moves on re-allocation. 254 std::vector<SmallVector<int64_t, 8>> operandExprStack; 255 256 unsigned numDims; 257 unsigned numSymbols; 258 259 // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's. 260 unsigned numLocals; 261 262 // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for 263 // which new identifiers were introduced; if the latter do not get canceled 264 // out, these expressions can be readily used to reconstruct the AffineExpr 265 // (tree) form. Note that these expressions themselves would have been 266 // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 267 // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) 268 // ceildiv 2 would be the local expression stored for q. 269 SmallVector<AffineExpr, 4> localExprs; 270 271 SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols); 272 273 virtual ~SimpleAffineExprFlattener() = default; 274 275 // Visitor method overrides. 276 void visitMulExpr(AffineBinaryOpExpr expr); 277 void visitAddExpr(AffineBinaryOpExpr expr); 278 void visitDimExpr(AffineDimExpr expr); 279 void visitSymbolExpr(AffineSymbolExpr expr); 280 void visitConstantExpr(AffineConstantExpr expr); 281 void visitCeilDivExpr(AffineBinaryOpExpr expr); 282 void visitFloorDivExpr(AffineBinaryOpExpr expr); 283 284 // 285 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 286 // 287 // A mod expression "expr mod c" is thus flattened by introducing a new local 288 // variable q (= expr floordiv c), such that expr mod c is replaced with 289 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 290 void visitModExpr(AffineBinaryOpExpr expr); 291 292 protected: 293 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 294 // The local identifier added is always a floordiv of a pure add/mul affine 295 // function of other identifiers, coefficients of which are specified in 296 // dividend and with respect to a positive constant divisor. localExpr is the 297 // simplified tree expression (AffineExpr) corresponding to the quantifier. 298 virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 299 AffineExpr localExpr); 300 301 private: 302 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 303 // A floordiv is thus flattened by introducing a new local variable q, and 304 // replacing that expression with 'q' while adding the constraints 305 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 306 // FlatAffineConstraints::addLocalFloorDiv). 307 // 308 // A ceildiv is similarly flattened: 309 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 310 void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); 311 312 int findLocalId(AffineExpr localExpr); 313 getNumCols()314 inline unsigned getNumCols() const { 315 return numDims + numSymbols + numLocals + 1; 316 } getConstantIndex()317 inline unsigned getConstantIndex() const { return getNumCols() - 1; } getLocalVarStartIndex()318 inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; } getSymbolStartIndex()319 inline unsigned getSymbolStartIndex() const { return numDims; } getDimStartIndex()320 inline unsigned getDimStartIndex() const { return 0; } 321 }; 322 323 } // end namespace mlir 324 325 #endif // MLIR_IR_AFFINE_EXPR_VISITOR_H 326