• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on builtin types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24 
25 namespace mlir {
26 
isRowMajorMatmul(ArrayAttr indexingMaps)27 inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
28   auto context = indexingMaps.getContext();
29   AffineExpr m, n, k;
30   bindDims(context, m, n, k);
31   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
32   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
33   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
34   auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
35   return indexingMaps == maps;
36 }
37 
isColumnMajorMatmul(ArrayAttr indexingMaps)38 inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
39   auto context = indexingMaps.getContext();
40   AffineExpr m, n, k;
41   bindDims(context, m, n, k);
42   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
43   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
44   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
45   auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
46   return indexingMaps == maps;
47 }
48 
49 /// Attribute name for the AffineArrayAttr which encodes the relationship
50 /// between a structured op iterators' and its operands.
getIndexingMapsAttrName()51 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
52 
53 /// Attribute name for the StrArrayAttr which encodes the type of a structured
54 /// op's iterators.
getIteratorTypesAttrName()55 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
56 
57 /// Attribute name for the StringAttr which encodes an optional documentation
58 /// string of the structured op.
getDocAttrName()59 constexpr StringRef getDocAttrName() { return "doc"; }
60 
61 /// Attribute name for the StrArrayAttr which encodes the external library
62 /// function that implements the structured op.
getLibraryCallAttrName()63 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
64 
65 /// Attribute name for the ArrayAttr of StrArrayAttr that encodes sparsity.
getSparseAttrName()66 constexpr StringRef getSparseAttrName() { return "sparse"; }
67 
68 /// Attribute name for the StrArrayAttr which encodes the value of strides.
getStridesAttrName()69 constexpr StringRef getStridesAttrName() { return "strides"; }
70 
71 /// Attribute name for the StrArrayAttr which encodes the value of dilations.
getDilationsAttrName()72 constexpr StringRef getDilationsAttrName() { return "dilations"; }
73 
74 /// Attribute name for the StrArrayAttr which encodes the value of paddings.
getPaddingAttrName()75 constexpr StringRef getPaddingAttrName() { return "padding"; }
76 
77 /// Use to encode that a particular iterator type has parallel semantics.
getParallelIteratorTypeName()78 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
isParallelIterator(Attribute attr)79 inline bool isParallelIterator(Attribute attr) {
80   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
81   return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
82 }
83 
84 /// Use to encode that a particular iterator type has reduction semantics.
getReductionIteratorTypeName()85 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
isReductionIterator(Attribute attr)86 inline bool isReductionIterator(Attribute attr) {
87   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
88   return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
89 }
90 
91 /// Use to encode that a particular iterator type has window semantics.
getWindowIteratorTypeName()92 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
isWindowIterator(Attribute attr)93 inline bool isWindowIterator(Attribute attr) {
94   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
95   return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
96 }
97 
98 /// Use to encode that a particular iterator type has window semantics.
getAllIteratorTypeNames()99 inline ArrayRef<StringRef> getAllIteratorTypeNames() {
100   static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
101                                          getReductionIteratorTypeName(),
102                                          getWindowIteratorTypeName()};
103   return llvm::makeArrayRef(names);
104 }
105 
106 /// Returns the iterator of a certain type.
getNumIterators(StringRef name,ArrayAttr iteratorTypes)107 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
108   auto names = getAllIteratorTypeNames();
109   (void)names;
110   assert(llvm::is_contained(names, name));
111   return llvm::count_if(iteratorTypes, [name](Attribute a) {
112     return a.cast<StringAttr>().getValue() == name;
113   });
114 }
115 
getNumIterators(ArrayAttr iteratorTypes)116 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
117   unsigned res = 0;
118   for (auto n : getAllIteratorTypeNames())
119     res += getNumIterators(n, iteratorTypes);
120   return res;
121 }
122 
123 /// Typed representation for loop type strings.
124 enum class IteratorType { Parallel, Reduction };
125 
toString(IteratorType t)126 inline StringRef toString(IteratorType t) {
127   switch (t) {
128   case IteratorType::Parallel:
129     return getParallelIteratorTypeName();
130   case IteratorType::Reduction:
131     return getReductionIteratorTypeName();
132   }
133   llvm_unreachable("Unsupported IteratorType");
134 }
135 
136 /// Use to encode a dense or sparse dimension.
getSparseDimName()137 constexpr StringRef getSparseDimName() { return "S"; }
isSparseDim(Attribute attr)138 inline bool isSparseDim(Attribute attr) {
139   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
140   return strAttr && strAttr.getValue() == getSparseDimName();
141 }
getDenseDimName()142 constexpr StringRef getDenseDimName() { return "D"; }
isDenseDim(Attribute attr)143 inline bool isDenseDim(Attribute attr) {
144   auto strAttr = attr.dyn_cast_or_null<StringAttr>();
145   return strAttr && strAttr.getValue() == getDenseDimName();
146 }
147 
148 } // end namespace mlir
149 
150 #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
151