• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 /// Fully compose map with operands and canonicalize the result.
34 /// Return the `createOrFold`'ed AffineApply op.
createFoldedComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operandsRef)35 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
36                                              AffineMap map,
37                                              ValueRange operandsRef) {
38   SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
39   fullyComposeAffineMapAndOperands(&map, &operands);
40   canonicalizeMapAndOperands(&map, &operands);
41   return b.createOrFold<AffineApplyOp>(loc, map, operands);
42 }
43 
applyMapToValues(OpBuilder & b,Location loc,AffineMap map,ValueRange values)44 SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc,
45                                                      AffineMap map,
46                                                      ValueRange values) {
47   SmallVector<Value, 4> res;
48   res.reserve(map.getNumResults());
49   unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
50   // For each `expr` in `map`, applies the `expr` to the values extracted from
51   // ranges. If the resulting application can be folded into a Value, the
52   // folding occurs eagerly.
53   for (auto expr : map.getResults()) {
54     AffineMap map = AffineMap::get(numDims, numSym, expr);
55     res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
56   }
57   return res;
58 }
59 
createFlatListOfOperandDims(OpBuilder & b,Location loc)60 SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
61                                                             Location loc) {
62   SmallVector<Value, 4> res;
63   for (Value v : getShapedOperands()) {
64     ShapedType t = v.getType().template cast<ShapedType>();
65     for (unsigned i = 0, e = t.getRank(); i < e; ++i)
66       res.push_back(b.create<DimOp>(loc, v, i));
67   }
68   return res;
69 }
70 
createLoopRanges(OpBuilder & b,Location loc)71 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
72   AffineMap map = getLoopsToShapesMap();
73   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
74   auto viewSizes = createFlatListOfOperandDims(b, loc);
75   SmallVector<Range, 4> res(numDims);
76   Value zeroVal = b.create<ConstantIndexOp>(loc, 0);
77   Value oneVal = b.create<ConstantIndexOp>(loc, 1);
78   for (unsigned idx = 0; idx < numRes; ++idx) {
79     auto result = map.getResult(idx);
80     if (auto d = result.dyn_cast<AffineDimExpr>()) {
81       if (res[d.getPosition()].offset)
82         continue;
83       res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal};
84     }
85   }
86   return res;
87 }
88 
89 /// Forward declarations.
90 template <typename NamedStructuredOpType>
91 static void buildNamedStructuredOpRegionAndAttributes(
92     OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
93     TypeRange outputBufferTypes, TypeRange initTensorTypes,
94     TypeRange resultTypes);
95 
96 static ParseResult
97 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
98                              SmallVectorImpl<Type> &inputTypes,
99                              SmallVectorImpl<Type> &outputBufferTypes,
100                              SmallVectorImpl<Type> &initTensorTypes);
101 
102 template <typename NamedStructuredOpType>
103 static ParseResult
104 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
105                              TypeRange inputTypes, TypeRange outputBufferTypes,
106                              TypeRange initTensorTypes, TypeRange resultTypes);
107 static ParseResult
108 parseNamedStructuredOpResults(OpAsmParser &parser,
109                               SmallVectorImpl<Type> &resultTypes);
110 
111 template <typename NamedStructuredOpType>
112 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
113                                           OperationState &result);
114 
115 template <typename NamedStructuredOpType>
116 static void printCommonStructuredOpParts(OpAsmPrinter &p,
117                                          NamedStructuredOpType op);
118 
119 static void printNamedStructuredOpResults(OpAsmPrinter &p,
120                                           TypeRange resultTypes);
121 
122 template <typename NamedStructuredOpType>
123 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
124 
125 template <typename NamedStructuredOpType>
126 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
127 
128 /// This is a common class used for patterns of the form
129 /// ```
130 ///    someop(memrefcast) -> someop
131 /// ```
132 /// It folds the source of the memref_cast into the root operation directly.
foldMemRefCast(Operation * op)133 static LogicalResult foldMemRefCast(Operation *op) {
134   bool folded = false;
135   for (OpOperand &operand : op->getOpOperands()) {
136     auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
137     if (castOp && canFoldIntoConsumerOp(castOp)) {
138       operand.set(castOp.getOperand());
139       folded = true;
140     }
141   }
142   return success(folded);
143 }
144 
145 ///////////////////// Operations defined with Tablegen /////////////////////////
146 // For such operations that do not correspond to library calls (i.e. defined in
147 // LinalgOps.td), we define an overloaded `print` function and a
148 // parse`className` function.
149 
150 //===----------------------------------------------------------------------===//
151 // GenericOps
152 //===----------------------------------------------------------------------===//
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)153 void GenericOp::build(
154     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
155     ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
156     ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
157     StringRef doc, StringRef libraryCall,
158     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
159   build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
160         builder.getAffineMapArrayAttr(indexingMaps),
161         builder.getStrArrayAttr(iteratorTypes),
162         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
163         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
164         ArrayAttr());
165   if (!bodyBuild)
166     return;
167 
168   SmallVector<Type, 4> blockArgTypes;
169   for (ValueRange container : {inputs, outputBuffers, initTensors})
170     for (Value v : container)
171       blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
172 
173   OpBuilder::InsertionGuard guard(builder);
174   auto &region = *result.regions.front();
175   Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
176   bodyBuild(builder, result.location, bodyBlock->getArguments());
177 }
178 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)179 void GenericOp::build(
180     OpBuilder &builder, OperationState &result, ValueRange inputs,
181     ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
182     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
183     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
184   build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
185         indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
186 }
187 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)188 void GenericOp::build(
189     OpBuilder &builder, OperationState &result, ValueRange inputs,
190     ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
191     ArrayRef<StringRef> iteratorTypes,
192     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
193   build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
194         /*doc=*/"",
195         /*libraryCall=*/"", bodyBuild);
196 }
197 
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuild)198 void GenericOp::build(
199     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
200     ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
201     ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
202     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
203   build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
204         indexingMaps, iteratorTypes,
205         /*doc=*/"",
206         /*libraryCall=*/"", bodyBuild);
207 }
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)208 void IndexedGenericOp::build(
209     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
210     ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
211     ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
212     StringRef doc, StringRef libraryCall,
213     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
214         bodyBuild) {
215   build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
216         builder.getAffineMapArrayAttr(indexingMaps),
217         builder.getStrArrayAttr(iteratorTypes),
218         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
219         libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
220         ArrayAttr());
221   if (!bodyBuild)
222     return;
223 
224   unsigned nLoops = iteratorTypes.size();
225   SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
226   for (ValueRange container : {inputs, outputBuffers, initTensors})
227     for (Value v : container)
228       blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
229 
230   OpBuilder::InsertionGuard guard(builder);
231   auto &region = *result.regions.front();
232   Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
233   bodyBuild(builder, result.location,
234             bodyBlock->getArguments().take_front(nLoops),
235             bodyBlock->getArguments().drop_front(nLoops));
236 }
237 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,StringRef doc,StringRef libraryCall,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)238 void IndexedGenericOp::build(
239     OpBuilder &builder, OperationState &result, ValueRange inputs,
240     ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
241     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
242     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
243         bodyBuild) {
244   build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{},
245         indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild);
246 }
247 
build(OpBuilder & builder,OperationState & result,ValueRange inputs,ValueRange outputBuffers,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)248 void IndexedGenericOp::build(
249     OpBuilder &builder, OperationState &result, ValueRange inputs,
250     ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps,
251     ArrayRef<StringRef> iteratorTypes,
252     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
253         bodyBuild) {
254   build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes,
255         /*doc=*/"", /*libraryCall=*/"", bodyBuild);
256 }
257 
build(OpBuilder & builder,OperationState & result,TypeRange resultTensorTypes,ValueRange inputs,ValueRange outputBuffers,ValueRange initTensors,ArrayRef<AffineMap> indexingMaps,ArrayRef<StringRef> iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuild)258 void IndexedGenericOp::build(
259     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
260     ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors,
261     ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes,
262     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
263         bodyBuild) {
264   build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors,
265         indexingMaps, iteratorTypes,
266         /*doc=*/"",
267         /*libraryCall=*/"", bodyBuild);
268 }
269 
270 template <typename GenericOpType>
printGenericOp(OpAsmPrinter & p,GenericOpType op)271 static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
272   p << op.getOperationName() << " ";
273 
274   // Print extra attributes.
275   auto genericAttrNames = op.linalgTraitAttrNames();
276 
277   llvm::StringSet<> genericAttrNamesSet;
278   genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
279   SmallVector<NamedAttribute, 8> genericAttrs;
280   for (auto attr : op.getAttrs())
281     if (genericAttrNamesSet.count(attr.first.strref()) > 0)
282       genericAttrs.push_back(attr);
283   if (!genericAttrs.empty()) {
284     auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
285     p << genericDictAttr;
286   }
287 
288   // Printing is shared with named ops, except for the region and attributes
289   printCommonStructuredOpParts(p, op);
290 
291   genericAttrNames.push_back("operand_segment_sizes");
292   genericAttrNamesSet.insert(genericAttrNames.back());
293 
294   bool hasExtraAttrs = false;
295   for (NamedAttribute n : op.getAttrs()) {
296     if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref())))
297       break;
298   }
299   if (hasExtraAttrs) {
300     p << " attrs = ";
301     p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/genericAttrNames);
302   }
303 
304   // Print region.
305   if (!op.region().empty())
306     p.printRegion(op.region());
307 
308   // Print results.
309   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
310 }
311 
print(OpAsmPrinter & p,GenericOp op)312 static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
313 
print(OpAsmPrinter & p,IndexedGenericOp op)314 static void print(OpAsmPrinter &p, IndexedGenericOp op) {
315   printGenericOp(p, op);
316 }
317 
parseGenericOp(OpAsmParser & parser,OperationState & result)318 static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
319   DictionaryAttr dictAttr;
320   // Parse the core linalg traits that must check into a dictAttr.
321   // The name is unimportant as we will overwrite result.attributes.
322   // The core linalg traits must contain the information necessary to pass the
323   // verifier.
324   if (parser.parseAttribute(dictAttr, "_", result.attributes))
325     return failure();
326   result.attributes.assign(dictAttr.getValue().begin(),
327                            dictAttr.getValue().end());
328 
329   // Parsing is shared with named ops, except for the region.
330   SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
331   if (parseCommonStructuredOpParts(parser, result, inputTypes,
332                                    outputBufferTypes, initTensorTypes))
333     return failure();
334 
335   // Optional attributes may be added.
336   if (succeeded(parser.parseOptionalKeyword("attrs")))
337     if (failed(parser.parseEqual()) ||
338         failed(parser.parseOptionalAttrDict(result.attributes)))
339       return failure();
340 
341   SmallVector<OpAsmParser::OperandType, 8> regionOperands;
342   std::unique_ptr<Region> region = std::make_unique<Region>();
343   SmallVector<Type, 8> operandTypes, regionTypes;
344   if (parser.parseRegion(*region, regionOperands, regionTypes))
345     return failure();
346   result.addRegion(std::move(region));
347 
348   // Generic ops may specify that a subset of its outputs are tensors. Such
349   // outputs are specified in the result type.
350   // TODO: may need to move output parsing before region parsing.
351   // Need to wait for declarative assembly resolution to decide.
352   SmallVector<Type, 1> outputTensorsTypes;
353   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
354     return failure();
355   result.addTypes(outputTensorsTypes);
356 
357   return success();
358 }
359 
getGenericEffectsImpl(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects,ValueRange results,ValueRange inputBuffers,ValueRange outputBuffers)360 static void getGenericEffectsImpl(
361     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
362         &effects,
363     ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
364   for (Value value : results) {
365     effects.emplace_back(MemoryEffects::Allocate::get(), value,
366                          SideEffects::DefaultResource::get());
367   }
368   for (Value value : inputBuffers) {
369     effects.emplace_back(MemoryEffects::Read::get(), value,
370                          SideEffects::DefaultResource::get());
371   }
372   for (Value value : outputBuffers) {
373     effects.emplace_back(MemoryEffects::Read::get(), value,
374                          SideEffects::DefaultResource::get());
375     effects.emplace_back(MemoryEffects::Write::get(), value,
376                          SideEffects::DefaultResource::get());
377   }
378 }
379 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)380 void GenericOp::getEffects(
381     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
382         &effects) {
383   getGenericEffectsImpl(effects, getOperation()->getResults(),
384                         getInputBuffers(), getOutputBuffers());
385 }
386 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)387 void IndexedGenericOp::getEffects(
388     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
389         &effects) {
390   getGenericEffectsImpl(effects, getOperation()->getResults(),
391                         getInputBuffers(), getOutputBuffers());
392 }
393 
394 namespace {
395 
396 template <typename GenericOpType>
397 struct BlockArgsVerifier {
398   static LogicalResult verify(GenericOpType op, Block &block);
399 };
400 
401 template <typename GenericOpType>
verify(GenericOpType op,Block & block)402 LogicalResult BlockArgsVerifier<GenericOpType>::verify(GenericOpType op,
403                                                        Block &block) {
404   auto nOperands = op.getNumOperands();
405   if (block.getNumArguments() != nOperands)
406     return op.emitOpError("expected number of block arguments to match number "
407                           "of operands");
408 
409   // Note: the number and type of yield values are checked in the YieldOp.
410   auto nInputViews = op.getNumInputs();
411   for (unsigned i = 0; i < nOperands; ++i) {
412     auto viewType = op.getShapedType(i);
413     if (viewType.getElementType() != block.getArgument(i).getType())
414       return op.emitOpError("expected block argument ")
415              << (i + 1) << " of the same type as elemental type of "
416              << ((i < nInputViews) ? "input " : "output ")
417              << "operand: " << viewType;
418   }
419   return success();
420 }
421 
422 template <>
verify(IndexedGenericOp op,Block & block)423 LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op,
424                                                           Block &block) {
425   auto nInputViews = op.getNumInputs();
426   auto nLoops = op.getNumLoops();
427   auto nOperands = op.getNumOperands();
428   if (block.getNumArguments() != nOperands + nLoops)
429     return op.emitOpError(
430         "expected number of block arguments to match number of operands + "
431         "number of loops");
432 
433   // Note: the number and type of yield values are checked in the YieldOp.
434   for (unsigned i = 0; i < nLoops; ++i)
435     if (!block.getArgument(i).getType().isIndex())
436       return op.emitOpError("expected block argument ")
437              << (i + 1) << " to be an index";
438 
439   for (unsigned i = 0; i < nOperands; ++i) {
440     unsigned memrefArgIndex = i + nLoops;
441     auto viewType = op.getShapedType(i);
442     if (viewType.getElementType() !=
443         block.getArgument(memrefArgIndex).getType())
444       return op.emitOpError("expected block argument ")
445              << (memrefArgIndex + 1)
446              << " of the same type as elemental type of "
447              << ((i < nInputViews) ? "input " : "output ")
448              << "operand: " << viewType;
449   }
450   return success();
451 }
452 
453 template <typename GenericOpType>
454 struct AnnotationsVerifier {
verify__anon9ff5be7e0111::AnnotationsVerifier455   static LogicalResult verify(GenericOpType op) { return success(); }
456 };
457 
458 template <>
verify(GenericOp op)459 LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
460   ArrayAttr sparseAttr = op.sparseAttr();
461   if (!sparseAttr)
462     return success();
463   // Verify consistency of sparse annotations.
464   if (!op.hasTensorSemantics())
465     return op.emitOpError("expected sparse annotations on tensors only");
466   if (op.getNumOutputs() != 1)
467     return op.emitOpError("expected single output tensor");
468   unsigned numTensors = op.getNumInputsAndOutputs();
469   if (sparseAttr.size() != numTensors)
470     return op.emitOpError("expected one sparse annotation for each tensor");
471   for (unsigned t = 0; t < numTensors; t++) {
472     auto dimAttr = sparseAttr[t].dyn_cast_or_null<ArrayAttr>();
473     if (!dimAttr)
474       return op.emitOpError("expected sparse annotation array for tensor ")
475              << t;
476     unsigned rank = op.getShapedType(t).getRank();
477     if (dimAttr.size() != rank)
478       return op.emitOpError("expected sparse annotation with rank ")
479              << rank << " for tensor " << t;
480     // Per-dimension annotations for each tensor consist of only "D" or "S".
481     for (unsigned d = 0; d < rank; d++) {
482       if (isDenseDim(dimAttr[d])) {
483         continue;
484       } else if (isSparseDim(dimAttr[d])) {
485         if (t == numTensors - 1)
486           return op.emitOpError("sparse output tensors not supported (yet)");
487         continue;
488       }
489       return op.emitOpError("expected sparse annotation at position ")
490              << d << " for tensor " << t;
491     }
492   }
493   return success();
494 }
495 
496 } // namespace
497 
498 template <typename GenericOpType>
verifyGenericOp(GenericOpType op)499 static LogicalResult verifyGenericOp(GenericOpType op) {
500   auto nLoops = op.getNumLoops();
501 
502   if (op.inputs().size() + op.output_buffers().size() +
503           op.init_tensors().size() + op.getNumResults() ==
504       0)
505     return op.emitOpError("expected at least 1 Shaped operand or return");
506 
507   auto &region = op.region();
508   if (!llvm::hasSingleElement(region))
509     return op.emitOpError("expected region with 1 block");
510   if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front())))
511     return failure();
512 
513   if (op.indexing_maps().size() != op.getNumInputsAndOutputs())
514     return op.emitOpError("expected the number of indexing_map (")
515            << op.indexing_maps().size()
516            << ") to be equal to the number of inputs and outputs ("
517            << op.getNumInputsAndOutputs() << ")";
518 
519   SmallVector<AffineMap, 4> indexingMaps;
520   indexingMaps.reserve(op.indexing_maps().size());
521   for (auto en : llvm::enumerate(op.indexing_maps())) {
522     auto idx = en.index();
523     auto m = en.value().template cast<AffineMapAttr>().getValue();
524     indexingMaps.push_back(m); // Save reference to map for further checks.
525     auto view = op.getShapedType(idx);
526 
527     if (m.getNumSymbols() != 0)
528       return op.emitOpError("unexpected symbols in indexing_map #") << idx;
529 
530     if (m.getNumDims() != nLoops)
531       return op.emitOpError("expected indexing_map #")
532              << idx << " to have " << nLoops
533              << " dim(s) to match the number of loops";
534 
535     if (m.getNumResults() != view.getRank())
536       return op.emitOpError("expected indexing_map #")
537              << idx << " results to match view rank: " << view;
538   }
539 
540   if (!op.getShapesToLoopsMap())
541     return op.emitOpError("expected the shape-to-loops map to be non-null");
542 
543   if (failed(AnnotationsVerifier<GenericOpType>::verify(op)))
544     return failure();
545 
546   return success();
547 }
548 
verify(GenericOp op)549 static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
550 
verify(IndexedGenericOp op)551 static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
552 
553 //===----------------------------------------------------------------------===//
554 // ReshapeOp
555 //===----------------------------------------------------------------------===//
556 
557 /// Collapse reassociation maps that are used in pair of reshape ops where one
558 /// is a producer and other is the consumer. Only valid to use this method when
559 /// both the producer and consumer are collapsing dimensions or both are
560 /// expanding dimensions.
561 ///
562 /// For example,
563 ///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
564 ///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
565 ///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
566 ///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
567 ///                   affine_map<(d0, d1, d2) -> (d2)>]
568 ///
569 /// is folded into
570 ///
571 ///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
572 ///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,ArrayRef<AffineMap> mapsConsumer,MLIRContext * context)573 static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
574                                            ArrayRef<AffineMap> mapsConsumer,
575                                            MLIRContext *context) {
576   // Handle the corner case of the result being a rank 0 shaped type. Return an
577   // emtpy ArrayAttr.
578   if (mapsConsumer.empty() && !mapsProducer.empty())
579     return ArrayAttr::get(ArrayRef<Attribute>(), context);
580   if (mapsProducer.empty() || mapsConsumer.empty() ||
581       mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
582       mapsProducer.size() != mapsConsumer[0].getNumDims())
583     return nullptr;
584   unsigned numLhsDims = mapsProducer[0].getNumDims();
585   unsigned currDim = 0;
586   SmallVector<AffineExpr, 4> reassociations;
587   SmallVector<Attribute, 4> reassociationMaps;
588   for (AffineMap rhs : mapsConsumer) {
589     for (AffineExpr rhsExpr : rhs.getResults()) {
590       AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
591       for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
592            i < e; ++i) {
593         reassociations.push_back(getAffineDimExpr(currDim++, context));
594       }
595     }
596     reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
597         numLhsDims, /*numSymbols =*/0, reassociations, context)));
598     reassociations.clear();
599   }
600   return ArrayAttr::get(reassociationMaps, context);
601 }
602 
603 namespace {
604 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
605 /// dimensions or are both expanding dimensions.
606 template <typename ReshapeOpTy>
607 struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
608   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
matchAndRewrite__anon9ff5be7e0211::CollapseReshapeOps609   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
610                                 PatternRewriter &rewriter) const override {
611     auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
612     if (!srcReshapeOp)
613       return failure();
614 
615     auto areReshapeOpsFoldable = [](ShapedType largerType,
616                                     ShapedType intermediateType,
617                                     ShapedType smallerType) -> bool {
618       return largerType.getRank() > intermediateType.getRank() &&
619              intermediateType.getRank() > smallerType.getRank();
620     };
621     // Check if producer and consumer are both expanding dims.
622     if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
623                               srcReshapeOp.getSrcType())) {
624       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
625           reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
626           collapseReassociationMaps(reshapeOp.getReassociationMaps(),
627                                     srcReshapeOp.getReassociationMaps(),
628                                     rewriter.getContext()));
629       return success();
630     }
631     // Check if producer and consumer are both collapsing dims.
632     if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(),
633                               reshapeOp.getResultType())) {
634       rewriter.replaceOpWithNewOp<ReshapeOpTy>(
635           reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
636           collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
637                                     reshapeOp.getReassociationMaps(),
638                                     rewriter.getContext()));
639       return success();
640     }
641     return failure();
642   }
643 };
644 } // namespace
645 
646 template <typename ReshapeOpTy>
foldReshapeOp(ReshapeOpTy reshapeOp,ArrayRef<Attribute> operands)647 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
648                                   ArrayRef<Attribute> operands) {
649   // Fold producer-consumer reshape ops that where the operand type of the
650   // producer is same as the return type of the consumer. This can only be
651   // verified if the shapes in question are static.
652   ReshapeOpTy reshapeSrcOp =
653       reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
654   if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
655       reshapeOp.getResultType().hasStaticShape() &&
656       reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
657     return reshapeSrcOp.src();
658   // Reshape of a constant can be replaced with a new constant.
659   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
660     return elements.reshape(
661         reshapeOp.getResult().getType().template cast<ShapedType>());
662   }
663   return nullptr;
664 }
665 
666 /// Return true if the reassociation specification is valid, false otherwise.
667 /// When false, the `invalidIndex` integer pointer is optionally filled with the
668 /// index of the offending reassociation map.
isReassociationValid(ArrayRef<AffineMap> reassociation,int * invalidIndex=nullptr)669 static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
670                                  int *invalidIndex = nullptr) {
671   if (reassociation.empty())
672     return true;
673   unsigned nDims = reassociation[0].getNumDims();
674   unsigned nextExpectedDim = 0;
675   for (auto it : llvm::enumerate(reassociation)) {
676     auto m = it.value();
677     if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
678       if (invalidIndex)
679         *invalidIndex = it.index();
680       return false;
681     }
682     for (auto e : m.getResults()) {
683       auto d = e.dyn_cast<AffineDimExpr>();
684       if (!d || d.getPosition() != nextExpectedDim++) {
685         if (invalidIndex)
686           *invalidIndex = it.index();
687         return false;
688       }
689     }
690   }
691   if (nextExpectedDim != nDims) {
692     if (invalidIndex)
693       *invalidIndex = reassociation.size() - 1;
694     return false;
695   }
696   return true;
697 }
698 
699 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
700 /// copies.
isReshapableDimBand(unsigned dim,unsigned extent,ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> strides)701 static bool isReshapableDimBand(unsigned dim, unsigned extent,
702                                 ArrayRef<int64_t> sizes,
703                                 ArrayRef<AffineExpr> strides) {
704   assert(sizes.size() == strides.size() && "mismatched ranks");
705   // off by 1 indexing to avoid out of bounds
706   //                       V
707   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
708     // Only bands of static shapes are reshapable. This is due to the fact that
709     // there is no relation between dynamic sizes and dynamic strides: we do not
710     // have enough information to know whether a "-1" size corresponds to the
711     // proper symbol in the AffineExpr of a stride.
712     if (ShapedType::isDynamic(sizes[dim + 1]))
713       return false;
714     // TODO: Refine this by passing the proper nDims and nSymbols so we can
715     // simplify on the fly and catch more reshapable cases.
716     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
717       return false;
718   }
719   return true;
720 }
721 
722 /// Compute the MemRefType obtained by applying the `reassociation` (which is
723 /// expected to be valid) to `type`.
724 /// If `type` is Contiguous MemRefType, this always produce a contiguous
725 /// MemRefType.
726 static MemRefType
computeReshapeCollapsedType(MemRefType type,ArrayRef<AffineMap> reassociation)727 computeReshapeCollapsedType(MemRefType type,
728                             ArrayRef<AffineMap> reassociation) {
729   auto sizes = type.getShape();
730   AffineExpr offset;
731   SmallVector<AffineExpr, 4> strides;
732   auto status = getStridesAndOffset(type, strides, offset);
733   (void)status;
734   assert(succeeded(status) && "expected strided memref");
735 
736   SmallVector<int64_t, 4> newSizes;
737   newSizes.reserve(reassociation.size());
738   SmallVector<AffineExpr, 4> newStrides;
739   newStrides.reserve(reassociation.size());
740 
741   // Use the fact that reassociation is valid to simplify the logic: only use
742   // each map's rank.
743   assert(isReassociationValid(reassociation) && "invalid reassociation");
744   unsigned currentDim = 0;
745   for (AffineMap m : reassociation) {
746     unsigned dim = m.getNumResults();
747     int64_t size = 1;
748     AffineExpr stride = strides[currentDim + dim - 1];
749     if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
750       size = ShapedType::kDynamicSize;
751       stride = AffineExpr();
752     } else {
753       for (unsigned d = 0; d < dim; ++d)
754         size *= sizes[currentDim + d];
755     }
756     newSizes.push_back(size);
757     newStrides.push_back(stride);
758     currentDim += dim;
759   }
760 
761   // Early-exit: if `type` is contiguous, the result must be contiguous.
762   if (canonicalizeStridedLayout(type).getAffineMaps().empty())
763     return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
764 
765   // Convert back to int64_t because we don't have enough information to create
766   // new strided layouts from AffineExpr only. This corresponds to a case where
767   // copies may be necessary.
768   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
769   if (auto o = offset.dyn_cast<AffineConstantExpr>())
770     intOffset = o.getValue();
771   SmallVector<int64_t, 4> intStrides;
772   intStrides.reserve(strides.size());
773   for (auto stride : newStrides) {
774     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
775       intStrides.push_back(cst.getValue());
776     else
777       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
778   }
779   auto layout =
780       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
781   return canonicalizeStridedLayout(
782       MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
783 }
784 
785 /// Helper functions assert Attribute of the proper type in attr and returns the
786 /// corresponding vector.
787 /// TODO: this should be evolved into a generic
788 /// `getRangeOfType<AffineMap>(ArrayAttr attrs)` that does not copy.
getAffineMaps(ArrayAttr attrs)789 static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) {
790   return llvm::to_vector<8>(llvm::map_range(
791       attrs, [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
792 }
793 
794 template <typename AffineExprTy>
getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays)795 unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
796   unsigned pos = 0;
797   for (const auto &exprs : exprArrays) {
798     for (auto expr : exprs) {
799       expr.walk([&pos](AffineExpr e) {
800         if (auto d = e.dyn_cast<AffineExprTy>())
801           pos = std::max(pos, d.getPosition());
802       });
803     }
804   }
805   return pos;
806 }
807 
808 static SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation)809 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
810   unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
811   assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
812          "Expected symbol-less expressions");
813   SmallVector<AffineMap, 4> maps;
814   maps.reserve(reassociation.size());
815   for (const auto &exprs : reassociation) {
816     assert(!exprs.empty());
817     maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
818   }
819   return maps;
820 }
821 
822 static SmallVector<SmallVector<AffineExpr, 2>, 2>
convertReassociationIndicesToMaps(OpBuilder & b,ArrayRef<ReassociationIndices> reassociationIndices)823 convertReassociationIndicesToMaps(
824     OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
825   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
826   for (const auto &indices : reassociationIndices) {
827     SmallVector<AffineExpr, 2> reassociationMap;
828     reassociationMap.reserve(indices.size());
829     for (int64_t index : indices)
830       reassociationMap.push_back(b.getAffineDimExpr(index));
831     reassociationMaps.push_back(std::move(reassociationMap));
832   }
833   return reassociationMaps;
834 }
835 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)836 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
837                                     Value src,
838                                     ArrayRef<ReassociationExprs> reassociation,
839                                     ArrayRef<NamedAttribute> attrs) {
840   auto maps = getSymbolLessAffineMaps(reassociation);
841   auto memRefType = src.getType().cast<MemRefType>();
842   auto resultType = computeReshapeCollapsedType(memRefType, maps);
843   build(b, result, resultType, src, attrs);
844   result.addAttribute(ReshapeOp::getReassociationAttrName(),
845                       b.getAffineMapArrayAttr(maps));
846 }
847 
build(OpBuilder & b,OperationState & result,Type resultType,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)848 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
849                                     Type resultType, Value src,
850                                     ArrayRef<ReassociationExprs> reassociation,
851                                     ArrayRef<NamedAttribute> attrs) {
852   auto maps = getSymbolLessAffineMaps(reassociation);
853   build(b, result, resultType, src, attrs);
854   result.addAttribute(ReshapeOp::getReassociationAttrName(),
855                       b.getAffineMapArrayAttr(maps));
856 }
857 
getViewSource()858 Value mlir::linalg::ReshapeOp::getViewSource() { return src(); }
859 
860 // Common verifier for reshape-like types. Fills `expandedType` and
861 // `collapsedType` with the proper `src` or `result` type.
862 template <typename Op, typename T>
verifyReshapeLikeTypes(Op op,T & expandedType,T & collapsedType)863 static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
864                                             T &collapsedType) {
865   expandedType = op.getSrcType();
866   collapsedType = op.getResultType();
867   unsigned expandedRank = expandedType.getRank();
868   unsigned collapsedRank = collapsedType.getRank();
869   bool isCollapse = expandedRank > collapsedRank;
870   if (!isCollapse) {
871     std::swap(expandedRank, collapsedRank);
872     std::swap(expandedType, collapsedType);
873   }
874   if (expandedRank == 0)
875     return op.emitOpError("expected non-zero memref ranks");
876   if (expandedRank == collapsedRank)
877     return op.emitOpError("expected to collapse or expand dims");
878 
879   if (collapsedRank == 0) {
880     // If collapsed rank is 0, then expanded type must be static shaped and of
881     // sizes 1.
882     if (llvm::any_of(expandedType.getShape(),
883                      [](int64_t dim) -> bool { return dim != 1; }))
884       return op.emitOpError(
885           "invalid to reshape tensor/memref with non-unit extent dimensions to "
886           "zero-rank tensor/memref");
887     return success();
888   }
889   if (collapsedRank != op.reassociation().size())
890     return op.emitOpError("expected rank of the collapsed type(")
891            << collapsedRank << ") to be the number of reassociation maps("
892            << op.reassociation().size() << ")";
893   auto maps = getAffineMaps(op.reassociation());
894   for (auto it : llvm::enumerate(maps))
895     if (it.value().getNumDims() != expandedRank)
896       return op.emitOpError("expected reassociation map #")
897              << it.index() << " of same rank as expanded memref("
898              << expandedRank << "), but got " << it.value().getNumDims();
899   int invalidIdx = 0;
900   if (!isReassociationValid(maps, &invalidIdx))
901     return op.emitOpError("expected reassociation map #")
902            << invalidIdx << " to be valid and contiguous";
903   return success();
904 }
905 
verify(ReshapeOp op)906 static LogicalResult verify(ReshapeOp op) {
907   MemRefType expandedType, collapsedType;
908   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
909     return failure();
910   auto maps = getAffineMaps(op.reassociation());
911   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
912   if (collapsedType != expectedType)
913     return op.emitOpError("expected collapsed type to be ")
914            << expectedType << ", but got " << collapsedType;
915   return success();
916 }
917 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)918 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
919                                             MLIRContext *context) {
920   results.insert<CollapseReshapeOps<ReshapeOp>>(context);
921 }
922 
923 //===----------------------------------------------------------------------===//
924 // TensorReshapeOp
925 //===----------------------------------------------------------------------===//
926 
927 /// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
928 static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,ArrayRef<AffineMap> reassociation)929 computeTensorReshapeCollapsedType(RankedTensorType type,
930                                   ArrayRef<AffineMap> reassociation) {
931   auto shape = type.getShape();
932   SmallVector<int64_t, 4> newShape;
933   newShape.reserve(reassociation.size());
934 
935   // Use the fact that reassociation is valid to simplify the logic: only use
936   // each map's rank.
937   assert(isReassociationValid(reassociation) && "invalid reassociation");
938   unsigned currentDim = 0;
939   for (AffineMap m : reassociation) {
940     unsigned dim = m.getNumResults();
941     auto band = shape.slice(currentDim, dim);
942     int64_t size = 1;
943     if (llvm::is_contained(band, ShapedType::kDynamicSize))
944       size = ShapedType::kDynamicSize;
945     else
946       for (unsigned d = 0; d < dim; ++d)
947         size *= shape[currentDim + d];
948     newShape.push_back(size);
949     currentDim += dim;
950   }
951 
952   return RankedTensorType::get(newShape, type.getElementType());
953 }
954 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)955 void mlir::linalg::TensorReshapeOp::build(
956     OpBuilder &b, OperationState &result, Value src,
957     ArrayRef<ReassociationExprs> reassociation,
958     ArrayRef<NamedAttribute> attrs) {
959   auto maps = getSymbolLessAffineMaps(reassociation);
960   auto resultType = computeTensorReshapeCollapsedType(
961       src.getType().cast<RankedTensorType>(), maps);
962   build(b, result, resultType, src, attrs);
963   result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
964                       b.getAffineMapArrayAttr(maps));
965 }
966 
build(OpBuilder & b,OperationState & result,Type resultType,Value src,ArrayRef<ReassociationExprs> reassociation,ArrayRef<NamedAttribute> attrs)967 void mlir::linalg::TensorReshapeOp::build(
968     OpBuilder &b, OperationState &result, Type resultType, Value src,
969     ArrayRef<ReassociationExprs> reassociation,
970     ArrayRef<NamedAttribute> attrs) {
971   auto maps = getSymbolLessAffineMaps(reassociation);
972   build(b, result, resultType, src, attrs);
973   result.addAttribute(TensorReshapeOp::getReassociationAttrName(),
974                       b.getAffineMapArrayAttr(maps));
975 }
976 
verify(TensorReshapeOp op)977 static LogicalResult verify(TensorReshapeOp op) {
978   RankedTensorType expandedType, collapsedType;
979   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
980     return failure();
981   auto maps = getAffineMaps(op.reassociation());
982   // TODO: expanding a ? with a non-constant is under-specified. Error
983   // out.
984   RankedTensorType expectedType =
985       computeTensorReshapeCollapsedType(expandedType, maps);
986   if (collapsedType != expectedType)
987     return op.emitOpError("expected collapsed type to be ")
988            << expectedType << ", but got " << collapsedType;
989   return success();
990 }
991 
992 namespace {
993 /// Reshape of a splat constant can be replaced with a constant of the result
994 /// type.
995 struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
996   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon9ff5be7e0711::FoldReshapeWithConstant997   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
998                                 PatternRewriter &rewriter) const override {
999     DenseElementsAttr attr;
1000     if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
1001       return failure();
1002     if (!attr || !attr.isSplat())
1003       return failure();
1004     DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
1005         reshapeOp.getResultType(), attr.getRawData(), true);
1006     rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
1007     return success();
1008   }
1009 };
1010 } // namespace
1011 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1012 void TensorReshapeOp::getCanonicalizationPatterns(
1013     OwningRewritePatternList &results, MLIRContext *context) {
1014   results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
1015       context);
1016 }
1017 
1018 //===----------------------------------------------------------------------===//
1019 // SliceOp
1020 //===----------------------------------------------------------------------===//
build(OpBuilder & b,OperationState & result,Value base,ValueRange indexings)1021 void mlir::linalg::SliceOp::build(OpBuilder &b, OperationState &result,
1022                                   Value base, ValueRange indexings) {
1023   result.addOperands(base);
1024   result.addOperands(indexings);
1025 
1026   auto memRefType = base.getType().cast<MemRefType>();
1027   int64_t offset;
1028   SmallVector<int64_t, 4> strides;
1029   auto res = getStridesAndOffset(memRefType, strides, offset);
1030   assert(succeeded(res) && strides.size() == indexings.size());
1031   (void)res;
1032 
1033   unsigned rank = memRefType.getRank();
1034   // TODO: propagate static size and stride information when available.
1035   SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size.
1036   result.addTypes({MemRefType::Builder(memRefType)
1037                        .setShape(sizes)
1038                        .setAffineMaps(makeStridedLinearLayoutMap(
1039                            strides, offset, b.getContext()))});
1040 }
1041 
print(OpAsmPrinter & p,SliceOp op)1042 static void print(OpAsmPrinter &p, SliceOp op) {
1043   auto indexings = op.indexings();
1044   p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings
1045     << "] ";
1046   p.printOptionalAttrDict(op.getAttrs());
1047   p << " : " << op.getBaseViewType();
1048   if (!indexings.empty())
1049     p << ", " << op.indexings().getTypes();
1050   p << ", " << op.getType();
1051 }
1052 
parseSliceOp(OpAsmParser & parser,OperationState & result)1053 static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) {
1054   OpAsmParser::OperandType baseInfo;
1055   SmallVector<OpAsmParser::OperandType, 8> operands;
1056   SmallVector<Type, 8> types;
1057   if (parser.parseOperand(baseInfo) ||
1058       parser.parseOperandList(operands, OpAsmParser::Delimiter::Square) ||
1059       parser.parseOptionalAttrDict(result.attributes) ||
1060       parser.parseColonTypeList(types))
1061     return failure();
1062 
1063   if (types.size() < 2)
1064     return parser.emitError(parser.getCurrentLocation(),
1065                             "expected at least input and result view types");
1066 
1067   ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
1068   return failure(
1069       parser.resolveOperand(baseInfo, types.front(), result.operands) ||
1070       (!operands.empty() &&
1071        parser.resolveOperands(operands, indexingTypes,
1072                               operands.front().location, result.operands)) ||
1073       parser.addTypeToList(types.back(), result.types));
1074 }
1075 
verify(SliceOp op)1076 static LogicalResult verify(SliceOp op) {
1077   unsigned rank = op.getBaseViewRank();
1078   if (rank != llvm::size(op.indexings()))
1079     return op.emitOpError("expected ")
1080            << rank << " indexings, got " << llvm::size(op.indexings());
1081   unsigned index = 0;
1082   for (auto indexing : op.indexings()) {
1083     if (indexing.getType().isa<IndexType>())
1084       --rank;
1085     ++index;
1086   }
1087   if (op.getRank() != rank)
1088     return op.emitOpError() << "expected rank of the view(" << op.getRank()
1089                             << ") to be the number of ranges(" << rank << ")";
1090   return success();
1091 }
1092 
getViewSource()1093 Value SliceOp::getViewSource() { return view(); }
1094 
1095 //===----------------------------------------------------------------------===//
1096 // YieldOp
1097 //===----------------------------------------------------------------------===//
1098 
print(OpAsmPrinter & p,linalg::YieldOp op)1099 static void print(OpAsmPrinter &p, linalg::YieldOp op) {
1100   p << op.getOperationName();
1101   if (op.getNumOperands() > 0)
1102     p << ' ' << op.getOperands();
1103   p.printOptionalAttrDict(op.getAttrs());
1104   if (op.getNumOperands() > 0)
1105     p << " : " << op.getOperandTypes();
1106 }
1107 
parseYieldOp(OpAsmParser & parser,OperationState & result)1108 static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
1109   SmallVector<OpAsmParser::OperandType, 2> opInfo;
1110   SmallVector<Type, 2> types;
1111   llvm::SMLoc loc = parser.getCurrentLocation();
1112   return failure(parser.parseOperandList(opInfo) ||
1113                  parser.parseOptionalAttrDict(result.attributes) ||
1114                  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1115                  parser.resolveOperands(opInfo, types, loc, result.operands));
1116 }
1117 
1118 // Check the operand number and types must match the element types of the
1119 // LinalgOp interface's shaped operands.
verifyYield(linalg::YieldOp op,LinalgOp linalgOpInterface)1120 static LogicalResult verifyYield(linalg::YieldOp op,
1121                                  LinalgOp linalgOpInterface) {
1122   auto nOutputs = linalgOpInterface.getNumOutputs();
1123   if (op.getNumOperands() != nOutputs)
1124     return op.emitOpError("expected number of yield values (")
1125            << nOutputs << ") to match the number of operands of the enclosing "
1126            << "LinalgOp (" << op.getNumOperands() << ")";
1127 
1128   for (unsigned i = 0; i != nOutputs; ++i) {
1129     auto elementType =
1130         linalgOpInterface.getOutputShapedType(i).getElementType();
1131     if (op.getOperand(i).getType() != elementType)
1132       return op.emitOpError("type of yield operand ")
1133              << (i + 1) << " (" << op.getOperand(i).getType()
1134              << ") doesn't match "
1135              << "the element type of the enclosing linalg.generic op ("
1136              << elementType << ")";
1137   }
1138   return success();
1139 }
1140 
verify(linalg::YieldOp op)1141 static LogicalResult verify(linalg::YieldOp op) {
1142   auto *parentOp = op->getParentOp();
1143   if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1144     return op.emitOpError("expected single non-empty parent region");
1145 
1146   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1147     return verifyYield(op, cast<LinalgOp>(parentOp));
1148 
1149   return op.emitOpError("expected parent op with LinalgOp interface");
1150 }
1151 
1152 /////// Operations corresponding to library calls defined with Tablegen ////////
1153 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1154 void FillOp::getEffects(
1155     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1156         &effects) {
1157   effects.emplace_back(MemoryEffects::Write::get(), output(),
1158                        SideEffects::DefaultResource::get());
1159 }
1160 
verify(FillOp op)1161 static LogicalResult verify(FillOp op) {
1162   auto viewType = op.getOutputShapedType(0);
1163   auto fillType = op.value().getType();
1164   if (viewType.getElementType() != fillType)
1165     return op.emitOpError("expects fill type to match view elemental type");
1166   return success();
1167 }
1168 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1169 void CopyOp::getEffects(
1170     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1171         &effects) {
1172   effects.emplace_back(MemoryEffects::Read::get(), input(),
1173                        SideEffects::DefaultResource::get());
1174   effects.emplace_back(MemoryEffects::Write::get(), output(),
1175                        SideEffects::DefaultResource::get());
1176 }
1177 
verify(CopyOp op)1178 static LogicalResult verify(CopyOp op) {
1179   auto outputViewType = op.getOutputShapedType(0);
1180   auto inputViewType = op.getInputShapedType(0);
1181   if (inputViewType.getElementType() != outputViewType.getElementType())
1182     return op.emitOpError("expects views of the same type");
1183   if (inputViewType.getRank() != outputViewType.getRank())
1184     return op.emitOpError("expects views of the same rank");
1185   auto rank = op.getNumParallelLoops();
1186   auto inputPermutationMap = op.inputPermutation();
1187   if (inputPermutationMap) {
1188     if (inputPermutationMap->getNumInputs() != rank)
1189       return op.emitOpError("expects optional input_permutation map of rank ")
1190              << rank;
1191     if (!inputPermutationMap->isPermutation())
1192       return op.emitOpError(
1193           "expects optional input_permutation map to be a permutation");
1194   }
1195   auto outputPermutationMap = op.outputPermutation();
1196   if (outputPermutationMap) {
1197     if (outputPermutationMap->getNumInputs() != rank)
1198       return op.emitOpError("expects optional output_permutation map of rank ")
1199              << rank;
1200     if (!outputPermutationMap->isPermutation())
1201       return op.emitOpError(
1202           "expects optional output_permutation map to be a permutation");
1203   }
1204   if (rank == 0 && inputPermutationMap)
1205     return op.emitOpError("expected no input permutation when rank == 0");
1206   if (rank == 0 && outputPermutationMap)
1207     return op.emitOpError("expected no output permutation when rank == 0");
1208   return success();
1209 }
1210 
1211 template <typename LinalgPoolingOp>
verifyStrideOrDilation(LinalgPoolingOp op,ArrayRef<Attribute> attrs,bool isStride)1212 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
1213                                             ArrayRef<Attribute> attrs,
1214                                             bool isStride) {
1215   auto strideOrDilation = isStride ? "stride" : "dilation";
1216   if (attrs.size() != op.getNumWindowLoops())
1217     return op.emitOpError("expects num ")
1218            << strideOrDilation
1219            << "s equal to number of window dimensions: " << attrs.size()
1220            << " vs " << op.getNumWindowLoops();
1221   return success();
1222 }
1223 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)1224 void ConvOp::getEffects(
1225     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1226         &effects) {
1227   effects.emplace_back(MemoryEffects::Read::get(), input(),
1228                        SideEffects::DefaultResource::get());
1229   effects.emplace_back(MemoryEffects::Read::get(), filter(),
1230                        SideEffects::DefaultResource::get());
1231   effects.emplace_back(MemoryEffects::Write::get(), output(),
1232                        SideEffects::DefaultResource::get());
1233 }
1234 
verify(ConvOp op)1235 static LogicalResult verify(ConvOp op) {
1236   auto oType = op.output().getType().cast<MemRefType>();
1237   auto fType = op.filter().getType().cast<MemRefType>();
1238   auto iType = op.input().getType().cast<MemRefType>();
1239   if (oType.getElementType() != iType.getElementType() ||
1240       oType.getElementType() != fType.getElementType())
1241     return op.emitOpError("expects memref elemental types to match");
1242   if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
1243     return op.emitOpError("expects memref ranks to match");
1244   if (oType.getRank() <= 2)
1245     return op.emitOpError("expects memref ranks to be greater than 2");
1246   if (auto strides = op.strides()) {
1247     if (failed(
1248             verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
1249       return failure();
1250   }
1251   if (auto dilations = op.dilations()) {
1252     if (failed(verifyStrideOrDilation(op, dilations->getValue(),
1253                                       /*isStride=*/false)))
1254       return failure();
1255   }
1256   return success();
1257 }
1258 
1259 template <typename PoolingOp>
verifySingleInputPoolingOp(PoolingOp op)1260 static LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
1261   auto inputType = op.input().getType().template cast<MemRefType>();
1262   auto outputType = op.output().getType().template cast<MemRefType>();
1263   if (outputType.getElementType() != inputType.getElementType())
1264     return op.emitOpError("expects memref elemental types to match");
1265 
1266   auto windowDimsType = op.windowDims().getType().template cast<MemRefType>();
1267   if (outputType.getRank() != inputType.getRank() ||
1268       outputType.getRank() != windowDimsType.getRank())
1269     return op.emitOpError("expects memref ranks to match");
1270 
1271   if (auto strides = op.strides()) {
1272     if (failed(
1273             verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
1274       return failure();
1275   }
1276   if (auto dilations = op.dilations()) {
1277     if (failed(verifyStrideOrDilation(op, dilations->getValue(),
1278                                       /*isStride=*/false)))
1279       return failure();
1280   }
1281   return success();
1282 }
1283 
1284 #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME)                                 \
1285   void OP_NAME::getEffects(                                                    \
1286       SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
1287           &effects) {                                                          \
1288     effects.emplace_back(MemoryEffects::Read::get(), input(),                  \
1289                          SideEffects::DefaultResource::get());                 \
1290     effects.emplace_back(MemoryEffects::Write::get(), output(),                \
1291                          SideEffects::DefaultResource::get());                 \
1292   }
1293 
verify(PoolingMaxOp op)1294 static LogicalResult verify(PoolingMaxOp op) {
1295   return verifySingleInputPoolingOp(op);
1296 }
verify(PoolingMinOp op)1297 static LogicalResult verify(PoolingMinOp op) {
1298   return verifySingleInputPoolingOp(op);
1299 }
verify(PoolingSumOp op)1300 static LogicalResult verify(PoolingSumOp op) {
1301   return verifySingleInputPoolingOp(op);
1302 }
1303 
1304 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
1305 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
1306 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
1307 
1308 namespace {
1309 struct EraseDeadLinalgOp;
1310 struct FoldTensorCastOp;
1311 } // namespace
1312 
1313 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
1314 
1315 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
1316 
1317 #define GET_OP_CLASSES
1318 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1319 
1320 #define GET_OP_CLASSES
1321 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1322 
1323 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1324 /// Assumes `op` is a LinalgOp.
getDimsOfType(Operation * op,StringRef iteratorTypeName,SmallVectorImpl<AffineExpr> & res)1325 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1326                                  SmallVectorImpl<AffineExpr> &res) {
1327   if (!cast<LinalgOp>(op).iterator_types())
1328     return;
1329 
1330   unsigned dim = 0;
1331   MLIRContext *ctx = op->getContext();
1332   for (auto tn :
1333        cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1334     if (tn == iteratorTypeName)
1335       res.push_back(getAffineDimExpr(dim, ctx));
1336     ++dim;
1337   }
1338 }
1339 
extractOrIdentityMap(Optional<AffineMap> maybeMap,unsigned rank,MLIRContext * context)1340 AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
1341                                              unsigned rank,
1342                                              MLIRContext *context) {
1343   if (maybeMap)
1344     return maybeMap.getValue();
1345   if (rank == 0)
1346     return AffineMap::get(context);
1347   return AffineMap::getMultiDimIdentityMap(rank, context);
1348 }
1349 
1350 SmallVector<AffineExpr, 4>
makeAffineDimExprs(unsigned num,unsigned & startIdx,MLIRContext * context)1351 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1352                                  MLIRContext *context) {
1353   SmallVector<AffineExpr, 4> res;
1354   res.reserve(num);
1355   for (unsigned i = 0; i < num; ++i)
1356     res.push_back(getAffineDimExpr(startIdx++, context));
1357   return res;
1358 }
1359 
1360 template <typename PoolingOp>
1361 SmallVector<AffineExpr, 4>
weightedPoolingInputIndex(PoolingOp op,ArrayRef<AffineExpr> outputDims,ArrayRef<AffineExpr> windowDims)1362 mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
1363                                         ArrayRef<AffineExpr> outputDims,
1364                                         ArrayRef<AffineExpr> windowDims) {
1365   assert(outputDims.size() == windowDims.size());
1366   SmallVector<AffineExpr, 4> res;
1367   res.reserve(outputDims.size());
1368   for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
1369     // TODO: add a level of indirection to linalg.generic.
1370     auto expr = op.getStride(i) * outputDims[i] +
1371                 op.getDilation(i) * windowDims[i] - op.getLowPad(i);
1372     res.push_back(expr);
1373   }
1374   return res;
1375 }
1376 
1377 #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE)                      \
1378   template SmallVector<AffineExpr, 4>                                          \
1379   mlir::linalg::weightedPoolingInputIndex<OP_TYPE>(                            \
1380       OP_TYPE op, ArrayRef<AffineExpr> outputDims,                             \
1381       ArrayRef<AffineExpr> windowDims);
1382 
1383 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)1384 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)
1385 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp)
1386 INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp)
1387 
1388 SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
1389                                                 ArrayRef<AffineExpr> b) {
1390   auto rangeA = llvm::make_range(a.begin(), a.end());
1391   auto rangeB = llvm::make_range(b.begin(), b.end());
1392   auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1393   return llvm::to_vector<4>(concatRanges);
1394 }
1395 
appendMangledType(llvm::raw_string_ostream & ss,Type t)1396 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1397   if (auto memref = t.dyn_cast<MemRefType>()) {
1398     ss << "view";
1399     for (auto size : memref.getShape())
1400       if (size < 0)
1401         ss << "sx";
1402       else
1403         ss << size << "x";
1404     appendMangledType(ss, memref.getElementType());
1405   } else if (auto vec = t.dyn_cast<VectorType>()) {
1406     ss << "vector";
1407     llvm::interleave(
1408         vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1409     appendMangledType(ss, vec.getElementType());
1410   } else if (t.isSignlessIntOrIndexOrFloat()) {
1411     ss << t;
1412   } else {
1413     llvm_unreachable("Invalid type for linalg library name mangling");
1414   }
1415 }
1416 
generateLibraryCallName(Operation * op)1417 std::string mlir::linalg::generateLibraryCallName(Operation *op) {
1418   assert(isa<LinalgOp>(op));
1419   std::string name(op->getName().getStringRef().str());
1420   name.reserve(128);
1421   std::replace(name.begin(), name.end(), '.', '_');
1422   llvm::raw_string_ostream ss(name);
1423   ss << "_";
1424   auto types = op->getOperandTypes();
1425   llvm::interleave(
1426       types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1427       [&]() { ss << "_"; });
1428   return ss.str();
1429 }
1430 
1431 // TODO: Consider making all this boilerplate easy to autogenerate
1432 // with Tablegen. This seems a desirable property in the context of
1433 // OpInterfaces where a Linalg "named" op **isa** LinalgOp.
fold(ArrayRef<Attribute> operands)1434 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
1435   if (succeeded(foldMemRefCast(*this)))
1436     return getResult();
1437   return foldReshapeOp(*this, operands);
1438 }
fold(ArrayRef<Attribute>)1439 OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
1440   if (succeeded(foldMemRefCast(*this)))
1441     return getResult();
1442   return {};
1443 }
fold(ArrayRef<Attribute> operands)1444 OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
1445   return foldReshapeOp(*this, operands);
1446 }
1447 
1448 //===----------------------------------------------------------------------===//
1449 // Auto-generated Linalg named ops.
1450 //===----------------------------------------------------------------------===//
1451 
1452 template <typename NamedStructuredOpType>
buildNamedStructuredOpRegionAndAttributesImpl(OpBuilder & opBuilder,Region & region,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes,std::function<void (unsigned,unsigned)> errorHandler)1453 static void buildNamedStructuredOpRegionAndAttributesImpl(
1454     OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
1455     TypeRange outputBufferTypes, TypeRange initTensorTypes,
1456     TypeRange resultTypes,
1457     std::function<void(unsigned, unsigned)> errorHandler) {
1458   // TODO: atm all operands go through getElementTypeOrSelf,
1459   // reconsider when we have evidence we need to.
1460   SmallVector<Type, 8> argTypes;
1461   for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
1462     for (auto t : containers)
1463       argTypes.push_back(getElementTypeOrSelf(t));
1464 
1465   // RAII.
1466   OpBuilder::InsertionGuard guard(opBuilder);
1467   Block *body = opBuilder.createBlock(&region, {}, argTypes);
1468   unsigned actual = body->getNumArguments();
1469   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
1470   if (expected != actual)
1471     return errorHandler(expected, actual);
1472 
1473   opBuilder.setInsertionPointToStart(body);
1474   mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
1475   NamedStructuredOpType::regionBuilder(*body);
1476 
1477   // indexing_maps is an auto-generated method.
1478 
1479   // iterator_types is an auto-generated method.
1480 }
1481 
1482 template <typename NamedStructuredOpType>
buildNamedStructuredOpRegionAndAttributes(OpBuilder & opBuilder,OperationState & result,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes)1483 void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
1484                                                OperationState &result,
1485                                                TypeRange inputTypes,
1486                                                TypeRange outputBufferTypes,
1487                                                TypeRange initTensorTypes,
1488                                                TypeRange resultTypes) {
1489   Region &region = *result.addRegion();
1490   buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
1491       opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
1492       resultTypes, [&](unsigned expected, unsigned actual) {
1493         llvm::errs() << "region expects " << expected << " args, got "
1494                      << actual;
1495         assert(expected != actual && "incorrect number of arguments");
1496       });
1497 }
1498 
1499 template <typename NamedStructuredOpType>
1500 static ParseResult
parseNamedStructuredOpRegion(OpAsmParser & parser,Region & region,TypeRange inputTypes,TypeRange outputBufferTypes,TypeRange initTensorTypes,TypeRange resultTypes)1501 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
1502                              TypeRange inputTypes, TypeRange outputBufferTypes,
1503                              TypeRange initTensorTypes, TypeRange resultTypes) {
1504   ParseResult res = success();
1505   OpBuilder opBuilder(parser.getBuilder().getContext());
1506   buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
1507       opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
1508       resultTypes, [&](unsigned expected, unsigned actual) {
1509         res = parser.emitError(parser.getCurrentLocation(),
1510                                llvm::formatv("region expects {0} args, got {1}",
1511                                              expected, actual));
1512       });
1513   return res;
1514 }
1515 
1516 static ParseResult
parseNamedStructuredOpResults(OpAsmParser & parser,SmallVectorImpl<Type> & resultTypes)1517 parseNamedStructuredOpResults(OpAsmParser &parser,
1518                               SmallVectorImpl<Type> &resultTypes) {
1519   if (succeeded(parser.parseOptionalArrow()))
1520     if (parser.parseTypeList(resultTypes))
1521       return failure();
1522   return success();
1523 }
1524 
1525 static ParseResult
parseCommonStructuredOpParts(OpAsmParser & parser,OperationState & result,SmallVectorImpl<Type> & inputTypes,SmallVectorImpl<Type> & outputBufferTypes,SmallVectorImpl<Type> & initTensorTypes)1526 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
1527                              SmallVectorImpl<Type> &inputTypes,
1528                              SmallVectorImpl<Type> &outputBufferTypes,
1529                              SmallVectorImpl<Type> &initTensorTypes) {
1530   llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc,
1531       initTensorsOperandsLoc;
1532   SmallVector<OpAsmParser::OperandType, 4> inputsOperands,
1533       outputBuffersOperands, initTensorsOperands;
1534 
1535   parser.parseOptionalAttrDict(result.attributes);
1536 
1537   if (succeeded(parser.parseOptionalKeyword("ins"))) {
1538     if (parser.parseLParen())
1539       return failure();
1540 
1541     inputsOperandsLoc = parser.getCurrentLocation();
1542     if (parser.parseOperandList(inputsOperands) ||
1543         parser.parseColonTypeList(inputTypes) || parser.parseRParen())
1544       return failure();
1545   }
1546 
1547   if (succeeded(parser.parseOptionalKeyword("outs"))) {
1548     outputBuffersOperandsLoc = parser.getCurrentLocation();
1549     if (parser.parseLParen() ||
1550         parser.parseOperandList(outputBuffersOperands) ||
1551         parser.parseColonTypeList(outputBufferTypes) || parser.parseRParen())
1552       return failure();
1553   }
1554   if (succeeded(parser.parseOptionalKeyword("init"))) {
1555     initTensorsOperandsLoc = parser.getCurrentLocation();
1556     if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) ||
1557         parser.parseColonTypeList(initTensorTypes) || parser.parseRParen())
1558       return failure();
1559   }
1560 
1561   if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
1562                              result.operands) ||
1563       parser.resolveOperands(outputBuffersOperands, outputBufferTypes,
1564                              outputBuffersOperandsLoc, result.operands) ||
1565       parser.resolveOperands(initTensorsOperands, initTensorTypes,
1566                              initTensorsOperandsLoc, result.operands))
1567     return failure();
1568 
1569   result.addAttribute("operand_segment_sizes",
1570                       parser.getBuilder().getI32VectorAttr(
1571                           {static_cast<int32_t>(inputsOperands.size()),
1572                            static_cast<int32_t>(outputBuffersOperands.size()),
1573                            static_cast<int32_t>(initTensorsOperands.size())}));
1574   return success();
1575 }
1576 
1577 template <typename NamedStructuredOpType>
parseNamedStructuredOp(OpAsmParser & parser,OperationState & result)1578 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
1579                                           OperationState &result) {
1580   SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes;
1581   if (parseCommonStructuredOpParts(parser, result, inputTypes,
1582                                    outputBufferTypes, initTensorTypes))
1583     return failure();
1584 
1585   // TODO: consider merging results parsing into region parsing.
1586   // Need to wait for declarative assembly resolution to decide.
1587   SmallVector<Type, 1> outputTensorsTypes;
1588   if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1589     return failure();
1590   result.addTypes(outputTensorsTypes);
1591 
1592   std::unique_ptr<Region> region = std::make_unique<Region>();
1593   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
1594           parser, *region, inputTypes, outputBufferTypes, initTensorTypes,
1595           outputTensorsTypes))
1596     return failure();
1597   result.addRegion(std::move(region));
1598 
1599   return success();
1600 }
1601 
printNamedStructuredOpResults(OpAsmPrinter & p,TypeRange resultTypes)1602 static void printNamedStructuredOpResults(OpAsmPrinter &p,
1603                                           TypeRange resultTypes) {
1604   if (resultTypes.empty())
1605     return;
1606   p.printOptionalArrowTypeList(resultTypes);
1607 }
1608 
1609 template <typename NamedStructuredOpType>
printCommonStructuredOpParts(OpAsmPrinter & p,NamedStructuredOpType op)1610 static void printCommonStructuredOpParts(OpAsmPrinter &p,
1611                                          NamedStructuredOpType op) {
1612   if (!op.inputs().empty())
1613     p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
1614   if (!op.output_buffers().empty())
1615     p << " outs(" << op.output_buffers() << " : "
1616       << op.output_buffers().getTypes() << ")";
1617   if (!op.init_tensors().empty())
1618     p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes()
1619       << ") ";
1620 }
1621 
1622 template <typename NamedStructuredOpType>
printNamedStructuredOp(OpAsmPrinter & p,NamedStructuredOpType op)1623 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
1624   p << op.getOperationName();
1625   p.printOptionalAttrDict(op.getAttrs(),
1626                           /*elidedAttrs=*/{"operand_segment_sizes"});
1627 
1628   // Printing is shared with generic ops, except for the region and
1629   // attributes.
1630   printCommonStructuredOpParts(p, op);
1631 
1632   // Results printing.
1633   printNamedStructuredOpResults(p, op.result_tensors().getTypes());
1634 
1635   // Region is elided.
1636 }
1637 
1638 template <typename NamedStructuredOpType>
verifyNamedStructuredOp(NamedStructuredOpType op)1639 static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
1640   return verifyGenericOp<NamedStructuredOpType>(op);
1641 }
1642 
1643 namespace {
1644 struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp__anon9ff5be7e0f11::EraseDeadLinalgOp1645   EraseDeadLinalgOp(PatternBenefit benefit = 1)
1646       : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1647 
matchAndRewrite__anon9ff5be7e0f11::EraseDeadLinalgOp1648   LogicalResult matchAndRewrite(Operation *op,
1649                                 PatternRewriter &rewriter) const override {
1650     auto linalgOp = dyn_cast<LinalgOp>(op);
1651     if (!linalgOp)
1652       return failure();
1653     for (Value v : linalgOp.getInputsAndOutputBuffers()) {
1654       // Linalg "inputs" may be either tensor or memref type.
1655       // tensor<0xelt_type> is a convention that may not always mean
1656       // "0 iterations". Only erase in cases we see memref<...x0x...>.
1657       auto mt = v.getType().dyn_cast<MemRefType>();
1658       if (!mt)
1659         continue;
1660       if (llvm::is_contained(mt.getShape(), 0)) {
1661         rewriter.eraseOp(linalgOp);
1662         return success();
1663       }
1664     }
1665     return failure();
1666   }
1667 };
1668 
1669 struct FoldTensorCastOp : public RewritePattern {
FoldTensorCastOp__anon9ff5be7e0f11::FoldTensorCastOp1670   FoldTensorCastOp(PatternBenefit benefit = 1)
1671       : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1672 
matchAndRewrite__anon9ff5be7e0f11::FoldTensorCastOp1673   LogicalResult matchAndRewrite(Operation *op,
1674                                 PatternRewriter &rewriter) const override {
1675     auto linalgOp = dyn_cast<LinalgOp>(op);
1676     if (!linalgOp)
1677       return failure();
1678 
1679     // If no operand comes from a TensorCastOp and can be folded then fail.
1680     bool hasTensorCastOperand =
1681         llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
1682           if (v.isa<BlockArgument>())
1683             return false;
1684           auto castOp = v.getDefiningOp<TensorCastOp>();
1685           return castOp && canFoldIntoConsumerOp(castOp);
1686         });
1687     if (!hasTensorCastOperand)
1688       return failure();
1689 
1690     SmallVector<Type, 4> newResultTypes;
1691     newResultTypes.reserve(op->getNumResults());
1692     SmallVector<Value, 4> newOperands;
1693     newOperands.reserve(op->getNumOperands());
1694     // Inputs may fold.
1695     for (Value v : linalgOp.getInputs()) {
1696       auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1697       newOperands.push_back(
1698           canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
1699     }
1700     // Output buffers are memrefs, they don't fold.
1701     newOperands.append(linalgOp.getOutputBuffers().begin(),
1702                        linalgOp.getOutputBuffers().end());
1703     // Init tensors may fold, in which case the resultType must also change.
1704     for (Value v : linalgOp.getInitTensors()) {
1705       auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1706       bool fold = canFoldIntoConsumerOp(tensorCastOp);
1707       newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
1708       newResultTypes.push_back(newOperands.back().getType());
1709     }
1710     auto extraOperands = linalgOp.getAssumedNonShapedOperands();
1711     newOperands.append(extraOperands.begin(), extraOperands.end());
1712     // Clone op.
1713     Operation *newOp =
1714         linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1715     rewriter.replaceOp(op, newOp->getResults());
1716 
1717     return success();
1718   }
1719 };
1720 } // namespace
1721 
1722 namespace {
1723 // Deduplicate redundant args of a linalg op.
1724 // An arg is redundant if it has the same Value and indexing map as another.
1725 struct DeduplicateInputs : public RewritePattern {
DeduplicateInputs__anon9ff5be7e1111::DeduplicateInputs1726   DeduplicateInputs(PatternBenefit benefit = 1)
1727       : RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1728 
matchAndRewrite__anon9ff5be7e1111::DeduplicateInputs1729   LogicalResult matchAndRewrite(Operation *op,
1730                                 PatternRewriter &rewriter) const override {
1731     // This pattern reduces the number of arguments of an op, which breaks
1732     // the invariants of semantically charged named ops.
1733     if (!isa<GenericOp, IndexedGenericOp>(op))
1734       return failure();
1735     auto linalgOp = cast<LinalgOp>(op);
1736 
1737     // Associate each input to an equivalent "canonical" input that has the same
1738     // Value and indexing map.
1739     //
1740     // In the non-duplicate case, input `i` will have canonical input `i`. But
1741     // in the case of duplicated inputs, the canonical input could be some other
1742     // input `< i`. That is, a later input will have some earlier input as its
1743     // canonical input.
1744     llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput;
1745     // For later remapping tasks like deduplicating payload block arguments,
1746     // having a simple "inputIndex -> canonicalInputIndex" integer mapping is
1747     // convenient.
1748     SmallVector<int, 6> canonicalInputIndices;
1749     for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) {
1750       Value input = linalgOp.getInput(i);
1751       AffineMap indexingMap = linalgOp.getInputIndexingMap(i);
1752       // STL-like maps have a convenient behavior for our use case here. In the
1753       // case of duplicate keys, the insertion is rejected, and the returned
1754       // iterator gives access to the value already in the map.
1755       auto pair = canonicalInput.insert({{input, indexingMap}, i});
1756       canonicalInputIndices.push_back(pair.first->second);
1757     }
1758 
1759     // If there are no duplicate args, then bail out.
1760     if (canonicalInput.size() == linalgOp.getNumInputs())
1761       return failure();
1762 
1763     // The operands for the newly canonicalized op.
1764     SmallVector<Value, 6> newOperands;
1765     for (auto v : llvm::enumerate(linalgOp.getInputs()))
1766       if (canonicalInputIndices[v.index()] == static_cast<int>(v.index()))
1767         newOperands.push_back(v.value());
1768     llvm::append_range(newOperands, linalgOp.getOutputBuffers());
1769     llvm::append_range(newOperands, linalgOp.getInitTensors());
1770     llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands());
1771 
1772     // Clone the old op with new operands.
1773     Operation *newOp = linalgOp.clone(rewriter, op->getLoc(),
1774                                       op->getResultTypes(), newOperands);
1775     auto newLinalgOp = cast<LinalgOp>(newOp);
1776 
1777     // Repair the indexing maps by filtering out the ones that have been
1778     // eliminated.
1779     SmallVector<AffineMap, 6> newIndexingMaps;
1780     for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++)
1781       if (canonicalInputIndices[i] == i)
1782         newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i));
1783     for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++)
1784       newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i));
1785     newOp->setAttr("indexing_maps",
1786                    rewriter.getAffineMapArrayAttr(newIndexingMaps));
1787 
1788     // Set the number of inputs to the new value. The `clone` call above kept
1789     // the value from the original op.
1790     newLinalgOp.setNumInputs(canonicalInput.size());
1791 
1792     // linalg.indexed_generic payloads have additional arguments prepended to
1793     // the block arg list. The number of such args is one per dimension of the
1794     // iteration space.
1795     int bbArgBaseOffset = 0;
1796     if (isa<IndexedGenericOp>(op))
1797       bbArgBaseOffset = newIndexingMaps[0].getNumInputs();
1798 
1799     // Repair the payload entry block by RAUW'ing redundant arguments and
1800     // erasing them.
1801     Block &payload = newOp->getRegion(0).front();
1802     for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
1803       // Iterate in reverse, so that we erase later args first, preventing the
1804       // argument list from shifting unexpectedly and invalidating all our
1805       // indices.
1806       int reversed = e - i - 1;
1807       int canonicalIndex = canonicalInputIndices[reversed];
1808       if (canonicalInputIndices[reversed] == reversed)
1809         continue;
1810       payload.getArgument(bbArgBaseOffset + reversed)
1811           .replaceAllUsesWith(
1812               payload.getArgument(bbArgBaseOffset + canonicalIndex));
1813       payload.eraseArgument(bbArgBaseOffset + reversed);
1814     }
1815 
1816     rewriter.replaceOp(op, newOp->getResults());
1817     return success();
1818   }
1819 };
1820 } // namespace
1821 
1822 #define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
1823   void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results,     \
1824                                         MLIRContext *context) {                \
1825     results.insert<EraseDeadLinalgOp>();                                       \
1826     results.insert<FoldTensorCastOp>();                                        \
1827     results.insert<DeduplicateInputs>();                                       \
1828   }                                                                            \
1829                                                                                \
1830   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
1831                           SmallVectorImpl<OpFoldResult> &) {                   \
1832     return foldMemRefCast(*this);                                              \
1833   }
1834 
1835 CANONICALIZERS_AND_FOLDERS(ConvOp)
1836 CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
1837 CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
1838 CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
1839 CANONICALIZERS_AND_FOLDERS(CopyOp)
1840 CANONICALIZERS_AND_FOLDERS(FillOp)
1841 CANONICALIZERS_AND_FOLDERS(GenericOp)
1842 CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
1843 
1844 // All named ops canonicalizers and folders are auto-generated in the
1845 // .cpp.inc.
1846