• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2   Copyright 2022 The StableHLO Authors.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#ifndef STABLEHLO_DIALECT_STABLEHLO_OPS
18#define STABLEHLO_DIALECT_STABLEHLO_OPS
19
20include "dialect/Base.td"
21include "mlir/Dialect/Shape/IR/ShapeBase.td"
22include "mlir/IR/OpBase.td"
23include "mlir/Interfaces/InferTypeOpInterface.td"
24include "mlir/Interfaces/SideEffectInterfaces.td"
25include "mlir/IR/OpAsmInterface.td"
26
27def StableHLO_Dialect : Dialect {
28  let name = "stablehlo";
29  let cppNamespace = "::mlir::stablehlo";
30
31  let description = [{
32    StableHLO is an operation set that expresses ML computations. It has been
33    originally bootstrapped from the MHLO dialect and enhances it with additional
34    functionality, including serialization and versioning, to be used as
35    a portability layer between ML frameworks and ML compilers.
36  }];
37
38  let emitAccessorPrefix = kEmitAccessorPrefix_Raw;
39  let useDefaultAttributePrinterParser = 0;
40  let useDefaultTypePrinterParser = 0;
41}
42
43class StableHLO_Op<string mnemonic, list<Trait> traits> :
44    Op<StableHLO_Dialect, mnemonic, traits> {
45}
46
47include "dialect/StablehloEnums.td"
48include "dialect/StablehloAttrs.td"
49
50class StableHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> :
51    StableHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
52    ["reifyReturnTypeShapes"]>]> {
53}
54
55//===----------------------------------------------------------------------===//
56// StableHLO nullary op definitions.
57//===----------------------------------------------------------------------===//
58
59def StableHLO_ConstantOp : StableHLO_Op<"constant",
60    [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
61  let summary = "Constant operator";
62  let description = [{
63    Represents a constant value.
64  }];
65  let arguments = (ins
66    ElementsAttr:$value
67  );
68
69  let results = (outs
70    HLO_StaticShapeTensor:$output
71  );
72
73  let builders = [
74    OpBuilder<(ins "Attribute":$value)>];
75
76  let hasCustomAssemblyFormat = 1;
77  let hasFolder = 1;
78
79  let extraClassDeclaration = [{
80    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
81  }];
82}
83
84def StableHLO_IotaOp : StableHLO_Op<"iota", [NoSideEffect]> {
85  let summary = "Iota operator";
86  let description = [{
87    Creates a rank 1 array of values starting at zero and incrementing by one.
88  }];
89  let arguments = (ins I64Attr:$iota_dimension);
90
91  let results = (outs HLO_IntFpOrComplexTensor:$output);
92
93  let hasVerifier = 1;
94}
95
96def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [NoSideEffect]> {
97  let summary = "Create linear increasing values from 0 to length -1.";
98  let description = [{
99    Produces an HLO Tensor of the specified shape, with an incremental set of
100    values along the specified dimension starting at 0.
101
102    Requires:
103    - The output length of the tensor result.
104  }];
105
106  let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension);
107  let results = (outs HLO_Tensor:$result);
108}
109
110def StableHLO_CreateTokenOp : StableHLO_Op<"create_token", [NoSideEffect]> {
111  let summary = "Create Token operator";
112
113  let description = [{
114    Produces a HLO token. Tokens are used for ordering side-effecting operations.
115    This is exported to HLO as an AfterAll operation with no operands to
116    generate a token.
117
118    Example:
119
120    ```mlir
121    %1 = stablehlo.create_token : !stablehlo.token
122    ```
123  }];
124
125  let results = (outs HLO_Token:$output);
126
127  let assemblyFormat = "attr-dict `:` type(results)";
128}
129
130//===----------------------------------------------------------------------===//
131// StableHLO unary elementwise op definitions.
132//===----------------------------------------------------------------------===//
133// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
134
135class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
136    Type OperandType, Type ResultType = OperandType> : StableHLO_Op<mnemonic, traits # [Elementwise,
137    InferShapedTypeOpInterface, SameOperandsAndResultShape]> {
138  let arguments = (ins OperandType:$operand);
139  let results = (outs ResultType:$result);
140  let extraClassDeclaration = [{
141    LogicalResult reifyReturnTypeShapes(
142        OpBuilder& builder, ValueRange operands,
143        SmallVectorImpl<Value>& reifiedReturnShapes) {
144      return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
145                                                operands.front(),
146                                                &reifiedReturnShapes);
147    }
148    // Relax the strict default implementation with one that allows
149    // for StableHLO-specific differences.
150    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
151      if (l.size() != r.size()) return false;
152      for (auto [lt, rt] : llvm::zip(l, r))
153        if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt))
154          return false;
155      return true;
156    }
157  }];
158  let extraClassDefinition = [{
159    ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) {
160      return ::mlir::stablehlo::parseUnaryOp(parser, result);
161    }
162    void $cppClass::print(OpAsmPrinter &p) {
163      ::mlir::stablehlo::printUnaryOp(getOperation(), p);
164    }
165  }];
166  let hasCustomAssemblyFormat = 1;
167}
168
169// Abs supports complex to real, so element type is not guaranteed to match.
170def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",
171    [NoSideEffect,
172     DeclareOpInterfaceMethods<InferTypeOpInterface>],
173     TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> {
174  let summary = "Absolute value operator";
175  let description = [{
176    Returns `abs(operand)` element-wise.
177
178    See
179    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
180  }];
181}
182
183def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt",
184    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
185  let summary = "Cubic root operator";
186  let description = [{
187    Returns element-wise cubic root of the operand.
188
189    See
190    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
191  }];
192}
193def StableHLO_CeilOp: StableHLO_UnaryElementwiseOp<"ceil",
194    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
195  let summary = "Ceil operator";
196  let description = [{
197    Returns `Ceil(operand)` element-wise.
198
199    See
200    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
201  }];
202}
203def StableHLO_ConvertOp : StableHLO_UnaryElementwiseOp<"convert",
204    [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor> {
205  let summary = "Convert operator";
206  let description = [{
207    Performs element-wise conversion of values from one type to another, e.g.
208    float to int.
209
210    See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
211  }];
212  let builders = [
213    OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>];
214}
215
216def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros",
217    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> {
218  let summary = "Count-leading-zeros (Clz) operator";
219  let description = [{
220    Returns the number of leading zeros in each operand element-wise.
221
222    See
223    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
224  }];
225}
226
227def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",
228    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
229  let summary = "Cos operator";
230  let description = [{
231    Returns `Cos(operand)` element-wise.
232
233    See
234    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
235  }];
236}
237
238def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
239    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
240  let summary = "Exponential operator";
241  let description = [{
242    Returns `e^(operand)` element-wise.
243
244    See
245    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
246  }];
247}
248def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
249    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
250  let summary = "Exponential minus one operator";
251  let description = [{
252    Returns `e^(operand) - 1` element-wise.
253
254    See
255    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
256  }];
257}
258def StableHLO_FloorOp: StableHLO_UnaryElementwiseOp<"floor",
259    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
260  let summary = "Floor operator";
261  let description = [{
262    Returns `Floor(operand)` element-wise.
263
264    See
265    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
266  }];
267}
268def StableHLO_ImagOp: StableHLO_UnaryElementwiseOp<"imag",
269    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
270    HLO_FpOrComplexTensor> {
271  let summary = "Imag operator";
272  let description = [{
273    Returns `Imag(operand)` element-wise.
274  }];
275  let results = (outs HLO_FpTensor);
276}
277
278def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [NoSideEffect,
279    DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor> {
280  let summary = "IsFinite operator";
281  let description = [{
282    Tests whether each element of operand is finite, i.e., is not positive or
283    negative infinity, and is not NaN. Returns a tensor of 1-bit integers with
284    the same shape as the input, where each element is nonzero (i.e. true) if
285    and only if the corresponding input element is finite.
286
287    See
288    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
289  }];
290  let arguments = (ins HLO_FpTensor:$x);
291  let results = (outs HLO_PredTensor:$y);
292}
293
294def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",
295    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
296  let summary = "Logarithm operator";
297  let description = [{
298    Returns `log(operand)` element-wise.
299
300    See
301    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
302  }];
303}
304def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
305    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
306  let summary = "Log1p operator";
307  let description = [{
308    Returns `log(operand+1)` element-wise.
309
310    See
311    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
312  }];
313}
314def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
315    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
316  let summary = "Logistic operator";
317  let description = [{
318    Returns `logistic(operand)` element-wise.
319
320    See
321    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
322  }];
323}
324def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not",
325    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_PredOrIntTensor> {
326  let summary = "Not operator";
327  let description = [{
328    Returns biwise-NOT of `operand` element-wise. The input tensor must be
329    of type integer `HLO_Int` or boolean `HLO_Pred`.
330
331    Note: For boolean tensor, the bitwise-NOT is equivalent to logical-NOT.
332  }];
333}
334
335def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate",
336    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
337  let summary = "Negation operator";
338  let description = [{
339    Returns `-operand` element-wise.
340
341    See
342    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
343  }];
344}
345
346def StableHLO_PopulationCountOp: StableHLO_UnaryElementwiseOp<"popcnt",
347    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> {
348  let summary = "PopulationCount operator";
349  let description = [{
350    Returns the number of bits set in each operand element-wise.
351
352    See
353    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
354  }];
355}
356def StableHLO_RealOp: StableHLO_UnaryElementwiseOp<"real",
357    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>],
358    HLO_FpOrComplexTensor> {
359  let summary = "Real operator";
360  let description = [{
361    Returns `Real(operand)` element-wise.
362  }];
363  let results = (outs HLO_FpTensor);
364}
365
366def StableHLO_RoundOp: StableHLO_UnaryElementwiseOp<"round_nearest_afz",
367    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
368  let summary = "Round operator, ties away from zero";
369  let description = [{
370    Returns `Round(operand)` element-wise, rounding to nearest integer with
371    half-way cases rounding away from zero.
372
373    See
374    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
375  }];
376}
377
378def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_even",
379    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> {
380  let summary = "Round operator, ties to even";
381  let description = [{
382    Returns `Round(operand)` element-wise, rounding to nearest integer with
383    half-way cases rounding towards even numbers.
384
385    See
386    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
387  }];
388}
389
390def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt",
391    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
392  let summary = "Reciprocal Square-root operator";
393  let description = [{
394    Returns `1.0 / sqrt(operand)` element-wise.
395
396    See
397    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
398  }];
399}
400def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
401    [NoSideEffect, HLO_CompatibleOperandsAndResultType],
402    TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> {
403  let summary = "Sign operator";
404  let description = [{
405    Returns `sign(operand)` element-wise, where
406
407    ```
408    sign(x) = -1  : x < 0
409            = -0  : x = -0
410            = NaN : x = NaN
411            = +0  : x = +0
412            = 1   : x > 0
413    ```
414
415    See
416    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
417  }];
418}
419
420def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",
421    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
422  let summary = "Sin operator";
423  let description = [{
424    Returns `Sin(operand)` element-wise.
425
426    See
427    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
428  }];
429}
430
431def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt",
432    [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
433  let summary = "Square-root operator";
434  let description = [{
435    Returns `sqrt(operand)` element-wise.
436
437    See
438    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
439  }];
440}
441
442def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
443    [NoSideEffect, HLO_CompatibleOperandsAndResultType],
444    HLO_FpOrComplexTensor> {
445  let summary = "Tanh operator";
446  let description = [{
447    Returns `tanh(operand)` element-wise.
448
449    See
450    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
451  }];
452}
453//===----------------------------------------------------------------------===//
454// StableHLO binary elementwise op definitions.
455//===----------------------------------------------------------------------===//
456// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
457
458class StableHLO_BinaryElementwiseOpNoAssembly<string mnemonic, list<Trait> traits> :
459    StableHLO_Op<mnemonic, traits # [InferShapedTypeOpInterface,
460    SameOperandsAndResultShape, Elementwise]> {
461  let arguments = (ins
462    HLO_Tensor:$lhs,
463    HLO_Tensor:$rhs
464  );
465
466  let extraClassDeclaration = [{
467    LogicalResult reifyReturnTypeShapes(
468        OpBuilder& builder, ValueRange operands,
469        SmallVectorImpl<Value>& reifiedReturnShapes) {
470      return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
471                                                 operands.front(),
472                                                 &reifiedReturnShapes);
473    }
474    // Relax the strict default implementation with one that allows
475    // for StableHLO-specific differences.
476    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
477      if (l.size() != r.size()) return false;
478      for (auto [lt, rt] : llvm::zip(l, r))
479        if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt))
480          return false;
481      return true;
482    }
483  }];
484
485  let results = (outs HLO_Tensor:$result);
486}
487
488class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits> :
489    StableHLO_BinaryElementwiseOpNoAssembly<mnemonic, traits> {
490  let extraClassDefinition = [{
491    ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) {
492      return ::mlir::stablehlo::parseBinaryOp(parser, result);
493    }
494    void $cppClass::print(OpAsmPrinter &p) {
495      ::mlir::stablehlo::printBinaryOp(getOperation(), p);
496    }
497  }];
498  let hasCustomAssemblyFormat = 1;
499}
500
501def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",
502      [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
503  let summary = "Addition operator";
504  let description = [{
505    Returns `lhs + rhs` element-wise.
506
507    See
508    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
509  }];
510}
511
512def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2",
513      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
514  let summary = "Atan2 operator";
515  let description = [{
516    Returns `atan2(lhs/rhs)` element-wise.
517
518    See
519    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
520  }];
521}
522
523def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOpNoAssembly<"complex", [NoSideEffect,
524    SameOperandsElementType, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
525  let summary = "Complex operator";
526  let description = [{
527    Performs element-wise conversion of a pair of real and imaginary values to
528    a complex value.
529  }];
530  let arguments = (ins HLO_Fp32Or64Tensor:$lhs, HLO_Fp32Or64Tensor:$rhs);
531  let results = (outs HLO_ComplexTensor:$result);
532
533  // TODO(b/241767457): Remove parens when cleaning up BinaryOps.
534  let assemblyFormat = "`(`operands`)` attr-dict `:` `(`type(operands)`)` `->` type($result)";
535}
536
537def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide",
538      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
539  let summary = "Division operator";
540  let description = [{
541    Returns `lhs / rhs` element-wise.
542
543    See
544    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
545  }];
546}
547
548def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum",
549      [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
550  let summary = "Maximum operator";
551  let description = [{
552    Returns `max(lhs, rhs)` element-wise.
553
554    See
555    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
556  }];
557}
558
559def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum",
560      [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
561  let summary = "Minimum operator";
562  let description = [{
563    Returns `min(lhs, rhs)` element-wise.
564
565    See
566    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
567  }];
568}
569
570def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply",
571      [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
572  let summary = "Multiplication operator";
573  let description = [{
574    Returns `lhs * rhs` element-wise.
575
576    See
577    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
578  }];
579}
580
581def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",
582      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
583  let summary = "Power operator";
584  let description = [{
585    Returns `lhs ^ rhs` element-wise.
586
587    See
588    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
589  }];
590}
591def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
592      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
593  let summary = "Remainder operator";
594  let description = [{
595    Returns `lhs % rhs` element-wise.
596
597    See
598    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
599  }];
600}
601
602def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left",
603      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
604  let summary = "Shift Left operator";
605  let description = [{
606    Returns `lhs << rhs` element-wise.
607
608    See
609    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
610  }];
611}
612
613def StableHLO_ShiftRightArithmeticOp : StableHLO_BinaryElementwiseOp<"shift_right_arithmetic",
614      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
615  let summary = "Shift right arithmetic operator";
616  let description = [{
617    Returns arithmetic `lhs >> rhs` element-wise.
618
619    See
620    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
621  }];
622}
623
624def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_logical",
625      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
626  let summary = "Shift right logical operator";
627  let description = [{
628    Returns logical `lhs >> rhs` element-wise.
629
630    See
631    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
632  }];
633}
634
635def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract",
636      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
637  let summary = "Subtraction operator";
638  let description = [{
639    Returns `lhs - rhs` element-wise.
640
641    See
642    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
643  }];
644}
645
646//===----------------------------------------------------------------------===//
647// StableHLO binary logical elementwise op definitions.
648//===----------------------------------------------------------------------===//
649
650// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
651class StableHLO_BinaryBiwiseOrLogicalElementwiseOp<string mnemonic> :
652        StableHLO_BinaryElementwiseOp<mnemonic,
653          [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
654  let arguments = (ins
655    HLO_PredOrIntTensor:$lhs,
656    HLO_PredOrIntTensor:$rhs
657  );
658}
659
660def StableHLO_AndOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> {
661  let summary = "And operator";
662  let description = [{
663    Returns biwise-AND of `lhs` and `rhs` element-wise. The input tensors must
664    be of type integer `HLO_Int` or boolean `HLO_Pred`.
665
666    Note: For boolean tensor, the bitwise-AND is equivalent to logical-AND.
667  }];
668}
669
670def StableHLO_OrOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> {
671  let summary = "Or operator";
672  let description = [{
673    Returns biwise-OR of `lhs` and `rhs` element-wise. The input tensors must
674    be of type integer `HLO_Int` or boolean `HLO_Pred`.
675
676    Note: For boolean tensor, the bitwise-OR is equivalent to logical-OR.
677  }];
678}
679
680def StableHLO_XorOp : StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> {
681  let summary = "Xor operator";
682  let description = [{
683    Returns biwise-XOR of `lhs` and `rhs` element-wise. The input tensors must
684    be of type integer `HLO_Int` or boolean `HLO_Pred`.
685
686    Note: For boolean tensor, the bitwise-XOR is equivalent to logical-XOR.
687  }];
688}
689
690//===----------------------------------------------------------------------===//
691// StableHLO communication op definitions.
692//===----------------------------------------------------------------------===//
693
694// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
695// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
696def StableHLO_InfeedOp : StableHLO_Op<"infeed", []> {
697
698  let summary = "Infeed operator";
699
700  let description = [{
701    Reads a single data item from the implicit Infeed streaming interface of
702    the device, interpreting the data as the given shape, and returns a XlaOp
703    of the data. Multiple Infeed operations are allowed in a computation, but
704    there must be a total order among the Infeed operations.
705
706    Attributes:
707      layout:  Array attribute. Each element of the array is a minor_to_major
708               array corresponding to the shape of the data read from the infeed
709               interface.
710
711    See https://www.tensorflow.org/xla/operation_semantics#infeed.
712  }];
713
714  let arguments = (ins
715    HLO_Token:$token,
716    DefaultValuedStrAttr<StrAttr, "">:$infeed_config,
717    OptionalAttr<ArrayAttr>:$layout
718  );
719  let results = (outs Variadic<HLO_TensorOrToken>);
720  let hasVerifier = 1;
721}
722
723// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'.
724// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens.
725def StableHLO_OutfeedOp : StableHLO_Op<"outfeed", []> {
726
727  let summary = "Outfeed operator";
728
729  let description = [{
730    Generates outgoing data transfers for the given data. It takes data and a
731    token type operand and produces a token type value. Tokens are used for
732    ordering side-effecting operations.
733
734    See https://www.tensorflow.org/xla/operation_semantics#outfeed.
735  }];
736
737  let arguments = (ins
738    Variadic<HLO_Tensor>:$operands,
739    HLO_Token:$token,
740    DefaultValuedStrAttr<StrAttr, "">:$outfeed_config
741  );
742  let results = (outs HLO_Token);
743}
744
745def StableHLO_SendOp : StableHLO_Op<"send", []> {
746
747  let summary = "Send operator";
748
749  let description = [{
750    Sends the given operand data to a Recv instruction in another computation
751    that shares the same channel handle. Does not return any data. Similar to
752    the Recv operation, Send operation represents synchronous communication,
753    and is internally decomposed into 2 HLO instructions (Send and SendDone) to
754    enable asynchronous data transfers.
755
756    See https://www.tensorflow.org/xla/operation_semantics#send.
757  }];
758
759  let arguments = (ins
760    Variadic<HLO_Tensor>:$operands,
761    HLO_Token:$token,
762    StableHLO_ChannelHandle:$channel_handle,
763    DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
764  );
765
766  let results = (outs HLO_Token);
767}
768
769def StableHLO_RecvOp : StableHLO_Op<"recv", []> {
770
771  let summary = "Recv operator";
772
773  let description = [{
774    Receives data of the given shape from a Send instruction in another
775    computation that shares the same channel handle. Returns a tuple containing
776    value for the received data and a token. Recv operation represents
777    synchronous communication. However, the instruction is internally decomposed
778    into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data
779    transfers.
780
781    See https://www.tensorflow.org/xla/operation_semantics#recv.
782  }];
783
784  let arguments = (ins
785    HLO_Token:$token,
786    StableHLO_ChannelHandle:$channel_handle,
787    DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
788  );
789
790  let results = (outs Variadic<HLO_TensorOrToken>);
791  let hasVerifier = 1;
792}
793
794//===----------------------------------------------------------------------===//
795// StableHLO parallelism related op definitions.
796//===----------------------------------------------------------------------===//
797
798def StableHLO_ReplicaIdOp : StableHLO_Op<"replica_id", [NoSideEffect,
799    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
800  let summary = "ReplicaId operator";
801  let description = [{
802    Returns the unique ID (int32 scalar) of the replica.
803
804    The unique ID of each replica is an unsigned integer in the interval [0, N),
805    where N is the number of replicas. Since all the replicas are running the
806    same program, a ReplicaId() call in the program will return a different
807    value on each replica.
808
809    See https://www.tensorflow.org/xla/operation_semantics#replicaid.
810
811    Example:
812
813    ```mlir
814    %0 = stablehlo.replica_id : tensor<ui32>
815    ```
816  }];
817  let results = (outs TensorOf<[UI32]>);
818
819  let assemblyFormat = "attr-dict `:` type(results)";
820}
821
822//===----------------------------------------------------------------------===//
823// StableHLO control flow op definitions.
824//===----------------------------------------------------------------------===//
825
826def StableHLO_AfterAllOp : StableHLO_Op<"after_all", [NoSideEffect]> {
827
828  let summary = "AfterAll operator";
829
830  let description = [{
831    AfterAll takes a variadic number of tokens and produces a single token.
832    Tokens are primitive types which can be threaded between side-effecting
833    operations to enforce ordering. AfterAll can be used as a join of tokens
834    for ordering a operation after a set operations.
835
836    See https://www.tensorflow.org/xla/operation_semantics#afterall.
837
838    Example:
839
840    ```mlir
841    %0 = stablehlo.after_all %arg0, %arg1 : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
842    ```
843  }];
844
845  let arguments = (ins Variadic<HLO_Token>:$operands);
846  let results = (outs HLO_Token);
847
848  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
849}
850
851// Xla Client API has two separate calls for indexed and predicated conditional,
852// although both eventually map to kConditional HLO. IfOp maps to predicated
853// conditional use of kConditional HLO.
854def StableHLO_IfOp: StableHLO_Op<"if", [
855    RecursiveSideEffects,
856    SingleBlockImplicitTerminator<"ReturnOp">]> {
857  let summary = "If operator";
858
859  let description = [{
860    Executes the function `true_branch` if `pred` is true or `false_branch` if
861    pred is false, and returns the result.
862
863    The type of the returned values of `true_branch` and `false_branch`
864    functions must be the same and equal to the types of the values returned by
865    the operation.
866
867    Note that only one of two functions will be executed depending on the value
868    of `pred`.
869  }];
870
871  let arguments = (ins
872    HLO_PredTensor:$pred
873  );
874
875  let regions = (region SizedRegion<1>:$true_branch,
876                        SizedRegion<1>:$false_branch);
877
878  let results = (outs Variadic<HLO_TensorOrToken>);
879
880  let hasVerifier = 1;
881}
882
883// Xla Client API has two separate calls for indexed and predicated conditional,
884// although both eventually map to kConditional HLO. CaseOp maps to indexed
885// conditional use of kConditional HLO.
886def StableHLO_CaseOp: StableHLO_Op<"case", [
887      RecursiveSideEffects,
888      SingleBlockImplicitTerminator<"ReturnOp">
889    ]> {
890  let summary = "Switch-Case operator";
891  let description = [{
892    Returns the result of executing `branches[index]`. If `index` is < 0 or >=
893    N, then `branches[N-1]` is executed as the default branch.
894
895    The type of the returned values of each branch must be the same and equal
896    to the types of the values returned by the operation.
897
898    Note that only one of the branches will be executed depending on the value
899    of index.
900  }];
901
902  let arguments = (ins
903    I32Tensor:$index
904  );
905
906  let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
907
908  let results = (outs Variadic<HLO_TensorOrToken>);
909
910  let hasVerifier = 1;
911}
912
913
914def StableHLO_WhileOp: StableHLO_Op<"while", [
915      RecursiveSideEffects,
916      HLO_PairwiseSameOperandAndResultType,
917      SingleBlockImplicitTerminator<"ReturnOp">,
918      OpAsmOpInterface
919    ]> {
920  let summary = "While operator";
921  let description = [{
922    Returns the result of executing a body function until the cond body returns
923    true.
924
925    See https://www.tensorflow.org/xla/operation_semantics#while.
926  }];
927  let arguments = (ins Variadic<HLO_TensorOrToken>:$operand);
928
929  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
930
931  let results = (outs Variadic<HLO_TensorOrToken>);
932
933  let extraClassDeclaration = [{
934    // Method of OpAsmOpInterface used during custom printing to name the block
935    // arguments in the nested regions. We name both the condition and the body
936    // regions entry arguments the same way, with a `iterArg` prefix. Since the
937    // two regions are side-by-side they will have the same name, which allows
938    // us to print them once and share it for the two regions, and still be able
939    // to parse them back.
940    void getAsmBlockArgumentNames(Region &region, OpAsmSetValueNameFn setNameFn) {
941      for (BlockArgument arg : region.getArguments())
942        setNameFn(arg, "iterArg");
943    }
944  }];
945  let hasCustomAssemblyFormat = 1;
946  let hasVerifier = 1;
947}
948
949def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", [SameOperandsAndResultElementType]> {
950
951  string summary = "AllGather operator";
952
953  string description = [{
954    Performs concatenation across replicas.
955
956    See https://www.tensorflow.org/xla/operation_semantics#allgather
957  }];
958
959  let arguments = (ins
960    HLO_Tensor:$operand,
961    I64Attr:$all_gather_dim,
962    I64ElementsAttr:$replica_groups,
963    OptionalAttr<StableHLO_ChannelHandle>:$channel_handle
964  );
965  let results = (outs HLO_Tensor);
966  let hasVerifier = 1;
967}
968
969def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
970    [HLO_CompatibleOperandsAndResultType]> {
971  let summary = "AllReduce operator";
972  let description = [{
973    Performs a custom reduction across replicas.
974
975    See https://www.tensorflow.org/xla/operation_semantics#allreduce.
976  }];
977
978  let arguments = (ins
979    HLO_Tensor:$operand,
980    I64ElementsAttr:$replica_groups,
981    OptionalAttr<StableHLO_ChannelHandle>:$channel_handle,
982    UnitAttr:$use_global_device_ids
983  );
984  let regions = (region SizedRegion<1>:$computation);
985  let results = (outs HLO_Tensor);
986  // use_global_device_ids is rarely used, so we add a simplified
987  // builder method for convenience.
988  let builders = [
989    OpBuilder<(ins
990      "::mlir::Type":$result_type, "::mlir::Value":$operand,
991      "::mlir::DenseIntElementsAttr":$replica_groups,
992      "::mlir::stablehlo::ChannelHandleAttr":$channel_handle)>];
993}
994
995def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter",
996    [SameOperandsAndResultElementType]> {
997  let summary = "ReduceScatter operator";
998  let description = [{
999     Performs all_reduce followed by a scatter.
1000
1001     See https://www.tensorflow.org/xla/operation_semantics#reducescatter
1002  }];
1003
1004  let arguments = (ins
1005    HLO_Tensor:$operand,
1006    I64Attr:$scatter_dimension,
1007    I64ElementsAttr:$replica_groups,
1008    OptionalAttr<StableHLO_ChannelHandle>:$channel_handle
1009  );
1010  let regions = (region SizedRegion<1>:$computation);
1011  let results = (outs HLO_Tensor);
1012  let hasVerifier = 1;
1013}
1014
1015def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
1016    [NoSideEffect, SameOperandsElementType, SameOperandsShape,
1017     InferTensorType]> {
1018
1019  let arguments = (ins
1020    HLO_Tensor:$operand,
1021    I64Attr:$split_dimension,
1022    I64Attr:$concat_dimension,
1023    I64Attr:$split_count,
1024    I64ElementsAttr:$replica_groups
1025  );
1026  let results = (outs HLO_Tensor);
1027}
1028
1029def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
1030      RecursiveSideEffects,
1031      SameVariadicOperandSize,
1032      SingleBlockImplicitTerminator<"ReturnOp">
1033    ]> {
1034  let summary = "Reduce operator";
1035  let description = [{
1036    Returns the result of executing a reduction function on one or more arrays
1037    in parallel.
1038
1039    See https://www.tensorflow.org/xla/operation_semantics#reduce.
1040  }];
1041  let arguments = (ins
1042    Variadic<HLO_Tensor>:$operands,
1043    Variadic<HLO_Tensor>:$init_values,
1044    I64ElementsAttr:$dimensions
1045  );
1046
1047  let results = (outs Variadic<HLO_Tensor>);
1048
1049  let builders = [
1050    OpBuilder<(ins "ValueRange":$operands, "ValueRange":$init_values,
1051      "DenseIntElementsAttr":$dimensions)>];
1052
1053  let hasCustomAssemblyFormat = 1;
1054  let hasVerifier = 1;
1055
1056  // TODO(hinsu): Verify that the attached body arguments and results are
1057  // compatible with reduce op's operands.
1058  let regions = (region SizedRegion<1>:$body);
1059}
1060
1061//===----------------------------------------------------------------------===//
1062// StableHLO tuple op definitions.
1063//===----------------------------------------------------------------------===//
1064def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [NoSideEffect,
1065     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1066  let summary = "GetTupleElement operator";
1067  let description = [{
1068    Returns a member of a tuple specified by an index.
1069
1070    See https://www.tensorflow.org/xla/operation_semantics#gettupleelement.
1071  }];
1072  let arguments = (ins
1073    HLO_Tuple,
1074    I32Attr:$index
1075  );
1076
1077  let results = (outs HLO_TensorOrTokenOrTuple);
1078
1079  let hasVerifier = 1;
1080}
1081
1082def StableHLO_TupleOp : StableHLO_Op<"tuple", [NoSideEffect,
1083     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1084  let summary = "XLA's tuple op";
1085  let description = [{
1086     Groups a set of tensor inputs into a single tuple object.
1087
1088     See https://www.tensorflow.org/xla/operation_semantics#tuple.
1089   }];
1090  let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
1091  let results = (outs HLO_Tuple);
1092
1093  let hasVerifier = 1;
1094}
1095
1096def StableHLO_CompareOp: StableHLO_Op<"compare", [NoSideEffect, SameOperandsElementType,
1097    SameOperandsAndResultShape, Elementwise, InferTensorTypeWithReify]> {
1098  let summary = "Comparison operator";
1099  let description = [{
1100    Compares `lhs` and `rhs` elementwise according to `comparison_direction`
1101    and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
1102    types, SIGNED for signed element types and UNSIGNED for unsigned element
1103    types.
1104
1105    See
1106    https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
1107
1108    Example:
1109
1110    ```mlir
1111    %0 = stablehlo.compare LT, %arg0, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
1112    %1 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
1113    ```
1114  }];
1115  let arguments = (ins
1116    HLO_Tensor:$lhs,
1117    HLO_Tensor:$rhs,
1118    StableHLO_ComparisonDirectionAttr:$comparison_direction,
1119    OptionalAttr<StableHLO_ComparisonTypeAttr>:$compare_type
1120  );
1121  let results = (outs HLO_PredTensor);
1122
1123  let builders = [
1124    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
1125      "::mlir::stablehlo::ComparisonDirection":$comparison_direction,
1126      CArg<"::mlir::stablehlo::ComparisonType",
1127      "::mlir::stablehlo::ComparisonType::NOTYPE">:$compare_type)>,
1128  ];
1129
1130  let extraClassDeclaration = [{
1131    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1132      return succeeded(mlir::verifyCompatibleShapes(l, r));
1133    }
1134  }];
1135
1136  let assemblyFormat = [{
1137    $comparison_direction `,` $lhs `,` $rhs (`,` $compare_type^)?
1138      attr-dict `:` functional-type(operands, results)
1139  }];
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// StableHLO Slice definitions.
1144//===----------------------------------------------------------------------===//
1145
1146def StableHLO_SliceOp: StableHLO_Op<
1147      "slice",
1148      [NoSideEffect, SameOperandsAndResultElementType,
1149       AllTypesMatch<["start_indices", "limit_indices", "strides"]>,
1150       DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1151  let arguments = (ins
1152    HLO_Tensor:$operand,
1153    I64ElementsAttr:$start_indices,
1154    I64ElementsAttr:$limit_indices,
1155    I64ElementsAttr:$strides
1156  );
1157
1158  let results = (outs HLO_Tensor);
1159}
1160
1161def StableHLO_DynamicSliceOp: StableHLO_Op<"dynamic_slice",
1162      [NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
1163       InferTensorType]> {
1164  let summary = "Dynamic Slice operator";
1165  let description = [{
1166    Extracts a sub-array from the input array at dynamic start_indices.
1167
1168    See https://www.tensorflow.org/xla/operation_semantics#dynamicslice.
1169  }];
1170  let arguments = (ins
1171    HLO_Tensor:$operand,
1172    Variadic<HLO_ScalarIntTensor>:$start_indices,
1173    I64ElementsAttr:$slice_sizes
1174  );
1175
1176  let results = (outs HLO_Tensor:$result);
1177  let hasVerifier = 1;
1178}
1179
1180def StableHLO_DynamicUpdateSliceOp: StableHLO_Op<"dynamic_update_slice",
1181      [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>,
1182       AllShapesMatch<["operand", "result"]>]> {
1183  let summary = "Dynamic Update Slice operator";
1184  let description = [{
1185    DynamicUpdateSlice generates a result which is the value of the input array
1186    operand, with a slice update overwritten at start_indices.
1187
1188    See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
1189
1190    Example:
1191
1192    ```mlir
1193    %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2
1194           : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
1195    ```
1196  }];
1197  let arguments = (ins
1198    HLO_Tensor:$operand,
1199    HLO_Tensor:$update,
1200    Variadic<HLO_ScalarIntTensor>:$start_indices
1201  );
1202  let results = (outs HLO_Tensor:$result);
1203  let hasVerifier = 1;
1204
1205  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
1206}
1207
1208
1209//===----------------------------------------------------------------------===//
1210// StableHLO Other op definitions.
1211//===----------------------------------------------------------------------===//
1212
1213def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [NoSideEffect,
1214    AllShapesMatch<["scale", "mean", "variance", "grad_scale",
1215        "grad_offset"]>,
1216    AllShapesMatch<["operand", "grad_output"]>,
1217    AllElementTypesMatch<["operand", "grad_scale", "grad_offset"]>,
1218    AllTypesMatch<["operand", "grad_operand"]>]> {
1219  let summary = "Batch Normalization Gradient";
1220  let description = [{
1221    Calculates gradients of batch norm.
1222
1223    See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad
1224  }];
1225
1226  let arguments = (ins
1227    RankedTensorOf<[HLO_Float]>:$operand,
1228    1DTensorOf<[HLO_Float]>:$scale,
1229    1DTensorOf<[HLO_Float]>:$mean,
1230    1DTensorOf<[HLO_Float]>:$variance,
1231    RankedTensorOf<[HLO_Float]>:$grad_output,
1232    F32Attr:$epsilon,
1233    I64Attr:$feature_index
1234  );
1235
1236  let results = (outs Variadic<HLO_TensorOrToken>);
1237  let results = (outs
1238      RankedTensorOf<[HLO_Float]>:$grad_operand,
1239      1DTensorOf<[HLO_Float]>:$grad_scale,
1240      1DTensorOf<[HLO_Float]>:$grad_offset);
1241
1242  let hasVerifier = 1;
1243}
1244
1245def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference",
1246    [NoSideEffect, AllTypesMatch<["operand", "result"]>,
1247    AllShapesMatch<["scale", "offset", "mean", "variance"]>]> {
1248  let summary = "Batch Normalization for Inference";
1249  let description = [{
1250    Normalizes an array across batch and spatial dimensions.
1251
1252    See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
1253  }];
1254
1255  let arguments = (ins
1256    RankedTensorOf<[HLO_Float]>:$operand,
1257    1DTensorOf<[HLO_Float]>:$scale,
1258    1DTensorOf<[HLO_Float]>:$offset,
1259    1DTensorOf<[HLO_Float]>:$mean,
1260    1DTensorOf<[HLO_Float]>:$variance,
1261    F32Attr:$epsilon,
1262    I64Attr:$feature_index
1263  );
1264
1265  let results = (outs RankedTensorOf<[HLO_Float]>:$result);
1266
1267  let hasVerifier = 1;
1268}
1269
1270def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training",
1271    [NoSideEffect, AllTypesMatch<["operand", "output"]>,
1272    AllElementTypesMatch<["operand", "batch_mean", "batch_var"]>,
1273    AllShapesMatch<["scale", "offset", "batch_mean", "batch_var"]>]> {
1274  let summary = "Batch Normalization for Training";
1275  let description = [{
1276    Normalizes an array across batch and spatial dimensions.
1277
1278    See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
1279  }];
1280
1281  let arguments = (ins
1282    RankedTensorOf<[HLO_Float]>:$operand,
1283    1DTensorOf<[HLO_Float]>:$scale,
1284    1DTensorOf<[HLO_Float]>:$offset,
1285    F32Attr:$epsilon,
1286    I64Attr:$feature_index
1287  );
1288
1289  let results = (outs
1290      RankedTensorOf<[HLO_Float]>:$output,
1291      1DTensorOf<[HLO_Float]>:$batch_mean,
1292      1DTensorOf<[HLO_Float]>:$batch_var);
1293
1294  let hasVerifier = 1;
1295}
1296
1297def StableHLO_BitcastConvertOp : StableHLO_ShapedInterfaceOp<"bitcast_convert",
1298    [NoSideEffect]> {
1299  let summary = "BitcastConvert operator";
1300  let description = [{
1301    Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
1302    operation from a data shape to a target shape. The dimensions must match,
1303    and the conversion is an element-wise one. Bitcast is implemented as a
1304    low-level cast, so machines with different floating-point representations
1305    will give different results.
1306
1307    See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
1308
1309    Example:
1310
1311    ```mlir
1312    %0 = stablehlo.bitcast_convert %arg0 : (tensor<2xi32>) -> tensor<2xf32>
1313    ```
1314  }];
1315
1316  let arguments = (ins HLO_Tensor:$operand);
1317  let results = (outs HLO_Tensor);
1318  let hasVerifier = 1;
1319
1320  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
1321}
1322
1323def StableHLO_BroadcastOp : StableHLO_ShapedInterfaceOp<"broadcast",
1324    [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> {
1325  let summary = "Broadcast a tensor to a higher rank by prepending dimensions";
1326  let description = [{
1327    Broadcasts the operand tensor to a higher rank by prepending
1328    `broadcast_sizes` to the dimensions. The current values of the operand are
1329    copied into the other dimensions.
1330
1331    This is a more limited form of broadcasting, that corresponds to the XLA
1332    client Broadcast method. For a more general form of broadcasting, see the
1333    BroadcastInDimOp.
1334
1335    See https://www.tensorflow.org/xla/operation_semantics#broadcast.
1336  }];
1337  let arguments = (ins
1338    HLO_Tensor:$operand,
1339    I64ElementsAttr:$broadcast_sizes
1340  );
1341
1342  let results = (outs HLO_Tensor);
1343
1344  let hasVerifier = 1;
1345}
1346
1347def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
1348      [NoSideEffect, SameOperandsAndResultElementType]> {
1349  let summary = "Broadcast a tensor into the given shape by adding dimensions.";
1350  let description = [{
1351    Broadcasts the `operand` tensor to a higher rank. This is not the limited
1352    form of broadcasting exposed as the XLA client broadcast op, but rather the
1353    more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
1354    and exposed in the XLA client BroadcastInDim method.
1355
1356    `broadcast_dimensions` maps the operand dimension number to the target shape
1357    dimension number. It must have the same size as the rank of the operand. The
1358    mapped dimensions must either be the same size or the dimension being
1359    broadcast from must be size 1 (degenerate broadcasting).
1360
1361    For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
1362    The scalar value will be broadcast to every element in the target shape.
1363
1364    See https://www.tensorflow.org/xla/broadcasting.
1365  }];
1366  let arguments = (ins
1367    HLO_Tensor:$operand,
1368    BroadcastDimAttr:$broadcast_dimensions
1369  );
1370
1371  let results = (outs HLO_StaticShapeTensor);
1372
1373  let hasVerifier = 1;
1374}
1375
1376def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
1377    "dynamic_broadcast_in_dim", [NoSideEffect]> {
1378  let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
1379  let description = [{
1380    This is a generalization of the BroadcastInDimOp which accepts its output
1381    dimensions as an argument. It should eventually supercede the statically
1382    shaped original, but is being phased as a separate op in order to support
1383    compatibility with lowerings and translations that precede dynamic shapes.
1384
1385    The op accepts optional attributes to express static knowledge about the
1386    expanding behavior of dimensions. If not specified, all dimensions are
1387    assumed to be possibly expanding. The sets of dimensions that are known to
1388    be expanding and the set of dimensions that are known to be non-expanding
1389    must be disjoint and they must be a subset of the operand's dimensions.
1390  }];
1391  let arguments = (ins
1392    HLO_Tensor:$operand,
1393    HLO_DimensionTensor:$output_dimensions,
1394    BroadcastDimAttr:$broadcast_dimensions,
1395    OptionalAttr<BroadcastDimAttr>:$known_expanding_dimensions,
1396    OptionalAttr<BroadcastDimAttr>:$known_nonexpanding_dimensions
1397  );
1398
1399  let results = (outs HLO_Tensor);
1400
1401  let builders = [
1402    OpBuilder<(ins
1403        "Type":$result_type, "Value":$operand, "Value":$output_dimensions,
1404        "DenseIntElementsAttr":$broadcast_dimensions), [{
1405      build($_builder, $_state, result_type, operand, output_dimensions,
1406          broadcast_dimensions, /*known_expanding_dimensions=*/{},
1407          /*known_nonexpanding_dimensions=*/{});
1408    }]>
1409  ];
1410
1411  let hasVerifier = 1;
1412}
1413
1414// Note: There is no HLO_CallOp because the standard call operation mlir::func::CallOp
1415// is used instead. A mlir::func::CallOp is exported to a HLO call instruction
1416// directly.
1417
1418def StableHLO_CholeskyOp : StableHLO_Op<"cholesky",
1419      [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> {
1420  let summary = "Cholesky operator";
1421  let description = [{
1422  Computes the Cholesky decomposition of a batch of symmetric (Hermitian)
1423  positive definite matrices.
1424
1425  If lower is true, computes lower-triangular matrices l such that
1426  `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such
1427  that `a=Transpose(u).u`.
1428
1429  Input data is read only from the lower/upper triangle of a, depending on the
1430  value of lower. Values from the other triangle are ignored. Output data is
1431  returned in the same triangle; the values in the other triangle are
1432  implementation-defined and may be anything.
1433
1434  If the rank of a is greater than 2, a is treated as a batch of matrices, where
1435  all except the minor 2 dimensions are batch dimensions.
1436
1437  If a is not symmetric (Hermitian) positive definite, the result is
1438  implementation-defined.
1439
1440    See https://www.tensorflow.org/xla/operation_semantics#cholesky.
1441  }];
1442  let arguments = (ins
1443    HLO_FpOrComplexTensor:$a,
1444    DefaultValuedAttr<BoolAttr, "false">:$lower
1445  );
1446
1447  let results = (outs HLO_FpOrComplexTensor);
1448}
1449
1450def StableHLO_ClampOp : StableHLO_ShapedInterfaceOp<"clamp", [NoSideEffect,
1451  SameOperandsAndResultElementType, HLO_BroadcastingElementwise,
1452  InferTensorType]> {
1453  let summary = "Clamp operator";
1454  let description = [{
1455    Clamps an operand to within the range between a minimum and maximum value.
1456
1457    Note: All three arrays must be the same shape. Alternatively, as a
1458          restricted form of broadcasting, min and/or max can be a scalar (0D
1459          tensor) of the element type of the tensor operand.
1460
1461    See https://www.tensorflow.org/xla/operation_semantics#clamp.
1462
1463    Example:
1464
1465    ```mlir
1466    %0 = stablehlo.clamp %arg0, %arg1, %arg2 : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
1467    ```
1468  }];
1469
1470  let arguments = (ins
1471    HLO_Tensor:$min,
1472    HLO_Tensor:$operand,
1473    HLO_Tensor:$max
1474  );
1475  let results = (outs HLO_Tensor);
1476
1477  let hasVerifier = 1;
1478
1479  let extraClassDeclaration = [{
1480    // Method from InferTypeOpInterface interface.
1481    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1482      if (l.size() != r.size()) return false;
1483      for (auto [lt, rt] : llvm::zip(l, r))
1484        if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt))
1485          return false;
1486      return true;
1487    }
1488  }];
1489
1490  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
1491}
1492
1493def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate",
1494    [NoSideEffect, SameOperandsAndResultElementType,
1495     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1496  let summary = "XLA's concatenate op";
1497  let description = [{
1498     Concatenates a set of tensors along the specified dimension.
1499
1500     See https://www.tensorflow.org/xla/operation_semantics#concatenate.
1501   }];
1502
1503  let arguments = (ins
1504    Variadic<HLO_Tensor>:$val,
1505    I64Attr: $dimension
1506  );
1507
1508  let results = (outs HLO_Tensor);
1509
1510  let hasVerifier = 1;
1511
1512  let extraClassDeclaration = [{
1513    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1514      return succeeded(mlir::verifyCompatibleShapes(l, r));
1515    }
1516  }];
1517}
1518
1519def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute",
1520    [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
1521  let summary = "CollectivePermute operator";
1522  let description = [{
1523    CollectivePermute is a collective operation that sends and receives data
1524    cross replicas.
1525    Note that there are the following restrictions on the source_target_pair:
1526    - Any two pairs should not have the same target replica id, and they should
1527    not have the same source replica id.
1528    - If a replica id is not a target in any pair, then the output on that
1529    replica is a tensor consists of 0(s) with the same shape as the input.
1530
1531    See https://www.tensorflow.org/xla/operation_semantics#collectivepermute.
1532
1533  }];
1534
1535  let arguments = (ins
1536    HLO_Tensor:$operand,
1537    I64ElementsAttr:$source_target_pairs
1538  );
1539  let results = (outs HLO_Tensor);
1540  let hasVerifier = 1;
1541}
1542
1543def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [NoSideEffect]> {
1544  let summary = "Convolution operator";
1545  let description = [{
1546    Computes a convolution of the kind used in neural networks.
1547
1548    See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
1549  }];
1550  let arguments = !con(
1551    (ins
1552       HLO_Tensor:$lhs,
1553       HLO_Tensor:$rhs),
1554    StableHLO_ConvolutionAttributes.attributes);
1555
1556  let results = (outs HLO_Tensor);
1557  let hasVerifier = 1;
1558
1559  code extraClassDeclaration = [{
1560    bool hasWindowReversal() {
1561      auto reversal = window_reversalAttr();
1562      return reversal && llvm::any_of(reversal.getValues<bool>(),
1563                                      [](bool v) { return v; });
1564    }
1565  }];
1566
1567 let assemblyFormat = [{
1568    `(`operands`)`
1569       `dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
1570       `window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
1571                                                 $lhs_dilation, $rhs_dilation,
1572                                                 $window_reversal) `}`
1573       attr-dict `:` functional-type(operands, results)
1574  }];
1575}
1576
1577def StableHLO_CrossReplicaSumOp : StableHLO_Op<"cross-replica-sum",
1578    [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
1579  let summary = "Sums input across replicated instances.";
1580  let description = [{
1581     For each of the replica groups, operands of the group devices are summed
1582     so that each device has the sum.
1583
1584     For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`.
1585     Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
1586     and `B, D, F, H` as group 1. Thus we get the outputs:
1587     `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
1588
1589     See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum.
1590   }];
1591
1592  let arguments = (ins
1593    HLO_Tensor:$operand,
1594    I64ElementsAttr:$replica_groups
1595  );
1596
1597  let results = (outs HLO_Tensor);
1598}
1599
1600def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
1601    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
1602  let summary = "CustomCall operator";
1603  let description = [{
1604    A custom call invokes code external to XLA. The `args` are passed to the
1605    external code, and the external code is expected to produce a result of the
1606    given type. The exact mechanism is backend-specific. For example, in the CPU
1607    backend, a call instruction is emitted which targets a symbol with the name
1608    `call_target_name`.
1609
1610    `call_target_name` and `backend_config` can be arbitrary strings, but
1611    `call_target_name` should be short as it may be used in labels.
1612    `backend_config` can encode arbitrarily large amounts of information.
1613
1614    `has_side_effect` must be true if the custom call has side-effects.
1615    `api_version` specifies the version of the API used by the custom call
1616    function.
1617
1618    A custom call may apply functions within the scope of the parent module.
1619    They can be referenced using `called_computations` attribute.
1620
1621    A custom call can also have layout constraints on operands and results which
1622    can be specified as optional `operand_layouts` and `result_layouts`
1623    attributes. The layout attribute is an array of rank-1 index tensors and the
1624    i-th layout attribute specifies the layout for i-th operand/result.
1625
1626    The `operand_layouts` & `result_layouts` attributes can be specified under
1627    the following constraints:
1628    1) Either both `operand_layouts` and `result_layouts` are specified or none.
1629    2) None of the operands are of tuple type.
1630    3) None of the results are of tuple type except the common case of single
1631       tuple result packing non-tuple values is allowed. In this case the i-th
1632       `result_layouts` attribute specifies the layout of i-th element in the
1633       result tuple.
1634
1635    See https://www.tensorflow.org/xla/operation_semantics#customcall.
1636  }];
1637  let arguments = (ins
1638    Variadic<HLO_TensorOrTokenOrTuple>:$operands,
1639    StrAttr:$call_target_name,
1640    DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
1641    DefaultValuedStrAttr<StrAttr, "">:$backend_config,
1642    // TODO(b/189822916): Remove this field when all clients are migrated to
1643    // the status-returning API.
1644    DefaultValuedAttr<
1645        StableHLO_CustomCallApiVersionAttr,
1646        "::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL">:
1647        $api_version,
1648    DefaultValuedAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations,
1649    OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts,
1650    OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts
1651  );
1652  let results = (outs Variadic<HLO_TensorOrTokenOrTuple>);
1653  let hasVerifier = 1;
1654}
1655
1656def StableHLO_DotOp: StableHLO_Op<"dot",
1657    [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1658  let summary = "Dot operator";
1659  let description = [{
1660    Performs dot products between vectors, vector/matrix and matrix/matrix
1661    multiplication.
1662
1663    See https://www.tensorflow.org/xla/operation_semantics#dot.
1664  }];
1665  let arguments = (
1666    ins HLO_Tensor:$lhs,
1667    HLO_Tensor:$rhs,
1668    StableHLO_PrecisionConfigAttr:$precision_config
1669  );
1670  let results = (outs HLO_Tensor);
1671  let hasVerifier = 1;
1672
1673  let extraClassDeclaration = [{
1674    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1675      return succeeded(mlir::verifyCompatibleShapes(l, r));
1676    }
1677  }];
1678}
1679
1680def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general", [NoSideEffect]> {
1681  let summary = "General Dot operator";
1682  let description = [{
1683    Performs general dot products between vectors, vector/matrix and
1684    matrix/matrix multiplication.
1685
1686    See https://www.tensorflow.org/xla/operation_semantics#dotgeneral.
1687  }];
1688  let arguments = (ins
1689    HLO_Tensor:$lhs,
1690    HLO_Tensor:$rhs,
1691    StableHLO_DotDimensionNumbers:$dot_dimension_numbers,
1692    StableHLO_PrecisionConfigAttr:$precision_config
1693  );
1694
1695  let results = (outs HLO_Tensor);
1696  let hasVerifier = 1;
1697}
1698
1699// Define Base Einsum op within the HLO dialect as these are client ops and
1700// therefore this class is not common between HLO and LHLO ops.
1701class BASE_EinsumOp {
1702  string summary = "Einsum operator";
1703
1704  string description = [{
1705    Returns a tensor whose elements are defined by equation, which is written
1706    in a shorthand form inspired by the Einstein summation convention.
1707  }];
1708}
1709
1710def StableHLO_EinsumOp: StableHLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp {
1711  let arguments = (ins
1712    HLO_Tensor:$lhs,
1713    HLO_Tensor:$rhs,
1714    StrAttr:$einsum_config
1715  );
1716
1717  let results = (outs HLO_Tensor);
1718
1719  // TODO(hinsu): Canonicalize to lower this client side HLO op to server
1720  // side HLO ops.
1721}
1722
1723def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp {
1724  let arguments = (ins
1725    HLO_Tensor:$operand,
1726    StrAttr:$einsum_config
1727  );
1728
1729  let results = (outs HLO_Tensor);
1730}
1731
1732def StableHLO_FftOp: StableHLO_Op<"fft", [InferTensorType, NoSideEffect]> {
1733  let summary = "Fast fourier transform operator";
1734  let description = [{
1735    Returns the fast-fourier-transform of the input array.
1736
1737    See
1738    https://www.tensorflow.org/xla/operation_semantics#fft.
1739  }];
1740  let arguments = (ins
1741    HLO_Tensor:$operand,
1742    StableHLO_FftTypeAttr: $fft_type,
1743    I64ElementsAttr:$fft_length
1744  );
1745
1746  let results = (outs HLO_Tensor);
1747}
1748
1749def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify, NoSideEffect]> {
1750  let summary = "Gather operator";
1751  let description = [{
1752    Stitches together several slices of `operand` from offsets specified in
1753    `start_indices` (each slice at a potentially different runtime offset).
1754
1755    See https://www.tensorflow.org/xla/operation_semantics#gather.
1756  }];
1757
1758  let arguments = (ins
1759    HLO_Tensor:$operand,
1760    HLO_IntTensor:$start_indices,
1761    StableHLO_GatherDimensionNumbers:$dimension_numbers,
1762    I64ElementsAttr:$slice_sizes,
1763    DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
1764  );
1765
1766  let results = (outs HLO_Tensor);
1767
1768  let extraClassDeclaration = [{
1769    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1770      return succeeded(mlir::verifyCompatibleShapes(l, r));
1771    }
1772  }];
1773}
1774
1775def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size", [NoSideEffect]> {
1776  let summary = "GetDimensionSize operator";
1777  let description = [{
1778    Returns the size of the given dimension of the operand.
1779
1780    See
1781    https://www.tensorflow.org/xla/operation_semantics#getdimensionsize.
1782  }];
1783  let arguments = (ins
1784    HLO_Tensor:$operand,
1785    I64Attr:$dimension
1786  );
1787  // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
1788  // XLA semantics is available. This limitation is because of the current XLA
1789  // implementation.
1790  let results = (outs I32Tensor);
1791
1792  let hasVerifier = 1;
1793}
1794
1795def StableHLO_MapOp: StableHLO_ShapedInterfaceOp<"map",
1796      [RecursiveSideEffects, SameOperandsAndResultShape,
1797       SingleBlockImplicitTerminator<"ReturnOp">]> {
1798  let summary = "Map operator";
1799  let description = [{
1800  Applies a scalar function over the given operands arrays, producing an array
1801  of the same dimensions where each element is the result of the mapped function
1802  applied to the corresponding elements in the input arrays.
1803
1804  The mapped function is an arbitrary computation with the restriction that it
1805  has N inputs of scalar type T and a single output with type S. The output has
1806  the same dimensions as the operands except that the element type T is replaced
1807  with S.
1808
1809  See https://www.tensorflow.org/xla/operation_semantics#map.
1810  }];
1811  let arguments = (ins
1812    Variadic<HLO_Tensor>:$operands,
1813    I64ElementsAttr:$dimensions
1814  );
1815  let regions = (region SizedRegion<1>:$computation);
1816  let results = (outs HLO_Tensor);
1817  let hasVerifier = 1;
1818}
1819
1820def StableHLO_ReshapeOp: StableHLO_Op<"reshape",
1821      [NoSideEffect, SameOperandsAndResultElementType]> {
1822  let summary = "Reshape operator";
1823  let description = [{
1824    Reshapes the dimensions of `operand` into a new configuration.
1825
1826    See https://www.tensorflow.org/xla/operation_semantics#reshape.
1827
1828    Example:
1829
1830    ```mlir
1831    %0 = stablehlo.reshape %arg0 : (tensor<2xf32>) -> tensor<1x2xf32>
1832    ```
1833  }];
1834
1835  let arguments = (ins HLO_Tensor:$operand);
1836
1837  let results = (outs HLO_StaticShapeTensor);
1838  let hasVerifier = 1;
1839
1840  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
1841}
1842
1843def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> {
1844  let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
1845  let description = [{
1846    Reshapes `operand` to `output_shape`.
1847
1848    Requires:
1849    - The length of `output_shape` is equal to the rank of `result`.
1850    - The number of elements in `operand` (that is, the product of extents of
1851      its shape) is equal to the number of elements in `output_shape` (that is,
1852      the product of values in `output_shape`).
1853
1854    Example:
1855
1856    ```mlir
1857    %0 = stablehlo.dynamic_reshape %arg0, %shape : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
1858    ```
1859  }];
1860
1861  let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape);
1862  let results = (outs HLO_Tensor:$result);
1863
1864  let hasVerifier = 1;
1865
1866  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
1867}
1868
1869def StableHLO_ScatterOp: StableHLO_Op<"scatter", [SameVariadicOperandSize, RecursiveSideEffects]> {
1870  let summary = "Scatter operator";
1871  let description = [{
1872    Generates a result which is the value of the input array `operand`,
1873    with several slices (at indices specified by `scatter_indices`)
1874    updated with the values in `updates` using `update_computation`.
1875
1876    See https://www.tensorflow.org/xla/operation_semantics#scatter.
1877  }];
1878  let arguments = (ins
1879    Variadic<HLO_Tensor>:$operands,
1880    TensorOf<[AnyInteger, Index]>:$scatter_indices,
1881    Variadic<HLO_Tensor>:$updates,
1882    StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers,
1883    DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
1884    DefaultValuedAttr<BoolAttr, "false">:$unique_indices
1885  );
1886
1887  let regions = (region SizedRegion<1>:$update_computation);
1888
1889  let results = (outs Variadic<HLO_Tensor>);
1890
1891  let hasVerifier = 1;
1892}
1893
1894def StableHLO_SelectOp: StableHLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise,
1895    InferTensorTypeWithReify]> {
1896  let summary = "Select operator";
1897  let description = [{
1898    Constructs an output tensor from the elements of `on_true` and `on_false`
1899    based on the values of `pred`. All three operands must be of the same shape
1900    with the exception of `pred`, which may also be a scalar in which case it is
1901    broadcasted.
1902
1903    See https://www.tensorflow.org/xla/operation_semantics#select.
1904  }];
1905  let arguments = (ins
1906    HLO_PredTensor:$pred,
1907    HLO_Tensor:$on_true,
1908    HLO_Tensor:$on_false
1909  );
1910
1911  let results = (outs HLO_Tensor);
1912
1913  let hasVerifier = 1;
1914
1915  let extraClassDeclaration = [{
1916    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1917      return succeeded(mlir::verifyCompatibleShapes(l, r));
1918    }
1919  }];
1920}
1921
1922def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
1923      [RecursiveSideEffects]> {
1924  let summary = "SelectAndScatter operator";
1925  let description = [{
1926    Runs a windowed selection `select` function over `operand` with shape
1927    `window_dimensions` and stride `window_strides`. This will produce an amount
1928    of selected locations whose shape matches `source`. These are then scattered
1929    to the output which is initialized with `init_value`.
1930    Multiple scattered elements which land in the same output location are
1931    combined using the `scatter` function.
1932
1933    See https://www.tensorflow.org/xla/operation_semantics#selectandscatter.
1934  }];
1935  let arguments = (ins
1936    HLO_Tensor:$operand,
1937    HLO_Tensor:$source,
1938    HLO_Tensor:$init_value,
1939    OptionalAttr<I64ElementsAttr>:$window_dimensions,
1940    OptionalAttr<I64ElementsAttr>:$window_strides,
1941    OptionalAttr<I64ElementsAttr>:$padding
1942  );
1943
1944  let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
1945
1946  let results = (outs HLO_Tensor);
1947
1948  let hasVerifier = 1;
1949}
1950
1951def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [NoSideEffect,
1952    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1953  let summary = "SetDimensionSize operator";
1954  let description = [{
1955    Sets the dynamic size of operand's given dimension. Pass through the operand
1956    as result, with dynamic dimension tracked by the compiler. Padded values
1957    will be ignored by downstream reduction ops.
1958
1959    See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize.
1960  }];
1961  let arguments = (ins
1962    HLO_Tensor:$operand,
1963    I32Tensor:$size,
1964    I64Attr:$dimension
1965  );
1966  let results = (outs HLO_Tensor);
1967
1968  let extraClassDeclaration = [{
1969    // Method from InferTypeOpInterface interface.
1970    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1971      if (l.size() != r.size()) return false;
1972      for (auto [lt, rt] : llvm::zip(l, r))
1973        if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt))
1974          return false;
1975      return true;
1976    }
1977  }];
1978
1979  let hasVerifier = 1;
1980}
1981
1982def StableHLO_SortOp : StableHLO_Op<"sort", [RecursiveSideEffects,
1983                                 SameOperandsAndResultShape]> {
1984  let summary = "Sort operator";
1985  let description = [{
1986    Sorts the given `operands` at the given `dimension` with the given
1987    `comparator`.
1988
1989    See https://www.tensorflow.org/xla/operation_semantics#sort.
1990  }];
1991  let arguments = (ins
1992    Variadic<HLO_Tensor>:$operands,
1993    DefaultValuedAttr<I64Attr, "-1">:$dimension,
1994    DefaultValuedAttr<BoolAttr, "false">:$is_stable
1995  );
1996
1997  let results = (outs Variadic<HLO_Tensor>);
1998
1999  let regions = (region SizedRegion<1>:$comparator);
2000
2001  let builders = [
2002    OpBuilder<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension,
2003      CArg<"bool", "false">:$is_stable)>];
2004
2005  let hasVerifier = 1;
2006}
2007
2008def StableHLO_ReverseOp: StableHLO_Op<"reverse",
2009      [NoSideEffect, HLO_CompatibleOperandsAndResultType]> {
2010  let summary = "Reverse operator";
2011  let description = [{
2012    Reverses the specified dimensions of `operand` according to the given
2013    `dimensions`.
2014
2015    See https://www.tensorflow.org/xla/operation_semantics#rev_reverse.
2016  }];
2017  let arguments = (ins
2018    HLO_Tensor:$operand,
2019    I64ElementsAttr:$dimensions
2020  );
2021
2022  let results = (outs HLO_Tensor);
2023}
2024
2025def StableHLO_PadOp: StableHLO_ShapedInterfaceOp<"pad",
2026      [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> {
2027  let summary = "Pad operator";
2028  let description = [{
2029    Pads edges and between the elements of `operand` with the `padding_value`
2030    according to the configuration parameters described below.
2031
2032    `edge_padding_low` and `edge_padding_high` specify the amount of padding
2033    added at the low-end (next to index 0) and the high-end (next to the
2034    highest index) of each dimension respectively. The amount of edge
2035    padding can be negative -- the absolute value of negative padding indicates
2036    the number of elements to remove from the specified dimension.
2037
2038    `interior_padding` specifies the amount of padding (non-negative) added
2039    between any two elements in each dimension. Interior padding occurs
2040    logically before edge padding, so in the case of negative edge padding,
2041    elements are removed from the interior-padded operand.
2042
2043    This operation is a no-op if, for all dimensions, the edge padding pairs are
2044    all (0, 0) and the interior padding values are all 0. The figure below shows
2045    examples of different `edge_padding` and `interior_padding` values for a
2046    two-dimensional array.
2047
2048    ![Examples](https://www.tensorflow.org/xla/images/ops_pad.png)
2049
2050  }];
2051  let arguments = (ins
2052    HLO_Tensor:$operand,
2053    HLO_Tensor:$padding_value,
2054    I64ElementsAttr: $edge_padding_low,
2055    I64ElementsAttr: $edge_padding_high,
2056    I64ElementsAttr: $interior_padding
2057  );
2058
2059  let results = (outs HLO_Tensor);
2060}
2061
2062def StableHLO_TraceOp: StableHLO_Op<"trace", []> {
2063  let summary = "Trace operator";
2064  let description = [{
2065    Emits a logging message `tag` with the `operand`.
2066
2067    Example:
2068
2069    ```mlir
2070    stablehlo.trace %arg0, "In test code." : tensor<5x1x5xi32>
2071    ```
2072  }];
2073  let arguments = (ins
2074    HLO_Tensor:$operand,
2075    StrAttr:$tag
2076  );
2077  let assemblyFormat = "$operand `,` $tag attr-dict `:` type($operand)";
2078}
2079
2080def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
2081      [NoSideEffect, SameOperandsAndResultElementType,
2082      DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
2083  let summary = "Transpose operator";
2084  let description = [{
2085    Permutes the dimensions of `operand` according to the given `permutation`.
2086
2087    `res_dimensions[i] = operand_dimensions[permutation[i]]`
2088
2089    See https://www.tensorflow.org/xla/operation_semantics#transpose.
2090  }];
2091  let arguments = (ins
2092    HLO_Tensor:$operand,
2093    I64ElementsAttr:$permutation
2094  );
2095  let results = (outs HLO_Tensor);
2096
2097  let extraClassDeclaration = [{
2098    // Method from  InferTypeOpInterface interface.
2099    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2100      return succeeded(mlir::verifyCompatibleShapes(l, r));
2101    }
2102  }];
2103}
2104
2105def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve",
2106    [NoSideEffect, SameOperandsAndResultElementType]> {
2107  let summary = "TriangularSolve operator";
2108  let description = [{
2109    Solves systems of linear equations with lower or upper triangular
2110    coefficient matrices by forward- or back-substitution. Broadcasting along
2111    leading dimensions, this routine solves one of the matrix systems
2112    op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
2113    op(a) is either op(a) = a, or op(a) = Transpose(a), or
2114    op(a) = Conj(Transpose(a)).
2115
2116    Input data is read only from the lower/upper triangle of a, depending on the
2117    value of lower. Values from the other triangle are ignored. Output data is
2118    returned in the same triangle; the values in the other triangle are
2119    implementation-defined and may be anything.
2120
2121    If the rank of a and b are greater than 2, they are treated as batches of
2122    matrices, where all except the minor 2 dimensions are batch dimensions. a
2123    and b must have equal batch dimensions.
2124
2125    See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
2126  }];
2127  let arguments = (ins
2128    HLO_FpOrComplexTensor:$a,
2129    HLO_FpOrComplexTensor:$b,
2130    BoolAttr:$left_side,
2131    BoolAttr:$lower,
2132    BoolAttr:$unit_diagonal,
2133    StableHLO_TransposeAttr:$transpose_a
2134  );
2135  let results = (outs HLO_FpOrComplexTensor);
2136
2137  let hasVerifier = 1;
2138}
2139
2140def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [
2141      RecursiveSideEffects,
2142      SameVariadicOperandSize,
2143      SingleBlockImplicitTerminator<"ReturnOp">
2144    ]> {
2145  let summary = "ReduceWindow operator";
2146  let description = [{
2147    Returns the result of executing a reduction function over all elements in
2148    each window of one or more arrays in parallel.
2149
2150    See https://www.tensorflow.org/xla/operation_semantics#reducewindow.
2151  }];
2152
2153  // TODO(hinsu): Verify that padding attribute is 2-d and the remaining
2154  // attributes are 1-d. Attributes' leading dimension should match rank of the
2155  // operands.
2156  let arguments = (ins
2157    Variadic<HLO_Tensor>:$operands,
2158    Variadic<HLO_Tensor>:$init_values,
2159    I64ElementsAttr:$window_dimensions,
2160    // If strides or dilations attributes are missing then the default value is
2161    // one for each of the operand dimensions. Similarly, padding values are zero
2162    // for both low and high in each of the dimensions, if not specified.
2163    OptionalAttr<I64ElementsAttr>:$window_strides,
2164    OptionalAttr<I64ElementsAttr>:$base_dilations,
2165    OptionalAttr<I64ElementsAttr>:$window_dilations,
2166    OptionalAttr<I64ElementsAttr>:$padding
2167  );
2168
2169  let results = (outs Variadic<HLO_Tensor>);
2170
2171  // TODO(hinsu): Verify that the attached body arguments and results are
2172  // compatible with reduce op's operands.
2173  let regions = (region SizedRegion<1>:$body);
2174
2175  // Builder for non-variadic version of the operation.
2176  let builders = [
2177    OpBuilder<(ins "Type":$result_type, "Value":$operand,
2178      "Value":$init_value,
2179      "DenseIntElementsAttr":$window_dimensions,
2180      "DenseIntElementsAttr":$window_strides,
2181      "DenseIntElementsAttr":$base_dilations,
2182      "DenseIntElementsAttr":$window_dilations,
2183      "DenseIntElementsAttr":$padding),
2184    [{
2185      build($_builder, $_state, TypeRange(result_type), ValueRange(operand),
2186            ValueRange(init_value), window_dimensions, window_strides,
2187            base_dilations, window_dilations, padding);
2188    }]>
2189  ];
2190
2191  let hasVerifier = 1;
2192  // TODO(hinsu): Implement custom printer and parser.
2193
2194  let extraClassDeclaration = [{
2195     // Get the operation used for reduction applied to `result_index`th result.
2196     Operation *getReductionOp(int result_index);
2197  }];
2198}
2199
2200def StableHLO_ReturnOp : StableHLO_Op<"return", [NoSideEffect, Terminator]> {
2201  let summary = [{
2202    The `hlo.return` operation terminates a region and returns values.
2203
2204    Example:
2205
2206    ```mlir
2207    %0 = stablehlo.reduce %arg0, %arg1 {
2208      ...
2209      stablehlo.return %1 : tensor<f32>
2210    }
2211    ```
2212  }];
2213
2214  let arguments = (ins
2215    Variadic<HLO_TensorOrTokenOrTuple >:$results
2216  );
2217
2218  let assemblyFormat = "$results attr-dict (`:` type($results)^)?";
2219}
2220
2221def StableHLO_TorchIndexSelectOp : StableHLO_Op<"torch_index_select", [NoSideEffect]> {
2222  let arguments = (ins
2223    HLO_Tensor:$operand,
2224    HLO_Tensor:$index,
2225    I64Attr:$dim,
2226    I64Attr:$batch_dims
2227  );
2228
2229  let results = (outs HLO_Tensor);
2230
2231  // TODO(hinsu): Canonicalize to lower this client side HLO op to server
2232  // side HLO ops.
2233}
2234
2235def StableHLO_OptimizationBarrierOp : StableHLO_Op<"optimization_barrier",
2236      [NoSideEffect, HLO_PairwiseSameOperandAndResultType]> {
2237  let summary = [{
2238    The `stablehlo.optimization_barrier` op blocks optimizations.
2239
2240    Example:
2241
2242    ```mlir
2243    %0:2 = stablehlo.optimization_barrier %arg0, %arg1 : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>)
2244    ```
2245  }];
2246
2247  let description = [{
2248    Blocks any optimization pass from moving computations across the barrier.
2249
2250    Ensures that all inputs are evaluated before any operators that depend on the barrier's outputs.
2251    See
2252    https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier
2253  }];
2254
2255  let arguments = (ins Variadic<HLO_TensorOrToken>:$operand);
2256
2257  let results = (outs Variadic<HLO_TensorOrToken>);
2258
2259  // TODO(b/241767462): Enhance type printing to condense pairwise ops.
2260  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
2261}
2262
2263//===----------------------------------------------------------------------===//
2264// StableHLO RNG Operators.
2265//===----------------------------------------------------------------------===//
2266
2267def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> {
2268  let summary = "RNG with uniform distribution.";
2269  let description = [{
2270    Constructs an output of a given shape with random numbers generated
2271    following the given `rng_distribution` with two parameters:
2272      `UNIFORM`: the uniform distribution over the interval `[a,b)`. The parameters
2273                 and output element type have to be a boolean type, an integral type or a
2274                 floating point types, and the types have to be consistent.
2275
2276                 See https://www.tensorflow.org/xla/operation_semantics#rnguniform.
2277
2278      `NORMAL`: the normal distribution with parameters `mu` (=`a`) and
2279                `sigma` (=`b`). The parameters and output shape have to have a
2280                floating point elemental type. The parameters furthermore have
2281                to be scalar valued.
2282
2283                See https://www.tensorflow.org/xla/operation_semantics#rngnormal.
2284  }];
2285  let arguments = (ins
2286    0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a,
2287    0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b,
2288    HLO_DimensionTensor:$shape,
2289    StableHLO_RngDistributionAttr:$rng_distribution
2290  );
2291
2292  let results = (outs HLO_PredIntOrFpTensor:$result);
2293
2294  let hasVerifier = 1;
2295
2296  let extraClassDeclaration = [{
2297    // Returns whether the return types are compatible.
2298    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2299      return succeeded(::mlir::verifyCompatibleShapes(l, r));
2300    }
2301  }];
2302}
2303
2304def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [NoSideEffect]> {
2305  let summary = "Uniform random number generator operator";
2306  let description = [{
2307    Returns an output with a given shape filled with uniform random bits using
2308    the specified algorithm (or backend default) and returns an updated state
2309    (with the same shape as initial state) and the generated random data.
2310
2311    See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.
2312  }];
2313  let arguments = (ins
2314    StableHLO_RngAlgorithmAttr:$rng_algorithm,
2315    HLO_IntOrFpTensor:$initial_state
2316  );
2317
2318  let results = (outs
2319      HLO_IntOrFpTensor:$output_state,
2320      HLO_IntOrFpTensor:$output
2321      );
2322
2323  let hasVerifier = 1;
2324}
2325
2326//===----------------------------------------------------------------------===//
2327// StableHLO Quantize Operator.
2328//===----------------------------------------------------------------------===//
2329
2330// TODO(b/230662142): Implement unknown scales/zero_point cases.
2331def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize",
2332      [NoSideEffect], TensorOf<[F32, BF16, HLO_QuantizedInt]>,
2333      HLO_QuantizedIntTensor> {
2334  let summary = "Uniform quantize operator";
2335  let description = [{
2336    Converts floating point tensors or uniform quantized integer tensors to
2337    uniform quantized integer tensors according to the quantization parameters
2338    defined by the output type.
2339
2340    Example:
2341
2342    ```mlir
2343    %0 = stablehlo.uniform_quantize %arg0 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
2344    ```
2345  }];
2346}
2347
2348def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize",
2349      [InferTensorType, NoSideEffect], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> {
2350  let summary = "Uniform dequantize operator";
2351  let description = [{
2352    Converts quantized array of integers to floating-points according to the
2353    quantization parameters defined by the input type.
2354
2355    Example:
2356
2357    ```mlir
2358    %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
2359    ```
2360  }];
2361}
2362
2363def StableHLO_ReducePrecisionOp :
2364    StableHLO_Op<"reduce_precision", [HLO_CompatibleOperandsAndResultType]> {
2365  let summary = "Reduce precision operator";
2366  let description = [{
2367    Models the effect of converting floating - point values to a lower -
2368    precision format(such as IEEE - FP16) and back to the original
2369    format. The number of exponent and mantissa bits in the lower -
2370    precision format can be specified arbitrarily,
2371    although all bit sizes may not be supported on all hardware
2372    implementations.
2373
2374    See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
2375  }];
2376  let arguments = (ins
2377    HLO_FpTensor:$operand,
2378    I32Attr:$exponent_bits,
2379    I32Attr:$mantissa_bits
2380  );
2381  let hasVerifier = 1;
2382  let results = (outs HLO_FpTensor:$output);
2383}
2384
2385def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
2386      "real_dynamic_slice",
2387      [NoSideEffect, AllElementTypesMatch<["operand", "result"]>,
2388       AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
2389  let summary = "Real Dynamic Slice operator";
2390  let description = [{
2391    The dynamic shape version of SliceOp. Extracts a sub-array from the input
2392    array according to start_indices, limit_indices and strides. Expect
2393    start_indices/limit_indices/strides to be statically shaped and matching
2394    the rank of the input.
2395
2396    Example:
2397
2398    ```mlir
2399    %0 = stablehlo.real_dynamic_slice %input, %start, %limit, %strides
2400           : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32>
2401    ```
2402  }];
2403  let arguments = (ins
2404    HLO_Tensor:$operand,
2405    HLO_DimensionTensor:$start_indices,
2406    HLO_DimensionTensor:$limit_indices,
2407    HLO_DimensionTensor:$strides
2408  );
2409  let results = (outs HLO_Tensor:$result);
2410  let hasVerifier = 1;
2411
2412  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
2413}
2414
2415def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
2416      [NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>,
2417      AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
2418  let summary = "Dynamic Pad operator";
2419  let description = [{
2420    The dynamic shape version of PadOp. Pads the edges of `operand` with the
2421    `padding_value` and according to the passed configuration. Expect
2422    edge_padding_low/edge_padding_high/interior_padding to be statically shaped
2423    and matching the rank of the input.
2424
2425    See https://www.tensorflow.org/xla/operation_semantics#pad.
2426
2427    Example:
2428
2429    ```mlir
2430    %0 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4
2431           : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
2432    ```
2433  }];
2434  let arguments = (ins
2435    HLO_Tensor:$operand,
2436    HLO_Tensor:$padding_value,
2437    HLO_DimensionTensor:$edge_padding_low,
2438    HLO_DimensionTensor:$edge_padding_high,
2439    HLO_DimensionTensor:$interior_padding
2440  );
2441  let results = (outs HLO_Tensor:$result);
2442  let description = [{
2443    Dynamically Pads the `operand`, with amount of padding added at
2444    low-end/high-end/interior is passed through input tensors.
2445  }];
2446  let hasVerifier = 1;
2447
2448  let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
2449}
2450
2451def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
2452                                [InferTensorTypeWithReify, NoSideEffect]> {
2453  string summary = "Dynamic Gather operator";
2454  string description = [{
2455    The dynamic shape version of GatherOp. Stitches together several slices of
2456    an input array.
2457  }];
2458
2459  let arguments = (ins
2460    HLO_Tensor:$operand,
2461    HLO_IntTensor:$start_indices,
2462    HLO_IntTensor:$slice_sizes,
2463    StableHLO_GatherDimensionNumbers:$dimension_numbers,
2464    DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
2465  );
2466  let results = (outs HLO_Tensor);
2467
2468  let extraClassDeclaration = [{
2469    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2470      return succeeded(mlir::verifyCompatibleShapes(l, r));
2471    }
2472  }];
2473}
2474
2475def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [NoSideEffect]> {
2476  let summary = "Dynamic Convolution operator";
2477  let description = [{
2478    The dynamic shape version of ConvOp. Computes a convolution with dynamic padding.
2479  }];
2480
2481  let arguments = !con(
2482    (ins
2483       HLO_Tensor:$lhs,
2484       HLO_Tensor:$rhs,
2485       HLO_Tensor:$d_padding),
2486    StableHLO_ConvolutionAttributes.attributes);
2487  let results = (outs HLO_Tensor);
2488}
2489
2490def StableHLO_ComputeReshapeShapeOp :
2491    StableHLO_Op<"compute_reshape_shape", [NoSideEffect]> {
2492  string summary = "Compute input for reshape with any dynamic dim resolved";
2493
2494  string description = [{
2495    This operation handles the dynamic aspect of a TF/NumPy/CHLO reshape. The
2496    dynamic aspect is that a single extent can be -1 and that dimension will
2497    instead be computed. This handles the computation and can then be passed to
2498    an HLO DynamicReshapeOp to replicate the TF/NumPy reshape behavior.
2499
2500    This op has undefined behavior if the dimensions do not evenly divide the
2501    number of elements, or if there are multiple -1 values. It is an identity op
2502    if no dimensions are -1.
2503
2504    ```
2505    %0 = hlo.compute_reshape_shape 12, [2, -1] -> [2, 6]
2506    ```
2507  }];
2508
2509  let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape);
2510  let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result);
2511
2512  // TODO (b/241767462): Use functional-type for type printing for consistency.
2513  let assemblyFormat = "$num_elements `,` $dynamic_shape attr-dict `:` type($num_elements) `,` type($dynamic_shape) `->` type($result)";
2514}
2515
2516def StableHLO_CstrReshapableOp :
2517    StableHLO_Op<"cstr_reshapable", [NoSideEffect]> {
2518  string summary = "Compute input for reshape with any dynamic dim resolved";
2519
2520  string description = [{
2521    This operation creates a witness on the constraint that a given shape would
2522    be a valid reshape for the given number of elements.
2523
2524    ```
2525    %0 = stablehlo.cstr_reshapable 12, [2, -1] -> success
2526    %1 = stablehlo.cstr_reshapable 13, [2, -1] -> failure
2527    ```
2528  }];
2529
2530  let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape);
2531  let results = (outs Shape_WitnessType:$result);
2532
2533  // TODO (b/241767462): Use functional-type for type printing for consistency.
2534  let assemblyFormat = "$num_elements `,` $dynamic_shape attr-dict `:` type($num_elements) `,` type($dynamic_shape)";
2535}
2536
2537#endif // STABLEHLO_DIALECT_STABLEHLO_OPS
2538