• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2    Copyright 2022 The StableHLO Authors.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #ifndef STABLEHLO_DIALECT_STABLEHLO_OPS_H
18 #define STABLEHLO_DIALECT_STABLEHLO_OPS_H
19 
20 #include <algorithm>
21 
22 #include "dialect/Base.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "mlir/Dialect/Quant/QuantTypes.h"
26 #include "mlir/Dialect/Shape/IR/Shape.h"
27 #include "mlir/IR/Attributes.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/BuiltinAttributes.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Dialect.h"
32 #include "mlir/IR/DialectImplementation.h"
33 #include "mlir/IR/Location.h"
34 #include "mlir/IR/MLIRContext.h"
35 #include "mlir/IR/OpDefinition.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/TensorEncoding.h"
38 #include "mlir/IR/TypeUtilities.h"
39 #include "mlir/IR/Types.h"
40 #include "mlir/Interfaces/InferTypeOpInterface.h"
41 #include "mlir/Interfaces/SideEffectInterfaces.h"
42 #include "mlir/Support/LogicalResult.h"
43 
44 // Include order matters.
45 #include "dialect/StablehloEnums.h.inc"
46 #define GET_ATTRDEF_CLASSES
47 #include "dialect/StablehloAttrs.h.inc"
48 
49 namespace mlir {
50 namespace stablehlo {
51 
52 class StablehloDialect : public Dialect {
53  public:
54   explicit StablehloDialect(MLIRContext *context);
getDialectNamespace()55   static StringRef getDialectNamespace() { return "stablehlo"; }
56 
57   // Registered hook to materialize a constant operation from a given attribute
58   // value with the desired resultant type.
59   Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
60                                  Location loc) override;
61 
62   // Registered hook to verify region arg attributes on operations.
63   LogicalResult verifyRegionArgAttribute(mlir::Operation *op,
64                                          unsigned regionIndex,
65                                          unsigned argIndex,
66                                          mlir::NamedAttribute attr) override;
67 
68   // Registered hook to verify an attribute from this dialect on operations.
69   LogicalResult verifyOperationAttribute(mlir::Operation *op,
70                                          mlir::NamedAttribute attr) override;
71 
72   // Parses a type registered to this dialect.
73   Type parseType(DialectAsmParser &parser) const override;
74 
75   // Prints a type registered to this dialect.
76   void printType(Type type, DialectAsmPrinter &os) const override;
77 
78   // Parses an attribute registered to this dialect.
79   Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
80 
81   // Prints an attribute registered to this dialect.
82   void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
83 };
84 
85 class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
86  public:
87   using Base::Base;
88 };
89 
90 // TODO(b/236017415): remove when we migrate to prefix accessor.
91 namespace accessor_dispatch {
92 template <typename OpT>
93 auto getReplicaGroups(OpT op, int)
94     -> decltype(op.getReplicaGroups(), DenseIntElementsAttr{}) {
95   return op.getReplicaGroups();
96 }
97 template <typename OpT>
98 auto getReplicaGroups(OpT op, char)
99     -> decltype(op.replica_groups(), DenseIntElementsAttr{}) {
100   return op.replica_groups();
101 }
102 }  // namespace accessor_dispatch
103 
104 // Verifies replica groups attached to collective communication operations.
105 // If the attribute is not empty, it must be a rank 2 tensor, and each replica
106 // should appear exactly once. If `is_uniform_sized` is true, then we also check
107 // that each group is of the same size. If the operation has
108 // `use_global_device_ids` set, then replica group cannot be empty.
109 template <typename OpT>
verifyReplicaGroups(OpT op,bool isUniformSized)110 LogicalResult verifyReplicaGroups(OpT op, bool isUniformSized) {
111   DenseIntElementsAttr attr = accessor_dispatch::getReplicaGroups(op, 0);
112   auto replicaGroupType = attr.getType().dyn_cast<RankedTensorType>();
113   if (!replicaGroupType || replicaGroupType.getRank() != 2 ||
114       !replicaGroupType.getElementType().isInteger(/*width=*/64))
115     return op.emitOpError(
116         "replica groups should be a rank 2 tensor of 64 bit integers");
117 
118   if (replicaGroupType.getShape().equals(ArrayRef<int64_t>{0, 0})) {
119     if (op->hasAttr("use_global_device_ids") &&
120         op->getAttr("use_global_device_ids")
121             .template cast<BoolAttr>()
122             .getValue()) {
123       return op.emitOpError(
124           "if `use_global_device_ids` is set, the replica groups cannot be "
125           "empty");
126     }
127     return success();
128   }
129 
130   int64_t maxReplicaIdSeen = 0;
131   llvm::SmallSet<int64_t, 8> replicaSeen;
132   for (int64_t id : attr.getValues<int64_t>()) {
133     // Replica groups are stored in a 2D tensor. If the op supports non-uniform
134     // groups, null replica IDs are stored as -1.
135     if (id == -1) {
136       if (isUniformSized) {
137         return op.emitOpError("Invalid replica id -1");
138       }
139       continue;
140     }
141 
142     if (!replicaSeen.insert(id).second) {
143       return op.emitOpError("replica id #") << id << " seen more than once";
144     }
145     maxReplicaIdSeen = std::max(maxReplicaIdSeen, id);
146   }
147 
148   for (int64_t id = 0; id <= maxReplicaIdSeen; id++) {
149     if (!replicaSeen.contains(id)) {
150       return op.emitOpError("replica id #")
151              << id << " not seen in replica groups";
152     }
153   }
154   return success();
155 }
156 
157 // Verifies the source target pairs attached to collective permute.
158 LogicalResult verifyCollectivePermuteSourceTargetPairs(
159     Operation *op, DenseIntElementsAttr attr);
160 
161 LogicalResult verifyReduceScatter(Operation *op, TypeRange operandTypes,
162                                   TypeRange resultTypes,
163                                   uint64_t scatterDimension);
164 
165 void printConvolutionDimensions(AsmPrinter &p, ConvDimensionNumbersAttr dnums);
166 void printConvolutionDimensions(AsmPrinter &p, Operation *,
167                                 ConvDimensionNumbersAttr dnums);
168 ParseResult parseConvolutionDimensions(AsmParser &parser,
169                                        ConvDimensionNumbersAttr &dnums);
170 
171 // Custom formatting for convolution window attributes.
172 void printWindowAttributes(OpAsmPrinter &p, Operation *op,
173                            llvm::Optional<DenseIntElementsAttr> windowStrides,
174                            llvm::Optional<DenseIntElementsAttr> padding,
175                            llvm::Optional<DenseIntElementsAttr> lhsDilation,
176                            llvm::Optional<DenseIntElementsAttr> rhsDilation,
177                            llvm::Optional<DenseElementsAttr> windowReversal);
178 
179 ParseResult parseWindowAttributes(OpAsmParser &parser,
180                                   DenseIntElementsAttr &windowStrides,
181                                   DenseIntElementsAttr &padding,
182                                   DenseIntElementsAttr &lhsDilation,
183                                   DenseIntElementsAttr &rhsDilation,
184                                   DenseElementsAttr &windowReversal);
185 
186 }  // end namespace stablehlo
187 }  // end namespace mlir
188 
189 #define GET_OP_CLASSES
190 #include "dialect/StablehloOps.h.inc"
191 
192 namespace mlir {
193 namespace stablehlo {
194 
195 SortOp createSortOp(PatternRewriter *rewriter, const Location &loc,
196                     const llvm::ArrayRef<Value> &operands,
197                     const llvm::ArrayRef<Type> &elementTypes, int64_t dimension,
198                     bool isStable, ComparisonDirection direction);
199 
200 }  // end namespace stablehlo
201 }  // end namespace mlir
202 
203 #endif  // STABLEHLO_DIALECT_STABLEHLO_OPS_H
204