1 //===- Sparsification.cpp - Implementation of linalg sparsification -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering annotated linalg dialect to sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
46 #include "mlir/Dialect/Linalg/Utils/Utils.h"
47 #include "mlir/Dialect/SCF/SCF.h"
48 #include "mlir/Dialect/StandardOps/IR/Ops.h"
49
50 using namespace mlir;
51
52 namespace {
53
54 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
55
56 /// Tensor expression. Represents a MLIR expression in tensor index notation.
57 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
58 /// stored directly. For binary operations, e0 and e1 denote the index of the
59 /// children tensor expressions.
60 struct TensorExp {
TensorExp__anon4430b2680111::TensorExp61 TensorExp(Kind k, unsigned x, unsigned y, Value v)
62 : kind(k), e0(x), e1(y), val(v) {
63 assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
64 (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
65 (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
66 }
67 Kind kind;
68 /// Indices of children expression(s).
69 unsigned e0;
70 unsigned e1;
71 /// Direct link to IR for an invariant. During code generation,
72 /// field is used to cache "hoisted" loop invariant tensor loads.
73 Value val;
74 };
75
76 /// Lattice point. Each lattice point consists of a conjunction of tensor
77 /// loop indices (encoded in a bitvector) and the index of the corresponding
78 /// tensor expression.
79 struct LatPoint {
LatPoint__anon4430b2680111::LatPoint80 LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
81 bits.set(b);
82 }
LatPoint__anon4430b2680111::LatPoint83 LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
84 /// Conjunction of tensor loop indices as bitvector.
85 llvm::BitVector bits;
86 /// Index of the tensor expresssion.
87 unsigned exp;
88 };
89
90 /// A class to handle all iteration lattice operations. This class abstracts
91 /// away from some implementation details of storing iteration lattices and
92 /// tensor expressions. This allows for fine-tuning performance characteristics
93 /// independently from the basic algorithm if bottlenecks are identified.
94 class Merger {
95 public:
Merger(unsigned t,unsigned l)96 Merger(unsigned t, unsigned l)
97 : numTensors(t), numLoops(l), isSparse(t, std::vector<bool>(l, false)) {}
98
99 /// Adds a tensor expression. Returns its index.
addExp(Kind k,unsigned e0,unsigned e1=-1u,Value v=Value ())100 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
101 unsigned e = tensorExps.size();
102 tensorExps.push_back(TensorExp(k, e0, e1, v));
103 return e;
104 }
addExp(Kind k,Value v)105 unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
106
107 /// Adds an iteration lattice point. Returns its index.
addLat(unsigned t,unsigned i,unsigned e)108 unsigned addLat(unsigned t, unsigned i, unsigned e) {
109 assert(t < numTensors && i < numLoops);
110 unsigned p = latPoints.size();
111 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
112 return p;
113 }
114
115 /// Adds a new, initially empty, set. Returns its index.
addSet()116 unsigned addSet() {
117 unsigned s = latSets.size();
118 latSets.emplace_back(SmallVector<unsigned, 16>());
119 return s;
120 }
121
122 /// Computes a single conjunction of two lattice points by taking the "union"
123 /// of loop indices (effectively constucting a larger "intersection" of those
124 /// indices) with a newly constructed tensor (sub)expression of given kind.
125 /// Returns the index of the new lattice point.
conjLatPoint(Kind kind,unsigned p0,unsigned p1)126 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
127 unsigned p = latPoints.size();
128 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
129 nb |= latPoints[p1].bits;
130 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
131 latPoints.push_back(LatPoint(nb, e));
132 return p;
133 }
134
135 /// Conjunctive merge of L1 and L2 is conjunction of cartesian product.
136 /// Returns the index of the new set.
takeConj(Kind kind,unsigned s0,unsigned s1)137 unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
138 unsigned s = addSet();
139 for (unsigned p0 : latSets[s0])
140 for (unsigned p1 : latSets[s1])
141 latSets[s].push_back(conjLatPoint(kind, p0, p1));
142 return s;
143 }
144
145 /// Disjunctive merge of L0 and L1 is (L0 /\_op L1, L0, L1).
146 /// Returns the index of the new set.
takeDisj(Kind kind,unsigned s0,unsigned s1)147 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
148 unsigned s = takeConj(kind, s0, s1);
149 for (unsigned p : latSets[s0])
150 latSets[s].push_back(p);
151 for (unsigned p : latSets[s1])
152 latSets[s].push_back(p);
153 return s;
154 }
155
156 /// Optimizes the iteration lattice points in the given set. This
157 /// method should be called right before code generation to avoid
158 /// generating redundant loops and conditions.
optimize(unsigned s0)159 unsigned optimize(unsigned s0) {
160 unsigned s = addSet();
161 assert(latSets[s0].size() != 0);
162 unsigned p0 = latSets[s0][0];
163 for (unsigned p1 : latSets[s0]) {
164 bool add = true;
165 if (p0 != p1) {
166 // Is this a straightforward copy?
167 unsigned e = latPoints[p1].exp;
168 if (exp(e).kind == Kind::kTensor && exp(e).e0 == numTensors - 1)
169 continue;
170 // Is any dense index exhausted?
171 llvm::BitVector tmp = latPoints[p1].bits;
172 tmp ^= latPoints[p0].bits;
173 if (hasAnyOf(tmp, false))
174 continue;
175 // Is this a direct duplication of an earlier conjunction?
176 for (unsigned p2 : latSets[s]) {
177 tmp = latPoints[p1].bits;
178 tmp ^= latPoints[p2].bits;
179 if (tmp.count() == 0) {
180 add = false;
181 break;
182 }
183 }
184 assert(!add || latGT(p0, p1));
185 }
186 if (add)
187 latSets[s].push_back(p1);
188 }
189 return s;
190 }
191
192 // Returns true if Li > Lj.
latGT(unsigned i,unsigned j) const193 bool latGT(unsigned i, unsigned j) const {
194 const llvm::BitVector &bitsi = latPoints[i].bits;
195 const llvm::BitVector &bitsj = latPoints[j].bits;
196 assert(bitsi.size() == bitsj.size());
197 if (bitsi.count() > bitsj.count()) {
198 for (unsigned b = 0, be = bitsj.size(); b < be; b++)
199 if (bitsj[b] && !bitsi[b])
200 return false;
201 return true;
202 }
203 return false;
204 }
205
206 // Bit translation.
tensor(unsigned b) const207 unsigned tensor(unsigned b) const { return b % numTensors; }
index(unsigned b) const208 unsigned index(unsigned b) const { return b / numTensors; }
209
210 // Returns true if bit corresponds to sparse access.
isSparseBit(unsigned b) const211 bool isSparseBit(unsigned b) const {
212 return isSparseAccess(tensor(b), index(b));
213 }
214
215 // Returns true if tensor access at given index is sparse.
isSparseAccess(unsigned t,unsigned i) const216 bool isSparseAccess(unsigned t, unsigned i) const {
217 assert(t < numTensors && i < numLoops);
218 return isSparse[t][i];
219 }
220
221 // Returns true if any set bit corresponds to sparse/dense access.
hasAnyOf(const llvm::BitVector & bits,bool sparse) const222 bool hasAnyOf(const llvm::BitVector &bits, bool sparse) const {
223 for (unsigned b = 0, be = bits.size(); b < be; b++)
224 if (bits[b] && isSparseBit(b) == sparse)
225 return true;
226 return false;
227 }
228
229 // Getters.
sparse()230 std::vector<std::vector<bool>> &sparse() { return isSparse; }
exp(unsigned e)231 TensorExp &exp(unsigned e) { return tensorExps[e]; }
lat(unsigned l)232 LatPoint &lat(unsigned l) { return latPoints[l]; }
set(unsigned s)233 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
234
235 private:
236 const unsigned numTensors;
237 const unsigned numLoops;
238
239 std::vector<std::vector<bool>> isSparse;
240 llvm::SmallVector<TensorExp, 32> tensorExps;
241 llvm::SmallVector<LatPoint, 16> latPoints;
242 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
243 };
244
245 // Code generation.
246 struct CodeGen {
CodeGen__anon4430b2680111::CodeGen247 CodeGen(linalg::SparsificationOptions o, unsigned numTensors,
248 unsigned numLoops)
249 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
250 pointers(numTensors, std::vector<Value>(numLoops)),
251 indices(numTensors, std::vector<Value>(numLoops)),
252 highs(numTensors, std::vector<Value>(numLoops)),
253 pidxs(numTensors, std::vector<Value>(numLoops)),
254 idxs(numTensors, std::vector<Value>(numLoops)) {}
255 // Sparsification options.
256 linalg::SparsificationOptions options;
257 // Universal dense indices and upper bounds (by index). The loops array
258 // is updated with the value of the universal dense index in the current
259 // loop. The sizes array is set once with the inferred dimension sizes.
260 std::vector<Value> loops;
261 std::vector<Value> sizes;
262 // Buffers for storing dense and sparse numerical values (by tensor).
263 // This array is set once during bufferization of all tensors.
264 std::vector<Value> buffers;
265 // Sparse storage schemes (1-D): pointers and indices (by tensor and index).
266 // This array is set once during bufferization of all sparse tensors.
267 std::vector<std::vector<Value>> pointers;
268 std::vector<std::vector<Value>> indices;
269 // Sparse iteration information (by tensor and index). These arrays
270 // are updated to remain current within the current loop.
271 std::vector<std::vector<Value>> highs;
272 std::vector<std::vector<Value>> pidxs;
273 std::vector<std::vector<Value>> idxs;
274 };
275
276 } // namespace
277
278 /// Helper method to inspect sparse annotations in the linalg operation.
279 /// Fills the per-dimension sparsity information for all tensors.
findSparseAnnotations(linalg::GenericOp op,std::vector<std::vector<bool>> & isSparse)280 static void findSparseAnnotations(linalg::GenericOp op,
281 std::vector<std::vector<bool>> &isSparse) {
282 unsigned numTensors = op.getNumInputsAndOutputs();
283 ArrayAttr sparseAttr = op.sparseAttr();
284 for (unsigned t = 0; t < numTensors; t++) {
285 auto map = op.getIndexingMap(t);
286 auto dimAttr = sparseAttr[t].cast<ArrayAttr>();
287 // For each tensor, we accept a per-dimension Sparse or Dense annotation.
288 // This is translated to the loop index that indexes that dimension.
289 unsigned rank = op.getShapedType(t).getRank();
290 for (unsigned d = 0; d < rank; d++)
291 if (isSparseDim(dimAttr[d])) {
292 unsigned idx = map.getDimPosition(d);
293 isSparse[t][idx] = true;
294 } else {
295 assert(isDenseDim(dimAttr[d]));
296 }
297 }
298 }
299
300 /// A DFS helper to compute a topological sort. Note that recursion is
301 /// bounded by the number of implicit loops, which is always small.
302 /// Returns false when a cycle is detected.
topSortDFS(unsigned i,std::vector<unsigned> & visit,std::vector<unsigned> & topSort,std::vector<std::vector<bool>> & adjM)303 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
304 std::vector<unsigned> &topSort,
305 std::vector<std::vector<bool>> &adjM) {
306 if (visit[i] != 0)
307 return visit[i] != 1; // 1 denotes cycle!
308 visit[i] = 1;
309 for (unsigned j = 0, e = visit.size(); j < e; j++)
310 if (adjM[i][j])
311 if (!topSortDFS(j, visit, topSort, adjM))
312 return false;
313 visit[i] = 2;
314 topSort.push_back(i);
315 return true;
316 }
317
318 /// Computes a topologically sorted iteration graph for the linalg operation.
319 /// Ensures all tensors are visited in natural index order. This is essential
320 /// for sparse storage formats since these only support access along fixed
321 /// dimensions. Even for dense storage formats, however, the natural index
322 /// order yields innermost unit-stride access with better spatial locality.
computeIterationGraph(linalg::GenericOp op,std::vector<unsigned> & topSort)323 static bool computeIterationGraph(linalg::GenericOp op,
324 std::vector<unsigned> &topSort) {
325 // Set up an n x n from/to adjacency matrix of the iteration graph
326 // for the implicit loop indices i_0 .. i_n-1.
327 unsigned n = op.getNumLoops();
328 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
329
330 // Iterate over the indexing maps of every tensor in the tensor expression.
331 for (auto imap : llvm::enumerate(op.indexing_maps())) {
332 auto map = imap.value().template cast<AffineMapAttr>().getValue();
333 assert(map.getNumDims() == n);
334 // At the moment, we take the index variables in the tensor access
335 // expression in the order in which they appear (conceptually a
336 // "row-major" layout of every tensor). So, a tensor access A_ijk
337 // forces the ordering i < j < k on the loop indices.
338 // TODO: support affine map to define alternative dimension orders.
339 for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
340 unsigned f = map.getDimPosition(d - 1);
341 unsigned t = map.getDimPosition(d);
342 adjM[f][t] = true;
343 }
344 }
345
346 // Topologically sort the iteration graph to determine loop order.
347 // Report failure for a cyclic iteration graph.
348 topSort.reserve(n);
349 std::vector<unsigned> visit(n, 0);
350 for (unsigned i = 0; i < n; i++)
351 if (visit[i] == 0)
352 if (!topSortDFS(i, visit, topSort, adjM))
353 return false; // cycle!
354 std::reverse(std::begin(topSort), std::end(topSort));
355 return true;
356 }
357
358 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
359 /// This simplifies constructing (sub)expressions during iteration lattice
360 /// building (compared to using the SSA representation everywhere).
buildTensorExp(Merger & merger,linalg::GenericOp op,Value val)361 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
362 Value val) {
363 if (auto arg = val.dyn_cast<BlockArgument>()) {
364 unsigned argN = arg.getArgNumber();
365 if (arg.getOwner()->getParentOp() == op) {
366 // Any parameter of the generic op is considered a tensor,
367 // indexed by the implicit loop bounds.
368 auto map = op.getIndexingMap(argN);
369 if (map.isProjectedPermutation())
370 return merger.addExp(Kind::kTensor, argN);
371 // Cannot handle (yet).
372 return None;
373 }
374 // Any parameter of a higher op is invariant.
375 return merger.addExp(Kind::kInvariant, val);
376 }
377 Operation *def = val.getDefiningOp();
378 if (def->getBlock() != &op.region().front()) {
379 // Something defined outside is invariant.
380 return merger.addExp(Kind::kInvariant, val);
381 } else if (def->getNumOperands() == 2) {
382 // Construct binary operations if subexpressions could be built.
383 auto x = buildTensorExp(merger, op, def->getOperand(0));
384 auto y = buildTensorExp(merger, op, def->getOperand(1));
385 if (x.hasValue() && y.hasValue()) {
386 unsigned e0 = x.getValue();
387 unsigned e1 = y.getValue();
388 if (isa<MulFOp>(def))
389 return merger.addExp(Kind::kMulF, e0, e1);
390 if (isa<MulIOp>(def))
391 return merger.addExp(Kind::kMulI, e0, e1);
392 if (isa<AddFOp>(def))
393 return merger.addExp(Kind::kAddF, e0, e1);
394 if (isa<AddIOp>(def))
395 return merger.addExp(Kind::kAddI, e0, e1);
396 }
397 }
398 // Cannot build (yet).
399 return None;
400 }
401
402 /// Builds the iteration lattices in a bottom-up traversal given the remaining
403 /// tensor (sub)expression and the next loop index in the iteration graph.
buildLattices(Merger & merger,linalg::GenericOp op,unsigned exp,unsigned idx)404 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
405 unsigned exp, unsigned idx) {
406 Kind kind = merger.exp(exp).kind;
407 if (kind == Kind::kTensor || kind == Kind::kInvariant) {
408 // Either the index is really used in the tensor expression, or it is
409 // set to the "non-existing dense index" in that dimension. Invariant
410 // expressions borrow the output tensor indices.
411 unsigned s = merger.addSet();
412 unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
413 : op.getNumInputsAndOutputs() - 1;
414 merger.set(s).push_back(merger.addLat(t, idx, exp));
415 return s;
416 }
417 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
418 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
419 switch (kind) {
420 case Kind::kTensor:
421 case Kind::kInvariant:
422 llvm_unreachable("handled above");
423 case Kind::kMulF:
424 case Kind::kMulI:
425 return merger.takeConj(kind, s0, s1);
426 case Kind::kAddF:
427 case Kind::kAddI:
428 return merger.takeDisj(kind, s0, s1);
429 }
430 }
431
432 /// Maps sparse integer option to actual integral storage type.
genIntType(PatternRewriter & rewriter,linalg::SparseIntType tp)433 static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
434 switch (tp) {
435 case linalg::SparseIntType::kNative:
436 return rewriter.getIndexType();
437 case linalg::SparseIntType::kI64:
438 return rewriter.getIntegerType(64);
439 case linalg::SparseIntType::kI32:
440 return rewriter.getIntegerType(32);
441 }
442 }
443
444 /// Local bufferization of all dense and sparse data structures.
445 /// This code enables testing the first prototype sparse compiler.
446 // TODO: replace this with a proliferated bufferization strategy
genBuffers(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op)447 static void genBuffers(Merger &merger, CodeGen &codegen,
448 PatternRewriter &rewriter, linalg::GenericOp op) {
449 Location loc = op.getLoc();
450 unsigned numTensors = op.getNumInputsAndOutputs();
451 unsigned numInputs = op.getNumInputs();
452 assert(numTensors == numInputs + 1);
453
454 // For now, set all unknown dimensions to 999.
455 // TODO: compute these values (using sparsity or by reading tensor)
456 Value unknown = rewriter.create<ConstantIndexOp>(loc, 999);
457
458 // For every tensor, find lower and upper bound on dimensions, set the
459 // same bounds on loop indices, and allocate dense or sparse buffer(s).
460 SmallVector<Value, 4> args;
461 for (unsigned t = 0; t < numTensors; t++) {
462 auto tensorType = op.getShapedType(t);
463 auto shape = tensorType.getShape();
464 auto map = op.getIndexingMap(t);
465 // Scan all dimensions of current tensor.
466 bool allDense = true;
467 args.clear();
468 for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
469 unsigned i = map.getDimPosition(d);
470 // Handle sparse storage schemes.
471 if (merger.isSparseAccess(t, i)) {
472 allDense = false;
473 auto dynShape = {ShapedType::kDynamicSize};
474 auto ptrTp = MemRefType::get(
475 dynShape, genIntType(rewriter, codegen.options.ptrType));
476 auto indTp = MemRefType::get(
477 dynShape, genIntType(rewriter, codegen.options.indType));
478 codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, ptrTp, unknown);
479 codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, indTp, unknown);
480 }
481 // Find lower and upper bound in current dimension.
482 Value up;
483 if (shape[d] == TensorType::kDynamicSize) {
484 // For the output tensor, we may need to infer the upper bound.
485 // For all others, we look at the incoming argument.
486 if (t == numInputs && !op.getNumInitTensors()) {
487 up = codegen.sizes[i];
488 assert(up); // TODO: what else?
489 } else {
490 Value arg = t < numInputs ? op.getInput(t) : op.getInitTensor(0);
491 up = rewriter.create<DimOp>(loc, arg, d);
492 }
493 args.push_back(up);
494 } else {
495 up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
496 }
497 codegen.sizes[i] = codegen.highs[t][i] = up;
498 }
499 // Allocate dense or sparse buffer for numerical values.
500 if (allDense) {
501 auto denseTp = MemRefType::get(shape, tensorType.getElementType());
502 codegen.buffers[t] = rewriter.create<AllocaOp>(loc, denseTp, args);
503 } else {
504 auto sparseTp = MemRefType::get({ShapedType::kDynamicSize},
505 tensorType.getElementType());
506 codegen.buffers[t] = rewriter.create<AllocaOp>(loc, sparseTp, unknown);
507 }
508 }
509 }
510
511 /// Generates a load on a dense or sparse tensor.
genTensorLoad(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)512 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
513 PatternRewriter &rewriter, linalg::GenericOp op,
514 unsigned exp) {
515 // Test if the load was hoisted to a higher loop nest.
516 Value val = merger.exp(exp).val;
517 if (val) {
518 merger.exp(exp).val = Value(); // reset
519 return val;
520 }
521 // Actual load.
522 SmallVector<Value, 4> args;
523 unsigned tensor = merger.exp(exp).e0;
524 auto map = op.getIndexingMap(tensor);
525 bool sparse = false;
526 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
527 unsigned idx = map.getDimPosition(i);
528 args.push_back(codegen.loops[idx]); // universal dense index
529 if (sparse || merger.isSparseAccess(tensor, idx)) {
530 sparse = true;
531 args.clear();
532 args.push_back(codegen.pidxs[tensor][idx]); // position index
533 }
534 }
535 Location loc = op.getLoc();
536 Value ptr = codegen.buffers[tensor];
537 return rewriter.create<LoadOp>(loc, ptr, args);
538 }
539
540 /// Generates a store on a dense tensor.
genTensorStore(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned tensor,Value rhs)541 static void genTensorStore(Merger &merger, CodeGen &codegen,
542 PatternRewriter &rewriter, linalg::GenericOp op,
543 unsigned tensor, Value rhs) {
544 SmallVector<Value, 4> args;
545 auto map = op.getIndexingMap(tensor);
546 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
547 unsigned idx = map.getDimPosition(i);
548 args.push_back(codegen.loops[idx]); // universal dense index
549 }
550 Location loc = op.getLoc();
551 Value ptr = codegen.buffers[tensor];
552 rewriter.create<StoreOp>(loc, rhs, ptr, args);
553 }
554
555 /// Generates a pointer/index load from the sparse storage scheme.
genLoad(PatternRewriter & rewriter,Location loc,Value ptr,Value s)556 static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr,
557 Value s) {
558 Value load = rewriter.create<LoadOp>(loc, ptr, s);
559 return load.getType().isa<IndexType>()
560 ? load
561 : rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
562 }
563
564 /// Generates an invariant value.
genInvariantValue(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,unsigned exp)565 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
566 PatternRewriter &rewriter, unsigned exp) {
567 return merger.exp(exp).val;
568 }
569
570 /// Recursively generates tensor expression.
genExp(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)571 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
572 linalg::GenericOp op, unsigned exp) {
573 if (merger.exp(exp).kind == Kind::kTensor)
574 return genTensorLoad(merger, codegen, rewriter, op, exp);
575 else if (merger.exp(exp).kind == Kind::kInvariant)
576 return genInvariantValue(merger, codegen, rewriter, exp);
577 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
578 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
579 switch (merger.exp(exp).kind) {
580 case Kind::kTensor:
581 case Kind::kInvariant:
582 llvm_unreachable("handled above");
583 case Kind::kMulF:
584 return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
585 case Kind::kMulI:
586 return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
587 case Kind::kAddF:
588 return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
589 case Kind::kAddI:
590 return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
591 }
592 }
593
594 /// Hoists loop invariant tensor loads for which indices have been exhausted.
genInvariants(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)595 static void genInvariants(Merger &merger, CodeGen &codegen,
596 PatternRewriter &rewriter, linalg::GenericOp op,
597 unsigned exp) {
598 if (merger.exp(exp).kind == Kind::kTensor) {
599 unsigned lhs = op.getNumInputsAndOutputs() - 1;
600 unsigned tensor = merger.exp(exp).e0;
601 if (tensor == lhs)
602 return; // TODO: scalarize reduction as well (using scf.yield)
603 auto map = op.getIndexingMap(tensor);
604 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
605 unsigned idx = map.getDimPosition(i);
606 if (!codegen.loops[idx])
607 return; // still in play
608 }
609 // All exhausted at this level.
610 merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
611
612 } else if (merger.exp(exp).kind != Kind::kInvariant) {
613 // Traverse into the binary operations. Note that we only hoist
614 // tensor loads, since subsequent MLIR/LLVM passes know how to
615 // deal with all other kinds of derived loop invariants.
616 genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
617 genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
618 }
619 }
620
621 /// Generates initialization code for the subsequent loop sequence at
622 /// current index level. Returns true if the loop sequence needs to
623 /// maintain the universal index.
genInit(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,llvm::BitVector & inits)624 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
625 linalg::GenericOp op, std::vector<unsigned> &topSort,
626 unsigned at, llvm::BitVector &inits) {
627 bool needsUniv = false;
628 Location loc = op.getLoc();
629 unsigned idx = topSort[at];
630
631 // Initialize sparse positions.
632 for (unsigned b = 0, be = inits.size(); b < be; b++) {
633 if (inits[b]) {
634 unsigned tensor = merger.tensor(b);
635 assert(idx == merger.index(b));
636 if (merger.isSparseBit(b)) {
637 // Initialize sparse index.
638 unsigned pat = at;
639 for (; pat != 0; pat--) {
640 if (codegen.pidxs[tensor][topSort[pat - 1]])
641 break;
642 }
643 Value ptr = codegen.pointers[tensor][idx];
644 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
645 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
646 : codegen.pidxs[tensor][topSort[pat - 1]];
647 codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0);
648 Value p1 = rewriter.create<AddIOp>(loc, p0, one);
649 codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1);
650 } else {
651 // Dense index still in play.
652 needsUniv = true;
653 }
654 }
655 }
656
657 // Initialize the universal dense index.
658 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
659 return needsUniv;
660 }
661
662 /// Generates a for-loop on a single index.
genFor(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,bool isOuter,bool isInner,unsigned idx,llvm::BitVector & indices)663 static Operation *genFor(Merger &merger, CodeGen &codegen,
664 PatternRewriter &rewriter, linalg::GenericOp op,
665 bool isOuter, bool isInner, unsigned idx,
666 llvm::BitVector &indices) {
667 unsigned fb = indices.find_first();
668 unsigned tensor = merger.tensor(fb);
669 assert(idx == merger.index(fb));
670
671 // Parallelization strategy. Any implicit loop in the Linalg operation that
672 // is marked "parallel" is a candidate. Whether it is actually converted to
673 // a parallel operation depends on the requested strategy.
674 auto iteratorTypes = op.iterator_types().getValue();
675 bool isSparse = merger.isSparseBit(fb);
676 bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]);
677 switch (codegen.options.parallelizationStrategy) {
678 case linalg::SparseParallelizationStrategy::kNone:
679 isParallel = false;
680 break;
681 case linalg::SparseParallelizationStrategy::kDenseOuterLoop:
682 isParallel &= isOuter && !isSparse;
683 break;
684 case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop:
685 isParallel &= isOuter;
686 break;
687 case linalg::SparseParallelizationStrategy::kDenseAnyLoop:
688 isParallel &= !isSparse;
689 break;
690 case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop:
691 break;
692 }
693
694 // Loop bounds and increment.
695 Location loc = op.getLoc();
696 Value lo;
697 Value hi;
698 Value step = rewriter.create<ConstantIndexOp>(loc, 1);
699 Value index;
700 if (isSparse) {
701 lo = codegen.pidxs[tensor][idx];
702 hi = codegen.highs[tensor][idx];
703 } else {
704 lo = codegen.loops[idx];
705 hi = codegen.sizes[idx];
706 }
707
708 // Emit a parallel loop.
709 if (isParallel) {
710 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
711 if (isSparse)
712 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
713 else
714 codegen.loops[idx] = parOp.getInductionVars()[0];
715 rewriter.setInsertionPointToStart(parOp.getBody());
716 return parOp;
717 }
718
719 // Emit a sequential loop.
720 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step);
721 if (isSparse)
722 codegen.pidxs[tensor][idx] = forOp.getInductionVar();
723 else
724 codegen.loops[idx] = forOp.getInductionVar();
725 rewriter.setInsertionPointToStart(forOp.getBody());
726 return forOp;
727 }
728
729 /// Emit a while-loop for co-iteration over multiple indices.
genWhile(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,bool needsUniv,llvm::BitVector & indices)730 static Operation *genWhile(Merger &merger, CodeGen &codegen,
731 PatternRewriter &rewriter, linalg::GenericOp op,
732 unsigned idx, bool needsUniv,
733 llvm::BitVector &indices) {
734 SmallVector<Type, 4> types;
735 SmallVector<Value, 4> operands;
736 // Construct the while-loop with a parameter for each index.
737 Type indexType = rewriter.getIndexType();
738 for (unsigned b = 0, be = indices.size(); b < be; b++) {
739 if (indices[b] && merger.isSparseBit(b)) {
740 unsigned tensor = merger.tensor(b);
741 assert(idx == merger.index(b));
742 types.push_back(indexType);
743 operands.push_back(codegen.pidxs[tensor][idx]);
744 }
745 }
746 if (needsUniv) {
747 types.push_back(indexType);
748 operands.push_back(codegen.loops[idx]);
749 }
750 Location loc = op.getLoc();
751 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
752 Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
753 Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
754
755 // Build the "before" region, which effectively consists
756 // of a conjunction of "i < upper" tests on all induction.
757 rewriter.setInsertionPointToStart(&whileOp.before().front());
758 Value cond;
759 unsigned o = 0;
760 for (unsigned b = 0, be = indices.size(); b < be; b++) {
761 if (indices[b] && merger.isSparseBit(b)) {
762 unsigned tensor = merger.tensor(b);
763 assert(idx == merger.index(b));
764 Value op1 = before->getArgument(o);
765 Value op2 = codegen.highs[tensor][idx];
766 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
767 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
768 codegen.pidxs[tensor][idx] = after->getArgument(o++);
769 }
770 }
771 if (needsUniv)
772 codegen.loops[idx] = after->getArgument(o++);
773 assert(o == operands.size());
774 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
775 rewriter.setInsertionPointToStart(&whileOp.after().front());
776 return whileOp;
777 }
778
779 /// Generates a for-loop or a while-loop, depending on whether it implements
780 /// singleton iteration or co-iteration over the given conjunction.
genLoop(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,bool needsUniv,llvm::BitVector & indices)781 static Operation *genLoop(Merger &merger, CodeGen &codegen,
782 PatternRewriter &rewriter, linalg::GenericOp op,
783 std::vector<unsigned> &topSort, unsigned at,
784 bool needsUniv, llvm::BitVector &indices) {
785 unsigned idx = topSort[at];
786 if (indices.count() == 1) {
787 bool isOuter = at == 0;
788 bool isInner = at == topSort.size() - 1;
789 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
790 indices);
791 }
792 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
793 }
794
795 /// Generates the local variables for this loop, consisting of the sparse
796 /// indices, restored universal dense index, and dense positions.
genLocals(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,bool needsUniv,llvm::BitVector & locals)797 static void genLocals(Merger &merger, CodeGen &codegen,
798 PatternRewriter &rewriter, linalg::GenericOp op,
799 std::vector<unsigned> &topSort, unsigned at,
800 bool needsUniv, llvm::BitVector &locals) {
801 Location loc = op.getLoc();
802 unsigned idx = topSort[at];
803
804 // Initialize sparse indices.
805 Value min;
806 for (unsigned b = 0, be = locals.size(); b < be; b++) {
807 if (locals[b] && merger.isSparseBit(b)) {
808 unsigned tensor = merger.tensor(b);
809 assert(idx == merger.index(b));
810 Value ptr = codegen.indices[tensor][idx];
811 Value s = codegen.pidxs[tensor][idx];
812 Value load = genLoad(rewriter, loc, ptr, s);
813 codegen.idxs[tensor][idx] = load;
814 if (!needsUniv) {
815 if (min) {
816 Value cmp =
817 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
818 min = rewriter.create<SelectOp>(loc, cmp, load, min);
819 } else {
820 min = load;
821 }
822 }
823 }
824 }
825
826 // Merge dense universal index over minimum.
827 if (min) {
828 assert(!needsUniv);
829 codegen.loops[idx] = min;
830 }
831
832 // Initialize dense positions.
833 for (unsigned b = 0, be = locals.size(); b < be; b++) {
834 if (locals[b] && !merger.isSparseBit(b)) {
835 unsigned tensor = merger.tensor(b);
836 assert(idx == merger.index(b));
837 if (!codegen.highs[tensor][idx])
838 continue; // unused dimension
839 unsigned pat = at;
840 for (; pat != 0; pat--)
841 if (codegen.pidxs[tensor][topSort[pat - 1]])
842 break;
843 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
844 : codegen.pidxs[tensor][topSort[pat - 1]];
845 Value m = rewriter.create<MulIOp>(loc, codegen.sizes[idx], p);
846 codegen.pidxs[tensor][idx] =
847 rewriter.create<AddIOp>(loc, m, codegen.loops[idx]);
848 }
849 }
850 }
851
852 /// Generates the induction structure for a while-loop.
genWhileInduction(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,bool needsUniv,llvm::BitVector & induction,ResultRange results)853 static void genWhileInduction(Merger &merger, CodeGen &codegen,
854 PatternRewriter &rewriter, linalg::GenericOp op,
855 unsigned idx, bool needsUniv,
856 llvm::BitVector &induction, ResultRange results) {
857 Location loc = op.getLoc();
858 unsigned o = 0;
859 SmallVector<Value, 4> operands;
860 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
861 for (unsigned b = 0, be = induction.size(); b < be; b++)
862 if (induction[b] && merger.isSparseBit(b)) {
863 unsigned tensor = merger.tensor(b);
864 assert(idx == merger.index(b));
865 Value op1 = codegen.idxs[tensor][idx];
866 Value op2 = codegen.loops[idx];
867 Value op3 = codegen.pidxs[tensor][idx];
868 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
869 Value add = rewriter.create<AddIOp>(loc, op3, one);
870 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
871 codegen.pidxs[tensor][idx] = results[o++];
872 }
873 if (needsUniv) {
874 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
875 codegen.loops[idx] = results[o++];
876 }
877 assert(o == operands.size());
878 rewriter.create<scf::YieldOp>(loc, operands);
879 }
880
881 /// Generates a single if-statement within a while-loop.
genIf(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,llvm::BitVector & conditions,scf::IfOp & ifOp)882 static void genIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
883 linalg::GenericOp op, unsigned idx,
884 llvm::BitVector &conditions, scf::IfOp &ifOp) {
885 Location loc = op.getLoc();
886 if (ifOp)
887 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
888 Value cond;
889 for (unsigned b = 0, be = conditions.size(); b < be; b++) {
890 if (conditions[b]) {
891 unsigned tensor = merger.tensor(b);
892 assert(idx == merger.index(b));
893 Value clause;
894 if (merger.isSparseBit(b)) {
895 Value op1 = codegen.idxs[tensor][idx];
896 Value op2 = codegen.loops[idx];
897 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
898 } else {
899 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
900 }
901 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
902 }
903 }
904 ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
905 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
906 }
907
908 /// Optimize the loop indices of Li with two rules rules:
909 /// (1) convert multiple dense to single dense, and
910 /// (2) convert singleton sparse/dense to sparse/random access.
optimizeIndices(Merger merger,unsigned lsize,llvm::BitVector & indices)911 static void optimizeIndices(Merger merger, unsigned lsize,
912 llvm::BitVector &indices) {
913 if (merger.hasAnyOf(indices, false)) {
914 bool reset = lsize == 1 && merger.hasAnyOf(indices, true);
915 for (unsigned b = 0, be = indices.size(); b < be; b++) {
916 if (indices[b] && !merger.isSparseBit(b)) {
917 if (reset)
918 indices.reset(b);
919 reset = true;
920 }
921 }
922 }
923 }
924
925 /// Recursively generates code while computing iteration lattices in order
926 /// to manage the complexity of implementing co-iteration over unions
927 /// and intersections of sparse iterations spaces.
genStmt(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned exp,unsigned at)928 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
929 linalg::GenericOp op, std::vector<unsigned> &topSort,
930 unsigned exp, unsigned at) {
931 // At each leaf, assign remaining tensor (sub)expression to output tensor.
932 if (at == topSort.size()) {
933 unsigned lhs = op.getNumInputsAndOutputs() - 1;
934 Value rhs = genExp(merger, codegen, rewriter, op, exp);
935 genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
936 return;
937 }
938
939 // Construct iteration lattices for current loop index, with L0 at top.
940 // Then emit initialization code for the loop sequence at this level.
941 // We maintain the universal dense index if dense indices are still
942 // in play for a non-singleton loop sequence.
943 unsigned idx = topSort[at];
944 unsigned lts = merger.optimize(buildLattices(merger, op, exp, idx));
945 unsigned lsize = merger.set(lts).size();
946 assert(lsize != 0);
947 unsigned l0 = merger.set(lts)[0];
948 LatPoint lat0 = merger.lat(l0);
949 genInvariants(merger, codegen, rewriter, op, exp);
950 bool needsUniv =
951 genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
952 lsize > 1;
953
954 // Emit a loop for every lattice point L0 >= Li.
955 for (unsigned li : merger.set(lts)) {
956 LatPoint lati = merger.lat(li);
957
958 // Emit loop.
959 llvm::BitVector indices = lati.bits;
960 optimizeIndices(merger, lsize, indices);
961 Operation *loop =
962 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
963 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
964
965 // Visit all lattices points with Li >= Lj to generate the
966 // loop-body, possibly with if statements for coiteration.
967 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
968 scf::IfOp ifOp;
969 for (unsigned lj : merger.set(lts)) {
970 if (li == lj || merger.latGT(li, lj)) {
971 LatPoint latj = merger.lat(lj);
972 llvm::BitVector tmp = latj.bits;
973 tmp ^= lati.bits;
974 if (merger.hasAnyOf(tmp, false))
975 continue; // dense exhausted within if/else
976 // Recurse into body of each branch.
977 if (isWhile)
978 genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp);
979 genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1);
980 }
981 }
982
983 // Wrap-up induction and restore insertion point.
984 if (isWhile) {
985 scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
986 rewriter.setInsertionPointToEnd(&whileOp.after().front());
987 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
988 lati.bits, whileOp.results());
989 } else {
990 needsUniv = false;
991 }
992 rewriter.setInsertionPointAfter(loop);
993 }
994 codegen.loops[idx] = Value();
995 }
996
997 namespace {
998
999 /// Sparse rewriting rule for generic Lingalg operation.
1000 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1001 public:
GenericOpSparsifier__anon4430b2680211::GenericOpSparsifier1002 GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o)
1003 : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1004
matchAndRewrite__anon4430b2680211::GenericOpSparsifier1005 LogicalResult matchAndRewrite(linalg::GenericOp op,
1006 PatternRewriter &rewriter) const override {
1007 // Detects sparse annotations and translate the per-dimension sparsity
1008 // information for all tensors to loop indices in the kernel.
1009 if (!op.hasSparseSemantics())
1010 return failure();
1011 assert(op.getNumOutputs() == 1);
1012 unsigned numTensors = op.getNumInputsAndOutputs();
1013 unsigned numLoops = op.iterator_types().getValue().size();
1014 Merger merger(numTensors, numLoops);
1015 findSparseAnnotations(op, merger.sparse());
1016
1017 // Computes a topologically sorted iteration graph to ensure
1018 // tensors are visited in natural index order. Fails on cycles.
1019 // This assumes that higher-level passes have already put the
1020 // tensors in each tensor expression in a feasible order.
1021 // TODO: try again without *dense* constraints on failure or
1022 // even try to insert sparse reorderings to resolve cycles
1023 std::vector<unsigned> topSort;
1024 if (!computeIterationGraph(op, topSort))
1025 return failure();
1026
1027 // Finds the terminating yield statement and builds the tensor
1028 // expression for the Linalg operation in SSA form.
1029 Operation *yield = op.region().front().getTerminator();
1030 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1031 if (!exp.hasValue())
1032 return failure(); // build failure
1033
1034 // Recursively generates code.
1035 CodeGen codegen(options, numTensors, numLoops);
1036 genBuffers(merger, codegen, rewriter, op);
1037 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1038 Value result =
1039 rewriter.create<TensorLoadOp>(op.getLoc(), codegen.buffers.back());
1040 rewriter.replaceOp(op, result);
1041 return success();
1042 }
1043
1044 private:
1045 /// Options to control sparse code generation.
1046 linalg::SparsificationOptions options;
1047 };
1048
1049 } // namespace
1050
1051 /// Populates the given patterns list with rewriting rules required for
1052 /// the sparsification of linear algebra operations.
populateSparsificationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,const SparsificationOptions & options)1053 void linalg::populateSparsificationPatterns(
1054 MLIRContext *context, OwningRewritePatternList &patterns,
1055 const SparsificationOptions &options) {
1056 patterns.insert<GenericOpSparsifier>(context, options);
1057 }
1058