1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Copyright 2022 The StableHLO Authors.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "dialect/StablehloOps.h"
18
19 #include <assert.h>
20 #include <stddef.h>
21 #include <stdint.h>
22
23 #include <algorithm>
24 #include <array>
25 #include <cstdint>
26 #include <functional>
27 #include <numeric>
28 #include <set>
29 #include <unordered_map>
30 #include <utility>
31
32 #include "dialect/StablehloOps.h.inc"
33 #include "llvm/ADT/APFloat.h"
34 #include "llvm/ADT/APInt.h"
35 #include "llvm/ADT/ArrayRef.h"
36 #include "llvm/ADT/DenseMap.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/StringExtras.h"
40 #include "llvm/ADT/StringRef.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/Twine.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/ADT/iterator_range.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/FormatVariadic.h"
47 #include "llvm/Support/MathExtras.h"
48 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
49 #include "mlir/Dialect/Complex/IR/Complex.h"
50 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
51 #include "mlir/Dialect/Tensor/IR/Tensor.h"
52 #include "mlir/IR/Attributes.h"
53 #include "mlir/IR/Builders.h"
54 #include "mlir/IR/BuiltinAttributes.h"
55 #include "mlir/IR/BuiltinTypes.h"
56 #include "mlir/IR/Diagnostics.h"
57 #include "mlir/IR/Dialect.h"
58 #include "mlir/IR/FunctionInterfaces.h"
59 #include "mlir/IR/Location.h"
60 #include "mlir/IR/MLIRContext.h"
61 #include "mlir/IR/Matchers.h"
62 #include "mlir/IR/OpDefinition.h"
63 #include "mlir/IR/OpImplementation.h"
64 #include "mlir/IR/Operation.h"
65 #include "mlir/IR/OperationSupport.h"
66 #include "mlir/IR/PatternMatch.h"
67 #include "mlir/IR/TypeUtilities.h"
68 #include "mlir/IR/Types.h"
69 #include "mlir/IR/Value.h"
70 #include "mlir/Support/LLVM.h"
71 #include "mlir/Support/LogicalResult.h"
72 #include "mlir/Transforms/InliningUtils.h"
73
74 // Include order matters
75 #include "dialect/StablehloEnums.cpp.inc"
76 #define GET_ATTRDEF_CLASSES
77 #include "dialect/StablehloAttrs.cpp.inc"
78
79 namespace mlir {
80 namespace stablehlo {
81 namespace {
createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,ArrayRef<Type> types,SmallVector<OpAsmParser::Argument> & args)82 void createArgs(ArrayRef<OpAsmParser::UnresolvedOperand> operands,
83 ArrayRef<Type> types,
84 SmallVector<OpAsmParser::Argument>& args) {
85 for (auto argAndType : llvm::zip(operands, types)) {
86 auto& arg = args.emplace_back();
87 arg.ssaName = std::get<0>(argAndType);
88 arg.type = std::get<1>(argAndType);
89 }
90 }
91
__anona165a9100202(SmallVector<int64_t>& nums) 92 const auto hasDuplicates = [](SmallVector<int64_t>& nums) {
93 if (!llvm::is_sorted(nums)) std::sort(nums.begin(), nums.end());
94 auto* last = std::unique(nums.begin(), nums.end());
95 return last != nums.end();
96 };
97
98 //===----------------------------------------------------------------------===//
99 // Utilities for the canonicalize patterns
100 //===----------------------------------------------------------------------===//
101
102 // Verifies that dimension attribute for the op correctly indexes in operand or
103 // result shape.
104 template <typename OpT>
verifyDimAttr(OpT op)105 static LogicalResult verifyDimAttr(OpT op) {
106 int64_t rank = -1;
107 if (auto ty = op.operand().getType().template dyn_cast<RankedTensorType>()) {
108 rank = ty.getRank();
109 } else if (auto ty = op.getType().template dyn_cast<RankedTensorType>()) {
110 rank = ty.getRank();
111 } else {
112 return success();
113 }
114
115 int64_t dim = op.dimension();
116 if (dim < 0 || dim >= rank)
117 return op.emitOpError() << "requires dimension attribute in range [0, "
118 << rank << "); found (" << dim << ")";
119 return success();
120 }
121
122 // Check if the dimension size is dynamic.
isDynamicDimSize(int64_t val)123 inline static bool isDynamicDimSize(int64_t val) {
124 return val == ShapedType::kDynamicSize;
125 }
126
127 // Common shape function helper for RngNormal and RngUniform.
rngInferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)128 static LogicalResult rngInferReturnTypeComponents(
129 MLIRContext* context, Optional<Location> location, ValueRange operands,
130 DictionaryAttr attributes, RegionRange regions,
131 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
132 if (operands.size() != 3)
133 return emitOptionalError(location, "expected 3 operands");
134
135 SmallVector<int64_t> shapeVector;
136 Value shapeOperand = operands[2];
137 auto shapeOperandType = shapeOperand.getType().cast<ShapedType>();
138 Type elementType = getElementTypeOrSelf(operands[1]);
139
140 // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is:
141 // 1, not constant and have dynimic dim (tensor<?x>): infer tensor<*x>.
142 // 2. not constant nor dynimic (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
143 // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
144
145 // Match to check whether the `shape` operand is a constant.
146 DenseIntElementsAttr shape;
147 if (!matchPattern(shapeOperand, m_Constant(&shape))) {
148 int size = shapeOperandType.getDimSize(0);
149 if (isDynamicDimSize(size)) {
150 inferredReturnShapes.emplace_back(elementType);
151 return success();
152 }
153 shapeVector.resize(size, ShapedType::kDynamicSize);
154 inferredReturnShapes.emplace_back(shapeVector, elementType);
155 return success();
156 }
157
158 // `shape` operand is a constant.
159 shapeVector.reserve(shape.size());
160 for (const APInt& fp : shape.getValues<APInt>())
161 shapeVector.push_back(fp.getSExtValue());
162 inferredReturnShapes.emplace_back(shapeVector, elementType);
163 return success();
164 }
165
166 // Returns a new scalar integer value having type `type`. Here `type` must be
167 // an integer or index type.
maybeCastTo(OpBuilder & b,Location loc,Value value,Type type)168 Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) {
169 if (type == value.getType()) return value;
170 assert(type.isIndex() || value.getType().isIndex());
171 return b.create<arith::IndexCastOp>(loc, type, value);
172 }
173
174 //===----------------------------------------------------------------------===//
175 // Utilities for verifiers
176 //===----------------------------------------------------------------------===//
177
178 // Convert a 1D dense int64 attribute to a list of values.
convertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr)179 SmallVector<int64_t> convertDenseIntAttr(
180 llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr) {
181 if (!optionalAttr.has_value()) return SmallVector<int64_t>{};
182
183 mlir::DenseIntElementsAttr attr = *optionalAttr;
184 auto values = attr.getValues<int64_t>();
185 return {values.begin(), values.end()};
186 }
187
188 // Convert a 1D or Nx2 dense int64 attribute to a list of tuples.
convertNx2Attribute(llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr,Location loc)189 FailureOr<SmallVector<std::pair<int64_t, int64_t>>> convertNx2Attribute(
190 llvm::Optional<mlir::DenseIntElementsAttr> optionalAttr, Location loc) {
191 if (!optionalAttr.has_value())
192 return SmallVector<std::pair<int64_t, int64_t>>{};
193 mlir::DenseIntElementsAttr attr = *optionalAttr;
194
195 auto attrType = attr.getType().cast<RankedTensorType>(); // ensured by ODS.
196 if (attrType.getRank() > 1) {
197 if (attrType.getRank() != 2 || attrType.getShape()[1] != 2)
198 return (mlir::emitError(loc) << "expects the shape of padding-attribute "
199 "to be {N, 2}, but got {"
200 << attrType.getShape() << "}.",
201 failure());
202 } else {
203 // Padding values can be provided as a 1D vector as well.
204 if (attr.getValues<int64_t>().size() % 2 != 0)
205 return (mlir::emitError(loc)
206 << "expects the padding-entries to have even number of "
207 "elements, but got "
208 << attr.getValues<int64_t>().size() << " elements.",
209 failure());
210 }
211
212 auto it = attr.getValues<int64_t>().begin();
213 SmallVector<std::pair<int64_t, int64_t>> out(attr.getNumElements() / 2);
214 for (auto& item : out) {
215 int64_t first = *it;
216 ++it;
217 int64_t second = *it;
218 ++it;
219 item = {first, second};
220 }
221 return out;
222 }
223
224 // If a window with the given bound in some dimension is dilated with the given
225 // dilation factor in that dimension, then the value returned is the bound for
226 // the array in that dimension after dilation.
227 //
228 // For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
229 // window with values 1, x, 2, x, 3, where x indicates holes left by the
230 // dilation. So DilatedBound(3, 2) == 5.
dilatedBound(int64_t bound,int64_t dilation)231 int64_t dilatedBound(int64_t bound, int64_t dilation) {
232 assert(bound >= 0 && "The dimension to dialate must be >= 0");
233 if (bound == 0) return 0;
234
235 // Suppose the array has three entries 123 and the dilation factor is 4. Then
236 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
237 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
238 // add 1 to account for the final input element.
239 return (bound - 1) * dilation + 1;
240 }
241
242 // Returns the number of valid positions of a window with the given size and
243 // stride within an array with the given bound. This is the bound of an output
244 // array with one element per valid position of the window.
245 //
246 // For example, for arguments of (bound=5, window_size=2, stride=2), the
247 // returned value is 2. There are valid positions at offset 0 and offset 2,
248 // while offset 4 is not valid since the window's last entry would be at 5,
249 // which is beyond the bound of 5.
stridedBound(int64_t bound,int64_t windowSize,int64_t stride)250 int64_t stridedBound(int64_t bound, int64_t windowSize, int64_t stride) {
251 assert(windowSize >= 0 && "Expected window size to be >= 0");
252 assert(bound >= 0 && "Expected bound to be >= 0");
253
254 if (bound == 0 || windowSize > bound) return 0;
255
256 // Without considering stride, the maximum valid offset is bound -
257 // window_size. Taking stride into account, the valid offsets then have the
258 // form q * stride for q = 0, ..., Q such that q * stride <= bound -
259 // window_size. This implies that Q equals floor(bound - window_size /
260 // stride). There are Q + 1 valid values of q, yielding the formula below.
261 return (bound - windowSize) / stride + 1;
262 }
263
264 // WindowDimension described how the kernel window moves across the base area
265 // in a particular dimension.
266 // Describes the windowing in an operation such as convolution.
267 // The window is moved across a base area and for each position of the
268 // window a computation is performed. The field below describes the
269 // window and the movement of the window across a base area.
270 struct WindowDimension {
271 int64_t size = 0;
272 int64_t stride = 1;
273 int64_t paddingLow = 0;
274 int64_t paddingHigh = 0;
275 int64_t windowDilation = 1;
276 int64_t baseDilation = 1;
277 bool windowReversal = false;
278 };
279
280 // Verifies various properties of window-attributes (viz., stride, padding,
281 // lhs_dilation and rhs_dilation) and collects all the window-attributes for
282 // each kernel spatial dimensions.
283 FailureOr<SmallVector<WindowDimension>>
verifyWindowAttributesAndInferWindowDimensions(ArrayRef<int64_t> windowDimensions,ArrayRef<int64_t> windowStrides,ArrayRef<std::pair<int64_t,int64_t>> padding,ArrayRef<int64_t> lhsDilation,ArrayRef<int64_t> rhsDilation,Location loc)284 verifyWindowAttributesAndInferWindowDimensions(
285 ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides,
286 ArrayRef<std::pair<int64_t, int64_t>> padding,
287 ArrayRef<int64_t> lhsDilation, ArrayRef<int64_t> rhsDilation,
288 Location loc) {
289 const auto verifySize = [&](const size_t attrSize,
290 StringRef attrName) -> LogicalResult {
291 if (attrSize == 0 || attrSize == windowDimensions.size()) return success();
292 return mlir::emitError(loc)
293 << "expects " << attrName
294 << " to have same dimension-size as size of "
295 "window dimensions "
296 "("
297 << windowDimensions.size() << "), but got: " << attrSize << ".";
298 };
299
300 if (failed(verifySize(windowStrides.size(), "window-strides")))
301 return failure();
302 if (failed(verifySize(lhsDilation.size(), "base-dilation factors")))
303 return failure();
304 if (failed(verifySize(rhsDilation.size(), "window-dilation factors")))
305 return failure();
306 if (failed(verifySize(padding.size(), "padding-entries"))) return failure();
307
308 SmallVector<WindowDimension> window(windowDimensions.size());
309 for (size_t i = 0; i < windowDimensions.size(); i++) {
310 WindowDimension& dim = window[i];
311
312 dim.size = windowDimensions[i];
313 if (!isDynamicDimSize(dim.size) && dim.size <= 0)
314 return (mlir::emitError(loc)
315 << "expects window to have positive value for " << i
316 << "-th window dimension, but got " << dim.size << ".",
317 failure());
318
319 if (!windowStrides.empty()) dim.stride = windowStrides[i];
320 if (dim.stride <= 0)
321 return (mlir::emitError(loc)
322 << "expects window to have positive stride for " << i
323 << "-th window dimension, but got " << dim.stride << ".",
324 failure());
325
326 if (!lhsDilation.empty()) dim.baseDilation = lhsDilation[i];
327 if (dim.baseDilation <= 0)
328 return (mlir::emitError(loc) << "expects window to have positive base "
329 "dilation factor for "
330 << i << "-th window dimension, but got "
331 << dim.baseDilation << ".",
332 failure());
333
334 if (!rhsDilation.empty()) dim.windowDilation = rhsDilation[i];
335 if (dim.windowDilation <= 0)
336 return (mlir::emitError(loc) << "expects window to have positive window "
337 "dilation factor for "
338 << i << "-th window dimension, but got "
339 << dim.windowDilation << ".",
340 failure());
341
342 if (!padding.empty()) {
343 dim.paddingLow = padding[i].first;
344 dim.paddingHigh = padding[i].second;
345 }
346 }
347
348 return window;
349 }
350
351 // Infer the shape of the output window.
352 // Foreach dimension d,
353 // output-window-shape[d] =
354 // stridedBound(padding_low + dilatedBound(base_shape[d]) +
355 // padding_high,
356 // dilatedBound(window_shape[d]))
357 // where (padding_low, padding_high) is the padding-pair for d.
inferWindowOutputShape(const ArrayRef<int64_t> baseShape,const ArrayRef<WindowDimension> window)358 SmallVector<int64_t> inferWindowOutputShape(
359 const ArrayRef<int64_t> baseShape, const ArrayRef<WindowDimension> window) {
360 assert(baseShape.size() == window.size() &&
361 "Size of window dimensions must match the size of base shape.");
362
363 SmallVector<int64_t> outputDimensions(window.size());
364 for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i) {
365 if (isDynamicDimSize(baseShape[i]) || isDynamicDimSize(window[i].size)) {
366 outputDimensions[i] = ShapedType::kDynamicSize;
367 } else {
368 const auto& dim = window[i];
369
370 const int64_t dilatedBase = dilatedBound(baseShape[i], dim.baseDilation);
371 const int64_t paddedDilatedBase =
372 dim.paddingLow + dilatedBase + dim.paddingHigh;
373 const int64_t dilatedWindow = dilatedBound(dim.size, dim.windowDilation);
374
375 outputDimensions[i] =
376 stridedBound(paddedDilatedBase, dilatedWindow, dim.stride);
377 }
378 }
379
380 return outputDimensions;
381 }
382
383 // Return true if type1 and type2 are tensors and have the same
384 // element-type, else return false. With float element-types, ignore comparing
385 // floating-point precision if ignoreFpPrecision is True.
tensorsHaveSameElType(Type type1,Type type2,bool ignoreFpPrecision)386 bool tensorsHaveSameElType(Type type1, Type type2, bool ignoreFpPrecision) {
387 auto tensorTy1 = type1.dyn_cast<TensorType>();
388 auto tensorTy2 = type2.dyn_cast<TensorType>();
389
390 if (!tensorTy1 || !tensorTy2) return false;
391
392 if (ignoreFpPrecision && tensorTy1.getElementType().isa<FloatType>() &&
393 tensorTy2.getElementType().isa<FloatType>())
394 return true;
395
396 return tensorTy1.getElementType() == tensorTy2.getElementType();
397 }
398
399 // Return true if type1 and type2 are shape-compatible and have same element
400 // type. If 'ignoreFpPrecision' is True, then allow floats with different
401 // precisions while checking element-types.
compatibleShapeAndElementType(Type type1,Type type2,bool ignoreFpPrecision=false)402 bool compatibleShapeAndElementType(Type type1, Type type2,
403 bool ignoreFpPrecision = false) {
404 if (failed(verifyCompatibleShape(type1, type2))) return false;
405 return tensorsHaveSameElType(type1.cast<ShapedType>(),
406 type2.cast<ShapedType>(), ignoreFpPrecision);
407 }
408
verifyReducerShape(Location loc,Block & block,ArrayRef<TensorType> inputArgTypes,ArrayRef<TensorType> initValueTypes,int64_t numInputs,ArrayRef<int64_t> allowedDimensions,bool allInputsUnranked,SmallVectorImpl<TensorType> & accumulatorSubShapes)409 LogicalResult verifyReducerShape(
410 Location loc, Block& block, ArrayRef<TensorType> inputArgTypes,
411 ArrayRef<TensorType> initValueTypes, int64_t numInputs,
412 ArrayRef<int64_t> allowedDimensions, bool allInputsUnranked,
413 SmallVectorImpl<TensorType>& accumulatorSubShapes) {
414 // Check that the number of reduction-region arguments matches with that of
415 // reduce-op's arguments.
416 if (static_cast<int64_t>(block.getArguments().size()) != numInputs * 2)
417 return mlir::emitError(loc)
418 << "Reduction-region must take " << numInputs * 2
419 << " parameters, but takes " << block.getArguments().size()
420 << " parameter(s)";
421
422 // Check if the reduction-region produces non-zero outputs.
423 if (block.getTerminator()->getOperands().empty())
424 return mlir::emitError(loc)
425 << "The reduction-region expected to return some value(s)";
426
427 // Check that the reduction-region returns list- of tensors.
428 // The number of result-tensors must match the `numInputs`.
429 if (static_cast<int64_t>(block.getTerminator()->getOperands().size()) !=
430 numInputs)
431 return mlir::emitError(loc)
432 << "Reduction-region here must produce " << numInputs
433 << " tensors, but produces "
434 << block.getTerminator()->getOperands().size() << " instead";
435
436 for (Value retOperand : block.getTerminator()->getOperands()) {
437 auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
438 if (!tensorTy)
439 return mlir::emitError(loc) << "Reduction-region here must produce "
440 "tensor-typed result(s), but "
441 "produces "
442 << retOperand.getType() << " instead";
443
444 accumulatorSubShapes.push_back(tensorTy);
445 }
446
447 // Consider typical reduce-* op syntax:
448 //
449 // op(I(i), V(j)):
450 // block(BI(i), BV(j)):
451 // ... some computation ...
452 // return(R(i))
453 //
454 // where
455 // I(i) : i-th input of op
456 // V(j) : j-th init-value of op
457 // BI(i) : i-th input of reducer-function
458 // BV(j) : j-th init-value of reducer-function
459 // R(i) : i-th return-type
460 //
461 // Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
462 //
463 // Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
464 // C1 : Check that BI(i) and R(i) have same shape and element-type.
465 // C2 : Check that BV(j) and R(i) have same shape and element-type.
466 // C3 : Check that V(j) and R(i) have same shape and element-type.
467 //
468 // From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
469 // have compatible shapes and element-types.
470 // The next check, C4, adds constraints on how the type if I(i) is related
471 // to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
472 //
473 // C4.1 : Check that I(i) and BV(j) have same element-type.
474 // C4.2 : Check that shape of BV(j) is a 'sub-sequence' of
475 // 'allowedDimensions'. 'allowedDimensions' is a list of dimensions
476 // which any of BI(i), BV(j), and R(i) is allowed to have.
477 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
478 // Check C1.
479 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
480 block.getArgument(inputIdx).getType()))
481 return mlir::emitError(loc)
482 << "The type of reduction-region's parameter at index " << inputIdx
483 << " is different than the corresponding result type: "
484 << block.getArgument(inputIdx).getType() << " vs "
485 << accumulatorSubShapes[inputIdx];
486
487 // Check C2.
488 if (!compatibleShapeAndElementType(
489 accumulatorSubShapes[inputIdx],
490 block.getArgument(numInputs + inputIdx).getType(),
491 /*ignoreFpPrecision=*/true))
492 return mlir::emitError(loc)
493 << "The type of reduction-region's parameter at index "
494 << numInputs + inputIdx
495 << " is different than the corresponding result type: "
496 << block.getArgument(numInputs + inputIdx).getType() << " vs "
497 << accumulatorSubShapes[inputIdx];
498
499 // Check C3.
500 if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
501 initValueTypes[inputIdx],
502 /*ignoreFpPrecision=*/true))
503 return mlir::emitError(loc)
504 << "The type of reduction-region's result type at index "
505 << inputIdx
506 << " differs from the op's corresponding init-value type: "
507 << accumulatorSubShapes[inputIdx] << " vs "
508 << initValueTypes[inputIdx];
509
510 // Check C4.1.
511 if (!tensorsHaveSameElType(
512 inputArgTypes[inputIdx],
513 block.getArgument(numInputs + inputIdx).getType(), true))
514 return mlir::emitError(loc)
515 << "The element-type of reduction-region's argument at index "
516 << numInputs + inputIdx << " is expected to be "
517 << inputArgTypes[inputIdx].getElementType() << ", but got "
518 << block.getArgument(numInputs + inputIdx).getType()
519 << " as its type.";
520
521 // Check C4.2.
522 Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
523 auto blockArgTensorTy = blockArgType.cast<TensorType>();
524
525 if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
526
527 auto argShape = blockArgTensorTy.getShape();
528 if (argShape.size() > allowedDimensions.size())
529 return mlir::emitError(loc)
530 << "The rank of reduction-region's argument at index "
531 << numInputs + inputIdx
532 << " is expected to be <= " << allowedDimensions.size() << ", got "
533 << argShape.size();
534
535 int64_t argShapeIdx = 0;
536 for (int64_t outputShapeIdx = 0;
537 outputShapeIdx < static_cast<int64_t>(allowedDimensions.size()) &&
538 argShapeIdx < static_cast<int64_t>(argShape.size());
539 outputShapeIdx++)
540 if (allowedDimensions[outputShapeIdx] == argShape[argShapeIdx])
541 argShapeIdx++;
542
543 if (argShapeIdx != static_cast<int64_t>(argShape.size()))
544 return mlir::emitError(loc)
545 << "The shape of reduction-region's argument at index "
546 << numInputs + inputIdx
547 << " is not compatible with that of reduce-op's input-parameter "
548 "at index "
549 << inputIdx;
550 }
551
552 return success();
553 }
554
potentiallyComplexBitwidth(Type type)555 unsigned potentiallyComplexBitwidth(Type type) {
556 auto complexTy = type.dyn_cast<ComplexType>();
557 return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
558 : type.getIntOrFloatBitWidth();
559 }
560 } // namespace
561
562 //===----------------------------------------------------------------------===//
563 // Utilities for attributes
564 //===----------------------------------------------------------------------===//
565
verifyEncoding(llvm::ArrayRef<int64_t> bounds,mlir::Type elementType,llvm::function_ref<mlir::InFlightDiagnostic ()> emitError) const566 LogicalResult TypeExtensionsAttr::verifyEncoding(
567 llvm::ArrayRef<int64_t> bounds, mlir::Type elementType,
568 llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
569 return hlo::verifyBounds(
570 getBounds(), RankedTensorType::get(bounds, elementType), emitError);
571 }
572
573 //===----------------------------------------------------------------------===//
574 // AllReduceOp
575 //===----------------------------------------------------------------------===//
576
build(::mlir::OpBuilder & odsBuilder,::mlir::OperationState & odsState,::mlir::Type resultType,::mlir::Value operand,::mlir::DenseIntElementsAttr replicaGroups,::mlir::stablehlo::ChannelHandleAttr channelHandle)577 void AllReduceOp::build(
578 ::mlir::OpBuilder& odsBuilder, ::mlir::OperationState& odsState,
579 ::mlir::Type resultType, ::mlir::Value operand,
580 ::mlir::DenseIntElementsAttr replicaGroups,
581 /*optional*/ ::mlir::stablehlo::ChannelHandleAttr channelHandle) {
582 AllReduceOp::build(odsBuilder, odsState, resultType, operand,
583 replicaGroups, channelHandle, nullptr);
584 }
585
586 //===----------------------------------------------------------------------===//
587 // ReduceScatterOp
588 //===----------------------------------------------------------------------===//
589
verifyReduceScatter(Operation * op,TypeRange operandTypes,TypeRange resultTypes,uint64_t scatterDimension)590 LogicalResult verifyReduceScatter(Operation* op, TypeRange operandTypes,
591 TypeRange resultTypes,
592 uint64_t scatterDimension) {
593 // If operand and result are both ranked, then the size of the scatter
594 // dimension in the operand should be a multiple of the size of the scatter
595 // dimension in the result.
596
597 // TODO(zhouxin) Change the ODS definition to return int64_t.
598 if (static_cast<int64_t>(scatterDimension) < 0) {
599 return op->emitOpError("expects scatter_dimension >= 0");
600 }
601
602 for (auto it : llvm::zip(operandTypes, resultTypes)) {
603 auto operandType = std::get<0>(it).cast<ShapedType>();
604 auto resultType = std::get<1>(it).cast<ShapedType>();
605 if (!operandType.hasRank() || !resultType.hasRank()) continue;
606 if (operandType.getRank() != resultType.getRank())
607 return op->emitOpError() << "operand and result should have same rank";
608 if (static_cast<int64_t>(scatterDimension) >= operandType.getRank())
609 return op->emitOpError()
610 << "scatter dim should be less than operand/result rank";
611 if (operandType.isDynamicDim(scatterDimension) ||
612 resultType.isDynamicDim(scatterDimension))
613 continue;
614 if (operandType.getDimSize(scatterDimension) == 0)
615 return op->emitOpError() << "operand scatter dimension cannot be zero";
616 if (resultType.getDimSize(scatterDimension) == 0)
617 return op->emitOpError() << "result scatter dimension cannot be zero";
618 if ((operandType.getDimSize(scatterDimension) %
619 resultType.getDimSize(scatterDimension)) != 0)
620 return op->emitOpError()
621 << "operand scatter dimension has size "
622 << operandType.getDimSize(scatterDimension)
623 << ", expected to be a multiple of result scatter dimension size "
624 << resultType.getDimSize(scatterDimension);
625
626 // Non scatter dimensions should be equal.
627 for (uint64_t index : llvm::seq<uint64_t>(0, operandType.getRank())) {
628 if (index == scatterDimension || operandType.isDynamicDim(index) ||
629 resultType.isDynamicDim(index))
630 continue;
631 if (operandType.getDimSize(index) != resultType.getDimSize(index))
632 return op->emitOpError()
633 << "non scatter dimensions should be same for operand ("
634 << operandType.getDimSize(index) << ") and result ("
635 << resultType.getDimSize(index) << ")";
636 }
637 }
638 return success();
639 }
640
verify()641 LogicalResult ReduceScatterOp::verify() {
642 if (failed(verifyReplicaGroups(*this, /*is_uniform_sized=*/true)))
643 return failure();
644 auto operandType = operand().getType().cast<TensorType>();
645 bool operandTypeRanked = operandType.isa<RankedTensorType>();
646 Block& block = computation().front();
647 SmallVector<TensorType> accumulatorSubshapes;
648 if (failed(verifyReducerShape(
649 this->getLoc(), block, {operandType},
650 {RankedTensorType::get({}, operandType.getElementType())},
651 /*numInputs=*/1, /*allowedDimensions=*/{},
652 /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes)))
653 return failure();
654
655 return verifyReduceScatter(*this,
656 /*operandTypes=*/{operand().getType()},
657 /*resultTypes=*/{getType()},
658 /*scatterDimension=*/scatter_dimension());
659 }
660
661 //===----------------------------------------------------------------------===//
662 // CompatibleOperandsAndResultType
663 //===----------------------------------------------------------------------===//
664
665 // TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
666 // support quantization or sparsity.
667 #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \
668 LogicalResult Op::inferReturnTypeComponents( \
669 MLIRContext* context, Optional<Location> location, \
670 ValueShapeRange operands, DictionaryAttr attributes, \
671 RegionRange regions, \
672 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) { \
673 return inferReturnTypeComponentsFromOperands(context, location, operands, \
674 attributes, regions, \
675 inferredReturnShapes); \
676 }
677
678 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)679 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
680 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
681 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
682 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
683 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp)
684 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ClzOp)
685 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CollectivePermuteOp)
686 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosineOp)
687 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CrossReplicaSumOp)
688 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DivOp)
689 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ExpOp)
690 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Expm1Op)
691 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp)
692 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogOp)
693 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
694 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LogisticOp)
695 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MaxOp)
696 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MinOp)
697 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MulOp)
698 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp)
699 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NotOp)
700 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OrOp)
701 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PopulationCountOp)
702 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PowOp)
703 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReducePrecisionOp)
704 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RemOp)
705 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ReverseOp)
706 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundNearestEvenOp)
707 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp)
708 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RsqrtOp)
709 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftLeftOp)
710 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightArithmeticOp)
711 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ShiftRightLogicalOp)
712 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp)
713 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp)
714 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp)
715 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp)
716 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp)
717 INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp)
718
719 //===----------------------------------------------------------------------===//
720 // ConstantOp
721 //===----------------------------------------------------------------------===//
722
723 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
724 assert(operands.empty() && "constant has no operands");
725
726 // Return the held attribute value.
727 return value();
728 }
729
730 // Builds a constant op with the specified attribute `value`.
build(OpBuilder &,OperationState & result,Attribute value)731 void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result,
732 Attribute value) {
733 Type type;
734 if (auto elemAttr = value.dyn_cast<ElementsAttr>()) {
735 type = elemAttr.getType();
736 } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
737 // All XLA types must be tensor types. In the build() method, we want to
738 // provide more flexibility by allowing attributes of scalar types. But we
739 // need to wrap it up with ElementsAttr to construct valid XLA constants.
740 type =
741 RankedTensorType::get(/*shape=*/{}, value.cast<TypedAttr>().getType());
742 value = DenseElementsAttr::get(type.cast<TensorType>(), value);
743 } else if (auto complexAttr = value.dyn_cast<complex::NumberAttr>()) {
744 type = RankedTensorType::get(/*shape=*/{},
745 complexAttr.cast<TypedAttr>().getType());
746 value =
747 DenseElementsAttr::get(type.cast<TensorType>(), complexAttr.getValue());
748 }
749
750 // TODO: support other XLA specific types.
751 assert(type && "unsupported attribute type for building constant");
752 result.types.push_back(type);
753 result.addAttribute("value", value);
754 }
755
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)756 LogicalResult ConstantOp::inferReturnTypes(
757 MLIRContext*, Optional<Location>, ValueRange operands,
758 DictionaryAttr attributes, RegionRange,
759 SmallVectorImpl<Type>& inferredReturnTypes) {
760 ConstantOpAdaptor adaptor(operands, attributes);
761 Type type = adaptor.value().getType();
762 inferredReturnTypes.push_back(type);
763 return success();
764 }
765
isCompatibleReturnTypes(TypeRange l,TypeRange r)766 bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
767 if (l.size() != r.size() || l.size() != 1) return false;
768 auto lhsTy = l.front().cast<TensorType>();
769 auto rhsTy = r.front().cast<TensorType>();
770 // For comparisons of the uniform quantized element based tensor type, use the
771 // storage type since the constant value will be stored through the underlying
772 // storage type.
773 if (auto rhsElemTy =
774 rhsTy.getElementType().dyn_cast<quant::QuantizedType>()) {
775 rhsTy = hlo::getSameShapeTensorType(rhsTy, rhsElemTy.getStorageType());
776 }
777 return lhsTy == rhsTy;
778 }
779
parse(OpAsmParser & parser,OperationState & result)780 ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
781 // Parse the generic form.
782 if (succeeded(parser.parseOptionalLParen())) {
783 if (parser.parseRParen()) return failure();
784 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
785 if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
786 parser.parseArrow())
787 return failure();
788 Type resultTy;
789 if (parser.parseType(resultTy)) {
790 return failure();
791 }
792 result.addTypes(resultTy);
793 return success();
794 }
795
796 ElementsAttr valueAttr;
797 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
798
799 if (parser.parseCustomAttributeWithFallback(valueAttr, Type{}, "value",
800 result.attributes)) {
801 return failure();
802 }
803 result.addTypes(valueAttr.getType());
804 return success();
805 }
806
807 /// Print a `constant` op.
808 ///
809 /// op ::= attr-dict $value
810 ///
811 /// When the `value` and `output` have different type, it just uses the default
812 /// operator assembly format as a fallback.
print(::mlir::OpAsmPrinter & p)813 void ConstantOp::print(::mlir::OpAsmPrinter& p) {
814 // If not all types are the same, use generic form.
815 if (value().getType() != getType()) {
816 p.printGenericOp(getOperation(), /*printOpName=*/false);
817 return;
818 }
819
820 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
821 p << ' ';
822 p.printStrippedAttrOrType(valueAttr());
823 }
824
825 //===----------------------------------------------------------------------===//
826 // CustomCallOp
827 //===----------------------------------------------------------------------===//
828
verify()829 LogicalResult CustomCallOp::verify() {
830 // If both operand and result layout attributes are not specified then nothing
831 // to verify.
832 if (!operand_layouts().has_value() && !result_layouts().has_value())
833 return success();
834
835 // Layout constraints for either both operands & results or none should be
836 // specified.
837 if (operand_layouts().has_value() != result_layouts().has_value())
838 return emitOpError() << "Layout attributes should be specified for "
839 "either both operands and results or none.";
840
841 // Helper function to verify types and the corresponding layouts.
842 auto verifyTypesAndLayouts =
843 [this](TypeRange types, mlir::ArrayAttr layouts,
844 const std::string& valueName) -> LogicalResult {
845 if (types.size() != layouts.size())
846 return emitOpError() << "Number of " << valueName
847 << "s must match the number of " << valueName
848 << " layouts, " << types.size()
849 << " != " << layouts.size();
850
851 for (const auto& indexedTypeAndLayout :
852 llvm::enumerate(llvm::zip(types, layouts))) {
853 // Get index for more descriptive error message.
854 auto index = indexedTypeAndLayout.index();
855
856 auto type = std::get<0>(indexedTypeAndLayout.value());
857 auto layout = std::get<1>(indexedTypeAndLayout.value())
858 .cast<DenseIntElementsAttr>();
859
860 if (type.isa<TupleType>())
861 return emitOpError() << "Tuple types are not fully supported with "
862 "layout constraints yet";
863 auto tensorType = type.dyn_cast<TensorType>();
864
865 // For non-tensor types e.g. !stablehlo.token, the layout should be empty.
866 if (!tensorType) {
867 if (layout.empty()) continue;
868 return emitOpError()
869 << "Only tensor types can have non-empty layout: " << valueName
870 << " #" << index << " of type " << type << " has layout "
871 << layout;
872 }
873
874 // For unranked tensors, we cannot verify the compatibility with layout
875 // any further.
876 if (!tensorType.hasRank()) continue;
877
878 // Layout must be a permutation of [0, N) where N is the rank of the
879 // tensor type.
880 std::vector<int64_t> range(tensorType.getRank());
881 std::iota(range.begin(), range.end(), 0);
882 if (tensorType.getRank() != layout.size() ||
883 !std::is_permutation(range.begin(), range.end(), layout.begin()))
884 return emitOpError() << "incorrect layout " << layout << " for type "
885 << type << ", layout must be a permutation of [0, "
886 << tensorType.getRank() << ")";
887 }
888 return success();
889 };
890
891 // At this point both `operand_layouts` and `result_layouts` are defined.
892 ArrayAttr operandLayouts = this->operand_layouts().value();
893 ArrayAttr resultLayouts = this->result_layouts().value();
894
895 // Full support for layouts for arbitrary nesting of tuples is not
896 // supported yet.
897 //
898 // If result does not have any tuples, then i-th element of `result_layouts`
899 // specifies the layout constraints on i-th result.
900 //
901 // For the common case of a single tuple result packing non-tuple values, the
902 // i-th element of `result_layouts` specifies layout for i-th element of the
903 // result tuple.
904 TypeRange resultTypes;
905 if (getNumResults() == 1 && getResult(0).getType().isa<TupleType>())
906 resultTypes = getResult(0).getType().cast<TupleType>().getTypes();
907 else
908 resultTypes = getResultTypes();
909
910 // Verify that operands and operand layouts match.
911 if (failed(
912 verifyTypesAndLayouts(getOperandTypes(), operandLayouts, "operand")))
913 return failure();
914
915 // Verify that results and result layouts match.
916 return verifyTypesAndLayouts(resultTypes, resultLayouts, "result");
917 }
918
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)919 void CustomCallOp::getEffects(
920 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>&
921 effects) {
922 // CustomCall has "all possible effects" unless the has_side_effect is present
923 // and set to false.
924 auto hasSideEffect = (*this)->getAttrOfType<BoolAttr>("has_side_effect");
925 if (hasSideEffect && !hasSideEffect.getValue()) return;
926 effects.emplace_back(MemoryEffects::Allocate::get());
927 effects.emplace_back(MemoryEffects::Free::get());
928 effects.emplace_back(MemoryEffects::Write::get());
929 effects.emplace_back(MemoryEffects::Read::get());
930 }
931
932 //===----------------------------------------------------------------------===//
933 // CholeskyOp
934 //===----------------------------------------------------------------------===//
935
936 // The following properties are already enforced by the ODS:
937 // P0. a.element_type is floating or complex
938 // We intend to verify the following properties
939 // P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s
940 // P2. The two minor dimensions of 'a' must have equal size, got %s.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)941 LogicalResult CholeskyOp::inferReturnTypeComponents(
942 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
943 DictionaryAttr attributes, RegionRange regions,
944 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
945 CholeskyOp::Adaptor adaptor(operands, attributes, regions);
946 Type aType = adaptor.a().getType();
947 RankedTensorType aRankedType = aType.dyn_cast<RankedTensorType>();
948 if (!aRankedType) {
949 inferredReturnShapes.emplace_back(
950 aType.cast<TensorType>().getElementType());
951 return success();
952 }
953
954 ArrayRef<int64_t> aShape = aRankedType.getShape();
955 if (aShape.size() < 2) {
956 return emitOptionalError(
957 location, "argument 'a' must have rank >= 2, got shape ", aShape, ".");
958 }
959
960 int64_t lastDim = aShape[aShape.size() - 1];
961 int64_t penultimateDim = aShape[aShape.size() - 2];
962 if (!isDynamicDimSize(lastDim) && !isDynamicDimSize(penultimateDim) &&
963 lastDim != penultimateDim) {
964 return emitOptionalError(
965 location, "minor dimensions of 'a' must have equal size, got shape ",
966 aShape, ".");
967 }
968 inferredReturnShapes.emplace_back(aRankedType.getShape(),
969 aRankedType.getElementType());
970 return success();
971 }
972
973 //===----------------------------------------------------------------------===//
974 // DotOp
975 //===----------------------------------------------------------------------===//
976 namespace {
dimCompatible(int64_t a,int64_t b)977 bool dimCompatible(int64_t a, int64_t b) {
978 return isDynamicDimSize(a) || isDynamicDimSize(b) || a == b;
979 }
980
inferDotReturnType(ShapedType lhs,ShapedType rhs)981 ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) {
982 auto elementType = lhs.getElementType();
983 if (!lhs.hasRank() || !rhs.hasRank()) {
984 return UnrankedTensorType::get(elementType);
985 }
986
987 // vector dot vector
988 if (1 == lhs.getRank() && 1 == rhs.getRank() &&
989 dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
990 return RankedTensorType::get({}, elementType);
991 }
992 // matrix dot vector
993 if (2 == lhs.getRank() && 1 == rhs.getRank() &&
994 dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
995 return RankedTensorType::get({lhs.getDimSize(0)}, elementType);
996 }
997 // vector dot matrix
998 if (1 == lhs.getRank() && 2 == rhs.getRank() &&
999 dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) {
1000 return RankedTensorType::get({rhs.getDimSize(1)}, elementType);
1001 }
1002 // matrix dot matrix
1003 if (2 == lhs.getRank() && 2 == rhs.getRank() &&
1004 dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) {
1005 int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)};
1006 return RankedTensorType::get(shape, elementType);
1007 }
1008 return {};
1009 }
1010 } // namespace
1011
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1012 LogicalResult DotOp::inferReturnTypes(
1013 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1014 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1015 DotOp::Adaptor op(operands);
1016 auto lhsType = op.lhs().getType().cast<ShapedType>();
1017 auto rhsType = op.rhs().getType().cast<ShapedType>();
1018 inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType));
1019 return success();
1020 }
1021
verify()1022 LogicalResult DotOp::verify() {
1023 auto lhsType = lhs().getType().cast<ShapedType>();
1024 auto rhsType = rhs().getType().cast<ShapedType>();
1025 auto resultType = getType().cast<ShapedType>();
1026 auto expectReturnType = inferDotReturnType(lhsType, rhsType);
1027 if (!expectReturnType) {
1028 return emitError() << "Unexpected operands type: " << lhsType << " and "
1029 << rhsType;
1030 }
1031 if (resultType.hasRank() && expectReturnType.hasRank()) {
1032 if (resultType.getShape() != expectReturnType.getShape()) {
1033 return emitError() << "Unexpected result type: has " << resultType
1034 << " but inferred " << expectReturnType
1035 << " from operands " << lhsType << " and " << rhsType;
1036 }
1037 }
1038 return success();
1039 }
1040
1041 //===----------------------------------------------------------------------===//
1042 // DotGeneralOp
1043 //===----------------------------------------------------------------------===//
1044
verify()1045 LogicalResult DotGeneralOp::verify() {
1046 auto dimNumbers = this->dot_dimension_numbers();
1047
1048 ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1049 ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1050 ArrayRef<int64_t> lhsContractingDims =
1051 dimNumbers.getLhsContractingDimensions();
1052 ArrayRef<int64_t> rhsContractingDims =
1053 dimNumbers.getRhsContractingDimensions();
1054
1055 if (lhsBatchingDims.size() != rhsBatchingDims.size()) {
1056 return emitOpError() << "lhs and rhs should have the same number of "
1057 "batching dimensions";
1058 }
1059 if (lhsContractingDims.size() != rhsContractingDims.size()) {
1060 return emitOpError() << "lhs and rhs should have the same number of "
1061 "contracting dimensions";
1062 }
1063
1064 llvm::SmallDenseSet<int64_t> dimSet;
1065
1066 auto checkDimsDistinct =
1067 [this](ArrayRef<int64_t> batchingDims, ArrayRef<int64_t> contractingDims,
1068 llvm::SmallDenseSet<int64_t>& dimSet, llvm::StringRef lhs,
1069 llvm::StringRef rhs) -> LogicalResult {
1070 auto dims = llvm::concat<const int64_t>(batchingDims, contractingDims);
1071 for (auto dim : dims) {
1072 auto [_, wasInserted] = dimSet.insert(dim);
1073 if (!wasInserted) {
1074 return emitOpError() << "has duplicated dimension from " << lhs
1075 << " and " << rhs << ": " << dim;
1076 }
1077 }
1078 return success();
1079 };
1080
1081 if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet,
1082 "lhs_batching_dimensions",
1083 "lhs_contracting_dimensions"))) {
1084 return failure();
1085 }
1086 dimSet.clear();
1087 if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet,
1088 "rhs_batching_dimensions",
1089 "rhs_contracting_dimensions"))) {
1090 return failure();
1091 }
1092
1093 auto checkDimsInRange = [this](int64_t rank, ArrayRef<int64_t> dims,
1094 llvm::StringRef dimName) -> LogicalResult {
1095 auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; };
1096 const auto* dimsNotInRange =
1097 std::find_if_not(dims.begin(), dims.end(), inRange);
1098 if (dimsNotInRange != dims.end()) {
1099 return emitOpError() << dimName << " value: " << *dimsNotInRange
1100 << " is out of range: "
1101 << "[0, " << rank << ")";
1102 }
1103 return success();
1104 };
1105
1106 auto lhsType = this->lhs().getType().dyn_cast<RankedTensorType>();
1107 auto rhsType = this->rhs().getType().dyn_cast<RankedTensorType>();
1108
1109 if (lhsType) {
1110 if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims,
1111 "lhs_batching_dimensions")) ||
1112 failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims,
1113 "lhs_contracting_dimensions"))) {
1114 return failure();
1115 }
1116 }
1117 if (rhsType) {
1118 if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims,
1119 "rhs_batching_dimensions")) ||
1120 failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims,
1121 "rhs_contracting_dimensions"))) {
1122 return failure();
1123 }
1124 }
1125
1126 if (lhsType && rhsType) {
1127 // Dimension sizes must be compatible for lhs/rhs.
1128 auto lhsShape = lhsType.getShape();
1129 auto rhsShape = rhsType.getShape();
1130
1131 for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) {
1132 if (lhsShape[lhs] != rhsShape[rhs]) {
1133 return emitOpError() << "batching dimension sizes must match for "
1134 "lhs/rhs";
1135 }
1136 }
1137 for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) {
1138 if (lhsShape[lhs] != rhsShape[rhs]) {
1139 return emitOpError() << "contracting dimension sizes must match for "
1140 "lhs/rhs";
1141 }
1142 }
1143 }
1144 return success();
1145 }
1146
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1147 LogicalResult DotGeneralOp::reifyReturnTypeShapes(
1148 OpBuilder& builder, ValueRange operands,
1149 SmallVectorImpl<Value>& reifiedReturnShapes) {
1150 auto lhsType = lhs().getType().dyn_cast<ShapedType>();
1151 auto rhsType = rhs().getType().dyn_cast<ShapedType>();
1152 if (!lhsType || !rhsType) {
1153 return failure();
1154 }
1155
1156 Adaptor adaptor(operands);
1157 auto dimNumbers = dot_dimension_numbers();
1158 SmallVector<Value> dimensions;
1159 for (const int64_t lhsDim : dimNumbers.getLhsBatchingDimensions()) {
1160 dimensions.push_back(
1161 builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), lhsDim));
1162 }
1163
1164 for (int64_t i = 0; i < lhsType.getRank(); i++) {
1165 if (!llvm::is_contained(dimNumbers.getLhsContractingDimensions(), i) &&
1166 !llvm::is_contained(dimNumbers.getLhsBatchingDimensions(), i)) {
1167 dimensions.push_back(
1168 builder.create<tensor::DimOp>(getLoc(), adaptor.lhs(), i));
1169 }
1170 }
1171 for (int64_t i = 0; i < rhsType.getRank(); i++) {
1172 if (!llvm::is_contained(dimNumbers.getRhsContractingDimensions(), i) &&
1173 !llvm::is_contained(dimNumbers.getRhsBatchingDimensions(), i)) {
1174 dimensions.push_back(
1175 builder.create<tensor::DimOp>(getLoc(), adaptor.rhs(), i));
1176 }
1177 }
1178
1179 reifiedReturnShapes.push_back(
1180 builder.create<tensor::FromElementsOp>(getLoc(), dimensions));
1181 return success();
1182 }
1183
1184 //===----------------------------------------------------------------------===//
1185 // FftOp
1186 //===----------------------------------------------------------------------===//
1187
1188 // We intend to verify the following properties
1189 // P1. 1 <= rank <= 3
1190 // P2. Element types agree with fft_type
1191 // P3. Operand shape dimensions agree with fft_length for the given fft_type
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1192 LogicalResult FftOp::inferReturnTypeComponents(
1193 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
1194 DictionaryAttr attributes, RegionRange regions,
1195 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1196 FftOp::Adaptor adaptor(operands, attributes, regions);
1197 auto fftLength = adaptor.fft_length().getValues<int64_t>();
1198 int64_t fftRank = fftLength.size();
1199
1200 // P1.
1201 if (fftRank > 3 || fftRank < 1) {
1202 return emitOptionalError(location, "rank must be between 1 and 3, but got ",
1203 fftRank, ".");
1204 }
1205
1206 // P2. Element type agreement
1207 // FFT : C -> C
1208 // IFFT : C -> C
1209 // RFFT : R -> C
1210 // IRFFT : C -> R
1211 auto fftType = adaptor.fft_type();
1212 auto operandType = adaptor.operand().getType().cast<TensorType>();
1213 Type operandElementType = operandType.getElementType();
1214 // Check the input element type and infer return element type
1215 if (fftType == FftType::RFFT) {
1216 if (!operandElementType.isF32() && !operandElementType.isF64()) {
1217 return emitOptionalError(
1218 location, "RFFT requires f32 or f64 input type, but is given ",
1219 operandElementType, ".");
1220 }
1221 } else {
1222 if (!operandElementType.isa<ComplexType>()) {
1223 return emitOptionalError(
1224 location, stringifyFftType(fftType),
1225 " takes a complex tensor as input, but is given ", operandType, ".");
1226 }
1227 }
1228 // Generate the output element type
1229 Type resultElementType = operandElementType;
1230 if (fftType == FftType::RFFT) { // RFFT : R -> C
1231 resultElementType = ComplexType::get(resultElementType);
1232 } else if (fftType == FftType::IRFFT) { // IRFFT : C -> R
1233 resultElementType = operandElementType.cast<ComplexType>().getElementType();
1234 }
1235
1236 // P3. Check input shape and infer return shape
1237 operandType = operandType.dyn_cast<RankedTensorType>();
1238 if (!operandType) {
1239 inferredReturnShapes.emplace_back(resultElementType);
1240 return success();
1241 }
1242 auto operandShape = operandType.getShape();
1243 if (static_cast<int64_t>(operandShape.size()) < fftRank) {
1244 return emitOptionalError(
1245 location, "operand rank must not be less than fft rank of ", fftRank,
1246 " for operand of type ", operandType, ".");
1247 }
1248
1249 SmallVector<int64_t> resultShape = to_vector(operandShape);
1250
1251 if (fftType == FftType::RFFT) {
1252 auto shapeBack = operandShape.take_back(fftRank);
1253 for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1254 if (operandDim != fftDim) {
1255 return emitOptionalError(
1256 location,
1257 "RFFT requires innermost dimensions match fft_length. Got: ",
1258 operandShape, " but wanted ", fftLength, ".");
1259 }
1260 }
1261 if (fftLength[fftRank - 1] != 0) {
1262 resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1;
1263 }
1264 }
1265 if (fftType == FftType::IRFFT) {
1266 auto shapeBack = operandShape.take_back(fftRank).drop_back();
1267 for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) {
1268 if (operandDim != fftDim) {
1269 return emitOptionalError(location,
1270 "IRFFT requires non-final dimensions "
1271 "match fft_length. Got: ",
1272 operandShape, " but wanted ", fftLength,
1273 ", and ", operandDim, " != ", fftDim, ".");
1274 }
1275 }
1276 if ((operandShape[operandShape.size() - 1] != 0 ||
1277 fftLength[fftRank - 1] != 0) &&
1278 operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1)
1279 return emitOptionalError(location,
1280 "IRFFT requires innermost dimension match "
1281 "fft_length[-1]/2+1. Got: ",
1282 operandShape, " but fft_length is ", fftLength,
1283 ".");
1284 resultShape[resultShape.size() - 1] = fftLength[fftRank - 1];
1285 }
1286
1287 inferredReturnShapes.emplace_back(resultShape, resultElementType);
1288 return success();
1289 }
1290
1291 //===----------------------------------------------------------------------===//
1292 // GatherOp
1293 //===----------------------------------------------------------------------===//
1294
1295 namespace {
1296
1297 // following https://www.tensorflow.org/xla/operation_semantics#gather
1298 // The bounds for the output array along dimension i is computed as follows:
1299 // (1) If i is present in batch_dims (i.e. is equal to batch_dims[k] for some k)
1300 // then we pick
1301 // the corresponding dimension bounds out of start_indices.shape, skipping
1302 // index_vector_dim
1303 // (i.e. pick start_indices.shape.dims[k] if k < index_vector_dim and
1304 // start_indices.shape.dims[k+1] otherwise).
1305 // (2) If i is present in offset_dims (i.e. equal to offset_dims[k] for some k)
1306 // then we pick
1307 // the corresponding bound out of slice_sizes after accounting for
1308 // collapsed_slice_dims
1309 // (i.e. we pick adjusted_slice_sizes[k] where adjusted_slice_sizes is
1310 // slice_sizes with the bounds at indices collapsed_slice_dims removed).
1311
getSliceSizeValues(GatherOp * gather,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizes)1312 void getSliceSizeValues(GatherOp* gather, OpBuilder& builder, Location loc,
1313 ValueRange operands,
1314 SmallVectorImpl<Value>& sliceSizes) {
1315 for (int64_t val : gather->slice_sizes().getValues<int64_t>()) {
1316 sliceSizes.push_back(builder.create<arith::ConstantIndexOp>(loc, val));
1317 }
1318 }
1319
getSliceSizeValues(DynamicGatherOp *,OpBuilder & builder,Location loc,ValueRange operands,SmallVectorImpl<Value> & sliceSizeValues)1320 void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder,
1321 Location loc, ValueRange operands,
1322 SmallVectorImpl<Value>& sliceSizeValues) {
1323 DynamicGatherOp::Adaptor adaptor(operands);
1324 Value sliceSizes = adaptor.slice_sizes();
1325 auto sliceSizesTy = sliceSizes.getType().cast<ShapedType>();
1326 for (int64_t i = 0; i < sliceSizesTy.getDimSize(0); ++i) {
1327 Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
1328 sliceSizeValues.push_back(
1329 builder.create<tensor::ExtractOp>(loc, sliceSizes, idx));
1330 }
1331 }
1332
1333 // Verify the following properties:
1334 // P1. Verify no repeat in start_index_map.
1335 // P2. Verify 0 <= start_index_map[i] < rank(operand), for every i.
1336 // P3. Verify 0 <= index_vector_dim <= rank(start_indices).
1337 // P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim].
1338 // P5. Verify offset_dims is_sorted and no repeated.
1339 // P6. Verify collapsed_slice_dims is_sorted and no repeated.
1340 // P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims).
1341 // P8. Verify slice_sizes has rank of 1.
1342 // P9. Verify size(slice_sizes) == rank(operand).
1343 // P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items.
verifyGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,ShapeAdaptor sliceSizesShape,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1344 static LogicalResult verifyGather(
1345 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1346 ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers,
1347 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1348 int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1349
1350 // Check startIndexMap
1351 auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap());
1352 // P1.
1353 if (hasDuplicates(startIndexMap))
1354 return errorEmitter() << "expects start_index_map to not repeat, got: ["
1355 << startIndexMap << "]";
1356
1357 // P2.
1358 for (int64_t i = 0; i < static_cast<int64_t>(startIndexMap.size()); ++i)
1359 if (startIndexMap[i] < 0 ||
1360 (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank()))
1361 return errorEmitter()
1362 << "start_index_map[" << i << "]: " << startIndexMap[i]
1363 << " is out of bounds for "
1364 << "operand rank " << operandShape.getRank();
1365
1366 if (startIndicesShape.hasRank()) {
1367 // P3.
1368 // index_vector_dim == start_indices.rank implies a trailing 1 on the shape
1369 // of start_indices.
1370 if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0)
1371 return errorEmitter() << "index_vector_dim " << indexVectorDim
1372 << " is out of bounds for start indices with rank "
1373 << startIndicesShape.getRank();
1374
1375 bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank();
1376 if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) {
1377 int64_t effectiveDimSize;
1378 if (impliedTrailingDim)
1379 effectiveDimSize = 1;
1380 else
1381 effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim);
1382 // P4.
1383 if (effectiveDimSize !=
1384 static_cast<int64_t>(dimensionNumbers.getStartIndexMap().size()))
1385 return errorEmitter() << "start_index_map size ("
1386 << dimensionNumbers.getStartIndexMap().size()
1387 << ") is not equal to size of index dimension ("
1388 << indexVectorDim << ") of start_indices ("
1389 << effectiveDimSize << ")";
1390 }
1391 }
1392
1393 // P5.
1394 auto offsetDims = to_vector(dimensionNumbers.getOffsetDims());
1395 if (!llvm::is_sorted(offsetDims))
1396 return errorEmitter() << "expects offset_dims to be sorted, got: ["
1397 << offsetDims << "]";
1398 if (hasDuplicates(offsetDims))
1399 return errorEmitter() << "expects offset_dims to not repeat, got: ["
1400 << offsetDims << "]";
1401
1402 // P6.
1403 auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims());
1404 if (!llvm::is_sorted(collapsedSliceDims))
1405 return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: ["
1406 << collapsedSliceDims << "]";
1407 if (hasDuplicates(collapsedSliceDims))
1408 return errorEmitter()
1409 << "expects collapsed_slice_dims to not repeat, got: ["
1410 << collapsedSliceDims << "]";
1411
1412 // P7.
1413 int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() +
1414 dimensionNumbers.getCollapsedSliceDims().size();
1415 if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank)
1416 return errorEmitter() << "offset_dims size ("
1417 << dimensionNumbers.getOffsetDims().size()
1418 << ") plus collapse_slice_dims size ("
1419 << dimensionNumbers.getCollapsedSliceDims().size()
1420 << ") is not equal to operand rank ("
1421 << operandShape.getRank() << ")";
1422
1423 // P8.
1424 // This should be fully expressible with type constraints, but it isn't
1425 // obvious how to do that with the current infrastructure.
1426 if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1)
1427 return errorEmitter() << "slice_sizes.rank != 1";
1428 if (sliceSizesShape.hasStaticShape()) {
1429 int64_t sliceSize = sliceSizesShape.getNumElements();
1430
1431 // P9.
1432 if (sliceSize != impliedOperandRank)
1433 return errorEmitter() << "slice_sizes size (" << sliceSize
1434 << ") not equal to (implied) operand rank ("
1435 << impliedOperandRank << ")";
1436
1437 // P10.
1438 for (auto dim : dimensionNumbers.getCollapsedSliceDims())
1439 if (dim < 0 || dim >= sliceSize)
1440 return errorEmitter() << "collapsed dimension " << dim
1441 << " is out of bounds for slice_sizes.size ("
1442 << sliceSize << ")";
1443 }
1444
1445 return success();
1446 }
1447
1448 // Verify the following properties:
1449 // P1. Verifications by verifyGather().
1450 // P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims.
1451 // P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i.
verifyStaticGather(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,DenseIntElementsAttr sliceSizes,GatherDimensionNumbersAttr dimensionNumbers,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1452 static LogicalResult verifyStaticGather(
1453 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1454 DenseIntElementsAttr sliceSizes,
1455 GatherDimensionNumbersAttr dimensionNumbers,
1456 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1457 // P1.
1458 // For some reason the getType call is necessary here
1459 if (failed(verifyGather(
1460 /*operandShape=*/operandShape,
1461 /*startIndicesShape=*/startIndicesShape,
1462 /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers,
1463 errorEmitter)))
1464 return failure();
1465
1466 // P2.
1467 for (auto dim : dimensionNumbers.getCollapsedSliceDims()) {
1468 int64_t sliceDimSize = sliceSizes.getValues<int64_t>()[dim];
1469 if (sliceDimSize > 1) {
1470 return errorEmitter() << "slice_sizes collapsed dimension " << dim
1471 << " should <= 1 but got " << sliceDimSize;
1472 }
1473 }
1474
1475 // P3.
1476 if (operandShape.hasRank()) {
1477 for (const auto& it : llvm::enumerate(sliceSizes.getValues<int64_t>())) {
1478 if (operandShape.isDynamicDim(it.index())) continue;
1479 auto operandDimSize = operandShape.getDimSize(it.index());
1480 auto sliceDimSize = it.value();
1481 if (sliceDimSize < 0 || sliceDimSize > operandDimSize)
1482 return errorEmitter() << "slice size (" << sliceDimSize
1483 << ") is out of bounds for operand dimension ("
1484 << operandDimSize << ") at index " << it.index();
1485 }
1486 }
1487 return success();
1488 }
1489
1490 template <typename dimTy>
inferGatherShape(int64_t resultRank,llvm::function_ref<dimTy (int64_t)> getStartIndicesDim,llvm::function_ref<dimTy (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<dimTy> & shape)1491 static void inferGatherShape(
1492 int64_t resultRank, llvm::function_ref<dimTy(int64_t)> getStartIndicesDim,
1493 llvm::function_ref<dimTy(int64_t)> getSliceDim,
1494 GatherDimensionNumbersAttr dimensionNumbers,
1495 SmallVectorImpl<dimTy>& shape) {
1496 ArrayRef<int64_t> collapsedSliceDims =
1497 dimensionNumbers.getCollapsedSliceDims();
1498 int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim();
1499
1500 // We don't necessarily know the rank of sliceSizes, but we do know that it
1501 // can't be larger than the highest collapsed dimension. So go through those
1502 // and populate the leading dimensions of adjustedSliceSizes. The trailing
1503 // dimensions can just be adjusted by an offset.
1504 const auto* maxCollapsedDimIt =
1505 std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end());
1506 int64_t maxCollapsedDim = -1;
1507 if (maxCollapsedDimIt != collapsedSliceDims.end())
1508 maxCollapsedDim = *maxCollapsedDimIt;
1509
1510 SmallVector<dimTy> adjustedSliceSizePrefix;
1511 for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) {
1512 if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue;
1513 adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex));
1514 }
1515 auto getAdjustedSliceDim = [&](int64_t index) -> dimTy {
1516 if (index < static_cast<int64_t>(adjustedSliceSizePrefix.size()))
1517 return adjustedSliceSizePrefix[index];
1518 return getSliceDim(index + collapsedSliceDims.size());
1519 };
1520
1521 ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1522
1523 // Dimensions in the output that aren't offset dimensions are called batch
1524 // dimensions.
1525 SmallVector<int64_t> batchDims;
1526 for (int dim = 0; dim < resultRank; ++dim)
1527 if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
1528
1529 for (int i = 0; i < resultRank; ++i) {
1530 const auto* offsetDimsIt =
1531 std::find(offsetDims.begin(), offsetDims.end(), i);
1532 if (offsetDimsIt != offsetDims.end()) {
1533 auto index = std::distance(offsetDims.begin(), offsetDimsIt);
1534 shape.push_back(getAdjustedSliceDim(index));
1535 continue;
1536 }
1537 auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i);
1538 assert(batchDimsIt != batchDims.end());
1539 auto index = std::distance(batchDims.begin(), batchDimsIt);
1540 // This can never run into the special case where start_indices gets
1541 // implicitly expanded with a trailing 1 if
1542 // index_vector_dim = start_indices.rank because then index would equal
1543 // index_vector_dim, which means we'd be looking at index+1, which would be
1544 // out of bounds anyway.
1545 if (index >= indexVectorDim) ++index;
1546 shape.push_back(getStartIndicesDim(index));
1547 }
1548 }
1549
1550 // Verify the following properties:
1551 // P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i.
1552 // (output_shape_rank = size(offset_dims) + rank(start_indices) -1)
inferGatherReturnTypeComponents(ShapeAdaptor operandShape,ShapeAdaptor startIndicesShape,llvm::function_ref<int64_t (int64_t)> getSliceDim,GatherDimensionNumbersAttr dimensionNumbers,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes,llvm::function_ref<InFlightDiagnostic ()> errorEmitter)1553 static LogicalResult inferGatherReturnTypeComponents(
1554 ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape,
1555 llvm::function_ref<int64_t(int64_t)> getSliceDim,
1556 GatherDimensionNumbersAttr dimensionNumbers,
1557 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes,
1558 llvm::function_ref<InFlightDiagnostic()> errorEmitter) {
1559 Type elementType = operandShape.getElementType();
1560
1561 // We need this to determine the result rank. We could still place bounds on
1562 // the result rank if that was something ShapedTypeComponents could express.
1563 if (!startIndicesShape.hasRank()) {
1564 inferredReturnShapes.push_back(elementType);
1565 return success();
1566 }
1567
1568 ArrayRef<int64_t> offsetDims = dimensionNumbers.getOffsetDims();
1569 int64_t startIndicesRank = startIndicesShape.getRank();
1570 // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is
1571 // appended to start_indices shape.
1572 if (dimensionNumbers.getIndexVectorDim() == startIndicesRank)
1573 ++startIndicesRank;
1574 int64_t resultRank = offsetDims.size() + startIndicesRank - 1;
1575 // P1.
1576 for (int64_t i = 0; i < static_cast<int64_t>(offsetDims.size()); ++i)
1577 if (offsetDims[i] < 0 || offsetDims[i] >= resultRank)
1578 return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i]
1579 << " is out of bounds for "
1580 << "implied result rank " << resultRank;
1581
1582 auto getStartIndicesDim = [&](int64_t index) {
1583 return startIndicesShape.getDimSize(index);
1584 };
1585
1586 SmallVector<int64_t> shape;
1587 inferGatherShape<int64_t>(resultRank, getStartIndicesDim, getSliceDim,
1588 dimensionNumbers, shape);
1589
1590 inferredReturnShapes.emplace_back(shape, elementType);
1591 return success();
1592 }
1593
1594 template <typename Op>
reifyGatherShape(Op * op,OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1595 LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
1596 SmallVectorImpl<Value>& reifiedReturnShapes) {
1597 // No support for unranked gather output shape a.t.m.
1598 auto resultTy =
1599 op->getResult().getType().template dyn_cast<RankedTensorType>();
1600 if (!resultTy) return failure();
1601
1602 typename Op::Adaptor adaptor(operands);
1603 Value startIndices = adaptor.start_indices();
1604
1605 Location loc = op->getLoc();
1606 int resultRank = resultTy.getRank();
1607 Type shapeElTy = startIndices.getType().cast<ShapedType>().getElementType();
1608 auto toShapeElType = [&](Value v) {
1609 return maybeCastTo(builder, loc, v, shapeElTy);
1610 };
1611
1612 SmallVector<Value, 4> sliceSizes;
1613 getSliceSizeValues(op, builder, loc, operands, sliceSizes);
1614 llvm::transform(sliceSizes, sliceSizes.begin(),
1615 [&](Value v) { return toShapeElType(v); });
1616
1617 auto getStartIndicesDim = [&](int64_t index) {
1618 return toShapeElType(
1619 builder.create<tensor::DimOp>(loc, startIndices, index));
1620 };
1621 SmallVector<Value, 4> shapeValues;
1622 auto getSliceDim = [&sliceSizes](int64_t index) -> Value {
1623 return sliceSizes[index];
1624 };
1625 inferGatherShape<Value>(resultRank, getStartIndicesDim, getSliceDim,
1626 op->dimension_numbers(), shapeValues);
1627
1628 Value outputShape = builder.create<tensor::FromElementsOp>(
1629 loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues);
1630 reifiedReturnShapes.push_back(outputShape);
1631
1632 return success();
1633 }
1634
1635 } // namespace
1636
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1637 LogicalResult GatherOp::reifyReturnTypeShapes(
1638 OpBuilder& builder, ValueRange operands,
1639 SmallVectorImpl<Value>& reifiedReturnShapes) {
1640 return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1641 }
1642
1643 // The following properties are already enforced by the ODS:
1644 // P0. Verify the start_indices has element type of integer.
1645 // Verify the following properties:
1646 // Verifications by verifyStaticGather() and verifyGather() inside it.
1647 // Verifications by inferGatherReturnTypeComponents.
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1648 LogicalResult GatherOp::inferReturnTypeComponents(
1649 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1650 DictionaryAttr attributes, RegionRange regions,
1651 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1652 // TODO(zhouxin) remove this comment after the ordering issue is clear.
1653 // This can get called before other op verify methods, so we have to do a
1654 // bunch of verification up front. With a better story for ordering and/or
1655 // multi-phase op verification, this should hopefully all go away.
1656 Location loc = location.value_or(UnknownLoc::get(context));
1657 auto errorEmitter = [&loc]() {
1658 return mlir::emitError(loc)
1659 << "'" << GatherOp::getOperationName() << "' op ";
1660 };
1661 GatherOp::Adaptor adaptor(operands, attributes, regions);
1662 if (failed(adaptor.verify(loc))) return failure();
1663
1664 // We want the ShapeAdaptors, so can't route via the adaptor :-/
1665 ShapeAdaptor operandShape = operands.getShape(0);
1666 ShapeAdaptor startIndicesShape = operands.getShape(1);
1667 GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1668 DenseIntElementsAttr sliceSizesAttr = adaptor.slice_sizes();
1669
1670 if (failed(verifyStaticGather(/*operandShape=*/operandShape,
1671 /*startIndicesShape=*/startIndicesShape,
1672 /*sliceSizes=*/sliceSizesAttr, dimensionNumbers,
1673 errorEmitter)))
1674 return failure();
1675
1676 auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t {
1677 return sliceSizesAttr.getValues<int64_t>()[index];
1678 };
1679
1680 return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1681 getSliceDim, dimensionNumbers,
1682 inferredReturnShapes, errorEmitter);
1683 }
1684
1685 //===----------------------------------------------------------------------===//
1686 // DynamicGatherOp
1687 //===----------------------------------------------------------------------===//
1688
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1689 LogicalResult DynamicGatherOp::reifyReturnTypeShapes(
1690 OpBuilder& builder, ValueRange operands,
1691 SmallVectorImpl<Value>& reifiedReturnShapes) {
1692 return reifyGatherShape(this, builder, operands, reifiedReturnShapes);
1693 }
1694
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1695 LogicalResult DynamicGatherOp::inferReturnTypeComponents(
1696 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
1697 DictionaryAttr attributes, RegionRange regions,
1698 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
1699 // This can get called before other op verify methods, so we have to do a
1700 // bunch of verification up front. With a better story for ordering and/or
1701 // multi-phase op verification, this should hopefully all go away.
1702 Location loc = location.value_or(UnknownLoc::get(context));
1703 auto errorEmitter = [&loc]() {
1704 return mlir::emitError(loc)
1705 << "'" << DynamicGatherOp::getOperationName() << "' op ";
1706 };
1707 DynamicGatherOp::Adaptor adaptor(operands, attributes, regions);
1708 if (failed(adaptor.verify(loc))) return failure();
1709
1710 // We want the ShapeAdaptors, so can't route via the adaptor :-/
1711 ShapeAdaptor operandShape = operands.getShape(0);
1712 ShapeAdaptor startIndicesShape = operands.getShape(1);
1713 ShapeAdaptor sliceSizesShape = operands.getShape(2);
1714 GatherDimensionNumbersAttr dimensionNumbers = adaptor.dimension_numbers();
1715
1716 if (failed(verifyGather(/*operandShape=*/operandShape,
1717 /*startIndicesShape=*/startIndicesShape,
1718 /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers,
1719 errorEmitter)))
1720 return failure();
1721
1722 auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; };
1723 return inferGatherReturnTypeComponents(operandShape, startIndicesShape,
1724 getSliceDim, dimensionNumbers,
1725 inferredReturnShapes, errorEmitter);
1726 }
1727
1728 //===----------------------------------------------------------------------===//
1729 // GetDimensionSizeOp
1730 //===----------------------------------------------------------------------===//
1731 //
verify()1732 LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); }
1733
1734 //===----------------------------------------------------------------------===//
1735 // IotaOp
1736 //===----------------------------------------------------------------------===//
1737
verify()1738 LogicalResult IotaOp::verify() {
1739 auto shape = getType().cast<ShapedType>();
1740 if (!shape.hasRank()) return success();
1741
1742 if (shape.getRank() == 0) return emitOpError() << "does not support scalars.";
1743
1744 auto iotaDimension = static_cast<int64_t>(this->iota_dimension());
1745 if (iotaDimension >= shape.getRank() || iotaDimension < 0)
1746 return emitOpError()
1747 << "iota dimension cannot go beyond the output rank or be negative.";
1748 return success();
1749 }
1750
1751 //===----------------------------------------------------------------------===//
1752 // DynamicIotaOp
1753 //===----------------------------------------------------------------------===//
1754
castToIndexTensor(OpBuilder & builder,Location loc,Value shapeOp)1755 static Value castToIndexTensor(OpBuilder& builder, Location loc,
1756 Value shapeOp) {
1757 ShapedType resultTy = shape::getExtentTensorType(
1758 builder.getContext(), shapeOp.getType().cast<ShapedType>().getDimSize(0));
1759 if (shapeOp.getType() == resultTy) return shapeOp; // Nothing to do.
1760 return builder.create<arith::IndexCastOp>(loc, resultTy, shapeOp);
1761 }
1762
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)1763 LogicalResult DynamicIotaOp::reifyReturnTypeShapes(
1764 OpBuilder& builder, ValueRange operands,
1765 SmallVectorImpl<Value>& reifiedReturnShapes) {
1766 DynamicIotaOp::Adaptor adaptor(operands);
1767 reifiedReturnShapes.push_back(
1768 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
1769 return success();
1770 }
1771
1772 //===----------------------------------------------------------------------===//
1773 // DynamicUpdateSliceOp
1774 //===----------------------------------------------------------------------===//
1775
verify()1776 LogicalResult DynamicUpdateSliceOp::verify() {
1777 OperandRange indices = start_indices();
1778 if (indices.size() <= 1) return success();
1779
1780 // Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
1781 // is OK to cast indices to ShapedType here.
1782 auto idxTensor = indices.take_front().front().getType().cast<ShapedType>();
1783 Type firstElemTy = idxTensor.getElementType();
1784 Type elemTy;
1785
1786 for (auto idx : llvm::drop_begin(indices, 1)) {
1787 idxTensor = idx.getType().cast<ShapedType>();
1788 elemTy = idxTensor.getElementType();
1789
1790 if (firstElemTy != elemTy) {
1791 return emitOpError() << "start indices must have same element type "
1792 "(encountered mismatch: "
1793 << firstElemTy << " vs " << elemTy << ")";
1794 }
1795 }
1796 return success();
1797 }
1798
1799 //===----------------------------------------------------------------------===//
1800 // AbsOp
1801 //===----------------------------------------------------------------------===//
1802
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1803 LogicalResult AbsOp::inferReturnTypes(
1804 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
1805 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
1806 auto operandTy = (*operands.begin()).getType().cast<ShapedType>();
1807 Type elementTy = operandTy.getElementType();
1808 if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
1809 elementTy = complexTy.getElementType();
1810 }
1811
1812 Type resultTy;
1813 if (auto rankedOperandTy = operandTy.dyn_cast<RankedTensorType>()) {
1814 resultTy = RankedTensorType::get(operandTy.getShape(), elementTy,
1815 rankedOperandTy.getEncoding());
1816 } else if (operandTy.hasRank()) {
1817 resultTy = RankedTensorType::get(operandTy.getShape(), elementTy);
1818 } else {
1819 resultTy = UnrankedTensorType::get(elementTy);
1820 }
1821 inferredReturnTypes.push_back(resultTy);
1822 return success();
1823 }
1824
1825 //===----------------------------------------------------------------------===//
1826 // CollectivePermuteOp
1827 //===----------------------------------------------------------------------===//
1828
1829 // Verifies the source target pairs attached to collective permute.
verifyCollectivePermuteSourceTargetPairs(Operation * op,DenseIntElementsAttr attr)1830 LogicalResult verifyCollectivePermuteSourceTargetPairs(
1831 Operation* op, DenseIntElementsAttr attr) {
1832 auto type = attr.getType().dyn_cast<RankedTensorType>();
1833 if (type.getRank() != 2)
1834 return op->emitError() << "expect source_target_pairs attribute to be of "
1835 "rank 2, but got rank "
1836 << type.getRank();
1837 if (type.getShape()[1] != 2)
1838 return op->emitError()
1839 << "expect source_target_pairs attribute of shape (N, 2), but got ("
1840 << type.getShape() << ")";
1841 // Check source target pairs for duplicate sources or targets.
1842 llvm::DenseSet<int64_t> sources;
1843 llvm::DenseSet<int64_t> targets;
1844 for (auto i = attr.begin(), e = attr.end(); i != e; ++i) {
1845 auto val = (*i).getSExtValue();
1846 if (i.getIndex() % 2 == 0) {
1847 bool isUnique = sources.insert(val).second;
1848 if (!isUnique) return op->emitError() << "duplicate sources not allowed.";
1849 } else {
1850 bool isUnique = targets.insert(val).second;
1851 if (!isUnique) return op->emitError() << "duplicate targets not allowed.";
1852 }
1853 }
1854 return success();
1855 }
1856
verify()1857 LogicalResult CollectivePermuteOp::verify() {
1858 return verifyCollectivePermuteSourceTargetPairs(*this,
1859 source_target_pairs());
1860 }
1861
1862 //===----------------------------------------------------------------------===//
1863 // ConvolutionOp
1864 //===----------------------------------------------------------------------===//
1865
1866 namespace {
1867 // Checks:
1868 // P1. Same sizes for input, kernel and output spatial_dims.
1869 // P2. Spatial and non-spatial dimentions (for input,kernel, &output) should
1870 // be unique and in range [0, num_dims), where num_dims = rank of input
1871 // (lhs/rhs) tensors.
1872 //
1873 // Note that the spatial + non-spatial dimensions may not cover all the
1874 // dimensions in the range [0,num) because of the presence of 'unknown'
1875 // dimensions (ref. cl/415132294).
isSpatialDimensionsValid(ConvolutionOp op)1876 LogicalResult isSpatialDimensionsValid(ConvolutionOp op) {
1877 auto inputSpatialDimensions =
1878 op.dimension_numbers().getInputSpatialDimensions();
1879 auto kernelSpatialDimensions =
1880 op.dimension_numbers().getKernelSpatialDimensions();
1881 auto outputSpatialDimensions =
1882 op.dimension_numbers().getOutputSpatialDimensions();
1883
1884 // P1.
1885 if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) ||
1886 (inputSpatialDimensions.size() != outputSpatialDimensions.size()))
1887 return op.emitOpError() << "expects the same size for input, kernel and "
1888 "output spatial-dimensions, but got "
1889 << inputSpatialDimensions.size() << ", "
1890 << kernelSpatialDimensions.size() << ", and "
1891 << outputSpatialDimensions.size() << " resp.";
1892
1893 // P2.
1894 SmallVector<int64_t> inputDnums(inputSpatialDimensions.size() + 2);
1895 inputDnums[0] = op.dimension_numbers().getInputBatchDimension();
1896 inputDnums[1] = op.dimension_numbers().getInputFeatureDimension();
1897 std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(),
1898 inputDnums.begin() + 2);
1899
1900 SmallVector<int64_t> windowDnums(kernelSpatialDimensions.size() + 2);
1901 windowDnums[0] = op.dimension_numbers().getKernelInputFeatureDimension();
1902 windowDnums[1] = op.dimension_numbers().getKernelOutputFeatureDimension();
1903 std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(),
1904 windowDnums.begin() + 2);
1905
1906 SmallVector<int64_t> outputDnums(outputSpatialDimensions.size() + 2);
1907 outputDnums[0] = op.dimension_numbers().getOutputBatchDimension();
1908 outputDnums[1] = op.dimension_numbers().getOutputFeatureDimension();
1909 std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(),
1910 outputDnums.begin() + 2);
1911
1912 auto numDims = op.lhs().getType().cast<RankedTensorType>().getRank();
1913 const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; };
1914
1915 if (!llvm::all_of(inputDnums, inRange) ||
1916 !llvm::all_of(windowDnums, inRange) ||
1917 !llvm::all_of(outputDnums, inRange))
1918 return op.emitOpError() << "expects input, kernel, and output "
1919 "dimension-numbers to be in-range [0, "
1920 << numDims << ").";
1921
1922 if (hasDuplicates(inputDnums))
1923 return op.emitOpError()
1924 << "expects input dimension-numbers to be unique, got {"
1925 << inputDnums << "}.";
1926
1927 if (hasDuplicates(windowDnums))
1928 return op.emitOpError()
1929 << "expects kernel dimension-numbers to be unique, got {"
1930 << windowDnums << "}.";
1931
1932 if (hasDuplicates(outputDnums))
1933 return op.emitOpError()
1934 << "expects output dimension-numbers to be unique, got {"
1935 << outputDnums << "}.";
1936
1937 return success();
1938 }
1939
1940 // Verifies the following properties:
1941 // P1. The input, kernel, and output spatial-dimentions are valid.
1942 // P2. Given,
1943 // input-dimensions: b * input-spatial-dims * f
1944 // kernel-dimensions: kernel-spatial-dims * i * o
1945 // output-dimensions: b' * out-spatial-dims * f'
1946 // where b = input-batch-dims
1947 // where f = input-feature-dims
1948 // where i = kernel-input-feature-dims
1949 // where o = kernel-output-feature-dims
1950 // where b' = output-batch-dims
1951 // where f' = output-feature-dims
1952 // Check the following properties w.r.t feature_group_count (fgc) and
1953 // batch_group_count (bgc).
1954 // fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1)
1955 // b % bgc == 0
1956 // f % fgc == 0 and i = f / fgc
1957 // o (or f') % bgc == 0 and o (or f') % fgc == 0
verifyConvolutionAttributes(ConvolutionOp op)1958 LogicalResult verifyConvolutionAttributes(ConvolutionOp op) {
1959 // P1.
1960 if (failed(isSpatialDimensionsValid(op))) return failure();
1961
1962 // P2.
1963 const int64_t featureGroupCount = op.feature_group_count();
1964 const int64_t batchGroupCount = op.batch_group_count();
1965
1966 if (featureGroupCount <= 0)
1967 return op.emitOpError()
1968 << "expects feature_group_count to be a positive number, got "
1969 << featureGroupCount << ".";
1970
1971 if (batchGroupCount <= 0)
1972 return op.emitOpError()
1973 << "expects batch_group_count to be a positive number, got "
1974 << batchGroupCount << ".";
1975
1976 if (batchGroupCount > 1 && featureGroupCount > 1)
1977 return op.emitOpError()
1978 << "expects batch_group_count and feature_group_count not to be "
1979 "both greater than 1. Got "
1980 << batchGroupCount << " and " << featureGroupCount << " resp.";
1981
1982 auto lhsType = op.lhs().getType().cast<RankedTensorType>();
1983 const int64_t inputFeatures =
1984 lhsType.getShape()[op.dimension_numbers().getInputFeatureDimension()];
1985 const int64_t inputBatch =
1986 lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
1987
1988 auto rhsType = op.rhs().getType().cast<RankedTensorType>();
1989 const int64_t kernelInputFeatures =
1990 rhsType
1991 .getShape()[op.dimension_numbers().getKernelInputFeatureDimension()];
1992 const int64_t kernelOutputFeatures =
1993 rhsType
1994 .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
1995
1996 if (!isDynamicDimSize(kernelOutputFeatures)) {
1997 if (kernelOutputFeatures % batchGroupCount != 0)
1998 return op.emitOpError() << "expects output feature dimension size ("
1999 << kernelOutputFeatures
2000 << ") to be a multiple of "
2001 "batch_group_count. Got batch_group_count = "
2002 << batchGroupCount << ".";
2003
2004 if (kernelOutputFeatures % featureGroupCount != 0)
2005 return op.emitOpError()
2006 << "expects kernel output feature dimension ("
2007 << kernelOutputFeatures
2008 << ") to be divisible by "
2009 "feature_group_count. For feature_group_count = "
2010 << featureGroupCount << ".";
2011 }
2012
2013 if (!isDynamicDimSize(inputFeatures)) {
2014 if (inputFeatures % featureGroupCount != 0)
2015 return op.emitOpError()
2016 << "expects input feature dimension (" << inputFeatures
2017 << ") to be a multiple of "
2018 "feature_group_count. Got feature_group_count = "
2019 << featureGroupCount << ".";
2020
2021 if (!isDynamicDimSize(kernelInputFeatures) &&
2022 inputFeatures / featureGroupCount != kernelInputFeatures)
2023 return op.emitOpError()
2024 << "expects input feature dimension (" << inputFeatures
2025 << ") / "
2026 "feature_group_count = kernel input feature dimension ("
2027 << kernelInputFeatures
2028 << "). Got feature_group_count = " << featureGroupCount << ".";
2029 }
2030
2031 if (!isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0)
2032 return op.emitOpError() << "expects input batch dimension (" << inputBatch
2033 << ") to be divisible by "
2034 "batch_group_count. Got batch_group_count = "
2035 << batchGroupCount << ".";
2036
2037 return success();
2038 }
2039
2040 // Infer the return-shape of ConvolutionOp.
2041 // Precondition:
2042 // 1. Input args to ConvolutionOp 'op' are RankedTypes.
2043 // 2. rank-of(input-type) == rank-of(output-type)
inferConvolutionOpReturnShape(ConvolutionOp op,const ArrayRef<WindowDimension> window)2044 SmallVector<int64_t> inferConvolutionOpReturnShape(
2045 ConvolutionOp op, const ArrayRef<WindowDimension> window) {
2046 // We keep the 'unknown' dimensions (cl/415132294) as it is in the
2047 // output-shape. To do that we initilize the output dimensions with the shape
2048 // of the return-type and updates only the spatial + non-spatial dimensions.
2049 // Precondition 2 ensures that size of output-shape == size of input-shape.
2050 SmallVector<int64_t> outputDimensions =
2051 to_vector(op.getResult().getType().cast<ShapedType>().getShape());
2052
2053 // Infer the output spatial dimensions.
2054 auto lhsType = op.lhs().getType().cast<RankedTensorType>();
2055 auto inputSpatialDims = op.dimension_numbers().getInputSpatialDimensions();
2056 auto numSpatialDims = inputSpatialDims.size();
2057 SmallVector<int64_t> inputSpatialDimVals(numSpatialDims);
2058 for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
2059 inputSpatialDimVals[i] = lhsType.getShape()[inputSpatialDims[i]];
2060
2061 auto windowOutputShape = inferWindowOutputShape(inputSpatialDimVals, window);
2062
2063 for (int64_t i = 0; i < static_cast<int64_t>(window.size()); ++i)
2064 outputDimensions[op.dimension_numbers().getOutputSpatialDimensions()[i]] =
2065 windowOutputShape[i];
2066
2067 // Infer the output-batch-dimension and output-feature-dimension.
2068 auto rhsType = op.rhs().getType().cast<RankedTensorType>();
2069 const int64_t inputBatch =
2070 lhsType.getShape()[op.dimension_numbers().getInputBatchDimension()];
2071 const int64_t kernelOutputFeatures =
2072 rhsType
2073 .getShape()[op.dimension_numbers().getKernelOutputFeatureDimension()];
2074
2075 outputDimensions[op.dimension_numbers().getOutputBatchDimension()] =
2076 isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize
2077 : inputBatch / op.batch_group_count();
2078 outputDimensions[op.dimension_numbers().getOutputFeatureDimension()] =
2079 kernelOutputFeatures;
2080
2081 return outputDimensions;
2082 }
2083
2084 } // namespace
2085
2086 /*
2087 * We intend to verify the following properties
2088 * P1. Verify the input, kernel types.
2089 * P2. Verify the convolution atributes.
2090 * P3. Verify and collect the window atributes.
2091 * P4. Verify the return shape.
2092 * TODO(b/232574102): Verify the element-type of return-value.
2093 */
verify()2094 LogicalResult ConvolutionOp::verify() {
2095 auto lhsType = lhs().getType().dyn_cast<RankedTensorType>();
2096 auto rhsType = rhs().getType().dyn_cast<RankedTensorType>();
2097
2098 if (!lhsType || !rhsType) return success();
2099
2100 // P1.
2101 int numDims = lhsType.getRank();
2102 if (numDims != rhsType.getRank())
2103 return emitOpError()
2104 << "expects convolution arguments to have same number of "
2105 "dimensions. Got: "
2106 << lhsType << " and " << rhsType << ".";
2107
2108 if (numDims < 2)
2109 return emitOpError()
2110 << "expects convolution arguments to have >= 2 dimensions. "
2111 "Got: "
2112 << lhsType << " and " << rhsType << ".";
2113
2114 // P2.
2115 if (failed(verifyConvolutionAttributes(*this))) return failure();
2116
2117 // P3.
2118 auto kernelSpatialDimensions =
2119 dimension_numbers().getKernelSpatialDimensions();
2120 SmallVector<int64_t> windowDimensions(kernelSpatialDimensions.size());
2121 for (size_t i = 0; i < windowDimensions.size(); i++)
2122 windowDimensions[i] = rhsType.getShape()[kernelSpatialDimensions[i]];
2123
2124 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
2125 if (failed(paddingOrErr)) return failure();
2126 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
2127
2128 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
2129 windowDimensions, convertDenseIntAttr(window_strides()), padding,
2130 convertDenseIntAttr(lhs_dilation()), convertDenseIntAttr(rhs_dilation()),
2131 getLoc());
2132 if (failed(windowOrErr)) return failure();
2133
2134 // P4.
2135 auto actualReturnType = getResult().getType().cast<TensorType>();
2136 auto actualReturnElementType = actualReturnType.getElementType();
2137 if (!actualReturnType.hasRank()) return success();
2138
2139 auto actualReturnRankedType = actualReturnType.cast<RankedTensorType>();
2140 if (numDims != actualReturnRankedType.getRank())
2141 return emitOpError() << "expects rank of convolution return-type to be "
2142 "equal to input-ranks ("
2143 << numDims << "), but got "
2144 << actualReturnRankedType.getRank() << ".";
2145
2146 auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr);
2147 auto expectedReturnType =
2148 RankedTensorType::get(expectedReturnShape, actualReturnElementType);
2149 if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType)))
2150 return emitOpError()
2151 << "has shape mismatch between the expected return-type ("
2152 << expectedReturnType << ") and actual return-type ("
2153 << actualReturnRankedType << ").";
2154
2155 return success();
2156 }
2157
2158 //===----------------------------------------------------------------------===//
2159 // ConvertOp
2160 //===----------------------------------------------------------------------===//
2161
build(OpBuilder & builder,OperationState & result,Value operand,Type resultElementTy)2162 void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
2163 Type resultElementTy) {
2164 Type resultTy;
2165 Type operandTy = operand.getType();
2166 if (auto rankedTy = operandTy.dyn_cast<RankedTensorType>()) {
2167 resultTy = RankedTensorType::get(rankedTy.getShape(), resultElementTy);
2168 } else {
2169 resultTy = UnrankedTensorType::get(resultElementTy);
2170 }
2171 build(builder, result, resultTy, operand);
2172 }
2173
2174 //===----------------------------------------------------------------------===//
2175 // GetTupleElementOp
2176 //===----------------------------------------------------------------------===//
2177
verify()2178 LogicalResult GetTupleElementOp::verify() {
2179 auto indexVal = index();
2180 auto operandType = getOperand().getType().cast<TupleType>();
2181 if (indexVal >= operandType.size()) {
2182 return emitOpError(
2183 llvm::formatv("index {0} is out of bounds of operand with size {1}",
2184 indexVal, operandType.size()));
2185 }
2186
2187 auto expectedType = operandType.getType(indexVal);
2188 if (getType() != expectedType) {
2189 return emitOpError(llvm::formatv("has return type {0}, but expected {1}",
2190 getType(), expectedType));
2191 }
2192 return success();
2193 }
2194
2195 //===----------------------------------------------------------------------===//
2196 // TupleOp
2197 //===----------------------------------------------------------------------===//
2198
verify()2199 LogicalResult TupleOp::verify() {
2200 auto opType = getType().dyn_cast<TupleType>();
2201 if (!opType) return emitOpError("tuple op with non-tuple result");
2202 if (getNumOperands() != opType.size())
2203 return emitOpError(
2204 "number of operands to tuple expected to match number of types in "
2205 "resultant tuple type");
2206 for (const auto& it :
2207 llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) {
2208 if (std::get<0>(it.value()) != std::get<1>(it.value()))
2209 return emitOpError("has return type mismatch at ")
2210 << it.index() << "th value (" << std::get<0>(it.value())
2211 << " != " << std::get<1>(it.value()) << ")";
2212 }
2213 return success();
2214 }
2215
2216 //===----------------------------------------------------------------------===//
2217 // AllToAllOp
2218 //===----------------------------------------------------------------------===//
2219
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2220 LogicalResult AllToAllOp::inferReturnTypeComponents(
2221 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2222 DictionaryAttr attributes, RegionRange regions,
2223 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2224 AllToAllOp::Adaptor adaptor(operands, attributes, regions);
2225 Type operandType = adaptor.operand().getType();
2226 RankedTensorType operandRankedType = operandType.dyn_cast<RankedTensorType>();
2227 if (!operandRankedType) {
2228 inferredReturnShapes.emplace_back(
2229 operandType.cast<TensorType>().getElementType());
2230 return success();
2231 }
2232
2233 int64_t inputRank = operandRankedType.getRank();
2234 int64_t splitDimension = static_cast<int64_t>(adaptor.split_dimension());
2235 int64_t concatDimension = static_cast<int64_t>(adaptor.concat_dimension());
2236 if (splitDimension >= inputRank || splitDimension < 0) {
2237 return emitOptionalError(location, "AllToAll split_dimension ",
2238 splitDimension,
2239 " is out-of-bounds for input rank ", inputRank);
2240 }
2241 if (concatDimension >= inputRank || concatDimension < 0) {
2242 return emitOptionalError(location, "AllToAll concat_dimension ",
2243 concatDimension,
2244 " is out-of-bounds for input rank ", inputRank);
2245 }
2246
2247 // If operand is ranked, size of split dimension should be a multiple of split
2248 // count.
2249 int64_t splitCount = adaptor.split_count();
2250 auto splitDimSize = operandRankedType.getDimSize(splitDimension);
2251 if (splitDimSize % splitCount != 0) {
2252 return emitOptionalError(
2253 location, "split dimension has size ", splitDimSize,
2254 ", expected to be a multiple of split_count ", splitCount);
2255 }
2256 SmallVector<int64_t> resultShape(operandRankedType.getShape().begin(),
2257 operandRankedType.getShape().end());
2258 resultShape[splitDimension] /= splitCount;
2259 resultShape[concatDimension] *= splitCount;
2260 inferredReturnShapes.emplace_back(resultShape,
2261 operandRankedType.getElementType());
2262 return success();
2263 }
2264
2265 //===----------------------------------------------------------------------===//
2266 // AllGatherOp
2267 //===----------------------------------------------------------------------===//
2268
verify()2269 LogicalResult AllGatherOp::verify() {
2270 // If operand and result are both ranked, then the size of the gather
2271 // dimension in the result should be a multiple of the size of the gather
2272 // dimension in the operand.
2273 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2274 auto resultType = getType().dyn_cast<RankedTensorType>();
2275 uint64_t allGatherDimIndex = all_gather_dim();
2276 if (!operandType || !resultType ||
2277 operandType.isDynamicDim(allGatherDimIndex) ||
2278 resultType.isDynamicDim(allGatherDimIndex))
2279 return success();
2280 if (operandType.getDimSize(allGatherDimIndex) == 0)
2281 return emitOpError() << "operand gather dimension cannot be zero.";
2282 if ((resultType.getDimSize(allGatherDimIndex) %
2283 operandType.getDimSize(allGatherDimIndex)) != 0)
2284 return emitOpError()
2285 << "result gather dimension has size "
2286 << resultType.getDimSize(allGatherDimIndex)
2287 << ", expected to be a multiple of operand gather dimension size "
2288 << operandType.getDimSize(allGatherDimIndex);
2289
2290 return success();
2291 }
2292
2293 //===----------------------------------------------------------------------===//
2294 // BatchNormGradOp
2295 //===----------------------------------------------------------------------===//
2296
verify()2297 LogicalResult BatchNormGradOp::verify() {
2298 // The following properties are already enforced by the ODS:
2299 // 1. Inputs 'operand' & 'grad_output' and outputs 'grad_operand',
2300 // are ranked-tensors with floating-point (fp) type.
2301 // 2. The shapes of inputs 'operand' & 'grad_output' match.
2302 // 3. Inputs 'scale', 'mean', 'variance' and Outputs 'grad_scale',
2303 // 'grad_offset' are all 1D fp tensors with same shape.
2304 // 4. The element-types of input 'operand' and outputs 'grad_scale',
2305 // 'grad_offset' match.
2306 // 5. The type of input 'operand' and output 'grad_operand' match.
2307 //
2308 // We intend to verify the following properties
2309 // P1. Inputs 'operand' & 'grad_output' has the same shape with fp
2310 // element-types, ignoring fp-precision : Inferred from (1) & (2).
2311 // P2. The feature dimension 'feature_index' is a valid index in 'operand':
2312 // Inferred from check C2 below.
2313 // P3. Inputs 'scale', 'mean', 'variance' must be 1D tensors with same shape
2314 // and fp element-type (ignoring precision) and the number of elements
2315 // in its sole-dimension == number of features in the 'operand's
2316 // feature-dimension 'feature_index': Inferred from (3) and check C3
2317 // below.
2318 // P4. Outputs 'grad_scale' & 'grad_offset' are 1D tensors with
2319 // element-type == element-type of(operand) and same shape as any of
2320 // the inputs 'scale', 'mean', or 'variance': Inferred from (3), (4) and
2321 // check C3 below.
2322 // P5. The type (shape + element-type) of input 'operand' and
2323 // output 'grad_operand' must match: Inferred from (5).
2324
2325 // C2.
2326 auto operandType = operand().getType().cast<RankedTensorType>();
2327 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2328 return emitOpError() << "expects feature_index to be smaller "
2329 "than the rank of operand type; got feature_index "
2330 << feature_index() << ", and rank "
2331 << operandType.getRank() << ".";
2332
2333 if (static_cast<int64_t>(feature_index()) < 0)
2334 return emitOpError() << "expects feature_index to be a "
2335 << "non-negative number, got "
2336 << static_cast<int64_t>(feature_index()) << ".";
2337
2338 auto gradOutputType = grad_output().getType().cast<RankedTensorType>();
2339 if (operandType.getRank() != gradOutputType.getRank())
2340 return emitOpError() << "expects 'operand' and 'grad_output' to have the "
2341 "same rank. but got rank(oprand) "
2342 << operandType.getRank() << " and rank(grad_output) "
2343 << gradOutputType.getRank() << ".";
2344
2345 // C3.
2346 const int64_t featureCount = operandType.getShape()[feature_index()];
2347 const int64_t scaleShape =
2348 scale().getType().cast<RankedTensorType>().getShape()[0];
2349 if (scaleShape != featureCount)
2350 return emitOpError() << "expects the size of scale factor to be "
2351 "same as the feature count,"
2352 " but the size of scale factor is "
2353 << scaleShape << " and the feature count is "
2354 << featureCount << ".";
2355
2356 return success();
2357 }
2358
2359 //===----------------------------------------------------------------------===//
2360 // BatchNormTrainingOp
2361 //===----------------------------------------------------------------------===//
2362
verify()2363 LogicalResult BatchNormTrainingOp::verify() {
2364 // The following properties are already enforced by the ODS:
2365 // 1. 'operand' and 'output' are ranked tensors.
2366 // 2. 'scale', 'offset', 'batch_mean', 'batch_var' are 1D tensors.
2367 // 3. Types of 'operand' and 'output' matches.
2368 // 4. Same element-types for 'operand', 'batch_mean', & 'batch_var'.
2369 // 5. Same shapes for 'scale', 'offset', 'batch_mean', & 'batch_var'.
2370
2371 auto operandType = operand().getType().cast<RankedTensorType>();
2372 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2373 return emitOpError() << "expects feature_index to be smaller "
2374 "than the rank of operand type; got feature_index "
2375 << feature_index() << ", and rank "
2376 << operandType.getRank() << ".";
2377
2378 if (static_cast<int64_t>(feature_index()) < 0)
2379 return emitOpError() << "expects feature_index to be a "
2380 << "non-negative number, got "
2381 << static_cast<int64_t>(feature_index()) << ".";
2382
2383 // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2384
2385 const int64_t featureCount = operandType.getShape()[feature_index()];
2386 const int64_t scaleShape =
2387 scale().getType().cast<RankedTensorType>().getShape()[0];
2388 // Check number of elements in input 'scale' equals feature_count.
2389 // Together with (5) implies that 'scale', 'offset', 'batch_mean', &
2390 // 'batch_var' all have the same shape.
2391 if (scaleShape != featureCount)
2392 return emitOpError() << "expects the size of scale factor to be "
2393 "same as the feature count,"
2394 " but the size of scale factor is "
2395 << scaleShape << " and the feature count is "
2396 << featureCount << ".";
2397
2398 return success();
2399 }
2400
2401 //===----------------------------------------------------------------------===//
2402 // BatchNormInferenceOp
2403 //===----------------------------------------------------------------------===//
2404
verify()2405 LogicalResult BatchNormInferenceOp::verify() {
2406 // The following properties are already enforced by the ODS:
2407 // 1. 'operand' and 'result' are ranked tensors.
2408 // 2. 'scale', 'offset', 'mean', 'variance' are 1D tensors.
2409 // 3. Types of 'operand' and 'result' matches.
2410 // 4. Same shapes for 'scale', 'offset', 'mean', & 'variance'.
2411
2412 auto operandType = operand().getType().cast<RankedTensorType>();
2413 if (static_cast<int64_t>(feature_index()) >= operandType.getRank())
2414 return emitOpError() << "expects feature_index to be smaller "
2415 "than the rank of operand type; got feature_index "
2416 << feature_index() << ", and rank "
2417 << operandType.getRank() << ".";
2418
2419 if (static_cast<int64_t>(feature_index()) < 0)
2420 return emitOpError() << "expects feature_index to be a "
2421 << "non-negative number, got "
2422 << static_cast<int64_t>(feature_index()) << ".";
2423
2424 // Note:A valid value of feature-index implies 'operand_type.getRank() >=1'.
2425
2426 const int64_t featureCount = operandType.getShape()[feature_index()];
2427 const int64_t scaleSize =
2428 scale().getType().cast<RankedTensorType>().getShape()[0];
2429 // Check number of elements in input 'scale' equals feature_count.
2430 // Together with (4) implies that 'scale', 'offset', 'mean', &
2431 // 'variance' all have the same shape.
2432 if (scaleSize != featureCount)
2433 return emitOpError() << "expects the size of scale factor to be "
2434 "same as the feature count,"
2435 " but the size of scale factor is "
2436 << scaleSize << " and the feature count is "
2437 << featureCount << ".";
2438
2439 return success();
2440 }
2441
2442 //===----------------------------------------------------------------------===//
2443 // BitcastConvertOp
2444 //===----------------------------------------------------------------------===//
2445
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2446 LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
2447 OpBuilder& builder, ValueRange operands,
2448 SmallVectorImpl<Value>& reifiedReturnShapes) {
2449 auto operandType = operands[0].getType().dyn_cast<RankedTensorType>();
2450 auto resultType = getType().dyn_cast<RankedTensorType>();
2451
2452 // Only ranked tensors are supported.
2453 if (!operandType || !resultType) return failure();
2454
2455 // Shape-changing bitcast convert is not implemented.
2456 // TODO(kramerb): This could be done by adjusting the last dimension.
2457 DataLayout dataLayout = DataLayout::closest(*this);
2458 unsigned operandElementSize =
2459 dataLayout.getTypeSizeInBits(operandType.getElementType());
2460 unsigned resultElementSize =
2461 dataLayout.getTypeSizeInBits(resultType.getElementType());
2462 if (operandElementSize != resultElementSize) return failure();
2463
2464 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
2465 &reifiedReturnShapes);
2466 }
2467
2468 /*
2469 * We intend to verify the following properties
2470 * P1. We cannot convert between complex and real types (cf xla)
2471 * P3. The dimensions of the operand and the target
2472 * shape must match, except that the shape with the smaller element bitwidth has
2473 * an appropriately-sized additional innermost dimension, e.g.
2474 * ... x f32 => [bitcast_convert] => ... x 4 x i8
2475 * ... x 4 x i8 => [bitcast_convert] => ... x f32
2476 */
verify()2477 LogicalResult BitcastConvertOp::verify() {
2478 auto operandTensorType = operand().getType().cast<TensorType>();
2479 auto targetTensorType = getResult().getType().cast<TensorType>();
2480
2481 // P1.
2482 auto targetElt = targetTensorType.getElementType();
2483 auto operandElt = operandTensorType.getElementType();
2484 if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>()) {
2485 return emitOpError()
2486 << "cannot convert between real and complex types, but got: "
2487 << operandTensorType << " and " << targetTensorType;
2488 }
2489
2490 auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt);
2491 auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt);
2492
2493 // P2.
2494 auto operandType = operandTensorType.dyn_cast<RankedTensorType>();
2495 auto targetType = targetTensorType.dyn_cast<RankedTensorType>();
2496 if (!operandType || !targetType) return success();
2497
2498 auto targetShape = targetType.getShape();
2499 auto operandShape = operandType.getShape();
2500 ArrayRef<int64_t> smallerEltShape, biggerEltShape;
2501 Type smallerElt, biggerElt;
2502 if (operandEltBitwidth < targetEltBitwidth) {
2503 smallerEltShape = operandShape;
2504 smallerElt = operandElt;
2505 biggerEltShape = targetShape;
2506 biggerElt = targetElt;
2507 } else {
2508 smallerEltShape = targetShape;
2509 smallerElt = targetElt;
2510 biggerEltShape = operandShape;
2511 biggerElt = operandElt;
2512 }
2513
2514 ArrayRef<int64_t> smallerEltPrefix;
2515 auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth);
2516 auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth);
2517 if (operandEltBitwidth != targetEltBitwidth) {
2518 if (smallerEltShape.empty()) {
2519 return emitOpError() << "does not allow the smaller element type to be "
2520 "part of a 0d tensor, but got: "
2521 << operandType << " and " << targetType << ".";
2522 }
2523 smallerEltPrefix = smallerEltShape.drop_back();
2524 if (!isDynamicDimSize(smallerEltShape.back()) &&
2525 smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) {
2526 return emitOpError() << "requires compatible bitwidths. "
2527 << "Got: " << operandType << " and " << targetType
2528 << ", but " << smallerEltBitwidth << " * "
2529 << smallerEltShape.back()
2530 << " != " << biggerEltBitwidth << ".";
2531 }
2532 } else {
2533 smallerEltPrefix = smallerEltShape;
2534 }
2535
2536 for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
2537 auto targetDim = std::get<0>(it);
2538 auto operandDim = std::get<1>(it);
2539 if (!isDynamicDimSize(targetDim) && !isDynamicDimSize(operandDim)) {
2540 if (targetDim != operandDim) {
2541 return emitOpError() << "operand and result shapes must match except "
2542 "for the innermost dimension of the shape with "
2543 "the smaller element type. Got: "
2544 << operandType << " and " << targetType << ".";
2545 }
2546 }
2547 }
2548
2549 return success();
2550 }
2551
2552 //===----------------------------------------------------------------------===//
2553 // BroadcastOp
2554 //===----------------------------------------------------------------------===//
2555
2556 // TODO(b/129012527) These should be expressed as type constraints.
verify()2557 LogicalResult BroadcastOp::verify() {
2558 auto sizes = broadcast_sizes();
2559 auto sizesType = sizes.getType();
2560 auto sizesRank = sizesType.getRank();
2561 if (sizesRank != 1) {
2562 return emitOpError(llvm::formatv(
2563 "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
2564 }
2565
2566 return success();
2567 }
2568
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2569 LogicalResult BroadcastOp::inferReturnTypeComponents(
2570 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
2571 DictionaryAttr attributes, RegionRange regions,
2572 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2573 BroadcastOp::Adaptor adaptor(operands, attributes, regions);
2574 Value operand = adaptor.operand();
2575 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
2576 if (!operandType) return failure();
2577
2578 Type elementTy = operandType.getElementType();
2579 auto dimensionAttr = adaptor.broadcast_sizes();
2580 for (int64_t size : dimensionAttr.getValues<int64_t>()) {
2581 if (size < 0)
2582 return emitOptionalError(location,
2583 "Broadcast with negative dimension size ", size);
2584 }
2585 SmallVector<int64_t> shapeValues(dimensionAttr.getValues<int64_t>());
2586 llvm::append_range(shapeValues, operandType.getShape());
2587
2588 inferredReturnShapes.emplace_back(shapeValues, elementTy);
2589 return success();
2590 }
2591
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2592 LogicalResult BroadcastOp::reifyReturnTypeShapes(
2593 OpBuilder& builder, ValueRange operands,
2594 SmallVectorImpl<Value>& reifiedReturnShapes) {
2595 BroadcastOp::Adaptor adaptor(operands);
2596 Value operand = adaptor.operand();
2597
2598 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
2599 // Unranked tensors are not supported.
2600 if (!operandType) return failure();
2601
2602 Location loc = getLoc();
2603 SmallVector<Value, 4> shapeValues;
2604
2605 // Collect the broadcast sizes.
2606 for (const auto& size : broadcast_sizes()) {
2607 shapeValues.push_back(
2608 builder.create<arith::ConstantIndexOp>(loc, size.getZExtValue()));
2609 }
2610
2611 // Collect the operand sizes.
2612 for (auto index : llvm::seq<int64_t>(0, operandType.getRank())) {
2613 shapeValues.push_back(
2614 builder.createOrFold<tensor::DimOp>(loc, operand, index));
2615 }
2616
2617 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
2618 loc,
2619 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
2620 builder.getIndexType()),
2621 shapeValues));
2622
2623 return success();
2624 }
2625
2626 //===----------------------------------------------------------------------===//
2627 // BroadcastInDimOp
2628 //===----------------------------------------------------------------------===//
2629
verify()2630 LogicalResult BroadcastInDimOp::verify() {
2631 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2632 if (!operandType) {
2633 // The following verification checks all depend on knowing the rank of
2634 // the operand. Bail out now if we don't know the rank of the operand.
2635 return success();
2636 }
2637
2638 auto operandRank = operandType.getRank();
2639 if (!broadcast_dimensions()) {
2640 if (operandRank == 0) {
2641 return success();
2642 }
2643 return emitOpError(
2644 llvm::formatv("broadcast_dimensions is absent, but required because "
2645 "operand has non-zero rank ({0})",
2646 operandRank));
2647 }
2648
2649 auto dimensions = broadcast_dimensions();
2650 auto dimensionsType = broadcast_dimensions().getType();
2651 auto dimensionsRank = dimensionsType.getRank();
2652 if (dimensionsRank != 1) {
2653 return emitOpError(llvm::formatv(
2654 "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank));
2655 }
2656
2657 auto dimensionsSize = dimensionsType.getNumElements();
2658 if (dimensionsSize != operandRank) {
2659 return emitOpError(llvm::formatv(
2660 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
2661 dimensionsSize, operandRank));
2662 }
2663
2664 auto resultType = getResult().getType().cast<RankedTensorType>();
2665 auto resultRank = resultType.getRank();
2666 if (resultRank < operandRank) {
2667 return emitOpError(
2668 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
2669 resultRank, operandRank));
2670 }
2671
2672 for (int i = 0; i != dimensionsSize; ++i) {
2673 auto dimIndex = dimensions.getValues<int64_t>()[i];
2674 if (dimIndex >= resultRank) {
2675 return emitOpError(
2676 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
2677 "result with rank {1}",
2678 dimIndex, resultRank));
2679 }
2680
2681 if (!operandType.isDynamicDim(i)) {
2682 auto dimSize = operandType.getDimSize(i);
2683 auto resultDimSize = resultType.getDimSize(dimIndex);
2684 if (dimSize != 1 && dimSize != resultDimSize) {
2685 return emitOpError(
2686 llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
2687 "1 or size of result dimension {2} ({3})",
2688 i, dimSize, dimIndex, resultDimSize));
2689 }
2690 }
2691 }
2692
2693 return success();
2694 }
2695
2696 //===----------------------------------------------------------------------===//
2697 // DynamicBroadcastInDimOp
2698 //===----------------------------------------------------------------------===//
2699
verify()2700 LogicalResult DynamicBroadcastInDimOp::verify() {
2701 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
2702 auto resultType = getResult().getType().dyn_cast<RankedTensorType>();
2703
2704 // If either the operand or result are unranked, there is very little
2705 // to verify statically.
2706 if (!operandType || !resultType) {
2707 return success();
2708 }
2709
2710 auto outputDimensionsType =
2711 output_dimensions().getType().cast<RankedTensorType>();
2712 auto outputDimensionsSize = outputDimensionsType.getDimSize(0);
2713 auto operandRank = operandType.getRank();
2714 auto resultRank = resultType.getRank();
2715
2716 // Verify broadcast_dimensions.
2717 auto bcastDimensions = broadcast_dimensions();
2718 auto bcastDimensionsType = broadcast_dimensions().getType();
2719 auto bcastDimensionsRank = bcastDimensionsType.getRank();
2720 // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
2721 if (bcastDimensionsRank != 1) {
2722 return emitOpError(
2723 llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1",
2724 bcastDimensionsRank));
2725 }
2726
2727 auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
2728 if (bcastDimensionsSize != operandRank) {
2729 return emitOpError(llvm::formatv(
2730 "broadcast_dimensions size ({0}) does not match operand rank ({1})",
2731 bcastDimensionsSize, operandRank));
2732 }
2733
2734 if (resultRank < operandRank) {
2735 return emitOpError(
2736 llvm::formatv("result rank ({0}) is less than operand rank ({1})",
2737 resultRank, operandRank));
2738 }
2739
2740 for (int i = 0; i != bcastDimensionsSize; ++i) {
2741 auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
2742 if (dimIndex >= resultRank) {
2743 return emitOpError(
2744 llvm::formatv("broadcast_dimensions contains invalid value {0} for "
2745 "result with rank {1}",
2746 dimIndex, resultRank));
2747 }
2748
2749 auto dimSize = operandType.getDimSize(i);
2750 auto resultDimSize = resultType.getDimSize(dimIndex);
2751 // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
2752 // add a manual check for this.
2753 if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
2754 return emitOpError(
2755 llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
2756 "with size of result dimension {2} ({3})",
2757 i, dimSize, dimIndex, resultDimSize));
2758 }
2759 }
2760
2761 if (outputDimensionsSize != resultRank) {
2762 return emitOpError(
2763 llvm::formatv("result rank ({0}) is not equal to number of output "
2764 "dimensions ({1})",
2765 resultRank, outputDimensionsSize));
2766 }
2767
2768 // Verify that the known expanding and non-expanding dimensions are a subset
2769 // of the operand's dimensions.
2770 int64_t numKnownExpansionBehavior = 0;
2771 DenseSet<int64_t> knownExpansionBehavior;
2772 auto collectExpansionBehaviorDims =
2773 [&](const Optional<DenseIntElementsAttr>& attr) {
2774 if (!attr) return;
2775 for (const APInt& it : *attr) {
2776 numKnownExpansionBehavior++;
2777 knownExpansionBehavior.insert(it.getLimitedValue());
2778 }
2779 };
2780 collectExpansionBehaviorDims(known_expanding_dimensions());
2781 collectExpansionBehaviorDims(known_nonexpanding_dimensions());
2782 if (knownExpansionBehavior.size() != numKnownExpansionBehavior) {
2783 return emitOpError(
2784 "duplicate expansion hint for at least one operand dimension");
2785 }
2786 for (int64_t i : knownExpansionBehavior) {
2787 if (i < 0 || i >= operandRank) {
2788 return emitOpError(
2789 llvm::formatv("hint for expanding dimension {0} does not refer to a "
2790 "valid operand dimension",
2791 i));
2792 }
2793 }
2794
2795 return success();
2796 }
2797
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2798 LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
2799 OpBuilder& builder, ValueRange operands,
2800 SmallVectorImpl<Value>& reifiedReturnShapes) {
2801 DynamicBroadcastInDimOp::Adaptor adaptor(operands);
2802 reifiedReturnShapes.push_back(
2803 castToIndexTensor(builder, getLoc(), adaptor.output_dimensions()));
2804 return success();
2805 }
2806
2807 //===----------------------------------------------------------------------===//
2808 // ClampOp
2809 //===----------------------------------------------------------------------===//
2810
verify()2811 LogicalResult ClampOp::verify() {
2812 auto operandType = operand().getType().cast<RankedTensorType>();
2813 auto operandShape = operandType.getShape();
2814 auto minType = min().getType().cast<RankedTensorType>();
2815
2816 auto minShape = minType.getShape();
2817 if (failed(verifyCompatibleShape(minType, operandType)) &&
2818 minType.getRank() != 0) {
2819 return emitOpError(llvm::formatv(
2820 "min shape [{0}] is not scalar and is not compatible to operand shape "
2821 "[{1}]",
2822 llvm::make_range(minShape.begin(), minShape.end()),
2823 llvm::make_range(operandShape.begin(), operandShape.end())));
2824 }
2825
2826 auto maxType = max().getType().cast<RankedTensorType>();
2827 auto maxShape = maxType.getShape();
2828 if (failed(verifyCompatibleShape(maxType, operandType)) &&
2829 maxType.getRank() != 0) {
2830 return emitOpError(llvm::formatv(
2831 "max shape [{0}] is not scalar and is not compatible to operand shape "
2832 "[{1}]",
2833 llvm::make_range(maxShape.begin(), maxShape.end()),
2834 llvm::make_range(operandShape.begin(), operandShape.end())));
2835 }
2836
2837 return success();
2838 }
2839
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)2840 LogicalResult ClampOp::inferReturnTypeComponents(
2841 MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
2842 DictionaryAttr attributes, RegionRange regions,
2843 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
2844 ClampOp::Adaptor adaptor(operands, attributes, regions);
2845 RankedTensorType operandType =
2846 adaptor.operand().getType().cast<RankedTensorType>();
2847 inferredReturnShapes.emplace_back(operandType.getShape(),
2848 operandType.getElementType());
2849 return success();
2850 }
2851
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)2852 LogicalResult ClampOp::reifyReturnTypeShapes(
2853 OpBuilder& builder, ValueRange operands,
2854 SmallVectorImpl<Value>& reifiedReturnShapes) {
2855 // For `stablehlo.clamp`, the first operand may be a scalar.
2856 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1],
2857 &reifiedReturnShapes);
2858 }
2859
2860 //===----------------------------------------------------------------------===//
2861 // ComplexOp
2862 //===----------------------------------------------------------------------===//
2863
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2864 LogicalResult ComplexOp::inferReturnTypes(
2865 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2866 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2867 TensorType operandType = operands[0].getType().cast<TensorType>();
2868 ComplexType elementTy = ComplexType::get(operandType.getElementType());
2869 inferredReturnTypes.push_back(
2870 hlo::getSameShapeTensorType(operandType, elementTy));
2871 return success();
2872 }
2873
2874 //===----------------------------------------------------------------------===//
2875 // ImagOp
2876 //===----------------------------------------------------------------------===//
2877
2878 namespace {
createRealType(TensorType type)2879 Type createRealType(TensorType type) {
2880 auto elementTy = type.getElementType();
2881 if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
2882 elementTy = complexTy.getElementType();
2883 }
2884 return hlo::getSameShapeTensorType(type, elementTy);
2885 }
2886 } // namespace
2887
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2888 LogicalResult ImagOp::inferReturnTypes(
2889 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2890 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2891 inferredReturnTypes.push_back(
2892 createRealType(operands[0].getType().cast<TensorType>()));
2893 return success();
2894 }
2895
2896 //===----------------------------------------------------------------------===//
2897 // IsFiniteOp
2898 //===----------------------------------------------------------------------===//
2899
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2900 LogicalResult IsFiniteOp::inferReturnTypes(
2901 MLIRContext* ctx, Optional<Location>, ValueRange operands, DictionaryAttr,
2902 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2903 auto argTy = operands.front().getType().cast<TensorType>();
2904 Builder b(ctx);
2905 inferredReturnTypes.push_back(
2906 hlo::getSameShapeTensorType(argTy, b.getI1Type()));
2907 return success();
2908 }
2909
2910 //===----------------------------------------------------------------------===//
2911 // RealOp
2912 //===----------------------------------------------------------------------===//
2913
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)2914 LogicalResult RealOp::inferReturnTypes(
2915 MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
2916 RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
2917 inferredReturnTypes.push_back(
2918 createRealType(operands[0].getType().cast<TensorType>()));
2919 return success();
2920 }
2921
2922 //===----------------------------------------------------------------------===//
2923 // ConcatenateOp
2924 //===----------------------------------------------------------------------===//
2925
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)2926 LogicalResult ConcatenateOp::inferReturnTypes(
2927 MLIRContext*, Optional<Location> location, ValueRange operands,
2928 DictionaryAttr attributes, RegionRange regions,
2929 SmallVectorImpl<Type>& inferredReturnTypes) {
2930 if (operands.empty()) {
2931 return failure();
2932 }
2933
2934 auto dimensionAttr = attributes.get("dimension").cast<IntegerAttr>();
2935 auto dimension = dimensionAttr.getInt();
2936
2937 auto firstType = (*operands.begin()).getType().cast<ShapedType>();
2938 auto outElement = firstType.getElementType();
2939
2940 // Find the first ranked input to determine the output rank.
2941 for (auto type : operands.getTypes()) {
2942 auto shapedType = type.cast<ShapedType>();
2943 if (shapedType.hasRank()) {
2944 firstType = shapedType;
2945 break;
2946 }
2947 }
2948
2949 // If all inputs are unranked, the result must be unranked.
2950 if (!firstType.hasRank()) {
2951 inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
2952 return success();
2953 }
2954
2955 auto outShape = llvm::to_vector<6>(firstType.getShape());
2956
2957 // Determine what the non-concatenate dimensions should be.
2958 for (auto type : operands.getTypes()) {
2959 auto shapedTy = type.cast<ShapedType>();
2960 if (!shapedTy.hasRank()) {
2961 continue;
2962 }
2963
2964 for (const auto& it : llvm::enumerate(shapedTy.getShape())) {
2965 // If a dimension is not dynamic, the output shape should match.
2966 if (ShapedType::isDynamic(outShape[it.index()])) {
2967 outShape[it.index()] = it.value();
2968 }
2969 }
2970 }
2971
2972 outShape[dimension] = 0;
2973
2974 for (auto operand : operands.getTypes()) {
2975 auto type = operand.cast<ShapedType>();
2976 if (!type.hasRank()) {
2977 inferredReturnTypes.push_back(UnrankedTensorType::get(outElement));
2978 return success();
2979 }
2980
2981 // If the dimension is dynamic we know the output dimension is dynamic.
2982 auto dim = type.getShape()[dimension];
2983 if (ShapedType::isDynamic(dim)) {
2984 outShape[dimension] = ShapedType::kDynamicSize;
2985 break;
2986 }
2987
2988 outShape[dimension] += dim;
2989 }
2990
2991 inferredReturnTypes.push_back(RankedTensorType::get(outShape, outElement));
2992
2993 return success();
2994 }
2995
verify()2996 LogicalResult ConcatenateOp::verify() {
2997 RankedTensorType firstRankedType;
2998 int firstRankedIndex;
2999 int numOperands = getNumOperands();
3000 int64_t concatDimension = static_cast<int64_t>(dimension());
3001 if (concatDimension < 0) {
3002 return emitOpError(
3003 llvm::formatv("dimension {0} is negative", concatDimension));
3004 }
3005 for (int i = 0; i < numOperands; i++) {
3006 auto secondType = getOperand(i).getType().dyn_cast<ShapedType>();
3007 if (!secondType.hasRank()) {
3008 continue;
3009 }
3010
3011 if (!firstRankedType) {
3012 firstRankedType = secondType.cast<RankedTensorType>();
3013 firstRankedIndex = i;
3014 if (firstRankedType.getRank() == 0)
3015 return emitOpError(
3016 llvm::formatv("rank-0 values cannot be concatenated"));
3017 if (concatDimension >= firstRankedType.getRank()) {
3018 return emitOpError(
3019 llvm::formatv("dimension {0} is out-of-bounds for input rank {1}",
3020 concatDimension, firstRankedType.getRank()));
3021 }
3022 continue;
3023 }
3024
3025 if (firstRankedType.getRank() != secondType.getRank()) {
3026 return emitOpError(llvm::formatv(
3027 "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i));
3028 }
3029
3030 auto firstShape = firstRankedType.getShape();
3031 auto secondShape = secondType.getShape();
3032 for (int d = 0; d < firstRankedType.getRank(); ++d) {
3033 if (!ShapedType::isDynamic(firstShape[d]) &&
3034 !ShapedType::isDynamic(secondShape[d]) &&
3035 firstShape[d] != secondShape[d] && d != concatDimension) {
3036 return emitOpError(llvm::formatv(
3037 "shapes of operand ({0}) and ({1}) do not match at non-concat "
3038 "index: ({2}) != ({3}) at non-concat index {4}",
3039 firstRankedIndex, i,
3040 llvm::make_range(firstShape.begin(), firstShape.end()),
3041 llvm::make_range(secondShape.begin(), secondShape.end()), d));
3042 }
3043 }
3044 }
3045 return success();
3046 }
3047
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3048 LogicalResult ConcatenateOp::reifyReturnTypeShapes(
3049 OpBuilder& builder, ValueRange operands,
3050 SmallVectorImpl<Value>& reifiedReturnShapes) {
3051 ConcatenateOp::Adaptor adaptor(operands);
3052 auto inputs = adaptor.val();
3053
3054 auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
3055 // Not support unranked type a.t.m.
3056 if (!operandType) return failure();
3057
3058 Location loc = this->getLoc();
3059 Type shapeScalarType = builder.getIndexType();
3060 auto toShapeScalarType = [&](Value v) {
3061 return maybeCastTo(builder, loc, v, shapeScalarType);
3062 };
3063
3064 SmallVector<SmallVector<Value, 4>, 4> allShapeValues;
3065 for (size_t inputId = 0; inputId < inputs.size(); ++inputId) {
3066 Value operand = inputs[inputId];
3067 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3068 if (!operandType) return failure();
3069
3070 SmallVector<Value, 4> shapeVals;
3071 for (const auto& element : llvm::enumerate(operandType.getShape())) {
3072 Value valueDim = toShapeScalarType(
3073 builder.create<tensor::DimOp>(loc, operand, element.index()));
3074 shapeVals.push_back(valueDim);
3075 }
3076 allShapeValues.emplace_back(std::move(shapeVals));
3077 }
3078
3079 int axis = this->dimension();
3080 auto& shapeValues = allShapeValues[0];
3081 for (size_t vecId = 1; vecId < allShapeValues.size(); ++vecId) {
3082 auto& otherShapeValues = allShapeValues[vecId];
3083 if (otherShapeValues.size() != shapeValues.size()) {
3084 this->emitOpError()
3085 << "Concatenate expects all operands must be of the same rank";
3086 return failure();
3087 }
3088 shapeValues[axis] = builder.create<arith::AddIOp>(loc, shapeValues[axis],
3089 otherShapeValues[axis]);
3090 }
3091
3092 Value outputShape = builder.create<tensor::FromElementsOp>(
3093 loc,
3094 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3095 shapeScalarType),
3096 shapeValues);
3097 reifiedReturnShapes.push_back(outputShape);
3098
3099 return success();
3100 }
3101
3102 //===----------------------------------------------------------------------===//
3103 // DynamicReshapeOp
3104 //===----------------------------------------------------------------------===//
3105
verify()3106 LogicalResult DynamicReshapeOp::verify() {
3107 auto resultType = result().getType().dyn_cast<RankedTensorType>();
3108 auto outputShapeType = output_shape().getType().dyn_cast<RankedTensorType>();
3109 if (resultType && outputShapeType && outputShapeType.hasStaticShape() &&
3110 outputShapeType.getDimSize(0) != resultType.getRank()) {
3111 return emitError() << "output should have a rank equal to the number of "
3112 "elements in output_shape";
3113 }
3114 return success();
3115 }
3116
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3117 LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
3118 OpBuilder& builder, ValueRange operands,
3119 SmallVectorImpl<Value>& reifiedReturnShapes) {
3120 DynamicReshapeOp::Adaptor adaptor(operands);
3121 reifiedReturnShapes.push_back(
3122 castToIndexTensor(builder, getLoc(), adaptor.output_shape()));
3123 return success();
3124 }
3125
3126 //===----------------------------------------------------------------------===//
3127 // DynamicSliceOp
3128 //===----------------------------------------------------------------------===//
3129
3130 // Verifies that the number of slice sizes and the number of start indices match
verify()3131 LogicalResult DynamicSliceOp::verify() {
3132 int numSliceSizes = slice_sizes().getNumElements();
3133 int numStartIndices = start_indices().size();
3134 if (numStartIndices != numSliceSizes) {
3135 return emitOpError() << "has mismatched number of slice sizes ("
3136 << numSliceSizes << ") and number of start indices ("
3137 << numStartIndices << ")";
3138 }
3139 auto operandType = operand().getType().dyn_cast<RankedTensorType>();
3140 if (!operandType) return failure();
3141
3142 if (operandType.getRank() != numStartIndices) {
3143 return emitOpError() << "has mismatched number of start indices ("
3144 << numStartIndices << ") and the rank of operand ("
3145 << operandType.getRank() << ")";
3146 }
3147
3148 for (int i = 0; i < numSliceSizes; ++i) {
3149 int64_t sliceSize = slice_sizes().getValues<int64_t>()[i];
3150 if (sliceSize < 0) {
3151 return emitOpError() << "has negative size index to dynamic slice: "
3152 << sliceSize;
3153 }
3154 if (!operandType.isDynamicDim(i)) {
3155 int64_t dimSize = operandType.getDimSize(i);
3156 if (sliceSize > dimSize) {
3157 return emitOpError() << "has slice size " << sliceSize
3158 << " greater than dimension size " << dimSize
3159 << " in dimension " << i << " of operand";
3160 }
3161 }
3162 }
3163 return success();
3164 }
3165
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)3166 LogicalResult DynamicSliceOp::inferReturnTypeComponents(
3167 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
3168 DictionaryAttr attributes, RegionRange regions,
3169 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
3170 DynamicSliceOp::Adaptor adaptor(operands, attributes, regions);
3171 Value operand = adaptor.operand();
3172 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3173 if (!operandType) return failure();
3174
3175 auto sliceSizes = adaptor.slice_sizes();
3176 Type elementTy = operandType.getElementType();
3177 inferredReturnShapes.emplace_back(sliceSizes.getValues<int64_t>(), elementTy);
3178 return success();
3179 }
3180
3181 //===----------------------------------------------------------------------===//
3182 // RealDynamicSliceOp
3183 //===----------------------------------------------------------------------===//
3184 // Verifies that operand rank matches start_indices/limit_indices/strides size
verify()3185 LogicalResult RealDynamicSliceOp::verify() {
3186 auto inputType = operand().getType().dyn_cast<RankedTensorType>();
3187 // If operand is unranked, there is very little to verify statically.
3188 if (!inputType) return success();
3189 int inputRank = inputType.getRank();
3190
3191 auto startType = start_indices().getType().cast<RankedTensorType>();
3192 auto limitType = limit_indices().getType().cast<RankedTensorType>();
3193 auto stridesType = strides().getType().cast<RankedTensorType>();
3194
3195 if (inputRank != startType.getNumElements()) {
3196 return emitOpError() << "has mismatched number of operand rank ("
3197 << inputRank << ") and start_indices size ("
3198 << startType.getNumElements() << ")";
3199 }
3200
3201 if (inputRank != limitType.getNumElements()) {
3202 return emitOpError() << "has mismatched number of operand rank ("
3203 << inputRank << ") and limit_indices size ("
3204 << limitType.getNumElements() << ")";
3205 }
3206
3207 if (inputRank != stridesType.getNumElements()) {
3208 return emitOpError() << "has mismatched number of operand rank ("
3209 << inputRank << ") and strides size ("
3210 << stridesType.getNumElements() << ")";
3211 }
3212
3213 return success();
3214 }
3215
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3216 LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
3217 OpBuilder& builder, ValueRange operands,
3218 SmallVectorImpl<Value>& reifiedReturnShapes) {
3219 RealDynamicSliceOp::Adaptor adaptor(operands);
3220 Value operand = adaptor.operand();
3221 Value startIndices = adaptor.start_indices();
3222 Value limitIndices = adaptor.limit_indices();
3223 Value strides = adaptor.strides();
3224
3225 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
3226 // Not support unranked type a.t.m.
3227 if (!operandType) return failure();
3228
3229 Location loc = this->getLoc();
3230 SmallVector<Value, 4> shapeValues;
3231 shapeValues.reserve(operandType.getRank());
3232 Type shapeScalarType =
3233 startIndices.getType().cast<ShapedType>().getElementType();
3234 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
3235 one = maybeCastTo(builder, loc, one, shapeScalarType);
3236 for (const auto& element : llvm::enumerate(operandType.getShape())) {
3237 Value offset = builder.create<arith::ConstantIndexOp>(loc, element.index());
3238 Value valueStart =
3239 builder.create<tensor::ExtractOp>(loc, startIndices, offset);
3240 Value valueLimit =
3241 builder.create<tensor::ExtractOp>(loc, limitIndices, offset);
3242 Value valueStride = builder.create<tensor::ExtractOp>(loc, strides, offset);
3243 // size = (limit - start + stride - 1) / stride
3244 shapeValues.push_back(builder.create<arith::DivSIOp>(
3245 loc,
3246 builder.create<arith::SubIOp>(
3247 loc,
3248 builder.create<arith::AddIOp>(
3249 loc, valueStride,
3250 builder.create<arith::SubIOp>(loc, valueLimit, valueStart)),
3251 one),
3252 valueStride));
3253 }
3254
3255 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
3256 loc,
3257 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
3258 shapeScalarType),
3259 shapeValues));
3260 return success();
3261 }
3262
3263 //===----------------------------------------------------------------------===//
3264 // InfeedOp
3265 //===----------------------------------------------------------------------===//
3266
3267 // Checks that the result type is of the form `zero_or_more_type(s),
3268 // stablehlo::token`
verify()3269 LogicalResult InfeedOp::verify() {
3270 auto resultTypes = getResultTypes();
3271 if (resultTypes.empty())
3272 return emitOpError()
3273 << "result is expected to be at least of size 1, but got "
3274 << resultTypes.size();
3275
3276 if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
3277 return emitOpError() << "last element of result types is expected to "
3278 "be of token type, but got "
3279 << resultTypes[resultTypes.size() - 1];
3280
3281 // Verify layout attribute
3282 constexpr char kLayoutAttr[] = "layout";
3283 if (!getOperation()->hasAttr(kLayoutAttr)) return success();
3284
3285 mlir::ArrayAttr layout =
3286 getOperation()->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
3287 if (!layout)
3288 return emitOpError() << "layout-attribute expected to be of array-type.";
3289
3290 if (layout.size() != resultTypes.size() - 1) {
3291 return emitOpError() << "layout-attribute size must be "
3292 << resultTypes.size() - 1
3293 << " (which is the number of "
3294 "op-results - 1 (for token result)), but got "
3295 << layout.size();
3296 }
3297
3298 for (auto childLayout : layout) {
3299 mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast<mlir::ArrayAttr>();
3300 if (!childLayoutArr) {
3301 return emitOpError() << "layout-attribute expected to have "
3302 "elements of type array, but got "
3303 << childLayout;
3304 }
3305
3306 for (auto i : childLayoutArr) {
3307 mlir::IntegerAttr attr = i.dyn_cast<mlir::IntegerAttr>();
3308 if (!attr) {
3309 return emitOpError() << "layout-attribute's leaf elements are "
3310 "expected to be of type integer, but got "
3311 << i;
3312 }
3313 }
3314 }
3315
3316 return success();
3317 }
3318
3319 //===----------------------------------------------------------------------===//
3320 // MapOp
3321 //===----------------------------------------------------------------------===//
3322
verify()3323 LogicalResult MapOp::verify() {
3324 // Checks if the number of `operands` match the arity of the map `computation`
3325 // region.
3326 auto& computationBlock = computation().front();
3327 auto computationArgs = computationBlock.getArguments();
3328 if (operands().size() != computationArgs.size())
3329 return emitOpError() << "expects number of operands to match the arity "
3330 "of map computation, but got: "
3331 << operands().size() << " and "
3332 << computationArgs.size();
3333
3334 // The parameters of computation should all be scalars and match the element
3335 // type of operands.
3336 for (const auto& indexedArg : llvm::enumerate(computationArgs)) {
3337 auto argType = indexedArg.value().getType().dyn_cast<TensorType>();
3338 if (!argType || argType.getRank() != 0)
3339 return emitOpError()
3340 << "computation arguments must be 0-rank tensor, but got: arg #"
3341 << indexedArg.index() << " of type "
3342 << indexedArg.value().getType();
3343 auto operandElemTy = operands()[indexedArg.index()]
3344 .getType()
3345 .cast<TensorType>()
3346 .getElementType();
3347 if (argType.getElementType() != operandElemTy) {
3348 return emitOpError()
3349 << "element type of operands and computation arguments must "
3350 "match, but got: "
3351 << operandElemTy << " and " << argType.getElementType();
3352 }
3353 }
3354
3355 // Mapped computation must return single output
3356 auto computationOutputs = computationBlock.getTerminator()->getOperands();
3357 if (computationOutputs.size() != 1)
3358 return emitOpError() << "computation must return single output, but got: "
3359 << computationOutputs.size();
3360
3361 // The output of computation must be scalar and have the same element type
3362 // as op result.
3363 auto computationOutputType =
3364 computationOutputs[0].getType().dyn_cast<TensorType>();
3365 if (!computationOutputType || computationOutputType.getRank() != 0)
3366 return emitOpError() << "computation must return 0-rank tensor, but got: "
3367 << computationOutputs[0].getType();
3368
3369 auto resultType = getType().cast<TensorType>();
3370 if (computationOutputType.getElementType() != resultType.getElementType())
3371 return emitOpError() << "element type of result and computation output "
3372 "must match, but got: "
3373 << resultType.getElementType() << " and "
3374 << computationOutputType.getElementType();
3375
3376 // Checks that the requested map dimension numbers are monotonically
3377 // increasing.
3378 DenseIntElementsAttr dimensions = this->dimensions();
3379 for (const auto& indexedValue :
3380 llvm::enumerate(dimensions.getValues<int64_t>())) {
3381 if (indexedValue.value() != static_cast<int64_t>(indexedValue.index()))
3382 return emitOpError() << "requires monotonically increasing dimension "
3383 "numbers, but got: "
3384 << dimensions;
3385 }
3386
3387 // Checks that number of dimensions of operands matches the size of
3388 // `dimensions` since we currently only support mapping across all
3389 // dimensions: i.e., scalar map functions.
3390 auto operandType = operands()[0].getType().cast<TensorType>();
3391 if (operandType.hasRank()) {
3392 if (dimensions.size() !=
3393 static_cast<int64_t>(operandType.getShape().size()))
3394 return emitOpError()
3395 << "applied to a subset of dimensions currently not supported: "
3396 "operand dimensions = "
3397 << operandType.getShape().size()
3398 << ", requested map dimensions size = " << dimensions.size();
3399 }
3400
3401 return success();
3402 }
3403
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)3404 LogicalResult MapOp::reifyReturnTypeShapes(
3405 OpBuilder& builder, ValueRange operands,
3406 SmallVectorImpl<Value>& reifiedReturnShapes) {
3407 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
3408 &reifiedReturnShapes);
3409 }
3410
3411 //===----------------------------------------------------------------------===//
3412 // RecvOp
3413 //===----------------------------------------------------------------------===//
3414
3415 // Checks that the result type is of the form `zero_or_more_type(s),
3416 // stablehlo::token`
verify()3417 LogicalResult RecvOp::verify() {
3418 auto resultTypes = getResultTypes();
3419 if (resultTypes.empty())
3420 return emitOpError()
3421 << "result is expected to be at least of size 1, but got "
3422 << resultTypes.size();
3423 if (!resultTypes[resultTypes.size() - 1].isa<TokenType>())
3424 return emitOpError() << "last element of result types is expected to "
3425 "be of token type, but got "
3426 << resultTypes[resultTypes.size() - 1];
3427 return success();
3428 }
3429
3430 //===----------------------------------------------------------------------===//
3431 // ReduceWindowOp
3432 //===----------------------------------------------------------------------===//
3433
3434 namespace {
3435 // Infer the return-type of ReduceWindowOp.
inferReduceWindowOpReturnType(ArrayRef<TensorType> inputTypes,ArrayRef<TensorType> initTypes,const ArrayRef<WindowDimension> window)3436 SmallVector<TensorType> inferReduceWindowOpReturnType(
3437 ArrayRef<TensorType> inputTypes, ArrayRef<TensorType> initTypes,
3438 const ArrayRef<WindowDimension> window) {
3439 SmallVector<TensorType> outputTypes;
3440 for (size_t i = 0; i < inputTypes.size(); ++i) {
3441 if (!inputTypes[i].hasRank()) {
3442 outputTypes.push_back(
3443 UnrankedTensorType::get(initTypes[i].getElementType()));
3444 continue;
3445 }
3446
3447 outputTypes.push_back(RankedTensorType::get(
3448 inferWindowOutputShape(inputTypes[i].getShape(), window),
3449 initTypes[i].getElementType()));
3450 }
3451
3452 return outputTypes;
3453 }
3454 } // namespace
3455
3456 // We intend to verify the following properties
3457 // P1. The sizes of 'inputs' and 'init_values' must be at least 1.
3458 // P2. All `inputs` need to have compatible shapes.
3459 // P3. size-of(window_dimension) == rank-of(input),
3460 // where input is an element of 'inputs'.
3461 // P4. Verify and collect the window atributes.
3462 // P5. Verify the inner block defining the reducer function.
3463 // P6. Verify the return type.
verify()3464 LogicalResult ReduceWindowOp::verify() {
3465 // P1.
3466 // Note that the ODS ensures that there are even number of operands; Check if
3467 // that number is not zero.
3468 if (getOperands().empty())
3469 return emitOpError() << "expects the size of operands to be >= 2.";
3470
3471 // Collect the input and init-value operands. Note that the operand-type is
3472 // enforced as "TensorType" by ODS.
3473 int64_t numInputs = getNumOperands() / 2;
3474 auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
3475 getOperandTypes(),
3476 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3477 ArrayRef<TensorType> inputTypes(operandTensorTypes.begin(),
3478 operandTensorTypes.begin() + numInputs);
3479 ArrayRef<TensorType> initTypes(operandTensorTypes.begin() + numInputs,
3480 operandTensorTypes.end());
3481
3482 // P2.
3483 if (failed(verifyCompatibleShapes(operands().getTypes())))
3484 return emitOpError() << "requires same shape for all inputs";
3485
3486 // P3.
3487 SmallVector<int64_t> windowDims =
3488 convertDenseIntAttr(this->window_dimensions());
3489 for (const auto inputType : inputTypes) {
3490 if (!inputType.hasRank()) continue;
3491 if (inputType.getRank() != static_cast<int64_t>(windowDims.size()))
3492 return emitOpError()
3493 << "expects window-dimensions size == input rank, but got "
3494 "window-dimensions size: "
3495 << windowDims.size() << " and input: " << inputType
3496 << " with rank = " << inputType.getRank() << ".";
3497 }
3498
3499 // P4.
3500 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
3501 if (failed(paddingOrErr)) return failure();
3502 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
3503
3504 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
3505 windowDims, convertDenseIntAttr(window_strides()), padding,
3506 /*lhs_dilation=*/convertDenseIntAttr(base_dilations()),
3507 /*rhs_dilation=*/convertDenseIntAttr(this->window_dilations()), getLoc());
3508 if (failed(windowOrErr)) return failure();
3509
3510 // P5.
3511 bool allInputsUnranked =
3512 llvm::all_of(inputTypes, [](TensorType t) { return !t.hasRank(); });
3513
3514 Block& block = body().front();
3515 SmallVector<TensorType> accumulatorSubshapes;
3516 if (failed(verifyReducerShape(this->getLoc(), block, inputTypes, initTypes,
3517 numInputs, windowDims, allInputsUnranked,
3518 accumulatorSubshapes)))
3519 return failure();
3520
3521 // P6.
3522 if (numInputs != getNumResults())
3523 return emitOpError() << "expects " << numInputs
3524 << " result values, but got " << getNumResults()
3525 << ".";
3526
3527 // The result-type is enforced as "TensorType" by ODS.
3528 auto resultTensorTypes = llvm::to_vector<4>(llvm::map_range(
3529 getResultTypes(),
3530 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3531
3532 // Check if the element-type of results match with the ones derived from
3533 // the reducer-block. Already ensured that |accumulator_subshapes| ==
3534 // num_inputs == num_of_results.
3535 for (int64_t shapeIdx = 0;
3536 shapeIdx < static_cast<int64_t>(accumulatorSubshapes.size());
3537 shapeIdx++) {
3538 if (accumulatorSubshapes[shapeIdx].getElementType() !=
3539 resultTensorTypes[shapeIdx].getElementType()) {
3540 return emitError()
3541 << "expects the element-type of reduce-op's return-value at index "
3542 << shapeIdx
3543 << " to match the element-type of reducer-block's "
3544 "corresponding return-value, but got "
3545 << resultTensorTypes[shapeIdx].getElementType() << " and "
3546 << accumulatorSubshapes[shapeIdx].getElementType() << " resp.";
3547 }
3548 }
3549
3550 // Check if the shape of results match with the ones derived from
3551 // the input-types and wndow-attributes.
3552 auto inferredReturnTypes = inferReduceWindowOpReturnType(
3553 inputTypes, accumulatorSubshapes, *windowOrErr);
3554
3555 for (size_t i = 0; i < getNumResults(); i++) {
3556 if (failed(verifyCompatibleShape(resultTensorTypes[i],
3557 inferredReturnTypes[i]))) {
3558 return emitOpError()
3559 << "expects result at index " << i
3560 << " to have compatible shape with the corresponding "
3561 "inferred type, but got "
3562 << resultTensorTypes[i] << " and " << inferredReturnTypes[i]
3563 << " resp.";
3564 }
3565 }
3566
3567 return success();
3568 }
3569
3570 // Get the operation used for reduction applied to `result_index`th result. Its
3571 // expected to be a binary operation that consumes `result_index`th and
3572 // `result_index + operands().size`th arguments of the body.
getReductionOp(int resultIndex)3573 Operation* ReduceWindowOp::getReductionOp(int resultIndex) {
3574 auto returnOp = cast<ReturnOp>(body().front().getTerminator());
3575 Operation* computeOp = returnOp.results()[resultIndex].getDefiningOp();
3576 if (computeOp->getNumOperands() != 2) return nullptr;
3577 auto arg0 = computeOp->getOperand(0).dyn_cast<BlockArgument>();
3578 auto arg1 = computeOp->getOperand(1).dyn_cast<BlockArgument>();
3579 if (!arg0 || !arg1) return nullptr;
3580 int64_t arg0Num = arg0.getArgNumber();
3581 int64_t arg1Num = arg1.getArgNumber();
3582 int64_t otherArgIndex = resultIndex + operands().size();
3583 if (arg0Num == resultIndex && arg1Num == otherArgIndex) return computeOp;
3584 if (arg0Num == otherArgIndex && arg1Num == resultIndex &&
3585 computeOp->hasTrait<mlir::OpTrait::IsCommutative>())
3586 return computeOp;
3587 return nullptr;
3588 }
3589
3590 //===----------------------------------------------------------------------===//
3591 // ReducePrecisionOp
3592 //===----------------------------------------------------------------------===//
3593
3594 // The following property is already enforced by the ODS:
3595 // P0. operand element type is float
3596 // P1. mantissa_bits >= 0
3597 // We intend to verify the following properties
3598 // P2. exponent_bits >= 1
verify()3599 LogicalResult ReducePrecisionOp::verify() {
3600 if (exponent_bits() < 1) {
3601 return emitOpError() << "exponent_bits must be at least 1.";
3602 }
3603 return success();
3604 }
3605
3606 //===----------------------------------------------------------------------===//
3607 // ReduceOp
3608 //===----------------------------------------------------------------------===//
3609
3610 // Returns the result type after reducing operand of the given type across the
3611 // specified dimensions.
getReduceResultType(Type operandTy,DenseIntElementsAttr dimensions,Builder * builder)3612 static TensorType getReduceResultType(Type operandTy,
3613 DenseIntElementsAttr dimensions,
3614 Builder* builder) {
3615 Type elementTy = getElementTypeOrSelf(operandTy);
3616
3617 auto rankedTy = operandTy.dyn_cast<RankedTensorType>();
3618 if (!rankedTy) return UnrankedTensorType::get(elementTy);
3619
3620 int64_t rank = rankedTy.getRank();
3621 llvm::SmallVector<bool, 4> dimsMask(rank, false);
3622 for (int64_t dim : dimensions.getValues<int64_t>()) dimsMask[dim] = true;
3623
3624 SmallVector<int64_t, 4> shape;
3625 for (int64_t i = 0; i < rank; ++i) {
3626 if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i));
3627 }
3628
3629 return RankedTensorType::get(shape, elementTy);
3630 }
3631
build(OpBuilder & builder,OperationState & state,ValueRange operands,ValueRange initValues,DenseIntElementsAttr dimensions)3632 void ReduceOp::build(OpBuilder& builder, OperationState& state,
3633 ValueRange operands, ValueRange initValues,
3634 DenseIntElementsAttr dimensions) {
3635 SmallVector<Type, 1> resultTy;
3636 resultTy.reserve(operands.size());
3637
3638 for (Value operand : operands) {
3639 resultTy.push_back(
3640 getReduceResultType(operand.getType(), dimensions, &builder));
3641 }
3642 build(builder, state, resultTy, operands, initValues, dimensions);
3643 }
3644
hasSameOperandAndResultTypes(Operation & op)3645 bool hasSameOperandAndResultTypes(Operation& op) {
3646 Type expected;
3647 if (op.getNumResults() != 0) expected = op.getResult(0).getType();
3648 if (op.getNumOperands() != 0) expected = op.getOperand(0).getType();
3649 if (!expected) return false;
3650
3651 auto typeMatch = [&](Type actual) { return actual == expected; };
3652 return llvm::all_of(op.getOperandTypes(), typeMatch) &&
3653 llvm::all_of(op.getResultTypes(), typeMatch);
3654 }
3655
3656 // Checks the following eligibility criteria for compact printing of reduce:
3657 // E1. The reduce-op wraps a single inner-op in the associated region.
3658 // E2. The single operation is a commutative binary-op from the dialect, zero
3659 // region, producing single result such that the operands and result all
3660 // have the same type.
3661 // E3. The reduce-op consist of at least one input-operand; The operand-types of
3662 // inner-op should be derived trivially from the element-type of reduce-op's
3663 // first input-operand.
3664 // E4. The arguments of the region's only basic block are forwarded perfectly
3665 // to inner-op's operands.
3666 // E5. The reduce-op, inner-op, blocks arguments, and the return-op all have the
3667 // same location.
3668 // E6. The single operation result is perfectly forwarded to the reduce op
3669 // return.
isEligibleForCompactPrint(ReduceOp op)3670 static bool isEligibleForCompactPrint(ReduceOp op) {
3671 // Check E1.
3672 auto& block = op.body().front();
3673 if (!hasSingleElement(block.without_terminator())) return false;
3674
3675 Operation& innerOp = *block.begin();
3676
3677 // Check E2.
3678 if (innerOp.getDialect() != op->getDialect()) return false;
3679
3680 if (innerOp.getNumOperands() != 2 ||
3681 !innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
3682 !hasSameOperandAndResultTypes(innerOp) ||
3683 !innerOp.hasTrait<mlir::OpTrait::IsCommutative>() ||
3684 !innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
3685 return false;
3686
3687 // Check E3.
3688 if (op.operands().empty()) return false;
3689
3690 auto elemType =
3691 op.operands()[0].getType().cast<TensorType>().getElementType();
3692 auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
3693 if (innerOp.getOperands()[0].getType() != expectedInnerOpType) return false;
3694
3695 // Check E4.
3696 if (!llvm::equal(block.getArguments(), innerOp.getOperands())) return false;
3697
3698 // Check E5.
3699 auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
3700 if (!retOp) return false;
3701
3702 auto blockArgLoc = block.getArgument(0).getLoc();
3703 if (blockArgLoc != block.getArgument(1).getLoc()) return false;
3704
3705 if (innerOp.getLoc() != op.getLoc() || retOp.getLoc() != op.getLoc() ||
3706 blockArgLoc != op.getLoc())
3707 return false;
3708
3709 // Check E6.
3710 return llvm::equal(innerOp.getResults(), retOp.getOperands());
3711 }
3712
print(OpAsmPrinter & p)3713 void ReduceOp::print(OpAsmPrinter& p) {
3714 {
3715 // Print the pairs of operands under the form:
3716 // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
3717 StringRef comma = "";
3718 int numOperandPairs = getNumOperands() / 2;
3719 for (int opId : llvm::seq<int>(0, numOperandPairs)) {
3720 p << comma << "(" << getOperand(opId)
3721 << " init: " << getOperand(opId + numOperandPairs) << ")";
3722 comma = ", ";
3723 }
3724 }
3725
3726 // If the reduce-op is eligible for compact printing, we emit the one-liner:
3727 // stablehlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
3728 // Note: We are not printing the function type of reduction operation. We
3729 // have some simplifying assumptions (refer to IsEligibleForCompactPrint::E3)
3730 // to derive the type from that of reduce-op.
3731 if (isEligibleForCompactPrint(*this)) {
3732 Operation& innerOp = body().front().front();
3733 p << " applies ";
3734 printEscapedString(innerOp.getName().getStringRef(), p.getStream());
3735
3736 p << " across dimensions = [";
3737 llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
3738 p << "]";
3739 p << " : ";
3740 p.printFunctionalType(*this);
3741 } else {
3742 p << " across dimensions = [";
3743 llvm::interleaveComma(dimensions().getValues<int64_t>(), p);
3744 p << "]";
3745 p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"});
3746 p << " : ";
3747 p.printFunctionalType(*this);
3748 p.printNewline();
3749 p << " reducer";
3750 {
3751 // Print the pairs of block operands under the form:
3752 // (%arg0_elt, %arg0_acc) (%arg1_elt, %arg1_acc):
3753 Block& reducer = body().front();
3754 int numOperandPairs = getNumOperands() / 2;
3755 for (int opId : llvm::seq<int>(0, numOperandPairs)) {
3756 p << "(";
3757 p.printRegionArgument(reducer.getArgument(opId));
3758 p << ", ";
3759 p.printRegionArgument(reducer.getArgument(opId + numOperandPairs));
3760 p << ") ";
3761 }
3762 }
3763 p << ' ';
3764 p.printRegion(body(), /*printEntryBlockArgs=*/false);
3765 }
3766 }
3767
parse(OpAsmParser & parser,OperationState & result)3768 ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) {
3769 llvm::SMLoc loc = parser.getCurrentLocation();
3770 Location currLocation = parser.getEncodedSourceLoc(loc);
3771
3772 // Parse the operands of reduce-op, this is a list of pair under the form:
3773 // (%arg0 init: %arg3), (%arg1 init: %arg4), (%arg2 init: %arg5)
3774 // Each input to reduce is paired with its init value, even though in memory
3775 // they are stored with the input first and the init values after.
3776 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
3777 SmallVector<OpAsmParser::UnresolvedOperand, 2> initOperands;
3778 do {
3779 (void)parser.parseOptionalComma();
3780 if (parser.parseOptionalLParen()) break;
3781 OpAsmParser::UnresolvedOperand operand, initOperand;
3782 if (parser.parseOperand(operand) || parser.parseKeyword("init") ||
3783 parser.parseColon() || parser.parseOperand(initOperand) ||
3784 parser.parseRParen())
3785 return failure();
3786 operands.push_back(operand);
3787 initOperands.push_back(initOperand);
3788 } while (true);
3789 operands.append(initOperands);
3790
3791 // Check if we are parsing the compact version of reduce-op:
3792 // stablehlo.reduce applies <inner-op> across dimensions = [...] : <func-type>
3793 // else parse the "region-based" variant.
3794 if (failed(parser.parseOptionalKeyword("applies"))) {
3795 // Parse the inner-op dimensions, reduce-op's function-type and
3796 // optional location.
3797 SmallVector<int64_t> dimensions;
3798 auto parseDim = [&]() -> ParseResult {
3799 if (parser.parseInteger(dimensions.emplace_back())) return failure();
3800 return success();
3801 };
3802
3803 FunctionType reduceOpFntype;
3804 if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
3805 parser.parseEqual() ||
3806 parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
3807 parseDim) ||
3808 parser.parseOptionalAttrDict(result.attributes) ||
3809 parser.parseColon() || parser.parseType(reduceOpFntype) ||
3810 parser.parseKeyword("reducer"))
3811 return failure();
3812 OpBuilder builder(parser.getBuilder().getContext());
3813 result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
3814
3815 // Parse the "reducer" region now.
3816 SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerOperands;
3817 SmallVector<OpAsmParser::UnresolvedOperand, 2> reducerInitOperands;
3818 SmallVector<Type, 2> reducerTypes;
3819 SmallVector<Type, 2> reducerInitTypes;
3820 SmallVector<Optional<Location>, 2> reducerLocs;
3821 SmallVector<Optional<Location>, 2> reducerInitLocs;
3822 auto parseBlockOperand =
3823 [&](SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands,
3824 SmallVectorImpl<Type>& types,
3825 SmallVectorImpl<Optional<Location>>& locs) -> ParseResult {
3826 OpAsmParser::UnresolvedOperand operand;
3827 Type type;
3828 Optional<Location> loc;
3829 if (parser.parseOperand(operand, /*allowResultNumber=*/false) ||
3830 parser.parseColon() || parser.parseType(type) ||
3831 parser.parseOptionalLocationSpecifier(loc))
3832 return failure();
3833 operands.push_back(operand);
3834 types.push_back(type);
3835 locs.push_back(loc);
3836 return success();
3837 };
3838 do {
3839 if (failed(parser.parseOptionalLParen())) break;
3840 if (parseBlockOperand(reducerOperands, reducerTypes, reducerLocs) ||
3841 parser.parseComma() ||
3842 parseBlockOperand(reducerInitOperands, reducerInitTypes,
3843 reducerInitLocs) ||
3844 parser.parseRParen())
3845 return failure();
3846 } while (true);
3847 reducerOperands.append(reducerInitOperands);
3848 reducerTypes.append(reducerInitTypes);
3849 reducerLocs.append(reducerInitLocs);
3850 result.addTypes(reduceOpFntype.getResults());
3851 SmallVector<OpAsmParser::Argument> reducerArgs;
3852 createArgs(reducerOperands, reducerTypes, reducerArgs);
3853
3854 // Derive the SSA-values for reduce-op's operands and parse the region, and
3855 // the optional trailing location.
3856 Optional<Location> trailingLoc;
3857 if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
3858 result.operands) ||
3859 parser.parseRegion(*result.addRegion(), reducerArgs))
3860 return failure();
3861 // Set the individual block arguments.
3862 for (auto argAndLoc :
3863 llvm::zip(result.regions.front()->front().getArguments(), reducerLocs))
3864 if (std::get<1>(argAndLoc))
3865 std::get<0>(argAndLoc).setLoc(std::get<1>(argAndLoc).value());
3866 result.location = trailingLoc.value_or(currLocation);
3867 return success();
3868 }
3869
3870 // Parse the inner-op name and check if the contract on inner-op
3871 // mentioned in "isEligibleForCompactPrint::E2" for pretty-priting is met.
3872 FailureOr<OperationName> innerOpNameInfo = parser.parseCustomOperationName();
3873 if (failed(innerOpNameInfo)) return failure();
3874
3875 StringRef innerOpName = innerOpNameInfo->getStringRef();
3876 Dialect* innerOpDialect = innerOpNameInfo->getDialect();
3877 if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") ||
3878 !innerOpNameInfo->hasTrait<mlir::OpTrait::NOperands<2>::Impl>() ||
3879 !innerOpNameInfo->hasTrait<mlir::OpTrait::OneResult>() ||
3880 !innerOpNameInfo->hasTrait<mlir::OpTrait::IsCommutative>() ||
3881 !innerOpNameInfo->hasTrait<mlir::OpTrait::ZeroRegions>()) {
3882 parser.emitError(loc,
3883 "expected the inner-op to be a commutative binary-op from "
3884 "stablehlo dialect, zero region, producing single result");
3885 return failure();
3886 }
3887
3888 // Parse the inner-op dimensions, reduce-op's function-type and
3889 // optional location.
3890 SmallVector<int64_t> dimensions;
3891 auto parseDim = [&]() -> ParseResult {
3892 if (parser.parseInteger(dimensions.emplace_back())) return failure();
3893 return success();
3894 };
3895
3896 Optional<Location> explicitLoc;
3897 FunctionType reduceOpFntype;
3898 if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") ||
3899 parser.parseEqual() ||
3900 parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) ||
3901 parser.parseColon() || parser.parseType(reduceOpFntype) ||
3902 parser.parseOptionalLocationSpecifier(explicitLoc))
3903 return failure();
3904
3905 if (!reduceOpFntype || reduceOpFntype.getInputs().empty()) {
3906 if (!reduceOpFntype) return parser.emitError(loc, "expected function type");
3907 return parser.emitError(loc,
3908 "input types missing in reduce-op function type");
3909 }
3910
3911 // If location of reduce-op is explicitly provided, then use it; Else use
3912 // the parser's current location.
3913 Location reduceOpLoc = explicitLoc.value_or(currLocation);
3914
3915 // Derive the SSA-values for reduce-op's operands.
3916 if (parser.resolveOperands(operands, reduceOpFntype.getInputs(), loc,
3917 result.operands))
3918 return failure();
3919
3920 // Derive the type of inner-op from that of reduce-op's input operand.
3921 auto innerOpType = RankedTensorType::get(
3922 /*shape=*/{}, getElementTypeOrSelf(reduceOpFntype.getInput(0)));
3923
3924 // Add a region for reduce-op.
3925 Region& region = *result.addRegion();
3926
3927 // Create a basic-block inside reduce-op's region.
3928 Block& block = region.emplaceBlock();
3929 auto lhs = block.addArgument(innerOpType, reduceOpLoc);
3930 auto rhs = block.addArgument(innerOpType, reduceOpLoc);
3931
3932 // Create and insert an "inner-op" operation in the block.
3933 OpBuilder builder(parser.getBuilder().getContext());
3934 builder.setInsertionPointToStart(&block);
3935
3936 OperationState innerOpState(reduceOpLoc, innerOpName);
3937 innerOpState.operands.push_back(lhs);
3938 innerOpState.operands.push_back(rhs);
3939 innerOpState.addTypes(innerOpType);
3940
3941 Operation* innerOp = builder.create(innerOpState);
3942
3943 // Insert a return statement in the block returning the inner-op's result.
3944 builder.create<ReturnOp>(innerOp->getLoc(), innerOp->getResults());
3945
3946 // Populate the reduce-op operation-state with result-type, location, and
3947 // dimension attribute.
3948 result.addTypes(reduceOpFntype.getResults());
3949 result.location = innerOp->getLoc();
3950 result.addAttribute("dimensions", builder.getI64TensorAttr(dimensions));
3951
3952 return success();
3953 }
3954
verify()3955 LogicalResult ReduceOp::verify() {
3956 // Check that there are even number of operands and >= 2.
3957 if (getNumOperands() % 2 != 0 || getOperands().empty())
3958 return emitOpError() << "expects the size of operands to be even and >= 2";
3959
3960 // Collect the input and init-value operands. Note that the operand-type is
3961 // enforced as "TensorType" by ODS.
3962 int64_t numInputs = getNumOperands() / 2;
3963 auto operandTensorTypes = llvm::to_vector<4>(llvm::map_range(
3964 getOperandTypes(),
3965 [](Type t) -> TensorType { return t.cast<TensorType>(); }));
3966 ArrayRef<TensorType> inputArgTypes(operandTensorTypes.begin(),
3967 operandTensorTypes.begin() + numInputs);
3968 ArrayRef<TensorType> initValueTypes(operandTensorTypes.begin() + numInputs,
3969 operandTensorTypes.end());
3970
3971 // Check for unranked tensors in input operands.
3972 int64_t rankedInputIdx = -1;
3973 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
3974 if (inputArgTypes[inputIdx].hasRank()) {
3975 rankedInputIdx = inputIdx;
3976 break;
3977 }
3978 }
3979
3980 bool allInputsUnranked = (rankedInputIdx == -1);
3981
3982 // Check that all input operands have compatible shapes. The element types may
3983 // be different.
3984 if (!allInputsUnranked) {
3985 for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
3986 if (failed(mlir::verifyCompatibleShape(inputArgTypes[rankedInputIdx],
3987 inputArgTypes[inputIdx]))) {
3988 return emitOpError()
3989 << "expects all inputs to have compatible shapes. Shape at"
3990 << " input-index " << inputIdx
3991 << " is not compatible with shape at input-index "
3992 << rankedInputIdx;
3993 }
3994 }
3995 }
3996
3997 // Check that
3998 // 1. the dimensions of reduce-op are in-bounds for the given shape.
3999 // 2. the dimension-attribute have no duplicate entries.
4000 DenseSet<int64_t> dimensionsToReduceSet;
4001 for (int64_t dimension : dimensions().getValues<int64_t>()) {
4002 if ((!allInputsUnranked &&
4003 dimension >= inputArgTypes[rankedInputIdx].getRank()) ||
4004 dimension < 0) {
4005 return emitError() << "Out-of-bounds dimension " << dimension
4006 << " for input-tensor rank: "
4007 << inputArgTypes[rankedInputIdx].getRank();
4008 }
4009
4010 if (!dimensionsToReduceSet.insert(dimension).second) {
4011 return emitError() << "Duplicate reduction dimension: " << dimension;
4012 }
4013 }
4014
4015 // Verify the inner block defining the reducer function.
4016 SmallVector<int64_t> newDimensions;
4017 if (!allInputsUnranked) {
4018 for (int inputIdx = 0; inputIdx < inputArgTypes[rankedInputIdx].getRank();
4019 ++inputIdx) {
4020 if (!dimensionsToReduceSet.count(inputIdx)) {
4021 newDimensions.push_back(
4022 inputArgTypes[rankedInputIdx].getDimSize(inputIdx));
4023 }
4024 }
4025 }
4026
4027 Block& block = body().front();
4028 SmallVector<TensorType> accumulatorSubShapes;
4029 if (failed(verifyReducerShape(this->getLoc(), block, inputArgTypes,
4030 initValueTypes, numInputs, newDimensions,
4031 allInputsUnranked, accumulatorSubShapes)))
4032 return failure();
4033
4034 // Check if the reduce-op's result-type matches with the one derived from
4035 // the reducer-block and dimensions attribute.
4036 if (getResults().size() != accumulatorSubShapes.size())
4037 return emitError() << "Unexpected number of reduce-op's returned values: "
4038 << getResults().size() << " vs "
4039 << accumulatorSubShapes.size() << " (expected)";
4040
4041 for (int64_t shapeIdx = 0;
4042 shapeIdx < static_cast<int64_t>(accumulatorSubShapes.size());
4043 shapeIdx++) {
4044 // The result-type is enforced as "TensorType" by ODS.
4045 auto opResultType = getResult(shapeIdx).getType().cast<TensorType>();
4046
4047 // Check element-type.
4048 if (accumulatorSubShapes[shapeIdx].getElementType() !=
4049 opResultType.getElementType()) {
4050 return emitError()
4051 << "Unexpected element-type for reduce-op's return value at index "
4052 << shapeIdx << ": " << opResultType.getElementType() << " vs "
4053 << accumulatorSubShapes[shapeIdx].getElementType()
4054 << " (expected)";
4055 }
4056
4057 // Check shape.
4058 if (!allInputsUnranked && opResultType.hasRank() &&
4059 failed(verifyCompatibleShape(newDimensions, opResultType.getShape()))) {
4060 Type expectedResultType = RankedTensorType::get(
4061 newDimensions, accumulatorSubShapes[shapeIdx].getElementType());
4062 return emitError()
4063 << "Unexpected type for reduce-op's return value at index "
4064 << shapeIdx << ": " << opResultType << " vs " << expectedResultType
4065 << " (expected)";
4066 }
4067 }
4068
4069 return success();
4070 }
4071
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4072 LogicalResult ReduceOp::reifyReturnTypeShapes(
4073 OpBuilder& builder, ValueRange operands,
4074 SmallVectorImpl<Value>& reifiedReturnShapes) {
4075 ReduceOp::Adaptor adaptor(operands);
4076 auto inputs = adaptor.operands();
4077
4078 auto operandType = inputs[0].getType().dyn_cast<RankedTensorType>();
4079 // Not support unranked type a.t.m.
4080 if (!operandType) return failure();
4081
4082 Location loc = this->getLoc();
4083 SmallVector<Value, 4> shapeValues;
4084 SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>());
4085 shapeValues.reserve(operandType.getRank());
4086 Type shapeScalarType = builder.getIndexType();
4087 auto toShapeScalarType = [&](Value v) {
4088 return maybeCastTo(builder, loc, v, shapeScalarType);
4089 };
4090
4091 for (const auto& element : llvm::enumerate(operandType.getShape())) {
4092 int64_t idx = element.index();
4093 auto* it = std::find(dimensions.begin(), dimensions.end(), idx);
4094 if (it != dimensions.end()) {
4095 continue;
4096 }
4097 Value valueDim = toShapeScalarType(
4098 builder.create<tensor::DimOp>(loc, inputs[0], element.index()));
4099 shapeValues.push_back(valueDim);
4100 }
4101
4102 Value outputShape = builder.create<tensor::FromElementsOp>(
4103 loc,
4104 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4105 shapeScalarType),
4106 shapeValues);
4107 for (size_t i = 0; i < inputs.size(); ++i) {
4108 reifiedReturnShapes.push_back(outputShape);
4109 }
4110
4111 return success();
4112 }
4113
4114 //===----------------------------------------------------------------------===//
4115 // RngBitGeneratorOp
4116 //===----------------------------------------------------------------------===//
4117
4118 // Verify that input state has the same shape as output shape
verify()4119 LogicalResult RngBitGeneratorOp::verify() {
4120 auto initialShape = initial_state().getType().dyn_cast<RankedTensorType>();
4121 auto outputShape = output_state().getType().dyn_cast<RankedTensorType>();
4122 if (initialShape.getShape() != outputShape.getShape())
4123 return emitOpError()
4124 << "output state shape must match initial state shape. Got: "
4125 << initialShape << " and " << outputShape;
4126 return success();
4127 }
4128
4129 //===----------------------------------------------------------------------===//
4130 // RngOp
4131 //===----------------------------------------------------------------------===//
4132
verify()4133 LogicalResult RngOp::verify() {
4134 auto dist = rng_distribution();
4135 if (dist == RngDistribution::UNIFORM) {
4136 return success();
4137 }
4138 auto muTy = a().getType().cast<TensorType>().getElementType();
4139 auto sigmaTy = b().getType().cast<TensorType>().getElementType();
4140 if (muTy.isa<FloatType>() && sigmaTy.isa<FloatType>()) {
4141 return success();
4142 }
4143 return emitOpError() << "mu and sigma must be floats";
4144 }
4145
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4146 LogicalResult RngOp::inferReturnTypeComponents(
4147 MLIRContext* context, Optional<Location> location, ValueShapeRange operands,
4148 DictionaryAttr attributes, RegionRange regions,
4149 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4150 return rngInferReturnTypeComponents(context, location, operands, attributes,
4151 regions, inferredReturnShapes);
4152 }
4153
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4154 LogicalResult RngOp::reifyReturnTypeShapes(
4155 OpBuilder& builder, ValueRange operands,
4156 SmallVectorImpl<Value>& reifiedReturnShapes) {
4157 RngOp::Adaptor adaptor(operands);
4158 reifiedReturnShapes.push_back(
4159 castToIndexTensor(builder, getLoc(), adaptor.shape()));
4160 return success();
4161 }
4162
4163 //===----------------------------------------------------------------------===//
4164 // SelectOp
4165 //===----------------------------------------------------------------------===//
4166
verify()4167 LogicalResult SelectOp::verify() {
4168 // The operands 'on_true' and 'on_false' should have compatible types, i.e.,
4169 // (a) have the same element type, and
4170 // (b) have compatible shapes (i.e. the same shape and/or at least one
4171 // dynamic shape)
4172 if (!compatibleShapeAndElementType(on_true().getType(), on_false().getType()))
4173 return emitOpError()
4174 << "requires compatible types for non-predicate operands";
4175
4176 // The predicate, if not-scalar, should have the same shape as the remaining
4177 // operands.
4178 auto predTy = pred().getType().dyn_cast<RankedTensorType>();
4179 bool predMayBeScalar = !predTy || predTy.getRank() == 0;
4180 if (predMayBeScalar) return success();
4181
4182 if (failed(verifyCompatibleShape(pred().getType(), on_true().getType())))
4183 return emitOpError() << "requires the same shape for all operands";
4184
4185 return success();
4186 }
4187
4188 // Makes it such that a SelectOp that is a non-root operation in a DRR infers
4189 // the return type based on operand type.
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4190 LogicalResult SelectOp::inferReturnTypeComponents(
4191 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4192 DictionaryAttr attributes, RegionRange,
4193 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4194 SelectOp::Adaptor op(operands, attributes);
4195 auto trueType = op.on_true().getType().cast<TensorType>();
4196 auto falseType = op.on_false().getType().cast<TensorType>();
4197
4198 // The output shape should be the most general of the operand shapes at each
4199 // dimension.
4200 ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back();
4201 if (trueType == falseType || !trueType.hasRank()) {
4202 outputType = ShapedTypeComponents(trueType.cast<ShapedType>());
4203 } else if (!falseType.hasRank()) {
4204 outputType = ShapedTypeComponents(falseType.cast<ShapedType>());
4205 } else {
4206 assert(trueType.getRank() == falseType.getRank());
4207 llvm::SmallVector<int64_t, 4> dims;
4208 dims.reserve(trueType.getRank());
4209 for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) {
4210 dims.push_back(std::get<0>(dim) == std::get<1>(dim)
4211 ? std::get<0>(dim)
4212 : ShapedType::kDynamicSize);
4213 }
4214 outputType = ShapedTypeComponents(dims, trueType.getElementType());
4215 }
4216 return success();
4217 }
4218
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4219 LogicalResult SelectOp::reifyReturnTypeShapes(
4220 OpBuilder& builder, ValueRange operands,
4221 SmallVectorImpl<Value>& reifiedReturnShapes) {
4222 // For `hlo.select`, the first operand may be a scalar.
4223 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1],
4224 &reifiedReturnShapes);
4225 }
4226
4227 //===----------------------------------------------------------------------===//
4228 // SetDimensionSizeOp
4229 //===----------------------------------------------------------------------===//
4230
verify()4231 LogicalResult SetDimensionSizeOp::verify() {
4232 if (auto size = this->size().getType().dyn_cast<RankedTensorType>()) {
4233 if (size.getRank() != 0)
4234 return emitOpError() << "size operand should be of rank-0";
4235 }
4236
4237 return verifyDimAttr(*this);
4238 }
4239
4240 // TODO(b/238903565): Switch to inferReturnTypeComponents after adding support
4241 // for the encoding upstream.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)4242 LogicalResult SetDimensionSizeOp::inferReturnTypes(
4243 MLIRContext* context, Optional<Location> location, ValueRange operands,
4244 DictionaryAttr attributes, RegionRange regions,
4245 SmallVectorImpl<Type>& inferredReturnTypes) {
4246 Location loc = location.value_or(UnknownLoc::get(context));
4247
4248 SetDimensionSizeOp::Adaptor adaptor(operands, attributes, regions);
4249 if (failed(adaptor.verify(loc))) return failure();
4250
4251 auto inputType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
4252 if (!inputType) {
4253 inferredReturnTypes.push_back(adaptor.operand().getType());
4254 return success();
4255 }
4256
4257 int64_t dim = adaptor.dimension();
4258 int64_t rank = inputType.getRank();
4259 if (dim < 0 || dim >= rank) {
4260 return mlir::emitError(loc) << "expects dimension to be in range [0, "
4261 << rank << "); got: [" << dim << "].";
4262 }
4263
4264 auto shape = llvm::to_vector<4>(inputType.getShape());
4265 llvm::SmallVector<int64_t, 4> bounds(rank, ShapedType::kDynamicSize);
4266 if (auto encoding =
4267 inputType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>())
4268 bounds = llvm::to_vector<4>(encoding.getBounds());
4269
4270 // TODO(hinsu): Handle the case when the size operand is a constant.
4271 if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim];
4272 shape[dim] = ShapedType::kDynamicSize;
4273
4274 auto extensions = TypeExtensionsAttr::get(context, bounds);
4275 auto resultType =
4276 RankedTensorType::get(shape, inputType.getElementType(), extensions);
4277 inferredReturnTypes.push_back(resultType);
4278 return success();
4279 }
4280
4281 //===----------------------------------------------------------------------===//
4282 // PadOp
4283 //===----------------------------------------------------------------------===//
4284
inferReturnTypeComponents(MLIRContext *,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)4285 LogicalResult PadOp::inferReturnTypeComponents(
4286 MLIRContext*, Optional<Location> location, ValueShapeRange operands,
4287 DictionaryAttr attributes, RegionRange regions,
4288 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
4289 PadOp::Adaptor adaptor(operands, attributes, regions);
4290 auto inputType = adaptor.operand().getType().cast<RankedTensorType>();
4291 auto padType = adaptor.padding_value().getType().cast<RankedTensorType>();
4292
4293 if (padType.getRank() != 0) {
4294 return emitOptionalError(
4295 location, llvm::formatv("padding value type should be a rank-0 "
4296 "tensor, is rank {0}",
4297 padType.getRank()));
4298 }
4299
4300 const auto& paddingLow = adaptor.edge_padding_low();
4301 if (paddingLow.getType().getNumElements() != inputType.getRank()) {
4302 return emitOptionalError(
4303 location,
4304 llvm::formatv(
4305 "edge_padding_low length ({0}) must match operand rank ({1})",
4306 paddingLow.getType().getNumElements(), inputType.getRank()));
4307 }
4308
4309 const auto& paddingHigh = adaptor.edge_padding_high();
4310 if (paddingHigh.getType().getNumElements() != inputType.getRank()) {
4311 return emitOptionalError(
4312 location,
4313 llvm::formatv(
4314 "edge_padding_high length ({0}) must match operand rank ({1})",
4315 paddingHigh.getType().getNumElements(), inputType.getRank()));
4316 }
4317
4318 const auto& paddingInterior = adaptor.interior_padding();
4319 if (paddingInterior.getType().getNumElements() != inputType.getRank()) {
4320 return emitOptionalError(
4321 location,
4322 llvm::formatv(
4323 "interior_padding length ({0}) must match operand rank ({1})",
4324 paddingInterior.getType().getNumElements(), inputType.getRank()));
4325 }
4326
4327 auto inputShape = inputType.getShape();
4328 SmallVector<int64_t> resultShape;
4329 for (int i = 0, e = inputShape.size(); i < e; i++) {
4330 if (isDynamicDimSize(inputShape[i])) {
4331 resultShape.push_back(ShapedType::kDynamicSize);
4332 continue;
4333 }
4334
4335 int64_t paddingLowVal = paddingLow.getValues<APInt>()[i].getSExtValue();
4336 int64_t paddingHighVal = paddingHigh.getValues<APInt>()[i].getSExtValue();
4337 int64_t paddingInteriorVal =
4338 paddingInterior.getValues<APInt>()[i].getSExtValue();
4339 if (paddingInteriorVal < 0) {
4340 return emitOptionalError(
4341 location, llvm::formatv("Interior padding cannot be negative: {0}",
4342 paddingInteriorVal));
4343 }
4344 int64_t expectedOutput =
4345 inputShape[i] + paddingLowVal + paddingHighVal +
4346 std::max<int64_t>(inputShape[i] - 1, 0LL) * paddingInteriorVal;
4347 if (expectedOutput < 0) {
4348 return emitOptionalError(
4349 location,
4350 llvm::formatv("Padding result in negative size for dimension {0}",
4351 i));
4352 }
4353 resultShape.push_back(expectedOutput);
4354 }
4355 inferredReturnShapes.emplace_back(resultShape, inputType.getElementType());
4356
4357 return success();
4358 }
4359
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4360 LogicalResult PadOp::reifyReturnTypeShapes(
4361 OpBuilder& builder, ValueRange operands,
4362 SmallVectorImpl<Value>& reifiedReturnShapes) {
4363 PadOp::Adaptor adaptor(operands, this->getOperation()->getAttrDictionary());
4364 auto loc = this->getLoc();
4365 Value operand = adaptor.operand();
4366 auto operandTy = operand.getType().cast<RankedTensorType>();
4367
4368 llvm::SmallVector<int32_t> padHigh;
4369 llvm::SmallVector<int32_t> padLow;
4370 llvm::SmallVector<int32_t> padInterior;
4371
4372 auto padHighAttr = adaptor.edge_padding_high();
4373 auto padLowAttr = adaptor.edge_padding_low();
4374 auto padInteriorAttr = adaptor.interior_padding();
4375
4376 padHigh.reserve(padHighAttr.getNumElements());
4377 padLow.reserve(padLowAttr.getNumElements());
4378 padInterior.reserve(padInteriorAttr.getNumElements());
4379
4380 for (const APInt& val : padHighAttr.getValues<APInt>())
4381 padHigh.push_back(val.getSExtValue());
4382
4383 for (const APInt& val : padLowAttr.getValues<APInt>())
4384 padLow.push_back(val.getSExtValue());
4385
4386 for (const APInt& val : padInteriorAttr.getValues<APInt>())
4387 padInterior.push_back(val.getSExtValue());
4388
4389 Value one = builder.create<arith::ConstantIndexOp>(loc, 1).getResult();
4390 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0).getResult();
4391
4392 llvm::SmallVector<Value> dimensions;
4393 dimensions.reserve(operandTy.getRank());
4394 for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
4395 Value padEdge =
4396 builder.create<arith::ConstantIndexOp>(loc, padHigh[i] + padLow[i]);
4397
4398 // First we grab the initial interior size.
4399 Value dim = builder.create<tensor::DimOp>(loc, operand, i).getResult();
4400
4401 // Compute the interior of the tensor and determine padding size.
4402 if (padInterior[i] > 0) {
4403 Value padInter =
4404 builder.create<arith::ConstantIndexOp>(loc, padInterior[i])
4405 .getResult();
4406 Value interior = builder.create<arith::SubIOp>(loc, dim, one).getResult();
4407 interior = builder.create<arith::MaxSIOp>(loc, interior, zero);
4408 interior = builder.create<arith::MulIOp>(loc, interior, padInter);
4409 dim = builder.create<arith::AddIOp>(loc, dim, interior).getResult();
4410 }
4411
4412 // Then we add the padding on the edge of the tensor.
4413 dim = builder.create<arith::AddIOp>(loc, dim, padEdge).getResult();
4414 dimensions.push_back(dim);
4415 }
4416
4417 Value dimensionTensor =
4418 builder.create<tensor::FromElementsOp>(loc, dimensions).getResult();
4419 reifiedReturnShapes.push_back(dimensionTensor);
4420 return success();
4421 }
4422
4423 //===----------------------------------------------------------------------===//
4424 // DynamicPadOp
4425 //===----------------------------------------------------------------------===//
4426
verify()4427 LogicalResult DynamicPadOp::verify() {
4428 auto inputType = operand().getType().dyn_cast<RankedTensorType>();
4429 // If operand is unranked, there is very little to verify statically.
4430 if (!inputType) return success();
4431 int inputRank = inputType.getRank();
4432
4433 auto padType = padding_value().getType().cast<RankedTensorType>();
4434 if (padType.getRank() != 0) {
4435 return emitOpError() << "padding value type should be a rank-0";
4436 }
4437
4438 auto paddingLowType = edge_padding_low().getType().cast<RankedTensorType>();
4439 if (paddingLowType.getNumElements() != inputRank) {
4440 return emitOpError() << "edge_padding_low length("
4441 << paddingLowType.getNumElements()
4442 << ") must match operand rank(" << inputRank << ").";
4443 }
4444
4445 auto paddingHighType = edge_padding_high().getType().cast<RankedTensorType>();
4446 if (paddingHighType.getNumElements() != inputRank) {
4447 return emitOpError() << "edge_padding_high length("
4448 << paddingHighType.getNumElements()
4449 << ") must match operand rank(" << inputRank << ").";
4450 }
4451
4452 auto interiorPaddingType =
4453 interior_padding().getType().cast<RankedTensorType>();
4454 if (interiorPaddingType.getNumElements() != inputRank) {
4455 return emitOpError() << "edge_padding_interior length("
4456 << interiorPaddingType.getNumElements()
4457 << ") must match operand rank(" << inputRank << ").";
4458 }
4459
4460 auto outputType = getResult().getType().dyn_cast<RankedTensorType>();
4461 // If result is unranked, there is very little to verify statically.
4462 if (!outputType) return success();
4463 int outputRank = outputType.getRank();
4464 if (inputRank != outputRank) {
4465 return emitOpError() << "operand rank(" << inputRank
4466 << ") must match result(" << outputRank << ").";
4467 }
4468
4469 return success();
4470 }
4471
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4472 LogicalResult DynamicPadOp::reifyReturnTypeShapes(
4473 OpBuilder& builder, ValueRange operands,
4474 SmallVectorImpl<Value>& reifiedReturnShapes) {
4475 DynamicPadOp::Adaptor adaptor(operands);
4476 Value operand = adaptor.operand();
4477 Value edgePaddingLow = adaptor.edge_padding_low();
4478 Value edgePaddingHigh = adaptor.edge_padding_high();
4479 Value interiorPadding = adaptor.interior_padding();
4480
4481 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4482 // Not support unranked pad a.t.m.
4483 if (!operandType) return failure();
4484
4485 auto loc = this->getLoc();
4486 SmallVector<Value, 4> shapeValues;
4487 shapeValues.reserve(operandType.getRank());
4488 Type shapeScalarType =
4489 edgePaddingLow.getType().cast<ShapedType>().getElementType();
4490
4491 auto toShapeScalarType = [&](Value v) {
4492 return maybeCastTo(builder, loc, v, shapeScalarType);
4493 };
4494
4495 Value zero =
4496 toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 0));
4497 Value one = toShapeScalarType(builder.create<arith::ConstantIndexOp>(loc, 1));
4498
4499 for (int idx : llvm::seq<int>(0, operandType.getShape().size())) {
4500 Value valueDim =
4501 toShapeScalarType(builder.create<tensor::DimOp>(loc, operand, idx));
4502 Value offset = builder.create<arith::ConstantIndexOp>(loc, idx);
4503 Value valueLow =
4504 builder.create<tensor::ExtractOp>(loc, edgePaddingLow, offset);
4505 Value valueHigh =
4506 builder.create<tensor::ExtractOp>(loc, edgePaddingHigh, offset);
4507 Value valueInterior =
4508 builder.create<tensor::ExtractOp>(loc, interiorPadding, offset);
4509 // output_size = input_size + padding_low + padding_high + interior *
4510 // max(input_size - 1, 0)
4511 Value valueDimLessThanOne = builder.create<arith::CmpIOp>(
4512 loc, arith::CmpIPredicate::slt, valueDim, one);
4513 Value interiorSize = builder.create<arith::MulIOp>(
4514 loc, valueInterior,
4515 builder.create<mlir::arith::SelectOp>(
4516 loc, valueDimLessThanOne, zero,
4517 builder.create<arith::SubIOp>(loc, valueDim, one)));
4518 shapeValues.push_back(builder.create<arith::AddIOp>(
4519 loc,
4520 builder.create<arith::AddIOp>(
4521 loc, builder.create<arith::AddIOp>(loc, interiorSize, valueDim),
4522 valueLow),
4523 valueHigh));
4524 }
4525
4526 reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>(
4527 loc,
4528 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4529 shapeScalarType),
4530 shapeValues));
4531
4532 return success();
4533 }
4534
4535 //===----------------------------------------------------------------------===//
4536 // ReshapeOp
4537 //===----------------------------------------------------------------------===//
4538
verify()4539 LogicalResult ReshapeOp::verify() {
4540 // If the operand type is dynamically shaped there is nothing to verify.
4541 auto operandTy = operand().getType().dyn_cast<RankedTensorType>();
4542 if (!operandTy || !operandTy.hasStaticShape()) return success();
4543
4544 // If the operand type is statically shaped (not required) the number of
4545 // elements must match that of the result type.
4546 auto resultTy = getType().cast<RankedTensorType>();
4547 assert(resultTy && resultTy.hasStaticShape() &&
4548 "result type must be statically shaped");
4549 int64_t numResultElements = resultTy.getNumElements();
4550 int64_t numOperandElements = operandTy.getNumElements();
4551 if (numResultElements != numOperandElements)
4552 return emitOpError() << "number of output elements (" << numResultElements
4553 << ") doesn't match expected number of elements ("
4554 << numOperandElements << ")";
4555
4556 return success();
4557 }
4558
4559 //===----------------------------------------------------------------------===//
4560 // ReplicaId Op
4561 //===----------------------------------------------------------------------===//
4562
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)4563 LogicalResult ReplicaIdOp::inferReturnTypes(
4564 MLIRContext* context, Optional<Location>, ValueRange operands,
4565 DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
4566 inferredReturnTypes.push_back(RankedTensorType::get(
4567 /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned)));
4568 return success();
4569 }
4570
4571 //===----------------------------------------------------------------------===//
4572 // If Op
4573 //===----------------------------------------------------------------------===//
4574
verifyConditionalBranch(Operation * op,Region & region,llvm::Twine branchName)4575 static LogicalResult verifyConditionalBranch(Operation* op, Region& region,
4576 llvm::Twine branchName) {
4577 if (region.getNumArguments() != 0)
4578 return op->emitOpError()
4579 << branchName << " must have 0 arguments, but found "
4580 << region.getNumArguments();
4581
4582 TypeRange branchReturnTypes =
4583 region.front().getTerminator()->getOperandTypes();
4584 if (branchReturnTypes != op->getResultTypes())
4585 return op->emitOpError()
4586 << branchName << " returned types (" << branchReturnTypes
4587 << ") do not match op result types (" << op->getResultTypes() << ")";
4588
4589 return success();
4590 }
4591
verify()4592 LogicalResult IfOp::verify() {
4593 if (failed(verifyConditionalBranch(*this, true_branch(),
4594 /*branchName=*/"true_branch"))) {
4595 return failure();
4596 }
4597
4598 if (failed(verifyConditionalBranch(*this, false_branch(),
4599 /*branchName=*/"false_branch"))) {
4600 return failure();
4601 }
4602 return success();
4603 }
4604
4605 //===----------------------------------------------------------------------===//
4606 // Case Op
4607 //===----------------------------------------------------------------------===//
4608
verify()4609 LogicalResult CaseOp::verify() {
4610 auto numBranches = branches().size();
4611
4612 for (unsigned i = 0; i < numBranches; ++i)
4613 if (failed(verifyConditionalBranch(*this, branches()[i],
4614 /*branchName=*/"branch " + Twine(i))))
4615 return failure();
4616
4617 return success();
4618 }
4619
4620 //===----------------------------------------------------------------------===//
4621 // UnaryOps
4622 //===----------------------------------------------------------------------===//
4623
parseUnaryOp(OpAsmParser & parser,OperationState & result)4624 ParseResult parseUnaryOp(OpAsmParser& parser, OperationState& result) {
4625 SmallVector<OpAsmParser::UnresolvedOperand> operands;
4626 Type type;
4627 // If the operand is in-between parentheses, use generic form.
4628 SMLoc loc = parser.getCurrentLocation();
4629 if (!parser.parseOptionalLParen()) {
4630 if (parser.parseOperandList(operands) || parser.parseRParen() ||
4631 parser.parseOptionalAttrDict(result.attributes) ||
4632 parser.parseColon() || parser.parseType(type))
4633 return failure();
4634 auto fnType = type.dyn_cast<FunctionType>();
4635 if (!fnType) {
4636 parser.emitError(loc, "expected function type");
4637 return failure();
4638 }
4639 if (parser.resolveOperands(operands, fnType.getInputs(), loc,
4640 result.operands))
4641 return failure();
4642 result.addTypes(fnType.getResults());
4643 return success();
4644 }
4645 // Otherwise, use shorthand syntax.
4646 return failure(parser.parseOperandList(operands) ||
4647 parser.parseOptionalAttrDict(result.attributes) ||
4648 parser.parseColonType(type) ||
4649 parser.resolveOperands(operands, type, result.operands) ||
4650 parser.addTypeToList(type, result.types));
4651 }
4652
printUnaryOp(Operation * op,OpAsmPrinter & p)4653 void printUnaryOp(Operation* op, OpAsmPrinter& p) {
4654 assert(op->getNumResults() == 1 && "op should have one result");
4655 assert(op->getNumOperands() == 1 && "op should have one input");
4656 // If not all types are the same, use generic form.
4657 auto resultType = op->getResult(0).getType();
4658 if (resultType != op->getOperandTypes()[0]) {
4659 p.printGenericOp(op, /*printOpName=*/false);
4660 return;
4661 }
4662 // Otherwise, use the shorthand syntax.
4663 p << ' ';
4664 p.printOperands(op->getOperands());
4665 p.printOptionalAttrDict(op->getAttrs());
4666 p << " : " << resultType;
4667 }
4668
4669 //===----------------------------------------------------------------------===//
4670 // BinaryOps
4671 //===----------------------------------------------------------------------===//
4672
parseBinaryOp(OpAsmParser & parser,OperationState & result)4673 ParseResult parseBinaryOp(OpAsmParser& parser, OperationState& result) {
4674 SmallVector<OpAsmParser::UnresolvedOperand> operands;
4675 Type type;
4676 // If the operand list is in-between parentheses, use generic form.
4677 SMLoc loc = parser.getCurrentLocation();
4678 if (!parser.parseOptionalLParen()) {
4679 if (parser.parseOperandList(operands) || parser.parseRParen() ||
4680 parser.parseOptionalAttrDict(result.attributes) ||
4681 parser.parseColon() || parser.parseType(type))
4682 return failure();
4683 auto fnType = type.dyn_cast<FunctionType>();
4684 if (!fnType) {
4685 parser.emitError(loc, "expected function type");
4686 return failure();
4687 }
4688 if (parser.resolveOperands(operands, fnType.getInputs(), loc,
4689 result.operands))
4690 return failure();
4691 result.addTypes(fnType.getResults());
4692 return success();
4693 }
4694 // Otherwise, use shorthand syntax.
4695 return failure(parser.parseOperandList(operands) ||
4696 parser.parseOptionalAttrDict(result.attributes) ||
4697 parser.parseColonType(type) ||
4698 parser.resolveOperands(operands, type, result.operands) ||
4699 parser.addTypeToList(type, result.types));
4700 }
4701
printBinaryOp(Operation * op,OpAsmPrinter & p)4702 void printBinaryOp(Operation* op, OpAsmPrinter& p) {
4703 assert(op->getNumResults() == 1 && "op should have one result");
4704 // If not all types are the same, use generic form.
4705 auto resultType = op->getResult(0).getType();
4706 if (llvm::any_of(op->getOperandTypes(),
4707 [&](Type type) { return type != resultType; })) {
4708 p.printGenericOp(op, /*printOpName=*/false);
4709 return;
4710 }
4711 // Otherwise, use the shorthand syntax.
4712 p << ' ';
4713 p.printOperands(op->getOperands());
4714 p.printOptionalAttrDict(op->getAttrs());
4715 p << " : " << resultType;
4716 }
4717
4718 //===----------------------------------------------------------------------===//
4719 // SliceOp
4720 //===----------------------------------------------------------------------===//
4721
4722 // Returns output dimension size for slice result for the given arguments.
4723 // Returns -1 if arguments are illegal.
inferSliceDim(int64_t inputDim,int64_t start,int64_t end,int64_t stride)4724 static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end,
4725 int64_t stride) {
4726 if (inputDim == -1 || start < 0 || start > end || end > inputDim ||
4727 stride == 0)
4728 return -1;
4729
4730 return llvm::divideCeil(end - start, stride);
4731 }
4732
4733 // The following properties are already enforced by the ODS:
4734 // type(start_indices) == type(limit_indices) == type(strides).
4735 // Verify the following properties:
4736 // P1. Verify rank(start_indices) == 1.
4737 // P2. Verify size(start_indices) == rank(operand).
4738 // P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i].
4739 // P6. Verify stride[i] > 0.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)4740 LogicalResult SliceOp::inferReturnTypes(
4741 MLIRContext* context, Optional<Location> location, ValueRange operands,
4742 DictionaryAttr attributes, RegionRange regions,
4743 SmallVectorImpl<Type>& inferredReturnTypes) {
4744 SliceOpAdaptor slice(operands, attributes);
4745 Type ty = slice.operand().getType();
4746 RankedTensorType rankedTy = ty.dyn_cast<RankedTensorType>();
4747 if (!rankedTy) {
4748 // The operand type is unranked, so the best we can infer for the result
4749 // type is an unranked tensor with the same element type as the operand
4750 // type.
4751 inferredReturnTypes.assign({ty});
4752 return success();
4753 }
4754
4755 ShapedType attrTy = slice.start_indices().getType();
4756 // P1.
4757 // Note: ODS has type(start_indices) == type(limit_indices) == type(strides)
4758 // So this implies rank(limit_indices) == rank(strides) == 1 also.
4759 if (attrTy.getRank() != 1) {
4760 return emitOptionalError(location, "start_indices has rank ",
4761 attrTy.getRank(), " instead of required rank 1");
4762 }
4763
4764 // P2.
4765 int64_t rank = rankedTy.getRank();
4766 if (attrTy.getNumElements() != rank) {
4767 return emitOptionalError(
4768 location, "the number of elements in start_indices (",
4769 attrTy.getNumElements(), ") does not match the rank of the operand (",
4770 rank, ")");
4771 }
4772
4773 SmallVector<int64_t, 4> start(slice.start_indices().getValues<int64_t>());
4774 SmallVector<int64_t, 4> limit(slice.limit_indices().getValues<int64_t>());
4775 SmallVector<int64_t, 4> strideVals(slice.strides().getValues<int64_t>());
4776
4777 SmallVector<int64_t, 4> shape;
4778 shape.reserve(rank);
4779 for (int64_t i = 0, e = rank; i != e; i++) {
4780 if (isDynamicDimSize(rankedTy.getDimSize(i))) {
4781 shape.push_back(ShapedType::kDynamicSize);
4782 continue;
4783 }
4784 // P3.
4785 if (start[i] < 0)
4786 return emitOptionalError(location, "negative start index ", start[i],
4787 " in dimension ", i);
4788 // P4.
4789 if (limit[i] > rankedTy.getDimSize(i))
4790 return emitOptionalError(location, "limit index ", limit[i],
4791 " is larger than dimension size ",
4792 rankedTy.getDimSize(i), " in dimension ", i);
4793 // P5.
4794 if (start[i] > limit[i])
4795 return emitOptionalError(location, "start index ", start[i],
4796 " is larger than limit index ", limit[i],
4797 " in dimension ", i);
4798 // P6.
4799 if (strideVals[i] <= 0)
4800 return emitOptionalError(location, "stride must be positive but got ",
4801 strideVals[i], " in dimension ", i);
4802
4803 shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i],
4804 strideVals[i]));
4805 }
4806 inferredReturnTypes.assign(
4807 {RankedTensorType::get(shape, rankedTy.getElementType())});
4808 return success();
4809 }
4810
4811 //===----------------------------------------------------------------------===//
4812 // SortOp
4813 //===----------------------------------------------------------------------===//
4814
build(OpBuilder & builder,OperationState & state,ValueRange operands,int64_t dimension,bool isStable)4815 void SortOp::build(OpBuilder& builder, OperationState& state,
4816 ValueRange operands, int64_t dimension, bool isStable) {
4817 state.addOperands(operands);
4818 state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
4819 state.addAttribute("is_stable", builder.getBoolAttr(isStable));
4820
4821 for (Value operand : operands) state.addTypes(operand.getType());
4822
4823 state.addRegion();
4824 }
4825
verify()4826 LogicalResult SortOp::verify() {
4827 Operation::operand_range operands = this->operands();
4828 if (operands.empty()) return emitOpError("requires at least one input");
4829
4830 // TODO(antiagainst): verify partionally dynamic shapes
4831 if (llvm::all_of(operands, [](Value operand) {
4832 return operand.getType().cast<ShapedType>().hasRank();
4833 })) {
4834 ArrayRef<int64_t> inputShape =
4835 (*operands.begin()).getType().cast<ShapedType>().getShape();
4836
4837 if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
4838 return operand.getType().cast<ShapedType>().getShape() != inputShape;
4839 }))
4840 return emitOpError("requires all inputs to have the same dimensions");
4841
4842 int64_t rank = inputShape.size();
4843 int64_t cmpDim = dimension();
4844 if (cmpDim < -rank || cmpDim >= rank)
4845 return emitOpError("dimension attribute value must be in range [-")
4846 << rank << ", " << rank << "), but found " << cmpDim;
4847 }
4848
4849 Block& block = comparator().front();
4850 size_t numOperands = getOperation()->getNumOperands();
4851 if (block.getNumArguments() != 2 * numOperands)
4852 return emitOpError("comparator block should have ")
4853 << 2 * numOperands << " arguments";
4854
4855 for (const auto& indexedOperand : llvm::enumerate(operands)) {
4856 int index = indexedOperand.index();
4857 Type elementType =
4858 indexedOperand.value().getType().cast<ShapedType>().getElementType();
4859 Type tensorType = RankedTensorType::get({}, elementType);
4860 for (int i : {2 * index, 2 * index + 1}) {
4861 Type argType = block.getArgument(i).getType();
4862 if (argType != tensorType)
4863 return emitOpError("comparator block argument #")
4864 << i << " should be of type " << tensorType << " but got "
4865 << argType;
4866 }
4867 }
4868
4869 // Mapped computation must return single output.
4870 auto comparatorResult = block.getTerminator()->getOperands();
4871 if (comparatorResult.size() != 1)
4872 return emitOpError() << "comparator must return single output, but got: "
4873 << comparatorResult.size();
4874
4875 // The output of computation must be 0-ranked tensor with element-type i1.
4876 auto comparatorResultType =
4877 comparatorResult[0].getType().dyn_cast<RankedTensorType>();
4878 if (!comparatorResultType || comparatorResultType.getRank() != 0 ||
4879 !comparatorResultType.getElementType().isInteger(1))
4880 return emitOpError() << "comparator must return tensor<i1>, but got: "
4881 << comparatorResult[0].getType();
4882
4883 // check number of return-values and their element-types.
4884 auto resultTypes = getResultTypes();
4885 if (resultTypes.size() != numOperands)
4886 return emitOpError() << "expects the number of results to be same as "
4887 "number of operands. Got number of results = "
4888 << resultTypes.size()
4889 << " and number of operands = " << numOperands;
4890
4891 for (auto it : llvm::zip(operands, getResultTypes()))
4892 if (std::get<0>(it).getType().cast<TensorType>().getElementType() !=
4893 std::get<1>(it).cast<TensorType>().getElementType())
4894 return emitOpError()
4895 << "expects the operands and results to have pairwize equal "
4896 "element-types, but got "
4897 << std::get<0>(it).getType().cast<TensorType>().getElementType()
4898 << " vs " << std::get<1>(it).cast<TensorType>().getElementType();
4899
4900 return success();
4901 }
4902
4903 //===----------------------------------------------------------------------===//
4904 // TransposeOp
4905 //===----------------------------------------------------------------------===//
4906
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)4907 LogicalResult TransposeOp::reifyReturnTypeShapes(
4908 OpBuilder& builder, ValueRange operands,
4909 SmallVectorImpl<Value>& reifiedReturnShapes) {
4910 TransposeOp::Adaptor adaptor(operands);
4911 Value operand = adaptor.operand();
4912
4913 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
4914 // Not support unranked type a.t.m.
4915 if (!operandType) return failure();
4916
4917 Location loc = this->getLoc();
4918 SmallVector<int64_t, 4> permutation(this->permutation().getValues<int64_t>());
4919 SmallVector<Value, 4> shapeValues(permutation.size());
4920
4921 Type shapeScalarType = builder.getIndexType();
4922 auto toShapeScalarType = [&](Value v) {
4923 return maybeCastTo(builder, loc, v, shapeScalarType);
4924 };
4925
4926 for (const auto& element : llvm::enumerate(operandType.getShape())) {
4927 int64_t idx = element.index();
4928 auto* it = std::find(permutation.begin(), permutation.end(), idx);
4929 Value valueDim = toShapeScalarType(
4930 builder.createOrFold<tensor::DimOp>(loc, operand, element.index()));
4931 shapeValues[std::distance(permutation.begin(), it)] = valueDim;
4932 }
4933
4934 Value outputShape = builder.create<tensor::FromElementsOp>(
4935 loc,
4936 RankedTensorType::get({static_cast<int64_t>(shapeValues.size())},
4937 shapeScalarType),
4938 shapeValues);
4939 reifiedReturnShapes.push_back(outputShape);
4940
4941 return success();
4942 }
4943
4944 // Method for InferTypeOpInterface: infer the return type from the operand type
4945 // and the permutation.
inferReturnTypes(MLIRContext *,Optional<Location> loc,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)4946 LogicalResult TransposeOp::inferReturnTypes(
4947 MLIRContext* /*context*/, Optional<Location> loc, ValueRange operands,
4948 DictionaryAttr attributes, RegionRange,
4949 SmallVectorImpl<Type>& inferredReturnTypes) {
4950 auto type = operands[0].getType();
4951 auto rankedTy = type.dyn_cast<RankedTensorType>();
4952 if (!rankedTy) {
4953 auto shapedTy = type.dyn_cast<ShapedType>();
4954 inferredReturnTypes.emplace_back(shapedTy);
4955 return success();
4956 }
4957 auto permutation = attributes.getAs<DenseIntElementsAttr>("permutation");
4958 int64_t rank = rankedTy.getRank();
4959 if (permutation.getType().getRank() != 1)
4960 return emitOptionalError(loc, "TransposeOp permutation has rank ",
4961 permutation.getType().getRank(),
4962 " instead of rank 1");
4963
4964 if (permutation.size() != rank)
4965 return emitOptionalError(loc, "TransposeOp operand rank ", rank,
4966 " does not match permutation size ",
4967 permutation.size());
4968
4969 std::vector<int64_t> range(rank);
4970 std::iota(range.begin(), range.end(), 0);
4971 if (!std::is_permutation(range.begin(), range.end(), permutation.begin()))
4972 return emitOptionalError(loc,
4973 "attribute permutation must be a permutation"
4974 " of [",
4975 range, "] but got ", permutation);
4976
4977 SmallVector<int64_t> resultShape;
4978 ArrayRef<int64_t> inputShape = rankedTy.getShape();
4979 for (int64_t dim : permutation.getValues<int64_t>()) {
4980 resultShape.push_back(inputShape[dim]);
4981 }
4982 inferredReturnTypes.emplace_back(RankedTensorType::get(
4983 resultShape, rankedTy.getElementType(), rankedTy.getEncoding()));
4984 return success();
4985 }
4986
4987 //===----------------------------------------------------------------------===//
4988 // TriangularSolveOp
4989 //===----------------------------------------------------------------------===//
4990
verify()4991 LogicalResult TriangularSolveOp::verify() {
4992 auto aType = a().getType().dyn_cast<RankedTensorType>();
4993
4994 // Skip verifier if a is unranked tensor.
4995 if (!aType) return success();
4996
4997 // Check that a should have rank >= 2
4998 auto aRank = aType.getRank();
4999 if (aRank < 2)
5000 return emitOpError() << "operand 'a' must have rank >= 2, but got "
5001 << aType;
5002
5003 // The two minor dimensions of a must have same size.
5004 if (aType.getDimSize(aRank - 2) != aType.getDimSize(aRank - 1))
5005 return emitOpError() << "two minor dimensions of operand 'a' must have "
5006 "equal size, but got "
5007 << aType;
5008
5009 auto bType = b().getType().dyn_cast<RankedTensorType>();
5010 // If b is unranked skip remaining checks.
5011 if (!bType) return success();
5012
5013 // Check that a and b have same rank.
5014 auto bRank = bType.getRank();
5015 if (aRank != bRank)
5016 return emitOpError() << "operands must have equal rank, but got " << aType
5017 << " and " << bType;
5018
5019 // The shared dimension of a and b should match.
5020 if (aType.getDimSize(aRank - 1) !=
5021 bType.getDimSize(bRank - (left_side() ? 2 : 1)))
5022 return emitOpError() << "shared dimension of operands 'a' and 'b' does "
5023 "not match, but got "
5024 << aType << " and " << bType;
5025
5026 // The leading batch dimensions of a and b must be equal.
5027 auto aBatchDims = aType.getShape().drop_back(2);
5028 auto bBatchDims = bType.getShape().drop_back(2);
5029 if (aBatchDims != bBatchDims)
5030 return emitOpError()
5031 << "leading batch dimensions of the operands must be same, but got "
5032 << aType << " and " << bType;
5033
5034 // Result and argument b must have same shape.
5035 auto resultType = getType().dyn_cast<RankedTensorType>();
5036 if (!resultType) return success();
5037 if (resultType != bType)
5038 return emitOpError()
5039 << "result and operand 'b' must have same shape, but got "
5040 << resultType << " and " << bType;
5041 return success();
5042 }
5043
5044 //===----------------------------------------------------------------------===//
5045 // GetTupleElementOp
5046 //===----------------------------------------------------------------------===//
5047
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5048 LogicalResult GetTupleElementOp::inferReturnTypes(
5049 MLIRContext*, Optional<Location>, ValueRange operands,
5050 DictionaryAttr attributes, RegionRange,
5051 SmallVectorImpl<Type>& inferredReturnTypes) {
5052 auto tupleType = operands[0].getType().dyn_cast<TupleType>();
5053 if (!tupleType) return failure();
5054
5055 auto indexAttr = attributes.get("index").cast<IntegerAttr>();
5056 auto index = indexAttr.getInt();
5057 if (index < 0 || index >= static_cast<int64_t>(tupleType.size()))
5058 return failure();
5059
5060 inferredReturnTypes.push_back(tupleType.getType(index));
5061 return success();
5062 }
5063
5064 //===----------------------------------------------------------------------===//
5065 // TupleOp
5066 //===----------------------------------------------------------------------===//
5067
inferReturnTypes(MLIRContext * context,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)5068 LogicalResult TupleOp::inferReturnTypes(
5069 MLIRContext* context, Optional<Location>, ValueRange operands,
5070 DictionaryAttr attributes, RegionRange,
5071 SmallVectorImpl<Type>& inferredReturnTypes) {
5072 inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
5073 return success();
5074 }
5075
5076 //===----------------------------------------------------------------------===//
5077 // CompareOp
5078 //===----------------------------------------------------------------------===//
5079
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,ComparisonDirection comparisonDirection,ComparisonType compareType)5080 void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
5081 Value rhs, ComparisonDirection comparisonDirection,
5082 ComparisonType compareType) {
5083 build(builder, result, lhs, rhs,
5084 ComparisonDirectionAttr::get(builder.getContext(), comparisonDirection),
5085 ComparisonTypeAttr::get(builder.getContext(), compareType));
5086 }
5087
inferReturnTypeComponents(mlir::MLIRContext * ctx,llvm::Optional<mlir::Location>,ValueShapeRange operands,mlir::DictionaryAttr,mlir::RegionRange,llvm::SmallVectorImpl<mlir::ShapedTypeComponents> & inferredReturnTypes)5088 LogicalResult CompareOp::inferReturnTypeComponents(
5089 mlir::MLIRContext* ctx, llvm::Optional<mlir::Location>,
5090 ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange,
5091 llvm::SmallVectorImpl<mlir::ShapedTypeComponents>& inferredReturnTypes) {
5092 ShapedTypeComponents& components =
5093 inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1));
5094 auto argTy = operands.front().getType().cast<TensorType>();
5095 if (argTy.hasRank()) {
5096 components =
5097 ShapedTypeComponents(argTy.getShape(), components.getElementType());
5098 }
5099 return success();
5100 }
5101
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,SmallVectorImpl<Value> & reifiedReturnShapes)5102 LogicalResult CompareOp::reifyReturnTypeShapes(
5103 OpBuilder& builder, ValueRange operands,
5104 SmallVectorImpl<Value>& reifiedReturnShapes) {
5105 return hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(),
5106 &reifiedReturnShapes);
5107 }
5108
5109 //===----------------------------------------------------------------------===//
5110 // SelectAndScatterOp
5111 //===----------------------------------------------------------------------===//
5112
5113 namespace {
5114 // Infer the return-type of SelectAndScatterOp.
inferSelectAndScatterOpReturnType(TensorType operandType,const ArrayRef<WindowDimension> window)5115 TensorType inferSelectAndScatterOpReturnType(
5116 TensorType operandType, const ArrayRef<WindowDimension> window) {
5117 if (!operandType.hasRank())
5118 return UnrankedTensorType::get(operandType.getElementType());
5119
5120 return RankedTensorType::get(
5121 inferWindowOutputShape(operandType.getShape(), window),
5122 operandType.getElementType());
5123 }
5124 } // namespace
5125
5126 // We intend to verify the following properties:
5127 // P1. Check if the select function has a proper shape of (T,T) -> PRED, where
5128 // T is a 0-D tensor with element-type same as 'operand' element-type.
5129 // P2. Verify scatter-computation type.
5130 // P3. size-of(window_dimension) == rank-of(input),
5131 // where input is an element of 'inputs'.
5132 // P4. Verify and collect the window attributes.
5133 // P5. Verify the return type matches the operand-type.
5134 // P6. Check if the result type of window operation matches the source type.
verify()5135 LogicalResult SelectAndScatterOp::verify() {
5136 auto operandType = operand().getType().cast<TensorType>();
5137 auto initValueType = init_value().getType().cast<TensorType>();
5138 auto sourceType = source().getType().cast<TensorType>();
5139 auto resultType = getResult().getType().cast<TensorType>();
5140
5141 // P1.
5142 Block& selectBlock = select().front();
5143
5144 if (selectBlock.getArguments().size() != 2)
5145 return emitOpError()
5146 << "expects the select-region to take 2 parameters, but takes "
5147 << selectBlock.getArguments().size();
5148
5149 Type expectedSelectArgType =
5150 RankedTensorType::get({}, operandType.getElementType());
5151 for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments()))
5152 if (!compatibleShapeAndElementType(expectedSelectArgType,
5153 selectArgIt.value().getType(),
5154 /*ignoreFpPrecision=*/true))
5155 return emitOpError()
5156 << "expects the type of select-region's parameter at index "
5157 << selectArgIt.index() << " to be " << expectedSelectArgType
5158 << ", but got " << selectArgIt.value().getType();
5159
5160 auto selectResult = selectBlock.getTerminator()->getOperands();
5161 if (selectResult.size() != 1)
5162 return emitOpError()
5163 << "expects select-region to return single value, but got: "
5164 << selectResult.size();
5165
5166 auto selectResultType = selectResult[0].getType().dyn_cast<TensorType>();
5167 if (!selectResultType || !selectResultType.getElementType().isInteger(1) ||
5168 (selectResultType.hasRank() &&
5169 selectResultType.cast<RankedTensorType>().getRank() != 0))
5170 return emitOpError() << "expects the return-type of select-region to be "
5171 "tensor<i1>, but got: "
5172 << selectResult[0].getType();
5173
5174 // P2.
5175 Block& scatterBlock = scatter().front();
5176 SmallVector<TensorType> accumulatorSubshapes;
5177 if (failed(verifyReducerShape(
5178 this->getLoc(), scatterBlock,
5179 {RankedTensorType::get({}, sourceType.getElementType())},
5180 {initValueType},
5181 /*numInputs=*/1, /*allowedDimensions=*/{},
5182 /*allInputsUnranked=*/false, accumulatorSubshapes)))
5183 return failure();
5184
5185 // P3.
5186 SmallVector<int64_t> windowDims =
5187 convertDenseIntAttr(this->window_dimensions());
5188 if (operandType.hasRank()) {
5189 if (operandType.getRank() != static_cast<int64_t>(windowDims.size()))
5190 return emitOpError()
5191 << "expects window-dimensions size == operand rank, but got "
5192 "window-dimensions size: "
5193 << windowDims.size() << " and operand-type: " << operandType
5194 << " with rank = " << operandType.getRank() << ".";
5195 }
5196
5197 // P4.
5198 auto paddingOrErr = convertNx2Attribute(this->padding(), getLoc());
5199 if (failed(paddingOrErr)) return failure();
5200 SmallVector<std::pair<int64_t, int64_t>> padding = *paddingOrErr;
5201
5202 auto windowOrErr = verifyWindowAttributesAndInferWindowDimensions(
5203 windowDims, convertDenseIntAttr(window_strides()), padding,
5204 /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc());
5205 if (failed(windowOrErr)) return failure();
5206
5207 // P5.
5208 if (!compatibleShapeAndElementType(operandType, resultType))
5209 return emitOpError()
5210 << "expects the return-type to match the operand-type, but got "
5211 << resultType << " and " << operandType << " resp.";
5212
5213 // P6.
5214 auto windowResultType =
5215 inferSelectAndScatterOpReturnType(operandType, *windowOrErr);
5216
5217 if (!compatibleShapeAndElementType(windowResultType, sourceType,
5218 /*ignoreFpPrecision=*/true))
5219 return emitOpError() << "expects source-type to be " << windowResultType
5220 << ", but got" << sourceType;
5221
5222 return success();
5223 }
5224
5225 //===----------------------------------------------------------------------===//
5226 // ScatterOp
5227 //===----------------------------------------------------------------------===//
5228
5229 /*
5230 * We intend to verify the following properties:
5231 * P1. The 'update_window_dims' must be valid indices of 'updates' tensor.
5232 * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor.
5233 * P3. Check if the rank-of('operand') == size-of('update_window_dims') +
5234 * size-of('inserted_window_dims')
5235 * P4. size-of('scatter_dims_to_operand_dims') =
5236 * 'scatter_indices'['index_vector_dim'] &
5237 * 'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor.
5238 */
validateScatterDimensionNumbers(ShapedType operandType,ArrayRef<int64_t> scatterIndicesShape,ShapedType updateType,bool operandTypeRanked,bool scatterIndicesTypeRanked,bool updatesTypeRanked,ScatterDimensionNumbersAttr dimNumbers,Location loc)5239 LogicalResult validateScatterDimensionNumbers(
5240 ShapedType operandType, ArrayRef<int64_t> scatterIndicesShape,
5241 ShapedType updateType, bool operandTypeRanked,
5242 bool scatterIndicesTypeRanked, bool updatesTypeRanked,
5243 ScatterDimensionNumbersAttr dimNumbers, Location loc) {
5244 // P1.
5245 auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims());
5246 if (!llvm::is_sorted(updateWindowDims))
5247 return mlir::emitError(loc)
5248 << "Expects update_window_dims to be sorted; got: ["
5249 << updateWindowDims << "].";
5250
5251 if (hasDuplicates(updateWindowDims))
5252 return mlir::emitError(loc)
5253 << "Expects update_window_dims to not repeat; got: ["
5254 << updateWindowDims << "].";
5255
5256 if (updatesTypeRanked) {
5257 for (int64_t windowDim : updateWindowDims) {
5258 if (windowDim < 0 || windowDim >= updateType.getRank()) {
5259 return mlir::emitError(loc)
5260 << "Expects each element of update_window_dims to be in range "
5261 "[0, "
5262 "rank-of('updates') i.e. [0, "
5263 << updateType.getRank() << "). got: " << windowDim << ".";
5264 }
5265 }
5266 }
5267
5268 // P2.
5269 auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims());
5270 if (!llvm::is_sorted(insertedWindowDims))
5271 return mlir::emitError(loc)
5272 << "Expects inserted_window_dims to be sorted; got: ["
5273 << insertedWindowDims << "].";
5274
5275 if (hasDuplicates(insertedWindowDims))
5276 return mlir::emitError(loc)
5277 << "Expects inserted_window_dims to not repeat; got: ["
5278 << insertedWindowDims << "].";
5279
5280 if (operandTypeRanked) {
5281 for (int64_t insertedDim : insertedWindowDims) {
5282 if (insertedDim < 0 || insertedDim >= operandType.getRank()) {
5283 return mlir::emitError(loc)
5284 << "Expects each element of inserted_window_dims to be in range "
5285 "[0, rank-of('operand') i.e. [0, "
5286 << operandType.getRank() << "). got: " << insertedDim << ".";
5287 }
5288 }
5289 }
5290
5291 // P3.
5292 if (operandTypeRanked) {
5293 auto windowSize = updateWindowDims.size() + insertedWindowDims.size();
5294 if (operandType.getRank() != static_cast<int64_t>(windowSize))
5295 return mlir::emitError(loc)
5296 << "Expects rank-of operand to match "
5297 "size-of('update_window_dims') + "
5298 "size-of('inserted_window_dims') i.e. "
5299 << windowSize << " but got " << operandType.getRank() << ".";
5300 }
5301
5302 // P4.
5303 auto scatterDimsToOperandDims =
5304 to_vector(dimNumbers.getScatterDimsToOperandDims());
5305 auto indexVectorDim = dimNumbers.getIndexVectorDim();
5306 if (scatterIndicesTypeRanked) {
5307 if (!isDynamicDimSize(scatterIndicesShape[indexVectorDim]) &&
5308 static_cast<int64_t>(scatterDimsToOperandDims.size()) !=
5309 scatterIndicesShape[dimNumbers.getIndexVectorDim()])
5310 return mlir::emitError(loc)
5311 << "Scatter op has " << scatterDimsToOperandDims.size()
5312 << " elements in scatter_dims_to_operand_dims and the bound of "
5313 "dimension index_vector_dim="
5314 << dimNumbers.getIndexVectorDim() << " of scatter_indices is "
5315 << scatterIndicesShape[dimNumbers.getIndexVectorDim()]
5316 << ". These two numbers must be equal.";
5317 }
5318
5319 if (operandTypeRanked) {
5320 for (int64_t i = 0;
5321 i < static_cast<int64_t>(scatterDimsToOperandDims.size()); ++i) {
5322 int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i];
5323 if (scatterDimToOperandDim < 0 ||
5324 scatterDimToOperandDim >= operandType.getRank())
5325 return mlir::emitError(loc)
5326 << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, "
5327 << operandType.getRank() << "), got: " << i << "->"
5328 << scatterDimToOperandDim << ".";
5329 }
5330 }
5331
5332 if (hasDuplicates(scatterDimsToOperandDims))
5333 return mlir::emitError(loc)
5334 << "Expects scatter_dims_to_operand_dims to not repeat; got: ["
5335 << scatterDimsToOperandDims << "].";
5336
5337 return success();
5338 }
5339 /*
5340 * We intend to verify the following properties:
5341 * P0. scatter_indices argument must be an integral tensor. Enforced by ODS.
5342 * P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)"
5343 * " + 1).
5344 * P2. Verify reducer shape.
5345 * P3. rank-of('updates[i]') == size-of('update_window_dims') +
5346 * rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a
5347 * trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices')
5348 * for all values of `i`.
5349 * P4. Validate the scatter-dimensions-numbers.
5350 * P5. Valide the bounds of each of the 'updates' w.r.t the operands.
5351 * P6. Validate the bounds of each of the 'updates' w.r.t the
5352 * 'scatter_indices'.
5353 * P7. Check return types.
5354 */
verify()5355 LogicalResult ScatterOp::verify() {
5356 // Get the first operand and update, since variadic Scatter is not yet
5357 // implemented
5358 auto numOperands = operands().size();
5359 auto scatterIndicesType = scatter_indices().getType().dyn_cast<TensorType>();
5360
5361 SmallVector<TensorType, 1> operandTypes =
5362 llvm::to_vector(llvm::map_range(operands().getTypes(), [](Type type) {
5363 return type.cast<TensorType>();
5364 }));
5365 SmallVector<TensorType, 1> updatesTypes = llvm::to_vector(llvm::map_range(
5366 updates().getTypes(), [](Type type) { return type.cast<TensorType>(); }));
5367 bool allOperandTypesRanked =
5368 llvm::all_of(operands().getTypes(),
5369 [](Type type) { return type.isa<RankedTensorType>(); });
5370 bool scatterIndicesTypeRanked = scatterIndicesType.isa<RankedTensorType>();
5371
5372 // P1.
5373 int64_t indexVectorDim = scatter_dimension_numbers().getIndexVectorDim();
5374 if (scatterIndicesTypeRanked) {
5375 if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0)
5376 return emitOpError()
5377 << "expects scatter index leaf dimension to be within [0, "
5378 "rank(scatter_indices) + 1."
5379 " rank(scatter_indices) is "
5380 << scatterIndicesType.getRank()
5381 << " and scatter index leaf dimension is " << indexVectorDim
5382 << ".";
5383 }
5384
5385 // P2.
5386 Block& block = update_computation().front();
5387 SmallVector<TensorType> accumulatorSubshapes;
5388 SmallVector<TensorType> inputTypes, initValueTypes;
5389 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5390 inputTypes.push_back(operandTypes[i]);
5391 initValueTypes.push_back(
5392 RankedTensorType::get({}, updatesTypes[i].getElementType()));
5393 }
5394 if (failed(verifyReducerShape(
5395 this->getLoc(), block, inputTypes, initValueTypes, numOperands,
5396 /*allowedDimensions=*/{},
5397 /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes)))
5398 return failure();
5399
5400 // P3.
5401 auto updateWindowDims = scatter_dimension_numbers().getUpdateWindowDims();
5402 SmallVector<int64_t> expandedScatterIndicesShape;
5403 if (scatterIndicesTypeRanked) {
5404 expandedScatterIndicesShape =
5405 llvm::to_vector(scatterIndicesType.getShape());
5406 if (static_cast<int64_t>(expandedScatterIndicesShape.size()) ==
5407 indexVectorDim)
5408 expandedScatterIndicesShape.push_back(1);
5409 }
5410
5411 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5412 if (scatterIndicesTypeRanked && updatesTypes[i].isa<RankedTensorType>()) {
5413 int64_t expectedUpdatesRank =
5414 expandedScatterIndicesShape.size() - 1 + updateWindowDims.size();
5415 if (updatesTypes[i].getRank() != expectedUpdatesRank)
5416 return emitOpError()
5417 << "expects updates tensor must be of rank "
5418 << expectedUpdatesRank
5419 << " ( == rank-of('scatter_indices') - 1 + "
5420 "size-of('update_window_dims'), where 'scatter_indices' is "
5421 "expanded by a trailing 1 dimension if 'index_vector_dim' == "
5422 "rank-of('scatter_indices')), but got "
5423 << updatesTypes[i].getRank() << ".";
5424 }
5425 }
5426
5427 // P4.
5428 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5429 if (failed(validateScatterDimensionNumbers(
5430 operandTypes[i], expandedScatterIndicesShape, updatesTypes[i],
5431 operandTypes[i].isa<RankedTensorType>(), scatterIndicesTypeRanked,
5432 updatesTypes[i].isa<RankedTensorType>(),
5433 scatter_dimension_numbers(), getLoc())))
5434 return failure();
5435 }
5436
5437 // P5.
5438 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5439 if (updatesTypes[i].isa<RankedTensorType>()) {
5440 auto updatesShape = updatesTypes[i].getShape();
5441 if (operandTypes[i].isa<RankedTensorType>()) {
5442 auto operandShape = operandTypes[i].getShape();
5443 auto insertedWindowDims =
5444 scatter_dimension_numbers().getInsertedWindowDims();
5445
5446 int64_t insertedDimsSeen = 0;
5447 SmallVector<int64_t> maxUpdateSliceSizes;
5448 const auto dimensionsSize = operandTypes[i].getRank();
5449 maxUpdateSliceSizes.reserve(dimensionsSize);
5450 for (int i = 0; i < dimensionsSize; ++i) {
5451 if (insertedDimsSeen <
5452 static_cast<int64_t>(insertedWindowDims.size()) &&
5453 insertedWindowDims[insertedDimsSeen] == i) {
5454 ++insertedDimsSeen;
5455 } else {
5456 maxUpdateSliceSizes.push_back(operandShape[i]);
5457 }
5458 }
5459
5460 for (int64_t i = 0; i < static_cast<int64_t>(updateWindowDims.size());
5461 ++i) {
5462 auto updateWindowDim = updateWindowDims[i];
5463
5464 if (isDynamicDimSize(updatesShape[updateWindowDim]) ||
5465 isDynamicDimSize(maxUpdateSliceSizes[i]))
5466 continue;
5467
5468 if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) {
5469 return emitOpError()
5470 << "expects bounds of the window dimensions of "
5471 "updates to not exceed the "
5472 "bounds of the corresponding dimensions of "
5473 "operand. For dimension "
5474 << updateWindowDim << ", updates bound is "
5475 << updatesShape[updateWindowDim] << ", operand bound is "
5476 << maxUpdateSliceSizes[i] << ".";
5477 }
5478 }
5479 }
5480
5481 // P6.
5482 if (scatterIndicesTypeRanked) {
5483 int64_t scatterDimsSeen = 0;
5484 for (int64_t i = 0; i < static_cast<int64_t>(updatesShape.size());
5485 ++i) {
5486 bool isUpdateWindowDim = std::binary_search(
5487 updateWindowDims.begin(), updateWindowDims.end(), i);
5488
5489 if (isUpdateWindowDim) continue;
5490 if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen;
5491
5492 if (!isDynamicDimSize(updatesShape[i]) &&
5493 !isDynamicDimSize(expandedScatterIndicesShape[scatterDimsSeen]) &&
5494 (updatesShape[i] !=
5495 expandedScatterIndicesShape[scatterDimsSeen])) {
5496 return emitOpError()
5497 << "expects bounds of the scatter dimensions of "
5498 "updates to be same as the "
5499 "bounds of the corresponding dimensions of "
5500 "scatter indices. For "
5501 "scatter dimension "
5502 << i << ", updates bound is " << updatesShape[i]
5503 << " , scatter_indices "
5504 "bound is "
5505 << expandedScatterIndicesShape[scatterDimsSeen] << ".";
5506 }
5507 ++scatterDimsSeen;
5508 }
5509 }
5510 }
5511 }
5512
5513 // P7.
5514 for (int64_t i = 0; i < static_cast<int64_t>(numOperands); i++) {
5515 if (!compatibleShapeAndElementType(operandTypes[i], getResult(i).getType()))
5516 return emitOpError()
5517 << "expects the return type to be same as the operand type: "
5518 << operandTypes[i] << ", but got " << getResult(i).getType()
5519 << ".";
5520 }
5521
5522 return success();
5523 }
5524
5525 //===----------------------------------------------------------------------===//
5526 // WhileOp
5527 //===----------------------------------------------------------------------===//
5528
verify()5529 LogicalResult WhileOp::verify() {
5530 if (getNumOperands() != cond().front().getNumArguments())
5531 return emitOpError() << "mismatch in operand count (" << getNumOperands()
5532 << ") vs the condition block argument count ("
5533 << cond().front().getNumArguments() << ")";
5534 if (getNumOperands() != body().front().getNumArguments())
5535 return emitOpError() << "mismatch in operand count (" << getNumOperands()
5536 << ") vs the body block argument count ("
5537 << body().front().getNumArguments() << ")";
5538 for (const auto& enumeratedOperands : llvm::enumerate(
5539 llvm::zip(getOperandTypes(), cond().front().getArgumentTypes(),
5540 body().front().getArgumentTypes()))) {
5541 int argCount = enumeratedOperands.index();
5542 const auto& operands = enumeratedOperands.value();
5543 Type operandType = std::get<0>(operands);
5544 Type condType = std::get<1>(operands);
5545 Type bodyType = std::get<2>(operands);
5546 if (operandType != condType)
5547 return emitOpError() << "type mismatch between operand #" << argCount
5548 << " and the matching condition block argument: "
5549 << operandType << " vs " << condType;
5550 if (operandType != bodyType)
5551 return emitOpError() << "type mismatch between operand #" << argCount
5552 << " and the matching body block argument: "
5553 << operandType << " vs " << bodyType;
5554 }
5555 // Check the return type for the condition block.
5556 {
5557 auto condReturnOp = cast<ReturnOp>(cond().front().back());
5558 if (condReturnOp->getNumOperands() != 1)
5559 return condReturnOp.emitOpError()
5560 << "expects a single operand for while condition body return, got "
5561 << condReturnOp->getNumOperands();
5562 auto operandType =
5563 condReturnOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
5564 if (!operandType || operandType.getRank() != 0 ||
5565 !operandType.getElementType().isInteger(1))
5566 return condReturnOp.emitOpError()
5567 << "expects a zero-ranked tensor of i1, got "
5568 << condReturnOp->getOperand(0).getType();
5569 }
5570 // Check the return type for the body block.
5571 {
5572 auto bodyReturnOp = cast<ReturnOp>(body().front().back());
5573 if (bodyReturnOp->getNumOperands() != getNumOperands())
5574 return bodyReturnOp.emitOpError()
5575 << "expects body to return a many value as the operands ("
5576 << getNumOperands() << "), got " << bodyReturnOp->getNumOperands();
5577 for (const auto& enumeratedOperandTypes : llvm::enumerate(
5578 llvm::zip(bodyReturnOp->getOperandTypes(), getOperandTypes()))) {
5579 Type operandType = std::get<0>(enumeratedOperandTypes.value());
5580 Type returnType = std::get<1>(enumeratedOperandTypes.value());
5581 if (operandType != returnType)
5582 return bodyReturnOp.emitOpError()
5583 << "type mismatch between operand #"
5584 << enumeratedOperandTypes.index()
5585 << " and the enclosing WhileOp returned value: " << operandType
5586 << " vs " << returnType;
5587 }
5588 }
5589 return success();
5590 }
5591
5592 /// Print a `while` op.
5593 ///
5594 /// op ::= `stablehlo.while` `(` assignment-list `)` `:` types attribute-dict
5595 /// `cond` region
5596 /// `do` region
5597 /// assignment-list ::= assignment | assignment `,` assignment-list
5598 /// assignment ::= ssa-value `=` ssa-value
print(OpAsmPrinter & p)5599 void WhileOp::print(OpAsmPrinter& p) {
5600 p << '(';
5601 llvm::interleaveComma(llvm::zip(getBody()->getArguments(), getOperands()), p,
5602 [&](auto zip) {
5603 p.printOperand(std::get<0>(zip));
5604 p << " = ";
5605 p.printOperand(std::get<1>(zip));
5606 });
5607 p << ")";
5608 if (getNumOperands()) {
5609 p << " : ";
5610 llvm::interleaveComma(getOperandTypes(), p);
5611 }
5612 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
5613 p.printNewline();
5614 p << " cond ";
5615 p.printRegion(getRegion(0), /*printEntryBlockArgs=*/false);
5616 p << " do ";
5617 p.printRegion(getRegion(1), /*printEntryBlockArgs=*/false);
5618 }
5619
parse(OpAsmParser & parser,OperationState & result)5620 ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) {
5621 llvm::SMLoc loc = parser.getCurrentLocation();
5622 // Parse the operands of the while: these are of the form:
5623 // %iter_arg = %init_val
5624 // where %iter_arg is the name of the block argument in the cond/body blocks
5625 // and %init_val is the actual operand.
5626 SmallVector<OpAsmParser::UnresolvedOperand> operands;
5627 SmallVector<OpAsmParser::UnresolvedOperand> iterArgs;
5628 if (parser.parseLParen()) return failure();
5629 do {
5630 if (succeeded(parser.parseOptionalRParen())) break;
5631 OpAsmParser::UnresolvedOperand operand, iterArg;
5632 if (parser.parseOperand(iterArg) || parser.parseEqual() ||
5633 parser.parseOperand(operand))
5634 return failure();
5635 iterArgs.push_back(iterArg);
5636 operands.push_back(operand);
5637 if (succeeded(parser.parseOptionalRParen())) break;
5638 if (failed(parser.parseComma())) return failure();
5639 } while (true);
5640 if (!operands.empty()) {
5641 if (parser.parseColon() || parser.parseTypeList(result.types))
5642 return failure();
5643 }
5644
5645 SmallVector<OpAsmParser::Argument> args;
5646 createArgs(iterArgs, result.types, args);
5647 if (parser.resolveOperands(operands, result.types, loc, result.operands) ||
5648 parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
5649 parser.parseKeyword("cond") ||
5650 parser.parseRegion(*result.addRegion(), args) ||
5651 parser.parseKeyword("do") ||
5652 parser.parseRegion(*result.addRegion(), args))
5653 return failure();
5654 return success();
5655 }
5656
inferReturnTypeComponents(MLIRContext *,Optional<Location>,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)5657 LogicalResult UniformDequantizeOp::inferReturnTypeComponents(
5658 MLIRContext*, Optional<Location> /*location*/, ValueShapeRange operands,
5659 DictionaryAttr attributes, RegionRange regions,
5660 SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
5661 UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions);
5662 auto operandType = (*operands.begin()).getType().cast<ShapedType>();
5663 // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
5664 auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
5665 auto shape = operandType.dyn_cast<ShapedType>().getShape();
5666 inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
5667 return success();
5668 }
5669
5670 } // namespace stablehlo
5671 } // namespace mlir
5672
5673 #define GET_OP_CLASSES
5674 #include "dialect/StablehloOps.cpp.inc"
5675
5676 namespace mlir {
5677 namespace stablehlo {
5678
5679 //===----------------------------------------------------------------------===//
5680 // StableHLO Dialect Interfaces
5681 //===----------------------------------------------------------------------===//
5682
5683 namespace {
5684 struct HLOBoundedDialectInterface : public hlo::BoundedDialectInterface {
5685 using BoundedDialectInterface::BoundedDialectInterface;
5686
createBoundedAttrmlir::stablehlo::__anona165a9102d11::HLOBoundedDialectInterface5687 Attribute createBoundedAttr(ArrayRef<int64_t> bounds) const override {
5688 return TypeExtensionsAttr::get(getDialect()->getContext(), bounds);
5689 }
5690 };
5691 } // end anonymous namespace
5692
5693 //===----------------------------------------------------------------------===//
5694 // StableHLO Dialect Constructor
5695 //===----------------------------------------------------------------------===//
5696
StablehloDialect(MLIRContext * context)5697 StablehloDialect::StablehloDialect(MLIRContext* context)
5698 : Dialect(getDialectNamespace(), context, TypeID::get<StablehloDialect>()) {
5699 addOperations<
5700 #define GET_OP_LIST
5701 #include "dialect/StablehloOps.cpp.inc"
5702 >();
5703 addInterfaces<HLOBoundedDialectInterface>();
5704 addTypes<TokenType>();
5705 addAttributes<
5706 #define GET_ATTRDEF_LIST
5707 #include "dialect/StablehloAttrs.cpp.inc"
5708 >();
5709 context->loadDialect<tensor::TensorDialect>();
5710 }
5711
parseType(DialectAsmParser & parser) const5712 Type StablehloDialect::parseType(DialectAsmParser& parser) const {
5713 StringRef dataType;
5714 if (parser.parseKeyword(&dataType)) return Type();
5715
5716 if (dataType == "token") return TokenType::get(getContext());
5717 parser.emitError(parser.getNameLoc())
5718 << "unknown stablehlo type: " << dataType;
5719 return nullptr;
5720 }
5721
printType(Type type,DialectAsmPrinter & os) const5722 void StablehloDialect::printType(Type type, DialectAsmPrinter& os) const {
5723 if (type.isa<TokenType>()) {
5724 os << "token";
5725 return;
5726 }
5727 os << "<unknown stablehlo type>";
5728 }
5729
5730 // Entry point for Attribute parsing, TableGen generated code will handle the
5731 // dispatch to the individual classes.
parseAttribute(DialectAsmParser & parser,Type type) const5732 Attribute StablehloDialect::parseAttribute(DialectAsmParser& parser,
5733 Type type) const {
5734 StringRef attrTag;
5735 Attribute attr;
5736 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
5737 if (parseResult.hasValue()) return attr;
5738 parser.emitError(parser.getNameLoc(), "unknown stablehlo attribute");
5739 return Attribute();
5740 }
5741
5742 // Entry point for Attribute printing, TableGen generated code will handle the
5743 // dispatch to the individual classes.
printAttribute(Attribute attr,DialectAsmPrinter & os) const5744 void StablehloDialect::printAttribute(Attribute attr,
5745 DialectAsmPrinter& os) const {
5746 LogicalResult result = generatedAttributePrinter(attr, os);
5747 (void)result;
5748 assert(succeeded(result));
5749 }
5750
5751 /// Helpers for attributes parsing.
5752
parseDims(AsmParser & parser,SmallVector<int64_t> & dims)5753 static ParseResult parseDims(AsmParser& parser, SmallVector<int64_t>& dims) {
5754 dims.clear();
5755 if (parser.parseLSquare()) return failure();
5756 while (failed(parser.parseOptionalRSquare())) {
5757 dims.emplace_back();
5758 if (parser.parseInteger(dims.back())) return failure();
5759 (void)parser.parseOptionalComma();
5760 }
5761 return success();
5762 }
5763
parseDimsWithMinimumElements(AsmParser & parser,SmallVector<int64_t> & dims,int minElements)5764 static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
5765 SmallVector<int64_t>& dims,
5766 int minElements) {
5767 if (failed(parseDims(parser, dims))) return failure();
5768 if (static_cast<int64_t>(dims.size()) < minElements)
5769 return parser.emitError(parser.getCurrentLocation())
5770 << "expected at least " << minElements << " element(s), found "
5771 << dims.size();
5772 return success();
5773 }
5774
5775 /// Parse a custom attribute that resembles a struct of the form
5776 /// <
5777 /// foo = something_parsed_by_custom_parser,
5778 /// bar = something_parsed_by_different_custom_parser,
5779 /// baz something_parsed_by_another_custom_parser
5780 /// >
5781 /// The optional argument `parse_equal` array can be used to denote if
5782 /// '=' follows the keyword (see baz in the example above) for a field. If
5783 /// not provided, all fields must be followed by a '='.
parseStruct(AsmParser & parser,ArrayRef<StringRef> keywords,ArrayRef<llvm::function_ref<ParseResult ()>> parseFuncs,ArrayRef<bool> parseEqual={})5784 static ParseResult parseStruct(
5785 AsmParser& parser, ArrayRef<StringRef> keywords,
5786 ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs,
5787 ArrayRef<bool> parseEqual = {}) {
5788 assert(keywords.size() == parseFuncs.size());
5789 assert(parseEqual.empty() || parseEqual.size() == keywords.size());
5790 SmallVector<bool> seen(keywords.size(), false);
5791 while (failed(parser.parseOptionalGreater())) {
5792 bool foundOne = false;
5793 for (const auto& it : llvm::enumerate(keywords)) {
5794 size_t index = it.index();
5795 StringRef keyword = it.value();
5796 if (succeeded(parser.parseOptionalKeyword(keyword))) {
5797 if (seen[index]) {
5798 return parser.emitError(parser.getCurrentLocation())
5799 << "duplicated `" << keyword << "` entry";
5800 }
5801 if (parseEqual.empty() || parseEqual[index]) {
5802 if (failed(parser.parseEqual())) return failure();
5803 }
5804 if (failed(parseFuncs[index]())) return failure();
5805 if (failed(parser.parseOptionalComma())) return parser.parseGreater();
5806 seen[index] = true;
5807 foundOne = true;
5808 }
5809 }
5810 if (!foundOne) {
5811 auto parseError = parser.emitError(parser.getCurrentLocation())
5812 << "expected one of: ";
__anona165a9102e02(StringRef kw) 5813 llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
5814 parseError << '`' << kw << '`';
5815 });
5816 return parseError;
5817 }
5818 }
5819 return success();
5820 }
5821
5822 // Helpers to print an optional array or integer field, to simplify writing
5823 // attribute printers.
5824 template <typename T>
printField(AsmPrinter & printer,StringRef name,T field,StringRef & separator)5825 static void printField(AsmPrinter& printer, StringRef name, T field,
5826 StringRef& separator) {
5827 if (field != 0) {
5828 printer << separator << name << " = " << field;
5829 separator = ", ";
5830 }
5831 }
5832 template <typename T>
printField(AsmPrinter & printer,StringRef name,ArrayRef<T> field,StringRef & separator)5833 static void printField(AsmPrinter& printer, StringRef name, ArrayRef<T> field,
5834 StringRef& separator) {
5835 if (!field.empty()) {
5836 printer << separator << name << " = [";
5837 llvm::interleaveComma(field, printer);
5838 printer << "]";
5839 separator = ", ";
5840 }
5841 }
5842 template <typename... Ts>
printStruct(AsmPrinter & printer,StringRef name,Ts...printFields)5843 static void printStruct(AsmPrinter& printer, StringRef name,
5844 Ts... printFields) {
5845 printer << "<";
5846 StringRef separator = "";
5847 // Fold expression to print each entry in the parameter pack.
5848 // TODO(stablehlo-team): this can be simplified when TF moves to C++17.
5849 using unused = int[];
5850 (void)unused{0, (printField(printer, std::get<0>(printFields),
5851 std::get<1>(printFields), separator),
5852 0)...};
5853 printer << ">";
5854 }
5855
5856 // Custom printer and parser for ScatterDimensionNumbersAttr.
print(AsmPrinter & printer) const5857 void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
5858 printStruct(printer, "scatter",
5859 std::make_pair("update_window_dims", getUpdateWindowDims()),
5860 std::make_pair("inserted_window_dims", getInsertedWindowDims()),
5861 std::make_pair("scatter_dims_to_operand_dims",
5862 getScatterDimsToOperandDims()),
5863 std::make_pair("index_vector_dim", getIndexVectorDim()));
5864 }
parse(AsmParser & parser,Type type)5865 Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5866 if (failed(parser.parseLess())) return {};
5867 SmallVector<int64_t> updateWindowDims;
5868 SmallVector<int64_t> insertedWindowDims;
5869 SmallVector<int64_t> scatterDimsToOperandDims;
5870 int64_t indexVectorDim = 0;
5871
5872 if (failed(parseStruct(
5873 parser,
5874 {"update_window_dims", "inserted_window_dims",
5875 "scatter_dims_to_operand_dims", "index_vector_dim"},
5876 {[&]() { return parseDims(parser, updateWindowDims); },
5877 [&]() { return parseDims(parser, insertedWindowDims); },
5878 [&]() { return parseDims(parser, scatterDimsToOperandDims); },
5879 [&]() { return parser.parseInteger(indexVectorDim); }}))) {
5880 parser.emitError(parser.getCurrentLocation())
5881 << "failed parsing scatter dimension numbers attribute";
5882 return {};
5883 }
5884
5885 return ScatterDimensionNumbersAttr::get(
5886 parser.getContext(), updateWindowDims, insertedWindowDims,
5887 scatterDimsToOperandDims, indexVectorDim);
5888 }
5889
5890 // Custom printer and parser for GatherDimensionNumbersAttr.
print(AsmPrinter & printer) const5891 void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
5892 printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
5893 std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
5894 std::make_pair("start_index_map", getStartIndexMap()),
5895 std::make_pair("index_vector_dim", getIndexVectorDim()));
5896 }
5897
parse(AsmParser & parser,Type type)5898 Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5899 if (failed(parser.parseLess())) return {};
5900
5901 SmallVector<int64_t> offsetDims;
5902 SmallVector<int64_t> collapsedSliceDims;
5903 SmallVector<int64_t> startIndexMap;
5904 int64_t indexVectorDim = 0;
5905
5906 if (failed(parseStruct(
5907 parser,
5908 {"offset_dims", "collapsed_slice_dims", "start_index_map",
5909 "index_vector_dim"},
5910 {[&]() { return parseDims(parser, offsetDims); },
5911 [&]() { return parseDims(parser, collapsedSliceDims); },
5912 [&]() { return parseDims(parser, startIndexMap); },
5913 [&]() { return parser.parseInteger(indexVectorDim); }}))) {
5914 parser.emitError(parser.getCurrentLocation())
5915 << "failed parsing gather dimension numbers attribute";
5916 return {};
5917 }
5918
5919 return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims,
5920 collapsedSliceDims, startIndexMap,
5921 indexVectorDim);
5922 }
5923
5924 // Custom printer and parser for DotDimensionNumbersAttr.
print(AsmPrinter & printer) const5925 void DotDimensionNumbersAttr::print(AsmPrinter& printer) const {
5926 printStruct(
5927 printer, "dot",
5928 std::make_pair("lhs_batching_dimensions", getLhsBatchingDimensions()),
5929 std::make_pair("rhs_batching_dimensions", getRhsBatchingDimensions()),
5930 std::make_pair("lhs_contracting_dimensions",
5931 getLhsContractingDimensions()),
5932 std::make_pair("rhs_contracting_dimensions",
5933 getRhsContractingDimensions()));
5934 }
5935
parse(AsmParser & parser,Type type)5936 Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
5937 if (failed(parser.parseLess())) return {};
5938
5939 SmallVector<int64_t> lhsBatchingDimensions;
5940 SmallVector<int64_t> rhsBatchingDimensions;
5941 SmallVector<int64_t> lhsContractingDimensions;
5942 SmallVector<int64_t> rhsContractingDimensions;
5943
5944 if (failed(parseStruct(
5945 parser,
5946 {"lhs_batching_dimensions", "rhs_batching_dimensions",
5947 "lhs_contracting_dimensions", "rhs_contracting_dimensions"},
5948 {[&]() { return parseDims(parser, lhsBatchingDimensions); },
5949 [&]() { return parseDims(parser, rhsBatchingDimensions); },
5950 [&]() { return parseDims(parser, lhsContractingDimensions); },
5951 [&]() { return parseDims(parser, rhsContractingDimensions); }}))) {
5952 parser.emitError(parser.getCurrentLocation())
5953 << "failed parsing dot dimension numbers attribute";
5954 return {};
5955 }
5956 return DotDimensionNumbersAttr::get(
5957 parser.getContext(), lhsBatchingDimensions, rhsBatchingDimensions,
5958 lhsContractingDimensions, rhsContractingDimensions);
5959 }
5960
5961 namespace {
5962 enum NonSpatialDim : int64_t {
5963 IOBatch = -1, // Input or output batch dimension
5964 IOFeature = -2, // Input or output feature dimension
5965 KIFeature = -3, // Kernel input feature dimension
5966 KOFeature = -4, // Kernel output feature dimensions.
5967 };
5968
5969 struct DenseMapInfoNonSpatialDim {
getEmptyKeymlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5970 static inline NonSpatialDim getEmptyKey() {
5971 return NonSpatialDim(DenseMapInfo<int64_t>::getEmptyKey());
5972 }
5973
getTombstoneKeymlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5974 static inline NonSpatialDim getTombstoneKey() {
5975 return NonSpatialDim(DenseMapInfo<int64_t>::getTombstoneKey());
5976 }
5977
getHashValuemlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5978 static unsigned getHashValue(const NonSpatialDim& key) {
5979 return DenseMapInfo<int64_t>::getHashValue(key);
5980 }
5981
isEqualmlir::stablehlo::__anona165a9103b11::DenseMapInfoNonSpatialDim5982 static bool isEqual(const NonSpatialDim& lhs, const NonSpatialDim& rhs) {
5983 return lhs == rhs;
5984 }
5985 };
5986
nonSpatialDimToString(NonSpatialDim dim)5987 char nonSpatialDimToString(NonSpatialDim dim) {
5988 switch (dim) {
5989 case IOBatch:
5990 return 'b';
5991 case IOFeature:
5992 return 'f';
5993 case KIFeature:
5994 return 'i';
5995 case KOFeature:
5996 return 'o';
5997 }
5998 llvm_unreachable("Unknown NonSpatialDim");
5999 }
6000 } // namespace
6001
6002 // Custom printer and parser for convolution attribute.
printConvolutionDimensions(AsmPrinter & p,ConvDimensionNumbersAttr dnums)6003 void printConvolutionDimensions(AsmPrinter& p, ConvDimensionNumbersAttr dnums) {
6004 // TODO(b/202040055): we should check the attribute invariant and print the
6005 // "raw" form if they are violated, otherwise we'll crash here.
6006 constexpr int64_t kUnknownDim = std::numeric_limits<int64_t>::min();
6007 auto printDim =
6008 [&](ArrayRef<int64_t> spatialDims,
6009 ArrayRef<std::pair<int64_t, NonSpatialDim>> nonSpatialDims) {
6010 int64_t numDims = 0;
6011 if (!spatialDims.empty()) {
6012 numDims =
6013 *std::max_element(spatialDims.begin(), spatialDims.end()) + 1;
6014 }
6015 for (const auto& dim : nonSpatialDims) {
6016 numDims = std::max(numDims, dim.first + 1);
6017 }
6018
6019 llvm::SmallVector<int64_t> dims(numDims, kUnknownDim);
6020 // Fill each element of dims with a (< 0) NonSpatialDim enum or a (>=0)
6021 // spatial dimension index.
6022 for (const std::pair<int64_t, NonSpatialDim>& nonSpatialDim :
6023 nonSpatialDims) {
6024 dims[nonSpatialDim.first] = nonSpatialDim.second;
6025 }
6026 for (const auto& spatialDim : llvm::enumerate(spatialDims)) {
6027 dims[spatialDim.value()] = static_cast<int64_t>(spatialDim.index());
6028 }
6029
6030 // Each dimension numbers will be printed as a comma separated list
6031 // surrounded by square brackets, e.g., [b, 0, 1, 2, f]
6032 p << '[';
6033 llvm::interleaveComma(dims, p, [&](int64_t dim) {
6034 if (dim == kUnknownDim) {
6035 p << "?";
6036 } else if (dim >= 0) {
6037 p << dim;
6038 } else {
6039 p << nonSpatialDimToString(static_cast<NonSpatialDim>(dim));
6040 }
6041 });
6042 p << ']';
6043 };
6044
6045 printDim(dnums.getInputSpatialDimensions(),
6046 {{dnums.getInputBatchDimension(), IOBatch},
6047 {dnums.getInputFeatureDimension(), IOFeature}});
6048 p << "x";
6049 printDim(dnums.getKernelSpatialDimensions(),
6050 {{dnums.getKernelInputFeatureDimension(), KIFeature},
6051 {dnums.getKernelOutputFeatureDimension(), KOFeature}});
6052 p << "->";
6053 printDim(dnums.getOutputSpatialDimensions(),
6054 {{dnums.getOutputBatchDimension(), IOBatch},
6055 {dnums.getOutputFeatureDimension(), IOFeature}});
6056 }
6057
printConvolutionDimensions(AsmPrinter & p,Operation *,ConvDimensionNumbersAttr dnums)6058 void printConvolutionDimensions(AsmPrinter& p, Operation*,
6059 ConvDimensionNumbersAttr dnums) {
6060 printConvolutionDimensions(p, dnums);
6061 }
6062
6063 // Custom printer and parser for ConvDimensionNumbersAttr.
print(AsmPrinter & printer) const6064 void ConvDimensionNumbersAttr::print(AsmPrinter& printer) const {
6065 printer << "<";
6066 printConvolutionDimensions(printer, *this);
6067 printer << ">";
6068 }
6069
6070 // If the attribute is written with `#stablehlo.conv raw<`, we parse it as
6071 // a struct instead of the compressed format. This enables writing tests
6072 // covering impossible/invalid internal representation for the attribute.
parseConvolutionDimensionsRaw(AsmParser & parser,ConvDimensionNumbersAttr & dnums)6073 static ParseResult parseConvolutionDimensionsRaw(
6074 AsmParser& parser, ConvDimensionNumbersAttr& dnums) {
6075 int64_t inputBatchDimension = 0;
6076 int64_t inputFeatureDimension = 0;
6077 SmallVector<int64_t> inputSpatialDimensions;
6078 int64_t kernelInputFeatureDimension = 0;
6079 int64_t kernelOutputFeatureDimension = 0;
6080 SmallVector<int64_t> kernelSpatialDimensions;
6081 int64_t outBatchDimension = 0;
6082 int64_t outputFeatureDimension = 0;
6083 SmallVector<int64_t> outputSpatialDimensions;
6084 if (failed(parseStruct(
6085 parser,
6086 {"input_batch_dimension", "input_feature_dimension",
6087 "input_spatial_dimensions", "kernel_input_feature_dimension",
6088 "kernel_output_feature_dimension", "kernel_spatial_dimensions",
6089 "output_batch_dimension", "output_feature_dimension",
6090 "output_spatial_dimensions"},
6091 {
6092 [&]() { return parser.parseInteger(inputBatchDimension); },
6093 [&]() { return parser.parseInteger(inputFeatureDimension); },
6094 [&]() { return parseDims(parser, inputSpatialDimensions); },
6095 [&]() {
6096 return parser.parseInteger(kernelInputFeatureDimension);
6097 },
6098 [&]() {
6099 return parser.parseInteger(kernelOutputFeatureDimension);
6100 },
6101 [&]() { return parseDims(parser, kernelSpatialDimensions); },
6102 [&]() { return parser.parseInteger(outBatchDimension); },
6103 [&]() { return parser.parseInteger(outputFeatureDimension); },
6104 [&]() { return parseDims(parser, outputSpatialDimensions); },
6105 }))) {
6106 parser.emitError(parser.getCurrentLocation())
6107 << "failed parsing dot dimension numbers attribute";
6108 return failure();
6109 }
6110 dnums = ConvDimensionNumbersAttr::get(
6111 parser.getBuilder().getContext(), inputBatchDimension,
6112 inputFeatureDimension, inputSpatialDimensions,
6113 kernelInputFeatureDimension, kernelOutputFeatureDimension,
6114 kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
6115 outputSpatialDimensions);
6116 return success();
6117 }
6118
parseConvolutionDimensions(AsmParser & parser,ConvDimensionNumbersAttr & dnums)6119 ParseResult parseConvolutionDimensions(AsmParser& parser,
6120 ConvDimensionNumbersAttr& dnums) {
6121 // Parsing a single set of dim numbers gives the spatial dimensions as a
6122 // single ArrayRef<int64_t> and a list of non-spatial dimensions as
6123 // IntegerAttrs (indexed by the NonSpatialDim enum).
6124 using parse_dim_result_t =
6125 std::pair<llvm::SmallVector<int64_t>,
6126 llvm::SmallDenseMap<NonSpatialDim, int64_t, 4,
6127 DenseMapInfoNonSpatialDim>>;
6128
6129 // Note that the allowed_non_spatial_dims is a set (as opposed to unordered
6130 // set) because its used to print a list of allowed non spatial dims in the
6131 // error messages, so making it a set keeps the error messages deterministic.
6132 auto parseDims =
6133 [&](std::set<NonSpatialDim, std::greater<>> allowedNonSpatialDims,
6134 parse_dim_result_t& parsedDims) -> ParseResult {
6135 auto& spatialDims = std::get<0>(parsedDims);
6136 auto& nonSpatialDims = std::get<1>(parsedDims);
6137 spatialDims.clear();
6138 nonSpatialDims.clear();
6139
6140 // Parse the starting [
6141 if (parser.parseLSquare()) {
6142 return failure();
6143 }
6144
6145 llvm::SmallDenseMap<int64_t, int64_t> spatialDimsMap;
6146 constexpr int64_t kInvalidDimension = -1;
6147 // Keep track of the maximum spatial dimension parsed as we expect to see
6148 // all the dimensions from 0 to maximum dimension parsed.
6149 int64_t maxParsedSpatialDim = kInvalidDimension;
6150
6151 int64_t index = 0;
6152 do {
6153 int64_t spatialDim;
6154 auto dimLocation = parser.getCurrentLocation();
6155 OptionalParseResult parseResult = parser.parseOptionalInteger(spatialDim);
6156 if (parseResult.hasValue()) {
6157 if (parseResult.getValue().failed()) {
6158 return failure();
6159 }
6160 // We were successful in parsing an integer. Check if it is a valid
6161 // dimension (non-negative and no duplicate) and add its index to the
6162 // spatial dims map.
6163 if (spatialDim < 0)
6164 return parser.emitError(dimLocation)
6165 << "Unexpected dimension " << spatialDim;
6166 if (!spatialDimsMap
6167 .insert(std::pair<int64_t, int64_t>(spatialDim, index))
6168 .second)
6169 return parser.emitError(dimLocation)
6170 << "Duplicate entries for spatial dimension " << spatialDim;
6171 maxParsedSpatialDim = std::max(spatialDim, maxParsedSpatialDim);
6172 } else if (!parser.parseOptionalQuestion()) {
6173 // Do nothing other than increment `index` at the bottom of the loop;
6174 // '?' means "unknown dimension", and it's not represented in the
6175 // return value of this function.
6176 } else {
6177 // We did not parse an integer or question mark. We expect a keyword
6178 // token.
6179 StringRef keyword;
6180 if (parser.parseKeyword(&keyword)) {
6181 return failure();
6182 }
6183 if (keyword.size() != 1 || allowedNonSpatialDims.empty()) {
6184 return parser.emitError(dimLocation, "Unexpected keyword ")
6185 << keyword;
6186 }
6187 // Check if the keyword matches one of the allowed non-spatial dims.
6188 // If so, add it to the non_spatial dims and remove it from the
6189 // allowed set so that it won't be allowed again.
6190 bool isAllowed = false;
6191 for (NonSpatialDim allowed : allowedNonSpatialDims) {
6192 if (keyword[0] == nonSpatialDimToString(allowed)) {
6193 nonSpatialDims.insert({allowed, index});
6194 allowedNonSpatialDims.erase(allowed);
6195 isAllowed = true;
6196 break;
6197 }
6198 }
6199
6200 if (!isAllowed) {
6201 mlir::InFlightDiagnostic diag =
6202 parser.emitError(dimLocation, "Unexpected dimension ");
6203 diag << keyword << ", expecting ";
6204 llvm::interleaveComma(
6205 allowedNonSpatialDims, diag,
6206 [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
6207 return diag;
6208 }
6209 }
6210 index++;
6211 } while (parser.parseOptionalComma().succeeded());
6212
6213 // Make sure all expected non-spatial dimensions are parsed.
6214 if (!allowedNonSpatialDims.empty()) {
6215 mlir::InFlightDiagnostic diag =
6216 parser.emitError(parser.getCurrentLocation(), "Expected dimensions ");
6217 llvm::interleaveComma(
6218 allowedNonSpatialDims, diag,
6219 [&](NonSpatialDim dim) { diag << nonSpatialDimToString(dim); });
6220 diag << " not specified";
6221 return diag;
6222 }
6223
6224 // parse ending ]
6225 if (parser.parseRSquare()) {
6226 return failure();
6227 }
6228
6229 // Number of expected spatial dimensions is one more than the maximum parsed
6230 // spatial dimension. For example, if we parse [0, 3, 2, b, i, 1], then the
6231 // maximum parsed spatial dimension is 3 and the number of expected spatial
6232 // dimensions is 4.
6233 int64_t numSpatialDimensions = maxParsedSpatialDim + 1;
6234 spatialDims.resize(numSpatialDimensions);
6235 // Store spatial dimensions in a vector which maps spatial dim (vector
6236 // index) -> index in the tensor dimensions. For example, for parsed
6237 // dimension numbers [0, 3, 2, b, i, 1] the spatial dimension vector would
6238 // be [0, 5, 2, 1].
6239 //
6240 // Get all the unspecified spatial dimensions to throw a more descriptive
6241 // error later.
6242 llvm::SmallVector<int64_t> unspecifiedSpatialDims;
6243 constexpr int kPrintUnspecifiedDimsMax = 10;
6244 for (int dim = 0; dim < numSpatialDimensions; ++dim) {
6245 auto it = spatialDimsMap.find(dim);
6246 if (it == spatialDimsMap.end()) {
6247 // Have an upper bound on the number of unspecified dimensions to print
6248 // in the error message.
6249 if (unspecifiedSpatialDims.size() < kPrintUnspecifiedDimsMax)
6250 unspecifiedSpatialDims.push_back(dim);
6251 continue;
6252 }
6253 spatialDims[dim] = it->second;
6254 }
6255
6256 // Verify that we got all spatial dimensions between 0 and maximum parsed
6257 // spatial dimension.
6258 if (!unspecifiedSpatialDims.empty()) {
6259 mlir::InFlightDiagnostic diag = parser.emitError(
6260 parser.getCurrentLocation(), "Expected spatial dimensions ");
6261 llvm::interleaveComma(unspecifiedSpatialDims, diag);
6262 diag << " not specified";
6263 return diag;
6264 }
6265
6266 return success();
6267 };
6268
6269 parse_dim_result_t parsedDims;
6270 if (parseDims({IOBatch, IOFeature}, parsedDims)) {
6271 return failure();
6272 }
6273 llvm::SmallVector<int64_t> inputSpatialDimensions = parsedDims.first;
6274 int64_t inputBatchDimension = parsedDims.second[IOBatch];
6275 int64_t inputFeatureDimension = parsedDims.second[IOFeature];
6276 if (parser.parseKeyword("x")) return failure();
6277 if (parseDims({KIFeature, KOFeature}, parsedDims)) {
6278 return failure();
6279 }
6280 llvm::SmallVector<int64_t> kernelSpatialDimensions = parsedDims.first;
6281 int64_t kernelInputFeatureDimension = parsedDims.second[KIFeature];
6282 int64_t kernelOutputFeatureDimension = parsedDims.second[KOFeature];
6283 if (parser.parseArrow()) {
6284 return failure();
6285 }
6286 if (parseDims({IOBatch, IOFeature}, parsedDims)) {
6287 return failure();
6288 }
6289 llvm::SmallVector<int64_t> outputSpatialDimensions = parsedDims.first;
6290 const int64_t outBatchDimension = parsedDims.second[IOBatch];
6291 const int64_t outputFeatureDimension = parsedDims.second[IOFeature];
6292 dnums = ConvDimensionNumbersAttr::get(
6293 parser.getBuilder().getContext(), inputBatchDimension,
6294 inputFeatureDimension, inputSpatialDimensions,
6295 kernelInputFeatureDimension, kernelOutputFeatureDimension,
6296 kernelSpatialDimensions, outBatchDimension, outputFeatureDimension,
6297 outputSpatialDimensions);
6298
6299 return success();
6300 }
6301
parse(AsmParser & parser,Type type)6302 Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
6303 if (failed(parser.parseLess())) return {};
6304 ConvDimensionNumbersAttr dnums;
6305 if (succeeded(parser.parseOptionalKeyword("raw"))) {
6306 if (failed(parseConvolutionDimensionsRaw(parser, dnums))) return {};
6307 return dnums;
6308 }
6309 if (failed(parseConvolutionDimensions(parser, dnums))) return {};
6310 if (failed(parser.parseGreater())) return {};
6311 return dnums;
6312 }
6313
6314 // Custom printer and parser for ArgResultAliasAttr.
6315 constexpr char kMustAlias[] = "must_alias";
6316 constexpr char kResult[] = "result_index";
6317 constexpr char kArgTupleIndices[] = "tuple_indices";
6318
print(AsmPrinter & printer) const6319 void ArgResultAliasAttr::print(AsmPrinter& printer) const {
6320 printer << "<";
6321
6322 // The attribute can have empty tuple indices. Only print argument tuple
6323 // indices if they are non-empty.
6324 if (!getArgTupleIndices().empty())
6325 printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";
6326
6327 // Print the result index followed by any result tuple indices if present.
6328 printer << kResult << " = [";
6329 printer << getResultIndex();
6330 if (!getResultTupleIndices().empty()) {
6331 printer << ", " << getResultTupleIndices();
6332 }
6333 printer << "]";
6334
6335 // Print the "must_alias" keyword if this is a must alias, otherwise skip.
6336 if (getIsMustAlias()) printer << ", " << kMustAlias;
6337
6338 printer << ">";
6339 }
6340
parse(AsmParser & parser,Type type)6341 Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
6342 if (failed(parser.parseLess())) return {};
6343 llvm::SmallVector<int64_t> argTupleIndices;
6344 // The first element of result indices holds the aliased result index and the
6345 // remaining elements are the result tuple indices.
6346 llvm::SmallVector<int64_t> resultIndices;
6347 bool isMustAlias = false;
6348
6349 // This conveys to parseStruct that keyword "must_alias" (3rd field) is not
6350 // followed by a "=", but other fields are.
6351 llvm::SmallVector<bool, 3> parseEqual = {true, true, false};
6352
6353 if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
6354 {[&]() { return parseDims(parser, argTupleIndices); },
6355 [&]() {
6356 // Since the first element is the index of result,
6357 // at least one element is expected.
6358 return parseDimsWithMinimumElements(
6359 parser, resultIndices, /*minElements=*/1);
6360 },
6361 [&]() {
6362 // always succeeds if the keyword "must_alias" was
6363 // parsed
6364 isMustAlias = true;
6365 return success();
6366 }},
6367 parseEqual))) {
6368 parser.emitError(parser.getCurrentLocation())
6369 << "failed parsing argument-result alias attribute";
6370 return {};
6371 }
6372
6373 int64_t resultIndex = resultIndices[0];
6374 auto resultTupleIndices =
6375 ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()};
6376
6377 return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices,
6378 resultIndex, resultTupleIndices, isMustAlias);
6379 }
6380
6381 // Returns the element type pointed to by `indices` in type `t`. If the indices
6382 // are invalid, returns nullptr.
getTypeFromTupleIndices(Type type,ArrayRef<int64_t> indices)6383 static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
6384 Type current = type;
6385 for (auto index : indices) {
6386 TupleType tupleType = current.dyn_cast<TupleType>();
6387 if (!tupleType || index >= static_cast<int64_t>(tupleType.size()))
6388 return {};
6389 current = tupleType.getType(index);
6390 }
6391 return current;
6392 }
6393
verifyArgResultAliasAttr(StringAttr attrName,ArgResultAliasAttr aliasAttr,unsigned argIndex,Operation * op)6394 static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
6395 ArgResultAliasAttr aliasAttr,
6396 unsigned argIndex,
6397 Operation* op) {
6398 // The attribute can only be applied to function-like operations.
6399 if (!isa<mlir::FunctionOpInterface>(op))
6400 return op->emitOpError() << "attribute " << attrName
6401 << " can only be used on function-like operations";
6402
6403 // Verify there are no negative indices.
6404 auto tupleIndices = llvm::concat<const int64_t>(
6405 aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices());
6406 if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) ||
6407 aliasAttr.getResultIndex() < 0)
6408 return op->emitOpError()
6409 << "attribute " << attrName
6410 << " expects all argument and result indices to be >= 0";
6411
6412 // Verify that the result index is not out of range. Since the attribute is a
6413 // function argument attribute, the argument index is always correct when this
6414 // verifier is called.
6415 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
6416 ArrayRef<Type> argTypes = funcOp.getArgumentTypes();
6417 ArrayRef<Type> resultTypes = funcOp.getResultTypes();
6418 if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size()))
6419 return op->emitOpError()
6420 << "attribute " << attrName
6421 << " result index is out of range, must be <" << resultTypes.size();
6422
6423 // Verify that argument and result types pointed to by the indices are valid
6424 // and compatible.
6425 Type argType = getTypeFromTupleIndices(argTypes[argIndex],
6426 aliasAttr.getArgTupleIndices());
6427 if (!argType)
6428 return op->emitOpError()
6429 << "attribute " << attrName << " argument tuple indices are invalid";
6430 Type resultType =
6431 getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()],
6432 aliasAttr.getResultTupleIndices());
6433 if (!resultType)
6434 return op->emitOpError()
6435 << "attribute " << attrName << " result tuple indices are invalid";
6436
6437 if (failed(mlir::verifyCompatibleShape(argType, resultType)) ||
6438 getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType))
6439 return op->emitOpError() << "attribute " << attrName
6440 << " aliases do not have compatible types, "
6441 << argType << " vs. " << resultType;
6442 return success();
6443 }
6444
6445 namespace {
6446 // Custom formatting for convolution window attributes.
printWindowAttribute(OpAsmPrinter & p,DenseElementsAttr attribute)6447 void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) {
6448 if (attribute.getElementType().isInteger(/*width=*/1)) {
6449 // boolean attribute.
6450 llvm::interleaveComma(attribute.getValues<bool>(), p,
6451 [&](bool b) { p << (b ? 1 : 0); });
6452 return;
6453 }
6454 if (attribute.getType().getRank() == 2) {
6455 // Padding is Nx2 attribute.
6456 auto it = attribute.value_begin<int64_t>();
6457 std::vector<std::pair<int64_t, int64_t>> values(attribute.getNumElements() /
6458 2);
6459 for (auto& item : values) {
6460 int64_t first = *it;
6461 ++it;
6462 int64_t second = *it;
6463 ++it;
6464 item = {first, second};
6465 }
6466 llvm::interleaveComma(
6467 values, p, [&](const std::pair<int64_t, int64_t> pair) {
6468 p << '[' << pair.first << ", " << pair.second << ']';
6469 });
6470 } else {
6471 llvm::interleaveComma(attribute.getValues<int64_t>(), p);
6472 }
6473 }
6474 } // namespace
6475
printWindowAttributes(OpAsmPrinter & p,Operation *,llvm::Optional<DenseIntElementsAttr> windowStrides,llvm::Optional<DenseIntElementsAttr> padding,llvm::Optional<DenseIntElementsAttr> lhsDilation,llvm::Optional<DenseIntElementsAttr> rhsDilation,llvm::Optional<DenseElementsAttr> windowReversal)6476 void printWindowAttributes(OpAsmPrinter& p, Operation* /*op*/,
6477 llvm::Optional<DenseIntElementsAttr> windowStrides,
6478 llvm::Optional<DenseIntElementsAttr> padding,
6479 llvm::Optional<DenseIntElementsAttr> lhsDilation,
6480 llvm::Optional<DenseIntElementsAttr> rhsDilation,
6481 llvm::Optional<DenseElementsAttr> windowReversal) {
6482 using pair_t = std::pair<DenseElementsAttr, StringRef>;
6483 std::array<pair_t, 5> printedAttributes = {{
6484 {windowStrides ? *windowStrides : nullptr, "stride"},
6485 {padding ? *padding : nullptr, "pad"},
6486 {lhsDilation ? *lhsDilation : nullptr, "lhs_dilate"},
6487 {rhsDilation ? *rhsDilation : nullptr, "rhs_dilate"},
6488 {windowReversal ? *windowReversal : nullptr, "reverse"},
6489 }};
6490
6491 // Do not print attributes that do no exist.
6492 auto nonNullAttributes = llvm::make_filter_range(
6493 printedAttributes,
6494 [](const pair_t& a) { return static_cast<bool>(a.first); });
6495
6496 llvm::interleaveComma(nonNullAttributes, p, [&](const pair_t& a) {
6497 p << a.second << " = [";
6498 printWindowAttribute(p, a.first);
6499 p << "]";
6500 });
6501 }
6502
parseWindowAttributes(OpAsmParser & parser,DenseIntElementsAttr & windowStrides,DenseIntElementsAttr & padding,DenseIntElementsAttr & lhsDilation,DenseIntElementsAttr & rhsDilation,DenseElementsAttr & windowReversal)6503 ParseResult parseWindowAttributes(OpAsmParser& parser,
6504 DenseIntElementsAttr& windowStrides,
6505 DenseIntElementsAttr& padding,
6506 DenseIntElementsAttr& lhsDilation,
6507 DenseIntElementsAttr& rhsDilation,
6508 DenseElementsAttr& windowReversal) {
6509 StringRef attributeName;
6510
6511 llvm::StringSet<> allowedAttributeNames{
6512 {"stride", "pad", "lhs_dilate", "rhs_dilate", "reverse"}};
6513
6514 while (parser.parseOptionalKeyword(&attributeName).succeeded()) {
6515 // Verify that the attribute name is valid and erase it.
6516 if (!allowedAttributeNames.erase(attributeName)) {
6517 return parser.emitError(parser.getCurrentLocation(),
6518 "Unexpected keyword ")
6519 << attributeName;
6520 }
6521
6522 if (parser.parseEqual()) {
6523 return failure();
6524 }
6525
6526 // parse the attribute value. We need to support either 1D and Nx2 array of
6527 // integers to parse.
6528 llvm::SmallVector<int64_t> values;
6529 auto int64Parser = [&]() {
6530 return parser.parseInteger(values.emplace_back(0));
6531 };
6532
6533 if (attributeName == "pad") {
6534 // Parse 2D array of integers.
6535 // Helper to parse an array of two integer elements such as [e0, e1].
6536 auto innerParser = [&]() -> ParseResult {
6537 size_t numOldElements = values.size();
6538 if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Square,
6539 int64Parser))
6540 return failure();
6541 size_t numParsedElements = values.size() - numOldElements;
6542 constexpr size_t kExpectedElements = 2;
6543 if (numParsedElements != kExpectedElements)
6544 return parser.emitError(parser.getCurrentLocation())
6545 << "Expected array with " << kExpectedElements
6546 << " elements, got " << numParsedElements
6547 << " elements instead";
6548 return success();
6549 };
6550
6551 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
6552 innerParser)) {
6553 return failure();
6554 }
6555 const int64_t size = static_cast<int64_t>(values.size());
6556 // values should be filled with the Nx2 padding values.
6557 assert(size % 2 == 0);
6558 auto ty = RankedTensorType::get({size / 2, 2},
6559 parser.getBuilder().getIntegerType(64));
6560 padding = DenseIntElementsAttr::get(ty, values);
6561 } else {
6562 // Parse 1D array of integers.
6563 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
6564 int64Parser)) {
6565 return failure();
6566 }
6567 const int64_t size = static_cast<int64_t>(values.size());
6568 if (attributeName == "reverse") {
6569 auto ty = RankedTensorType::get({size},
6570 parser.getBuilder().getIntegerType(1));
6571 auto boolVector = llvm::to_vector<4>(
6572 llvm::map_range(values, [](int64_t v) { return v != 0; }));
6573 windowReversal = DenseElementsAttr::get(ty, boolVector);
6574 } else {
6575 auto attr = parser.getBuilder().getI64TensorAttr(values);
6576
6577 if (attributeName == "stride") {
6578 windowStrides = attr;
6579 } else if (attributeName == "lhs_dilate") {
6580 lhsDilation = attr;
6581 } else if (attributeName == "rhs_dilate") {
6582 rhsDilation = attr;
6583 } else {
6584 llvm_unreachable("Unexpected attribute name");
6585 }
6586 }
6587 }
6588 // continue parsing if there is a comma at the end.
6589 if (parser.parseOptionalComma().failed()) break;
6590 }
6591 return success();
6592 }
6593
6594 //===----------------------------------------------------------------------===//
6595 // Builder utilities
6596 //===----------------------------------------------------------------------===//
6597
6598 // Builds the region `body` for stablehlo.sort's comparator: for each type in
6599 // `element_types`, create two block arguments, one for lhs and one for rhs, and
6600 // generates stablehlo.compare op to compare them with the given `direction`.
6601 //
6602 // Note that this right now only does comparision on the first pair of block
6603 // arguments.
buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,ComparisonDirection direction,llvm::Optional<StringRef> compareType,Region * body,OpBuilder * builder)6604 static void buildSortComparisonBody(llvm::ArrayRef<Type> elementTypes,
6605 ComparisonDirection direction,
6606 llvm::Optional<StringRef> compareType,
6607 Region* body, OpBuilder* builder) {
6608 OpBuilder::InsertionGuard insertionPointGurad(*builder);
6609
6610 Location loc = body->getLoc();
6611 Block* block = builder->createBlock(body);
6612 // Add two arguments for each element type.
6613 for (Type elementType : elementTypes) {
6614 TensorType tensorType = RankedTensorType::get({}, elementType);
6615 block->addArguments({tensorType, tensorType},
6616 SmallVector<Location, 2>(2, loc));
6617 }
6618
6619 ComparisonType typeAttr;
6620 if (compareType)
6621 typeAttr = symbolizeComparisonType(*compareType).value();
6622 else
6623 typeAttr = ComparisonType::NOTYPE;
6624 Value compare = builder->create<CompareOp>(
6625 loc, block->getArgument(0), block->getArgument(1), direction, typeAttr);
6626
6627 builder->create<ReturnOp>(loc, compare);
6628 }
6629
createSortOp(PatternRewriter * rewriter,const Location & loc,const llvm::ArrayRef<Value> & operands,const llvm::ArrayRef<Type> & elementTypes,int64_t dimension,bool isStable,ComparisonDirection direction)6630 SortOp createSortOp(PatternRewriter* rewriter, const Location& loc,
6631 const llvm::ArrayRef<Value>& operands,
6632 const llvm::ArrayRef<Type>& elementTypes, int64_t dimension,
6633 bool isStable, ComparisonDirection direction) {
6634 assert(!operands.empty() && "No operands to sort");
6635 // Create the sort op.
6636 auto sortOp = rewriter->create<SortOp>(loc, operands, dimension, isStable);
6637
6638 // Use TOTALORDER comparison type instead of the default comparison if the
6639 // element type is of type float.
6640 llvm::Optional<StringRef> compareType = llvm::None;
6641 for (auto const& elementType : elementTypes)
6642 if (elementType.isa<FloatType>()) {
6643 compareType.emplace("TOTALORDER");
6644 break;
6645 }
6646 buildSortComparisonBody(elementTypes, direction, compareType,
6647 &sortOp.comparator(), rewriter);
6648 return sortOp;
6649 }
6650
6651 //===----------------------------------------------------------------------===//
6652 // StableHLO Dialect Hooks
6653 //===----------------------------------------------------------------------===//
6654
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)6655 Operation* StablehloDialect::materializeConstant(OpBuilder& builder,
6656 Attribute value, Type type,
6657 Location loc) {
6658 auto elementsAttr = value.dyn_cast<ElementsAttr>();
6659 // HLO dialect constants only support ElementsAttr unlike standard dialect
6660 // constant which supports all attributes.
6661 if (!elementsAttr) return nullptr;
6662 // HLO dialect constants require the type of value and result to match.
6663 if (type != elementsAttr.getType()) return nullptr;
6664
6665 return builder.create<ConstantOp>(loc, type, elementsAttr);
6666 }
6667
verifyRegionArgAttribute(Operation * op,unsigned,unsigned argIndex,NamedAttribute attr)6668 LogicalResult StablehloDialect::verifyRegionArgAttribute(
6669 Operation* op, unsigned /*regionIndex*/, unsigned argIndex,
6670 NamedAttribute attr) {
6671 if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
6672 if (failed(
6673 verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op)))
6674 return failure();
6675 }
6676 return success();
6677 }
6678
verifyOperationAttribute(Operation * op,NamedAttribute attr)6679 LogicalResult StablehloDialect::verifyOperationAttribute(Operation* op,
6680 NamedAttribute attr) {
6681 if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>()) {
6682 if (!isa<mlir::FunctionOpInterface>(op))
6683 return op->emitOpError()
6684 << "attribute " << attr.getName()
6685 << " can only be used on function-like operations";
6686 }
6687 return success();
6688 }
6689
6690 } // namespace stablehlo
6691 } // namespace mlir
6692