1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Copyright 2022 The StableHLO Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #ifndef STABLEHLO_DIALECT_STABLEHLO_OPS_H
18 #define STABLEHLO_DIALECT_STABLEHLO_OPS_H
19
20 #include <algorithm>
21
22 #include "dialect/Base.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "mlir/Dialect/Quant/QuantTypes.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinAttributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Dialect.h"
32 #include "mlir/IR/DialectImplementation.h"
33 #include "mlir/IR/Location.h"
34 #include "mlir/IR/MLIRContext.h"
35 #include "mlir/IR/OpDefinition.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/TensorEncoding.h"
38 #include "mlir/IR/TypeUtilities.h"
39 #include "mlir/IR/Types.h"
40 #include "mlir/Interfaces/InferTypeOpInterface.h"
41 #include "mlir/Interfaces/SideEffectInterfaces.h"
42 #include "mlir/Support/LogicalResult.h"
43
44 // Include order matters.
45 #include "dialect/StablehloEnums.h.inc"
46 #define GET_ATTRDEF_CLASSES
47 #include "dialect/StablehloAttrs.h.inc"
48
49 namespace mlir {
50 namespace stablehlo {
51
52 class StablehloDialect : public Dialect {
53 public:
54 explicit StablehloDialect(MLIRContext *context);
getDialectNamespace()55 static StringRef getDialectNamespace() { return "stablehlo"; }
56
57 // Registered hook to materialize a constant operation from a given attribute
58 // value with the desired resultant type.
59 Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
60 Location loc) override;
61
62 // Registered hook to verify region arg attributes on operations.
63 LogicalResult verifyRegionArgAttribute(mlir::Operation *op,
64 unsigned regionIndex,
65 unsigned argIndex,
66 mlir::NamedAttribute attr) override;
67
68 // Registered hook to verify an attribute from this dialect on operations.
69 LogicalResult verifyOperationAttribute(mlir::Operation *op,
70 mlir::NamedAttribute attr) override;
71
72 // Parses a type registered to this dialect.
73 Type parseType(DialectAsmParser &parser) const override;
74
75 // Prints a type registered to this dialect.
76 void printType(Type type, DialectAsmPrinter &os) const override;
77
78 // Parses an attribute registered to this dialect.
79 Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
80
81 // Prints an attribute registered to this dialect.
82 void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
83 };
84
85 class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
86 public:
87 using Base::Base;
88 };
89
90 // TODO(b/236017415): remove when we migrate to prefix accessor.
91 namespace accessor_dispatch {
92 template <typename OpT>
93 auto getReplicaGroups(OpT op, int)
94 -> decltype(op.getReplicaGroups(), DenseIntElementsAttr{}) {
95 return op.getReplicaGroups();
96 }
97 template <typename OpT>
98 auto getReplicaGroups(OpT op, char)
99 -> decltype(op.replica_groups(), DenseIntElementsAttr{}) {
100 return op.replica_groups();
101 }
102 } // namespace accessor_dispatch
103
104 // Verifies replica groups attached to collective communication operations.
105 // If the attribute is not empty, it must be a rank 2 tensor, and each replica
106 // should appear exactly once. If `is_uniform_sized` is true, then we also check
107 // that each group is of the same size. If the operation has
108 // `use_global_device_ids` set, then replica group cannot be empty.
109 template <typename OpT>
verifyReplicaGroups(OpT op,bool isUniformSized)110 LogicalResult verifyReplicaGroups(OpT op, bool isUniformSized) {
111 DenseIntElementsAttr attr = accessor_dispatch::getReplicaGroups(op, 0);
112 auto replicaGroupType = attr.getType().dyn_cast<RankedTensorType>();
113 if (!replicaGroupType || replicaGroupType.getRank() != 2 ||
114 !replicaGroupType.getElementType().isInteger(/*width=*/64))
115 return op.emitOpError(
116 "replica groups should be a rank 2 tensor of 64 bit integers");
117
118 if (replicaGroupType.getShape().equals(ArrayRef<int64_t>{0, 0})) {
119 if (op->hasAttr("use_global_device_ids") &&
120 op->getAttr("use_global_device_ids")
121 .template cast<BoolAttr>()
122 .getValue()) {
123 return op.emitOpError(
124 "if `use_global_device_ids` is set, the replica groups cannot be "
125 "empty");
126 }
127 return success();
128 }
129
130 int64_t maxReplicaIdSeen = 0;
131 llvm::SmallSet<int64_t, 8> replicaSeen;
132 for (int64_t id : attr.getValues<int64_t>()) {
133 // Replica groups are stored in a 2D tensor. If the op supports non-uniform
134 // groups, null replica IDs are stored as -1.
135 if (id == -1) {
136 if (isUniformSized) {
137 return op.emitOpError("Invalid replica id -1");
138 }
139 continue;
140 }
141
142 if (!replicaSeen.insert(id).second) {
143 return op.emitOpError("replica id #") << id << " seen more than once";
144 }
145 maxReplicaIdSeen = std::max(maxReplicaIdSeen, id);
146 }
147
148 for (int64_t id = 0; id <= maxReplicaIdSeen; id++) {
149 if (!replicaSeen.contains(id)) {
150 return op.emitOpError("replica id #")
151 << id << " not seen in replica groups";
152 }
153 }
154 return success();
155 }
156
157 // Verifies the source target pairs attached to collective permute.
158 LogicalResult verifyCollectivePermuteSourceTargetPairs(
159 Operation *op, DenseIntElementsAttr attr);
160
161 LogicalResult verifyReduceScatter(Operation *op, TypeRange operandTypes,
162 TypeRange resultTypes,
163 uint64_t scatterDimension);
164
165 void printConvolutionDimensions(AsmPrinter &p, ConvDimensionNumbersAttr dnums);
166 void printConvolutionDimensions(AsmPrinter &p, Operation *,
167 ConvDimensionNumbersAttr dnums);
168 ParseResult parseConvolutionDimensions(AsmParser &parser,
169 ConvDimensionNumbersAttr &dnums);
170
171 // Custom formatting for convolution window attributes.
172 void printWindowAttributes(OpAsmPrinter &p, Operation *op,
173 llvm::Optional<DenseIntElementsAttr> windowStrides,
174 llvm::Optional<DenseIntElementsAttr> padding,
175 llvm::Optional<DenseIntElementsAttr> lhsDilation,
176 llvm::Optional<DenseIntElementsAttr> rhsDilation,
177 llvm::Optional<DenseElementsAttr> windowReversal);
178
179 ParseResult parseWindowAttributes(OpAsmParser &parser,
180 DenseIntElementsAttr &windowStrides,
181 DenseIntElementsAttr &padding,
182 DenseIntElementsAttr &lhsDilation,
183 DenseIntElementsAttr &rhsDilation,
184 DenseElementsAttr &windowReversal);
185
186 } // end namespace stablehlo
187 } // end namespace mlir
188
189 #define GET_OP_CLASSES
190 #include "dialect/StablehloOps.h.inc"
191
192 namespace mlir {
193 namespace stablehlo {
194
195 SortOp createSortOp(PatternRewriter *rewriter, const Location &loc,
196 const llvm::ArrayRef<Value> &operands,
197 const llvm::ArrayRef<Type> &elementTypes, int64_t dimension,
198 bool isStable, ComparisonDirection direction);
199
200 } // end namespace stablehlo
201 } // end namespace mlir
202
203 #endif // STABLEHLO_DIALECT_STABLEHLO_OPS_H
204