• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2    Copyright 2022 The StableHLO Authors.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "dialect/StablehloOps.h"
18 
19 #include <assert.h>
20 #include <stddef.h>
21 #include <stdint.h>
22 
23 #include <algorithm>
24 #include <array>
25 #include <cstdint>
26 #include <functional>
27 #include <numeric>
28 #include <set>
29 #include <unordered_map>
30 #include <utility>
31 
32 #include "dialect/StablehloOps.h.inc"
33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/DenseMap.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/Twine.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/iterator_range.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/Support/MathExtras.h"
48 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
49 #include "mlir/Dialect/Complex/IR/Complex.h"
50 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
51 #include "mlir/Dialect/Tensor/IR/Tensor.h"
52 #include "mlir/IR/Attributes.h"
53 #include "mlir/IR/Builders.h"
54 #include "mlir/IR/BuiltinAttributes.h"
55 #include "mlir/IR/BuiltinTypes.h"
56 #include "mlir/IR/Diagnostics.h"
57 #include "mlir/IR/Dialect.h"
58 #include "mlir/IR/FunctionInterfaces.h"
59 #include "mlir/IR/Location.h"
60 #include "mlir/IR/MLIRContext.h"
61 #include "mlir/IR/Matchers.h"
62 #include "mlir/IR/OpDefinition.h"
63 #include "mlir/IR/OpImplementation.h"
64 #include "mlir/IR/Operation.h"
65 #include "mlir/IR/OperationSupport.h"
66 #include "mlir/IR/PatternMatch.h"
67 #include "mlir/IR/TypeUtilities.h"
68 #include "mlir/IR/Types.h"
69 #include "mlir/IR/Value.h"
70 #include "mlir/Support/LLVM.h"
71 #include "mlir/Support/LogicalResult.h"
72 #include "mlir/Transforms/InliningUtils.h"
73 
74 // Include order matters
75 #include "dialect/StablehloEnums.cpp.inc"
76 #define GET_ATTRDEF_CLASSES
77 #include "dialect/StablehloAttrs.cpp.inc"
78 
79 namespace mlir {
80 namespace stablehlo {
81 namespace {
createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,ArrayRef<Type> types,SmallVector<OpAsmParser::Argument> & args)82 void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,
83                 ArrayRef<Type> types,
84                 SmallVector<OpAsmParser::Argument>& args) {
85   for (auto argAndType : llvm::zip(operands, types)) {
86     auto& arg = args.emplace_back();
87     arg.ssaName = std::get<0>(argAndType);
88     arg.type = std::get<1>(argAndType);
89   }
90 }
91 
__anona165a9100202(SmallVector<int64_t>& nums) 92 const auto hasDuplicates = [](SmallVector<int64_t>& nums) {
93   if (!llvm::is_sorted(nums)) std::sort(nums.begin(), nums.end());
94   auto* last = std::unique(nums.begin(), nums.end());
95   return last != nums.end();
96 };
97 
98 //===----------------------------------------------------------------------===//
99 // Utilities for the canonicalize patterns
100 //===----------------------------------------------------------------------===//
101 
102 // Verifies that dimension attribute for the op correctly indexes in operand or
103 // result shape.
104 template <typename OpT>
verifyDimAttr(OpT op)105 static LogicalResult verifyDimAttr(OpT op) {
106   int64_t rank = -1;
107   if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
108     rank = ty.getRank();
109   } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
110     rank = ty.getRank();
111   } else {
112     return success();
113   }
114 
115   int64_t dim = op.dimension();
116   if (dim < 0 || dim >= rank)
117     return op.emitOpError() << "requires dimension attribute in range [0, "
118                             << rank << "); found (" << dim << ")";
119   return success();
120 }
121 
122 // Check if the dimension size is dynamic.
isDynamicDimSize(int64_t val)123 inline static bool isDynamicDimSize(int64_t val) {
124   return val == ShapedType::kDynamicSize;
125 }
126 
127 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)128 static LogicalResult rngInferReturnTypeComponents(
129     MLIRContext* context, Optional<Location> location, ValueRange operands,
130     DictionaryAttr attributes, RegionRange regions,
131     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
132   if (operands.size() != 3)
133     return emitOptionalError(location, "expected 3 operands");
134 
135   SmallVector<int64_t> shapeVector;
136   Value shapeOperand = operands[2];
137   auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
138   Type elementType = getElementTypeOrSelf(operands[1]);
139 
140   // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is:
141   // 1, not constant and have dynimic dim (tensor<?x>): infer tensor<*x>.
142   // 2. not constant nor dynimic (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
143   // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
144 
145   // Match to check whether the `shape` operand is a constant.
146   DenseIntElementsAttr shape;
147   if (!matchPattern(shapeOperand, m_Constant(&shape))) {
148     int size = shapeOperandType.getDimSize(0);
149     if (isDynamicDimSize(size)) {
150       inferredReturnShapes.emplace_back(elementType);
151       return success();
152     }
153     shapeVector.resize(size, ShapedType::kDynamicSize);
154     inferredReturnShapes.emplace_back(shapeVector, elementType);
155     return success();
156   }
157 
158   // `shape` operand is a constant.
159   shapeVector.reserve(shape.size());
160   for (const APInt& fp : shape.getValues<APInt>())
161     shapeVector.push_back(fp.getSExtValue());
162   inferredReturnShapes.emplace_back(shapeVector, elementType);
163   return success();
164 }
165 
166 // Returns a new scalar integer value having type `type`. Here `type` must be
167 // an integer or index type.
maybeCastTo(OpBuilder & b,Location loc,Value value,Type type)168 Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
169   if (type == value.getType()) return value;
170   assert(type.isIndex() || value.getType().isIndex());
171   return b.create<arith::IndexCastOp>(loc, type, value);
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Utilities for verifiers
176 //===----------------------------------------------------------------------===//
177 
178 // Convert a 1D dense int64 attribute to a list of values.
convertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr)179 SmallVector<int64_t> convertDenseIntAttr(
180     llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr) {
181   if (!optionalAttr.has_value()) return SmallVector<int64_t>{};
182 
183   mlir::DenseIntElementsAttr attr = *optionalAttr;
184   auto values = attr.getValues<int64_t>();
185   return {values.begin(), values.end()};
186 }
187 
188 // Convert a 1D or Nx2 dense int64 attribute to a list of tuples.
convertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr,Location loc)189 FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertNx2Attribute(
190     llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr, Location loc) {
191   if (!optionalAttr.has_value())
192     return SmallVector<std::pair<int64_t, int64_t>>{};
193   mlir::DenseIntElementsAttr attr = *optionalAttr;
194 
195   auto attrType = attr.getType().cast<RankedTensorType>();  // ensured by ODS.
196   if (attrType.getRank() > 1) {
197     if (attrType.getRank() != 2 || attrType.getShape()[1] != 2)
198       return (mlir::emitError(loc) << "expects the shape of padding-attribute "
199                                       "to be {N, 2}, but got {"
200                                    << attrType.getShape() << "}.",
201               failure());
202   } else {
203     // Padding values can be provided as a 1D vector as well.
204     if (attr.getValues<int64_t>().size() % 2 != 0)
205       return (mlir::emitError(loc)
206                   << "expects the padding-entries to have even number of "
207                      "elements, but got "
208                   << attr.getValues<int64_t>().size() << " elements.",
209               failure());
210   }
211 
212   auto it = attr.getValues<int64_t>().begin();
213   SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
214   for (auto& item : out) {
215     int64_t first = *it;
216     ++it;
217     int64_t second = *it;
218     ++it;
219     item = {first, second};
220   }
221   return out;
222 }
223 
224 // If a window with the given bound in some dimension is dilated with the given
225 // dilation factor in that dimension, then the value returned is the bound for
226 // the array in that dimension after dilation.
227 //
228 // For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
229 // window with values 1, x, 2, x, 3, where x indicates holes left by the
230 // dilation. So DilatedBound(3, 2) == 5.
dilatedBound(int64_t bound,int64_t dilation)231 int64_t dilatedBound(int64_t bound, int64_t dilation) {
232   assert(bound >= 0 && "The dimension to dialate must be >= 0");
233   if (bound == 0) return 0;
234 
235   // Suppose the array has three entries 123 and the dilation factor is 4. Then
236   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
237   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
238   // add 1 to account for the final input element.
239   return (bound - 1) * dilation + 1;
240 }
241 
242 // Returns the number of valid positions of a window with the given size and
243 // stride within an array with the given bound. This is the bound of an output
244 // array with one element per valid position of the window.
245 //
246 // For example, for arguments of (bound=5, window_size=2, stride=2), the
247 // returned value is 2. There are valid positions at offset 0 and offset 2,
248 // while offset 4 is not valid since the window's last entry would be at 5,
249 // which is beyond the bound of 5.
stridedBound(int64_t bound,int64_t windowSize,int64_t stride)250 int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) {
251   assert(windowSize >= 0 && "Expected window size to be >= 0");
252   assert(bound >= 0 && "Expected bound to be >= 0");
253 
254   if (bound == 0 || windowSize > bound) return 0;
255 
256   // Without considering stride, the maximum valid offset is bound -
257   // window_size. Taking stride into account, the valid offsets then have the
258   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
259   // window_size. This implies that Q equals floor(bound - window_size /
260   // stride). There are Q + 1 valid values of q, yielding the formula below.
261   return (bound - windowSize) / stride + 1;
262 }
263 
264 // WindowDimension described how the kernel window moves across the base area
265 // in a particular dimension.
266 // Describes the windowing in an operation such as convolution.
267 // The window is moved across a base area and for each position of the
268 // window a computation is performed. The field below describes the
269 // window and the movement of the window across a base area.
270 struct WindowDimension {
271   int64_t size = 0;
272   int64_t stride = 1;
273   int64_t paddingLow = 0;
274   int64_t paddingHigh = 0;
275   int64_t windowDilation = 1;
276   int64_t baseDilation = 1;
277   bool windowReversal = false;
278 };
279 
280 // Verifies various properties of window-attributes (viz., stride, padding,
281 // lhs_dilation and rhs_dilation) and collects all the window-attributes for
282 // each kernel spatial dimensions.
283 FailureOr<SmallVector<WindowDimension>>
verifyWindowAttributesAndInferWindowDimensions(ArrayRef<int64_t> windowDimensions,ArrayRef<int64_t> windowStrides,ArrayRef<std::pair<int64_t,int64_t>> padding,ArrayRef<int64_t> lhsDilation,ArrayRef<int64_t> rhsDilation,Location loc)284 verifyWindowAttributesAndInferWindowDimensions(
285     ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,
286     ArrayRef<std::pair<int64_t, int64_t>> padding,
287     ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
288     Location loc) {
289   const auto verifySize = [&](const size_t attrSize,
290                               StringRef attrName) -> LogicalResult {
291     if (attrSize == 0 || attrSize == windowDimensions.size()) return success();
292     return mlir::emitError(loc)
293            << "expects " << attrName
294            << " to have same dimension-size as size of "
295               "window dimensions "
296               "("
297            << windowDimensions.size() << "), but got: " << attrSize << ".";
298   };
299 
300   if (failed(verifySize(windowStrides.size(), "window-strides")))
301     return failure();
302   if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
303     return failure();
304   if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
305     return failure();
306   if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
307 
308   SmallVector<WindowDimension> window(windowDimensions.size());
309   for (size_t i = 0; i < windowDimensions.size(); i++) {
310     WindowDimension& dim = window[i];
311 
312     dim.size = windowDimensions[i];
313     if (!isDynamicDimSize(dim.size) && dim.size <= 0)
314       return (mlir::emitError(loc)
315                   << "expects window to have positive value for " << i
316                   << "-th window dimension, but got " << dim.size << ".",
317               failure());
318 
319     if (!windowStrides.empty()) dim.stride = windowStrides[i];
320     if (dim.stride <= 0)
321       return (mlir::emitError(loc)
322                   << "expects window to have positive stride for " << i
323                   << "-th window dimension, but got " << dim.stride << ".",
324               failure());
325 
326     if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
327     if (dim.baseDilation <= 0)
328       return (mlir::emitError(loc) << "expects window to have positive base "
329                                       "dilation factor for "
330                                    << i << "-th window dimension, but got "
331                                    << dim.baseDilation << ".",
332               failure());
333 
334     if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
335     if (dim.windowDilation <= 0)
336       return (mlir::emitError(loc) << "expects window to have positive window "
337                                       "dilation factor for "
338                                    << i << "-th window dimension, but got "
339                                    << dim.windowDilation << ".",
340               failure());
341 
342     if (!padding.empty()) {
343       dim.paddingLow = padding[i].first;
344       dim.paddingHigh = padding[i].second;
345     }
346   }
347 
348   return window;
349 }
350 
351 // Infer the shape of the output window.
352 //  Foreach dimension d,
353 //    output-window-shape[d] =
354 //            stridedBound(padding_low + dilatedBound(base_shape[d]) +
355 //            padding_high,
356 //                         dilatedBound(window_shape[d]))
357 //      where (padding_low, padding_high) is the padding-pair for d.
inferWindowOutputShape(const ArrayRef<int64_t> baseShape,const ArrayRef<WindowDimension> window)358 SmallVector<int64_t> inferWindowOutputShape(
359     const ArrayRef<int64_t> baseShape, const ArrayRef<WindowDimension> window) {
360   assert(baseShape.size() == window.size() &&
361          "Size of window dimensions must match the size of base shape.");
362 
363   SmallVector<int64_t> outputDimensions(window.size());
364   for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) {
365     if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) {
366       outputDimensions[i] = ShapedType::kDynamicSize;
367     } else {
368       const auto& dim = window[i];
369 
370       const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation);
371       const int64_t paddedDilatedBase =
372           dim.paddingLow + dilatedBase + dim.paddingHigh;
373       const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation);
374 
375       outputDimensions[i] =
376           stridedBound(paddedDilatedBase, dilatedWindow, dim.stride);
377     }
378   }
379 
380   return outputDimensions;
381 }
382 
383 // Return true if type1 and type2 are tensors and have the same
384 // element-type, else return false. With float element-types, ignore comparing
385 // floating-point precision if ignoreFpPrecision is True.
tensorsHaveSameElType(Type type1,Type type2,bool ignoreFpPrecision)386 bool tensorsHaveSameElType(Type type1, Type type2, bool ignoreFpPrecision) {
387   auto tensorTy1 = type1.dyn_cast<TensorType>();
388   auto tensorTy2 = type2.dyn_cast<TensorType>();
389 
390   if (!tensorTy1 || !tensorTy2) return false;
391 
392   if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() &&
393       tensorTy2.getElementType().isa<FloatType>())
394     return true;
395 
396   return tensorTy1.getElementType() == tensorTy2.getElementType();
397 }
398 
399 // Return true if type1 and type2 are shape-compatible and have same element
400 // type. If 'ignoreFpPrecision' is True, then allow floats with different
401 // precisions while checking element-types.
compatibleShapeAndElementType(Type type1,Type type2,bool ignoreFpPrecision=false)402 bool compatibleShapeAndElementType(Type type1, Type type2,
403                                    bool ignoreFpPrecision = false) {
404   if (failed(verifyCompatibleShape(type1, type2))) return false;
405   return tensorsHaveSameElType(type1.cast<ShapedType>(),
406                                type2.cast<ShapedType>(), ignoreFpPrecision);
407 }
408 
verifyReducerShape(Location loc,Block & block,ArrayRef<TensorType> inputArgTypes,ArrayRef<TensorType> initValueTypes,int64_t numInputs,ArrayRef<int64_t> allowedDimensions,bool allInputsUnranked,SmallVectorImpl<TensorType> & accumulatorSubShapes)409 LogicalResult verifyReducerShape(
410     Location loc, Block& block, ArrayRef<TensorType> inputArgTypes,
411     ArrayRef<TensorType> initValueTypes, int64_t numInputs,
412     ArrayRef<int64_t> allowedDimensions, bool allInputsUnranked,
413     SmallVectorImpl<TensorType>& accumulatorSubShapes) {
414   // Check that the number of reduction-region arguments matches with that of
415   // reduce-op's arguments.
416   if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
417     return mlir::emitError(loc)
418            << "Reduction-region must take " << numInputs * 2
419            << " parameters, but takes " << block.getArguments().size()
420            << " parameter(s)";
421 
422   // Check if the reduction-region produces non-zero outputs.
423   if (block.getTerminator()->getOperands().empty())
424     return mlir::emitError(loc)
425            << "The reduction-region expected to return some value(s)";
426 
427   // Check that the reduction-region returns list- of tensors.
428   // The number of result-tensors must match the `numInputs`.
429   if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
430       numInputs)
431     return mlir::emitError(loc)
432            << "Reduction-region here must produce " << numInputs
433            << " tensors, but produces "
434            << block.getTerminator()->getOperands().size() << " instead";
435 
436   for (Value retOperand : block.getTerminator()->getOperands()) {
437     auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
438     if (!tensorTy)
439       return mlir::emitError(loc) << "Reduction-region here must produce "
440                                      "tensor-typed result(s), but "
441                                      "produces "
442                                   << retOperand.getType() << " instead";
443 
444     accumulatorSubShapes.push_back(tensorTy);
445   }
446 
447   // Consider typical reduce-* op syntax:
448   //
449   //      op(I(i), V(j)):
450   //       block(BI(i), BV(j)):
451   //         ... some computation ...
452   //         return(R(i))
453   //
454   // where
455   //  I(i)  : i-th input of op
456   //  V(j)  : j-th init-value of op
457   //  BI(i) : i-th input of reducer-function
458   //  BV(j) : j-th init-value of reducer-function
459   //  R(i)  : i-th return-type
460   //
461   //  Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
462   //
463   //  Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
464   //    C1 : Check that BI(i) and R(i) have same shape and element-type.
465   //    C2 : Check that BV(j) and R(i) have same shape and element-type.
466   //    C3 : Check that V(j) and R(i) have same shape and element-type.
467   //
468   //  From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
469   //  have compatible shapes and element-types.
470   //  The next check, C4, adds constraints on how the type if I(i) is related
471   //  to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
472   //
473   //  C4.1 : Check that I(i) and BV(j) have same element-type.
474   //  C4.2 : Check that shape of BV(j) is a 'sub-sequence' of
475   //         'allowedDimensions'. 'allowedDimensions' is a list of dimensions
476   //         which any of BI(i), BV(j), and R(i) is allowed to have.
477   for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
478     // Check C1.
479     if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
480                                        block.getArgument(inputIdx).getType()))
481       return mlir::emitError(loc)
482              << "The type of reduction-region's parameter at index " << inputIdx
483              << " is different than the corresponding result type: "
484              << block.getArgument(inputIdx).getType() << " vs "
485              << accumulatorSubShapes[inputIdx];
486 
487     // Check C2.
488     if (!compatibleShapeAndElementType(
489             accumulatorSubShapes[inputIdx],
490             block.getArgument(numInputs + inputIdx).getType(),
491             /*ignoreFpPrecision=*/true))
492       return mlir::emitError(loc)
493              << "The type of reduction-region's parameter at index "
494              << numInputs + inputIdx
495              << " is different than the corresponding result type: "
496              << block.getArgument(numInputs + inputIdx).getType() << " vs "
497              << accumulatorSubShapes[inputIdx];
498 
499     // Check C3.
500     if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
501                                        initValueTypes[inputIdx],
502                                        /*ignoreFpPrecision=*/true))
503       return mlir::emitError(loc)
504              << "The type of reduction-region's result type at index "
505              << inputIdx
506              << " differs from the op's corresponding init-value type: "
507              << accumulatorSubShapes[inputIdx] << " vs "
508              << initValueTypes[inputIdx];
509 
510     // Check C4.1.
511     if (!tensorsHaveSameElType(
512             inputArgTypes[inputIdx],
513             block.getArgument(numInputs + inputIdx).getType(), true))
514       return mlir::emitError(loc)
515              << "The element-type of reduction-region's argument at index "
516              << numInputs + inputIdx << " is expected to be "
517              << inputArgTypes[inputIdx].getElementType() << ", but got "
518              << block.getArgument(numInputs + inputIdx).getType()
519              << " as its type.";
520 
521     // Check C4.2.
522     Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
523     auto blockArgTensorTy = blockArgType.cast<TensorType>();
524 
525     if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
526 
527     auto argShape = blockArgTensorTy.getShape();
528     if (argShape.size() > allowedDimensions.size())
529       return mlir::emitError(loc)
530              << "The rank of reduction-region's argument at index "
531              << numInputs + inputIdx
532              << " is expected to be <= " << allowedDimensions.size() << ", got "
533              << argShape.size();
534 
535     int64_t argShapeIdx = 0;
536     for (int64_t outputShapeIdx = 0;
537          outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) &&
538          argShapeIdx < static_cast<int64_t>(argShape.size());
539          outputShapeIdx++)
540       if (allowedDimensions[outputShapeIdx] == argShape[argShapeIdx])
541         argShapeIdx++;
542 
543     if (argShapeIdx != static_cast<int64_t>(argShape.size()))
544       return mlir::emitError(loc)
545              << "The shape of reduction-region's argument at index "
546              << numInputs + inputIdx
547              << " is not compatible with that of reduce-op's input-parameter "
548                 "at index "
549              << inputIdx;
550   }
551 
552   return success();
553 }
554 
potentiallyComplexBitwidth(Type type)555 unsigned potentiallyComplexBitwidth(Type type) {
556   auto complexTy = type.dyn_cast<ComplexType>();
557   return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
558                    : type.getIntOrFloatBitWidth();
559 }
560 }  // namespace
561 
562 //===----------------------------------------------------------------------===//
563 // Utilities for attributes
564 //===----------------------------------------------------------------------===//
565 
verifyEncoding(llvm::ArrayRef<int64_t> bounds,mlir::Type elementType,llvm::function_ref<mlir::InFlightDiagnostic ()> emitError) const566 LogicalResult TypeExtensionsAttr::verifyEncoding(
567     llvm::ArrayRef<int64_t> bounds, mlir::Type elementType,
568     llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
569   return hlo::verifyBounds(
570       getBounds(), RankedTensorType::get(bounds, elementType), emitError);
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // AllReduceOp
575 //===----------------------------------------------------------------------===//
576 
build(::mlir::OpBuilder & odsBuilder,::mlir::OperationState & odsState,::mlir::Type resultType,::mlir::Value operand,::mlir::DenseIntElementsAttr replicaGroups,::mlir::stablehlo::ChannelHandleAttr channelHandle)577 void AllReduceOp::build(
578     ::mlir::OpBuilder& odsBuilder, ::mlir::OperationState& odsState,
579     ::mlir::Type resultType, ::mlir::Value operand,
580     ::mlir::DenseIntElementsAttr replicaGroups,
581     /*optional*/ ::mlir::stablehlo::ChannelHandleAttr channelHandle) {
582   AllReduceOp::build(odsBuilder, odsState, resultType, operand,
583                      replicaGroups, channelHandle, nullptr);
584 }
585 
586 //===----------------------------------------------------------------------===//
587 // ReduceScatterOp
588 //===----------------------------------------------------------------------===//
589 
verifyReduceScatter(Operation * op,TypeRange operandTypes,TypeRange resultTypes,uint64_t scatterDimension)590 LogicalResult verifyReduceScatter(Operation* op, TypeRange operandTypes,
591                                   TypeRange resultTypes,
592                                   uint64_t scatterDimension) {
593   // If operand and result are both ranked, then the size of the scatter
594   // dimension in the operand should be a multiple of the size of the scatter
595   // dimension in the result.
596 
597   // TODO(zhouxin) Change the ODS definition to return int64_t.
598   if (static_cast<int64_t>(scatterDimension) < 0) {
599     return op->emitOpError("expects scatter_dimension >= 0");
600   }
601 
602   for (auto it : llvm::zip(operandTypes, resultTypes)) {
603     auto operandType = std::get<0>(it).cast<ShapedType>();
604     auto resultType = std::get<1>(it).cast<ShapedType>();
605     if (!operandType.hasRank() || !resultType.hasRank()) continue;
606     if (operandType.getRank() != resultType.getRank())
607       return op->emitOpError() << "operand and result should have same rank";
608     if (static_cast<int64_t>(scatterDimension) >= operandType.getRank())
609       return op->emitOpError()
610              << "scatter dim should be less than operand/result rank";
611     if (operandType.isDynamicDim(scatterDimension) ||
612         resultType.isDynamicDim(scatterDimension))
613       continue;
614     if (operandType.getDimSize(scatterDimension) == 0)
615       return op->emitOpError() << "operand scatter dimension cannot be zero";
616     if (resultType.getDimSize(scatterDimension) == 0)
617       return op->emitOpError() << "result scatter dimension cannot be zero";
618     if ((operandType.getDimSize(scatterDimension) %
619          resultType.getDimSize(scatterDimension)) != 0)
620       return op->emitOpError()
621              << "operand scatter dimension has size "
622              << operandType.getDimSize(scatterDimension)
623              << ", expected to be a multiple of result scatter dimension size "
624              << resultType.getDimSize(scatterDimension);
625 
626     // Non scatter dimensions should be equal.
627     for (uint64_t index : llvm::seq<uint64_t>(0, operandType.getRank())) {
628       if (index == scatterDimension || operandType.isDynamicDim(index) ||
629           resultType.isDynamicDim(index))
630         continue;
631       if (operandType.getDimSize(index) != resultType.getDimSize(index))
632         return op->emitOpError()
633                << "non scatter dimensions should be same for operand ("
634                << operandType.getDimSize(index) << ") and result ("
635                << resultType.getDimSize(index) << ")";
636     }
637   }
638   return success();
639 }
640 
verify()641 LogicalResult ReduceScatterOp::verify() {
642   if (failed(verifyReplicaGroups(*this, /*is_uniform_sized=*/true)))
643     return failure();
644   auto operandType = operand().getType().cast<TensorType>();
645   bool operandTypeRanked = operandType.isa<RankedTensorType>();
646   Block& block = computation().front();
647   SmallVector<TensorType> accumulatorSubshapes;
648   if (failed(verifyReducerShape(
649           this->getLoc(), block, {operandType},
650           {RankedTensorType::get({}, operandType.getElementType())},
651           /*numInputs=*/1, /*allowedDimensions=*/{},
652           /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes)))
653     return failure();
654 
655   return verifyReduceScatter(*this,
656                              /*operandTypes=*/{operand().getType()},
657                              /*resultTypes=*/{getType()},
658                              /*scatterDimension=*/scatter_dimension());
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // CompatibleOperandsAndResultType
663 //===----------------------------------------------------------------------===//
664 
665 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
666 // support quantization or sparsity.
667 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op)                        \
668   LogicalResult Op::inferReturnTypeComponents(                                \
669       MLIRContext* context, Optional<Location> location,                      \
670       ValueShapeRange operands, DictionaryAttr attributes,                    \
671       RegionRange regions,                                                    \
672       SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {          \
673     return inferReturnTypeComponentsFromOperands(context, location, operands, \
674                                                  attributes, regions,         \
675                                                  inferredReturnShapes);       \
676   }
677 
678 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)679 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
680 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
681 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
682 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
683 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
684 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
685 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
686 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
687 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
688 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp)
689 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp)
690 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op)
691 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp)
692 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogOp)
693 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
694 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp)
695 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp)
696 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp)
697 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp)
698 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp)
699 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp)
700 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp)
701 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PopulationCountOp)
702 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PowOp)
703 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReducePrecisionOp)
704 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RemOp)
705 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReverseOp)
706 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundNearestEvenOp)
707 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp)
708 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RsqrtOp)
709 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp)
710 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp)
711 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp)
712 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp)
713 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp)
714 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp)
715 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp)
716 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp)
717 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp)
718 
719 //===----------------------------------------------------------------------===//
720 // ConstantOp
721 //===----------------------------------------------------------------------===//
722 
723 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
724   assert(operands.empty() && "constant has no operands");
725 
726   // Return the held attribute value.
727   return value();
728 }
729 
730 // Builds a constant op with the specified attribute `value`.
build(OpBuilder &,OperationState & result,Attribute value)731 void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result,
732                        Attribute value) {
733   Type type;
734   if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
735     type = elemAttr.getType();
736   } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
737     // All XLA types must be tensor types. In the build() method, we want to
738     // provide more flexibility by allowing attributes of scalar types. But we
739     // need to wrap it up with ElementsAttr to construct valid XLA constants.
740     type =
741         RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
742     value = DenseElementsAttr::get(type.cast<TensorType>(), value);
743   } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
744     type = RankedTensorType::get(/*shape=*/{},
745                                  complexAttr.cast<TypedAttr>().getType());
746     value =
747         DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
748   }
749 
750   // TODO: support other XLA specific types.
751   assert(type && "unsupported attribute type for building constant");
752   result.types.push_back(type);
753   result.addAttribute("value", value);
754 }
755 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)756 LogicalResult ConstantOp::inferReturnTypes(
757     MLIRContext*, Optional<Location>, ValueRange operands,
758     DictionaryAttr attributes, RegionRange,
759     SmallVectorImpl<Type>& inferredReturnTypes) {
760   ConstantOpAdaptor adaptor(operands, attributes);
761   Type type = adaptor.value().getType();
762   inferredReturnTypes.push_back(type);
763   return success();
764 }
765 
isCompatibleReturnTypes(TypeRange l,TypeRange r)766 bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
767   if (l.size() != r.size() || l.size() != 1) return false;
768   auto lhsTy = l.front().cast<TensorType>();
769   auto rhsTy = r.front().cast<TensorType>();
770   // For comparisons of the uniform quantized element based tensor type, use the
771   // storage type since the constant value will be stored through the underlying
772   // storage type.
773   if (auto rhsElemTy =
774           rhsTy.getElementType().dyn_cast<quant::QuantizedType>()) {
775     rhsTy = hlo::getSameShapeTensorType(rhsTy, rhsElemTy.getStorageType());
776   }
777   return lhsTy == rhsTy;
778 }
779 
parse(OpAsmParser & parser,OperationState & result)780 ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
781   // Parse the generic form.
782   if (succeeded(parser.parseOptionalLParen())) {
783     if (parser.parseRParen()) return failure();
784     if (parser.parseOptionalAttrDict(result.attributes)) return failure();
785     if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
786         parser.parseArrow())
787       return failure();
788     Type resultTy;
789     if (parser.parseType(resultTy)) {
790       return failure();
791     }
792     result.addTypes(resultTy);
793     return success();
794   }
795 
796   ElementsAttr valueAttr;
797   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
798 
799   if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
800                                               result.attributes)) {
801     return failure();
802   }
803   result.addTypes(valueAttr.getType());
804   return success();
805 }
806 
807 /// Print a `constant` op.
808 ///
809 /// op ::= attr-dict $value
810 ///
811 /// When the `value` and `output` have different type, it just uses the default
812 /// operator assembly format as a fallback.
print(::mlir::OpAsmPrinter & p)813 void ConstantOp::print(::mlir::OpAsmPrinter& p) {
814   // If not all types are the same, use generic form.
815   if (value().getType() != getType()) {
816     p.printGenericOp(getOperation(), /*printOpName=*/false);
817     return;
818   }
819 
820   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
821   p << ' ';
822   p.printStrippedAttrOrType(valueAttr());
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // CustomCallOp
827 //===----------------------------------------------------------------------===//
828 
verify()829 LogicalResult CustomCallOp::verify() {
830   // If both operand and result layout attributes are not specified then nothing
831   // to verify.
832   if (!operand_layouts().has_value() && !result_layouts().has_value())
833     return success();
834 
835   // Layout constraints for either both operands & results or none should be
836   // specified.
837   if (operand_layouts().has_value() != result_layouts().has_value())
838     return emitOpError() << "Layout attributes should be specified for "
839                             "either both operands and results or none.";
840 
841   // Helper function to verify types and the corresponding layouts.
842   auto verifyTypesAndLayouts =
843       [this](TypeRange types, mlir::ArrayAttr layouts,
844              const std::string& valueName) -> LogicalResult {
845     if (types.size() != layouts.size())
846       return emitOpError() << "Number of " << valueName
847                            << "s must match the number of " << valueName
848                            << " layouts, " << types.size()
849                            << " != " << layouts.size();
850 
851     for (const auto& indexedTypeAndLayout :
852          llvm::enumerate(llvm::zip(types, layouts))) {
853       // Get index for more descriptive error message.
854       auto index = indexedTypeAndLayout.index();
855 
856       auto type = std::get<0>(indexedTypeAndLayout.value());
857       auto layout = std::get<1>(indexedTypeAndLayout.value())
858                         .cast<DenseIntElementsAttr>();
859 
860       if (type.isa<TupleType>())
861         return emitOpError() << "Tuple types are not fully supported with "
862                                 "layout constraints yet";
863       auto tensorType = type.dyn_cast<TensorType>();
864 
865       // For non-tensor types e.g. !stablehlo.token, the layout should be empty.
866       if (!tensorType) {
867         if (layout.empty()) continue;
868         return emitOpError()
869                << "Only tensor types can have non-empty layout: " << valueName
870                << " #" << index << " of type " << type << " has layout "
871                << layout;
872       }
873 
874       // For unranked tensors, we cannot verify the compatibility with layout
875       // any further.
876       if (!tensorType.hasRank()) continue;
877 
878       // Layout must be a permutation of [0, N) where N is the rank of the
879       // tensor type.
880       std::vector<int64_t> range(tensorType.getRank());
881       std::iota(range.begin(), range.end(), 0);
882       if (tensorType.getRank() != layout.size() ||
883           !std::is_permutation(range.begin(), range.end(), layout.begin()))
884         return emitOpError() << "incorrect layout " << layout << " for type "
885                              << type << ", layout must be a permutation of [0, "
886                              << tensorType.getRank() << ")";
887     }
888     return success();
889   };
890 
891   // At this point both `operand_layouts` and `result_layouts` are defined.
892   ArrayAttr operandLayouts = this->operand_layouts().value();
893   ArrayAttr resultLayouts = this->result_layouts().value();
894 
895   // Full support for layouts for arbitrary nesting of tuples is not
896   // supported yet.
897   //
898   // If result does not have any tuples, then i-th element of `result_layouts`
899   // specifies the layout constraints on i-th result.
900   //
901   // For the common case of a single tuple result packing non-tuple values, the
902   // i-th element of `result_layouts` specifies layout for i-th element of the
903   // result tuple.
904   TypeRange resultTypes;
905   if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
906     resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
907   else
908     resultTypes = getResultTypes();
909 
910   // Verify that operands and operand layouts match.
911   if (failed(
912           verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand")))
913     return failure();
914 
915   // Verify that results and result layouts match.
916   return verifyTypesAndLayouts(resultTypes, resultLayouts, "result");
917 }
918 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)919 void CustomCallOp::getEffects(
920     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
921         effects) {
922   // CustomCall has "all possible effects" unless the has_side_effect is present
923   // and set to false.
924   auto hasSideEffect = (*this)->getAttrOfType<BoolAttr>("has_side_effect");
925   if (hasSideEffect && !hasSideEffect.getValue()) return;
926   effects.emplace_back(MemoryEffects::Allocate::get());
927   effects.emplace_back(MemoryEffects::Free::get());
928   effects.emplace_back(MemoryEffects::Write::get());
929   effects.emplace_back(MemoryEffects::Read::get());
930 }
931 
932 //===----------------------------------------------------------------------===//
933 // CholeskyOp
934 //===----------------------------------------------------------------------===//
935 
936 // The following properties are already enforced by the ODS:
937 //   P0. a.element_type is floating or complex
938 // We intend to verify the following properties
939 //   P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s
940 //   P2. The two minor dimensions of 'a' must have equal size, got %s.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)941 LogicalResult CholeskyOp::inferReturnTypeComponents(
942     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
943     DictionaryAttr attributes, RegionRange regions,
944     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
945   CholeskyOp::Adaptor adaptor(operands, attributes, regions);
946   Type aType = adaptor.a().getType();
947   RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>();
948   if (!aRankedType) {
949     inferredReturnShapes.emplace_back(
950         aType.cast<TensorType>().getElementType());
951     return success();
952   }
953 
954   ArrayRef<int64_t> aShape = aRankedType.getShape();
955   if (aShape.size() < 2) {
956     return emitOptionalError(
957         location, "argument 'a' must have rank >= 2, got shape ", aShape, ".");
958   }
959 
960   int64_t lastDim = aShape[aShape.size() - 1];
961   int64_t penultimateDim = aShape[aShape.size() - 2];
962   if (!isDynamicDimSize(lastDim) && !isDynamicDimSize(penultimateDim) &&
963       lastDim != penultimateDim) {
964     return emitOptionalError(
965         location, "minor dimensions of 'a' must have equal size, got shape ",
966         aShape, ".");
967   }
968   inferredReturnShapes.emplace_back(aRankedType.getShape(),
969                                     aRankedType.getElementType());
970   return success();
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // DotOp
975 //===----------------------------------------------------------------------===//
976 namespace {
dimCompatible(int64_t a,int64_t b)977 bool dimCompatible(int64_t a, int64_t b) {
978   return isDynamicDimSize(a) || isDynamicDimSize(b) || a == b;
979 }
980 
inferDotReturnType(ShapedType lhs,ShapedType rhs)981 ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) {
982   auto elementType = lhs.getElementType();
983   if (!lhs.hasRank() || !rhs.hasRank()) {
984     return UnrankedTensorType::get(elementType);
985   }
986 
987   // vector dot vector
988   if (1 == lhs.getRank() && 1 == rhs.getRank() &&
989       dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
990     return RankedTensorType::get({}, elementType);
991   }
992   // matrix dot vector
993   if (2 == lhs.getRank() && 1 == rhs.getRank() &&
994       dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
995     return RankedTensorType::get({lhs.getDimSize(0)}, elementType);
996   }
997   // vector dot matrix
998   if (1 == lhs.getRank() && 2 == rhs.getRank() &&
999       dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
1000     return RankedTensorType::get({rhs.getDimSize(1)}, elementType);
1001   }
1002   // matrix dot matrix
1003   if (2 == lhs.getRank() && 2 == rhs.getRank() &&
1004       dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
1005     int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)};
1006     return RankedTensorType::get(shape, elementType);
1007   }
1008   return {};
1009 }
1010 }  // namespace
1011 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1012 LogicalResult DotOp::inferReturnTypes(
1013     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1014     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1015   DotOp::Adaptor op(operands);
1016   auto lhsType = op.lhs().getType().cast<ShapedType>();
1017   auto rhsType = op.rhs().getType().cast<ShapedType>();
1018   inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType));
1019   return success();
1020 }
1021 
verify()1022 LogicalResult DotOp::verify() {
1023   auto lhsType = lhs().getType().cast<ShapedType>();
1024   auto rhsType = rhs().getType().cast<ShapedType>();
1025   auto resultType = getType().cast<ShapedType>();
1026   auto expectReturnType = inferDotReturnType(lhsType, rhsType);
1027   if (!expectReturnType) {
1028     return emitError() << "Unexpected operands type: " << lhsType << " and "
1029                        << rhsType;
1030   }
1031   if (resultType.hasRank() && expectReturnType.hasRank()) {
1032     if (resultType.getShape() != expectReturnType.getShape()) {
1033       return emitError() << "Unexpected result type: has " << resultType
1034                          << " but inferred " << expectReturnType
1035                          << " from operands " << lhsType << " and " << rhsType;
1036     }
1037   }
1038   return success();
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // DotGeneralOp
1043 //===----------------------------------------------------------------------===//
1044 
verify()1045 LogicalResult DotGeneralOp::verify() {
1046   auto dimNumbers = this->dot_dimension_numbers();
1047 
1048   ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1049   ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1050   ArrayRef<int64_t> lhsContractingDims =
1051       dimNumbers.getLhsContractingDimensions();
1052   ArrayRef<int64_t> rhsContractingDims =
1053       dimNumbers.getRhsContractingDimensions();
1054 
1055   if (lhsBatchingDims.size() != rhsBatchingDims.size()) {
1056     return emitOpError() << "lhs and rhs should have the same number of "
1057                             "batching dimensions";
1058   }
1059   if (lhsContractingDims.size() != rhsContractingDims.size()) {
1060     return emitOpError() << "lhs and rhs should have the same number of "
1061                             "contracting dimensions";
1062   }
1063 
1064   llvm::SmallDenseSet<int64_t> dimSet;
1065 
1066   auto checkDimsDistinct =
1067       [this](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims,
1068              llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs,
1069              llvm::StringRef rhs) -> LogicalResult {
1070     auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims);
1071     for (auto dim : dims) {
1072       auto [_, wasInserted] = dimSet.insert(dim);
1073       if (!wasInserted) {
1074         return emitOpError() << "has duplicated dimension from " << lhs
1075                              << " and " << rhs << ": " << dim;
1076       }
1077     }
1078     return success();
1079   };
1080 
1081   if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet,
1082                                "lhs_batching_dimensions",
1083                                "lhs_contracting_dimensions"))) {
1084     return failure();
1085   }
1086   dimSet.clear();
1087   if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet,
1088                                "rhs_batching_dimensions",
1089                                "rhs_contracting_dimensions"))) {
1090     return failure();
1091   }
1092 
1093   auto checkDimsInRange = [this](int64_t rank, ArrayRef<int64_t> dims,
1094                                  llvm::StringRef dimName) -> LogicalResult {
1095     auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; };
1096     const auto* dimsNotInRange =
1097         std::find_if_not(dims.begin(), dims.end(), inRange);
1098     if (dimsNotInRange != dims.end()) {
1099       return emitOpError() << dimName << " value: " << *dimsNotInRange
1100                            << " is out of range: "
1101                            << "[0, " << rank << ")";
1102     }
1103     return success();
1104   };
1105 
1106   auto lhsType = this->lhs().getType().dyn_cast<RankedTensorType>();
1107   auto rhsType = this->rhs().getType().dyn_cast<RankedTensorType>();
1108 
1109   if (lhsType) {
1110     if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims,
1111                                 "lhs_batching_dimensions")) ||
1112         failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims,
1113                                 "lhs_contracting_dimensions"))) {
1114       return failure();
1115     }
1116   }
1117   if (rhsType) {
1118     if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims,
1119                                 "rhs_batching_dimensions")) ||
1120         failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims,
1121                                 "rhs_contracting_dimensions"))) {
1122       return failure();
1123     }
1124   }
1125 
1126   if (lhsType && rhsType) {
1127     // Dimension sizes must be compatible for lhs/rhs.
1128     auto lhsShape = lhsType.getShape();
1129     auto rhsShape = rhsType.getShape();
1130 
1131     for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) {
1132       if (lhsShape[lhs] != rhsShape[rhs]) {
1133         return emitOpError() << "batching dimension sizes must match for "
1134                                 "lhs/rhs";
1135       }
1136     }
1137     for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) {
1138       if (lhsShape[lhs] != rhsShape[rhs]) {
1139         return emitOpError() << "contracting dimension sizes must match for "
1140                                 "lhs/rhs";
1141       }
1142     }
1143   }
1144   return success();
1145 }
1146 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1147 LogicalResult DotGeneralOp::reifyReturnTypeShapes(
1148     OpBuilder& builder, ValueRange operands,
1149     SmallVectorImpl<Value>& reifiedReturnShapes) {
1150   auto lhsType = lhs().getType().dyn_cast<ShapedType>();
1151   auto rhsType = rhs().getType().dyn_cast<ShapedType>();
1152   if (!lhsType || !rhsType) {
1153     return failure();
1154   }
1155 
1156   Adaptor adaptor(operands);
1157   auto dimNumbers = dot_dimension_numbers();
1158   SmallVector<Value> dimensions;
1159   for (const int64_t lhsDim : dimNumbers.getLhsBatchingDimensions()) {
1160     dimensions.push_back(
1161         builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhsDim));
1162   }
1163 
1164   for (int64_t i = 0; i < lhsType.getRank(); i++) {
1165     if (!llvm::is_contained(dimNumbers.getLhsContractingDimensions(), i) &&
1166         !llvm::is_contained(dimNumbers.getLhsBatchingDimensions(), i)) {
1167       dimensions.push_back(
1168           builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
1169     }
1170   }
1171   for (int64_t i = 0; i < rhsType.getRank(); i++) {
1172     if (!llvm::is_contained(dimNumbers.getRhsContractingDimensions(), i) &&
1173         !llvm::is_contained(dimNumbers.getRhsBatchingDimensions(), i)) {
1174       dimensions.push_back(
1175           builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
1176     }
1177   }
1178 
1179   reifiedReturnShapes.push_back(
1180       builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
1181   return success();
1182 }
1183 
1184 //===----------------------------------------------------------------------===//
1185 // FftOp
1186 //===----------------------------------------------------------------------===//
1187 
1188 // We intend to verify the following properties
1189 // P1. 1 <= rank <= 3
1190 // P2. Element types agree with fft_type
1191 // P3. Operand shape dimensions agree with fft_length for the given fft_type
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1192 LogicalResult FftOp::inferReturnTypeComponents(
1193     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
1194     DictionaryAttr attributes, RegionRange regions,
1195     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1196   FftOp::Adaptor adaptor(operands, attributes, regions);
1197   auto fftLength = adaptor.fft_length().getValues<int64_t>();
1198   int64_t fftRank = fftLength.size();
1199 
1200   // P1.
1201   if (fftRank > 3 || fftRank < 1) {
1202     return emitOptionalError(location, "rank must be between 1 and 3, but got ",
1203                              fftRank, ".");
1204   }
1205 
1206   // P2. Element type agreement
1207   // FFT : C -> C
1208   // IFFT : C -> C
1209   // RFFT : R -> C
1210   // IRFFT : C -> R
1211   auto fftType = adaptor.fft_type();
1212   auto operandType = adaptor.operand().getType().cast<TensorType>();
1213   Type operandElementType = operandType.getElementType();
1214   // Check the input element type and infer return element type
1215   if (fftType == FftType::RFFT) {
1216     if (!operandElementType.isF32() && !operandElementType.isF64()) {
1217       return emitOptionalError(
1218           location, "RFFT requires f32 or f64 input type, but is given ",
1219           operandElementType, ".");
1220     }
1221   } else {
1222     if (!operandElementType.isa<ComplexType>()) {
1223       return emitOptionalError(
1224           location, stringifyFftType(fftType),
1225           " takes a complex tensor as input, but is given ", operandType, ".");
1226     }
1227   }
1228   // Generate the output element type
1229   Type resultElementType = operandElementType;
1230   if (fftType == FftType::RFFT) {  // RFFT : R -> C
1231     resultElementType = ComplexType::get(resultElementType);
1232   } else if (fftType == FftType::IRFFT) {  // IRFFT : C -> R
1233     resultElementType = operandElementType.cast<ComplexType>().getElementType();
1234   }
1235 
1236   // P3. Check input shape and infer return shape
1237   operandType = operandType.dyn_cast<RankedTensorType>();
1238   if (!operandType) {
1239     inferredReturnShapes.emplace_back(resultElementType);
1240     return success();
1241   }
1242   auto operandShape = operandType.getShape();
1243   if (static_cast<int64_t>(operandShape.size()) < fftRank) {
1244     return emitOptionalError(
1245         location, "operand rank must not be less than fft rank of ", fftRank,
1246         " for operand of type ", operandType, ".");
1247   }
1248 
1249   SmallVector<int64_t> resultShape = to_vector(operandShape);
1250 
1251   if (fftType == FftType::RFFT) {
1252     auto shapeBack = operandShape.take_back(fftRank);
1253     for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1254       if (operandDim != fftDim) {
1255         return emitOptionalError(
1256             location,
1257             "RFFT requires innermost dimensions match fft_length. Got: ",
1258             operandShape, " but wanted ", fftLength, ".");
1259       }
1260     }
1261     if (fftLength[fftRank - 1] != 0) {
1262       resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1;
1263     }
1264   }
1265   if (fftType == FftType::IRFFT) {
1266     auto shapeBack = operandShape.take_back(fftRank).drop_back();
1267     for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1268       if (operandDim != fftDim) {
1269         return emitOptionalError(location,
1270                                  "IRFFT requires non-final dimensions "
1271                                  "match fft_length. Got: ",
1272                                  operandShape, " but wanted ", fftLength,
1273                                  ", and ", operandDim, " != ", fftDim, ".");
1274       }
1275     }
1276     if ((operandShape[operandShape.size() - 1] != 0 ||
1277          fftLength[fftRank - 1] != 0) &&
1278         operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1)
1279       return emitOptionalError(location,
1280                                "IRFFT requires innermost dimension match "
1281                                "fft_length[-1]/2+1. Got: ",
1282                                operandShape, " but fft_length is ", fftLength,
1283                                ".");
1284     resultShape[resultShape.size() - 1] = fftLength[fftRank - 1];
1285   }
1286 
1287   inferredReturnShapes.emplace_back(resultShape, resultElementType);
1288   return success();
1289 }
1290 
1291 //===----------------------------------------------------------------------===//
1292 // GatherOp
1293 //===----------------------------------------------------------------------===//
1294 
1295 namespace {
1296 
1297 // following https://www.tensorflow.org/xla/operation_semantics#gather
1298 // The bounds for the output array along dimension i is computed as follows:
1299 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
1300 // then we pick
1301 // the corresponding dimension bounds out of start_indices.shape, skipping
1302 // index_vector_dim
1303 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
1304 // start_indices.shape.dims[k+1] otherwise).
1305 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
1306 // then we pick
1307 // the corresponding bound out of slice_sizes after accounting for
1308 // collapsed_slice_dims
1309 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
1310 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
1311 
getSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizes)1312 void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
1313                         ValueRange operands,
1314                         SmallVectorImpl<Value>& sliceSizes) {
1315   for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
1316     sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
1317   }
1318 }
1319 
getSliceSizeValues(DynamicGatherOp *,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizeValues)1320 void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder,
1321                         Location loc, ValueRange operands,
1322                         SmallVectorImpl<Value>& sliceSizeValues) {
1323   DynamicGatherOp::Adaptor adaptor(operands);
1324   Value sliceSizes = adaptor.slice_sizes();
1325   auto sliceSizesTy = sliceSizes.getType().cast<ShapedType>();
1326   for (int64_t i = 0; i < sliceSizesTy.getDimSize(0); ++i) {
1327     Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
1328     sliceSizeValues.push_back(
1329         builder.create<tensor::ExtractOp>(loc, sliceSizes, idx));
1330   }
1331 }
1332 
1333 // Verify the following properties:
1334 //  P1. Verify no repeat in start_index_map.
1335 //  P2. Verify 0 <= start_index_map[i] < rank(operand), for every i.
1336 //  P3. Verify 0 <= index_vector_dim <= rank(start_indices).
1337 //  P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim].
1338 //  P5. Verify offset_dims is_sorted and no repeated.
1339 //  P6. Verify collapsed_slice_dims is_sorted and no repeated.
1340 //  P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims).
1341 //  P8. Verify slice_sizes has rank of 1.
1342 //  P9. Verify size(slice_sizes) == rank(operand).
1343 //  P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items.
verifyGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,ShapeAdaptor sliceSizesShape,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1344 static LogicalResult verifyGather(
1345     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1346     ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers,
1347     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1348   int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1349 
1350   // Check startIndexMap
1351   auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap());
1352   // P1.
1353   if (hasDuplicates(startIndexMap))
1354     return errorEmitter() << "expects start_index_map to not repeat, got: ["
1355                           << startIndexMap << "]";
1356 
1357   // P2.
1358   for (int64_t i = 0; i < static_cast<int64_t>(startIndexMap.size()); ++i)
1359     if (startIndexMap[i] < 0 ||
1360         (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank()))
1361       return errorEmitter()
1362              << "start_index_map[" << i << "]: " << startIndexMap[i]
1363              << " is out of bounds for "
1364              << "operand rank " << operandShape.getRank();
1365 
1366   if (startIndicesShape.hasRank()) {
1367     // P3.
1368     // index_vector_dim == start_indices.rank implies a trailing 1 on the shape
1369     // of start_indices.
1370     if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0)
1371       return errorEmitter() << "index_vector_dim " << indexVectorDim
1372                             << " is out of bounds for start indices with rank "
1373                             << startIndicesShape.getRank();
1374 
1375     bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
1376     if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
1377       int64_t effectiveDimSize;
1378       if (impliedTrailingDim)
1379         effectiveDimSize = 1;
1380       else
1381         effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
1382       // P4.
1383       if (effectiveDimSize !=
1384           static_cast<int64_t>(dimensionNumbers.getStartIndexMap().size()))
1385         return errorEmitter() << "start_index_map size ("
1386                               << dimensionNumbers.getStartIndexMap().size()
1387                               << ") is not equal to size of index dimension ("
1388                               << indexVectorDim << ") of start_indices ("
1389                               << effectiveDimSize << ")";
1390     }
1391   }
1392 
1393   // P5.
1394   auto offsetDims = to_vector(dimensionNumbers.getOffsetDims());
1395   if (!llvm::is_sorted(offsetDims))
1396     return errorEmitter() << "expects offset_dims to be sorted, got: ["
1397                           << offsetDims << "]";
1398   if (hasDuplicates(offsetDims))
1399     return errorEmitter() << "expects offset_dims to not repeat, got: ["
1400                           << offsetDims << "]";
1401 
1402   // P6.
1403   auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims());
1404   if (!llvm::is_sorted(collapsedSliceDims))
1405     return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: ["
1406                           << collapsedSliceDims << "]";
1407   if (hasDuplicates(collapsedSliceDims))
1408     return errorEmitter()
1409            << "expects collapsed_slice_dims to not repeat, got: ["
1410            << collapsedSliceDims << "]";
1411 
1412   // P7.
1413   int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() +
1414                                dimensionNumbers.getCollapsedSliceDims().size();
1415   if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
1416     return errorEmitter() << "offset_dims size ("
1417                           << dimensionNumbers.getOffsetDims().size()
1418                           << ") plus collapse_slice_dims size ("
1419                           << dimensionNumbers.getCollapsedSliceDims().size()
1420                           << ") is not equal to operand rank ("
1421                           << operandShape.getRank() << ")";
1422 
1423   // P8.
1424   // This should be fully expressible with type constraints, but it isn't
1425   // obvious how to do that with the current infrastructure.
1426   if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
1427     return errorEmitter() << "slice_sizes.rank != 1";
1428   if (sliceSizesShape.hasStaticShape()) {
1429     int64_t sliceSize = sliceSizesShape.getNumElements();
1430 
1431     // P9.
1432     if (sliceSize != impliedOperandRank)
1433       return errorEmitter() << "slice_sizes size (" << sliceSize
1434                             << ") not equal to (implied) operand rank ("
1435                             << impliedOperandRank << ")";
1436 
1437     // P10.
1438     for (auto dim : dimensionNumbers.getCollapsedSliceDims())
1439       if (dim < 0 || dim >= sliceSize)
1440         return errorEmitter() << "collapsed dimension " << dim
1441                               << " is out of bounds for slice_sizes.size ("
1442                               << sliceSize << ")";
1443   }
1444 
1445   return success();
1446 }
1447 
1448 // Verify the following properties:
1449 //  P1. Verifications by verifyGather().
1450 //  P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims.
1451 //  P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i.
verifyStaticGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,DenseIntElementsAttr sliceSizes,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1452 static LogicalResult verifyStaticGather(
1453     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1454     DenseIntElementsAttr sliceSizes,
1455     GatherDimensionNumbersAttr dimensionNumbers,
1456     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1457   // P1.
1458   // For some reason the getType call is necessary here
1459   if (failed(verifyGather(
1460           /*operandShape=*/operandShape,
1461           /*startIndicesShape=*/startIndicesShape,
1462           /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers,
1463           errorEmitter)))
1464     return failure();
1465 
1466   // P2.
1467   for (auto dim : dimensionNumbers.getCollapsedSliceDims()) {
1468     int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim];
1469     if (sliceDimSize > 1) {
1470       return errorEmitter() << "slice_sizes collapsed dimension " << dim
1471                             << " should <= 1 but got " << sliceDimSize;
1472     }
1473   }
1474 
1475   // P3.
1476   if (operandShape.hasRank()) {
1477     for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) {
1478       if (operandShape.isDynamicDim(it.index())) continue;
1479       auto operandDimSize = operandShape.getDimSize(it.index());
1480       auto sliceDimSize = it.value();
1481       if (sliceDimSize < 0 || sliceDimSize > operandDimSize)
1482         return errorEmitter() << "slice size (" << sliceDimSize
1483                               << ") is out of bounds for operand dimension ("
1484                               << operandDimSize << ") at index " << it.index();
1485     }
1486   }
1487   return success();
1488 }
1489 
1490 template <typename dimTy>
inferGatherShape(int64_t resultRank,llvm::function_ref<dimTy (int64_t)> getStartIndicesDim,llvm::function_ref<dimTy (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<dimTy> & shape)1491 static void inferGatherShape(
1492     int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
1493     llvm::function_ref<dimTy(int64_t)> getSliceDim,
1494     GatherDimensionNumbersAttr dimensionNumbers,
1495     SmallVectorImpl<dimTy>& shape) {
1496   ArrayRef<int64_t> collapsedSliceDims =
1497       dimensionNumbers.getCollapsedSliceDims();
1498   int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1499 
1500   // We don't necessarily know the rank of sliceSizes, but we do know that it
1501   // can't be larger than the highest collapsed dimension. So go through those
1502   // and populate the leading dimensions of adjustedSliceSizes. The trailing
1503   // dimensions can just be adjusted by an offset.
1504   const auto* maxCollapsedDimIt =
1505       std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
1506   int64_t maxCollapsedDim = -1;
1507   if (maxCollapsedDimIt != collapsedSliceDims.end())
1508     maxCollapsedDim = *maxCollapsedDimIt;
1509 
1510   SmallVector<dimTy> adjustedSliceSizePrefix;
1511   for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
1512     if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
1513     adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex));
1514   }
1515   auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
1516     if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size()))
1517       return adjustedSliceSizePrefix[index];
1518     return getSliceDim(index + collapsedSliceDims.size());
1519   };
1520 
1521   ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1522 
1523   // Dimensions in the output that aren't offset dimensions are called batch
1524   // dimensions.
1525   SmallVector<int64_t> batchDims;
1526   for (int dim = 0; dim < resultRank; ++dim)
1527     if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
1528 
1529   for (int i = 0; i < resultRank; ++i) {
1530     const auto* offsetDimsIt =
1531         std::find(offsetDims.begin(), offsetDims.end(), i);
1532     if (offsetDimsIt != offsetDims.end()) {
1533       auto index = std::distance(offsetDims.begin(), offsetDimsIt);
1534       shape.push_back(getAdjustedSliceDim(index));
1535       continue;
1536     }
1537     auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
1538     assert(batchDimsIt != batchDims.end());
1539     auto index = std::distance(batchDims.begin(), batchDimsIt);
1540     // This can never run into the special case where start_indices gets
1541     // implicitly expanded with a trailing 1 if
1542     // index_vector_dim = start_indices.rank because then index would equal
1543     // index_vector_dim, which means we'd be looking at index+1, which would be
1544     // out of bounds anyway.
1545     if (index >= indexVectorDim) ++index;
1546     shape.push_back(getStartIndicesDim(index));
1547   }
1548 }
1549 
1550 // Verify the following properties:
1551 //  P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i.
1552 //      (output_shape_rank = size(offset_dims) + rank(start_indices) -1)
inferGatherReturnTypeComponents(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,llvm::function_ref<int64_t (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1553 static LogicalResult inferGatherReturnTypeComponents(
1554     ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1555     llvm::function_ref<int64_t(int64_t)> getSliceDim,
1556     GatherDimensionNumbersAttr dimensionNumbers,
1557     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes,
1558     llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1559   Type elementType = operandShape.getElementType();
1560 
1561   // We need this to determine the result rank. We could still place bounds on
1562   // the result rank if that was something ShapedTypeComponents could express.
1563   if (!startIndicesShape.hasRank()) {
1564     inferredReturnShapes.push_back(elementType);
1565     return success();
1566   }
1567 
1568   ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1569   int64_t startIndicesRank = startIndicesShape.getRank();
1570   // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
1571   // appended to start_indices shape.
1572   if (dimensionNumbers.getIndexVectorDim() == startIndicesRank)
1573     ++startIndicesRank;
1574   int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
1575   // P1.
1576   for (int64_t i = 0; i < static_cast<int64_t>(offsetDims.size()); ++i)
1577     if (offsetDims[i] < 0 || offsetDims[i] >= resultRank)
1578       return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i]
1579                             << " is out of bounds for "
1580                             << "implied result rank " << resultRank;
1581 
1582   auto getStartIndicesDim = [&](int64_t index) {
1583     return startIndicesShape.getDimSize(index);
1584   };
1585 
1586   SmallVector<int64_t> shape;
1587   inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
1588                             dimensionNumbers, shape);
1589 
1590   inferredReturnShapes.emplace_back(shape, elementType);
1591   return success();
1592 }
1593 
1594 template <typename Op>
reifyGatherShape(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1595 LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
1596                                SmallVectorImpl<Value>& reifiedReturnShapes) {
1597   // No support for unranked gather output shape a.t.m.
1598   auto resultTy =
1599       op->getResult().getType().template dyn_cast<RankedTensorType>();
1600   if (!resultTy) return failure();
1601 
1602   typename Op::Adaptor adaptor(operands);
1603   Value startIndices = adaptor.start_indices();
1604 
1605   Location loc = op->getLoc();
1606   int resultRank = resultTy.getRank();
1607   Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType();
1608   auto toShapeElType = [&](Value v) {
1609     return maybeCastTo(builder, loc, v, shapeElTy);
1610   };
1611 
1612   SmallVector<Value, 4> sliceSizes;
1613   getSliceSizeValues(op, builder, loc, operands, sliceSizes);
1614   llvm::transform(sliceSizes, sliceSizes.begin(),
1615                   [&](Value v) { return toShapeElType(v); });
1616 
1617   auto getStartIndicesDim = [&](int64_t index) {
1618     return toShapeElType(
1619         builder.create<tensor::DimOp>(loc, startIndices, index));
1620   };
1621   SmallVector<Value, 4> shapeValues;
1622   auto getSliceDim = [&sliceSizes](int64_t index) -> Value {
1623     return sliceSizes[index];
1624   };
1625   inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
1626                           op->dimension_numbers(), shapeValues);
1627 
1628   Value outputShape = builder.create<tensor::FromElementsOp>(
1629       loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues);
1630   reifiedReturnShapes.push_back(outputShape);
1631 
1632   return success();
1633 }
1634 
1635 }  // namespace
1636 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1637 LogicalResult GatherOp::reifyReturnTypeShapes(
1638     OpBuilder& builder, ValueRange operands,
1639     SmallVectorImpl<Value>& reifiedReturnShapes) {
1640   return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1641 }
1642 
1643 // The following properties are already enforced by the ODS:
1644 //  P0. Verify the start_indices has element type of integer.
1645 // Verify the following properties:
1646 //  Verifications by verifyStaticGather() and verifyGather() inside it.
1647 //  Verifications by inferGatherReturnTypeComponents.
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1648 LogicalResult GatherOp::inferReturnTypeComponents(
1649     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1650     DictionaryAttr attributes, RegionRange regions,
1651     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1652   // TODO(zhouxin) remove this comment after the ordering issue is clear.
1653   // This can get called before other op verify methods, so we have to do a
1654   // bunch of verification up front. With a better story for ordering and/or
1655   // multi-phase op verification, this should hopefully all go away.
1656   Location loc = location.value_or(UnknownLoc::get(context));
1657   auto errorEmitter = [&loc]() {
1658     return mlir::emitError(loc)
1659            << "'" << GatherOp::getOperationName() << "' op ";
1660   };
1661   GatherOp::Adaptor adaptor(operands, attributes, regions);
1662   if (failed(adaptor.verify(loc))) return failure();
1663 
1664   // We want the ShapeAdaptors, so can't route via the adaptor :-/
1665   ShapeAdaptor operandShape = operands.getShape(0);
1666   ShapeAdaptor startIndicesShape = operands.getShape(1);
1667   GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1668   DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes();
1669 
1670   if (failed(verifyStaticGather(/*operandShape=*/operandShape,
1671                                 /*startIndicesShape=*/startIndicesShape,
1672                                 /*sliceSizes=*/sliceSizesAttr, dimensionNumbers,
1673                                 errorEmitter)))
1674     return failure();
1675 
1676   auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t {
1677     return sliceSizesAttr.getValues<int64_t>()[index];
1678   };
1679 
1680   return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1681                                          getSliceDim, dimensionNumbers,
1682                                          inferredReturnShapes, errorEmitter);
1683 }
1684 
1685 //===----------------------------------------------------------------------===//
1686 // DynamicGatherOp
1687 //===----------------------------------------------------------------------===//
1688 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1689 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
1690     OpBuilder& builder, ValueRange operands,
1691     SmallVectorImpl<Value>& reifiedReturnShapes) {
1692   return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1693 }
1694 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1695 LogicalResult DynamicGatherOp::inferReturnTypeComponents(
1696     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1697     DictionaryAttr attributes, RegionRange regions,
1698     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1699   // This can get called before other op verify methods, so we have to do a
1700   // bunch of verification up front. With a better story for ordering and/or
1701   // multi-phase op verification, this should hopefully all go away.
1702   Location loc = location.value_or(UnknownLoc::get(context));
1703   auto errorEmitter = [&loc]() {
1704     return mlir::emitError(loc)
1705            << "'" << DynamicGatherOp::getOperationName() << "' op ";
1706   };
1707   DynamicGatherOp::Adaptor adaptor(operands, attributes, regions);
1708   if (failed(adaptor.verify(loc))) return failure();
1709 
1710   // We want the ShapeAdaptors, so can't route via the adaptor :-/
1711   ShapeAdaptor operandShape = operands.getShape(0);
1712   ShapeAdaptor startIndicesShape = operands.getShape(1);
1713   ShapeAdaptor sliceSizesShape = operands.getShape(2);
1714   GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1715 
1716   if (failed(verifyGather(/*operandShape=*/operandShape,
1717                           /*startIndicesShape=*/startIndicesShape,
1718                           /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers,
1719                           errorEmitter)))
1720     return failure();
1721 
1722   auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; };
1723   return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1724                                          getSliceDim, dimensionNumbers,
1725                                          inferredReturnShapes, errorEmitter);
1726 }
1727 
1728 //===----------------------------------------------------------------------===//
1729 // GetDimensionSizeOp
1730 //===----------------------------------------------------------------------===//
1731 //
verify()1732 LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); }
1733 
1734 //===----------------------------------------------------------------------===//
1735 // IotaOp
1736 //===----------------------------------------------------------------------===//
1737 
verify()1738 LogicalResult IotaOp::verify() {
1739   auto shape = getType().cast<ShapedType>();
1740   if (!shape.hasRank()) return success();
1741 
1742   if (shape.getRank() == 0) return emitOpError() << "does not support scalars.";
1743 
1744   auto iotaDimension = static_cast<int64_t>(this->iota_dimension());
1745   if (iotaDimension >= shape.getRank() || iotaDimension < 0)
1746     return emitOpError()
1747            << "iota dimension cannot go beyond the output rank or be negative.";
1748   return success();
1749 }
1750 
1751 //===----------------------------------------------------------------------===//
1752 // DynamicIotaOp
1753 //===----------------------------------------------------------------------===//
1754 
castToIndexTensor(OpBuilder & builder,Location loc,Value shapeOp)1755 static Value castToIndexTensor(OpBuilder& builder, Location loc,
1756                                Value shapeOp) {
1757   ShapedType resultTy = shape::getExtentTensorType(
1758       builder.getContext(), shapeOp.getType().cast<ShapedType>().getDimSize(0));
1759   if (shapeOp.getType() == resultTy) return shapeOp;  // Nothing to do.
1760   return builder.create<arith::IndexCastOp>(loc, resultTy, shapeOp);
1761 }
1762 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1763 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
1764     OpBuilder& builder, ValueRange operands,
1765     SmallVectorImpl<Value>& reifiedReturnShapes) {
1766   DynamicIotaOp::Adaptor adaptor(operands);
1767   reifiedReturnShapes.push_back(
1768       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
1769   return success();
1770 }
1771 
1772 //===----------------------------------------------------------------------===//
1773 // DynamicUpdateSliceOp
1774 //===----------------------------------------------------------------------===//
1775 
verify()1776 LogicalResult DynamicUpdateSliceOp::verify() {
1777   OperandRange indices = start_indices();
1778   if (indices.size() <= 1) return success();
1779 
1780   // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
1781   // is OK to cast indices to ShapedType here.
1782   auto idxTensor = indices.take_front().front().getType().cast<ShapedType>();
1783   Type firstElemTy = idxTensor.getElementType();
1784   Type elemTy;
1785 
1786   for (auto idx : llvm::drop_begin(indices, 1)) {
1787     idxTensor = idx.getType().cast<ShapedType>();
1788     elemTy = idxTensor.getElementType();
1789 
1790     if (firstElemTy != elemTy) {
1791       return emitOpError() << "start indices must have same element type "
1792                               "(encountered mismatch: "
1793                            << firstElemTy << " vs " << elemTy << ")";
1794     }
1795   }
1796   return success();
1797 }
1798 
1799 //===----------------------------------------------------------------------===//
1800 // AbsOp
1801 //===----------------------------------------------------------------------===//
1802 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1803 LogicalResult AbsOp::inferReturnTypes(
1804     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1805     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1806   auto operandTy = (*operands.begin()).getType().cast<ShapedType>();
1807   Type elementTy = operandTy.getElementType();
1808   if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
1809     elementTy = complexTy.getElementType();
1810   }
1811 
1812   Type resultTy;
1813   if (auto rankedOperandTy = operandTy.dyn_cast<RankedTensorType>()) {
1814     resultTy = RankedTensorType::get(operandTy.getShape(), elementTy,
1815                                      rankedOperandTy.getEncoding());
1816   } else if (operandTy.hasRank()) {
1817     resultTy = RankedTensorType::get(operandTy.getShape(), elementTy);
1818   } else {
1819     resultTy = UnrankedTensorType::get(elementTy);
1820   }
1821   inferredReturnTypes.push_back(resultTy);
1822   return success();
1823 }
1824 
1825 //===----------------------------------------------------------------------===//
1826 // CollectivePermuteOp
1827 //===----------------------------------------------------------------------===//
1828 
1829 // Verifies the source target pairs attached to collective permute.
verifyCollectivePermuteSourceTargetPairs(Operation * op,DenseIntElementsAttr attr)1830 LogicalResult verifyCollectivePermuteSourceTargetPairs(
1831     Operation* op, DenseIntElementsAttr attr) {
1832   auto type = attr.getType().dyn_cast<RankedTensorType>();
1833   if (type.getRank() != 2)
1834     return op->emitError() << "expect source_target_pairs attribute to be of "
1835                               "rank 2, but got rank "
1836                            << type.getRank();
1837   if (type.getShape()[1] != 2)
1838     return op->emitError()
1839            << "expect source_target_pairs attribute of shape (N, 2), but got ("
1840            << type.getShape() << ")";
1841   // Check source target pairs for duplicate sources or targets.
1842   llvm::DenseSet<int64_t> sources;
1843   llvm::DenseSet<int64_t> targets;
1844   for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
1845     auto val = (*i).getSExtValue();
1846     if (i.getIndex() % 2 == 0) {
1847       bool isUnique = sources.insert(val).second;
1848       if (!isUnique) return op->emitError() << "duplicate sources not allowed.";
1849     } else {
1850       bool isUnique = targets.insert(val).second;
1851       if (!isUnique) return op->emitError() << "duplicate targets not allowed.";
1852     }
1853   }
1854   return success();
1855 }
1856 
verify()1857 LogicalResult CollectivePermuteOp::verify() {
1858   return verifyCollectivePermuteSourceTargetPairs(*this,
1859                                                        source_target_pairs());
1860 }
1861 
1862 //===----------------------------------------------------------------------===//
1863 // ConvolutionOp
1864 //===----------------------------------------------------------------------===//
1865 
1866 namespace {
1867 // Checks:
1868 //  P1. Same sizes for input, kernel and output spatial_dims.
1869 //  P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
1870 //      be unique and in range [0, num_dims), where num_dims = rank of input
1871 //      (lhs/rhs) tensors.
1872 //
1873 //  Note that the spatial + non-spatial dimensions may not cover all the
1874 //  dimensions in the range [0,num) because of the presence of 'unknown'
1875 //  dimensions (ref. cl/415132294).
isSpatialDimensionsValid(ConvolutionOp op)1876 LogicalResult isSpatialDimensionsValid(ConvolutionOp op) {
1877   auto inputSpatialDimensions =
1878       op.dimension_numbers().getInputSpatialDimensions();
1879   auto kernelSpatialDimensions =
1880       op.dimension_numbers().getKernelSpatialDimensions();
1881   auto outputSpatialDimensions =
1882       op.dimension_numbers().getOutputSpatialDimensions();
1883 
1884   // P1.
1885   if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) ||
1886       (inputSpatialDimensions.size() != outputSpatialDimensions.size()))
1887     return op.emitOpError() << "expects the same size for input, kernel and "
1888                                "output spatial-dimensions, but got "
1889                             << inputSpatialDimensions.size() << ", "
1890                             << kernelSpatialDimensions.size() << ", and "
1891                             << outputSpatialDimensions.size() << " resp.";
1892 
1893   // P2.
1894   SmallVector<int64_t> inputDnums(inputSpatialDimensions.size() + 2);
1895   inputDnums[0] = op.dimension_numbers().getInputBatchDimension();
1896   inputDnums[1] = op.dimension_numbers().getInputFeatureDimension();
1897   std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(),
1898             inputDnums.begin() + 2);
1899 
1900   SmallVector<int64_t> windowDnums(kernelSpatialDimensions.size() + 2);
1901   windowDnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
1902   windowDnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
1903   std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(),
1904             windowDnums.begin() + 2);
1905 
1906   SmallVector<int64_t> outputDnums(outputSpatialDimensions.size() + 2);
1907   outputDnums[0] = op.dimension_numbers().getOutputBatchDimension();
1908   outputDnums[1] = op.dimension_numbers().getOutputFeatureDimension();
1909   std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(),
1910             outputDnums.begin() + 2);
1911 
1912   auto numDims = op.lhs().getType().cast<RankedTensorType>().getRank();
1913   const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
1914 
1915   if (!llvm::all_of(inputDnums, inRange) ||
1916       !llvm::all_of(windowDnums, inRange) ||
1917       !llvm::all_of(outputDnums, inRange))
1918     return op.emitOpError() << "expects input, kernel, and output "
1919                                "dimension-numbers to be in-range [0, "
1920                             << numDims << ").";
1921 
1922   if (hasDuplicates(inputDnums))
1923     return op.emitOpError()
1924            << "expects input dimension-numbers to be unique, got {"
1925            << inputDnums << "}.";
1926 
1927   if (hasDuplicates(windowDnums))
1928     return op.emitOpError()
1929            << "expects kernel dimension-numbers to be unique, got {"
1930            << windowDnums << "}.";
1931 
1932   if (hasDuplicates(outputDnums))
1933     return op.emitOpError()
1934            << "expects output dimension-numbers to be unique, got {"
1935            << outputDnums << "}.";
1936 
1937   return success();
1938 }
1939 
1940 // Verifies the following properties:
1941 //  P1. The input, kernel, and output spatial-dimentions are valid.
1942 //  P2. Given,
1943 //          input-dimensions: b * input-spatial-dims * f
1944 //          kernel-dimensions: kernel-spatial-dims * i * o
1945 //          output-dimensions: b' * out-spatial-dims * f'
1946 //            where b = input-batch-dims
1947 //            where f = input-feature-dims
1948 //            where i = kernel-input-feature-dims
1949 //            where o = kernel-output-feature-dims
1950 //            where b' = output-batch-dims
1951 //            where f' = output-feature-dims
1952 //      Check the following properties w.r.t feature_group_count (fgc) and
1953 //      batch_group_count (bgc).
1954 //        fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
1955 //        b % bgc == 0
1956 //        f % fgc == 0 and i = f / fgc
1957 //        o (or f') % bgc == 0 and o (or f') % fgc == 0
verifyConvolutionAttributes(ConvolutionOp op)1958 LogicalResult verifyConvolutionAttributes(ConvolutionOp op) {
1959   // P1.
1960   if (failed(isSpatialDimensionsValid(op))) return failure();
1961 
1962   // P2.
1963   const int64_t featureGroupCount = op.feature_group_count();
1964   const int64_t batchGroupCount = op.batch_group_count();
1965 
1966   if (featureGroupCount <= 0)
1967     return op.emitOpError()
1968            << "expects feature_group_count to be a positive number, got "
1969            << featureGroupCount << ".";
1970 
1971   if (batchGroupCount <= 0)
1972     return op.emitOpError()
1973            << "expects batch_group_count to be a positive number, got "
1974            << batchGroupCount << ".";
1975 
1976   if (batchGroupCount > 1 && featureGroupCount > 1)
1977     return op.emitOpError()
1978            << "expects batch_group_count and feature_group_count not to be "
1979               "both greater than 1. Got "
1980            << batchGroupCount << " and " << featureGroupCount << " resp.";
1981 
1982   auto lhsType = op.lhs().getType().cast<RankedTensorType>();
1983   const int64_t inputFeatures =
1984       lhsType.getShape()[op.dimension_numbers().getInputFeatureDimension()];
1985   const int64_t inputBatch =
1986       lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
1987 
1988   auto rhsType = op.rhs().getType().cast<RankedTensorType>();
1989   const int64_t kernelInputFeatures =
1990       rhsType
1991           .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
1992   const int64_t kernelOutputFeatures =
1993       rhsType
1994           .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
1995 
1996   if (!isDynamicDimSize(kernelOutputFeatures)) {
1997     if (kernelOutputFeatures % batchGroupCount != 0)
1998       return op.emitOpError() << "expects output feature dimension size ("
1999                               << kernelOutputFeatures
2000                               << ") to be a multiple of "
2001                                  "batch_group_count. Got batch_group_count = "
2002                               << batchGroupCount << ".";
2003 
2004     if (kernelOutputFeatures % featureGroupCount != 0)
2005       return op.emitOpError()
2006              << "expects kernel output feature dimension ("
2007              << kernelOutputFeatures
2008              << ") to be divisible by "
2009                 "feature_group_count. For feature_group_count = "
2010              << featureGroupCount << ".";
2011   }
2012 
2013   if (!isDynamicDimSize(inputFeatures)) {
2014     if (inputFeatures % featureGroupCount != 0)
2015       return op.emitOpError()
2016              << "expects input feature dimension (" << inputFeatures
2017              << ") to be a multiple of "
2018                 "feature_group_count. Got feature_group_count = "
2019              << featureGroupCount << ".";
2020 
2021     if (!isDynamicDimSize(kernelInputFeatures) &&
2022         inputFeatures / featureGroupCount != kernelInputFeatures)
2023       return op.emitOpError()
2024              << "expects input feature dimension (" << inputFeatures
2025              << ") / "
2026                 "feature_group_count = kernel input feature dimension ("
2027              << kernelInputFeatures
2028              << "). Got feature_group_count = " << featureGroupCount << ".";
2029   }
2030 
2031   if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
2032     return op.emitOpError() << "expects input batch dimension (" << inputBatch
2033                             << ") to be divisible by "
2034                                "batch_group_count. Got batch_group_count = "
2035                             << batchGroupCount << ".";
2036 
2037   return success();
2038 }
2039 
2040 // Infer the return-shape of ConvolutionOp.
2041 // Precondition:
2042 //  1. Input args to ConvolutionOp 'op' are RankedTypes.
2043 //  2. rank-of(input-type) == rank-of(output-type)
inferConvolutionOpReturnShape(ConvolutionOp op,const ArrayRef<WindowDimension> window)2044 SmallVector<int64_t> inferConvolutionOpReturnShape(
2045     ConvolutionOp op, const ArrayRef<WindowDimension> window) {
2046   // We keep the 'unknown' dimensions (cl/415132294) as it is in the
2047   // output-shape. To do that we initilize the output dimensions with the shape
2048   // of the return-type and updates only the spatial + non-spatial dimensions.
2049   // Precondition 2 ensures that size of output-shape == size of input-shape.
2050   SmallVector<int64_t> outputDimensions =
2051       to_vector(op.getResult().getType().cast<ShapedType>().getShape());
2052 
2053   // Infer the output spatial dimensions.
2054   auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2055   auto inputSpatialDims = op.dimension_numbers().getInputSpatialDimensions();
2056   auto numSpatialDims = inputSpatialDims.size();
2057   SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
2058   for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
2059     inputSpatialDimVals[i] = lhsType.getShape()[inputSpatialDims[i]];
2060 
2061   auto windowOutputShape = inferWindowOutputShape(inputSpatialDimVals, window);
2062 
2063   for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i)
2064     outputDimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
2065         windowOutputShape[i];
2066 
2067   // Infer the output-batch-dimension and output-feature-dimension.
2068   auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2069   const int64_t inputBatch =
2070       lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2071   const int64_t kernelOutputFeatures =
2072       rhsType
2073           .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2074 
2075   outputDimensions[op.dimension_numbers().getOutputBatchDimension()] =
2076       isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize
2077                                    : inputBatch / op.batch_group_count();
2078   outputDimensions[op.dimension_numbers().getOutputFeatureDimension()] =
2079       kernelOutputFeatures;
2080 
2081   return outputDimensions;
2082 }
2083 
2084 }  // namespace
2085 
2086 /*
2087  * We intend to verify the following properties
2088  *  P1. Verify the input, kernel types.
2089  *  P2. Verify the convolution atributes.
2090  *  P3. Verify and collect the window atributes.
2091  *  P4. Verify the return shape.
2092  *      TODO(b/232574102): Verify the element-type of return-value.
2093  */
verify()2094 LogicalResult ConvolutionOp::verify() {
2095   auto lhsType = lhs().getType().dyn_cast<RankedTensorType>();
2096   auto rhsType = rhs().getType().dyn_cast<RankedTensorType>();
2097 
2098   if (!lhsType || !rhsType) return success();
2099 
2100   // P1.
2101   int numDims = lhsType.getRank();
2102   if (numDims != rhsType.getRank())
2103     return emitOpError()
2104            << "expects convolution arguments to have same number of "
2105               "dimensions. Got: "
2106            << lhsType << " and " << rhsType << ".";
2107 
2108   if (numDims < 2)
2109     return emitOpError()
2110            << "expects convolution arguments to have >= 2 dimensions. "
2111               "Got: "
2112            << lhsType << " and " << rhsType << ".";
2113 
2114   // P2.
2115   if (failed(verifyConvolutionAttributes(*this))) return failure();
2116 
2117   // P3.
2118   auto kernelSpatialDimensions =
2119       dimension_numbers().getKernelSpatialDimensions();
2120   SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size());
2121   for (size_t i = 0; i < windowDimensions.size(); i++)
2122     windowDimensions[i] = rhsType.getShape()[kernelSpatialDimensions[i]];
2123 
2124   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
2125   if (failed(paddingOrErr)) return failure();
2126   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
2127 
2128   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
2129       windowDimensions, convertDenseIntAttr(window_strides()), padding,
2130       convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
2131       getLoc());
2132   if (failed(windowOrErr)) return failure();
2133 
2134   // P4.
2135   auto actualReturnType = getResult().getType().cast<TensorType>();
2136   auto actualReturnElementType = actualReturnType.getElementType();
2137   if (!actualReturnType.hasRank()) return success();
2138 
2139   auto actualReturnRankedType = actualReturnType.cast<RankedTensorType>();
2140   if (numDims != actualReturnRankedType.getRank())
2141     return emitOpError() << "expects rank of convolution return-type to be "
2142                             "equal to input-ranks ("
2143                          << numDims << "), but got "
2144                          << actualReturnRankedType.getRank() << ".";
2145 
2146   auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr);
2147   auto expectedReturnType =
2148       RankedTensorType::get(expectedReturnShape, actualReturnElementType);
2149   if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType)))
2150     return emitOpError()
2151            << "has shape mismatch between the expected return-type ("
2152            << expectedReturnType << ") and actual return-type ("
2153            << actualReturnRankedType << ").";
2154 
2155   return success();
2156 }
2157 
2158 //===----------------------------------------------------------------------===//
2159 // ConvertOp
2160 //===----------------------------------------------------------------------===//
2161 
build(OpBuilder & builder,OperationState & result,Value operand,Type resultElementTy)2162 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
2163                       Type resultElementTy) {
2164   Type resultTy;
2165   Type operandTy = operand.getType();
2166   if (auto rankedTy = operandTy.dyn_cast<RankedTensorType>()) {
2167     resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy);
2168   } else {
2169     resultTy = UnrankedTensorType::get(resultElementTy);
2170   }
2171   build(builder, result, resultTy, operand);
2172 }
2173 
2174 //===----------------------------------------------------------------------===//
2175 // GetTupleElementOp
2176 //===----------------------------------------------------------------------===//
2177 
verify()2178 LogicalResult GetTupleElementOp::verify() {
2179   auto indexVal = index();
2180   auto operandType = getOperand().getType().cast<TupleType>();
2181   if (indexVal >= operandType.size()) {
2182     return emitOpError(
2183         llvm::formatv("index {0} is out of bounds of operand with size {1}",
2184                       indexVal, operandType.size()));
2185   }
2186 
2187   auto expectedType = operandType.getType(indexVal);
2188   if (getType() != expectedType) {
2189     return emitOpError(llvm::formatv("has return type {0}, but expected {1}",
2190                                      getType(), expectedType));
2191   }
2192   return success();
2193 }
2194 
2195 //===----------------------------------------------------------------------===//
2196 // TupleOp
2197 //===----------------------------------------------------------------------===//
2198 
verify()2199 LogicalResult TupleOp::verify() {
2200   auto opType = getType().dyn_cast<TupleType>();
2201   if (!opType) return emitOpError("tuple op with non-tuple result");
2202   if (getNumOperands() != opType.size())
2203     return emitOpError(
2204         "number of operands to tuple expected to match number of types in "
2205         "resultant tuple type");
2206   for (const auto& it :
2207        llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) {
2208     if (std::get<0>(it.value()) != std::get<1>(it.value()))
2209       return emitOpError("has return type mismatch at ")
2210              << it.index() << "th value (" << std::get<0>(it.value())
2211              << " != " << std::get<1>(it.value()) << ")";
2212   }
2213   return success();
2214 }
2215 
2216 //===----------------------------------------------------------------------===//
2217 // AllToAllOp
2218 //===----------------------------------------------------------------------===//
2219 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2220 LogicalResult AllToAllOp::inferReturnTypeComponents(
2221     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2222     DictionaryAttr attributes, RegionRange regions,
2223     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2224   AllToAllOp::Adaptor adaptor(operands, attributes, regions);
2225   Type operandType = adaptor.operand().getType();
2226   RankedTensorType operandRankedType = operandType.dyn_cast<RankedTensorType>();
2227   if (!operandRankedType) {
2228     inferredReturnShapes.emplace_back(
2229         operandType.cast<TensorType>().getElementType());
2230     return success();
2231   }
2232 
2233   int64_t inputRank = operandRankedType.getRank();
2234   int64_t splitDimension = static_cast<int64_t>(adaptor.split_dimension());
2235   int64_t concatDimension = static_cast<int64_t>(adaptor.concat_dimension());
2236   if (splitDimension >= inputRank || splitDimension < 0) {
2237     return emitOptionalError(location, "AllToAll split_dimension ",
2238                              splitDimension,
2239                              " is out-of-bounds for input rank ", inputRank);
2240   }
2241   if (concatDimension >= inputRank || concatDimension < 0) {
2242     return emitOptionalError(location, "AllToAll concat_dimension ",
2243                              concatDimension,
2244                              " is out-of-bounds for input rank ", inputRank);
2245   }
2246 
2247   // If operand is ranked, size of split dimension should be a multiple of split
2248   // count.
2249   int64_t splitCount = adaptor.split_count();
2250   auto splitDimSize = operandRankedType.getDimSize(splitDimension);
2251   if (splitDimSize % splitCount != 0) {
2252     return emitOptionalError(
2253         location, "split dimension has size ", splitDimSize,
2254         ", expected to be a multiple of split_count ", splitCount);
2255   }
2256   SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(),
2257                                    operandRankedType.getShape().end());
2258   resultShape[splitDimension] /= splitCount;
2259   resultShape[concatDimension] *= splitCount;
2260   inferredReturnShapes.emplace_back(resultShape,
2261                                     operandRankedType.getElementType());
2262   return success();
2263 }
2264 
2265 //===----------------------------------------------------------------------===//
2266 // AllGatherOp
2267 //===----------------------------------------------------------------------===//
2268 
verify()2269 LogicalResult AllGatherOp::verify() {
2270   // If operand and result are both ranked, then the size of the gather
2271   // dimension in the result should be a multiple of the size of the gather
2272   // dimension in the operand.
2273   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2274   auto resultType = getType().dyn_cast<RankedTensorType>();
2275   uint64_t allGatherDimIndex = all_gather_dim();
2276   if (!operandType || !resultType ||
2277       operandType.isDynamicDim(allGatherDimIndex) ||
2278       resultType.isDynamicDim(allGatherDimIndex))
2279     return success();
2280   if (operandType.getDimSize(allGatherDimIndex) == 0)
2281     return emitOpError() << "operand gather dimension cannot be zero.";
2282   if ((resultType.getDimSize(allGatherDimIndex) %
2283        operandType.getDimSize(allGatherDimIndex)) != 0)
2284     return emitOpError()
2285            << "result gather dimension has size "
2286            << resultType.getDimSize(allGatherDimIndex)
2287            << ", expected to be a multiple of operand gather dimension size "
2288            << operandType.getDimSize(allGatherDimIndex);
2289 
2290   return success();
2291 }
2292 
2293 //===----------------------------------------------------------------------===//
2294 // BatchNormGradOp
2295 //===----------------------------------------------------------------------===//
2296 
verify()2297 LogicalResult BatchNormGradOp::verify() {
2298   // The following properties are already enforced by the ODS:
2299   //  1. Inputs 'operand' & 'grad_output' and outputs 'grad_operand',
2300   //     are ranked-tensors with floating-point (fp) type.
2301   //  2. The shapes of inputs 'operand' & 'grad_output' match.
2302   //  3. Inputs 'scale', 'mean', 'variance' and Outputs 'grad_scale',
2303   //     'grad_offset'  are all 1D fp tensors with same shape.
2304   //  4. The element-types of input 'operand' and outputs 'grad_scale',
2305   //     'grad_offset' match.
2306   //  5. The type of input 'operand' and output 'grad_operand' match.
2307   //
2308   // We intend to verify the following properties
2309   //  P1. Inputs 'operand' & 'grad_output' has the same shape with fp
2310   //      element-types, ignoring fp-precision : Inferred from (1) & (2).
2311   //  P2. The feature dimension 'feature_index' is a valid index in 'operand':
2312   //      Inferred from check C2 below.
2313   //  P3. Inputs 'scale', 'mean', 'variance' must be 1D tensors with same shape
2314   //      and fp element-type (ignoring precision) and the number of elements
2315   //      in its sole-dimension == number of features in the 'operand's
2316   //      feature-dimension 'feature_index': Inferred from (3) and check C3
2317   //      below.
2318   //  P4. Outputs 'grad_scale' & 'grad_offset' are 1D tensors with
2319   //      element-type == element-type of(operand) and same shape as any of
2320   //      the inputs 'scale', 'mean', or 'variance': Inferred from (3), (4) and
2321   //      check C3 below.
2322   //  P5. The type (shape + element-type) of input 'operand' and
2323   //      output 'grad_operand' must match: Inferred from (5).
2324 
2325   // C2.
2326   auto operandType = operand().getType().cast<RankedTensorType>();
2327   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2328     return emitOpError() << "expects feature_index to be smaller "
2329                             "than the rank of operand type; got feature_index "
2330                          << feature_index() << ", and rank "
2331                          << operandType.getRank() << ".";
2332 
2333   if (static_cast<int64_t>(feature_index()) < 0)
2334     return emitOpError() << "expects feature_index to be a "
2335                          << "non-negative number, got "
2336                          << static_cast<int64_t>(feature_index()) << ".";
2337 
2338   auto gradOutputType = grad_output().getType().cast<RankedTensorType>();
2339   if (operandType.getRank() != gradOutputType.getRank())
2340     return emitOpError() << "expects 'operand' and 'grad_output' to have the "
2341                             "same rank. but got rank(oprand) "
2342                          << operandType.getRank() << " and rank(grad_output) "
2343                          << gradOutputType.getRank() << ".";
2344 
2345   // C3.
2346   const int64_t featureCount = operandType.getShape()[feature_index()];
2347   const int64_t scaleShape =
2348       scale().getType().cast<RankedTensorType>().getShape()[0];
2349   if (scaleShape != featureCount)
2350     return emitOpError() << "expects the size of scale factor to be "
2351                             "same as the feature count,"
2352                             " but the size of scale factor is "
2353                          << scaleShape << " and the feature count is "
2354                          << featureCount << ".";
2355 
2356   return success();
2357 }
2358 
2359 //===----------------------------------------------------------------------===//
2360 // BatchNormTrainingOp
2361 //===----------------------------------------------------------------------===//
2362 
verify()2363 LogicalResult BatchNormTrainingOp::verify() {
2364   // The following properties are already enforced by the ODS:
2365   //  1. 'operand' and 'output' are ranked tensors.
2366   //  2. 'scale', 'offset', 'batch_mean', 'batch_var' are 1D tensors.
2367   //  3. Types of 'operand' and 'output' matches.
2368   //  4. Same element-types for 'operand', 'batch_mean', & 'batch_var'.
2369   //  5. Same shapes for 'scale', 'offset', 'batch_mean', & 'batch_var'.
2370 
2371   auto operandType = operand().getType().cast<RankedTensorType>();
2372   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2373     return emitOpError() << "expects feature_index to be smaller "
2374                             "than the rank of operand type; got feature_index "
2375                          << feature_index() << ", and rank "
2376                          << operandType.getRank() << ".";
2377 
2378   if (static_cast<int64_t>(feature_index()) < 0)
2379     return emitOpError() << "expects feature_index to be a "
2380                          << "non-negative number, got "
2381                          << static_cast<int64_t>(feature_index()) << ".";
2382 
2383   // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2384 
2385   const int64_t featureCount = operandType.getShape()[feature_index()];
2386   const int64_t scaleShape =
2387       scale().getType().cast<RankedTensorType>().getShape()[0];
2388   // Check number of elements in input 'scale' equals feature_count.
2389   // Together with (5) implies that 'scale', 'offset', 'batch_mean', &
2390   // 'batch_var' all have the same shape.
2391   if (scaleShape != featureCount)
2392     return emitOpError() << "expects the size of scale factor to be "
2393                             "same as the feature count,"
2394                             " but the size of scale factor is "
2395                          << scaleShape << " and the feature count is "
2396                          << featureCount << ".";
2397 
2398   return success();
2399 }
2400 
2401 //===----------------------------------------------------------------------===//
2402 // BatchNormInferenceOp
2403 //===----------------------------------------------------------------------===//
2404 
verify()2405 LogicalResult BatchNormInferenceOp::verify() {
2406   // The following properties are already enforced by the ODS:
2407   //  1. 'operand' and 'result' are ranked tensors.
2408   //  2. 'scale', 'offset', 'mean', 'variance' are 1D tensors.
2409   //  3. Types of 'operand' and 'result' matches.
2410   //  4. Same shapes for 'scale', 'offset', 'mean', & 'variance'.
2411 
2412   auto operandType = operand().getType().cast<RankedTensorType>();
2413   if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2414     return emitOpError() << "expects feature_index to be smaller "
2415                             "than the rank of operand type; got feature_index "
2416                          << feature_index() << ", and rank "
2417                          << operandType.getRank() << ".";
2418 
2419   if (static_cast<int64_t>(feature_index()) < 0)
2420     return emitOpError() << "expects feature_index to be a "
2421                          << "non-negative number, got "
2422                          << static_cast<int64_t>(feature_index()) << ".";
2423 
2424   // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2425 
2426   const int64_t featureCount = operandType.getShape()[feature_index()];
2427   const int64_t scaleSize =
2428       scale().getType().cast<RankedTensorType>().getShape()[0];
2429   // Check number of elements in input 'scale' equals feature_count.
2430   // Together with (4) implies that 'scale', 'offset', 'mean', &
2431   // 'variance' all have the same shape.
2432   if (scaleSize != featureCount)
2433     return emitOpError() << "expects the size of scale factor to be "
2434                             "same as the feature count,"
2435                             " but the size of scale factor is "
2436                          << scaleSize << " and the feature count is "
2437                          << featureCount << ".";
2438 
2439   return success();
2440 }
2441 
2442 //===----------------------------------------------------------------------===//
2443 // BitcastConvertOp
2444 //===----------------------------------------------------------------------===//
2445 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2446 LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
2447     OpBuilder& builder, ValueRange operands,
2448     SmallVectorImpl<Value>& reifiedReturnShapes) {
2449   auto operandType = operands[0].getType().dyn_cast<RankedTensorType>();
2450   auto resultType = getType().dyn_cast<RankedTensorType>();
2451 
2452   // Only ranked tensors are supported.
2453   if (!operandType || !resultType) return failure();
2454 
2455   // Shape-changing bitcast convert is not implemented.
2456   // TODO(kramerb): This could be done by adjusting the last dimension.
2457   DataLayout dataLayout = DataLayout::closest(*this);
2458   unsigned operandElementSize =
2459       dataLayout.getTypeSizeInBits(operandType.getElementType());
2460   unsigned resultElementSize =
2461       dataLayout.getTypeSizeInBits(resultType.getElementType());
2462   if (operandElementSize != resultElementSize) return failure();
2463 
2464   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
2465                                      &reifiedReturnShapes);
2466 }
2467 
2468 /*
2469  * We intend to verify the following properties
2470  * P1. We cannot convert between complex and real types (cf xla)
2471  * P3. The dimensions of the operand and the target
2472  * shape must match, except that the shape with the smaller element bitwidth has
2473  * an appropriately-sized additional innermost dimension, e.g.
2474  * ... x f32 => [bitcast_convert] => ... x 4 x i8
2475  * ... x 4 x i8 => [bitcast_convert] => ... x f32
2476  */
verify()2477 LogicalResult BitcastConvertOp::verify() {
2478   auto operandTensorType = operand().getType().cast<TensorType>();
2479   auto targetTensorType = getResult().getType().cast<TensorType>();
2480 
2481   // P1.
2482   auto targetElt = targetTensorType.getElementType();
2483   auto operandElt = operandTensorType.getElementType();
2484   if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>()) {
2485     return emitOpError()
2486            << "cannot convert between real and complex types, but got: "
2487            << operandTensorType << " and " << targetTensorType;
2488   }
2489 
2490   auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt);
2491   auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt);
2492 
2493   // P2.
2494   auto operandType = operandTensorType.dyn_cast<RankedTensorType>();
2495   auto targetType = targetTensorType.dyn_cast<RankedTensorType>();
2496   if (!operandType || !targetType) return success();
2497 
2498   auto targetShape = targetType.getShape();
2499   auto operandShape = operandType.getShape();
2500   ArrayRef<int64_t> smallerEltShape, biggerEltShape;
2501   Type smallerElt, biggerElt;
2502   if (operandEltBitwidth < targetEltBitwidth) {
2503     smallerEltShape = operandShape;
2504     smallerElt = operandElt;
2505     biggerEltShape = targetShape;
2506     biggerElt = targetElt;
2507   } else {
2508     smallerEltShape = targetShape;
2509     smallerElt = targetElt;
2510     biggerEltShape = operandShape;
2511     biggerElt = operandElt;
2512   }
2513 
2514   ArrayRef<int64_t> smallerEltPrefix;
2515   auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth);
2516   auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth);
2517   if (operandEltBitwidth != targetEltBitwidth) {
2518     if (smallerEltShape.empty()) {
2519       return emitOpError() << "does not allow the smaller element type to be "
2520                               "part of a 0d tensor, but got: "
2521                            << operandType << " and " << targetType << ".";
2522     }
2523     smallerEltPrefix = smallerEltShape.drop_back();
2524     if (!isDynamicDimSize(smallerEltShape.back()) &&
2525         smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) {
2526       return emitOpError() << "requires compatible bitwidths. "
2527                            << "Got: " << operandType << " and " << targetType
2528                            << ", but " << smallerEltBitwidth << " * "
2529                            << smallerEltShape.back()
2530                            << " != " << biggerEltBitwidth << ".";
2531     }
2532   } else {
2533     smallerEltPrefix = smallerEltShape;
2534   }
2535 
2536   for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
2537     auto targetDim = std::get<0>(it);
2538     auto operandDim = std::get<1>(it);
2539     if (!isDynamicDimSize(targetDim) && !isDynamicDimSize(operandDim)) {
2540       if (targetDim != operandDim) {
2541         return emitOpError() << "operand and result shapes must match except "
2542                                 "for the innermost dimension of the shape with "
2543                                 "the smaller element type. Got: "
2544                              << operandType << " and " << targetType << ".";
2545       }
2546     }
2547   }
2548 
2549   return success();
2550 }
2551 
2552 //===----------------------------------------------------------------------===//
2553 // BroadcastOp
2554 //===----------------------------------------------------------------------===//
2555 
2556 // TODO(b/129012527) These should be expressed as type constraints.
verify()2557 LogicalResult BroadcastOp::verify() {
2558   auto sizes = broadcast_sizes();
2559   auto sizesType = sizes.getType();
2560   auto sizesRank = sizesType.getRank();
2561   if (sizesRank != 1) {
2562     return emitOpError(llvm::formatv(
2563         "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
2564   }
2565 
2566   return success();
2567 }
2568 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2569 LogicalResult BroadcastOp::inferReturnTypeComponents(
2570     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2571     DictionaryAttr attributes, RegionRange regions,
2572     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2573   BroadcastOp::Adaptor adaptor(operands, attributes, regions);
2574   Value operand = adaptor.operand();
2575   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
2576   if (!operandType) return failure();
2577 
2578   Type elementTy = operandType.getElementType();
2579   auto dimensionAttr = adaptor.broadcast_sizes();
2580   for (int64_t size : dimensionAttr.getValues<int64_t>()) {
2581     if (size < 0)
2582       return emitOptionalError(location,
2583                                "Broadcast with negative dimension size ", size);
2584   }
2585   SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>());
2586   llvm::append_range(shapeValues, operandType.getShape());
2587 
2588   inferredReturnShapes.emplace_back(shapeValues, elementTy);
2589   return success();
2590 }
2591 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2592 LogicalResult BroadcastOp::reifyReturnTypeShapes(
2593     OpBuilder& builder, ValueRange operands,
2594     SmallVectorImpl<Value>& reifiedReturnShapes) {
2595   BroadcastOp::Adaptor adaptor(operands);
2596   Value operand = adaptor.operand();
2597 
2598   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
2599   // Unranked tensors are not supported.
2600   if (!operandType) return failure();
2601 
2602   Location loc = getLoc();
2603   SmallVector<Value, 4> shapeValues;
2604 
2605   // Collect the broadcast sizes.
2606   for (const auto& size : broadcast_sizes()) {
2607     shapeValues.push_back(
2608         builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
2609   }
2610 
2611   // Collect the operand sizes.
2612   for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
2613     shapeValues.push_back(
2614         builder.createOrFold<tensor::DimOp>(loc, operand, index));
2615   }
2616 
2617   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
2618       loc,
2619       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
2620                             builder.getIndexType()),
2621       shapeValues));
2622 
2623   return success();
2624 }
2625 
2626 //===----------------------------------------------------------------------===//
2627 // BroadcastInDimOp
2628 //===----------------------------------------------------------------------===//
2629 
verify()2630 LogicalResult BroadcastInDimOp::verify() {
2631   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2632   if (!operandType) {
2633     // The following verification checks all depend on knowing the rank of
2634     // the operand. Bail out now if we don't know the rank of the operand.
2635     return success();
2636   }
2637 
2638   auto operandRank = operandType.getRank();
2639   if (!broadcast_dimensions()) {
2640     if (operandRank == 0) {
2641       return success();
2642     }
2643     return emitOpError(
2644         llvm::formatv("broadcast_dimensions is absent, but required because "
2645                       "operand has non-zero rank ({0})",
2646                       operandRank));
2647   }
2648 
2649   auto dimensions = broadcast_dimensions();
2650   auto dimensionsType = broadcast_dimensions().getType();
2651   auto dimensionsRank = dimensionsType.getRank();
2652   if (dimensionsRank != 1) {
2653     return emitOpError(llvm::formatv(
2654         "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
2655   }
2656 
2657   auto dimensionsSize = dimensionsType.getNumElements();
2658   if (dimensionsSize != operandRank) {
2659     return emitOpError(llvm::formatv(
2660         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
2661         dimensionsSize, operandRank));
2662   }
2663 
2664   auto resultType = getResult().getType().cast<RankedTensorType>();
2665   auto resultRank = resultType.getRank();
2666   if (resultRank < operandRank) {
2667     return emitOpError(
2668         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
2669                       resultRank, operandRank));
2670   }
2671 
2672   for (int i = 0; i != dimensionsSize; ++i) {
2673     auto dimIndex = dimensions.getValues<int64_t>()[i];
2674     if (dimIndex >= resultRank) {
2675       return emitOpError(
2676           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
2677                         "result with rank {1}",
2678                         dimIndex, resultRank));
2679     }
2680 
2681     if (!operandType.isDynamicDim(i)) {
2682       auto dimSize = operandType.getDimSize(i);
2683       auto resultDimSize = resultType.getDimSize(dimIndex);
2684       if (dimSize != 1 && dimSize != resultDimSize) {
2685         return emitOpError(
2686             llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
2687                           "1 or size of result dimension {2} ({3})",
2688                           i, dimSize, dimIndex, resultDimSize));
2689       }
2690     }
2691   }
2692 
2693   return success();
2694 }
2695 
2696 //===----------------------------------------------------------------------===//
2697 // DynamicBroadcastInDimOp
2698 //===----------------------------------------------------------------------===//
2699 
verify()2700 LogicalResult DynamicBroadcastInDimOp::verify() {
2701   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2702   auto resultType = getResult().getType().dyn_cast<RankedTensorType>();
2703 
2704   // If either the operand or result are unranked, there is very little
2705   // to verify statically.
2706   if (!operandType || !resultType) {
2707     return success();
2708   }
2709 
2710   auto outputDimensionsType =
2711       output_dimensions().getType().cast<RankedTensorType>();
2712   auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
2713   auto operandRank = operandType.getRank();
2714   auto resultRank = resultType.getRank();
2715 
2716   // Verify broadcast_dimensions.
2717   auto bcastDimensions = broadcast_dimensions();
2718   auto bcastDimensionsType = broadcast_dimensions().getType();
2719   auto bcastDimensionsRank = bcastDimensionsType.getRank();
2720   // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
2721   if (bcastDimensionsRank != 1) {
2722     return emitOpError(
2723         llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
2724                       bcastDimensionsRank));
2725   }
2726 
2727   auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
2728   if (bcastDimensionsSize != operandRank) {
2729     return emitOpError(llvm::formatv(
2730         "broadcast_dimensions size ({0}) does not match operand rank ({1})",
2731         bcastDimensionsSize, operandRank));
2732   }
2733 
2734   if (resultRank < operandRank) {
2735     return emitOpError(
2736         llvm::formatv("result rank ({0}) is less than operand rank ({1})",
2737                       resultRank, operandRank));
2738   }
2739 
2740   for (int i = 0; i != bcastDimensionsSize; ++i) {
2741     auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
2742     if (dimIndex >= resultRank) {
2743       return emitOpError(
2744           llvm::formatv("broadcast_dimensions contains invalid value {0} for "
2745                         "result with rank {1}",
2746                         dimIndex, resultRank));
2747     }
2748 
2749     auto dimSize = operandType.getDimSize(i);
2750     auto resultDimSize = resultType.getDimSize(dimIndex);
2751     // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
2752     // add a manual check for this.
2753     if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
2754       return emitOpError(
2755           llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
2756                         "with size of result dimension {2} ({3})",
2757                         i, dimSize, dimIndex, resultDimSize));
2758     }
2759   }
2760 
2761   if (outputDimensionsSize != resultRank) {
2762     return emitOpError(
2763         llvm::formatv("result rank ({0}) is not equal to number of output "
2764                       "dimensions ({1})",
2765                       resultRank, outputDimensionsSize));
2766   }
2767 
2768   // Verify that the known expanding and non-expanding dimensions are a subset
2769   // of the operand's dimensions.
2770   int64_t numKnownExpansionBehavior = 0;
2771   DenseSet<int64_t> knownExpansionBehavior;
2772   auto collectExpansionBehaviorDims =
2773       [&](const Optional<DenseIntElementsAttr>& attr) {
2774         if (!attr) return;
2775         for (const APInt& it : *attr) {
2776           numKnownExpansionBehavior++;
2777           knownExpansionBehavior.insert(it.getLimitedValue());
2778         }
2779       };
2780   collectExpansionBehaviorDims(known_expanding_dimensions());
2781   collectExpansionBehaviorDims(known_nonexpanding_dimensions());
2782   if (knownExpansionBehavior.size() != numKnownExpansionBehavior) {
2783     return emitOpError(
2784         "duplicate expansion hint for at least one operand dimension");
2785   }
2786   for (int64_t i : knownExpansionBehavior) {
2787     if (i < 0 || i >= operandRank) {
2788       return emitOpError(
2789           llvm::formatv("hint for expanding dimension {0} does not refer to a "
2790                         "valid operand dimension",
2791                         i));
2792     }
2793   }
2794 
2795   return success();
2796 }
2797 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2798 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
2799     OpBuilder& builder, ValueRange operands,
2800     SmallVectorImpl<Value>& reifiedReturnShapes) {
2801   DynamicBroadcastInDimOp::Adaptor adaptor(operands);
2802   reifiedReturnShapes.push_back(
2803       castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
2804   return success();
2805 }
2806 
2807 //===----------------------------------------------------------------------===//
2808 // ClampOp
2809 //===----------------------------------------------------------------------===//
2810 
verify()2811 LogicalResult ClampOp::verify() {
2812   auto operandType = operand().getType().cast<RankedTensorType>();
2813   auto operandShape = operandType.getShape();
2814   auto minType = min().getType().cast<RankedTensorType>();
2815 
2816   auto minShape = minType.getShape();
2817   if (failed(verifyCompatibleShape(minType, operandType)) &&
2818       minType.getRank() != 0) {
2819     return emitOpError(llvm::formatv(
2820         "min shape [{0}] is not scalar and is not compatible to operand shape "
2821         "[{1}]",
2822         llvm::make_range(minShape.begin(), minShape.end()),
2823         llvm::make_range(operandShape.begin(), operandShape.end())));
2824   }
2825 
2826   auto maxType = max().getType().cast<RankedTensorType>();
2827   auto maxShape = maxType.getShape();
2828   if (failed(verifyCompatibleShape(maxType, operandType)) &&
2829       maxType.getRank() != 0) {
2830     return emitOpError(llvm::formatv(
2831         "max shape [{0}] is not scalar and is not compatible to operand shape "
2832         "[{1}]",
2833         llvm::make_range(maxShape.begin(), maxShape.end()),
2834         llvm::make_range(operandShape.begin(), operandShape.end())));
2835   }
2836 
2837   return success();
2838 }
2839 
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2840 LogicalResult ClampOp::inferReturnTypeComponents(
2841     MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
2842     DictionaryAttr attributes, RegionRange regions,
2843     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2844   ClampOp::Adaptor adaptor(operands, attributes, regions);
2845   RankedTensorType operandType =
2846       adaptor.operand().getType().cast<RankedTensorType>();
2847   inferredReturnShapes.emplace_back(operandType.getShape(),
2848                                     operandType.getElementType());
2849   return success();
2850 }
2851 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2852 LogicalResult ClampOp::reifyReturnTypeShapes(
2853     OpBuilder& builder, ValueRange operands,
2854     SmallVectorImpl<Value>& reifiedReturnShapes) {
2855   // For `stablehlo.clamp`, the first operand may be a scalar.
2856   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1],
2857                                      &reifiedReturnShapes);
2858 }
2859 
2860 //===----------------------------------------------------------------------===//
2861 // ComplexOp
2862 //===----------------------------------------------------------------------===//
2863 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2864 LogicalResult ComplexOp::inferReturnTypes(
2865     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2866     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2867   TensorType operandType = operands[0].getType().cast<TensorType>();
2868   ComplexType elementTy = ComplexType::get(operandType.getElementType());
2869   inferredReturnTypes.push_back(
2870       hlo::getSameShapeTensorType(operandType, elementTy));
2871   return success();
2872 }
2873 
2874 //===----------------------------------------------------------------------===//
2875 // ImagOp
2876 //===----------------------------------------------------------------------===//
2877 
2878 namespace {
createRealType(TensorType type)2879 Type createRealType(TensorType type) {
2880   auto elementTy = type.getElementType();
2881   if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
2882     elementTy = complexTy.getElementType();
2883   }
2884   return hlo::getSameShapeTensorType(type, elementTy);
2885 }
2886 }  // namespace
2887 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2888 LogicalResult ImagOp::inferReturnTypes(
2889     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2890     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2891   inferredReturnTypes.push_back(
2892       createRealType(operands[0].getType().cast<TensorType>()));
2893   return success();
2894 }
2895 
2896 //===----------------------------------------------------------------------===//
2897 // IsFiniteOp
2898 //===----------------------------------------------------------------------===//
2899 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2900 LogicalResult IsFiniteOp::inferReturnTypes(
2901     MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
2902     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2903   auto argTy = operands.front().getType().cast<TensorType>();
2904   Builder b(ctx);
2905   inferredReturnTypes.push_back(
2906       hlo::getSameShapeTensorType(argTy, b.getI1Type()));
2907   return success();
2908 }
2909 
2910 //===----------------------------------------------------------------------===//
2911 // RealOp
2912 //===----------------------------------------------------------------------===//
2913 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2914 LogicalResult RealOp::inferReturnTypes(
2915     MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2916     RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2917   inferredReturnTypes.push_back(
2918       createRealType(operands[0].getType().cast<TensorType>()));
2919   return success();
2920 }
2921 
2922 //===----------------------------------------------------------------------===//
2923 // ConcatenateOp
2924 //===----------------------------------------------------------------------===//
2925 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2926 LogicalResult ConcatenateOp::inferReturnTypes(
2927     MLIRContext*, Optional<Location> location, ValueRange operands,
2928     DictionaryAttr attributes, RegionRange regions,
2929     SmallVectorImpl<Type>& inferredReturnTypes) {
2930   if (operands.empty()) {
2931     return failure();
2932   }
2933 
2934   auto dimensionAttr = attributes.get("dimension").cast<IntegerAttr>();
2935   auto dimension = dimensionAttr.getInt();
2936 
2937   auto firstType = (*operands.begin()).getType().cast<ShapedType>();
2938   auto outElement = firstType.getElementType();
2939 
2940   // Find the first ranked input to determine the output rank.
2941   for (auto type : operands.getTypes()) {
2942     auto shapedType = type.cast<ShapedType>();
2943     if (shapedType.hasRank()) {
2944       firstType = shapedType;
2945       break;
2946     }
2947   }
2948 
2949   // If all inputs are unranked, the result must be unranked.
2950   if (!firstType.hasRank()) {
2951     inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
2952     return success();
2953   }
2954 
2955   auto outShape = llvm::to_vector<6>(firstType.getShape());
2956 
2957   // Determine what the non-concatenate dimensions should be.
2958   for (auto type : operands.getTypes()) {
2959     auto shapedTy = type.cast<ShapedType>();
2960     if (!shapedTy.hasRank()) {
2961       continue;
2962     }
2963 
2964     for (const auto& it : llvm::enumerate(shapedTy.getShape())) {
2965       // If a dimension is not dynamic, the output shape should match.
2966       if (ShapedType::isDynamic(outShape[it.index()])) {
2967         outShape[it.index()] = it.value();
2968       }
2969     }
2970   }
2971 
2972   outShape[dimension] = 0;
2973 
2974   for (auto operand : operands.getTypes()) {
2975     auto type = operand.cast<ShapedType>();
2976     if (!type.hasRank()) {
2977       inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
2978       return success();
2979     }
2980 
2981     // If the dimension is dynamic we know the output dimension is dynamic.
2982     auto dim = type.getShape()[dimension];
2983     if (ShapedType::isDynamic(dim)) {
2984       outShape[dimension] = ShapedType::kDynamicSize;
2985       break;
2986     }
2987 
2988     outShape[dimension] += dim;
2989   }
2990 
2991   inferredReturnTypes.push_back(RankedTensorType::get(outShape, outElement));
2992 
2993   return success();
2994 }
2995 
verify()2996 LogicalResult ConcatenateOp::verify() {
2997   RankedTensorType firstRankedType;
2998   int firstRankedIndex;
2999   int numOperands = getNumOperands();
3000   int64_t concatDimension = static_cast<int64_t>(dimension());
3001   if (concatDimension < 0) {
3002     return emitOpError(
3003         llvm::formatv("dimension {0} is negative", concatDimension));
3004   }
3005   for (int i = 0; i < numOperands; i++) {
3006     auto secondType = getOperand(i).getType().dyn_cast<ShapedType>();
3007     if (!secondType.hasRank()) {
3008       continue;
3009     }
3010 
3011     if (!firstRankedType) {
3012       firstRankedType = secondType.cast<RankedTensorType>();
3013       firstRankedIndex = i;
3014       if (firstRankedType.getRank() == 0)
3015         return emitOpError(
3016             llvm::formatv("rank-0 values cannot be concatenated"));
3017       if (concatDimension >= firstRankedType.getRank()) {
3018         return emitOpError(
3019             llvm::formatv("dimension {0} is out-of-bounds for input rank {1}",
3020                           concatDimension, firstRankedType.getRank()));
3021       }
3022       continue;
3023     }
3024 
3025     if (firstRankedType.getRank() != secondType.getRank()) {
3026       return emitOpError(llvm::formatv(
3027           "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i));
3028     }
3029 
3030     auto firstShape = firstRankedType.getShape();
3031     auto secondShape = secondType.getShape();
3032     for (int d = 0; d < firstRankedType.getRank(); ++d) {
3033       if (!ShapedType::isDynamic(firstShape[d]) &&
3034           !ShapedType::isDynamic(secondShape[d]) &&
3035           firstShape[d] != secondShape[d] && d != concatDimension) {
3036         return emitOpError(llvm::formatv(
3037             "shapes of operand ({0}) and ({1}) do not match at non-concat "
3038             "index: ({2}) != ({3}) at non-concat index {4}",
3039             firstRankedIndex, i,
3040             llvm::make_range(firstShape.begin(), firstShape.end()),
3041             llvm::make_range(secondShape.begin(), secondShape.end()), d));
3042       }
3043     }
3044   }
3045   return success();
3046 }
3047 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3048 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
3049     OpBuilder& builder, ValueRange operands,
3050     SmallVectorImpl<Value>& reifiedReturnShapes) {
3051   ConcatenateOp::Adaptor adaptor(operands);
3052   auto inputs = adaptor.val();
3053 
3054   auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
3055   // Not support unranked type a.t.m.
3056   if (!operandType) return failure();
3057 
3058   Location loc = this->getLoc();
3059   Type shapeScalarType = builder.getIndexType();
3060   auto toShapeScalarType = [&](Value v) {
3061     return maybeCastTo(builder, loc, v, shapeScalarType);
3062   };
3063 
3064   SmallVector<SmallVector<Value, 4>, 4> allShapeValues;
3065   for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
3066     Value operand = inputs[inputId];
3067     auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3068     if (!operandType) return failure();
3069 
3070     SmallVector<Value, 4> shapeVals;
3071     for (const auto& element : llvm::enumerate(operandType.getShape())) {
3072       Value valueDim = toShapeScalarType(
3073           builder.create<tensor::DimOp>(loc, operand, element.index()));
3074       shapeVals.push_back(valueDim);
3075     }
3076     allShapeValues.emplace_back(std::move(shapeVals));
3077   }
3078 
3079   int axis = this->dimension();
3080   auto& shapeValues = allShapeValues[0];
3081   for (size_t vecId = 1; vecId < allShapeValues.size(); ++vecId) {
3082     auto& otherShapeValues = allShapeValues[vecId];
3083     if (otherShapeValues.size() != shapeValues.size()) {
3084       this->emitOpError()
3085           << "Concatenate expects all operands must be of the same rank";
3086       return failure();
3087     }
3088     shapeValues[axis] = builder.create<arith::AddIOp>(loc, shapeValues[axis],
3089                                                       otherShapeValues[axis]);
3090   }
3091 
3092   Value outputShape = builder.create<tensor::FromElementsOp>(
3093       loc,
3094       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3095                             shapeScalarType),
3096       shapeValues);
3097   reifiedReturnShapes.push_back(outputShape);
3098 
3099   return success();
3100 }
3101 
3102 //===----------------------------------------------------------------------===//
3103 // DynamicReshapeOp
3104 //===----------------------------------------------------------------------===//
3105 
verify()3106 LogicalResult DynamicReshapeOp::verify() {
3107   auto resultType = result().getType().dyn_cast<RankedTensorType>();
3108   auto outputShapeType = output_shape().getType().dyn_cast<RankedTensorType>();
3109   if (resultType && outputShapeType && outputShapeType.hasStaticShape() &&
3110       outputShapeType.getDimSize(0) != resultType.getRank()) {
3111     return emitError() << "output should have a rank equal to the number of "
3112                           "elements in output_shape";
3113   }
3114   return success();
3115 }
3116 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3117 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
3118     OpBuilder& builder, ValueRange operands,
3119     SmallVectorImpl<Value>& reifiedReturnShapes) {
3120   DynamicReshapeOp::Adaptor adaptor(operands);
3121   reifiedReturnShapes.push_back(
3122       castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
3123   return success();
3124 }
3125 
3126 //===----------------------------------------------------------------------===//
3127 // DynamicSliceOp
3128 //===----------------------------------------------------------------------===//
3129 
3130 // Verifies that the number of slice sizes and the number of start indices match
verify()3131 LogicalResult DynamicSliceOp::verify() {
3132   int numSliceSizes = slice_sizes().getNumElements();
3133   int numStartIndices = start_indices().size();
3134   if (numStartIndices != numSliceSizes) {
3135     return emitOpError() << "has mismatched number of slice sizes ("
3136                          << numSliceSizes << ") and number of start indices ("
3137                          << numStartIndices << ")";
3138   }
3139   auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3140   if (!operandType) return failure();
3141 
3142   if (operandType.getRank() != numStartIndices) {
3143     return emitOpError() << "has mismatched number of start indices ("
3144                          << numStartIndices << ") and the rank of operand ("
3145                          << operandType.getRank() << ")";
3146   }
3147 
3148   for (int i = 0; i < numSliceSizes; ++i) {
3149     int64_t sliceSize = slice_sizes().getValues<int64_t>()[i];
3150     if (sliceSize < 0) {
3151       return emitOpError() << "has negative size index to dynamic slice: "
3152                            << sliceSize;
3153     }
3154     if (!operandType.isDynamicDim(i)) {
3155       int64_t dimSize = operandType.getDimSize(i);
3156       if (sliceSize > dimSize) {
3157         return emitOpError() << "has slice size " << sliceSize
3158                              << " greater than dimension size " << dimSize
3159                              << " in dimension " << i << " of operand";
3160       }
3161     }
3162   }
3163   return success();
3164 }
3165 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3166 LogicalResult DynamicSliceOp::inferReturnTypeComponents(
3167     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
3168     DictionaryAttr attributes, RegionRange regions,
3169     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3170   DynamicSliceOp::Adaptor adaptor(operands, attributes, regions);
3171   Value operand = adaptor.operand();
3172   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3173   if (!operandType) return failure();
3174 
3175   auto sliceSizes = adaptor.slice_sizes();
3176   Type elementTy = operandType.getElementType();
3177   inferredReturnShapes.emplace_back(sliceSizes.getValues<int64_t>(), elementTy);
3178   return success();
3179 }
3180 
3181 //===----------------------------------------------------------------------===//
3182 // RealDynamicSliceOp
3183 //===----------------------------------------------------------------------===//
3184 // Verifies that operand rank matches start_indices/limit_indices/strides size
verify()3185 LogicalResult RealDynamicSliceOp::verify() {
3186   auto inputType = operand().getType().dyn_cast<RankedTensorType>();
3187   // If operand is unranked, there is very little to verify statically.
3188   if (!inputType) return success();
3189   int inputRank = inputType.getRank();
3190 
3191   auto startType = start_indices().getType().cast<RankedTensorType>();
3192   auto limitType = limit_indices().getType().cast<RankedTensorType>();
3193   auto stridesType = strides().getType().cast<RankedTensorType>();
3194 
3195   if (inputRank != startType.getNumElements()) {
3196     return emitOpError() << "has mismatched number of operand rank ("
3197                          << inputRank << ") and start_indices size ("
3198                          << startType.getNumElements() << ")";
3199   }
3200 
3201   if (inputRank != limitType.getNumElements()) {
3202     return emitOpError() << "has mismatched number of operand rank ("
3203                          << inputRank << ") and limit_indices size ("
3204                          << limitType.getNumElements() << ")";
3205   }
3206 
3207   if (inputRank != stridesType.getNumElements()) {
3208     return emitOpError() << "has mismatched number of operand rank ("
3209                          << inputRank << ") and strides size ("
3210                          << stridesType.getNumElements() << ")";
3211   }
3212 
3213   return success();
3214 }
3215 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3216 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
3217     OpBuilder& builder, ValueRange operands,
3218     SmallVectorImpl<Value>& reifiedReturnShapes) {
3219   RealDynamicSliceOp::Adaptor adaptor(operands);
3220   Value operand = adaptor.operand();
3221   Value startIndices = adaptor.start_indices();
3222   Value limitIndices = adaptor.limit_indices();
3223   Value strides = adaptor.strides();
3224 
3225   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3226   // Not support unranked type a.t.m.
3227   if (!operandType) return failure();
3228 
3229   Location loc = this->getLoc();
3230   SmallVector<Value, 4> shapeValues;
3231   shapeValues.reserve(operandType.getRank());
3232   Type shapeScalarType =
3233       startIndices.getType().cast<ShapedType>().getElementType();
3234   Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
3235   one = maybeCastTo(builder, loc, one, shapeScalarType);
3236   for (const auto& element : llvm::enumerate(operandType.getShape())) {
3237     Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index());
3238     Value valueStart =
3239         builder.create<tensor::ExtractOp>(loc, startIndices, offset);
3240     Value valueLimit =
3241         builder.create<tensor::ExtractOp>(loc, limitIndices, offset);
3242     Value valueStride = builder.create<tensor::ExtractOp>(loc, strides, offset);
3243     // size = (limit - start + stride - 1) / stride
3244     shapeValues.push_back(builder.create<arith::DivSIOp>(
3245         loc,
3246         builder.create<arith::SubIOp>(
3247             loc,
3248             builder.create<arith::AddIOp>(
3249                 loc, valueStride,
3250                 builder.create<arith::SubIOp>(loc, valueLimit, valueStart)),
3251             one),
3252         valueStride));
3253   }
3254 
3255   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
3256       loc,
3257       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3258                             shapeScalarType),
3259       shapeValues));
3260   return success();
3261 }
3262 
3263 //===----------------------------------------------------------------------===//
3264 // InfeedOp
3265 //===----------------------------------------------------------------------===//
3266 
3267 // Checks that the result type is of the form `zero_or_more_type(s),
3268 // stablehlo::token`
verify()3269 LogicalResult InfeedOp::verify() {
3270   auto resultTypes = getResultTypes();
3271   if (resultTypes.empty())
3272     return emitOpError()
3273            << "result is expected to be at least of size 1, but got "
3274            << resultTypes.size();
3275 
3276   if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
3277     return emitOpError() << "last element of result types is expected to "
3278                             "be of token type, but got "
3279                          << resultTypes[resultTypes.size() - 1];
3280 
3281   // Verify layout attribute
3282   constexpr char kLayoutAttr[] = "layout";
3283   if (!getOperation()->hasAttr(kLayoutAttr)) return success();
3284 
3285   mlir::ArrayAttr layout =
3286       getOperation()->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
3287   if (!layout)
3288     return emitOpError() << "layout-attribute expected to be of array-type.";
3289 
3290   if (layout.size() != resultTypes.size() - 1) {
3291     return emitOpError() << "layout-attribute size must be "
3292                          << resultTypes.size() - 1
3293                          << " (which is the number of "
3294                             "op-results - 1 (for token result)), but got "
3295                          << layout.size();
3296   }
3297 
3298   for (auto childLayout : layout) {
3299     mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>();
3300     if (!childLayoutArr) {
3301       return emitOpError() << "layout-attribute expected to have "
3302                               "elements of type array, but got "
3303                            << childLayout;
3304     }
3305 
3306     for (auto i : childLayoutArr) {
3307       mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>();
3308       if (!attr) {
3309         return emitOpError() << "layout-attribute's leaf elements are "
3310                                 "expected to be of type integer, but got "
3311                              << i;
3312       }
3313     }
3314   }
3315 
3316   return success();
3317 }
3318 
3319 //===----------------------------------------------------------------------===//
3320 // MapOp
3321 //===----------------------------------------------------------------------===//
3322 
verify()3323 LogicalResult MapOp::verify() {
3324   // Checks if the number of `operands` match the arity of the map `computation`
3325   // region.
3326   auto& computationBlock = computation().front();
3327   auto computationArgs = computationBlock.getArguments();
3328   if (operands().size() != computationArgs.size())
3329     return emitOpError() << "expects number of operands to match the arity "
3330                             "of map computation, but got: "
3331                          << operands().size() << " and "
3332                          << computationArgs.size();
3333 
3334   // The parameters of computation should all be scalars and match the element
3335   // type of operands.
3336   for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
3337     auto argType = indexedArg.value().getType().dyn_cast<TensorType>();
3338     if (!argType || argType.getRank() != 0)
3339       return emitOpError()
3340              << "computation arguments must be 0-rank tensor, but got: arg #"
3341              << indexedArg.index() << " of type "
3342              << indexedArg.value().getType();
3343     auto operandElemTy = operands()[indexedArg.index()]
3344                              .getType()
3345                              .cast<TensorType>()
3346                              .getElementType();
3347     if (argType.getElementType() != operandElemTy) {
3348       return emitOpError()
3349              << "element type of operands and computation arguments must "
3350                 "match, but got: "
3351              << operandElemTy << " and " << argType.getElementType();
3352     }
3353   }
3354 
3355   // Mapped computation must return single output
3356   auto computationOutputs = computationBlock.getTerminator()->getOperands();
3357   if (computationOutputs.size() != 1)
3358     return emitOpError() << "computation must return single output, but got: "
3359                          << computationOutputs.size();
3360 
3361   // The output of computation must be scalar and have the same element type
3362   // as op result.
3363   auto computationOutputType =
3364       computationOutputs[0].getType().dyn_cast<TensorType>();
3365   if (!computationOutputType || computationOutputType.getRank() != 0)
3366     return emitOpError() << "computation must return 0-rank tensor, but got: "
3367                          << computationOutputs[0].getType();
3368 
3369   auto resultType = getType().cast<TensorType>();
3370   if (computationOutputType.getElementType() != resultType.getElementType())
3371     return emitOpError() << "element type of result and computation output "
3372                             "must match, but got: "
3373                          << resultType.getElementType() << " and "
3374                          << computationOutputType.getElementType();
3375 
3376   // Checks that the requested map dimension numbers are monotonically
3377   // increasing.
3378   DenseIntElementsAttr dimensions = this->dimensions();
3379   for (const auto& indexedValue :
3380        llvm::enumerate(dimensions.getValues<int64_t>())) {
3381     if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
3382       return emitOpError() << "requires monotonically increasing dimension "
3383                               "numbers, but got: "
3384                            << dimensions;
3385   }
3386 
3387   // Checks that number of dimensions of operands matches the size of
3388   // `dimensions` since we currently only support mapping across all
3389   // dimensions: i.e., scalar map functions.
3390   auto operandType = operands()[0].getType().cast<TensorType>();
3391   if (operandType.hasRank()) {
3392     if (dimensions.size() !=
3393         static_cast<int64_t>(operandType.getShape().size()))
3394       return emitOpError()
3395              << "applied to a subset of dimensions currently not supported: "
3396                 "operand dimensions = "
3397              << operandType.getShape().size()
3398              << ", requested map dimensions size = " << dimensions.size();
3399   }
3400 
3401   return success();
3402 }
3403 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3404 LogicalResult MapOp::reifyReturnTypeShapes(
3405     OpBuilder& builder, ValueRange operands,
3406     SmallVectorImpl<Value>& reifiedReturnShapes) {
3407   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
3408                                      &reifiedReturnShapes);
3409 }
3410 
3411 //===----------------------------------------------------------------------===//
3412 // RecvOp
3413 //===----------------------------------------------------------------------===//
3414 
3415 // Checks that the result type is of the form `zero_or_more_type(s),
3416 // stablehlo::token`
verify()3417 LogicalResult RecvOp::verify() {
3418   auto resultTypes = getResultTypes();
3419   if (resultTypes.empty())
3420     return emitOpError()
3421            << "result is expected to be at least of size 1, but got "
3422            << resultTypes.size();
3423   if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
3424     return emitOpError() << "last element of result types is expected to "
3425                             "be of token type, but got "
3426                          << resultTypes[resultTypes.size() - 1];
3427   return success();
3428 }
3429 
3430 //===----------------------------------------------------------------------===//
3431 // ReduceWindowOp
3432 //===----------------------------------------------------------------------===//
3433 
3434 namespace {
3435 // Infer the return-type of ReduceWindowOp.
inferReduceWindowOpReturnType(ArrayRef<TensorType> inputTypes,ArrayRef<TensorType> initTypes,const ArrayRef<WindowDimension> window)3436 SmallVector<TensorType> inferReduceWindowOpReturnType(
3437     ArrayRef<TensorType> inputTypes, ArrayRef<TensorType> initTypes,
3438     const ArrayRef<WindowDimension> window) {
3439   SmallVector<TensorType> outputTypes;
3440   for (size_t i = 0; i < inputTypes.size(); ++i) {
3441     if (!inputTypes[i].hasRank()) {
3442       outputTypes.push_back(
3443           UnrankedTensorType::get(initTypes[i].getElementType()));
3444       continue;
3445     }
3446 
3447     outputTypes.push_back(RankedTensorType::get(
3448         inferWindowOutputShape(inputTypes[i].getShape(), window),
3449         initTypes[i].getElementType()));
3450   }
3451 
3452   return outputTypes;
3453 }
3454 }  // namespace
3455 
3456 // We intend to verify the following properties
3457 //  P1. The sizes of 'inputs' and 'init_values' must be at least 1.
3458 //  P2. All `inputs` need to have compatible shapes.
3459 //  P3. size-of(window_dimension) == rank-of(input),
3460 //        where input is an element of 'inputs'.
3461 //  P4. Verify and collect the window atributes.
3462 //  P5. Verify the inner block defining the reducer function.
3463 //  P6. Verify the return type.
verify()3464 LogicalResult ReduceWindowOp::verify() {
3465   // P1.
3466   // Note that the ODS ensures that there are even number of operands; Check if
3467   // that number is not zero.
3468   if (getOperands().empty())
3469     return emitOpError() << "expects the size of operands to be >= 2.";
3470 
3471   // Collect the input and init-value operands. Note that the operand-type is
3472   // enforced as "TensorType" by ODS.
3473   int64_t numInputs = getNumOperands() / 2;
3474   auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
3475       getOperandTypes(),
3476       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3477   ArrayRef<TensorType> inputTypes(operandTensorTypes.begin(),
3478                                   operandTensorTypes.begin() + numInputs);
3479   ArrayRef<TensorType> initTypes(operandTensorTypes.begin() + numInputs,
3480                                  operandTensorTypes.end());
3481 
3482   // P2.
3483   if (failed(verifyCompatibleShapes(operands().getTypes())))
3484     return emitOpError() << "requires same shape for all inputs";
3485 
3486   // P3.
3487   SmallVector<int64_t> windowDims =
3488       convertDenseIntAttr(this->window_dimensions());
3489   for (const auto inputType : inputTypes) {
3490     if (!inputType.hasRank()) continue;
3491     if (inputType.getRank() != static_cast<int64_t>(windowDims.size()))
3492       return emitOpError()
3493              << "expects window-dimensions size == input rank, but got "
3494                 "window-dimensions size: "
3495              << windowDims.size() << " and input: " << inputType
3496              << " with rank = " << inputType.getRank() << ".";
3497   }
3498 
3499   // P4.
3500   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
3501   if (failed(paddingOrErr)) return failure();
3502   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
3503 
3504   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
3505       windowDims, convertDenseIntAttr(window_strides()), padding,
3506       /*lhs_dilation=*/convertDenseIntAttr(base_dilations()),
3507       /*rhs_dilation=*/convertDenseIntAttr(this->window_dilations()), getLoc());
3508   if (failed(windowOrErr)) return failure();
3509 
3510   // P5.
3511   bool allInputsUnranked =
3512       llvm::all_of(inputTypes, [](TensorType t) { return !t.hasRank(); });
3513 
3514   Block& block = body().front();
3515   SmallVector<TensorType> accumulatorSubshapes;
3516   if (failed(verifyReducerShape(this->getLoc(), block, inputTypes, initTypes,
3517                                 numInputs, windowDims, allInputsUnranked,
3518                                 accumulatorSubshapes)))
3519     return failure();
3520 
3521   // P6.
3522   if (numInputs != getNumResults())
3523     return emitOpError() << "expects " << numInputs
3524                          << " result values, but got " << getNumResults()
3525                          << ".";
3526 
3527   // The result-type is enforced as "TensorType" by ODS.
3528   auto resultTensorTypes = llvm::to_vector<4>(llvm::map_range(
3529       getResultTypes(),
3530       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3531 
3532   // Check if the element-type of results match with the ones derived from
3533   // the reducer-block. Already ensured that  |accumulator_subshapes| ==
3534   // num_inputs == num_of_results.
3535   for (int64_t shapeIdx = 0;
3536        shapeIdx < static_cast<int64_t>(accumulatorSubshapes.size());
3537        shapeIdx++) {
3538     if (accumulatorSubshapes[shapeIdx].getElementType() !=
3539         resultTensorTypes[shapeIdx].getElementType()) {
3540       return emitError()
3541              << "expects the element-type of reduce-op's return-value at index "
3542              << shapeIdx
3543              << " to match the element-type of reducer-block's "
3544                 "corresponding return-value, but got "
3545              << resultTensorTypes[shapeIdx].getElementType() << " and "
3546              << accumulatorSubshapes[shapeIdx].getElementType() << " resp.";
3547     }
3548   }
3549 
3550   // Check if the shape of results match with the ones derived from
3551   // the input-types and wndow-attributes.
3552   auto inferredReturnTypes = inferReduceWindowOpReturnType(
3553       inputTypes, accumulatorSubshapes, *windowOrErr);
3554 
3555   for (size_t i = 0; i < getNumResults(); i++) {
3556     if (failed(verifyCompatibleShape(resultTensorTypes[i],
3557                                      inferredReturnTypes[i]))) {
3558       return emitOpError()
3559              << "expects result at index " << i
3560              << " to have compatible shape with the corresponding "
3561                 "inferred type, but got "
3562              << resultTensorTypes[i] << " and " << inferredReturnTypes[i]
3563              << " resp.";
3564     }
3565   }
3566 
3567   return success();
3568 }
3569 
3570 // Get the operation used for reduction applied to `result_index`th result. Its
3571 // expected to be a binary operation that consumes `result_index`th and
3572 // `result_index + operands().size`th arguments of the body.
getReductionOp(int resultIndex)3573 Operation* ReduceWindowOp::getReductionOp(int resultIndex) {
3574   auto returnOp = cast<ReturnOp>(body().front().getTerminator());
3575   Operation* computeOp = returnOp.results()[resultIndex].getDefiningOp();
3576   if (computeOp->getNumOperands() != 2) return nullptr;
3577   auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
3578   auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
3579   if (!arg0 || !arg1) return nullptr;
3580   int64_t arg0Num = arg0.getArgNumber();
3581   int64_t arg1Num = arg1.getArgNumber();
3582   int64_t otherArgIndex = resultIndex + operands().size();
3583   if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
3584   if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
3585       computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
3586     return computeOp;
3587   return nullptr;
3588 }
3589 
3590 //===----------------------------------------------------------------------===//
3591 // ReducePrecisionOp
3592 //===----------------------------------------------------------------------===//
3593 
3594 // The following property is already enforced by the ODS:
3595 //  P0. operand element type is float
3596 //  P1. mantissa_bits >= 0
3597 // We intend to verify the following properties
3598 //  P2. exponent_bits >= 1
verify()3599 LogicalResult ReducePrecisionOp::verify() {
3600   if (exponent_bits() < 1) {
3601     return emitOpError() << "exponent_bits must be at least 1.";
3602   }
3603   return success();
3604 }
3605 
3606 //===----------------------------------------------------------------------===//
3607 // ReduceOp
3608 //===----------------------------------------------------------------------===//
3609 
3610 // Returns the result type after reducing operand of the given type across the
3611 // specified dimensions.
getReduceResultType(Type operandTy,DenseIntElementsAttr dimensions,Builder * builder)3612 static TensorType getReduceResultType(Type operandTy,
3613                                       DenseIntElementsAttr dimensions,
3614                                       Builder* builder) {
3615   Type elementTy = getElementTypeOrSelf(operandTy);
3616 
3617   auto rankedTy = operandTy.dyn_cast<RankedTensorType>();
3618   if (!rankedTy) return UnrankedTensorType::get(elementTy);
3619 
3620   int64_t rank = rankedTy.getRank();
3621   llvm::SmallVector<bool, 4> dimsMask(rank, false);
3622   for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true;
3623 
3624   SmallVector<int64_t, 4> shape;
3625   for (int64_t i = 0; i < rank; ++i) {
3626     if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i));
3627   }
3628 
3629   return RankedTensorType::get(shape, elementTy);
3630 }
3631 
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange initValues,DenseIntElementsAttr dimensions)3632 void ReduceOp::build(OpBuilder& builder, OperationState& state,
3633                      ValueRange operands, ValueRange initValues,
3634                      DenseIntElementsAttr dimensions) {
3635   SmallVector<Type, 1> resultTy;
3636   resultTy.reserve(operands.size());
3637 
3638   for (Value operand : operands) {
3639     resultTy.push_back(
3640         getReduceResultType(operand.getType(), dimensions, &builder));
3641   }
3642   build(builder, state, resultTy, operands, initValues, dimensions);
3643 }
3644 
hasSameOperandAndResultTypes(Operation & op)3645 bool hasSameOperandAndResultTypes(Operation& op) {
3646   Type expected;
3647   if (op.getNumResults() != 0) expected = op.getResult(0).getType();
3648   if (op.getNumOperands() != 0) expected = op.getOperand(0).getType();
3649   if (!expected) return false;
3650 
3651   auto typeMatch = [&](Type actual) { return actual == expected; };
3652   return llvm::all_of(op.getOperandTypes(), typeMatch) &&
3653          llvm::all_of(op.getResultTypes(), typeMatch);
3654 }
3655 
3656 // Checks the following eligibility criteria for compact printing of reduce:
3657 // E1. The reduce-op wraps a single inner-op in the associated region.
3658 // E2. The single operation is a commutative binary-op from the dialect, zero
3659 //     region, producing single result such that the operands and result all
3660 //     have the same type.
3661 // E3. The reduce-op consist of at least one input-operand; The operand-types of
3662 //     inner-op should be derived trivially from the element-type of reduce-op's
3663 //     first input-operand.
3664 // E4. The  arguments of the region's only basic block are forwarded perfectly
3665 //     to inner-op's operands.
3666 // E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the
3667 //     same location.
3668 // E6. The single operation result is perfectly forwarded to the reduce op
3669 //     return.
isEligibleForCompactPrint(ReduceOp op)3670 static bool isEligibleForCompactPrint(ReduceOp op) {
3671   // Check E1.
3672   auto& block = op.body().front();
3673   if (!hasSingleElement(block.without_terminator())) return false;
3674 
3675   Operation& innerOp = *block.begin();
3676 
3677   // Check E2.
3678   if (innerOp.getDialect() != op->getDialect()) return false;
3679 
3680   if (innerOp.getNumOperands() != 2 ||
3681       !innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
3682       !hasSameOperandAndResultTypes(innerOp) ||
3683       !innerOp.hasTrait<mlir::OpTrait::IsCommutative>() ||
3684       !innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
3685     return false;
3686 
3687   // Check E3.
3688   if (op.operands().empty()) return false;
3689 
3690   auto elemType =
3691       op.operands()[0].getType().cast<TensorType>().getElementType();
3692   auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
3693   if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false;
3694 
3695   // Check E4.
3696   if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false;
3697 
3698   // Check E5.
3699   auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
3700   if (!retOp) return false;
3701 
3702   auto blockArgLoc = block.getArgument(0).getLoc();
3703   if (blockArgLoc != block.getArgument(1).getLoc()) return false;
3704 
3705   if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() ||
3706       blockArgLoc != op.getLoc())
3707     return false;
3708 
3709   // Check E6.
3710   return llvm::equal(innerOp.getResults(), retOp.getOperands());
3711 }
3712 
print(OpAsmPrinter & p)3713 void ReduceOp::print(OpAsmPrinter& p) {
3714   {
3715     // Print the pairs of operands under the form:
3716     //   (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
3717     StringRef comma = "";
3718     int numOperandPairs = getNumOperands() / 2;
3719     for (int opId : llvm::seq<int>(0, numOperandPairs)) {
3720       p << comma << "(" << getOperand(opId)
3721         << " init: " << getOperand(opId + numOperandPairs) << ")";
3722       comma = ", ";
3723     }
3724   }
3725 
3726   // If the reduce-op is eligible for compact printing, we emit the one-liner:
3727   // stablehlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
3728   // Note: We are not printing the function type of reduction operation. We
3729   // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3)
3730   // to derive the type from that of reduce-op.
3731   if (isEligibleForCompactPrint(*this)) {
3732     Operation& innerOp = body().front().front();
3733     p << " applies ";
3734     printEscapedString(innerOp.getName().getStringRef(), p.getStream());
3735 
3736     p << " across dimensions = [";
3737     llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
3738     p << "]";
3739     p << " : ";
3740     p.printFunctionalType(*this);
3741   } else {
3742     p << " across dimensions = [";
3743     llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
3744     p << "]";
3745     p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
3746     p << " : ";
3747     p.printFunctionalType(*this);
3748     p.printNewline();
3749     p << " reducer";
3750     {
3751       // Print the pairs of block operands under the form:
3752       //   (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc):
3753       Block& reducer = body().front();
3754       int numOperandPairs = getNumOperands() / 2;
3755       for (int opId : llvm::seq<int>(0, numOperandPairs)) {
3756         p << "(";
3757         p.printRegionArgument(reducer.getArgument(opId));
3758         p << ", ";
3759         p.printRegionArgument(reducer.getArgument(opId + numOperandPairs));
3760         p << ") ";
3761       }
3762     }
3763     p << ' ';
3764     p.printRegion(body(), /*printEntryBlockArgs=*/false);
3765   }
3766 }
3767 
parse(OpAsmParser & parser,OperationState & result)3768 ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) {
3769   llvm::SMLoc loc = parser.getCurrentLocation();
3770   Location currLocation = parser.getEncodedSourceLoc(loc);
3771 
3772   // Parse the operands of reduce-op, this is a list of pair under the form:
3773   //   (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
3774   // Each input to reduce is paired with its init value, even though in memory
3775   // they are stored with the input first and the init values after.
3776   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
3777   SmallVector<OpAsmParser::UnresolvedOperand, 2> initOperands;
3778   do {
3779     (void)parser.parseOptionalComma();
3780     if (parser.parseOptionalLParen()) break;
3781     OpAsmParser::UnresolvedOperand operand, initOperand;
3782     if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
3783         parser.parseColon() || parser.parseOperand(initOperand) ||
3784         parser.parseRParen())
3785       return failure();
3786     operands.push_back(operand);
3787     initOperands.push_back(initOperand);
3788   } while (true);
3789   operands.append(initOperands);
3790 
3791   // Check if we are parsing the compact version of reduce-op:
3792   // stablehlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
3793   // else parse the "region-based" variant.
3794   if (failed(parser.parseOptionalKeyword("applies"))) {
3795     // Parse the inner-op dimensions, reduce-op's function-type and
3796     // optional location.
3797     SmallVector<int64_t> dimensions;
3798     auto parseDim = [&]() -> ParseResult {
3799       if (parser.parseInteger(dimensions.emplace_back())) return failure();
3800       return success();
3801     };
3802 
3803     FunctionType reduceOpFntype;
3804     if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
3805         parser.parseEqual() ||
3806         parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
3807                                        parseDim) ||
3808         parser.parseOptionalAttrDict(result.attributes) ||
3809         parser.parseColon() || parser.parseType(reduceOpFntype) ||
3810         parser.parseKeyword("reducer"))
3811       return failure();
3812     OpBuilder builder(parser.getBuilder().getContext());
3813     result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
3814 
3815     // Parse the "reducer" region now.
3816     SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerOperands;
3817     SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerInitOperands;
3818     SmallVector<Type, 2> reducerTypes;
3819     SmallVector<Type, 2> reducerInitTypes;
3820     SmallVector<Optional<Location>, 2> reducerLocs;
3821     SmallVector<Optional<Location>, 2> reducerInitLocs;
3822     auto parseBlockOperand =
3823         [&](SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands,
3824             SmallVectorImpl<Type>& types,
3825             SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
3826       OpAsmParser::UnresolvedOperand operand;
3827       Type type;
3828       Optional<Location> loc;
3829       if (parser.parseOperand(operand, /*allowResultNumber=*/false) ||
3830           parser.parseColon() || parser.parseType(type) ||
3831           parser.parseOptionalLocationSpecifier(loc))
3832         return failure();
3833       operands.push_back(operand);
3834       types.push_back(type);
3835       locs.push_back(loc);
3836       return success();
3837     };
3838     do {
3839       if (failed(parser.parseOptionalLParen())) break;
3840       if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
3841           parser.parseComma() ||
3842           parseBlockOperand(reducerInitOperands, reducerInitTypes,
3843                             reducerInitLocs) ||
3844           parser.parseRParen())
3845         return failure();
3846     } while (true);
3847     reducerOperands.append(reducerInitOperands);
3848     reducerTypes.append(reducerInitTypes);
3849     reducerLocs.append(reducerInitLocs);
3850     result.addTypes(reduceOpFntype.getResults());
3851     SmallVector<OpAsmParser::Argument> reducerArgs;
3852     createArgs(reducerOperands, reducerTypes, reducerArgs);
3853 
3854     // Derive the SSA-values for reduce-op's operands and parse the region, and
3855     // the optional trailing location.
3856     Optional<Location> trailingLoc;
3857     if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
3858                                result.operands) ||
3859         parser.parseRegion(*result.addRegion(), reducerArgs))
3860       return failure();
3861     // Set the individual block arguments.
3862     for (auto argAndLoc :
3863          llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
3864       if (std::get<1>(argAndLoc))
3865         std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value());
3866     result.location = trailingLoc.value_or(currLocation);
3867     return success();
3868   }
3869 
3870   // Parse the inner-op name and check if the contract on inner-op
3871   // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
3872   FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
3873   if (failed(innerOpNameInfo)) return failure();
3874 
3875   StringRef innerOpName = innerOpNameInfo->getStringRef();
3876   Dialect* innerOpDialect = innerOpNameInfo->getDialect();
3877   if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") ||
3878       !innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
3879       !innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
3880       !innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
3881       !innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegions>()) {
3882     parser.emitError(loc,
3883                      "expected the inner-op to be a commutative binary-op from "
3884                      "stablehlo dialect, zero region, producing single result");
3885     return failure();
3886   }
3887 
3888   // Parse the inner-op dimensions, reduce-op's function-type and
3889   // optional location.
3890   SmallVector<int64_t> dimensions;
3891   auto parseDim = [&]() -> ParseResult {
3892     if (parser.parseInteger(dimensions.emplace_back())) return failure();
3893     return success();
3894   };
3895 
3896   Optional<Location> explicitLoc;
3897   FunctionType reduceOpFntype;
3898   if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
3899       parser.parseEqual() ||
3900       parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
3901       parser.parseColon() || parser.parseType(reduceOpFntype) ||
3902       parser.parseOptionalLocationSpecifier(explicitLoc))
3903     return failure();
3904 
3905   if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
3906     if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
3907     return parser.emitError(loc,
3908                             "input types missing in reduce-op function type");
3909   }
3910 
3911   // If location of reduce-op is explicitly provided, then use it; Else use
3912   // the parser's current location.
3913   Location reduceOpLoc = explicitLoc.value_or(currLocation);
3914 
3915   // Derive the SSA-values for reduce-op's operands.
3916   if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
3917                              result.operands))
3918     return failure();
3919 
3920   // Derive the type of inner-op from that of reduce-op's input operand.
3921   auto innerOpType = RankedTensorType::get(
3922       /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
3923 
3924   // Add a region for reduce-op.
3925   Region& region = *result.addRegion();
3926 
3927   // Create a basic-block inside reduce-op's region.
3928   Block& block = region.emplaceBlock();
3929   auto lhs = block.addArgument(innerOpType, reduceOpLoc);
3930   auto rhs = block.addArgument(innerOpType, reduceOpLoc);
3931 
3932   // Create and insert an "inner-op" operation in the block.
3933   OpBuilder builder(parser.getBuilder().getContext());
3934   builder.setInsertionPointToStart(&block);
3935 
3936   OperationState innerOpState(reduceOpLoc, innerOpName);
3937   innerOpState.operands.push_back(lhs);
3938   innerOpState.operands.push_back(rhs);
3939   innerOpState.addTypes(innerOpType);
3940 
3941   Operation* innerOp = builder.create(innerOpState);
3942 
3943   // Insert a return statement in the block returning the inner-op's result.
3944   builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
3945 
3946   // Populate the reduce-op operation-state with result-type, location, and
3947   // dimension attribute.
3948   result.addTypes(reduceOpFntype.getResults());
3949   result.location = innerOp->getLoc();
3950   result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
3951 
3952   return success();
3953 }
3954 
verify()3955 LogicalResult ReduceOp::verify() {
3956   // Check that there are even number of operands and >= 2.
3957   if (getNumOperands() % 2 != 0 || getOperands().empty())
3958     return emitOpError() << "expects the size of operands to be even and >= 2";
3959 
3960   // Collect the input and init-value operands. Note that the operand-type is
3961   // enforced as "TensorType" by ODS.
3962   int64_t numInputs = getNumOperands() / 2;
3963   auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
3964       getOperandTypes(),
3965       [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3966   ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(),
3967                                      operandTensorTypes.begin() + numInputs);
3968   ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs,
3969                                       operandTensorTypes.end());
3970 
3971   // Check for unranked tensors in input operands.
3972   int64_t rankedInputIdx = -1;
3973   for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
3974     if (inputArgTypes[inputIdx].hasRank()) {
3975       rankedInputIdx = inputIdx;
3976       break;
3977     }
3978   }
3979 
3980   bool allInputsUnranked = (rankedInputIdx == -1);
3981 
3982   // Check that all input operands have compatible shapes. The element types may
3983   // be different.
3984   if (!allInputsUnranked) {
3985     for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
3986       if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
3987                                              inputArgTypes[inputIdx]))) {
3988         return emitOpError()
3989                << "expects all inputs to have compatible shapes. Shape at"
3990                << " input-index " << inputIdx
3991                << " is not compatible with shape at input-index "
3992                << rankedInputIdx;
3993       }
3994     }
3995   }
3996 
3997   // Check that
3998   //   1. the dimensions of reduce-op are in-bounds for the given shape.
3999   //   2. the dimension-attribute have no duplicate entries.
4000   DenseSet<int64_t> dimensionsToReduceSet;
4001   for (int64_t dimension : dimensions().getValues<int64_t>()) {
4002     if ((!allInputsUnranked &&
4003          dimension >= inputArgTypes[rankedInputIdx].getRank()) ||
4004         dimension < 0) {
4005       return emitError() << "Out-of-bounds dimension " << dimension
4006                          << " for input-tensor rank: "
4007                          << inputArgTypes[rankedInputIdx].getRank();
4008     }
4009 
4010     if (!dimensionsToReduceSet.insert(dimension).second) {
4011       return emitError() << "Duplicate reduction dimension: " << dimension;
4012     }
4013   }
4014 
4015   // Verify the inner block defining the reducer function.
4016   SmallVector<int64_t> newDimensions;
4017   if (!allInputsUnranked) {
4018     for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
4019          ++inputIdx) {
4020       if (!dimensionsToReduceSet.count(inputIdx)) {
4021         newDimensions.push_back(
4022             inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
4023       }
4024     }
4025   }
4026 
4027   Block& block = body().front();
4028   SmallVector<TensorType> accumulatorSubShapes;
4029   if (failed(verifyReducerShape(this->getLoc(), block, inputArgTypes,
4030                                 initValueTypes, numInputs, newDimensions,
4031                                 allInputsUnranked, accumulatorSubShapes)))
4032     return failure();
4033 
4034   // Check if the reduce-op's result-type matches with the one derived from
4035   // the reducer-block and dimensions attribute.
4036   if (getResults().size() != accumulatorSubShapes.size())
4037     return emitError() << "Unexpected number of reduce-op's returned values: "
4038                        << getResults().size() << " vs "
4039                        << accumulatorSubShapes.size() << " (expected)";
4040 
4041   for (int64_t shapeIdx = 0;
4042        shapeIdx < static_cast<int64_t>(accumulatorSubShapes.size());
4043        shapeIdx++) {
4044     // The result-type is enforced as "TensorType" by ODS.
4045     auto opResultType = getResult(shapeIdx).getType().cast<TensorType>();
4046 
4047     // Check element-type.
4048     if (accumulatorSubShapes[shapeIdx].getElementType() !=
4049         opResultType.getElementType()) {
4050       return emitError()
4051              << "Unexpected element-type for reduce-op's return value at index "
4052              << shapeIdx << ": " << opResultType.getElementType() << " vs "
4053              << accumulatorSubShapes[shapeIdx].getElementType()
4054              << " (expected)";
4055     }
4056 
4057     // Check shape.
4058     if (!allInputsUnranked && opResultType.hasRank() &&
4059         failed(verifyCompatibleShape(newDimensions, opResultType.getShape()))) {
4060       Type expectedResultType = RankedTensorType::get(
4061           newDimensions, accumulatorSubShapes[shapeIdx].getElementType());
4062       return emitError()
4063              << "Unexpected type for reduce-op's return value at index "
4064              << shapeIdx << ": " << opResultType << " vs " << expectedResultType
4065              << " (expected)";
4066     }
4067   }
4068 
4069   return success();
4070 }
4071 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4072 LogicalResult ReduceOp::reifyReturnTypeShapes(
4073     OpBuilder& builder, ValueRange operands,
4074     SmallVectorImpl<Value>& reifiedReturnShapes) {
4075   ReduceOp::Adaptor adaptor(operands);
4076   auto inputs = adaptor.operands();
4077 
4078   auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
4079   // Not support unranked type a.t.m.
4080   if (!operandType) return failure();
4081 
4082   Location loc = this->getLoc();
4083   SmallVector<Value, 4> shapeValues;
4084   SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
4085   shapeValues.reserve(operandType.getRank());
4086   Type shapeScalarType = builder.getIndexType();
4087   auto toShapeScalarType = [&](Value v) {
4088     return maybeCastTo(builder, loc, v, shapeScalarType);
4089   };
4090 
4091   for (const auto& element : llvm::enumerate(operandType.getShape())) {
4092     int64_t idx = element.index();
4093     auto* it = std::find(dimensions.begin(), dimensions.end(), idx);
4094     if (it != dimensions.end()) {
4095       continue;
4096     }
4097     Value valueDim = toShapeScalarType(
4098         builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
4099     shapeValues.push_back(valueDim);
4100   }
4101 
4102   Value outputShape = builder.create<tensor::FromElementsOp>(
4103       loc,
4104       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4105                             shapeScalarType),
4106       shapeValues);
4107   for (size_t i = 0; i < inputs.size(); ++i) {
4108     reifiedReturnShapes.push_back(outputShape);
4109   }
4110 
4111   return success();
4112 }
4113 
4114 //===----------------------------------------------------------------------===//
4115 // RngBitGeneratorOp
4116 //===----------------------------------------------------------------------===//
4117 
4118 // Verify that input state has the same shape as output shape
verify()4119 LogicalResult RngBitGeneratorOp::verify() {
4120   auto initialShape = initial_state().getType().dyn_cast<RankedTensorType>();
4121   auto outputShape = output_state().getType().dyn_cast<RankedTensorType>();
4122   if (initialShape.getShape() != outputShape.getShape())
4123     return emitOpError()
4124            << "output state shape must match initial state shape. Got: "
4125            << initialShape << " and " << outputShape;
4126   return success();
4127 }
4128 
4129 //===----------------------------------------------------------------------===//
4130 // RngOp
4131 //===----------------------------------------------------------------------===//
4132 
verify()4133 LogicalResult RngOp::verify() {
4134   auto dist = rng_distribution();
4135   if (dist == RngDistribution::UNIFORM) {
4136     return success();
4137   }
4138   auto muTy = a().getType().cast<TensorType>().getElementType();
4139   auto sigmaTy = b().getType().cast<TensorType>().getElementType();
4140   if (muTy.isa<FloatType>() && sigmaTy.isa<FloatType>()) {
4141     return success();
4142   }
4143   return emitOpError() << "mu and sigma must be floats";
4144 }
4145 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4146 LogicalResult RngOp::inferReturnTypeComponents(
4147     MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
4148     DictionaryAttr attributes, RegionRange regions,
4149     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4150   return rngInferReturnTypeComponents(context, location, operands, attributes,
4151                                       regions, inferredReturnShapes);
4152 }
4153 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4154 LogicalResult RngOp::reifyReturnTypeShapes(
4155     OpBuilder& builder, ValueRange operands,
4156     SmallVectorImpl<Value>& reifiedReturnShapes) {
4157   RngOp::Adaptor adaptor(operands);
4158   reifiedReturnShapes.push_back(
4159       castToIndexTensor(builder, getLoc(), adaptor.shape()));
4160   return success();
4161 }
4162 
4163 //===----------------------------------------------------------------------===//
4164 // SelectOp
4165 //===----------------------------------------------------------------------===//
4166 
verify()4167 LogicalResult SelectOp::verify() {
4168   // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
4169   //   (a) have the same element type, and
4170   //   (b) have compatible shapes (i.e. the same shape and/or at least one
4171   //       dynamic shape)
4172   if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
4173     return emitOpError()
4174            << "requires compatible types for non-predicate operands";
4175 
4176   // The predicate, if not-scalar, should have the same shape as the remaining
4177   // operands.
4178   auto predTy = pred().getType().dyn_cast<RankedTensorType>();
4179   bool predMayBeScalar = !predTy || predTy.getRank() == 0;
4180   if (predMayBeScalar) return success();
4181 
4182   if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
4183     return emitOpError() << "requires the same shape for all operands";
4184 
4185   return success();
4186 }
4187 
4188 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
4189 // the return type based on operand type.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4190 LogicalResult SelectOp::inferReturnTypeComponents(
4191     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4192     DictionaryAttr attributes, RegionRange,
4193     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4194   SelectOp::Adaptor op(operands, attributes);
4195   auto trueType = op.on_true().getType().cast<TensorType>();
4196   auto falseType = op.on_false().getType().cast<TensorType>();
4197 
4198   // The output shape should be the most general of the operand shapes at each
4199   // dimension.
4200   ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back();
4201   if (trueType == falseType || !trueType.hasRank()) {
4202     outputType = ShapedTypeComponents(trueType.cast<ShapedType>());
4203   } else if (!falseType.hasRank()) {
4204     outputType = ShapedTypeComponents(falseType.cast<ShapedType>());
4205   } else {
4206     assert(trueType.getRank() == falseType.getRank());
4207     llvm::SmallVector<int64_t, 4> dims;
4208     dims.reserve(trueType.getRank());
4209     for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) {
4210       dims.push_back(std::get<0>(dim) == std::get<1>(dim)
4211                          ? std::get<0>(dim)
4212                          : ShapedType::kDynamicSize);
4213     }
4214     outputType = ShapedTypeComponents(dims, trueType.getElementType());
4215   }
4216   return success();
4217 }
4218 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4219 LogicalResult SelectOp::reifyReturnTypeShapes(
4220     OpBuilder& builder, ValueRange operands,
4221     SmallVectorImpl<Value>& reifiedReturnShapes) {
4222   // For `hlo.select`, the first operand may be a scalar.
4223   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1],
4224                                      &reifiedReturnShapes);
4225 }
4226 
4227 //===----------------------------------------------------------------------===//
4228 // SetDimensionSizeOp
4229 //===----------------------------------------------------------------------===//
4230 
verify()4231 LogicalResult SetDimensionSizeOp::verify() {
4232   if (auto size = this->size().getType().dyn_cast<RankedTensorType>()) {
4233     if (size.getRank() != 0)
4234       return emitOpError() << "size operand should be of rank-0";
4235   }
4236 
4237   return verifyDimAttr(*this);
4238 }
4239 
4240 // TODO(b/238903565): Switch to inferReturnTypeComponents after adding support
4241 // for the encoding upstream.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)4242 LogicalResult SetDimensionSizeOp::inferReturnTypes(
4243     MLIRContext* context, Optional<Location> location, ValueRange operands,
4244     DictionaryAttr attributes, RegionRange regions,
4245     SmallVectorImpl<Type>& inferredReturnTypes) {
4246   Location loc = location.value_or(UnknownLoc::get(context));
4247 
4248   SetDimensionSizeOp::Adaptor adaptor(operands, attributes, regions);
4249   if (failed(adaptor.verify(loc))) return failure();
4250 
4251   auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
4252   if (!inputType) {
4253     inferredReturnTypes.push_back(adaptor.operand().getType());
4254     return success();
4255   }
4256 
4257   int64_t dim = adaptor.dimension();
4258   int64_t rank = inputType.getRank();
4259   if (dim < 0 || dim >= rank) {
4260     return mlir::emitError(loc) << "expects dimension to be in range [0, "
4261                                 << rank << "); got: [" << dim << "].";
4262   }
4263 
4264   auto shape = llvm::to_vector<4>(inputType.getShape());
4265   llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamicSize);
4266   if (auto encoding =
4267           inputType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>())
4268     bounds = llvm::to_vector<4>(encoding.getBounds());
4269 
4270   // TODO(hinsu): Handle the case when the size operand is a constant.
4271   if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim];
4272   shape[dim] = ShapedType::kDynamicSize;
4273 
4274   auto extensions = TypeExtensionsAttr::get(context, bounds);
4275   auto resultType =
4276       RankedTensorType::get(shape, inputType.getElementType(), extensions);
4277   inferredReturnTypes.push_back(resultType);
4278   return success();
4279 }
4280 
4281 //===----------------------------------------------------------------------===//
4282 // PadOp
4283 //===----------------------------------------------------------------------===//
4284 
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4285 LogicalResult PadOp::inferReturnTypeComponents(
4286     MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4287     DictionaryAttr attributes, RegionRange regions,
4288     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4289   PadOp::Adaptor adaptor(operands, attributes, regions);
4290   auto inputType = adaptor.operand().getType().cast<RankedTensorType>();
4291   auto padType = adaptor.padding_value().getType().cast<RankedTensorType>();
4292 
4293   if (padType.getRank() != 0) {
4294     return emitOptionalError(
4295         location, llvm::formatv("padding value type should be a rank-0 "
4296                                 "tensor, is rank {0}",
4297                                 padType.getRank()));
4298   }
4299 
4300   const auto& paddingLow = adaptor.edge_padding_low();
4301   if (paddingLow.getType().getNumElements() != inputType.getRank()) {
4302     return emitOptionalError(
4303         location,
4304         llvm::formatv(
4305             "edge_padding_low length ({0}) must match operand rank ({1})",
4306             paddingLow.getType().getNumElements(), inputType.getRank()));
4307   }
4308 
4309   const auto& paddingHigh = adaptor.edge_padding_high();
4310   if (paddingHigh.getType().getNumElements() != inputType.getRank()) {
4311     return emitOptionalError(
4312         location,
4313         llvm::formatv(
4314             "edge_padding_high length ({0}) must match operand rank ({1})",
4315             paddingHigh.getType().getNumElements(), inputType.getRank()));
4316   }
4317 
4318   const auto& paddingInterior = adaptor.interior_padding();
4319   if (paddingInterior.getType().getNumElements() != inputType.getRank()) {
4320     return emitOptionalError(
4321         location,
4322         llvm::formatv(
4323             "interior_padding length ({0}) must match operand rank ({1})",
4324             paddingInterior.getType().getNumElements(), inputType.getRank()));
4325   }
4326 
4327   auto inputShape = inputType.getShape();
4328   SmallVector<int64_t> resultShape;
4329   for (int i = 0, e = inputShape.size(); i < e; i++) {
4330     if (isDynamicDimSize(inputShape[i])) {
4331       resultShape.push_back(ShapedType::kDynamicSize);
4332       continue;
4333     }
4334 
4335     int64_t paddingLowVal = paddingLow.getValues<APInt>()[i].getSExtValue();
4336     int64_t paddingHighVal = paddingHigh.getValues<APInt>()[i].getSExtValue();
4337     int64_t paddingInteriorVal =
4338         paddingInterior.getValues<APInt>()[i].getSExtValue();
4339     if (paddingInteriorVal < 0) {
4340       return emitOptionalError(
4341           location, llvm::formatv("Interior padding cannot be negative: {0}",
4342                                   paddingInteriorVal));
4343     }
4344     int64_t expectedOutput =
4345         inputShape[i] + paddingLowVal + paddingHighVal +
4346         std::max<int64_t>(inputShape[i] - 1, 0LL) * paddingInteriorVal;
4347     if (expectedOutput < 0) {
4348       return emitOptionalError(
4349           location,
4350           llvm::formatv("Padding result in negative size for dimension {0}",
4351                         i));
4352     }
4353     resultShape.push_back(expectedOutput);
4354   }
4355   inferredReturnShapes.emplace_back(resultShape, inputType.getElementType());
4356 
4357   return success();
4358 }
4359 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4360 LogicalResult PadOp::reifyReturnTypeShapes(
4361     OpBuilder& builder, ValueRange operands,
4362     SmallVectorImpl<Value>& reifiedReturnShapes) {
4363   PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary());
4364   auto loc = this->getLoc();
4365   Value operand = adaptor.operand();
4366   auto operandTy = operand.getType().cast<RankedTensorType>();
4367 
4368   llvm::SmallVector<int32_t> padHigh;
4369   llvm::SmallVector<int32_t> padLow;
4370   llvm::SmallVector<int32_t> padInterior;
4371 
4372   auto padHighAttr = adaptor.edge_padding_high();
4373   auto padLowAttr = adaptor.edge_padding_low();
4374   auto padInteriorAttr = adaptor.interior_padding();
4375 
4376   padHigh.reserve(padHighAttr.getNumElements());
4377   padLow.reserve(padLowAttr.getNumElements());
4378   padInterior.reserve(padInteriorAttr.getNumElements());
4379 
4380   for (const APInt& val : padHighAttr.getValues<APInt>())
4381     padHigh.push_back(val.getSExtValue());
4382 
4383   for (const APInt& val : padLowAttr.getValues<APInt>())
4384     padLow.push_back(val.getSExtValue());
4385 
4386   for (const APInt& val : padInteriorAttr.getValues<APInt>())
4387     padInterior.push_back(val.getSExtValue());
4388 
4389   Value one = builder.create<arith::ConstantIndexOp>(loc, 1).getResult();
4390   Value zero = builder.create<arith::ConstantIndexOp>(loc, 0).getResult();
4391 
4392   llvm::SmallVector<Value> dimensions;
4393   dimensions.reserve(operandTy.getRank());
4394   for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
4395     Value padEdge =
4396         builder.create<arith::ConstantIndexOp>(loc, padHigh[i] + padLow[i]);
4397 
4398     // First we grab the initial interior size.
4399     Value dim = builder.create<tensor::DimOp>(loc, operand, i).getResult();
4400 
4401     // Compute the interior of the tensor and determine padding size.
4402     if (padInterior[i] > 0) {
4403       Value padInter =
4404           builder.create<arith::ConstantIndexOp>(loc, padInterior[i])
4405               .getResult();
4406       Value interior = builder.create<arith::SubIOp>(loc, dim, one).getResult();
4407       interior = builder.create<arith::MaxSIOp>(loc, interior, zero);
4408       interior = builder.create<arith::MulIOp>(loc, interior, padInter);
4409       dim = builder.create<arith::AddIOp>(loc, dim, interior).getResult();
4410     }
4411 
4412     // Then we add the padding on the edge of the tensor.
4413     dim = builder.create<arith::AddIOp>(loc, dim, padEdge).getResult();
4414     dimensions.push_back(dim);
4415   }
4416 
4417   Value dimensionTensor =
4418       builder.create<tensor::FromElementsOp>(loc, dimensions).getResult();
4419   reifiedReturnShapes.push_back(dimensionTensor);
4420   return success();
4421 }
4422 
4423 //===----------------------------------------------------------------------===//
4424 // DynamicPadOp
4425 //===----------------------------------------------------------------------===//
4426 
verify()4427 LogicalResult DynamicPadOp::verify() {
4428   auto inputType = operand().getType().dyn_cast<RankedTensorType>();
4429   // If operand is unranked, there is very little to verify statically.
4430   if (!inputType) return success();
4431   int inputRank = inputType.getRank();
4432 
4433   auto padType = padding_value().getType().cast<RankedTensorType>();
4434   if (padType.getRank() != 0) {
4435     return emitOpError() << "padding value type should be a rank-0";
4436   }
4437 
4438   auto paddingLowType = edge_padding_low().getType().cast<RankedTensorType>();
4439   if (paddingLowType.getNumElements() != inputRank) {
4440     return emitOpError() << "edge_padding_low length("
4441                          << paddingLowType.getNumElements()
4442                          << ") must match operand rank(" << inputRank << ").";
4443   }
4444 
4445   auto paddingHighType = edge_padding_high().getType().cast<RankedTensorType>();
4446   if (paddingHighType.getNumElements() != inputRank) {
4447     return emitOpError() << "edge_padding_high length("
4448                          << paddingHighType.getNumElements()
4449                          << ") must match operand rank(" << inputRank << ").";
4450   }
4451 
4452   auto interiorPaddingType =
4453       interior_padding().getType().cast<RankedTensorType>();
4454   if (interiorPaddingType.getNumElements() != inputRank) {
4455     return emitOpError() << "edge_padding_interior length("
4456                          << interiorPaddingType.getNumElements()
4457                          << ") must match operand rank(" << inputRank << ").";
4458   }
4459 
4460   auto outputType = getResult().getType().dyn_cast<RankedTensorType>();
4461   // If result is unranked, there is very little to verify statically.
4462   if (!outputType) return success();
4463   int outputRank = outputType.getRank();
4464   if (inputRank != outputRank) {
4465     return emitOpError() << "operand rank(" << inputRank
4466                          << ") must match result(" << outputRank << ").";
4467   }
4468 
4469   return success();
4470 }
4471 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4472 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
4473     OpBuilder& builder, ValueRange operands,
4474     SmallVectorImpl<Value>& reifiedReturnShapes) {
4475   DynamicPadOp::Adaptor adaptor(operands);
4476   Value operand = adaptor.operand();
4477   Value edgePaddingLow = adaptor.edge_padding_low();
4478   Value edgePaddingHigh = adaptor.edge_padding_high();
4479   Value interiorPadding = adaptor.interior_padding();
4480 
4481   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4482   // Not support unranked pad a.t.m.
4483   if (!operandType) return failure();
4484 
4485   auto loc = this->getLoc();
4486   SmallVector<Value, 4> shapeValues;
4487   shapeValues.reserve(operandType.getRank());
4488   Type shapeScalarType =
4489       edgePaddingLow.getType().cast<ShapedType>().getElementType();
4490 
4491   auto toShapeScalarType = [&](Value v) {
4492     return maybeCastTo(builder, loc, v, shapeScalarType);
4493   };
4494 
4495   Value zero =
4496       toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 0));
4497   Value one = toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 1));
4498 
4499   for (int idx : llvm::seq<int>(0, operandType.getShape().size())) {
4500     Value valueDim =
4501         toShapeScalarType(builder.create<tensor::DimOp>(loc, operand, idx));
4502     Value offset = builder.create<arith::ConstantIndexOp>(loc, idx);
4503     Value valueLow =
4504         builder.create<tensor::ExtractOp>(loc, edgePaddingLow, offset);
4505     Value valueHigh =
4506         builder.create<tensor::ExtractOp>(loc, edgePaddingHigh, offset);
4507     Value valueInterior =
4508         builder.create<tensor::ExtractOp>(loc, interiorPadding, offset);
4509     // output_size = input_size + padding_low + padding_high + interior *
4510     // max(input_size - 1, 0)
4511     Value valueDimLessThanOne = builder.create<arith::CmpIOp>(
4512         loc, arith::CmpIPredicate::slt, valueDim, one);
4513     Value interiorSize = builder.create<arith::MulIOp>(
4514         loc, valueInterior,
4515         builder.create<mlir::arith::SelectOp>(
4516             loc, valueDimLessThanOne, zero,
4517             builder.create<arith::SubIOp>(loc, valueDim, one)));
4518     shapeValues.push_back(builder.create<arith::AddIOp>(
4519         loc,
4520         builder.create<arith::AddIOp>(
4521             loc, builder.create<arith::AddIOp>(loc, interiorSize, valueDim),
4522             valueLow),
4523         valueHigh));
4524   }
4525 
4526   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
4527       loc,
4528       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4529                             shapeScalarType),
4530       shapeValues));
4531 
4532   return success();
4533 }
4534 
4535 //===----------------------------------------------------------------------===//
4536 // ReshapeOp
4537 //===----------------------------------------------------------------------===//
4538 
verify()4539 LogicalResult ReshapeOp::verify() {
4540   // If the operand type is dynamically shaped there is nothing to verify.
4541   auto operandTy = operand().getType().dyn_cast<RankedTensorType>();
4542   if (!operandTy || !operandTy.hasStaticShape()) return success();
4543 
4544   // If the operand type is statically shaped (not required) the number of
4545   // elements must match that of the result type.
4546   auto resultTy = getType().cast<RankedTensorType>();
4547   assert(resultTy && resultTy.hasStaticShape() &&
4548          "result type must be statically shaped");
4549   int64_t numResultElements = resultTy.getNumElements();
4550   int64_t numOperandElements = operandTy.getNumElements();
4551   if (numResultElements != numOperandElements)
4552     return emitOpError() << "number of output elements (" << numResultElements
4553                          << ") doesn't match expected number of elements ("
4554                          << numOperandElements << ")";
4555 
4556   return success();
4557 }
4558 
4559 //===----------------------------------------------------------------------===//
4560 // ReplicaId Op
4561 //===----------------------------------------------------------------------===//
4562 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)4563 LogicalResult ReplicaIdOp::inferReturnTypes(
4564     MLIRContext* context, Optional<Location>, ValueRange operands,
4565     DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
4566   inferredReturnTypes.push_back(RankedTensorType::get(
4567       /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
4568   return success();
4569 }
4570 
4571 //===----------------------------------------------------------------------===//
4572 // If Op
4573 //===----------------------------------------------------------------------===//
4574 
verifyConditionalBranch(Operation * op,Region & region,llvm::Twine branchName)4575 static LogicalResult verifyConditionalBranch(Operation* op, Region& region,
4576                                              llvm::Twine branchName) {
4577   if (region.getNumArguments() != 0)
4578     return op->emitOpError()
4579            << branchName << " must have 0 arguments, but found "
4580            << region.getNumArguments();
4581 
4582   TypeRange branchReturnTypes =
4583       region.front().getTerminator()->getOperandTypes();
4584   if (branchReturnTypes != op->getResultTypes())
4585     return op->emitOpError()
4586            << branchName << " returned types (" << branchReturnTypes
4587            << ") do not match op result types (" << op->getResultTypes() << ")";
4588 
4589   return success();
4590 }
4591 
verify()4592 LogicalResult IfOp::verify() {
4593   if (failed(verifyConditionalBranch(*this, true_branch(),
4594                                      /*branchName=*/"true_branch"))) {
4595     return failure();
4596   }
4597 
4598   if (failed(verifyConditionalBranch(*this, false_branch(),
4599                                      /*branchName=*/"false_branch"))) {
4600     return failure();
4601   }
4602   return success();
4603 }
4604 
4605 //===----------------------------------------------------------------------===//
4606 // Case Op
4607 //===----------------------------------------------------------------------===//
4608 
verify()4609 LogicalResult CaseOp::verify() {
4610   auto numBranches = branches().size();
4611 
4612   for (unsigned i = 0; i < numBranches; ++i)
4613     if (failed(verifyConditionalBranch(*this, branches()[i],
4614                                        /*branchName=*/"branch " + Twine(i))))
4615       return failure();
4616 
4617   return success();
4618 }
4619 
4620 //===----------------------------------------------------------------------===//
4621 // UnaryOps
4622 //===----------------------------------------------------------------------===//
4623 
parseUnaryOp(OpAsmParser & parser,OperationState & result)4624 ParseResult parseUnaryOp(OpAsmParser& parser, OperationState& result) {
4625   SmallVector<OpAsmParser::UnresolvedOperand> operands;
4626   Type type;
4627   // If the operand is in-between parentheses, use generic form.
4628   SMLoc loc = parser.getCurrentLocation();
4629   if (!parser.parseOptionalLParen()) {
4630     if (parser.parseOperandList(operands) || parser.parseRParen() ||
4631         parser.parseOptionalAttrDict(result.attributes) ||
4632         parser.parseColon() || parser.parseType(type))
4633       return failure();
4634     auto fnType = type.dyn_cast<FunctionType>();
4635     if (!fnType) {
4636       parser.emitError(loc, "expected function type");
4637       return failure();
4638     }
4639     if (parser.resolveOperands(operands, fnType.getInputs(), loc,
4640                                result.operands))
4641       return failure();
4642     result.addTypes(fnType.getResults());
4643     return success();
4644   }
4645   // Otherwise, use shorthand syntax.
4646   return failure(parser.parseOperandList(operands) ||
4647                  parser.parseOptionalAttrDict(result.attributes) ||
4648                  parser.parseColonType(type) ||
4649                  parser.resolveOperands(operands, type, result.operands) ||
4650                  parser.addTypeToList(type, result.types));
4651 }
4652 
printUnaryOp(Operation * op,OpAsmPrinter & p)4653 void printUnaryOp(Operation* op, OpAsmPrinter& p) {
4654   assert(op->getNumResults() == 1 && "op should have one result");
4655   assert(op->getNumOperands() == 1 && "op should have one input");
4656   // If not all types are the same, use generic form.
4657   auto resultType = op->getResult(0).getType();
4658   if (resultType != op->getOperandTypes()[0]) {
4659     p.printGenericOp(op, /*printOpName=*/false);
4660     return;
4661   }
4662   // Otherwise, use the shorthand syntax.
4663   p << ' ';
4664   p.printOperands(op->getOperands());
4665   p.printOptionalAttrDict(op->getAttrs());
4666   p << " : " << resultType;
4667 }
4668 
4669 //===----------------------------------------------------------------------===//
4670 // BinaryOps
4671 //===----------------------------------------------------------------------===//
4672 
parseBinaryOp(OpAsmParser & parser,OperationState & result)4673 ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) {
4674   SmallVector<OpAsmParser::UnresolvedOperand> operands;
4675   Type type;
4676   // If the operand list is in-between parentheses, use generic form.
4677   SMLoc loc = parser.getCurrentLocation();
4678   if (!parser.parseOptionalLParen()) {
4679     if (parser.parseOperandList(operands) || parser.parseRParen() ||
4680         parser.parseOptionalAttrDict(result.attributes) ||
4681         parser.parseColon() || parser.parseType(type))
4682       return failure();
4683     auto fnType = type.dyn_cast<FunctionType>();
4684     if (!fnType) {
4685       parser.emitError(loc, "expected function type");
4686       return failure();
4687     }
4688     if (parser.resolveOperands(operands, fnType.getInputs(), loc,
4689                                result.operands))
4690       return failure();
4691     result.addTypes(fnType.getResults());
4692     return success();
4693   }
4694   // Otherwise, use shorthand syntax.
4695   return failure(parser.parseOperandList(operands) ||
4696                  parser.parseOptionalAttrDict(result.attributes) ||
4697                  parser.parseColonType(type) ||
4698                  parser.resolveOperands(operands, type, result.operands) ||
4699                  parser.addTypeToList(type, result.types));
4700 }
4701 
printBinaryOp(Operation * op,OpAsmPrinter & p)4702 void printBinaryOp(Operation* op, OpAsmPrinter& p) {
4703   assert(op->getNumResults() == 1 && "op should have one result");
4704   // If not all types are the same, use generic form.
4705   auto resultType = op->getResult(0).getType();
4706   if (llvm::any_of(op->getOperandTypes(),
4707                    [&](Type type) { return type != resultType; })) {
4708     p.printGenericOp(op, /*printOpName=*/false);
4709     return;
4710   }
4711   // Otherwise, use the shorthand syntax.
4712   p << ' ';
4713   p.printOperands(op->getOperands());
4714   p.printOptionalAttrDict(op->getAttrs());
4715   p << " : " << resultType;
4716 }
4717 
4718 //===----------------------------------------------------------------------===//
4719 // SliceOp
4720 //===----------------------------------------------------------------------===//
4721 
4722 // Returns output dimension size for slice result for the given arguments.
4723 // Returns -1 if arguments are illegal.
inferSliceDim(int64_t inputDim,int64_t start,int64_t end,int64_t stride)4724 static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end,
4725                              int64_t stride) {
4726   if (inputDim == -1 || start < 0 || start > end || end > inputDim ||
4727       stride == 0)
4728     return -1;
4729 
4730   return llvm::divideCeil(end - start, stride);
4731 }
4732 
4733 // The following properties are already enforced by the ODS:
4734 //  type(start_indices) == type(limit_indices) == type(strides).
4735 // Verify the following properties:
4736 //  P1. Verify rank(start_indices) == 1.
4737 //  P2. Verify size(start_indices) == rank(operand).
4738 //  P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i].
4739 //  P6. Verify stride[i] > 0.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)4740 LogicalResult SliceOp::inferReturnTypes(
4741     MLIRContext* context, Optional<Location> location, ValueRange operands,
4742     DictionaryAttr attributes, RegionRange regions,
4743     SmallVectorImpl<Type>& inferredReturnTypes) {
4744   SliceOpAdaptor slice(operands, attributes);
4745   Type ty = slice.operand().getType();
4746   RankedTensorType rankedTy = ty.dyn_cast<RankedTensorType>();
4747   if (!rankedTy) {
4748     // The operand type is unranked, so the best we can infer for the result
4749     // type is an unranked tensor with the same element type as the operand
4750     // type.
4751     inferredReturnTypes.assign({ty});
4752     return success();
4753   }
4754 
4755   ShapedType attrTy = slice.start_indices().getType();
4756   // P1.
4757   // Note: ODS has type(start_indices) == type(limit_indices) == type(strides)
4758   // So this implies rank(limit_indices) == rank(strides) == 1 also.
4759   if (attrTy.getRank() != 1) {
4760     return emitOptionalError(location, "start_indices has rank ",
4761                              attrTy.getRank(), " instead of required rank 1");
4762   }
4763 
4764   // P2.
4765   int64_t rank = rankedTy.getRank();
4766   if (attrTy.getNumElements() != rank) {
4767     return emitOptionalError(
4768         location, "the number of elements in start_indices (",
4769         attrTy.getNumElements(), ") does not match the rank of the operand (",
4770         rank, ")");
4771   }
4772 
4773   SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
4774   SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
4775   SmallVector<int64_t, 4> strideVals(slice.strides().getValues<int64_t>());
4776 
4777   SmallVector<int64_t, 4> shape;
4778   shape.reserve(rank);
4779   for (int64_t i = 0, e = rank; i != e; i++) {
4780     if (isDynamicDimSize(rankedTy.getDimSize(i))) {
4781       shape.push_back(ShapedType::kDynamicSize);
4782       continue;
4783     }
4784     // P3.
4785     if (start[i] < 0)
4786       return emitOptionalError(location, "negative start index ", start[i],
4787                                " in dimension ", i);
4788     // P4.
4789     if (limit[i] > rankedTy.getDimSize(i))
4790       return emitOptionalError(location, "limit index ", limit[i],
4791                                " is larger than dimension size ",
4792                                rankedTy.getDimSize(i), " in dimension ", i);
4793     // P5.
4794     if (start[i] > limit[i])
4795       return emitOptionalError(location, "start index ", start[i],
4796                                " is larger than limit index ", limit[i],
4797                                " in dimension ", i);
4798     // P6.
4799     if (strideVals[i] <= 0)
4800       return emitOptionalError(location, "stride must be positive but got ",
4801                                strideVals[i], " in dimension ", i);
4802 
4803     shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i],
4804                                   strideVals[i]));
4805   }
4806   inferredReturnTypes.assign(
4807       {RankedTensorType::get(shape, rankedTy.getElementType())});
4808   return success();
4809 }
4810 
4811 //===----------------------------------------------------------------------===//
4812 // SortOp
4813 //===----------------------------------------------------------------------===//
4814 
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool isStable)4815 void SortOp::build(OpBuilder& builder, OperationState& state,
4816                    ValueRange operands, int64_t dimension, bool isStable) {
4817   state.addOperands(operands);
4818   state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
4819   state.addAttribute("is_stable", builder.getBoolAttr(isStable));
4820 
4821   for (Value operand : operands) state.addTypes(operand.getType());
4822 
4823   state.addRegion();
4824 }
4825 
verify()4826 LogicalResult SortOp::verify() {
4827   Operation::operand_range operands = this->operands();
4828   if (operands.empty()) return emitOpError("requires at least one input");
4829 
4830   // TODO(antiagainst): verify partionally dynamic shapes
4831   if (llvm::all_of(operands, [](Value operand) {
4832         return operand.getType().cast<ShapedType>().hasRank();
4833       })) {
4834     ArrayRef<int64_t> inputShape =
4835         (*operands.begin()).getType().cast<ShapedType>().getShape();
4836 
4837     if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
4838           return operand.getType().cast<ShapedType>().getShape() != inputShape;
4839         }))
4840       return emitOpError("requires all inputs to have the same dimensions");
4841 
4842     int64_t rank = inputShape.size();
4843     int64_t cmpDim = dimension();
4844     if (cmpDim < -rank || cmpDim >= rank)
4845       return emitOpError("dimension attribute value must be in range [-")
4846              << rank << ", " << rank << "), but found " << cmpDim;
4847   }
4848 
4849   Block& block = comparator().front();
4850   size_t numOperands = getOperation()->getNumOperands();
4851   if (block.getNumArguments() != 2 * numOperands)
4852     return emitOpError("comparator block should have ")
4853            << 2 * numOperands << " arguments";
4854 
4855   for (const auto& indexedOperand : llvm::enumerate(operands)) {
4856     int index = indexedOperand.index();
4857     Type elementType =
4858         indexedOperand.value().getType().cast<ShapedType>().getElementType();
4859     Type tensorType = RankedTensorType::get({}, elementType);
4860     for (int i : {2 * index, 2 * index + 1}) {
4861       Type argType = block.getArgument(i).getType();
4862       if (argType != tensorType)
4863         return emitOpError("comparator block argument #")
4864                << i << " should be of type " << tensorType << " but got "
4865                << argType;
4866     }
4867   }
4868 
4869   // Mapped computation must return single output.
4870   auto comparatorResult = block.getTerminator()->getOperands();
4871   if (comparatorResult.size() != 1)
4872     return emitOpError() << "comparator must return single output, but got: "
4873                          << comparatorResult.size();
4874 
4875   // The output of computation must be 0-ranked tensor with element-type i1.
4876   auto comparatorResultType =
4877       comparatorResult[0].getType().dyn_cast<RankedTensorType>();
4878   if (!comparatorResultType || comparatorResultType.getRank() != 0 ||
4879       !comparatorResultType.getElementType().isInteger(1))
4880     return emitOpError() << "comparator must return tensor<i1>, but got: "
4881                          << comparatorResult[0].getType();
4882 
4883   // check number of return-values and their element-types.
4884   auto resultTypes = getResultTypes();
4885   if (resultTypes.size() != numOperands)
4886     return emitOpError() << "expects the number of results to be same as "
4887                             "number of operands. Got number of results = "
4888                          << resultTypes.size()
4889                          << " and number of operands = " << numOperands;
4890 
4891   for (auto it : llvm::zip(operands, getResultTypes()))
4892     if (std::get<0>(it).getType().cast<TensorType>().getElementType() !=
4893         std::get<1>(it).cast<TensorType>().getElementType())
4894       return emitOpError()
4895              << "expects the operands and results to have pairwize equal "
4896                 "element-types, but got "
4897              << std::get<0>(it).getType().cast<TensorType>().getElementType()
4898              << " vs " << std::get<1>(it).cast<TensorType>().getElementType();
4899 
4900   return success();
4901 }
4902 
4903 //===----------------------------------------------------------------------===//
4904 // TransposeOp
4905 //===----------------------------------------------------------------------===//
4906 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4907 LogicalResult TransposeOp::reifyReturnTypeShapes(
4908     OpBuilder& builder, ValueRange operands,
4909     SmallVectorImpl<Value>& reifiedReturnShapes) {
4910   TransposeOp::Adaptor adaptor(operands);
4911   Value operand = adaptor.operand();
4912 
4913   auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4914   // Not support unranked type a.t.m.
4915   if (!operandType) return failure();
4916 
4917   Location loc = this->getLoc();
4918   SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
4919   SmallVector<Value, 4> shapeValues(permutation.size());
4920 
4921   Type shapeScalarType = builder.getIndexType();
4922   auto toShapeScalarType = [&](Value v) {
4923     return maybeCastTo(builder, loc, v, shapeScalarType);
4924   };
4925 
4926   for (const auto& element : llvm::enumerate(operandType.getShape())) {
4927     int64_t idx = element.index();
4928     auto* it = std::find(permutation.begin(), permutation.end(), idx);
4929     Value valueDim = toShapeScalarType(
4930         builder.createOrFold<tensor::DimOp>(loc, operand, element.index()));
4931     shapeValues[std::distance(permutation.begin(), it)] = valueDim;
4932   }
4933 
4934   Value outputShape = builder.create<tensor::FromElementsOp>(
4935       loc,
4936       RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4937                             shapeScalarType),
4938       shapeValues);
4939   reifiedReturnShapes.push_back(outputShape);
4940 
4941   return success();
4942 }
4943 
4944 // Method for InferTypeOpInterface: infer the return type from the operand type
4945 // and the permutation.
inferReturnTypes(MLIRContext *,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)4946 LogicalResult TransposeOp::inferReturnTypes(
4947     MLIRContext* /*context*/, Optional<Location> loc, ValueRange operands,
4948     DictionaryAttr attributes, RegionRange,
4949     SmallVectorImpl<Type>& inferredReturnTypes) {
4950   auto type = operands[0].getType();
4951   auto rankedTy = type.dyn_cast<RankedTensorType>();
4952   if (!rankedTy) {
4953     auto shapedTy = type.dyn_cast<ShapedType>();
4954     inferredReturnTypes.emplace_back(shapedTy);
4955     return success();
4956   }
4957   auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation");
4958   int64_t rank = rankedTy.getRank();
4959   if (permutation.getType().getRank() != 1)
4960     return emitOptionalError(loc, "TransposeOp permutation has rank ",
4961                              permutation.getType().getRank(),
4962                              " instead of rank 1");
4963 
4964   if (permutation.size() != rank)
4965     return emitOptionalError(loc, "TransposeOp operand rank ", rank,
4966                              " does not match permutation size ",
4967                              permutation.size());
4968 
4969   std::vector<int64_t> range(rank);
4970   std::iota(range.begin(), range.end(), 0);
4971   if (!std::is_permutation(range.begin(), range.end(), permutation.begin()))
4972     return emitOptionalError(loc,
4973                              "attribute permutation must be a permutation"
4974                              " of [",
4975                              range, "] but got ", permutation);
4976 
4977   SmallVector<int64_t> resultShape;
4978   ArrayRef<int64_t> inputShape = rankedTy.getShape();
4979   for (int64_t dim : permutation.getValues<int64_t>()) {
4980     resultShape.push_back(inputShape[dim]);
4981   }
4982   inferredReturnTypes.emplace_back(RankedTensorType::get(
4983       resultShape, rankedTy.getElementType(), rankedTy.getEncoding()));
4984   return success();
4985 }
4986 
4987 //===----------------------------------------------------------------------===//
4988 // TriangularSolveOp
4989 //===----------------------------------------------------------------------===//
4990 
verify()4991 LogicalResult TriangularSolveOp::verify() {
4992   auto aType = a().getType().dyn_cast<RankedTensorType>();
4993 
4994   // Skip verifier if a is unranked tensor.
4995   if (!aType) return success();
4996 
4997   // Check that a should have rank >= 2
4998   auto aRank = aType.getRank();
4999   if (aRank < 2)
5000     return emitOpError() << "operand 'a' must have rank >= 2, but got "
5001                          << aType;
5002 
5003   // The two minor dimensions of a must have same size.
5004   if (aType.getDimSize(aRank - 2) != aType.getDimSize(aRank - 1))
5005     return emitOpError() << "two minor dimensions of operand 'a' must have "
5006                             "equal size, but got "
5007                          << aType;
5008 
5009   auto bType = b().getType().dyn_cast<RankedTensorType>();
5010   // If b is unranked skip remaining checks.
5011   if (!bType) return success();
5012 
5013   // Check that a and b have same rank.
5014   auto bRank = bType.getRank();
5015   if (aRank != bRank)
5016     return emitOpError() << "operands must have equal rank, but got " << aType
5017                          << " and " << bType;
5018 
5019   // The shared dimension of a and b should match.
5020   if (aType.getDimSize(aRank - 1) !=
5021       bType.getDimSize(bRank - (left_side() ? 2 : 1)))
5022     return emitOpError() << "shared dimension of operands 'a' and 'b' does "
5023                             "not match, but got "
5024                          << aType << " and " << bType;
5025 
5026   // The leading batch dimensions of a and b must be equal.
5027   auto aBatchDims = aType.getShape().drop_back(2);
5028   auto bBatchDims = bType.getShape().drop_back(2);
5029   if (aBatchDims != bBatchDims)
5030     return emitOpError()
5031            << "leading batch dimensions of the operands must be same, but got "
5032            << aType << " and " << bType;
5033 
5034   // Result and argument b must have same shape.
5035   auto resultType = getType().dyn_cast<RankedTensorType>();
5036   if (!resultType) return success();
5037   if (resultType != bType)
5038     return emitOpError()
5039            << "result and operand 'b' must have same shape, but got "
5040            << resultType << " and " << bType;
5041   return success();
5042 }
5043 
5044 //===----------------------------------------------------------------------===//
5045 // GetTupleElementOp
5046 //===----------------------------------------------------------------------===//
5047 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5048 LogicalResult GetTupleElementOp::inferReturnTypes(
5049     MLIRContext*, Optional<Location>, ValueRange operands,
5050     DictionaryAttr attributes, RegionRange,
5051     SmallVectorImpl<Type>& inferredReturnTypes) {
5052   auto tupleType = operands[0].getType().dyn_cast<TupleType>();
5053   if (!tupleType) return failure();
5054 
5055   auto indexAttr = attributes.get("index").cast<IntegerAttr>();
5056   auto index = indexAttr.getInt();
5057   if (index < 0 || index >= static_cast<int64_t>(tupleType.size()))
5058     return failure();
5059 
5060   inferredReturnTypes.push_back(tupleType.getType(index));
5061   return success();
5062 }
5063 
5064 //===----------------------------------------------------------------------===//
5065 // TupleOp
5066 //===----------------------------------------------------------------------===//
5067 
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5068 LogicalResult TupleOp::inferReturnTypes(
5069     MLIRContext* context, Optional<Location>, ValueRange operands,
5070     DictionaryAttr attributes, RegionRange,
5071     SmallVectorImpl<Type>& inferredReturnTypes) {
5072   inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
5073   return success();
5074 }
5075 
5076 //===----------------------------------------------------------------------===//
5077 // CompareOp
5078 //===----------------------------------------------------------------------===//
5079 
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,ComparisonDirection comparisonDirection,ComparisonType compareType)5080 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
5081                       Value rhs, ComparisonDirection comparisonDirection,
5082                       ComparisonType compareType) {
5083   build(builder, result, lhs, rhs,
5084         ComparisonDirectionAttr::get(builder.getContext(), comparisonDirection),
5085         ComparisonTypeAttr::get(builder.getContext(), compareType));
5086 }
5087 
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)5088 LogicalResult CompareOp::inferReturnTypeComponents(
5089     mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
5090     ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
5091     llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
5092   ShapedTypeComponents& components =
5093       inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1));
5094   auto argTy = operands.front().getType().cast<TensorType>();
5095   if (argTy.hasRank()) {
5096     components =
5097         ShapedTypeComponents(argTy.getShape(), components.getElementType());
5098   }
5099   return success();
5100 }
5101 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5102 LogicalResult CompareOp::reifyReturnTypeShapes(
5103     OpBuilder& builder, ValueRange operands,
5104     SmallVectorImpl<Value>& reifiedReturnShapes) {
5105   return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
5106                                      &reifiedReturnShapes);
5107 }
5108 
5109 //===----------------------------------------------------------------------===//
5110 // SelectAndScatterOp
5111 //===----------------------------------------------------------------------===//
5112 
5113 namespace {
5114 // Infer the return-type of SelectAndScatterOp.
inferSelectAndScatterOpReturnType(TensorType operandType,const ArrayRef<WindowDimension> window)5115 TensorType inferSelectAndScatterOpReturnType(
5116     TensorType operandType, const ArrayRef<WindowDimension> window) {
5117   if (!operandType.hasRank())
5118     return UnrankedTensorType::get(operandType.getElementType());
5119 
5120   return RankedTensorType::get(
5121       inferWindowOutputShape(operandType.getShape(), window),
5122       operandType.getElementType());
5123 }
5124 }  // namespace
5125 
5126 //  We intend to verify the following properties:
5127 //   P1. Check if the select function has a proper shape of (T,T) -> PRED, where
5128 //        T is a 0-D tensor with element-type same as 'operand' element-type.
5129 //   P2. Verify scatter-computation type.
5130 //   P3. size-of(window_dimension) == rank-of(input),
5131 //         where input is an element of 'inputs'.
5132 //   P4. Verify and collect the window attributes.
5133 //   P5. Verify the return type matches the operand-type.
5134 //   P6. Check if the result type of window operation matches the source type.
verify()5135 LogicalResult SelectAndScatterOp::verify() {
5136   auto operandType = operand().getType().cast<TensorType>();
5137   auto initValueType = init_value().getType().cast<TensorType>();
5138   auto sourceType = source().getType().cast<TensorType>();
5139   auto resultType = getResult().getType().cast<TensorType>();
5140 
5141   // P1.
5142   Block& selectBlock = select().front();
5143 
5144   if (selectBlock.getArguments().size() != 2)
5145     return emitOpError()
5146            << "expects the select-region to take 2 parameters, but takes "
5147            << selectBlock.getArguments().size();
5148 
5149   Type expectedSelectArgType =
5150       RankedTensorType::get({}, operandType.getElementType());
5151   for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments()))
5152     if (!compatibleShapeAndElementType(expectedSelectArgType,
5153                                        selectArgIt.value().getType(),
5154                                        /*ignoreFpPrecision=*/true))
5155       return emitOpError()
5156              << "expects the type of select-region's parameter at index "
5157              << selectArgIt.index() << " to be " << expectedSelectArgType
5158              << ", but got " << selectArgIt.value().getType();
5159 
5160   auto selectResult = selectBlock.getTerminator()->getOperands();
5161   if (selectResult.size() != 1)
5162     return emitOpError()
5163            << "expects select-region to return single value, but got: "
5164            << selectResult.size();
5165 
5166   auto selectResultType = selectResult[0].getType().dyn_cast<TensorType>();
5167   if (!selectResultType || !selectResultType.getElementType().isInteger(1) ||
5168       (selectResultType.hasRank() &&
5169        selectResultType.cast<RankedTensorType>().getRank() != 0))
5170     return emitOpError() << "expects the return-type of select-region to be "
5171                             "tensor<i1>, but got: "
5172                          << selectResult[0].getType();
5173 
5174   // P2.
5175   Block& scatterBlock = scatter().front();
5176   SmallVector<TensorType> accumulatorSubshapes;
5177   if (failed(verifyReducerShape(
5178           this->getLoc(), scatterBlock,
5179           {RankedTensorType::get({}, sourceType.getElementType())},
5180           {initValueType},
5181           /*numInputs=*/1, /*allowedDimensions=*/{},
5182           /*allInputsUnranked=*/false, accumulatorSubshapes)))
5183     return failure();
5184 
5185   // P3.
5186   SmallVector<int64_t> windowDims =
5187       convertDenseIntAttr(this->window_dimensions());
5188   if (operandType.hasRank()) {
5189     if (operandType.getRank() != static_cast<int64_t>(windowDims.size()))
5190       return emitOpError()
5191              << "expects window-dimensions size == operand rank, but got "
5192                 "window-dimensions size: "
5193              << windowDims.size() << " and operand-type: " << operandType
5194              << " with rank = " << operandType.getRank() << ".";
5195   }
5196 
5197   // P4.
5198   auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
5199   if (failed(paddingOrErr)) return failure();
5200   SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
5201 
5202   auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
5203       windowDims, convertDenseIntAttr(window_strides()), padding,
5204       /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc());
5205   if (failed(windowOrErr)) return failure();
5206 
5207   // P5.
5208   if (!compatibleShapeAndElementType(operandType, resultType))
5209     return emitOpError()
5210            << "expects the return-type to match the operand-type, but got "
5211            << resultType << " and " << operandType << " resp.";
5212 
5213   // P6.
5214   auto windowResultType =
5215       inferSelectAndScatterOpReturnType(operandType, *windowOrErr);
5216 
5217   if (!compatibleShapeAndElementType(windowResultType, sourceType,
5218                                      /*ignoreFpPrecision=*/true))
5219     return emitOpError() << "expects source-type to be " << windowResultType
5220                          << ", but got" << sourceType;
5221 
5222   return success();
5223 }
5224 
5225 //===----------------------------------------------------------------------===//
5226 // ScatterOp
5227 //===----------------------------------------------------------------------===//
5228 
5229 /*
5230  * We intend to verify the following properties:
5231  * P1. The 'update_window_dims' must be valid indices of 'updates' tensor.
5232  * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor.
5233  * P3. Check if the rank-of('operand') == size-of('update_window_dims') +
5234  *     size-of('inserted_window_dims')
5235  * P4. size-of('scatter_dims_to_operand_dims') =
5236  *         'scatter_indices'['index_vector_dim'] &
5237  *     'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor.
5238  */
validateScatterDimensionNumbers(ShapedType operandType,ArrayRef<int64_t> scatterIndicesShape,ShapedType updateType,bool operandTypeRanked,bool scatterIndicesTypeRanked,bool updatesTypeRanked,ScatterDimensionNumbersAttr dimNumbers,Location loc)5239 LogicalResult validateScatterDimensionNumbers(
5240     ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape,
5241     ShapedType updateType, bool operandTypeRanked,
5242     bool scatterIndicesTypeRanked, bool updatesTypeRanked,
5243     ScatterDimensionNumbersAttr dimNumbers, Location loc) {
5244   // P1.
5245   auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims());
5246   if (!llvm::is_sorted(updateWindowDims))
5247     return mlir::emitError(loc)
5248            << "Expects update_window_dims to be sorted; got: ["
5249            << updateWindowDims << "].";
5250 
5251   if (hasDuplicates(updateWindowDims))
5252     return mlir::emitError(loc)
5253            << "Expects update_window_dims to not repeat; got: ["
5254            << updateWindowDims << "].";
5255 
5256   if (updatesTypeRanked) {
5257     for (int64_t windowDim : updateWindowDims) {
5258       if (windowDim < 0 || windowDim >= updateType.getRank()) {
5259         return mlir::emitError(loc)
5260                << "Expects each element of update_window_dims to be in range "
5261                   "[0, "
5262                   "rank-of('updates') i.e. [0, "
5263                << updateType.getRank() << "). got: " << windowDim << ".";
5264       }
5265     }
5266   }
5267 
5268   // P2.
5269   auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims());
5270   if (!llvm::is_sorted(insertedWindowDims))
5271     return mlir::emitError(loc)
5272            << "Expects inserted_window_dims to be sorted; got: ["
5273            << insertedWindowDims << "].";
5274 
5275   if (hasDuplicates(insertedWindowDims))
5276     return mlir::emitError(loc)
5277            << "Expects inserted_window_dims to not repeat; got: ["
5278            << insertedWindowDims << "].";
5279 
5280   if (operandTypeRanked) {
5281     for (int64_t insertedDim : insertedWindowDims) {
5282       if (insertedDim < 0 || insertedDim >= operandType.getRank()) {
5283         return mlir::emitError(loc)
5284                << "Expects each element of inserted_window_dims to be in range "
5285                   "[0, rank-of('operand') i.e. [0, "
5286                << operandType.getRank() << "). got: " << insertedDim << ".";
5287       }
5288     }
5289   }
5290 
5291   // P3.
5292   if (operandTypeRanked) {
5293     auto windowSize = updateWindowDims.size() + insertedWindowDims.size();
5294     if (operandType.getRank() != static_cast<int64_t>(windowSize))
5295       return mlir::emitError(loc)
5296              << "Expects rank-of operand to match "
5297                 "size-of('update_window_dims')  + "
5298                 "size-of('inserted_window_dims') i.e. "
5299              << windowSize << " but got " << operandType.getRank() << ".";
5300   }
5301 
5302   // P4.
5303   auto scatterDimsToOperandDims =
5304       to_vector(dimNumbers.getScatterDimsToOperandDims());
5305   auto indexVectorDim = dimNumbers.getIndexVectorDim();
5306   if (scatterIndicesTypeRanked) {
5307     if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) &&
5308         static_cast<int64_t>(scatterDimsToOperandDims.size()) !=
5309             scatterIndicesShape[dimNumbers.getIndexVectorDim()])
5310       return mlir::emitError(loc)
5311              << "Scatter op has " << scatterDimsToOperandDims.size()
5312              << " elements in scatter_dims_to_operand_dims and the bound of "
5313                 "dimension index_vector_dim="
5314              << dimNumbers.getIndexVectorDim() << " of scatter_indices is "
5315              << scatterIndicesShape[dimNumbers.getIndexVectorDim()]
5316              << ". These two numbers must be equal.";
5317   }
5318 
5319   if (operandTypeRanked) {
5320     for (int64_t i = 0;
5321          i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) {
5322       int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i];
5323       if (scatterDimToOperandDim < 0 ||
5324           scatterDimToOperandDim >= operandType.getRank())
5325         return mlir::emitError(loc)
5326                << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, "
5327                << operandType.getRank() << "), got: " << i << "->"
5328                << scatterDimToOperandDim << ".";
5329     }
5330   }
5331 
5332   if (hasDuplicates(scatterDimsToOperandDims))
5333     return mlir::emitError(loc)
5334            << "Expects scatter_dims_to_operand_dims to not repeat; got: ["
5335            << scatterDimsToOperandDims << "].";
5336 
5337   return success();
5338 }
5339 /*
5340  * We intend to verify the following properties:
5341  *  P0. scatter_indices argument must be an integral tensor. Enforced by ODS.
5342  *  P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)"
5343  *      " + 1).
5344  *  P2. Verify reducer shape.
5345  *  P3. rank-of('updates[i]') == size-of('update_window_dims') +
5346  *      rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a
5347  *      trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')
5348  *      for all values of `i`.
5349  *  P4. Validate the scatter-dimensions-numbers.
5350  *  P5. Valide the bounds of each of the 'updates' w.r.t the operands.
5351  *  P6. Validate the bounds of each of the 'updates' w.r.t the
5352  * 'scatter_indices'.
5353  *  P7. Check return types.
5354  */
verify()5355 LogicalResult ScatterOp::verify() {
5356   // Get the first operand and update, since variadic Scatter is not yet
5357   // implemented
5358   auto numOperands = operands().size();
5359   auto scatterIndicesType = scatter_indices().getType().dyn_cast<TensorType>();
5360 
5361   SmallVector<TensorType, 1> operandTypes =
5362       llvm::to_vector(llvm::map_range(operands().getTypes(), [](Type type) {
5363         return type.cast<TensorType>();
5364       }));
5365   SmallVector<TensorType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
5366       updates().getTypes(), [](Type type) { return type.cast<TensorType>(); }));
5367   bool allOperandTypesRanked =
5368       llvm::all_of(operands().getTypes(),
5369                    [](Type type) { return type.isa<RankedTensorType>(); });
5370   bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>();
5371 
5372   // P1.
5373   int64_t indexVectorDim = scatter_dimension_numbers().getIndexVectorDim();
5374   if (scatterIndicesTypeRanked) {
5375     if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0)
5376       return emitOpError()
5377              << "expects scatter index leaf dimension to be within [0, "
5378                 "rank(scatter_indices) + 1."
5379                 " rank(scatter_indices) is "
5380              << scatterIndicesType.getRank()
5381              << " and scatter index leaf dimension is " << indexVectorDim
5382              << ".";
5383   }
5384 
5385   // P2.
5386   Block& block = update_computation().front();
5387   SmallVector<TensorType> accumulatorSubshapes;
5388   SmallVector<TensorType> inputTypes, initValueTypes;
5389   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5390     inputTypes.push_back(operandTypes[i]);
5391     initValueTypes.push_back(
5392         RankedTensorType::get({}, updatesTypes[i].getElementType()));
5393   }
5394   if (failed(verifyReducerShape(
5395           this->getLoc(), block, inputTypes, initValueTypes, numOperands,
5396           /*allowedDimensions=*/{},
5397           /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes)))
5398     return failure();
5399 
5400   // P3.
5401   auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
5402   SmallVector<int64_t> expandedScatterIndicesShape;
5403   if (scatterIndicesTypeRanked) {
5404     expandedScatterIndicesShape =
5405         llvm::to_vector(scatterIndicesType.getShape());
5406     if (static_cast<int64_t>(expandedScatterIndicesShape.size()) ==
5407         indexVectorDim)
5408       expandedScatterIndicesShape.push_back(1);
5409   }
5410 
5411   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5412     if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) {
5413       int64_t expectedUpdatesRank =
5414           expandedScatterIndicesShape.size() - 1 + updateWindowDims.size();
5415       if (updatesTypes[i].getRank() != expectedUpdatesRank)
5416         return emitOpError()
5417                << "expects updates tensor must be of rank "
5418                << expectedUpdatesRank
5419                << " ( == rank-of('scatter_indices') - 1 + "
5420                   "size-of('update_window_dims'), where 'scatter_indices' is "
5421                   "expanded by a trailing 1 dimension if 'index_vector_dim' == "
5422                   "rank-of('scatter_indices')), but got "
5423                << updatesTypes[i].getRank() << ".";
5424     }
5425   }
5426 
5427   // P4.
5428   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5429     if (failed(validateScatterDimensionNumbers(
5430             operandTypes[i], expandedScatterIndicesShape, updatesTypes[i],
5431             operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked,
5432             updatesTypes[i].isa<RankedTensorType>(),
5433             scatter_dimension_numbers(), getLoc())))
5434       return failure();
5435   }
5436 
5437   // P5.
5438   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5439     if (updatesTypes[i].isa<RankedTensorType>()) {
5440       auto updatesShape = updatesTypes[i].getShape();
5441       if (operandTypes[i].isa<RankedTensorType>()) {
5442         auto operandShape = operandTypes[i].getShape();
5443         auto insertedWindowDims =
5444             scatter_dimension_numbers().getInsertedWindowDims();
5445 
5446         int64_t insertedDimsSeen = 0;
5447         SmallVector<int64_t> maxUpdateSliceSizes;
5448         const auto dimensionsSize = operandTypes[i].getRank();
5449         maxUpdateSliceSizes.reserve(dimensionsSize);
5450         for (int i = 0; i < dimensionsSize; ++i) {
5451           if (insertedDimsSeen <
5452                   static_cast<int64_t>(insertedWindowDims.size()) &&
5453               insertedWindowDims[insertedDimsSeen] == i) {
5454             ++insertedDimsSeen;
5455           } else {
5456             maxUpdateSliceSizes.push_back(operandShape[i]);
5457           }
5458         }
5459 
5460         for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size());
5461              ++i) {
5462           auto updateWindowDim = updateWindowDims[i];
5463 
5464           if (isDynamicDimSize(updatesShape[updateWindowDim]) ||
5465               isDynamicDimSize(maxUpdateSliceSizes[i]))
5466             continue;
5467 
5468           if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) {
5469             return emitOpError()
5470                    << "expects bounds of the window dimensions of "
5471                       "updates to not exceed the "
5472                       "bounds of the corresponding dimensions of "
5473                       "operand. For dimension "
5474                    << updateWindowDim << ", updates bound is "
5475                    << updatesShape[updateWindowDim] << ", operand bound is "
5476                    << maxUpdateSliceSizes[i] << ".";
5477           }
5478         }
5479       }
5480 
5481       // P6.
5482       if (scatterIndicesTypeRanked) {
5483         int64_t scatterDimsSeen = 0;
5484         for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size());
5485              ++i) {
5486           bool isUpdateWindowDim = std::binary_search(
5487               updateWindowDims.begin(), updateWindowDims.end(), i);
5488 
5489           if (isUpdateWindowDim) continue;
5490           if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen;
5491 
5492           if (!isDynamicDimSize(updatesShape[i]) &&
5493               !isDynamicDimSize(expandedScatterIndicesShape[scatterDimsSeen]) &&
5494               (updatesShape[i] !=
5495                expandedScatterIndicesShape[scatterDimsSeen])) {
5496             return emitOpError()
5497                    << "expects bounds of the scatter dimensions of "
5498                       "updates to be same as the "
5499                       "bounds of the corresponding dimensions of "
5500                       "scatter indices. For "
5501                       "scatter dimension "
5502                    << i << ", updates bound is " << updatesShape[i]
5503                    << " , scatter_indices "
5504                       "bound is "
5505                    << expandedScatterIndicesShape[scatterDimsSeen] << ".";
5506           }
5507           ++scatterDimsSeen;
5508         }
5509       }
5510     }
5511   }
5512 
5513   // P7.
5514   for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5515     if (!compatibleShapeAndElementType(operandTypes[i], getResult(i).getType()))
5516       return emitOpError()
5517              << "expects the return type to be same as the operand type: "
5518              << operandTypes[i] << ", but got " << getResult(i).getType()
5519              << ".";
5520   }
5521 
5522   return success();
5523 }
5524 
5525 //===----------------------------------------------------------------------===//
5526 // WhileOp
5527 //===----------------------------------------------------------------------===//
5528 
verify()5529 LogicalResult WhileOp::verify() {
5530   if (getNumOperands() != cond().front().getNumArguments())
5531     return emitOpError() << "mismatch in operand count (" << getNumOperands()
5532                          << ") vs the condition block argument count ("
5533                          << cond().front().getNumArguments() << ")";
5534   if (getNumOperands() != body().front().getNumArguments())
5535     return emitOpError() << "mismatch in operand count (" << getNumOperands()
5536                          << ") vs the body block argument count ("
5537                          << body().front().getNumArguments() << ")";
5538   for (const auto& enumeratedOperands : llvm::enumerate(
5539            llvm::zip(getOperandTypes(), cond().front().getArgumentTypes(),
5540                      body().front().getArgumentTypes()))) {
5541     int argCount = enumeratedOperands.index();
5542     const auto& operands = enumeratedOperands.value();
5543     Type operandType = std::get<0>(operands);
5544     Type condType = std::get<1>(operands);
5545     Type bodyType = std::get<2>(operands);
5546     if (operandType != condType)
5547       return emitOpError() << "type mismatch between operand #" << argCount
5548                            << " and the matching condition block argument: "
5549                            << operandType << " vs " << condType;
5550     if (operandType != bodyType)
5551       return emitOpError() << "type mismatch between operand #" << argCount
5552                            << " and the matching body block argument: "
5553                            << operandType << " vs " << bodyType;
5554   }
5555   // Check the return type for the condition block.
5556   {
5557     auto condReturnOp = cast<ReturnOp>(cond().front().back());
5558     if (condReturnOp->getNumOperands() != 1)
5559       return condReturnOp.emitOpError()
5560              << "expects a single operand for while condition body return, got "
5561              << condReturnOp->getNumOperands();
5562     auto operandType =
5563         condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
5564     if (!operandType || operandType.getRank() != 0 ||
5565         !operandType.getElementType().isInteger(1))
5566       return condReturnOp.emitOpError()
5567              << "expects a zero-ranked tensor of i1, got "
5568              << condReturnOp->getOperand(0).getType();
5569   }
5570   // Check the return type for the body block.
5571   {
5572     auto bodyReturnOp = cast<ReturnOp>(body().front().back());
5573     if (bodyReturnOp->getNumOperands() != getNumOperands())
5574       return bodyReturnOp.emitOpError()
5575              << "expects body to return a many value as the operands ("
5576              << getNumOperands() << "), got " << bodyReturnOp->getNumOperands();
5577     for (const auto& enumeratedOperandTypes : llvm::enumerate(
5578              llvm::zip(bodyReturnOp->getOperandTypes(), getOperandTypes()))) {
5579       Type operandType = std::get<0>(enumeratedOperandTypes.value());
5580       Type returnType = std::get<1>(enumeratedOperandTypes.value());
5581       if (operandType != returnType)
5582         return bodyReturnOp.emitOpError()
5583                << "type mismatch between operand #"
5584                << enumeratedOperandTypes.index()
5585                << " and the enclosing WhileOp returned value: " << operandType
5586                << " vs " << returnType;
5587     }
5588   }
5589   return success();
5590 }
5591 
5592 /// Print a `while` op.
5593 ///
5594 /// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict
5595 ///         `cond` region
5596 ///         `do` region
5597 /// assignment-list ::= assignment | assignment `,` assignment-list
5598 /// assignment ::= ssa-value `=` ssa-value
print(OpAsmPrinter & p)5599 void WhileOp::print(OpAsmPrinter& p) {
5600   p << '(';
5601   llvm::interleaveComma(llvm::zip(getBody()->getArguments(), getOperands()), p,
5602                         [&](auto zip) {
5603                           p.printOperand(std::get<0>(zip));
5604                           p << " = ";
5605                           p.printOperand(std::get<1>(zip));
5606                         });
5607   p << ")";
5608   if (getNumOperands()) {
5609     p << " : ";
5610     llvm::interleaveComma(getOperandTypes(), p);
5611   }
5612   p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
5613   p.printNewline();
5614   p << " cond ";
5615   p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false);
5616   p << " do ";
5617   p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false);
5618 }
5619 
parse(OpAsmParser & parser,OperationState & result)5620 ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) {
5621   llvm::SMLoc loc = parser.getCurrentLocation();
5622   // Parse the operands of the while: these are of the form:
5623   //   %iter_arg = %init_val
5624   // where %iter_arg is the name of the block argument in the cond/body blocks
5625   // and %init_val is the actual operand.
5626   SmallVector<OpAsmParser::UnresolvedOperand> operands;
5627   SmallVector<OpAsmParser::UnresolvedOperand> iterArgs;
5628   if (parser.parseLParen()) return failure();
5629   do {
5630     if (succeeded(parser.parseOptionalRParen())) break;
5631     OpAsmParser::UnresolvedOperand operand, iterArg;
5632     if (parser.parseOperand(iterArg) || parser.parseEqual() ||
5633         parser.parseOperand(operand))
5634       return failure();
5635     iterArgs.push_back(iterArg);
5636     operands.push_back(operand);
5637     if (succeeded(parser.parseOptionalRParen())) break;
5638     if (failed(parser.parseComma())) return failure();
5639   } while (true);
5640   if (!operands.empty()) {
5641     if (parser.parseColon() || parser.parseTypeList(result.types))
5642       return failure();
5643   }
5644 
5645   SmallVector<OpAsmParser::Argument> args;
5646   createArgs(iterArgs, result.types, args);
5647   if (parser.resolveOperands(operands, result.types, loc, result.operands) ||
5648       parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
5649       parser.parseKeyword("cond") ||
5650       parser.parseRegion(*result.addRegion(), args) ||
5651       parser.parseKeyword("do") ||
5652       parser.parseRegion(*result.addRegion(), args))
5653     return failure();
5654   return success();
5655 }
5656 
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5657 LogicalResult UniformDequantizeOp::inferReturnTypeComponents(
5658     MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
5659     DictionaryAttr attributes, RegionRange regions,
5660     SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5661   UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions);
5662   auto operandType = (*operands.begin()).getType().cast<ShapedType>();
5663   // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
5664   auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
5665   auto shape = operandType.dyn_cast<ShapedType>().getShape();
5666   inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
5667   return success();
5668 }
5669 
5670 }  // namespace stablehlo
5671 }  // namespace mlir
5672 
5673 #define GET_OP_CLASSES
5674 #include "dialect/StablehloOps.cpp.inc"
5675 
5676 namespace mlir {
5677 namespace stablehlo {
5678 
5679 //===----------------------------------------------------------------------===//
5680 // StableHLO Dialect Interfaces
5681 //===----------------------------------------------------------------------===//
5682 
5683 namespace {
5684 struct HLOBoundedDialectInterface : public hlo::BoundedDialectInterface {
5685   using BoundedDialectInterface::BoundedDialectInterface;
5686 
createBoundedAttrmlir::stablehlo::__anona165a9102d11::HLOBoundedDialectInterface5687   Attribute createBoundedAttr(ArrayRef<int64_t> bounds) const override {
5688     return TypeExtensionsAttr::get(getDialect()->getContext(), bounds);
5689   }
5690 };
5691 }  // end anonymous namespace
5692 
5693 //===----------------------------------------------------------------------===//
5694 // StableHLO Dialect Constructor
5695 //===----------------------------------------------------------------------===//
5696 
StablehloDialect(MLIRContext * context)5697 StablehloDialect::StablehloDialect(MLIRContext* context)
5698     : Dialect(getDialectNamespace(), context, TypeID::get<StablehloDialect>()) {
5699   addOperations<
5700 #define GET_OP_LIST
5701 #include "dialect/StablehloOps.cpp.inc"
5702       >();
5703   addInterfaces<HLOBoundedDialectInterface>();
5704   addTypes<TokenType>();
5705   addAttributes<
5706 #define GET_ATTRDEF_LIST
5707 #include "dialect/StablehloAttrs.cpp.inc"
5708       >();
5709   context->loadDialect<tensor::TensorDialect>();
5710 }
5711 
parseType(DialectAsmParser & parser) const5712 Type StablehloDialect::parseType(DialectAsmParser& parser) const {
5713   StringRef dataType;
5714   if (parser.parseKeyword(&dataType)) return Type();
5715 
5716   if (dataType == "token") return TokenType::get(getContext());
5717   parser.emitError(parser.getNameLoc())
5718       << "unknown stablehlo type: " << dataType;
5719   return nullptr;
5720 }
5721 
printType(Type type,DialectAsmPrinter & os) const5722 void StablehloDialect::printType(Type type, DialectAsmPrinter& os) const {
5723   if (type.isa<TokenType>()) {
5724     os << "token";
5725     return;
5726   }
5727   os << "<unknown stablehlo type>";
5728 }
5729 
5730 // Entry point for Attribute parsing, TableGen generated code will handle the
5731 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const5732 Attribute StablehloDialect::parseAttribute(DialectAsmParser& parser,
5733                                            Type type) const {
5734   StringRef attrTag;
5735   Attribute attr;
5736   auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
5737   if (parseResult.hasValue()) return attr;
5738   parser.emitError(parser.getNameLoc(), "unknown stablehlo attribute");
5739   return Attribute();
5740 }
5741 
5742 // Entry point for Attribute printing, TableGen generated code will handle the
5743 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const5744 void StablehloDialect::printAttribute(Attribute attr,
5745                                       DialectAsmPrinter& os) const {
5746   LogicalResult result = generatedAttributePrinter(attr, os);
5747   (void)result;
5748   assert(succeeded(result));
5749 }
5750 
5751 /// Helpers for attributes parsing.
5752 
parseDims(AsmParser & parser,SmallVector<int64_t> & dims)5753 static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) {
5754   dims.clear();
5755   if (parser.parseLSquare()) return failure();
5756   while (failed(parser.parseOptionalRSquare())) {
5757     dims.emplace_back();
5758     if (parser.parseInteger(dims.back())) return failure();
5759     (void)parser.parseOptionalComma();
5760   }
5761   return success();
5762 }
5763 
parseDimsWithMinimumElements(AsmParser & parser,SmallVector<int64_t> & dims,int minElements)5764 static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
5765                                                 SmallVector<int64_t>& dims,
5766                                                 int minElements) {
5767   if (failed(parseDims(parser, dims))) return failure();
5768   if (static_cast<int64_t>(dims.size()) < minElements)
5769     return parser.emitError(parser.getCurrentLocation())
5770            << "expected at least " << minElements << " element(s), found "
5771            << dims.size();
5772   return success();
5773 }
5774 
5775 /// Parse a custom attribute that resembles a struct of the form
5776 /// <
5777 ///   foo = something_parsed_by_custom_parser,
5778 ///   bar = something_parsed_by_different_custom_parser,
5779 ///   baz something_parsed_by_another_custom_parser
5780 /// >
5781 /// The optional argument `parse_equal` array can be used to denote if
5782 /// '=' follows the keyword (see baz in the example above) for a field. If
5783 /// not provided, all fields must be followed by a '='.
parseStruct(AsmParser & parser,ArrayRef<StringRef> keywords,ArrayRef<llvm::function_ref<ParseResult ()>> parseFuncs,ArrayRef<bool> parseEqual={})5784 static ParseResult parseStruct(
5785     AsmParser& parser, ArrayRef<StringRef> keywords,
5786     ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs,
5787     ArrayRef<bool> parseEqual = {}) {
5788   assert(keywords.size() == parseFuncs.size());
5789   assert(parseEqual.empty() || parseEqual.size() == keywords.size());
5790   SmallVector<bool> seen(keywords.size(), false);
5791   while (failed(parser.parseOptionalGreater())) {
5792     bool foundOne = false;
5793     for (const auto& it : llvm::enumerate(keywords)) {
5794       size_t index = it.index();
5795       StringRef keyword = it.value();
5796       if (succeeded(parser.parseOptionalKeyword(keyword))) {
5797         if (seen[index]) {
5798           return parser.emitError(parser.getCurrentLocation())
5799                  << "duplicated `" << keyword << "` entry";
5800         }
5801         if (parseEqual.empty() || parseEqual[index]) {
5802           if (failed(parser.parseEqual())) return failure();
5803         }
5804         if (failed(parseFuncs[index]())) return failure();
5805         if (failed(parser.parseOptionalComma())) return parser.parseGreater();
5806         seen[index] = true;
5807         foundOne = true;
5808       }
5809     }
5810     if (!foundOne) {
5811       auto parseError = parser.emitError(parser.getCurrentLocation())
5812                         << "expected one of: ";
__anona165a9102e02(StringRef kw) 5813       llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
5814         parseError << '`' << kw << '`';
5815       });
5816       return parseError;
5817     }
5818   }
5819   return success();
5820 }
5821 
5822 // Helpers to print an optional array or integer field, to simplify writing
5823 // attribute printers.
5824 template <typename T>
printField(AsmPrinter & printer,StringRef name,T field,StringRef & separator)5825 static void printField(AsmPrinter& printer, StringRef name, T field,
5826                        StringRef& separator) {
5827   if (field != 0) {
5828     printer << separator << name << " = " << field;
5829     separator = ", ";
5830   }
5831 }
5832 template <typename T>
printField(AsmPrinter & printer,StringRef name,ArrayRef<T> field,StringRef & separator)5833 static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field,
5834                        StringRef& separator) {
5835   if (!field.empty()) {
5836     printer << separator << name << " = [";
5837     llvm::interleaveComma(field, printer);
5838     printer << "]";
5839     separator = ", ";
5840   }
5841 }
5842 template <typename... Ts>
printStruct(AsmPrinter & printer,StringRef name,Ts...printFields)5843 static void printStruct(AsmPrinter& printer, StringRef name,
5844                         Ts... printFields) {
5845   printer << "<";
5846   StringRef separator = "";
5847   // Fold expression to print each entry in the parameter pack.
5848   // TODO(stablehlo-team): this can be simplified when TF moves to C++17.
5849   using unused = int[];
5850   (void)unused{0, (printField(printer, std::get<0>(printFields),
5851                               std::get<1>(printFields), separator),
5852                    0)...};
5853   printer << ">";
5854 }
5855 
5856 // Custom printer and parser for ScatterDimensionNumbersAttr.
print(AsmPrinter & printer) const5857 void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
5858   printStruct(printer, "scatter",
5859               std::make_pair("update_window_dims", getUpdateWindowDims()),
5860               std::make_pair("inserted_window_dims", getInsertedWindowDims()),
5861               std::make_pair("scatter_dims_to_operand_dims",
5862                              getScatterDimsToOperandDims()),
5863               std::make_pair("index_vector_dim", getIndexVectorDim()));
5864 }
parse(AsmParser & parser,Type type)5865 Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5866   if (failed(parser.parseLess())) return {};
5867   SmallVector<int64_t> updateWindowDims;
5868   SmallVector<int64_t> insertedWindowDims;
5869   SmallVector<int64_t> scatterDimsToOperandDims;
5870   int64_t indexVectorDim = 0;
5871 
5872   if (failed(parseStruct(
5873           parser,
5874           {"update_window_dims", "inserted_window_dims",
5875            "scatter_dims_to_operand_dims", "index_vector_dim"},
5876           {[&]() { return parseDims(parser, updateWindowDims); },
5877            [&]() { return parseDims(parser, insertedWindowDims); },
5878            [&]() { return parseDims(parser, scatterDimsToOperandDims); },
5879            [&]() { return parser.parseInteger(indexVectorDim); }}))) {
5880     parser.emitError(parser.getCurrentLocation())
5881         << "failed parsing scatter dimension numbers attribute";
5882     return {};
5883   }
5884 
5885   return ScatterDimensionNumbersAttr::get(
5886       parser.getContext(), updateWindowDims, insertedWindowDims,
5887       scatterDimsToOperandDims, indexVectorDim);
5888 }
5889 
5890 // Custom printer and parser for GatherDimensionNumbersAttr.
print(AsmPrinter & printer) const5891 void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
5892   printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
5893               std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
5894               std::make_pair("start_index_map", getStartIndexMap()),
5895               std::make_pair("index_vector_dim", getIndexVectorDim()));
5896 }
5897 
parse(AsmParser & parser,Type type)5898 Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5899   if (failed(parser.parseLess())) return {};
5900 
5901   SmallVector<int64_t> offsetDims;
5902   SmallVector<int64_t> collapsedSliceDims;
5903   SmallVector<int64_t> startIndexMap;
5904   int64_t indexVectorDim = 0;
5905 
5906   if (failed(parseStruct(
5907           parser,
5908           {"offset_dims", "collapsed_slice_dims", "start_index_map",
5909            "index_vector_dim"},
5910           {[&]() { return parseDims(parser, offsetDims); },
5911            [&]() { return parseDims(parser, collapsedSliceDims); },
5912            [&]() { return parseDims(parser, startIndexMap); },
5913            [&]() { return parser.parseInteger(indexVectorDim); }}))) {
5914     parser.emitError(parser.getCurrentLocation())
5915         << "failed parsing gather dimension numbers attribute";
5916     return {};
5917   }
5918 
5919   return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims,
5920                                          collapsedSliceDims, startIndexMap,
5921                                          indexVectorDim);
5922 }
5923 
5924 // Custom printer and parser for DotDimensionNumbersAttr.
print(AsmPrinter & printer) const5925 void DotDimensionNumbersAttr::print(AsmPrinter& printer) const {
5926   printStruct(
5927       printer, "dot",
5928       std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()),
5929       std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()),
5930       std::make_pair("lhs_contracting_dimensions",
5931                      getLhsContractingDimensions()),
5932       std::make_pair("rhs_contracting_dimensions",
5933                      getRhsContractingDimensions()));
5934 }
5935 
parse(AsmParser & parser,Type type)5936 Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5937   if (failed(parser.parseLess())) return {};
5938 
5939   SmallVector<int64_t> lhsBatchingDimensions;
5940   SmallVector<int64_t> rhsBatchingDimensions;
5941   SmallVector<int64_t> lhsContractingDimensions;
5942   SmallVector<int64_t> rhsContractingDimensions;
5943 
5944   if (failed(parseStruct(
5945           parser,
5946           {"lhs_batching_dimensions", "rhs_batching_dimensions",
5947            "lhs_contracting_dimensions", "rhs_contracting_dimensions"},
5948           {[&]() { return parseDims(parser, lhsBatchingDimensions); },
5949            [&]() { return parseDims(parser, rhsBatchingDimensions); },
5950            [&]() { return parseDims(parser, lhsContractingDimensions); },
5951            [&]() { return parseDims(parser, rhsContractingDimensions); }}))) {
5952     parser.emitError(parser.getCurrentLocation())
5953         << "failed parsing dot dimension numbers attribute";
5954     return {};
5955   }
5956   return DotDimensionNumbersAttr::get(
5957       parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions,
5958       lhsContractingDimensions, rhsContractingDimensions);
5959 }
5960 
5961 namespace {
5962 enum NonSpatialDim : int64_t {
5963   IOBatch = -1,    // Input or output batch dimension
5964   IOFeature = -2,  // Input or output feature dimension
5965   KIFeature = -3,  // Kernel input feature dimension
5966   KOFeature = -4,  // Kernel output feature dimensions.
5967 };
5968 
5969 struct DenseMapInfoNonSpatialDim {
getEmptyKeymlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5970   static inline NonSpatialDim getEmptyKey() {
5971     return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey());
5972   }
5973 
getTombstoneKeymlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5974   static inline NonSpatialDim getTombstoneKey() {
5975     return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey());
5976   }
5977 
getHashValuemlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5978   static unsigned getHashValue(const NonSpatialDim& key) {
5979     return DenseMapInfo<int64_t>::getHashValue(key);
5980   }
5981 
isEqualmlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5982   static bool isEqual(const NonSpatialDim& lhs, const NonSpatialDim& rhs) {
5983     return lhs == rhs;
5984   }
5985 };
5986 
nonSpatialDimToString(NonSpatialDim dim)5987 char nonSpatialDimToString(NonSpatialDim dim) {
5988   switch (dim) {
5989     case IOBatch:
5990       return 'b';
5991     case IOFeature:
5992       return 'f';
5993     case KIFeature:
5994       return 'i';
5995     case KOFeature:
5996       return 'o';
5997   }
5998   llvm_unreachable("Unknown NonSpatialDim");
5999 }
6000 }  // namespace
6001 
6002 // Custom printer and parser for convolution attribute.
printConvolutionDimensions(AsmPrinter & p,ConvDimensionNumbersAttr dnums)6003 void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) {
6004   // TODO(b/202040055): we should check the attribute invariant and print the
6005   // "raw" form if they are violated, otherwise we'll crash here.
6006   constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min();
6007   auto printDim =
6008       [&](ArrayRef<int64_t> spatialDims,
6009           ArrayRef<std::pair<int64_t, NonSpatialDim>> nonSpatialDims) {
6010         int64_t numDims = 0;
6011         if (!spatialDims.empty()) {
6012           numDims =
6013               *std::max_element(spatialDims.begin(), spatialDims.end()) + 1;
6014         }
6015         for (const auto& dim : nonSpatialDims) {
6016           numDims = std::max(numDims, dim.first + 1);
6017         }
6018 
6019         llvm::SmallVector<int64_t> dims(numDims, kUnknownDim);
6020         // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
6021         // spatial dimension index.
6022         for (const std::pair<int64_t, NonSpatialDim>& nonSpatialDim :
6023              nonSpatialDims) {
6024           dims[nonSpatialDim.first] = nonSpatialDim.second;
6025         }
6026         for (const auto& spatialDim : llvm::enumerate(spatialDims)) {
6027           dims[spatialDim.value()] = static_cast<int64_t>(spatialDim.index());
6028         }
6029 
6030         // Each dimension numbers will be printed as a comma separated list
6031         // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
6032         p << '[';
6033         llvm::interleaveComma(dims, p, [&](int64_t dim) {
6034           if (dim == kUnknownDim) {
6035             p << "?";
6036           } else if (dim >= 0) {
6037             p << dim;
6038           } else {
6039             p << nonSpatialDimToString(static_cast<NonSpatialDim>(dim));
6040           }
6041         });
6042         p << ']';
6043       };
6044 
6045   printDim(dnums.getInputSpatialDimensions(),
6046            {{dnums.getInputBatchDimension(), IOBatch},
6047             {dnums.getInputFeatureDimension(), IOFeature}});
6048   p << "x";
6049   printDim(dnums.getKernelSpatialDimensions(),
6050            {{dnums.getKernelInputFeatureDimension(), KIFeature},
6051             {dnums.getKernelOutputFeatureDimension(), KOFeature}});
6052   p << "->";
6053   printDim(dnums.getOutputSpatialDimensions(),
6054            {{dnums.getOutputBatchDimension(), IOBatch},
6055             {dnums.getOutputFeatureDimension(), IOFeature}});
6056 }
6057 
printConvolutionDimensions(AsmPrinter & p,Operation *,ConvDimensionNumbersAttr dnums)6058 void printConvolutionDimensions(AsmPrinter& p, Operation*,
6059                                 ConvDimensionNumbersAttr dnums) {
6060   printConvolutionDimensions(p, dnums);
6061 }
6062 
6063 // Custom printer and parser for ConvDimensionNumbersAttr.
print(AsmPrinter & printer) const6064 void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const {
6065   printer << "<";
6066   printConvolutionDimensions(printer, *this);
6067   printer << ">";
6068 }
6069 
6070 // If the attribute is written with `#stablehlo.conv raw<`, we parse it as
6071 // a struct instead of the compressed format. This enables writing tests
6072 // covering impossible/invalid internal representation for the attribute.
parseConvolutionDimensionsRaw(AsmParser & parser,ConvDimensionNumbersAttr & dnums)6073 static ParseResult parseConvolutionDimensionsRaw(
6074     AsmParser& parser, ConvDimensionNumbersAttr& dnums) {
6075   int64_t inputBatchDimension = 0;
6076   int64_t inputFeatureDimension = 0;
6077   SmallVector<int64_t> inputSpatialDimensions;
6078   int64_t kernelInputFeatureDimension = 0;
6079   int64_t kernelOutputFeatureDimension = 0;
6080   SmallVector<int64_t> kernelSpatialDimensions;
6081   int64_t outBatchDimension = 0;
6082   int64_t outputFeatureDimension = 0;
6083   SmallVector<int64_t> outputSpatialDimensions;
6084   if (failed(parseStruct(
6085           parser,
6086           {"input_batch_dimension", "input_feature_dimension",
6087            "input_spatial_dimensions", "kernel_input_feature_dimension",
6088            "kernel_output_feature_dimension", "kernel_spatial_dimensions",
6089            "output_batch_dimension", "output_feature_dimension",
6090            "output_spatial_dimensions"},
6091           {
6092               [&]() { return parser.parseInteger(inputBatchDimension); },
6093               [&]() { return parser.parseInteger(inputFeatureDimension); },
6094               [&]() { return parseDims(parser, inputSpatialDimensions); },
6095               [&]() {
6096                 return parser.parseInteger(kernelInputFeatureDimension);
6097               },
6098               [&]() {
6099                 return parser.parseInteger(kernelOutputFeatureDimension);
6100               },
6101               [&]() { return parseDims(parser, kernelSpatialDimensions); },
6102               [&]() { return parser.parseInteger(outBatchDimension); },
6103               [&]() { return parser.parseInteger(outputFeatureDimension); },
6104               [&]() { return parseDims(parser, outputSpatialDimensions); },
6105           }))) {
6106     parser.emitError(parser.getCurrentLocation())
6107         << "failed parsing dot dimension numbers attribute";
6108     return failure();
6109   }
6110   dnums = ConvDimensionNumbersAttr::get(
6111       parser.getBuilder().getContext(), inputBatchDimension,
6112       inputFeatureDimension, inputSpatialDimensions,
6113       kernelInputFeatureDimension, kernelOutputFeatureDimension,
6114       kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
6115       outputSpatialDimensions);
6116   return success();
6117 }
6118 
parseConvolutionDimensions(AsmParser & parser,ConvDimensionNumbersAttr & dnums)6119 ParseResult parseConvolutionDimensions(AsmParser& parser,
6120                                        ConvDimensionNumbersAttr& dnums) {
6121   // Parsing a single set of dim numbers gives the spatial dimensions as a
6122   // single ArrayRef<int64_t> and a list of non-spatial dimensions as
6123   // IntegerAttrs (indexed by the NonSpatialDim enum).
6124   using parse_dim_result_t =
6125       std::pair<llvm::SmallVector<int64_t>,
6126                 llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
6127                                     DenseMapInfoNonSpatialDim>>;
6128 
6129   // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
6130   // set) because its used to print a list of allowed non spatial dims in the
6131   // error messages, so making it a set keeps the error messages deterministic.
6132   auto parseDims =
6133       [&](std::set<NonSpatialDim, std::greater<>> allowedNonSpatialDims,
6134           parse_dim_result_t& parsedDims) -> ParseResult {
6135     auto& spatialDims = std::get<0>(parsedDims);
6136     auto& nonSpatialDims = std::get<1>(parsedDims);
6137     spatialDims.clear();
6138     nonSpatialDims.clear();
6139 
6140     // Parse the starting [
6141     if (parser.parseLSquare()) {
6142       return failure();
6143     }
6144 
6145     llvm::SmallDenseMap<int64_t, int64_t> spatialDimsMap;
6146     constexpr int64_t kInvalidDimension = -1;
6147     // Keep track of the maximum spatial dimension parsed as we expect to see
6148     // all the dimensions from 0 to maximum dimension parsed.
6149     int64_t maxParsedSpatialDim = kInvalidDimension;
6150 
6151     int64_t index = 0;
6152     do {
6153       int64_t spatialDim;
6154       auto dimLocation = parser.getCurrentLocation();
6155       OptionalParseResult parseResult = parser.parseOptionalInteger(spatialDim);
6156       if (parseResult.hasValue()) {
6157         if (parseResult.getValue().failed()) {
6158           return failure();
6159         }
6160         // We were successful in parsing an integer. Check if it is a valid
6161         // dimension (non-negative and no duplicate) and add its index to the
6162         // spatial dims map.
6163         if (spatialDim < 0)
6164           return parser.emitError(dimLocation)
6165                  << "Unexpected dimension " << spatialDim;
6166         if (!spatialDimsMap
6167                  .insert(std::pair<int64_t, int64_t>(spatialDim, index))
6168                  .second)
6169           return parser.emitError(dimLocation)
6170                  << "Duplicate entries for spatial dimension " << spatialDim;
6171         maxParsedSpatialDim = std::max(spatialDim, maxParsedSpatialDim);
6172       } else if (!parser.parseOptionalQuestion()) {
6173         // Do nothing other than increment `index` at the bottom of the loop;
6174         // '?' means "unknown dimension", and it's not represented in the
6175         // return value of this function.
6176       } else {
6177         // We did not parse an integer or question mark. We expect a keyword
6178         // token.
6179         StringRef keyword;
6180         if (parser.parseKeyword(&keyword)) {
6181           return failure();
6182         }
6183         if (keyword.size() != 1 || allowedNonSpatialDims.empty()) {
6184           return parser.emitError(dimLocation, "Unexpected keyword ")
6185                  << keyword;
6186         }
6187         // Check if the keyword matches one of the allowed non-spatial dims.
6188         // If so, add it to the non_spatial dims and remove it from the
6189         // allowed set so that it won't be allowed again.
6190         bool isAllowed = false;
6191         for (NonSpatialDim allowed : allowedNonSpatialDims) {
6192           if (keyword[0] == nonSpatialDimToString(allowed)) {
6193             nonSpatialDims.insert({allowed, index});
6194             allowedNonSpatialDims.erase(allowed);
6195             isAllowed = true;
6196             break;
6197           }
6198         }
6199 
6200         if (!isAllowed) {
6201           mlir::InFlightDiagnostic diag =
6202               parser.emitError(dimLocation, "Unexpected dimension ");
6203           diag << keyword << ", expecting ";
6204           llvm::interleaveComma(
6205               allowedNonSpatialDims, diag,
6206               [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
6207           return diag;
6208         }
6209       }
6210       index++;
6211     } while (parser.parseOptionalComma().succeeded());
6212 
6213     // Make sure all expected non-spatial dimensions are parsed.
6214     if (!allowedNonSpatialDims.empty()) {
6215       mlir::InFlightDiagnostic diag =
6216           parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
6217       llvm::interleaveComma(
6218           allowedNonSpatialDims, diag,
6219           [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
6220       diag << " not specified";
6221       return diag;
6222     }
6223 
6224     // parse ending ]
6225     if (parser.parseRSquare()) {
6226       return failure();
6227     }
6228 
6229     // Number of expected spatial dimensions is one more than the maximum parsed
6230     // spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
6231     // maximum parsed spatial dimension is 3 and the number of expected spatial
6232     // dimensions is 4.
6233     int64_t numSpatialDimensions = maxParsedSpatialDim + 1;
6234     spatialDims.resize(numSpatialDimensions);
6235     // Store spatial dimensions in a vector which maps spatial dim (vector
6236     // index) -> index in the tensor dimensions. For example, for parsed
6237     // dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
6238     // be [0, 5, 2, 1].
6239     //
6240     // Get all the unspecified spatial dimensions to throw a more descriptive
6241     // error later.
6242     llvm::SmallVector<int64_t> unspecifiedSpatialDims;
6243     constexpr int kPrintUnspecifiedDimsMax = 10;
6244     for (int dim = 0; dim < numSpatialDimensions; ++dim) {
6245       auto it = spatialDimsMap.find(dim);
6246       if (it == spatialDimsMap.end()) {
6247         // Have an upper bound on the number of unspecified dimensions to print
6248         // in the error message.
6249         if (unspecifiedSpatialDims.size() < kPrintUnspecifiedDimsMax)
6250           unspecifiedSpatialDims.push_back(dim);
6251         continue;
6252       }
6253       spatialDims[dim] = it->second;
6254     }
6255 
6256     // Verify that we got all spatial dimensions between 0 and maximum parsed
6257     // spatial dimension.
6258     if (!unspecifiedSpatialDims.empty()) {
6259       mlir::InFlightDiagnostic diag = parser.emitError(
6260           parser.getCurrentLocation(), "Expected spatial dimensions ");
6261       llvm::interleaveComma(unspecifiedSpatialDims, diag);
6262       diag << " not specified";
6263       return diag;
6264     }
6265 
6266     return success();
6267   };
6268 
6269   parse_dim_result_t parsedDims;
6270   if (parseDims({IOBatch, IOFeature}, parsedDims)) {
6271     return failure();
6272   }
6273   llvm::SmallVector<int64_t> inputSpatialDimensions = parsedDims.first;
6274   int64_t inputBatchDimension = parsedDims.second[IOBatch];
6275   int64_t inputFeatureDimension = parsedDims.second[IOFeature];
6276   if (parser.parseKeyword("x")) return failure();
6277   if (parseDims({KIFeature, KOFeature}, parsedDims)) {
6278     return failure();
6279   }
6280   llvm::SmallVector<int64_t> kernelSpatialDimensions = parsedDims.first;
6281   int64_t kernelInputFeatureDimension = parsedDims.second[KIFeature];
6282   int64_t kernelOutputFeatureDimension = parsedDims.second[KOFeature];
6283   if (parser.parseArrow()) {
6284     return failure();
6285   }
6286   if (parseDims({IOBatch, IOFeature}, parsedDims)) {
6287     return failure();
6288   }
6289   llvm::SmallVector<int64_t> outputSpatialDimensions = parsedDims.first;
6290   const int64_t outBatchDimension = parsedDims.second[IOBatch];
6291   const int64_t outputFeatureDimension = parsedDims.second[IOFeature];
6292   dnums = ConvDimensionNumbersAttr::get(
6293       parser.getBuilder().getContext(), inputBatchDimension,
6294       inputFeatureDimension, inputSpatialDimensions,
6295       kernelInputFeatureDimension, kernelOutputFeatureDimension,
6296       kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
6297       outputSpatialDimensions);
6298 
6299   return success();
6300 }
6301 
parse(AsmParser & parser,Type type)6302 Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
6303   if (failed(parser.parseLess())) return {};
6304   ConvDimensionNumbersAttr dnums;
6305   if (succeeded(parser.parseOptionalKeyword("raw"))) {
6306     if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {};
6307     return dnums;
6308   }
6309   if (failed(parseConvolutionDimensions(parser, dnums))) return {};
6310   if (failed(parser.parseGreater())) return {};
6311   return dnums;
6312 }
6313 
6314 // Custom printer and parser for ArgResultAliasAttr.
6315 constexpr char kMustAlias[] = "must_alias";
6316 constexpr char kResult[] = "result_index";
6317 constexpr char kArgTupleIndices[] = "tuple_indices";
6318 
print(AsmPrinter & printer) const6319 void ArgResultAliasAttr::print(AsmPrinter& printer) const {
6320   printer << "<";
6321 
6322   // The attribute can have empty tuple indices. Only print argument tuple
6323   // indices if they are non-empty.
6324   if (!getArgTupleIndices().empty())
6325     printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";
6326 
6327   // Print the result index followed by any result tuple indices if present.
6328   printer << kResult << " = [";
6329   printer << getResultIndex();
6330   if (!getResultTupleIndices().empty()) {
6331     printer << ", " << getResultTupleIndices();
6332   }
6333   printer << "]";
6334 
6335   // Print the "must_alias" keyword if this is a must alias, otherwise skip.
6336   if (getIsMustAlias()) printer << ", " << kMustAlias;
6337 
6338   printer << ">";
6339 }
6340 
parse(AsmParser & parser,Type type)6341 Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
6342   if (failed(parser.parseLess())) return {};
6343   llvm::SmallVector<int64_t> argTupleIndices;
6344   // The first element of result indices holds the aliased result index and the
6345   // remaining elements are the result tuple indices.
6346   llvm::SmallVector<int64_t> resultIndices;
6347   bool isMustAlias = false;
6348 
6349   // This conveys to parseStruct that keyword "must_alias" (3rd field) is not
6350   // followed by a "=", but other fields are.
6351   llvm::SmallVector<bool, 3> parseEqual = {true, true, false};
6352 
6353   if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
6354                          {[&]() { return parseDims(parser, argTupleIndices); },
6355                           [&]() {
6356                             // Since the first element is the index of result,
6357                             // at least one element is expected.
6358                             return parseDimsWithMinimumElements(
6359                                 parser, resultIndices, /*minElements=*/1);
6360                           },
6361                           [&]() {
6362                             // always succeeds if the keyword "must_alias" was
6363                             // parsed
6364                             isMustAlias = true;
6365                             return success();
6366                           }},
6367                          parseEqual))) {
6368     parser.emitError(parser.getCurrentLocation())
6369         << "failed parsing argument-result alias attribute";
6370     return {};
6371   }
6372 
6373   int64_t resultIndex = resultIndices[0];
6374   auto resultTupleIndices =
6375       ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()};
6376 
6377   return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices,
6378                                  resultIndex, resultTupleIndices, isMustAlias);
6379 }
6380 
6381 // Returns the element type pointed to by `indices` in type `t`. If the indices
6382 // are invalid, returns nullptr.
getTypeFromTupleIndices(Type type,ArrayRef<int64_t> indices)6383 static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
6384   Type current = type;
6385   for (auto index : indices) {
6386     TupleType tupleType = current.dyn_cast<TupleType>();
6387     if (!tupleType || index >= static_cast<int64_t>(tupleType.size()))
6388       return {};
6389     current = tupleType.getType(index);
6390   }
6391   return current;
6392 }
6393 
verifyArgResultAliasAttr(StringAttr attrName,ArgResultAliasAttr aliasAttr,unsigned argIndex,Operation * op)6394 static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
6395                                               ArgResultAliasAttr aliasAttr,
6396                                               unsigned argIndex,
6397                                               Operation* op) {
6398   // The attribute can only be applied to function-like operations.
6399   if (!isa<mlir::FunctionOpInterface>(op))
6400     return op->emitOpError() << "attribute " << attrName
6401                              << " can only be used on function-like operations";
6402 
6403   // Verify there are no negative indices.
6404   auto tupleIndices = llvm::concat<const int64_t>(
6405       aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices());
6406   if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) ||
6407       aliasAttr.getResultIndex() < 0)
6408     return op->emitOpError()
6409            << "attribute " << attrName
6410            << " expects all argument and result indices to be >= 0";
6411 
6412   // Verify that the result index is not out of range. Since the attribute is a
6413   // function argument attribute, the argument index is always correct when this
6414   // verifier is called.
6415   FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
6416   ArrayRef<Type> argTypes = funcOp.getArgumentTypes();
6417   ArrayRef<Type> resultTypes = funcOp.getResultTypes();
6418   if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size()))
6419     return op->emitOpError()
6420            << "attribute " << attrName
6421            << " result index is out of range, must be <" << resultTypes.size();
6422 
6423   // Verify that argument and result types pointed to by the indices are valid
6424   // and compatible.
6425   Type argType = getTypeFromTupleIndices(argTypes[argIndex],
6426                                          aliasAttr.getArgTupleIndices());
6427   if (!argType)
6428     return op->emitOpError()
6429            << "attribute " << attrName << " argument tuple indices are invalid";
6430   Type resultType =
6431       getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()],
6432                               aliasAttr.getResultTupleIndices());
6433   if (!resultType)
6434     return op->emitOpError()
6435            << "attribute " << attrName << " result tuple indices are invalid";
6436 
6437   if (failed(mlir::verifyCompatibleShape(argType, resultType)) ||
6438       getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType))
6439     return op->emitOpError() << "attribute " << attrName
6440                              << " aliases do not have compatible types, "
6441                              << argType << " vs. " << resultType;
6442   return success();
6443 }
6444 
6445 namespace {
6446 // Custom formatting for convolution window attributes.
printWindowAttribute(OpAsmPrinter & p,DenseElementsAttr attribute)6447 void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) {
6448   if (attribute.getElementType().isInteger(/*width=*/1)) {
6449     // boolean attribute.
6450     llvm::interleaveComma(attribute.getValues<bool>(), p,
6451                           [&](bool b) { p << (b ? 1 : 0); });
6452     return;
6453   }
6454   if (attribute.getType().getRank() == 2) {
6455     // Padding is Nx2 attribute.
6456     auto it = attribute.value_begin<int64_t>();
6457     std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
6458                                                     2);
6459     for (auto& item : values) {
6460       int64_t first = *it;
6461       ++it;
6462       int64_t second = *it;
6463       ++it;
6464       item = {first, second};
6465     }
6466     llvm::interleaveComma(
6467         values, p, [&](const std::pair<int64_t, int64_t> pair) {
6468           p << '[' << pair.first << ", " << pair.second << ']';
6469         });
6470   } else {
6471     llvm::interleaveComma(attribute.getValues<int64_t>(), p);
6472   }
6473 }
6474 }  // namespace
6475 
printWindowAttributes(OpAsmPrinter & p,Operation *,llvm::Optional<DenseIntElementsAttr> windowStrides,llvm::Optional<DenseIntElementsAttr> padding,llvm::Optional<DenseIntElementsAttr> lhsDilation,llvm::Optional<DenseIntElementsAttr> rhsDilation,llvm::Optional<DenseElementsAttr> windowReversal)6476 void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
6477                            llvm::Optional<DenseIntElementsAttr> windowStrides,
6478                            llvm::Optional<DenseIntElementsAttr> padding,
6479                            llvm::Optional<DenseIntElementsAttr> lhsDilation,
6480                            llvm::Optional<DenseIntElementsAttr> rhsDilation,
6481                            llvm::Optional<DenseElementsAttr> windowReversal) {
6482   using pair_t = std::pair<DenseElementsAttr, StringRef>;
6483   std::array<pair_t, 5> printedAttributes = {{
6484       {windowStrides ? *windowStrides : nullptr, "stride"},
6485       {padding ? *padding : nullptr, "pad"},
6486       {lhsDilation ? *lhsDilation : nullptr, "lhs_dilate"},
6487       {rhsDilation ? *rhsDilation : nullptr, "rhs_dilate"},
6488       {windowReversal ? *windowReversal : nullptr, "reverse"},
6489   }};
6490 
6491   // Do not print attributes that do no exist.
6492   auto nonNullAttributes = llvm::make_filter_range(
6493       printedAttributes,
6494       [](const pair_t& a) { return static_cast<bool>(a.first); });
6495 
6496   llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& a) {
6497     p << a.second << " = [";
6498     printWindowAttribute(p, a.first);
6499     p << "]";
6500   });
6501 }
6502 
parseWindowAttributes(OpAsmParser & parser,DenseIntElementsAttr & windowStrides,DenseIntElementsAttr & padding,DenseIntElementsAttr & lhsDilation,DenseIntElementsAttr & rhsDilation,DenseElementsAttr & windowReversal)6503 ParseResult parseWindowAttributes(OpAsmParser& parser,
6504                                   DenseIntElementsAttr& windowStrides,
6505                                   DenseIntElementsAttr& padding,
6506                                   DenseIntElementsAttr& lhsDilation,
6507                                   DenseIntElementsAttr& rhsDilation,
6508                                   DenseElementsAttr& windowReversal) {
6509   StringRef attributeName;
6510 
6511   llvm::StringSet<> allowedAttributeNames{
6512       {"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
6513 
6514   while (parser.parseOptionalKeyword(&attributeName).succeeded()) {
6515     // Verify that the attribute name is valid and erase it.
6516     if (!allowedAttributeNames.erase(attributeName)) {
6517       return parser.emitError(parser.getCurrentLocation(),
6518                               "Unexpected keyword ")
6519              << attributeName;
6520     }
6521 
6522     if (parser.parseEqual()) {
6523       return failure();
6524     }
6525 
6526     // parse the attribute value. We need to support either 1D and Nx2 array of
6527     // integers to parse.
6528     llvm::SmallVector<int64_t> values;
6529     auto int64Parser = [&]() {
6530       return parser.parseInteger(values.emplace_back(0));
6531     };
6532 
6533     if (attributeName == "pad") {
6534       // Parse 2D array of integers.
6535       // Helper to parse an array of two integer elements such as [e0, e1].
6536       auto innerParser = [&]() -> ParseResult {
6537         size_t numOldElements = values.size();
6538         if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Square,
6539                                            int64Parser))
6540           return failure();
6541         size_t numParsedElements = values.size() - numOldElements;
6542         constexpr size_t kExpectedElements = 2;
6543         if (numParsedElements != kExpectedElements)
6544           return parser.emitError(parser.getCurrentLocation())
6545                  << "Expected array with " << kExpectedElements
6546                  << " elements, got " << numParsedElements
6547                  << " elements instead";
6548         return success();
6549       };
6550 
6551       if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
6552                                          innerParser)) {
6553         return failure();
6554       }
6555       const int64_t size = static_cast<int64_t>(values.size());
6556       // values should be filled with the Nx2 padding values.
6557       assert(size % 2 == 0);
6558       auto ty = RankedTensorType::get({size / 2, 2},
6559                                       parser.getBuilder().getIntegerType(64));
6560       padding = DenseIntElementsAttr::get(ty, values);
6561     } else {
6562       // Parse 1D array of integers.
6563       if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
6564                                          int64Parser)) {
6565         return failure();
6566       }
6567       const int64_t size = static_cast<int64_t>(values.size());
6568       if (attributeName == "reverse") {
6569         auto ty = RankedTensorType::get({size},
6570                                         parser.getBuilder().getIntegerType(1));
6571         auto boolVector = llvm::to_vector<4>(
6572             llvm::map_range(values, [](int64_t v) { return v != 0; }));
6573         windowReversal = DenseElementsAttr::get(ty, boolVector);
6574       } else {
6575         auto attr = parser.getBuilder().getI64TensorAttr(values);
6576 
6577         if (attributeName == "stride") {
6578           windowStrides = attr;
6579         } else if (attributeName == "lhs_dilate") {
6580           lhsDilation = attr;
6581         } else if (attributeName == "rhs_dilate") {
6582           rhsDilation = attr;
6583         } else {
6584           llvm_unreachable("Unexpected attribute name");
6585         }
6586       }
6587     }
6588     // continue parsing if there is a comma at the end.
6589     if (parser.parseOptionalComma().failed()) break;
6590   }
6591   return success();
6592 }
6593 
6594 //===----------------------------------------------------------------------===//
6595 // Builder utilities
6596 //===----------------------------------------------------------------------===//
6597 
6598 // Builds the region `body` for stablehlo.sort's comparator: for each type in
6599 // `element_types`, create two block arguments, one for lhs and one for rhs, and
6600 // generates stablehlo.compare op to compare them with the given `direction`.
6601 //
6602 // Note that this right now only does comparision on the first pair of block
6603 // arguments.
buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,ComparisonDirection direction,llvm::Optional<StringRef> compareType,Region * body,OpBuilder * builder)6604 static void buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,
6605                                     ComparisonDirection direction,
6606                                     llvm::Optional<StringRef> compareType,
6607                                     Region* body, OpBuilder* builder) {
6608   OpBuilder::InsertionGuard insertionPointGurad(*builder);
6609 
6610   Location loc = body->getLoc();
6611   Block* block = builder->createBlock(body);
6612   // Add two arguments for each element type.
6613   for (Type elementType : elementTypes) {
6614     TensorType tensorType = RankedTensorType::get({}, elementType);
6615     block->addArguments({tensorType, tensorType},
6616                         SmallVector<Location, 2>(2, loc));
6617   }
6618 
6619   ComparisonType typeAttr;
6620   if (compareType)
6621     typeAttr = symbolizeComparisonType(*compareType).value();
6622   else
6623     typeAttr = ComparisonType::NOTYPE;
6624   Value compare = builder->create<CompareOp>(
6625       loc, block->getArgument(0), block->getArgument(1), direction, typeAttr);
6626 
6627   builder->create<ReturnOp>(loc, compare);
6628 }
6629 
createSortOp(PatternRewriter * rewriter,const Location & loc,const llvm::ArrayRef<Value> & operands,const llvm::ArrayRef<Type> & elementTypes,int64_t dimension,bool isStable,ComparisonDirection direction)6630 SortOp createSortOp(PatternRewriter* rewriter, const Location& loc,
6631                     const llvm::ArrayRef<Value>& operands,
6632                     const llvm::ArrayRef<Type>& elementTypes, int64_t dimension,
6633                     bool isStable, ComparisonDirection direction) {
6634   assert(!operands.empty() && "No operands to sort");
6635   // Create the sort op.
6636   auto sortOp = rewriter->create<SortOp>(loc, operands, dimension, isStable);
6637 
6638   // Use TOTALORDER comparison type instead of the default comparison if the
6639   // element type is of type float.
6640   llvm::Optional<StringRef> compareType = llvm::None;
6641   for (auto const& elementType : elementTypes)
6642     if (elementType.isa<FloatType>()) {
6643       compareType.emplace("TOTALORDER");
6644       break;
6645     }
6646   buildSortComparisonBody(elementTypes, direction, compareType,
6647                           &sortOp.comparator(), rewriter);
6648   return sortOp;
6649 }
6650 
6651 //===----------------------------------------------------------------------===//
6652 // StableHLO Dialect Hooks
6653 //===----------------------------------------------------------------------===//
6654 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)6655 Operation* StablehloDialect::materializeConstant(OpBuilder& builder,
6656                                                  Attribute value, Type type,
6657                                                  Location loc) {
6658   auto elementsAttr = value.dyn_cast<ElementsAttr>();
6659   // HLO dialect constants only support ElementsAttr unlike standard dialect
6660   // constant which supports all attributes.
6661   if (!elementsAttr) return nullptr;
6662   // HLO dialect constants require the type of value and result to match.
6663   if (type != elementsAttr.getType()) return nullptr;
6664 
6665   return builder.create<ConstantOp>(loc, type, elementsAttr);
6666 }
6667 
verifyRegionArgAttribute(Operation * op,unsigned,unsigned argIndex,NamedAttribute attr)6668 LogicalResult StablehloDialect::verifyRegionArgAttribute(
6669     Operation* op, unsigned /*regionIndex*/, unsigned argIndex,
6670     NamedAttribute attr) {
6671   if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
6672     if (failed(
6673             verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op)))
6674       return failure();
6675   }
6676   return success();
6677 }
6678 
verifyOperationAttribute(Operation * op,NamedAttribute attr)6679 LogicalResult StablehloDialect::verifyOperationAttribute(Operation* op,
6680                                                          NamedAttribute attr) {
6681   if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
6682     if (!isa<mlir::FunctionOpInterface>(op))
6683       return op->emitOpError()
6684              << "attribute " << attr.getName()
6685              << " can only be used on function-like operations";
6686   }
6687   return success();
6688 }
6689 
6690 }  // namespace stablehlo
6691 }  // namespace mlir
6692