1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 Copyright 2022 The StableHLO Authors. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15==============================================================================*/ 16 17#ifndef STABLEHLO_DIALECT_STABLEHLO_OPS 18#define STABLEHLO_DIALECT_STABLEHLO_OPS 19 20include "dialect/Base.td" 21include "mlir/Dialect/Shape/IR/ShapeBase.td" 22include "mlir/IR/OpBase.td" 23include "mlir/Interfaces/InferTypeOpInterface.td" 24include "mlir/Interfaces/SideEffectInterfaces.td" 25include "mlir/IR/OpAsmInterface.td" 26 27def StableHLO_Dialect : Dialect { 28 let name = "stablehlo"; 29 let cppNamespace = "::mlir::stablehlo"; 30 31 let description = [{ 32 StableHLO is an operation set that expresses ML computations. It has been 33 originally bootstrapped from the MHLO dialect and enhances it with additional 34 functionality, including serialization and versioning, to be used as 35 a portability layer between ML frameworks and ML compilers. 36 }]; 37 38 let emitAccessorPrefix = kEmitAccessorPrefix_Raw; 39 let useDefaultAttributePrinterParser = 0; 40 let useDefaultTypePrinterParser = 0; 41} 42 43class StableHLO_Op<string mnemonic, list<Trait> traits> : 44 Op<StableHLO_Dialect, mnemonic, traits> { 45} 46 47include "dialect/StablehloEnums.td" 48include "dialect/StablehloAttrs.td" 49 50class StableHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> : 51 StableHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface, 52 ["reifyReturnTypeShapes"]>]> { 53} 54 55//===----------------------------------------------------------------------===// 56// StableHLO nullary op definitions. 57//===----------------------------------------------------------------------===// 58 59def StableHLO_ConstantOp : StableHLO_Op<"constant", 60 [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 61 let summary = "Constant operator"; 62 let description = [{ 63 Represents a constant value. 64 }]; 65 let arguments = (ins 66 ElementsAttr:$value 67 ); 68 69 let results = (outs 70 HLO_StaticShapeTensor:$output 71 ); 72 73 let builders = [ 74 OpBuilder<(ins "Attribute":$value)>]; 75 76 let hasCustomAssemblyFormat = 1; 77 let hasFolder = 1; 78 79 let extraClassDeclaration = [{ 80 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); 81 }]; 82} 83 84def StableHLO_IotaOp : StableHLO_Op<"iota", [NoSideEffect]> { 85 let summary = "Iota operator"; 86 let description = [{ 87 Creates a rank 1 array of values starting at zero and incrementing by one. 88 }]; 89 let arguments = (ins I64Attr:$iota_dimension); 90 91 let results = (outs HLO_IntFpOrComplexTensor:$output); 92 93 let hasVerifier = 1; 94} 95 96def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [NoSideEffect]> { 97 let summary = "Create linear increasing values from 0 to length -1."; 98 let description = [{ 99 Produces an HLO Tensor of the specified shape, with an incremental set of 100 values along the specified dimension starting at 0. 101 102 Requires: 103 - The output length of the tensor result. 104 }]; 105 106 let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); 107 let results = (outs HLO_Tensor:$result); 108} 109 110def StableHLO_CreateTokenOp : StableHLO_Op<"create_token", [NoSideEffect]> { 111 let summary = "Create Token operator"; 112 113 let description = [{ 114 Produces a HLO token. Tokens are used for ordering side-effecting operations. 115 This is exported to HLO as an AfterAll operation with no operands to 116 generate a token. 117 118 Example: 119 120 ```mlir 121 %1 = stablehlo.create_token : !stablehlo.token 122 ``` 123 }]; 124 125 let results = (outs HLO_Token:$output); 126 127 let assemblyFormat = "attr-dict `:` type(results)"; 128} 129 130//===----------------------------------------------------------------------===// 131// StableHLO unary elementwise op definitions. 132//===----------------------------------------------------------------------===// 133// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions 134 135class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits, 136 Type OperandType, Type ResultType = OperandType> : StableHLO_Op<mnemonic, traits # [Elementwise, 137 InferShapedTypeOpInterface, SameOperandsAndResultShape]> { 138 let arguments = (ins OperandType:$operand); 139 let results = (outs ResultType:$result); 140 let extraClassDeclaration = [{ 141 LogicalResult reifyReturnTypeShapes( 142 OpBuilder& builder, ValueRange operands, 143 SmallVectorImpl<Value>& reifiedReturnShapes) { 144 return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), 145 operands.front(), 146 &reifiedReturnShapes); 147 } 148 // Relax the strict default implementation with one that allows 149 // for StableHLO-specific differences. 150 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 151 if (l.size() != r.size()) return false; 152 for (auto [lt, rt] : llvm::zip(l, r)) 153 if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt)) 154 return false; 155 return true; 156 } 157 }]; 158 let extraClassDefinition = [{ 159 ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { 160 return ::mlir::stablehlo::parseUnaryOp(parser, result); 161 } 162 void $cppClass::print(OpAsmPrinter &p) { 163 ::mlir::stablehlo::printUnaryOp(getOperation(), p); 164 } 165 }]; 166 let hasCustomAssemblyFormat = 1; 167} 168 169// Abs supports complex to real, so element type is not guaranteed to match. 170def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs", 171 [NoSideEffect, 172 DeclareOpInterfaceMethods<InferTypeOpInterface>], 173 TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> { 174 let summary = "Absolute value operator"; 175 let description = [{ 176 Returns `abs(operand)` element-wise. 177 178 See 179 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 180 }]; 181} 182 183def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt", 184 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { 185 let summary = "Cubic root operator"; 186 let description = [{ 187 Returns element-wise cubic root of the operand. 188 189 See 190 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 191 }]; 192} 193def StableHLO_CeilOp: StableHLO_UnaryElementwiseOp<"ceil", 194 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { 195 let summary = "Ceil operator"; 196 let description = [{ 197 Returns `Ceil(operand)` element-wise. 198 199 See 200 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 201 }]; 202} 203def StableHLO_ConvertOp : StableHLO_UnaryElementwiseOp<"convert", 204 [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor> { 205 let summary = "Convert operator"; 206 let description = [{ 207 Performs element-wise conversion of values from one type to another, e.g. 208 float to int. 209 210 See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. 211 }]; 212 let builders = [ 213 OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>]; 214} 215 216def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros", 217 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> { 218 let summary = "Count-leading-zeros (Clz) operator"; 219 let description = [{ 220 Returns the number of leading zeros in each operand element-wise. 221 222 See 223 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 224 }]; 225} 226 227def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine", 228 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 229 let summary = "Cos operator"; 230 let description = [{ 231 Returns `Cos(operand)` element-wise. 232 233 See 234 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 235 }]; 236} 237 238def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential", 239 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 240 let summary = "Exponential operator"; 241 let description = [{ 242 Returns `e^(operand)` element-wise. 243 244 See 245 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 246 }]; 247} 248def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", 249 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 250 let summary = "Exponential minus one operator"; 251 let description = [{ 252 Returns `e^(operand) - 1` element-wise. 253 254 See 255 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 256 }]; 257} 258def StableHLO_FloorOp: StableHLO_UnaryElementwiseOp<"floor", 259 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { 260 let summary = "Floor operator"; 261 let description = [{ 262 Returns `Floor(operand)` element-wise. 263 264 See 265 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 266 }]; 267} 268def StableHLO_ImagOp: StableHLO_UnaryElementwiseOp<"imag", 269 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], 270 HLO_FpOrComplexTensor> { 271 let summary = "Imag operator"; 272 let description = [{ 273 Returns `Imag(operand)` element-wise. 274 }]; 275 let results = (outs HLO_FpTensor); 276} 277 278def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, 279 DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor> { 280 let summary = "IsFinite operator"; 281 let description = [{ 282 Tests whether each element of operand is finite, i.e., is not positive or 283 negative infinity, and is not NaN. Returns a tensor of 1-bit integers with 284 the same shape as the input, where each element is nonzero (i.e. true) if 285 and only if the corresponding input element is finite. 286 287 See 288 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 289 }]; 290 let arguments = (ins HLO_FpTensor:$x); 291 let results = (outs HLO_PredTensor:$y); 292} 293 294def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log", 295 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 296 let summary = "Logarithm operator"; 297 let description = [{ 298 Returns `log(operand)` element-wise. 299 300 See 301 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 302 }]; 303} 304def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one", 305 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 306 let summary = "Log1p operator"; 307 let description = [{ 308 Returns `log(operand+1)` element-wise. 309 310 See 311 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 312 }]; 313} 314def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic", 315 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 316 let summary = "Logistic operator"; 317 let description = [{ 318 Returns `logistic(operand)` element-wise. 319 320 See 321 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 322 }]; 323} 324def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not", 325 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_PredOrIntTensor> { 326 let summary = "Not operator"; 327 let description = [{ 328 Returns biwise-NOT of `operand` element-wise. The input tensor must be 329 of type integer `HLO_Int` or boolean `HLO_Pred`. 330 331 Note: For boolean tensor, the bitwise-NOT is equivalent to logical-NOT. 332 }]; 333} 334 335def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate", 336 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> { 337 let summary = "Negation operator"; 338 let description = [{ 339 Returns `-operand` element-wise. 340 341 See 342 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 343 }]; 344} 345 346def StableHLO_PopulationCountOp: StableHLO_UnaryElementwiseOp<"popcnt", 347 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> { 348 let summary = "PopulationCount operator"; 349 let description = [{ 350 Returns the number of bits set in each operand element-wise. 351 352 See 353 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 354 }]; 355} 356def StableHLO_RealOp: StableHLO_UnaryElementwiseOp<"real", 357 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], 358 HLO_FpOrComplexTensor> { 359 let summary = "Real operator"; 360 let description = [{ 361 Returns `Real(operand)` element-wise. 362 }]; 363 let results = (outs HLO_FpTensor); 364} 365 366def StableHLO_RoundOp: StableHLO_UnaryElementwiseOp<"round_nearest_afz", 367 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { 368 let summary = "Round operator, ties away from zero"; 369 let description = [{ 370 Returns `Round(operand)` element-wise, rounding to nearest integer with 371 half-way cases rounding away from zero. 372 373 See 374 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 375 }]; 376} 377 378def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_even", 379 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { 380 let summary = "Round operator, ties to even"; 381 let description = [{ 382 Returns `Round(operand)` element-wise, rounding to nearest integer with 383 half-way cases rounding towards even numbers. 384 385 See 386 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 387 }]; 388} 389 390def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", 391 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 392 let summary = "Reciprocal Square-root operator"; 393 let description = [{ 394 Returns `1.0 / sqrt(operand)` element-wise. 395 396 See 397 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 398 }]; 399} 400def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign", 401 [NoSideEffect, HLO_CompatibleOperandsAndResultType], 402 TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> { 403 let summary = "Sign operator"; 404 let description = [{ 405 Returns `sign(operand)` element-wise, where 406 407 ``` 408 sign(x) = -1 : x < 0 409 = -0 : x = -0 410 = NaN : x = NaN 411 = +0 : x = +0 412 = 1 : x > 0 413 ``` 414 415 See 416 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 417 }]; 418} 419 420def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine", 421 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 422 let summary = "Sin operator"; 423 let description = [{ 424 Returns `Sin(operand)` element-wise. 425 426 See 427 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 428 }]; 429} 430 431def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", 432 [NoSideEffect, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { 433 let summary = "Square-root operator"; 434 let description = [{ 435 Returns `sqrt(operand)` element-wise. 436 437 See 438 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 439 }]; 440} 441 442def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh", 443 [NoSideEffect, HLO_CompatibleOperandsAndResultType], 444 HLO_FpOrComplexTensor> { 445 let summary = "Tanh operator"; 446 let description = [{ 447 Returns `tanh(operand)` element-wise. 448 449 See 450 https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. 451 }]; 452} 453//===----------------------------------------------------------------------===// 454// StableHLO binary elementwise op definitions. 455//===----------------------------------------------------------------------===// 456// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations 457 458class StableHLO_BinaryElementwiseOpNoAssembly<string mnemonic, list<Trait> traits> : 459 StableHLO_Op<mnemonic, traits # [InferShapedTypeOpInterface, 460 SameOperandsAndResultShape, Elementwise]> { 461 let arguments = (ins 462 HLO_Tensor:$lhs, 463 HLO_Tensor:$rhs 464 ); 465 466 let extraClassDeclaration = [{ 467 LogicalResult reifyReturnTypeShapes( 468 OpBuilder& builder, ValueRange operands, 469 SmallVectorImpl<Value>& reifiedReturnShapes) { 470 return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), 471 operands.front(), 472 &reifiedReturnShapes); 473 } 474 // Relax the strict default implementation with one that allows 475 // for StableHLO-specific differences. 476 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 477 if (l.size() != r.size()) return false; 478 for (auto [lt, rt] : llvm::zip(l, r)) 479 if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt)) 480 return false; 481 return true; 482 } 483 }]; 484 485 let results = (outs HLO_Tensor:$result); 486} 487 488class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits> : 489 StableHLO_BinaryElementwiseOpNoAssembly<mnemonic, traits> { 490 let extraClassDefinition = [{ 491 ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { 492 return ::mlir::stablehlo::parseBinaryOp(parser, result); 493 } 494 void $cppClass::print(OpAsmPrinter &p) { 495 ::mlir::stablehlo::printBinaryOp(getOperation(), p); 496 } 497 }]; 498 let hasCustomAssemblyFormat = 1; 499} 500 501def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", 502 [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 503 let summary = "Addition operator"; 504 let description = [{ 505 Returns `lhs + rhs` element-wise. 506 507 See 508 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 509 }]; 510} 511 512def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2", 513 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 514 let summary = "Atan2 operator"; 515 let description = [{ 516 Returns `atan2(lhs/rhs)` element-wise. 517 518 See 519 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 520 }]; 521} 522 523def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOpNoAssembly<"complex", [NoSideEffect, 524 SameOperandsElementType, DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 525 let summary = "Complex operator"; 526 let description = [{ 527 Performs element-wise conversion of a pair of real and imaginary values to 528 a complex value. 529 }]; 530 let arguments = (ins HLO_Fp32Or64Tensor:$lhs, HLO_Fp32Or64Tensor:$rhs); 531 let results = (outs HLO_ComplexTensor:$result); 532 533 // TODO(b/241767457): Remove parens when cleaning up BinaryOps. 534 let assemblyFormat = "`(`operands`)` attr-dict `:` `(`type(operands)`)` `->` type($result)"; 535} 536 537def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", 538 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 539 let summary = "Division operator"; 540 let description = [{ 541 Returns `lhs / rhs` element-wise. 542 543 See 544 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 545 }]; 546} 547 548def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum", 549 [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 550 let summary = "Maximum operator"; 551 let description = [{ 552 Returns `max(lhs, rhs)` element-wise. 553 554 See 555 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 556 }]; 557} 558 559def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum", 560 [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 561 let summary = "Minimum operator"; 562 let description = [{ 563 Returns `min(lhs, rhs)` element-wise. 564 565 See 566 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 567 }]; 568} 569 570def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", 571 [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 572 let summary = "Multiplication operator"; 573 let description = [{ 574 Returns `lhs * rhs` element-wise. 575 576 See 577 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 578 }]; 579} 580 581def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power", 582 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 583 let summary = "Power operator"; 584 let description = [{ 585 Returns `lhs ^ rhs` element-wise. 586 587 See 588 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 589 }]; 590} 591def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder", 592 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 593 let summary = "Remainder operator"; 594 let description = [{ 595 Returns `lhs % rhs` element-wise. 596 597 See 598 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 599 }]; 600} 601 602def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left", 603 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 604 let summary = "Shift Left operator"; 605 let description = [{ 606 Returns `lhs << rhs` element-wise. 607 608 See 609 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 610 }]; 611} 612 613def StableHLO_ShiftRightArithmeticOp : StableHLO_BinaryElementwiseOp<"shift_right_arithmetic", 614 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 615 let summary = "Shift right arithmetic operator"; 616 let description = [{ 617 Returns arithmetic `lhs >> rhs` element-wise. 618 619 See 620 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 621 }]; 622} 623 624def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_logical", 625 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 626 let summary = "Shift right logical operator"; 627 let description = [{ 628 Returns logical `lhs >> rhs` element-wise. 629 630 See 631 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 632 }]; 633} 634 635def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract", 636 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 637 let summary = "Subtraction operator"; 638 let description = [{ 639 Returns `lhs - rhs` element-wise. 640 641 See 642 https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. 643 }]; 644} 645 646//===----------------------------------------------------------------------===// 647// StableHLO binary logical elementwise op definitions. 648//===----------------------------------------------------------------------===// 649 650// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations 651class StableHLO_BinaryBiwiseOrLogicalElementwiseOp<string mnemonic> : 652 StableHLO_BinaryElementwiseOp<mnemonic, 653 [Commutative, NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 654 let arguments = (ins 655 HLO_PredOrIntTensor:$lhs, 656 HLO_PredOrIntTensor:$rhs 657 ); 658} 659 660def StableHLO_AndOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> { 661 let summary = "And operator"; 662 let description = [{ 663 Returns biwise-AND of `lhs` and `rhs` element-wise. The input tensors must 664 be of type integer `HLO_Int` or boolean `HLO_Pred`. 665 666 Note: For boolean tensor, the bitwise-AND is equivalent to logical-AND. 667 }]; 668} 669 670def StableHLO_OrOp: StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> { 671 let summary = "Or operator"; 672 let description = [{ 673 Returns biwise-OR of `lhs` and `rhs` element-wise. The input tensors must 674 be of type integer `HLO_Int` or boolean `HLO_Pred`. 675 676 Note: For boolean tensor, the bitwise-OR is equivalent to logical-OR. 677 }]; 678} 679 680def StableHLO_XorOp : StableHLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> { 681 let summary = "Xor operator"; 682 let description = [{ 683 Returns biwise-XOR of `lhs` and `rhs` element-wise. The input tensors must 684 be of type integer `HLO_Int` or boolean `HLO_Pred`. 685 686 Note: For boolean tensor, the bitwise-XOR is equivalent to logical-XOR. 687 }]; 688} 689 690//===----------------------------------------------------------------------===// 691// StableHLO communication op definitions. 692//===----------------------------------------------------------------------===// 693 694// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. 695// InfeedWithToken allows ordering of infeed HLO instructions using tokens. 696def StableHLO_InfeedOp : StableHLO_Op<"infeed", []> { 697 698 let summary = "Infeed operator"; 699 700 let description = [{ 701 Reads a single data item from the implicit Infeed streaming interface of 702 the device, interpreting the data as the given shape, and returns a XlaOp 703 of the data. Multiple Infeed operations are allowed in a computation, but 704 there must be a total order among the Infeed operations. 705 706 Attributes: 707 layout: Array attribute. Each element of the array is a minor_to_major 708 array corresponding to the shape of the data read from the infeed 709 interface. 710 711 See https://www.tensorflow.org/xla/operation_semantics#infeed. 712 }]; 713 714 let arguments = (ins 715 HLO_Token:$token, 716 DefaultValuedStrAttr<StrAttr, "">:$infeed_config, 717 OptionalAttr<ArrayAttr>:$layout 718 ); 719 let results = (outs Variadic<HLO_TensorOrToken>); 720 let hasVerifier = 1; 721} 722 723// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'. 724// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. 725def StableHLO_OutfeedOp : StableHLO_Op<"outfeed", []> { 726 727 let summary = "Outfeed operator"; 728 729 let description = [{ 730 Generates outgoing data transfers for the given data. It takes data and a 731 token type operand and produces a token type value. Tokens are used for 732 ordering side-effecting operations. 733 734 See https://www.tensorflow.org/xla/operation_semantics#outfeed. 735 }]; 736 737 let arguments = (ins 738 Variadic<HLO_Tensor>:$operands, 739 HLO_Token:$token, 740 DefaultValuedStrAttr<StrAttr, "">:$outfeed_config 741 ); 742 let results = (outs HLO_Token); 743} 744 745def StableHLO_SendOp : StableHLO_Op<"send", []> { 746 747 let summary = "Send operator"; 748 749 let description = [{ 750 Sends the given operand data to a Recv instruction in another computation 751 that shares the same channel handle. Does not return any data. Similar to 752 the Recv operation, Send operation represents synchronous communication, 753 and is internally decomposed into 2 HLO instructions (Send and SendDone) to 754 enable asynchronous data transfers. 755 756 See https://www.tensorflow.org/xla/operation_semantics#send. 757 }]; 758 759 let arguments = (ins 760 Variadic<HLO_Tensor>:$operands, 761 HLO_Token:$token, 762 StableHLO_ChannelHandle:$channel_handle, 763 DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer 764 ); 765 766 let results = (outs HLO_Token); 767} 768 769def StableHLO_RecvOp : StableHLO_Op<"recv", []> { 770 771 let summary = "Recv operator"; 772 773 let description = [{ 774 Receives data of the given shape from a Send instruction in another 775 computation that shares the same channel handle. Returns a tuple containing 776 value for the received data and a token. Recv operation represents 777 synchronous communication. However, the instruction is internally decomposed 778 into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data 779 transfers. 780 781 See https://www.tensorflow.org/xla/operation_semantics#recv. 782 }]; 783 784 let arguments = (ins 785 HLO_Token:$token, 786 StableHLO_ChannelHandle:$channel_handle, 787 DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer 788 ); 789 790 let results = (outs Variadic<HLO_TensorOrToken>); 791 let hasVerifier = 1; 792} 793 794//===----------------------------------------------------------------------===// 795// StableHLO parallelism related op definitions. 796//===----------------------------------------------------------------------===// 797 798def StableHLO_ReplicaIdOp : StableHLO_Op<"replica_id", [NoSideEffect, 799 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 800 let summary = "ReplicaId operator"; 801 let description = [{ 802 Returns the unique ID (int32 scalar) of the replica. 803 804 The unique ID of each replica is an unsigned integer in the interval [0, N), 805 where N is the number of replicas. Since all the replicas are running the 806 same program, a ReplicaId() call in the program will return a different 807 value on each replica. 808 809 See https://www.tensorflow.org/xla/operation_semantics#replicaid. 810 811 Example: 812 813 ```mlir 814 %0 = stablehlo.replica_id : tensor<ui32> 815 ``` 816 }]; 817 let results = (outs TensorOf<[UI32]>); 818 819 let assemblyFormat = "attr-dict `:` type(results)"; 820} 821 822//===----------------------------------------------------------------------===// 823// StableHLO control flow op definitions. 824//===----------------------------------------------------------------------===// 825 826def StableHLO_AfterAllOp : StableHLO_Op<"after_all", [NoSideEffect]> { 827 828 let summary = "AfterAll operator"; 829 830 let description = [{ 831 AfterAll takes a variadic number of tokens and produces a single token. 832 Tokens are primitive types which can be threaded between side-effecting 833 operations to enforce ordering. AfterAll can be used as a join of tokens 834 for ordering a operation after a set operations. 835 836 See https://www.tensorflow.org/xla/operation_semantics#afterall. 837 838 Example: 839 840 ```mlir 841 %0 = stablehlo.after_all %arg0, %arg1 : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token 842 ``` 843 }]; 844 845 let arguments = (ins Variadic<HLO_Token>:$operands); 846 let results = (outs HLO_Token); 847 848 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 849} 850 851// Xla Client API has two separate calls for indexed and predicated conditional, 852// although both eventually map to kConditional HLO. IfOp maps to predicated 853// conditional use of kConditional HLO. 854def StableHLO_IfOp: StableHLO_Op<"if", [ 855 RecursiveSideEffects, 856 SingleBlockImplicitTerminator<"ReturnOp">]> { 857 let summary = "If operator"; 858 859 let description = [{ 860 Executes the function `true_branch` if `pred` is true or `false_branch` if 861 pred is false, and returns the result. 862 863 The type of the returned values of `true_branch` and `false_branch` 864 functions must be the same and equal to the types of the values returned by 865 the operation. 866 867 Note that only one of two functions will be executed depending on the value 868 of `pred`. 869 }]; 870 871 let arguments = (ins 872 HLO_PredTensor:$pred 873 ); 874 875 let regions = (region SizedRegion<1>:$true_branch, 876 SizedRegion<1>:$false_branch); 877 878 let results = (outs Variadic<HLO_TensorOrToken>); 879 880 let hasVerifier = 1; 881} 882 883// Xla Client API has two separate calls for indexed and predicated conditional, 884// although both eventually map to kConditional HLO. CaseOp maps to indexed 885// conditional use of kConditional HLO. 886def StableHLO_CaseOp: StableHLO_Op<"case", [ 887 RecursiveSideEffects, 888 SingleBlockImplicitTerminator<"ReturnOp"> 889 ]> { 890 let summary = "Switch-Case operator"; 891 let description = [{ 892 Returns the result of executing `branches[index]`. If `index` is < 0 or >= 893 N, then `branches[N-1]` is executed as the default branch. 894 895 The type of the returned values of each branch must be the same and equal 896 to the types of the values returned by the operation. 897 898 Note that only one of the branches will be executed depending on the value 899 of index. 900 }]; 901 902 let arguments = (ins 903 I32Tensor:$index 904 ); 905 906 let regions = (region VariadicRegion<SizedRegion<1>>:$branches); 907 908 let results = (outs Variadic<HLO_TensorOrToken>); 909 910 let hasVerifier = 1; 911} 912 913 914def StableHLO_WhileOp: StableHLO_Op<"while", [ 915 RecursiveSideEffects, 916 HLO_PairwiseSameOperandAndResultType, 917 SingleBlockImplicitTerminator<"ReturnOp">, 918 OpAsmOpInterface 919 ]> { 920 let summary = "While operator"; 921 let description = [{ 922 Returns the result of executing a body function until the cond body returns 923 true. 924 925 See https://www.tensorflow.org/xla/operation_semantics#while. 926 }]; 927 let arguments = (ins Variadic<HLO_TensorOrToken>:$operand); 928 929 let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); 930 931 let results = (outs Variadic<HLO_TensorOrToken>); 932 933 let extraClassDeclaration = [{ 934 // Method of OpAsmOpInterface used during custom printing to name the block 935 // arguments in the nested regions. We name both the condition and the body 936 // regions entry arguments the same way, with a `iterArg` prefix. Since the 937 // two regions are side-by-side they will have the same name, which allows 938 // us to print them once and share it for the two regions, and still be able 939 // to parse them back. 940 void getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { 941 for (BlockArgument arg : region.getArguments()) 942 setNameFn(arg, "iterArg"); 943 } 944 }]; 945 let hasCustomAssemblyFormat = 1; 946 let hasVerifier = 1; 947} 948 949def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", [SameOperandsAndResultElementType]> { 950 951 string summary = "AllGather operator"; 952 953 string description = [{ 954 Performs concatenation across replicas. 955 956 See https://www.tensorflow.org/xla/operation_semantics#allgather 957 }]; 958 959 let arguments = (ins 960 HLO_Tensor:$operand, 961 I64Attr:$all_gather_dim, 962 I64ElementsAttr:$replica_groups, 963 OptionalAttr<StableHLO_ChannelHandle>:$channel_handle 964 ); 965 let results = (outs HLO_Tensor); 966 let hasVerifier = 1; 967} 968 969def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", 970 [HLO_CompatibleOperandsAndResultType]> { 971 let summary = "AllReduce operator"; 972 let description = [{ 973 Performs a custom reduction across replicas. 974 975 See https://www.tensorflow.org/xla/operation_semantics#allreduce. 976 }]; 977 978 let arguments = (ins 979 HLO_Tensor:$operand, 980 I64ElementsAttr:$replica_groups, 981 OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, 982 UnitAttr:$use_global_device_ids 983 ); 984 let regions = (region SizedRegion<1>:$computation); 985 let results = (outs HLO_Tensor); 986 // use_global_device_ids is rarely used, so we add a simplified 987 // builder method for convenience. 988 let builders = [ 989 OpBuilder<(ins 990 "::mlir::Type":$result_type, "::mlir::Value":$operand, 991 "::mlir::DenseIntElementsAttr":$replica_groups, 992 "::mlir::stablehlo::ChannelHandleAttr":$channel_handle)>]; 993} 994 995def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", 996 [SameOperandsAndResultElementType]> { 997 let summary = "ReduceScatter operator"; 998 let description = [{ 999 Performs all_reduce followed by a scatter. 1000 1001 See https://www.tensorflow.org/xla/operation_semantics#reducescatter 1002 }]; 1003 1004 let arguments = (ins 1005 HLO_Tensor:$operand, 1006 I64Attr:$scatter_dimension, 1007 I64ElementsAttr:$replica_groups, 1008 OptionalAttr<StableHLO_ChannelHandle>:$channel_handle 1009 ); 1010 let regions = (region SizedRegion<1>:$computation); 1011 let results = (outs HLO_Tensor); 1012 let hasVerifier = 1; 1013} 1014 1015def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", 1016 [NoSideEffect, SameOperandsElementType, SameOperandsShape, 1017 InferTensorType]> { 1018 1019 let arguments = (ins 1020 HLO_Tensor:$operand, 1021 I64Attr:$split_dimension, 1022 I64Attr:$concat_dimension, 1023 I64Attr:$split_count, 1024 I64ElementsAttr:$replica_groups 1025 ); 1026 let results = (outs HLO_Tensor); 1027} 1028 1029def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ 1030 RecursiveSideEffects, 1031 SameVariadicOperandSize, 1032 SingleBlockImplicitTerminator<"ReturnOp"> 1033 ]> { 1034 let summary = "Reduce operator"; 1035 let description = [{ 1036 Returns the result of executing a reduction function on one or more arrays 1037 in parallel. 1038 1039 See https://www.tensorflow.org/xla/operation_semantics#reduce. 1040 }]; 1041 let arguments = (ins 1042 Variadic<HLO_Tensor>:$operands, 1043 Variadic<HLO_Tensor>:$init_values, 1044 I64ElementsAttr:$dimensions 1045 ); 1046 1047 let results = (outs Variadic<HLO_Tensor>); 1048 1049 let builders = [ 1050 OpBuilder<(ins "ValueRange":$operands, "ValueRange":$init_values, 1051 "DenseIntElementsAttr":$dimensions)>]; 1052 1053 let hasCustomAssemblyFormat = 1; 1054 let hasVerifier = 1; 1055 1056 // TODO(hinsu): Verify that the attached body arguments and results are 1057 // compatible with reduce op's operands. 1058 let regions = (region SizedRegion<1>:$body); 1059} 1060 1061//===----------------------------------------------------------------------===// 1062// StableHLO tuple op definitions. 1063//===----------------------------------------------------------------------===// 1064def StableHLO_GetTupleElementOp: StableHLO_Op<"get_tuple_element", [NoSideEffect, 1065 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1066 let summary = "GetTupleElement operator"; 1067 let description = [{ 1068 Returns a member of a tuple specified by an index. 1069 1070 See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. 1071 }]; 1072 let arguments = (ins 1073 HLO_Tuple, 1074 I32Attr:$index 1075 ); 1076 1077 let results = (outs HLO_TensorOrTokenOrTuple); 1078 1079 let hasVerifier = 1; 1080} 1081 1082def StableHLO_TupleOp : StableHLO_Op<"tuple", [NoSideEffect, 1083 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1084 let summary = "XLA's tuple op"; 1085 let description = [{ 1086 Groups a set of tensor inputs into a single tuple object. 1087 1088 See https://www.tensorflow.org/xla/operation_semantics#tuple. 1089 }]; 1090 let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val); 1091 let results = (outs HLO_Tuple); 1092 1093 let hasVerifier = 1; 1094} 1095 1096def StableHLO_CompareOp: StableHLO_Op<"compare", [NoSideEffect, SameOperandsElementType, 1097 SameOperandsAndResultShape, Elementwise, InferTensorTypeWithReify]> { 1098 let summary = "Comparison operator"; 1099 let description = [{ 1100 Compares `lhs` and `rhs` elementwise according to `comparison_direction` 1101 and `compare_type`. If unspecified, `compare_type` is FLOAT for float element 1102 types, SIGNED for signed element types and UNSIGNED for unsigned element 1103 types. 1104 1105 See 1106 https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. 1107 1108 Example: 1109 1110 ```mlir 1111 %0 = stablehlo.compare LT, %arg0, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> 1112 %1 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> 1113 ``` 1114 }]; 1115 let arguments = (ins 1116 HLO_Tensor:$lhs, 1117 HLO_Tensor:$rhs, 1118 StableHLO_ComparisonDirectionAttr:$comparison_direction, 1119 OptionalAttr<StableHLO_ComparisonTypeAttr>:$compare_type 1120 ); 1121 let results = (outs HLO_PredTensor); 1122 1123 let builders = [ 1124 OpBuilder<(ins "Value":$lhs, "Value":$rhs, 1125 "::mlir::stablehlo::ComparisonDirection":$comparison_direction, 1126 CArg<"::mlir::stablehlo::ComparisonType", 1127 "::mlir::stablehlo::ComparisonType::NOTYPE">:$compare_type)>, 1128 ]; 1129 1130 let extraClassDeclaration = [{ 1131 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1132 return succeeded(mlir::verifyCompatibleShapes(l, r)); 1133 } 1134 }]; 1135 1136 let assemblyFormat = [{ 1137 $comparison_direction `,` $lhs `,` $rhs (`,` $compare_type^)? 1138 attr-dict `:` functional-type(operands, results) 1139 }]; 1140} 1141 1142//===----------------------------------------------------------------------===// 1143// StableHLO Slice definitions. 1144//===----------------------------------------------------------------------===// 1145 1146def StableHLO_SliceOp: StableHLO_Op< 1147 "slice", 1148 [NoSideEffect, SameOperandsAndResultElementType, 1149 AllTypesMatch<["start_indices", "limit_indices", "strides"]>, 1150 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1151 let arguments = (ins 1152 HLO_Tensor:$operand, 1153 I64ElementsAttr:$start_indices, 1154 I64ElementsAttr:$limit_indices, 1155 I64ElementsAttr:$strides 1156 ); 1157 1158 let results = (outs HLO_Tensor); 1159} 1160 1161def StableHLO_DynamicSliceOp: StableHLO_Op<"dynamic_slice", 1162 [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, 1163 InferTensorType]> { 1164 let summary = "Dynamic Slice operator"; 1165 let description = [{ 1166 Extracts a sub-array from the input array at dynamic start_indices. 1167 1168 See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. 1169 }]; 1170 let arguments = (ins 1171 HLO_Tensor:$operand, 1172 Variadic<HLO_ScalarIntTensor>:$start_indices, 1173 I64ElementsAttr:$slice_sizes 1174 ); 1175 1176 let results = (outs HLO_Tensor:$result); 1177 let hasVerifier = 1; 1178} 1179 1180def StableHLO_DynamicUpdateSliceOp: StableHLO_Op<"dynamic_update_slice", 1181 [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, 1182 AllShapesMatch<["operand", "result"]>]> { 1183 let summary = "Dynamic Update Slice operator"; 1184 let description = [{ 1185 DynamicUpdateSlice generates a result which is the value of the input array 1186 operand, with a slice update overwritten at start_indices. 1187 1188 See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. 1189 1190 Example: 1191 1192 ```mlir 1193 %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2 1194 : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32> 1195 ``` 1196 }]; 1197 let arguments = (ins 1198 HLO_Tensor:$operand, 1199 HLO_Tensor:$update, 1200 Variadic<HLO_ScalarIntTensor>:$start_indices 1201 ); 1202 let results = (outs HLO_Tensor:$result); 1203 let hasVerifier = 1; 1204 1205 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 1206} 1207 1208 1209//===----------------------------------------------------------------------===// 1210// StableHLO Other op definitions. 1211//===----------------------------------------------------------------------===// 1212 1213def StableHLO_BatchNormGradOp : StableHLO_Op<"batch_norm_grad", [NoSideEffect, 1214 AllShapesMatch<["scale", "mean", "variance", "grad_scale", 1215 "grad_offset"]>, 1216 AllShapesMatch<["operand", "grad_output"]>, 1217 AllElementTypesMatch<["operand", "grad_scale", "grad_offset"]>, 1218 AllTypesMatch<["operand", "grad_operand"]>]> { 1219 let summary = "Batch Normalization Gradient"; 1220 let description = [{ 1221 Calculates gradients of batch norm. 1222 1223 See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad 1224 }]; 1225 1226 let arguments = (ins 1227 RankedTensorOf<[HLO_Float]>:$operand, 1228 1DTensorOf<[HLO_Float]>:$scale, 1229 1DTensorOf<[HLO_Float]>:$mean, 1230 1DTensorOf<[HLO_Float]>:$variance, 1231 RankedTensorOf<[HLO_Float]>:$grad_output, 1232 F32Attr:$epsilon, 1233 I64Attr:$feature_index 1234 ); 1235 1236 let results = (outs Variadic<HLO_TensorOrToken>); 1237 let results = (outs 1238 RankedTensorOf<[HLO_Float]>:$grad_operand, 1239 1DTensorOf<[HLO_Float]>:$grad_scale, 1240 1DTensorOf<[HLO_Float]>:$grad_offset); 1241 1242 let hasVerifier = 1; 1243} 1244 1245def StableHLO_BatchNormInferenceOp : StableHLO_Op<"batch_norm_inference", 1246 [NoSideEffect, AllTypesMatch<["operand", "result"]>, 1247 AllShapesMatch<["scale", "offset", "mean", "variance"]>]> { 1248 let summary = "Batch Normalization for Inference"; 1249 let description = [{ 1250 Normalizes an array across batch and spatial dimensions. 1251 1252 See https://www.tensorflow.org/xla/operation_semantics#batchnorminference 1253 }]; 1254 1255 let arguments = (ins 1256 RankedTensorOf<[HLO_Float]>:$operand, 1257 1DTensorOf<[HLO_Float]>:$scale, 1258 1DTensorOf<[HLO_Float]>:$offset, 1259 1DTensorOf<[HLO_Float]>:$mean, 1260 1DTensorOf<[HLO_Float]>:$variance, 1261 F32Attr:$epsilon, 1262 I64Attr:$feature_index 1263 ); 1264 1265 let results = (outs RankedTensorOf<[HLO_Float]>:$result); 1266 1267 let hasVerifier = 1; 1268} 1269 1270def StableHLO_BatchNormTrainingOp : StableHLO_Op<"batch_norm_training", 1271 [NoSideEffect, AllTypesMatch<["operand", "output"]>, 1272 AllElementTypesMatch<["operand", "batch_mean", "batch_var"]>, 1273 AllShapesMatch<["scale", "offset", "batch_mean", "batch_var"]>]> { 1274 let summary = "Batch Normalization for Training"; 1275 let description = [{ 1276 Normalizes an array across batch and spatial dimensions. 1277 1278 See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining 1279 }]; 1280 1281 let arguments = (ins 1282 RankedTensorOf<[HLO_Float]>:$operand, 1283 1DTensorOf<[HLO_Float]>:$scale, 1284 1DTensorOf<[HLO_Float]>:$offset, 1285 F32Attr:$epsilon, 1286 I64Attr:$feature_index 1287 ); 1288 1289 let results = (outs 1290 RankedTensorOf<[HLO_Float]>:$output, 1291 1DTensorOf<[HLO_Float]>:$batch_mean, 1292 1DTensorOf<[HLO_Float]>:$batch_var); 1293 1294 let hasVerifier = 1; 1295} 1296 1297def StableHLO_BitcastConvertOp : StableHLO_ShapedInterfaceOp<"bitcast_convert", 1298 [NoSideEffect]> { 1299 let summary = "BitcastConvert operator"; 1300 let description = [{ 1301 Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast 1302 operation from a data shape to a target shape. The dimensions must match, 1303 and the conversion is an element-wise one. Bitcast is implemented as a 1304 low-level cast, so machines with different floating-point representations 1305 will give different results. 1306 1307 See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. 1308 1309 Example: 1310 1311 ```mlir 1312 %0 = stablehlo.bitcast_convert %arg0 : (tensor<2xi32>) -> tensor<2xf32> 1313 ``` 1314 }]; 1315 1316 let arguments = (ins HLO_Tensor:$operand); 1317 let results = (outs HLO_Tensor); 1318 let hasVerifier = 1; 1319 1320 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 1321} 1322 1323def StableHLO_BroadcastOp : StableHLO_ShapedInterfaceOp<"broadcast", 1324 [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> { 1325 let summary = "Broadcast a tensor to a higher rank by prepending dimensions"; 1326 let description = [{ 1327 Broadcasts the operand tensor to a higher rank by prepending 1328 `broadcast_sizes` to the dimensions. The current values of the operand are 1329 copied into the other dimensions. 1330 1331 This is a more limited form of broadcasting, that corresponds to the XLA 1332 client Broadcast method. For a more general form of broadcasting, see the 1333 BroadcastInDimOp. 1334 1335 See https://www.tensorflow.org/xla/operation_semantics#broadcast. 1336 }]; 1337 let arguments = (ins 1338 HLO_Tensor:$operand, 1339 I64ElementsAttr:$broadcast_sizes 1340 ); 1341 1342 let results = (outs HLO_Tensor); 1343 1344 let hasVerifier = 1; 1345} 1346 1347def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim", 1348 [NoSideEffect, SameOperandsAndResultElementType]> { 1349 let summary = "Broadcast a tensor into the given shape by adding dimensions."; 1350 let description = [{ 1351 Broadcasts the `operand` tensor to a higher rank. This is not the limited 1352 form of broadcasting exposed as the XLA client broadcast op, but rather the 1353 more powerful "InDim" broadcasting, which is closer to the HLO broadcast op 1354 and exposed in the XLA client BroadcastInDim method. 1355 1356 `broadcast_dimensions` maps the operand dimension number to the target shape 1357 dimension number. It must have the same size as the rank of the operand. The 1358 mapped dimensions must either be the same size or the dimension being 1359 broadcast from must be size 1 (degenerate broadcasting). 1360 1361 For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The 1362 The scalar value will be broadcast to every element in the target shape. 1363 1364 See https://www.tensorflow.org/xla/broadcasting. 1365 }]; 1366 let arguments = (ins 1367 HLO_Tensor:$operand, 1368 BroadcastDimAttr:$broadcast_dimensions 1369 ); 1370 1371 let results = (outs HLO_StaticShapeTensor); 1372 1373 let hasVerifier = 1; 1374} 1375 1376def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp< 1377 "dynamic_broadcast_in_dim", [NoSideEffect]> { 1378 let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; 1379 let description = [{ 1380 This is a generalization of the BroadcastInDimOp which accepts its output 1381 dimensions as an argument. It should eventually supercede the statically 1382 shaped original, but is being phased as a separate op in order to support 1383 compatibility with lowerings and translations that precede dynamic shapes. 1384 1385 The op accepts optional attributes to express static knowledge about the 1386 expanding behavior of dimensions. If not specified, all dimensions are 1387 assumed to be possibly expanding. The sets of dimensions that are known to 1388 be expanding and the set of dimensions that are known to be non-expanding 1389 must be disjoint and they must be a subset of the operand's dimensions. 1390 }]; 1391 let arguments = (ins 1392 HLO_Tensor:$operand, 1393 HLO_DimensionTensor:$output_dimensions, 1394 BroadcastDimAttr:$broadcast_dimensions, 1395 OptionalAttr<BroadcastDimAttr>:$known_expanding_dimensions, 1396 OptionalAttr<BroadcastDimAttr>:$known_nonexpanding_dimensions 1397 ); 1398 1399 let results = (outs HLO_Tensor); 1400 1401 let builders = [ 1402 OpBuilder<(ins 1403 "Type":$result_type, "Value":$operand, "Value":$output_dimensions, 1404 "DenseIntElementsAttr":$broadcast_dimensions), [{ 1405 build($_builder, $_state, result_type, operand, output_dimensions, 1406 broadcast_dimensions, /*known_expanding_dimensions=*/{}, 1407 /*known_nonexpanding_dimensions=*/{}); 1408 }]> 1409 ]; 1410 1411 let hasVerifier = 1; 1412} 1413 1414// Note: There is no HLO_CallOp because the standard call operation mlir::func::CallOp 1415// is used instead. A mlir::func::CallOp is exported to a HLO call instruction 1416// directly. 1417 1418def StableHLO_CholeskyOp : StableHLO_Op<"cholesky", 1419 [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> { 1420 let summary = "Cholesky operator"; 1421 let description = [{ 1422 Computes the Cholesky decomposition of a batch of symmetric (Hermitian) 1423 positive definite matrices. 1424 1425 If lower is true, computes lower-triangular matrices l such that 1426 `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such 1427 that `a=Transpose(u).u`. 1428 1429 Input data is read only from the lower/upper triangle of a, depending on the 1430 value of lower. Values from the other triangle are ignored. Output data is 1431 returned in the same triangle; the values in the other triangle are 1432 implementation-defined and may be anything. 1433 1434 If the rank of a is greater than 2, a is treated as a batch of matrices, where 1435 all except the minor 2 dimensions are batch dimensions. 1436 1437 If a is not symmetric (Hermitian) positive definite, the result is 1438 implementation-defined. 1439 1440 See https://www.tensorflow.org/xla/operation_semantics#cholesky. 1441 }]; 1442 let arguments = (ins 1443 HLO_FpOrComplexTensor:$a, 1444 DefaultValuedAttr<BoolAttr, "false">:$lower 1445 ); 1446 1447 let results = (outs HLO_FpOrComplexTensor); 1448} 1449 1450def StableHLO_ClampOp : StableHLO_ShapedInterfaceOp<"clamp", [NoSideEffect, 1451 SameOperandsAndResultElementType, HLO_BroadcastingElementwise, 1452 InferTensorType]> { 1453 let summary = "Clamp operator"; 1454 let description = [{ 1455 Clamps an operand to within the range between a minimum and maximum value. 1456 1457 Note: All three arrays must be the same shape. Alternatively, as a 1458 restricted form of broadcasting, min and/or max can be a scalar (0D 1459 tensor) of the element type of the tensor operand. 1460 1461 See https://www.tensorflow.org/xla/operation_semantics#clamp. 1462 1463 Example: 1464 1465 ```mlir 1466 %0 = stablehlo.clamp %arg0, %arg1, %arg2 : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 1467 ``` 1468 }]; 1469 1470 let arguments = (ins 1471 HLO_Tensor:$min, 1472 HLO_Tensor:$operand, 1473 HLO_Tensor:$max 1474 ); 1475 let results = (outs HLO_Tensor); 1476 1477 let hasVerifier = 1; 1478 1479 let extraClassDeclaration = [{ 1480 // Method from InferTypeOpInterface interface. 1481 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1482 if (l.size() != r.size()) return false; 1483 for (auto [lt, rt] : llvm::zip(l, r)) 1484 if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt)) 1485 return false; 1486 return true; 1487 } 1488 }]; 1489 1490 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 1491} 1492 1493def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate", 1494 [NoSideEffect, SameOperandsAndResultElementType, 1495 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1496 let summary = "XLA's concatenate op"; 1497 let description = [{ 1498 Concatenates a set of tensors along the specified dimension. 1499 1500 See https://www.tensorflow.org/xla/operation_semantics#concatenate. 1501 }]; 1502 1503 let arguments = (ins 1504 Variadic<HLO_Tensor>:$val, 1505 I64Attr: $dimension 1506 ); 1507 1508 let results = (outs HLO_Tensor); 1509 1510 let hasVerifier = 1; 1511 1512 let extraClassDeclaration = [{ 1513 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1514 return succeeded(mlir::verifyCompatibleShapes(l, r)); 1515 } 1516 }]; 1517} 1518 1519def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute", 1520 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 1521 let summary = "CollectivePermute operator"; 1522 let description = [{ 1523 CollectivePermute is a collective operation that sends and receives data 1524 cross replicas. 1525 Note that there are the following restrictions on the source_target_pair: 1526 - Any two pairs should not have the same target replica id, and they should 1527 not have the same source replica id. 1528 - If a replica id is not a target in any pair, then the output on that 1529 replica is a tensor consists of 0(s) with the same shape as the input. 1530 1531 See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. 1532 1533 }]; 1534 1535 let arguments = (ins 1536 HLO_Tensor:$operand, 1537 I64ElementsAttr:$source_target_pairs 1538 ); 1539 let results = (outs HLO_Tensor); 1540 let hasVerifier = 1; 1541} 1542 1543def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [NoSideEffect]> { 1544 let summary = "Convolution operator"; 1545 let description = [{ 1546 Computes a convolution of the kind used in neural networks. 1547 1548 See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. 1549 }]; 1550 let arguments = !con( 1551 (ins 1552 HLO_Tensor:$lhs, 1553 HLO_Tensor:$rhs), 1554 StableHLO_ConvolutionAttributes.attributes); 1555 1556 let results = (outs HLO_Tensor); 1557 let hasVerifier = 1; 1558 1559 code extraClassDeclaration = [{ 1560 bool hasWindowReversal() { 1561 auto reversal = window_reversalAttr(); 1562 return reversal && llvm::any_of(reversal.getValues<bool>(), 1563 [](bool v) { return v; }); 1564 } 1565 }]; 1566 1567 let assemblyFormat = [{ 1568 `(`operands`)` 1569 `dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,` 1570 `window` `=` `{` custom<WindowAttributes>($window_strides, $padding, 1571 $lhs_dilation, $rhs_dilation, 1572 $window_reversal) `}` 1573 attr-dict `:` functional-type(operands, results) 1574 }]; 1575} 1576 1577def StableHLO_CrossReplicaSumOp : StableHLO_Op<"cross-replica-sum", 1578 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 1579 let summary = "Sums input across replicated instances."; 1580 let description = [{ 1581 For each of the replica groups, operands of the group devices are summed 1582 so that each device has the sum. 1583 1584 For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. 1585 Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, 1586 and `B, D, F, H` as group 1. Thus we get the outputs: 1587 `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. 1588 1589 See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. 1590 }]; 1591 1592 let arguments = (ins 1593 HLO_Tensor:$operand, 1594 I64ElementsAttr:$replica_groups 1595 ); 1596 1597 let results = (outs HLO_Tensor); 1598} 1599 1600def StableHLO_CustomCallOp: StableHLO_Op<"custom_call", 1601 [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { 1602 let summary = "CustomCall operator"; 1603 let description = [{ 1604 A custom call invokes code external to XLA. The `args` are passed to the 1605 external code, and the external code is expected to produce a result of the 1606 given type. The exact mechanism is backend-specific. For example, in the CPU 1607 backend, a call instruction is emitted which targets a symbol with the name 1608 `call_target_name`. 1609 1610 `call_target_name` and `backend_config` can be arbitrary strings, but 1611 `call_target_name` should be short as it may be used in labels. 1612 `backend_config` can encode arbitrarily large amounts of information. 1613 1614 `has_side_effect` must be true if the custom call has side-effects. 1615 `api_version` specifies the version of the API used by the custom call 1616 function. 1617 1618 A custom call may apply functions within the scope of the parent module. 1619 They can be referenced using `called_computations` attribute. 1620 1621 A custom call can also have layout constraints on operands and results which 1622 can be specified as optional `operand_layouts` and `result_layouts` 1623 attributes. The layout attribute is an array of rank-1 index tensors and the 1624 i-th layout attribute specifies the layout for i-th operand/result. 1625 1626 The `operand_layouts` & `result_layouts` attributes can be specified under 1627 the following constraints: 1628 1) Either both `operand_layouts` and `result_layouts` are specified or none. 1629 2) None of the operands are of tuple type. 1630 3) None of the results are of tuple type except the common case of single 1631 tuple result packing non-tuple values is allowed. In this case the i-th 1632 `result_layouts` attribute specifies the layout of i-th element in the 1633 result tuple. 1634 1635 See https://www.tensorflow.org/xla/operation_semantics#customcall. 1636 }]; 1637 let arguments = (ins 1638 Variadic<HLO_TensorOrTokenOrTuple>:$operands, 1639 StrAttr:$call_target_name, 1640 DefaultValuedAttr<BoolAttr, "false">:$has_side_effect, 1641 DefaultValuedStrAttr<StrAttr, "">:$backend_config, 1642 // TODO(b/189822916): Remove this field when all clients are migrated to 1643 // the status-returning API. 1644 DefaultValuedAttr< 1645 StableHLO_CustomCallApiVersionAttr, 1646 "::mlir::stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL">: 1647 $api_version, 1648 DefaultValuedAttr<StableHLO_FlatSymbolRefArrayAttr, "{}">:$called_computations, 1649 OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts, 1650 OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts 1651 ); 1652 let results = (outs Variadic<HLO_TensorOrTokenOrTuple>); 1653 let hasVerifier = 1; 1654} 1655 1656def StableHLO_DotOp: StableHLO_Op<"dot", 1657 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1658 let summary = "Dot operator"; 1659 let description = [{ 1660 Performs dot products between vectors, vector/matrix and matrix/matrix 1661 multiplication. 1662 1663 See https://www.tensorflow.org/xla/operation_semantics#dot. 1664 }]; 1665 let arguments = ( 1666 ins HLO_Tensor:$lhs, 1667 HLO_Tensor:$rhs, 1668 StableHLO_PrecisionConfigAttr:$precision_config 1669 ); 1670 let results = (outs HLO_Tensor); 1671 let hasVerifier = 1; 1672 1673 let extraClassDeclaration = [{ 1674 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1675 return succeeded(mlir::verifyCompatibleShapes(l, r)); 1676 } 1677 }]; 1678} 1679 1680def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general", [NoSideEffect]> { 1681 let summary = "General Dot operator"; 1682 let description = [{ 1683 Performs general dot products between vectors, vector/matrix and 1684 matrix/matrix multiplication. 1685 1686 See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. 1687 }]; 1688 let arguments = (ins 1689 HLO_Tensor:$lhs, 1690 HLO_Tensor:$rhs, 1691 StableHLO_DotDimensionNumbers:$dot_dimension_numbers, 1692 StableHLO_PrecisionConfigAttr:$precision_config 1693 ); 1694 1695 let results = (outs HLO_Tensor); 1696 let hasVerifier = 1; 1697} 1698 1699// Define Base Einsum op within the HLO dialect as these are client ops and 1700// therefore this class is not common between HLO and LHLO ops. 1701class BASE_EinsumOp { 1702 string summary = "Einsum operator"; 1703 1704 string description = [{ 1705 Returns a tensor whose elements are defined by equation, which is written 1706 in a shorthand form inspired by the Einstein summation convention. 1707 }]; 1708} 1709 1710def StableHLO_EinsumOp: StableHLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp { 1711 let arguments = (ins 1712 HLO_Tensor:$lhs, 1713 HLO_Tensor:$rhs, 1714 StrAttr:$einsum_config 1715 ); 1716 1717 let results = (outs HLO_Tensor); 1718 1719 // TODO(hinsu): Canonicalize to lower this client side HLO op to server 1720 // side HLO ops. 1721} 1722 1723def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp { 1724 let arguments = (ins 1725 HLO_Tensor:$operand, 1726 StrAttr:$einsum_config 1727 ); 1728 1729 let results = (outs HLO_Tensor); 1730} 1731 1732def StableHLO_FftOp: StableHLO_Op<"fft", [InferTensorType, NoSideEffect]> { 1733 let summary = "Fast fourier transform operator"; 1734 let description = [{ 1735 Returns the fast-fourier-transform of the input array. 1736 1737 See 1738 https://www.tensorflow.org/xla/operation_semantics#fft. 1739 }]; 1740 let arguments = (ins 1741 HLO_Tensor:$operand, 1742 StableHLO_FftTypeAttr: $fft_type, 1743 I64ElementsAttr:$fft_length 1744 ); 1745 1746 let results = (outs HLO_Tensor); 1747} 1748 1749def StableHLO_GatherOp: StableHLO_Op<"gather", [InferTensorTypeWithReify, NoSideEffect]> { 1750 let summary = "Gather operator"; 1751 let description = [{ 1752 Stitches together several slices of `operand` from offsets specified in 1753 `start_indices` (each slice at a potentially different runtime offset). 1754 1755 See https://www.tensorflow.org/xla/operation_semantics#gather. 1756 }]; 1757 1758 let arguments = (ins 1759 HLO_Tensor:$operand, 1760 HLO_IntTensor:$start_indices, 1761 StableHLO_GatherDimensionNumbers:$dimension_numbers, 1762 I64ElementsAttr:$slice_sizes, 1763 DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted 1764 ); 1765 1766 let results = (outs HLO_Tensor); 1767 1768 let extraClassDeclaration = [{ 1769 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1770 return succeeded(mlir::verifyCompatibleShapes(l, r)); 1771 } 1772 }]; 1773} 1774 1775def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size", [NoSideEffect]> { 1776 let summary = "GetDimensionSize operator"; 1777 let description = [{ 1778 Returns the size of the given dimension of the operand. 1779 1780 See 1781 https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. 1782 }]; 1783 let arguments = (ins 1784 HLO_Tensor:$operand, 1785 I64Attr:$dimension 1786 ); 1787 // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the 1788 // XLA semantics is available. This limitation is because of the current XLA 1789 // implementation. 1790 let results = (outs I32Tensor); 1791 1792 let hasVerifier = 1; 1793} 1794 1795def StableHLO_MapOp: StableHLO_ShapedInterfaceOp<"map", 1796 [RecursiveSideEffects, SameOperandsAndResultShape, 1797 SingleBlockImplicitTerminator<"ReturnOp">]> { 1798 let summary = "Map operator"; 1799 let description = [{ 1800 Applies a scalar function over the given operands arrays, producing an array 1801 of the same dimensions where each element is the result of the mapped function 1802 applied to the corresponding elements in the input arrays. 1803 1804 The mapped function is an arbitrary computation with the restriction that it 1805 has N inputs of scalar type T and a single output with type S. The output has 1806 the same dimensions as the operands except that the element type T is replaced 1807 with S. 1808 1809 See https://www.tensorflow.org/xla/operation_semantics#map. 1810 }]; 1811 let arguments = (ins 1812 Variadic<HLO_Tensor>:$operands, 1813 I64ElementsAttr:$dimensions 1814 ); 1815 let regions = (region SizedRegion<1>:$computation); 1816 let results = (outs HLO_Tensor); 1817 let hasVerifier = 1; 1818} 1819 1820def StableHLO_ReshapeOp: StableHLO_Op<"reshape", 1821 [NoSideEffect, SameOperandsAndResultElementType]> { 1822 let summary = "Reshape operator"; 1823 let description = [{ 1824 Reshapes the dimensions of `operand` into a new configuration. 1825 1826 See https://www.tensorflow.org/xla/operation_semantics#reshape. 1827 1828 Example: 1829 1830 ```mlir 1831 %0 = stablehlo.reshape %arg0 : (tensor<2xf32>) -> tensor<1x2xf32> 1832 ``` 1833 }]; 1834 1835 let arguments = (ins HLO_Tensor:$operand); 1836 1837 let results = (outs HLO_StaticShapeTensor); 1838 let hasVerifier = 1; 1839 1840 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 1841} 1842 1843def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> { 1844 let summary = "Reshape a tensor to a given, possibly dynamic, shape."; 1845 let description = [{ 1846 Reshapes `operand` to `output_shape`. 1847 1848 Requires: 1849 - The length of `output_shape` is equal to the rank of `result`. 1850 - The number of elements in `operand` (that is, the product of extents of 1851 its shape) is equal to the number of elements in `output_shape` (that is, 1852 the product of values in `output_shape`). 1853 1854 Example: 1855 1856 ```mlir 1857 %0 = stablehlo.dynamic_reshape %arg0, %shape : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> 1858 ``` 1859 }]; 1860 1861 let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); 1862 let results = (outs HLO_Tensor:$result); 1863 1864 let hasVerifier = 1; 1865 1866 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 1867} 1868 1869def StableHLO_ScatterOp: StableHLO_Op<"scatter", [SameVariadicOperandSize, RecursiveSideEffects]> { 1870 let summary = "Scatter operator"; 1871 let description = [{ 1872 Generates a result which is the value of the input array `operand`, 1873 with several slices (at indices specified by `scatter_indices`) 1874 updated with the values in `updates` using `update_computation`. 1875 1876 See https://www.tensorflow.org/xla/operation_semantics#scatter. 1877 }]; 1878 let arguments = (ins 1879 Variadic<HLO_Tensor>:$operands, 1880 TensorOf<[AnyInteger, Index]>:$scatter_indices, 1881 Variadic<HLO_Tensor>:$updates, 1882 StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, 1883 DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, 1884 DefaultValuedAttr<BoolAttr, "false">:$unique_indices 1885 ); 1886 1887 let regions = (region SizedRegion<1>:$update_computation); 1888 1889 let results = (outs Variadic<HLO_Tensor>); 1890 1891 let hasVerifier = 1; 1892} 1893 1894def StableHLO_SelectOp: StableHLO_Op<"select", [NoSideEffect, HLO_BroadcastingElementwise, 1895 InferTensorTypeWithReify]> { 1896 let summary = "Select operator"; 1897 let description = [{ 1898 Constructs an output tensor from the elements of `on_true` and `on_false` 1899 based on the values of `pred`. All three operands must be of the same shape 1900 with the exception of `pred`, which may also be a scalar in which case it is 1901 broadcasted. 1902 1903 See https://www.tensorflow.org/xla/operation_semantics#select. 1904 }]; 1905 let arguments = (ins 1906 HLO_PredTensor:$pred, 1907 HLO_Tensor:$on_true, 1908 HLO_Tensor:$on_false 1909 ); 1910 1911 let results = (outs HLO_Tensor); 1912 1913 let hasVerifier = 1; 1914 1915 let extraClassDeclaration = [{ 1916 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1917 return succeeded(mlir::verifyCompatibleShapes(l, r)); 1918 } 1919 }]; 1920} 1921 1922def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", 1923 [RecursiveSideEffects]> { 1924 let summary = "SelectAndScatter operator"; 1925 let description = [{ 1926 Runs a windowed selection `select` function over `operand` with shape 1927 `window_dimensions` and stride `window_strides`. This will produce an amount 1928 of selected locations whose shape matches `source`. These are then scattered 1929 to the output which is initialized with `init_value`. 1930 Multiple scattered elements which land in the same output location are 1931 combined using the `scatter` function. 1932 1933 See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. 1934 }]; 1935 let arguments = (ins 1936 HLO_Tensor:$operand, 1937 HLO_Tensor:$source, 1938 HLO_Tensor:$init_value, 1939 OptionalAttr<I64ElementsAttr>:$window_dimensions, 1940 OptionalAttr<I64ElementsAttr>:$window_strides, 1941 OptionalAttr<I64ElementsAttr>:$padding 1942 ); 1943 1944 let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); 1945 1946 let results = (outs HLO_Tensor); 1947 1948 let hasVerifier = 1; 1949} 1950 1951def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [NoSideEffect, 1952 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 1953 let summary = "SetDimensionSize operator"; 1954 let description = [{ 1955 Sets the dynamic size of operand's given dimension. Pass through the operand 1956 as result, with dynamic dimension tracked by the compiler. Padded values 1957 will be ignored by downstream reduction ops. 1958 1959 See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. 1960 }]; 1961 let arguments = (ins 1962 HLO_Tensor:$operand, 1963 I32Tensor:$size, 1964 I64Attr:$dimension 1965 ); 1966 let results = (outs HLO_Tensor); 1967 1968 let extraClassDeclaration = [{ 1969 // Method from InferTypeOpInterface interface. 1970 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 1971 if (l.size() != r.size()) return false; 1972 for (auto [lt, rt] : llvm::zip(l, r)) 1973 if (!mlir::hlo::isCompatibleForHloTypeInference(lt, rt)) 1974 return false; 1975 return true; 1976 } 1977 }]; 1978 1979 let hasVerifier = 1; 1980} 1981 1982def StableHLO_SortOp : StableHLO_Op<"sort", [RecursiveSideEffects, 1983 SameOperandsAndResultShape]> { 1984 let summary = "Sort operator"; 1985 let description = [{ 1986 Sorts the given `operands` at the given `dimension` with the given 1987 `comparator`. 1988 1989 See https://www.tensorflow.org/xla/operation_semantics#sort. 1990 }]; 1991 let arguments = (ins 1992 Variadic<HLO_Tensor>:$operands, 1993 DefaultValuedAttr<I64Attr, "-1">:$dimension, 1994 DefaultValuedAttr<BoolAttr, "false">:$is_stable 1995 ); 1996 1997 let results = (outs Variadic<HLO_Tensor>); 1998 1999 let regions = (region SizedRegion<1>:$comparator); 2000 2001 let builders = [ 2002 OpBuilder<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension, 2003 CArg<"bool", "false">:$is_stable)>]; 2004 2005 let hasVerifier = 1; 2006} 2007 2008def StableHLO_ReverseOp: StableHLO_Op<"reverse", 2009 [NoSideEffect, HLO_CompatibleOperandsAndResultType]> { 2010 let summary = "Reverse operator"; 2011 let description = [{ 2012 Reverses the specified dimensions of `operand` according to the given 2013 `dimensions`. 2014 2015 See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. 2016 }]; 2017 let arguments = (ins 2018 HLO_Tensor:$operand, 2019 I64ElementsAttr:$dimensions 2020 ); 2021 2022 let results = (outs HLO_Tensor); 2023} 2024 2025def StableHLO_PadOp: StableHLO_ShapedInterfaceOp<"pad", 2026 [NoSideEffect, SameOperandsAndResultElementType, InferTensorType]> { 2027 let summary = "Pad operator"; 2028 let description = [{ 2029 Pads edges and between the elements of `operand` with the `padding_value` 2030 according to the configuration parameters described below. 2031 2032 `edge_padding_low` and `edge_padding_high` specify the amount of padding 2033 added at the low-end (next to index 0) and the high-end (next to the 2034 highest index) of each dimension respectively. The amount of edge 2035 padding can be negative -- the absolute value of negative padding indicates 2036 the number of elements to remove from the specified dimension. 2037 2038 `interior_padding` specifies the amount of padding (non-negative) added 2039 between any two elements in each dimension. Interior padding occurs 2040 logically before edge padding, so in the case of negative edge padding, 2041 elements are removed from the interior-padded operand. 2042 2043 This operation is a no-op if, for all dimensions, the edge padding pairs are 2044 all (0, 0) and the interior padding values are all 0. The figure below shows 2045 examples of different `edge_padding` and `interior_padding` values for a 2046 two-dimensional array. 2047 2048  2049 2050 }]; 2051 let arguments = (ins 2052 HLO_Tensor:$operand, 2053 HLO_Tensor:$padding_value, 2054 I64ElementsAttr: $edge_padding_low, 2055 I64ElementsAttr: $edge_padding_high, 2056 I64ElementsAttr: $interior_padding 2057 ); 2058 2059 let results = (outs HLO_Tensor); 2060} 2061 2062def StableHLO_TraceOp: StableHLO_Op<"trace", []> { 2063 let summary = "Trace operator"; 2064 let description = [{ 2065 Emits a logging message `tag` with the `operand`. 2066 2067 Example: 2068 2069 ```mlir 2070 stablehlo.trace %arg0, "In test code." : tensor<5x1x5xi32> 2071 ``` 2072 }]; 2073 let arguments = (ins 2074 HLO_Tensor:$operand, 2075 StrAttr:$tag 2076 ); 2077 let assemblyFormat = "$operand `,` $tag attr-dict `:` type($operand)"; 2078} 2079 2080def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose", 2081 [NoSideEffect, SameOperandsAndResultElementType, 2082 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 2083 let summary = "Transpose operator"; 2084 let description = [{ 2085 Permutes the dimensions of `operand` according to the given `permutation`. 2086 2087 `res_dimensions[i] = operand_dimensions[permutation[i]]` 2088 2089 See https://www.tensorflow.org/xla/operation_semantics#transpose. 2090 }]; 2091 let arguments = (ins 2092 HLO_Tensor:$operand, 2093 I64ElementsAttr:$permutation 2094 ); 2095 let results = (outs HLO_Tensor); 2096 2097 let extraClassDeclaration = [{ 2098 // Method from InferTypeOpInterface interface. 2099 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 2100 return succeeded(mlir::verifyCompatibleShapes(l, r)); 2101 } 2102 }]; 2103} 2104 2105def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve", 2106 [NoSideEffect, SameOperandsAndResultElementType]> { 2107 let summary = "TriangularSolve operator"; 2108 let description = [{ 2109 Solves systems of linear equations with lower or upper triangular 2110 coefficient matrices by forward- or back-substitution. Broadcasting along 2111 leading dimensions, this routine solves one of the matrix systems 2112 op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where 2113 op(a) is either op(a) = a, or op(a) = Transpose(a), or 2114 op(a) = Conj(Transpose(a)). 2115 2116 Input data is read only from the lower/upper triangle of a, depending on the 2117 value of lower. Values from the other triangle are ignored. Output data is 2118 returned in the same triangle; the values in the other triangle are 2119 implementation-defined and may be anything. 2120 2121 If the rank of a and b are greater than 2, they are treated as batches of 2122 matrices, where all except the minor 2 dimensions are batch dimensions. a 2123 and b must have equal batch dimensions. 2124 2125 See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. 2126 }]; 2127 let arguments = (ins 2128 HLO_FpOrComplexTensor:$a, 2129 HLO_FpOrComplexTensor:$b, 2130 BoolAttr:$left_side, 2131 BoolAttr:$lower, 2132 BoolAttr:$unit_diagonal, 2133 StableHLO_TransposeAttr:$transpose_a 2134 ); 2135 let results = (outs HLO_FpOrComplexTensor); 2136 2137 let hasVerifier = 1; 2138} 2139 2140def StableHLO_ReduceWindowOp: StableHLO_Op<"reduce_window", [ 2141 RecursiveSideEffects, 2142 SameVariadicOperandSize, 2143 SingleBlockImplicitTerminator<"ReturnOp"> 2144 ]> { 2145 let summary = "ReduceWindow operator"; 2146 let description = [{ 2147 Returns the result of executing a reduction function over all elements in 2148 each window of one or more arrays in parallel. 2149 2150 See https://www.tensorflow.org/xla/operation_semantics#reducewindow. 2151 }]; 2152 2153 // TODO(hinsu): Verify that padding attribute is 2-d and the remaining 2154 // attributes are 1-d. Attributes' leading dimension should match rank of the 2155 // operands. 2156 let arguments = (ins 2157 Variadic<HLO_Tensor>:$operands, 2158 Variadic<HLO_Tensor>:$init_values, 2159 I64ElementsAttr:$window_dimensions, 2160 // If strides or dilations attributes are missing then the default value is 2161 // one for each of the operand dimensions. Similarly, padding values are zero 2162 // for both low and high in each of the dimensions, if not specified. 2163 OptionalAttr<I64ElementsAttr>:$window_strides, 2164 OptionalAttr<I64ElementsAttr>:$base_dilations, 2165 OptionalAttr<I64ElementsAttr>:$window_dilations, 2166 OptionalAttr<I64ElementsAttr>:$padding 2167 ); 2168 2169 let results = (outs Variadic<HLO_Tensor>); 2170 2171 // TODO(hinsu): Verify that the attached body arguments and results are 2172 // compatible with reduce op's operands. 2173 let regions = (region SizedRegion<1>:$body); 2174 2175 // Builder for non-variadic version of the operation. 2176 let builders = [ 2177 OpBuilder<(ins "Type":$result_type, "Value":$operand, 2178 "Value":$init_value, 2179 "DenseIntElementsAttr":$window_dimensions, 2180 "DenseIntElementsAttr":$window_strides, 2181 "DenseIntElementsAttr":$base_dilations, 2182 "DenseIntElementsAttr":$window_dilations, 2183 "DenseIntElementsAttr":$padding), 2184 [{ 2185 build($_builder, $_state, TypeRange(result_type), ValueRange(operand), 2186 ValueRange(init_value), window_dimensions, window_strides, 2187 base_dilations, window_dilations, padding); 2188 }]> 2189 ]; 2190 2191 let hasVerifier = 1; 2192 // TODO(hinsu): Implement custom printer and parser. 2193 2194 let extraClassDeclaration = [{ 2195 // Get the operation used for reduction applied to `result_index`th result. 2196 Operation *getReductionOp(int result_index); 2197 }]; 2198} 2199 2200def StableHLO_ReturnOp : StableHLO_Op<"return", [NoSideEffect, Terminator]> { 2201 let summary = [{ 2202 The `hlo.return` operation terminates a region and returns values. 2203 2204 Example: 2205 2206 ```mlir 2207 %0 = stablehlo.reduce %arg0, %arg1 { 2208 ... 2209 stablehlo.return %1 : tensor<f32> 2210 } 2211 ``` 2212 }]; 2213 2214 let arguments = (ins 2215 Variadic<HLO_TensorOrTokenOrTuple >:$results 2216 ); 2217 2218 let assemblyFormat = "$results attr-dict (`:` type($results)^)?"; 2219} 2220 2221def StableHLO_TorchIndexSelectOp : StableHLO_Op<"torch_index_select", [NoSideEffect]> { 2222 let arguments = (ins 2223 HLO_Tensor:$operand, 2224 HLO_Tensor:$index, 2225 I64Attr:$dim, 2226 I64Attr:$batch_dims 2227 ); 2228 2229 let results = (outs HLO_Tensor); 2230 2231 // TODO(hinsu): Canonicalize to lower this client side HLO op to server 2232 // side HLO ops. 2233} 2234 2235def StableHLO_OptimizationBarrierOp : StableHLO_Op<"optimization_barrier", 2236 [NoSideEffect, HLO_PairwiseSameOperandAndResultType]> { 2237 let summary = [{ 2238 The `stablehlo.optimization_barrier` op blocks optimizations. 2239 2240 Example: 2241 2242 ```mlir 2243 %0:2 = stablehlo.optimization_barrier %arg0, %arg1 : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) 2244 ``` 2245 }]; 2246 2247 let description = [{ 2248 Blocks any optimization pass from moving computations across the barrier. 2249 2250 Ensures that all inputs are evaluated before any operators that depend on the barrier's outputs. 2251 See 2252 https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier 2253 }]; 2254 2255 let arguments = (ins Variadic<HLO_TensorOrToken>:$operand); 2256 2257 let results = (outs Variadic<HLO_TensorOrToken>); 2258 2259 // TODO(b/241767462): Enhance type printing to condense pairwise ops. 2260 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 2261} 2262 2263//===----------------------------------------------------------------------===// 2264// StableHLO RNG Operators. 2265//===----------------------------------------------------------------------===// 2266 2267def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { 2268 let summary = "RNG with uniform distribution."; 2269 let description = [{ 2270 Constructs an output of a given shape with random numbers generated 2271 following the given `rng_distribution` with two parameters: 2272 `UNIFORM`: the uniform distribution over the interval `[a,b)`. The parameters 2273 and output element type have to be a boolean type, an integral type or a 2274 floating point types, and the types have to be consistent. 2275 2276 See https://www.tensorflow.org/xla/operation_semantics#rnguniform. 2277 2278 `NORMAL`: the normal distribution with parameters `mu` (=`a`) and 2279 `sigma` (=`b`). The parameters and output shape have to have a 2280 floating point elemental type. The parameters furthermore have 2281 to be scalar valued. 2282 2283 See https://www.tensorflow.org/xla/operation_semantics#rngnormal. 2284 }]; 2285 let arguments = (ins 2286 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a, 2287 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b, 2288 HLO_DimensionTensor:$shape, 2289 StableHLO_RngDistributionAttr:$rng_distribution 2290 ); 2291 2292 let results = (outs HLO_PredIntOrFpTensor:$result); 2293 2294 let hasVerifier = 1; 2295 2296 let extraClassDeclaration = [{ 2297 // Returns whether the return types are compatible. 2298 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 2299 return succeeded(::mlir::verifyCompatibleShapes(l, r)); 2300 } 2301 }]; 2302} 2303 2304def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [NoSideEffect]> { 2305 let summary = "Uniform random number generator operator"; 2306 let description = [{ 2307 Returns an output with a given shape filled with uniform random bits using 2308 the specified algorithm (or backend default) and returns an updated state 2309 (with the same shape as initial state) and the generated random data. 2310 2311 See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. 2312 }]; 2313 let arguments = (ins 2314 StableHLO_RngAlgorithmAttr:$rng_algorithm, 2315 HLO_IntOrFpTensor:$initial_state 2316 ); 2317 2318 let results = (outs 2319 HLO_IntOrFpTensor:$output_state, 2320 HLO_IntOrFpTensor:$output 2321 ); 2322 2323 let hasVerifier = 1; 2324} 2325 2326//===----------------------------------------------------------------------===// 2327// StableHLO Quantize Operator. 2328//===----------------------------------------------------------------------===// 2329 2330// TODO(b/230662142): Implement unknown scales/zero_point cases. 2331def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize", 2332 [NoSideEffect], TensorOf<[F32, BF16, HLO_QuantizedInt]>, 2333 HLO_QuantizedIntTensor> { 2334 let summary = "Uniform quantize operator"; 2335 let description = [{ 2336 Converts floating point tensors or uniform quantized integer tensors to 2337 uniform quantized integer tensors according to the quantization parameters 2338 defined by the output type. 2339 2340 Example: 2341 2342 ```mlir 2343 %0 = stablehlo.uniform_quantize %arg0 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>> 2344 ``` 2345 }]; 2346} 2347 2348def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize", 2349 [InferTensorType, NoSideEffect], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> { 2350 let summary = "Uniform dequantize operator"; 2351 let description = [{ 2352 Converts quantized array of integers to floating-points according to the 2353 quantization parameters defined by the input type. 2354 2355 Example: 2356 2357 ```mlir 2358 %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32> 2359 ``` 2360 }]; 2361} 2362 2363def StableHLO_ReducePrecisionOp : 2364 StableHLO_Op<"reduce_precision", [HLO_CompatibleOperandsAndResultType]> { 2365 let summary = "Reduce precision operator"; 2366 let description = [{ 2367 Models the effect of converting floating - point values to a lower - 2368 precision format(such as IEEE - FP16) and back to the original 2369 format. The number of exponent and mantissa bits in the lower - 2370 precision format can be specified arbitrarily, 2371 although all bit sizes may not be supported on all hardware 2372 implementations. 2373 2374 See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. 2375 }]; 2376 let arguments = (ins 2377 HLO_FpTensor:$operand, 2378 I32Attr:$exponent_bits, 2379 I32Attr:$mantissa_bits 2380 ); 2381 let hasVerifier = 1; 2382 let results = (outs HLO_FpTensor:$output); 2383} 2384 2385def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< 2386 "real_dynamic_slice", 2387 [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, 2388 AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { 2389 let summary = "Real Dynamic Slice operator"; 2390 let description = [{ 2391 The dynamic shape version of SliceOp. Extracts a sub-array from the input 2392 array according to start_indices, limit_indices and strides. Expect 2393 start_indices/limit_indices/strides to be statically shaped and matching 2394 the rank of the input. 2395 2396 Example: 2397 2398 ```mlir 2399 %0 = stablehlo.real_dynamic_slice %input, %start, %limit, %strides 2400 : (tensor<256x?xf32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<256x?xf32> 2401 ``` 2402 }]; 2403 let arguments = (ins 2404 HLO_Tensor:$operand, 2405 HLO_DimensionTensor:$start_indices, 2406 HLO_DimensionTensor:$limit_indices, 2407 HLO_DimensionTensor:$strides 2408 ); 2409 let results = (outs HLO_Tensor:$result); 2410 let hasVerifier = 1; 2411 2412 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 2413} 2414 2415def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", 2416 [NoSideEffect, AllElementTypesMatch<["operand", "padding_value", "result"]>, 2417 AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { 2418 let summary = "Dynamic Pad operator"; 2419 let description = [{ 2420 The dynamic shape version of PadOp. Pads the edges of `operand` with the 2421 `padding_value` and according to the passed configuration. Expect 2422 edge_padding_low/edge_padding_high/interior_padding to be statically shaped 2423 and matching the rank of the input. 2424 2425 See https://www.tensorflow.org/xla/operation_semantics#pad. 2426 2427 Example: 2428 2429 ```mlir 2430 %0 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 2431 : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32> 2432 ``` 2433 }]; 2434 let arguments = (ins 2435 HLO_Tensor:$operand, 2436 HLO_Tensor:$padding_value, 2437 HLO_DimensionTensor:$edge_padding_low, 2438 HLO_DimensionTensor:$edge_padding_high, 2439 HLO_DimensionTensor:$interior_padding 2440 ); 2441 let results = (outs HLO_Tensor:$result); 2442 let description = [{ 2443 Dynamically Pads the `operand`, with amount of padding added at 2444 low-end/high-end/interior is passed through input tensors. 2445 }]; 2446 let hasVerifier = 1; 2447 2448 let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; 2449} 2450 2451def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", 2452 [InferTensorTypeWithReify, NoSideEffect]> { 2453 string summary = "Dynamic Gather operator"; 2454 string description = [{ 2455 The dynamic shape version of GatherOp. Stitches together several slices of 2456 an input array. 2457 }]; 2458 2459 let arguments = (ins 2460 HLO_Tensor:$operand, 2461 HLO_IntTensor:$start_indices, 2462 HLO_IntTensor:$slice_sizes, 2463 StableHLO_GatherDimensionNumbers:$dimension_numbers, 2464 DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted 2465 ); 2466 let results = (outs HLO_Tensor); 2467 2468 let extraClassDeclaration = [{ 2469 static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { 2470 return succeeded(mlir::verifyCompatibleShapes(l, r)); 2471 } 2472 }]; 2473} 2474 2475def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [NoSideEffect]> { 2476 let summary = "Dynamic Convolution operator"; 2477 let description = [{ 2478 The dynamic shape version of ConvOp. Computes a convolution with dynamic padding. 2479 }]; 2480 2481 let arguments = !con( 2482 (ins 2483 HLO_Tensor:$lhs, 2484 HLO_Tensor:$rhs, 2485 HLO_Tensor:$d_padding), 2486 StableHLO_ConvolutionAttributes.attributes); 2487 let results = (outs HLO_Tensor); 2488} 2489 2490def StableHLO_ComputeReshapeShapeOp : 2491 StableHLO_Op<"compute_reshape_shape", [NoSideEffect]> { 2492 string summary = "Compute input for reshape with any dynamic dim resolved"; 2493 2494 string description = [{ 2495 This operation handles the dynamic aspect of a TF/NumPy/CHLO reshape. The 2496 dynamic aspect is that a single extent can be -1 and that dimension will 2497 instead be computed. This handles the computation and can then be passed to 2498 an HLO DynamicReshapeOp to replicate the TF/NumPy reshape behavior. 2499 2500 This op has undefined behavior if the dimensions do not evenly divide the 2501 number of elements, or if there are multiple -1 values. It is an identity op 2502 if no dimensions are -1. 2503 2504 ``` 2505 %0 = hlo.compute_reshape_shape 12, [2, -1] -> [2, 6] 2506 ``` 2507 }]; 2508 2509 let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); 2510 let results = (outs 1DTensorOf<[AnyInteger, Index]>:$result); 2511 2512 // TODO (b/241767462): Use functional-type for type printing for consistency. 2513 let assemblyFormat = "$num_elements `,` $dynamic_shape attr-dict `:` type($num_elements) `,` type($dynamic_shape) `->` type($result)"; 2514} 2515 2516def StableHLO_CstrReshapableOp : 2517 StableHLO_Op<"cstr_reshapable", [NoSideEffect]> { 2518 string summary = "Compute input for reshape with any dynamic dim resolved"; 2519 2520 string description = [{ 2521 This operation creates a witness on the constraint that a given shape would 2522 be a valid reshape for the given number of elements. 2523 2524 ``` 2525 %0 = stablehlo.cstr_reshapable 12, [2, -1] -> success 2526 %1 = stablehlo.cstr_reshapable 13, [2, -1] -> failure 2527 ``` 2528 }]; 2529 2530 let arguments = (ins Index:$num_elements, 1DTensorOf<[AnyInteger, Index]>:$dynamic_shape); 2531 let results = (outs Shape_WitnessType:$result); 2532 2533 // TODO (b/241767462): Use functional-type for type printing for consistency. 2534 let assemblyFormat = "$num_elements `,` $dynamic_shape attr-dict `:` type($num_elements) `,` type($dynamic_shape)"; 2535} 2536 2537#endif // STABLEHLO_DIALECT_STABLEHLO_OPS 2538