1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the operation definition file for MHLO ops. 17 18#ifndef HLO_OPS 19#define HLO_OPS 20 21include "mlir/IR/OpBase.td" 22include "mlir/Interfaces/InferTypeOpInterface.td" 23include "mlir/Interfaces/SideEffectInterfaces.td" 24include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" 25include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" 26include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" 27 28class HLO_Op<string mnemonic, list<OpTrait> traits> : 29 Op<HLO_Dialect, mnemonic, traits> { 30 // Whether this operation has a custom conversion to HLO or not. 31 bit hasCustomHLOConverter = 0b0; 32 33 // TODO(b/129012527) Much of this custom verification should be expressed as 34 // type constraints. 35 let verifier = [{ return Verify(*this); }]; 36} 37 38def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">; 39def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">; 40def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">; 41def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">; 42def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [ 43 HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION 44]> { 45 let cppNamespace = "::mlir::mhlo"; 46} 47 48//===----------------------------------------------------------------------===// 49// MHLO nullary op definitions. 50//===----------------------------------------------------------------------===// 51 52def HLO_ConstOp : HLO_Op<"constant", 53 [ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]>, 54 BASE_HLO_ConstOp { 55 let arguments = (ins 56 ElementsAttr:$value 57 ); 58 59 let results = (outs 60 HLO_StaticShapeTensor:$output 61 ); 62 63 let builders = [ 64 OpBuilderDAG<(ins "Attribute":$value)>]; 65 66 let assemblyFormat = "attr-dict $value"; 67 68 let hasFolder = 1; 69 70 // Constant has special conversion logic to HLO. 71 let hasCustomHLOConverter = 1; 72} 73 74def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { 75 let arguments = (ins I64Attr:$iota_dimension); 76 77 let results = (outs HLO_IntFpOrComplexTensor:$output); 78 79 // TODO(b/130357376): Iota has special conversion logic to HLO. 80 let hasCustomHLOConverter = 1; 81 let hasCanonicalizer = 1; 82 let hasFolder = 1; 83} 84 85def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { 86 let summary = "Create linear increasing values from 0 to length -1."; 87 let description = [{ 88 Produces an HLO Tensor of the specified shape, with an incremental set of 89 values along the specified dimension starting at 0. 90 91 Requires: 92 - The output length of the tensor result. 93 }]; 94 95 let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); 96 let results = (outs HLO_Tensor:$result); 97 98 let hasCanonicalizer = 1; 99 // Cannot be exported to legacy formats. 100 let hasCustomHLOConverter = 1; 101} 102 103 104def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { 105 string summary = "Create Token operator"; 106 107 string description = [{ 108 Produces a HLO token. Tokens are used for ordering side-effecting perations. 109 This is exported to HLO as an AfterAll operation with no operands to 110 generate a token. 111 }]; 112 113 let results = (outs HLO_Token:$output); 114} 115 116//===----------------------------------------------------------------------===// 117// MHLO unary elementwise op definitions. 118//===----------------------------------------------------------------------===// 119// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions 120 121class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, 122 Type TensorType>: HLO_Op<mnemonic, 123 !listconcat(traits, 124 [InferShapedTypeOpInterface, InferFusibilityOpInterface, 125 SameOperandsAndResultShape])> { 126 let arguments = (ins TensorType:$operand); 127 let results = (outs TensorType); 128 let extraClassDeclaration = [{ 129 static LogicalResult inferReturnTypeComponents( 130 MLIRContext* context, Optional<Location> location, 131 ValueRange operands, DictionaryAttr attributes, RegionRange regions, 132 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { 133 return failure(); 134 } 135 LogicalResult reifyReturnTypeShapes( 136 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { 137 return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), 138 &reifiedReturnShapes); 139 } 140 bool inferInputOutputShapeEquality(int input, int output) { 141 return true; 142 } 143 llvm::Optional<Value> inferEffectiveWorkloadShape() { 144 return getOperation()->getResult(0); 145 } 146 }]; 147} 148 149// Abs supports complex to real, so element type is not guaranteed to match. 150def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", 151 [NoSideEffect, 152 DeclareOpInterfaceMethods<InferTypeOpInterface>], 153 TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { 154} 155 156def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", 157 [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp; 158 159def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", 160 [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; 161 162def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", 163 [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, 164 BASE_HLO_ConvertOp { 165 let builders = [ 166 OpBuilderDAG<(ins "Value":$operand, "Type":$result_element_ty)>]; 167 168 let hasFolder = 1; 169 170 let hasCustomHLOConverter = 1; 171} 172 173def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", 174 [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, 175 BASE_HLO_ClzOp; 176 177def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", 178 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 179 BASE_HLO_CosOp; 180 181def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", 182 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 183 BASE_HLO_ExpOp; 184 185def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", 186 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 187 BASE_HLO_Expm1Op; 188 189def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", 190 [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; 191 192def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", 193 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], 194 HLO_ComplexTensor>, BASE_HLO_ImagOp { 195 let results = (outs HLO_FpTensor); 196 let hasFolder = 1; 197} 198 199def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, 200 DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_Tensor>, 201 BASE_HLO_IsFiniteOp { 202 let arguments = (ins HLO_FpTensor:$x); 203 let results = (outs HLO_PredTensor:$y); 204} 205 206def HLO_LogOp: HLO_UnaryElementwiseOp<"log", 207 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 208 BASE_HLO_LogOp; 209 210def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one", 211 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 212 BASE_HLO_Log1pOp; 213 214def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", 215 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 216 BASE_HLO_LogisticOp; 217 218def HLO_NotOp: HLO_UnaryElementwiseOp<"not", 219 [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, 220 BASE_HLO_NotOp { 221} 222 223def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", 224 [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, 225 BASE_HLO_NegOp { 226 let hasFolder = 1; 227} 228 229def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", 230 [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, 231 BASE_HLO_PopulationCountOp; 232 233def HLO_RealOp: HLO_UnaryElementwiseOp<"real", 234 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>], 235 HLO_ComplexTensor>, BASE_HLO_RealOp { 236 let results = (outs HLO_FpTensor); 237 let hasFolder = 1; 238} 239 240def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", 241 [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp { 242 let hasFolder = 1; 243} 244 245def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", 246 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 247 BASE_HLO_RsqrtOp; 248 249def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", 250 [NoSideEffect, SameOperandsAndResultType], 251 TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, 252 BASE_HLO_SignOp; 253 254def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", 255 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 256 BASE_HLO_SinOp; 257 258def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", 259 [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, 260 BASE_HLO_SqrtOp { 261 let hasFolder = 1; 262} 263 264def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", 265 [NoSideEffect, SameOperandsAndResultType], 266 HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; 267 268//===----------------------------------------------------------------------===// 269// MHLO binary elementwise op definitions. 270//===----------------------------------------------------------------------===// 271// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations 272 273class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> : 274 HLO_Op<mnemonic, !listconcat(traits, 275 [InferShapedTypeOpInterface, InferFusibilityOpInterface, 276 SameOperandsAndResultShape])> { 277 let arguments = (ins 278 HLO_Tensor:$lhs, 279 HLO_Tensor:$rhs 280 ); 281 282 let extraClassDeclaration = [{ 283 static LogicalResult inferReturnTypeComponents( 284 MLIRContext* context, Optional<Location> location, ValueRange operands, 285 DictionaryAttr attributes, RegionRange regions, 286 SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { 287 return failure(); 288 } 289 LogicalResult reifyReturnTypeShapes( 290 OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { 291 return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), 292 &reifiedReturnShapes); 293 } 294 bool inferInputsShapeEquality(int lhs, int rhs) { 295 return true; 296 } 297 bool inferInputOutputShapeEquality(int input, int output) { 298 return true; 299 } 300 llvm::Optional<Value> inferEffectiveWorkloadShape() { 301 return getOperation()->getResult(0); 302 } 303 }]; 304 305 let results = (outs HLO_Tensor); 306 let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; 307 let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; 308} 309 310def HLO_AddOp : HLO_BinaryElementwiseOp<"add", 311 [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { 312 let hasFolder = 1; 313} 314 315def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", 316 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; 317 318def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", 319 [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]>, 320 BASE_HLO_ComplexOp { 321 let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); 322 let results = (outs HLO_ComplexTensor); 323 let hasFolder = 1; 324} 325 326def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", 327 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { 328 let hasFolder = 1; 329} 330 331def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", 332 [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { 333 let hasFolder = 1; 334} 335 336def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", 337 [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { 338 let hasFolder = 1; 339} 340 341def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", 342 [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { 343 let hasFolder = 1; 344} 345 346def HLO_PowOp : HLO_BinaryElementwiseOp<"power", 347 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; 348 349def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", 350 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { 351 let hasFolder = 1; 352} 353 354def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", 355 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; 356 357def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", 358 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; 359 360def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", 361 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; 362 363def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", 364 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { 365 let hasFolder = 1; 366} 367 368//===----------------------------------------------------------------------===// 369// MHLO binary logical elementwise op definitions. 370//===----------------------------------------------------------------------===// 371 372// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations 373class HLO_BinaryLogicalElementwiseOp<string mnemonic> : 374 HLO_BinaryElementwiseOp< 375 mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { 376 let arguments = (ins 377 HLO_PredOrIntTensor:$lhs, 378 HLO_PredOrIntTensor:$rhs 379 ); 380 381 let hasFolder = 1; 382} 383 384def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; 385def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; 386def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; 387 388//===----------------------------------------------------------------------===// 389// MHLO communication op definitions. 390//===----------------------------------------------------------------------===// 391 392// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. 393// InfeedWithToken allows ordering of infeed HLO instructions using tokens. 394def HLO_InfeedOp : HLO_Op<"infeed", []> { 395 396 string summary = "Infeed operator"; 397 398 string description = [{ 399 Reads a single data item from the implicit Infeed streaming interface of 400 the device, interpreting the data as the given shape, and returns a XlaOp 401 of the data. Multiple Infeed operations are allowed in a computation, but 402 there must be a total order among the Infeed operations. 403 404 Attributes: 405 layout: Array attribute. Same shape as the output of the infeed, except 406 that every tensor is replaced by a minor_to_major array for the 407 tensor's layout. 408 409 See https://www.tensorflow.org/xla/operation_semantics#infeed. 410 }]; 411 412 let arguments = (ins 413 HLO_Token:$token, 414 DefaultValuedAttr<StrAttr, "">:$infeed_config, 415 OptionalAttr<ArrayAttr>:$layout 416 ); 417 let results = (outs HLO_Tuple); 418 let hasCustomHLOConverter = 1; 419} 420 421// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'. 422// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. 423def HLO_OutfeedOp : HLO_Op<"outfeed", []> { 424 425 string summary = "Outfeed operator"; 426 427 string description = [{ 428 Generates outgoing data transfers for the given data. It takes data and a 429 token type operand and produces a token type value. Tokens are used for 430 ordering side-effecting operations. 431 432 See https://www.tensorflow.org/xla/operation_semantics#outfeed. 433 }]; 434 435 let arguments = (ins 436 HLO_TensorOrTuple:$operand, 437 HLO_Token:$token, 438 DefaultValuedAttr<StrAttr, "">:$outfeed_config 439 ); 440 let results = (outs HLO_Token); 441 let hasCustomHLOConverter = 1; 442} 443 444def HLO_SendOp : HLO_Op<"send", []> { 445 446 string summary = "Send operator"; 447 448 string description = [{ 449 Sends the given operand data to a Recv instruction in another computation 450 that shares the same channel handle. Does not return any data. Similar to 451 the Recv operation, Send operation represents synchronous communication, 452 and is internally decomposed into 2 HLO instructions (Send and SendDone) to 453 enable asynchronous data transfers. 454 455 See https://www.tensorflow.org/xla/operation_semantics#send. 456 }]; 457 458 let arguments = (ins 459 HLO_TensorOrTuple:$operand, 460 HLO_Token:$token, 461 ChannelHandle:$channel_id, 462 DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer 463 ); 464 465 let results = (outs HLO_Token); 466 let hasCustomHLOConverter = 1; 467} 468 469def HLO_RecvOp : HLO_Op<"recv", []> { 470 471 string summary = "Recv operator"; 472 473 string description = [{ 474 Receives data of the given shape from a Send instruction in another 475 computation that shares the same channel handle. Returns a tuple containing 476 value for the received data and a token. Recv operation represents 477 synchronous communication. However, the instruction is internally decomposed 478 into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data 479 transfers. 480 481 See https://www.tensorflow.org/xla/operation_semantics#recv. 482 }]; 483 484 let arguments = (ins 485 HLO_Token:$token, 486 ChannelHandle:$channel_id, 487 DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer 488 ); 489 490 let results = (outs HLO_Tuple); 491 let hasCustomHLOConverter = 1; 492} 493 494//===----------------------------------------------------------------------===// 495// MHLO parallelism related op definitions. 496//===----------------------------------------------------------------------===// 497 498def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect, 499 DeclareOpInterfaceMethods<InferTypeOpInterface>]>, 500 BASE_HLO_ReplicaIdOp { 501 let results = (outs TensorOf<[UI32]>); 502} 503 504//===----------------------------------------------------------------------===// 505// MHLO control flow op definitions. 506//===----------------------------------------------------------------------===// 507 508def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { 509 510 string summary = "AfterAll operator"; 511 512 string description = [{ 513 AfterAll takes a variadic number of tokens and produces a single token. 514 Tokens are primitive types which can be threaded between side-effecting 515 operations to enforce ordering. AfterAll can be used as a join of tokens 516 for ordering a operation after a set operations. 517 518 See https://www.tensorflow.org/xla/operation_semantics#afterall. 519 }]; 520 521 let arguments = (ins Variadic<HLO_Token>:$operands); 522 let results = (outs HLO_Token); 523} 524 525// Xla Client API has two separate calls for indexed and predicated conditional, 526// although both eventually map to kConditional HLO. IfOp maps to predicated 527// conditional use of kConditional HLO. 528def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { 529 string summary = "If operator"; 530 531 string description = [{ 532 Returns the result of executing either a true or false function depending on 533 the result of a condition function. 534 535 See https://www.tensorflow.org/xla/operation_semantics#conditional. 536 }]; 537 538 let arguments = (ins 539 HLO_PredTensor:$pred, 540 HLO_TensorOrTuple:$true_arg, 541 HLO_TensorOrTuple:$false_arg 542 ); 543 544 let regions = (region AnyRegion:$true_branch, 545 AnyRegion:$false_branch); 546 547 let results = (outs HLO_TensorOrTuple); 548 549 // TODO(b/129422361): ConditionalOp has special conversion logic to HLO. 550 let hasCustomHLOConverter = 1; 551} 552 553// Xla Client API has two separate calls for indexed and predicated conditional, 554// although both eventually map to kConditional HLO. CaseOp maps to indexed 555// conditional use of kConditional HLO. 556def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, 557 BASE_HLO_CaseOp { 558 559 let arguments = (ins 560 I32Tensor:$index, 561 Variadic<HLO_TensorOrTuple>:$branch_operands 562 ); 563 564 let regions = (region VariadicRegion<AnyRegion>:$branches); 565 566 let results = (outs Variadic<HLO_TensorOrTuple>); 567 568 let hasCustomHLOConverter = 1; 569} 570 571 572def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, 573 SameOperandsAndResultType]>, 574 BASE_HLO_WhileOp { 575 let arguments = (ins HLO_TensorOrTuple:$val); 576 577 let regions = (region AnyRegion:$cond, AnyRegion:$body); 578 579 let results = (outs HLO_TensorOrTuple); 580 581 // TODO(b/129422361): WhileOp has special conversion logic to HLO. 582 let hasCustomHLOConverter = 1; 583} 584 585def HLO_AllReduceOp : HLO_Op<"all_reduce", 586 [SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { 587 588 let arguments = (ins 589 HLO_Tensor:$operand, 590 I64ElementsAttr:$replica_groups, 591 OptionalAttr<ChannelHandle>:$channel_id 592 ); 593 let regions = (region SizedRegion<1>:$computation); 594 let results = (outs HLO_Tensor); 595 596 let hasCustomHLOConverter = 1; 597} 598 599def HLO_AllToAllOp : HLO_Op<"all_to_all", 600 [NoSideEffect, SameOperandsElementType, SameOperandsShape]>, BASE_HLO_AllToAllOp { 601 602 let arguments = (ins 603 HLO_Tensor:$operand, 604 I64Attr:$split_dimension, 605 I64Attr:$concat_dimension, 606 I64Attr:$split_count, 607 I64ElementsAttr:$replica_groups 608 ); 609 let results = (outs HLO_Tensor); 610} 611 612def HLO_ReduceOp: HLO_Op<"reduce", [ 613 RecursiveSideEffects, 614 SameVariadicOperandSize, 615 SingleBlockImplicitTerminator<"ReturnOp">, 616 InferFusibilityOpInterface 617 ]>, BASE_HLO_ReduceOp { 618 let arguments = (ins 619 Variadic<HLO_TensorOrTuple>:$operands, 620 Variadic<HLO_TensorOrTuple>:$init_values, 621 I64ElementsAttr:$dimensions 622 ); 623 624 let results = (outs Variadic<HLO_TensorOrTuple>); 625 626 let builders = [ 627 OpBuilderDAG<(ins "ValueRange":$operands, "ValueRange":$init_values, 628 "DenseIntElementsAttr":$dimensions)>]; 629 630 let extraClassDeclaration = [{ 631 bool isFusibleWithConsumer() { 632 return false; 633 } 634 llvm::Optional<Value> inferEffectiveWorkloadShape() { 635 return getOperation()->getOperand(0); 636 } 637 }]; 638 639 let hasFolder = 1; 640 641 // TODO(hinsu): Verify that the attached body arguments and results are 642 // compatible with reduce op's operands. 643 let regions = (region SizedRegion<1>:$body); 644 645 // TODO(hinsu): Implement custom printer and parser. 646 647 // TODO(b/129422361): ReduceOp has special conversion logic to HLO. 648 let hasCustomHLOConverter = 1; 649} 650 651//===----------------------------------------------------------------------===// 652// MHLO tuple op definitions. 653//===----------------------------------------------------------------------===// 654def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { 655 let arguments = (ins 656 HLO_Tuple, 657 I32Attr:$index 658 ); 659 660 let results = (outs HLO_TensorOrTokenOrTuple); 661 662 let hasFolder = 1; 663 664 let builders = [ 665 OpBuilderDAG<(ins "Value":$value, "int32_t":$index)>]; 666} 667 668def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { 669 let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val); 670 let results = (outs HLO_Tuple); 671 672 let builders = [ 673 OpBuilderDAG<(ins "ValueRange":$values)>]; 674 675 let hasCanonicalizer = 1; 676} 677 678def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, 679 SameOperandsAndResultShape, 680 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, 681 ["reifyReturnTypeShapes"]>]>, BASE_HLO_CompareOp { 682 let arguments = (ins 683 HLO_Tensor:$lhs, 684 HLO_Tensor:$rhs, 685 HLO_ComparisonDirectionAttr:$comparison_direction, 686 OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type 687 ); 688 let results = (outs HLO_PredTensor); 689 690 let hasFolder = 1; 691 692 let builders = [ 693 OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs, 694 "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>, 695 ]; 696 697 let hasCustomHLOConverter = 1; 698} 699 700//===----------------------------------------------------------------------===// 701// MHLO Slice definitions. 702//===----------------------------------------------------------------------===// 703 704def HLO_SliceOp: HLO_Op< 705 "slice", 706 [NoSideEffect, SameOperandsAndResultElementType, 707 AllTypesMatch<["start_indices", "limit_indices", "strides"]>, 708 DeclareOpInterfaceMethods<InferTypeOpInterface>]> { 709 let arguments = (ins 710 HLO_Tensor:$operand, 711 I64ElementsAttr:$start_indices, 712 I64ElementsAttr:$limit_indices, 713 I64ElementsAttr:$strides 714 ); 715 716 let results = (outs HLO_Tensor); 717 718 let hasCanonicalizer = 1; 719 let hasFolder = 1; 720} 721 722def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", 723 [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]>, BASE_HLO_DynamicSliceOp { 724 let arguments = (ins 725 HLO_Tensor:$operand, 726 Variadic<HLO_ScalarIntTensor>:$start_indices, 727 I64ElementsAttr:$slice_sizes 728 ); 729 730 let results = (outs HLO_Tensor:$result); 731 let hasCanonicalizer = 1; 732} 733 734def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", 735 [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, 736 AllShapesMatch<["operand", "result"]>]>, BASE_HLO_DynamicUpdateSliceOp { 737 let arguments = (ins 738 HLO_Tensor:$operand, 739 HLO_Tensor:$update, 740 Variadic<HLO_ScalarIntTensor>:$start_indices 741 ); 742 743 let results = (outs HLO_Tensor:$result); 744} 745 746 747//===----------------------------------------------------------------------===// 748// MHLO Other op definitions. 749//===----------------------------------------------------------------------===// 750 751def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, 752 BASE_HLO_BatchNormGradOp { 753 754 let arguments = (ins 755 HLO_Tensor:$operand, 756 HLO_Tensor:$scale, 757 HLO_Tensor:$mean, 758 HLO_Tensor:$variance, 759 HLO_Tensor:$grad_output, 760 F32Attr:$epsilon, 761 I64Attr:$feature_index 762 ); 763 764 let results = (outs HLO_Tuple); 765} 766 767def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", 768 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BatchNormInferenceOp { 769 770 let arguments = (ins 771 HLO_Tensor:$operand, 772 HLO_Tensor:$scale, 773 HLO_Tensor:$offset, 774 HLO_Tensor:$mean, 775 HLO_Tensor:$variance, 776 F32Attr:$epsilon, 777 I64Attr:$feature_index 778 ); 779 780 let results = (outs HLO_Tensor); 781} 782 783def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>, 784 BASE_HLO_BatchNormTrainingOp { 785 786 let arguments = (ins 787 HLO_Tensor:$operand, 788 HLO_Tensor:$scale, 789 HLO_Tensor:$offset, 790 F32Attr:$epsilon, 791 I64Attr:$feature_index 792 ); 793 794 let results = (outs HLO_Tuple); 795} 796 797def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert", 798 [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp { 799 800 let arguments = (ins HLO_Tensor:$operand); 801 let results = (outs HLO_Tensor); 802 let hasCustomHLOConverter = 1; 803} 804 805def HLO_BroadcastOp : HLO_Op<"broadcast", 806 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { 807 let arguments = (ins 808 HLO_Tensor:$operand, 809 I64ElementsAttr:$broadcast_sizes 810 ); 811 812 let results = (outs HLO_Tensor); 813} 814 815def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", 816 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastInDimOp { 817 let arguments = (ins 818 HLO_Tensor:$operand, 819 BroadcastDimAttr:$broadcast_dimensions 820 ); 821 822 let results = (outs HLO_StaticShapeTensor); 823 824 let hasFolder = 1; 825 // Only handles a static subset of the legacy format. 826 let hasCustomHLOConverter = 1; 827} 828 829def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", 830 [NoSideEffect]> { 831 string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; 832 string description = [{ 833 This is a generalization of the BroadcastInDimOp which accepts its output 834 dimensions as an argument. It should eventually supercede the statically 835 shaped original, but is being phased as a separate op in order to support 836 compatibility with lowerings and translations that precede dynamic 837 shapes. 838 }]; 839 let arguments = (ins 840 HLO_Tensor:$operand, 841 HLO_DimensionTensor:$output_dimensions, 842 BroadcastDimAttr:$broadcast_dimensions 843 ); 844 845 let results = (outs HLO_Tensor); 846 847 let hasCanonicalizer = 1; 848 // Cannot be exported to legacy formats. 849 let hasCustomHLOConverter = 1; 850} 851 852// Note: There is no HLO_CallOp because the standard call operation mlir::CallOp 853// is used instead. A mlir::CallOp is exported to a HLO call instruction 854// directly. 855 856def HLO_CholeskyOp : HLO_Op<"cholesky", 857 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp { 858 let arguments = (ins 859 HLO_FpOrComplexTensor:$a, 860 DefaultValuedAttr<BoolAttr, "false">:$lower 861 ); 862 863 let results = (outs HLO_FpOrComplexTensor); 864} 865 866def HLO_ClampOp : HLO_Op<"clamp", 867 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp { 868 let arguments = (ins 869 HLO_Tensor:$min, 870 HLO_Tensor:$operand, 871 HLO_Tensor:$max 872 ); 873 874 let results = (outs HLO_Tensor); 875} 876 877def HLO_ConcatenateOp : HLO_Op<"concatenate", 878 [NoSideEffect, SameOperandsAndResultElementType, 879 DeclareOpInterfaceMethods<InferTypeOpInterface>]>, BASE_HLO_ConcatenateOp { 880 881 let arguments = (ins 882 Variadic<HLO_Tensor>:$val, 883 I64Attr: $dimension 884 ); 885 886 let results = (outs HLO_Tensor); 887 888 let hasCanonicalizer = 1; 889 let hasFolder = 1; 890 891} 892 893def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", 894 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CollectivePermuteOp { 895 896 let arguments = (ins 897 HLO_Tensor:$operand, 898 I64ElementsAttr:$source_target_pairs 899 ); 900 let results = (outs HLO_Tensor); 901} 902 903def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { 904 let arguments = !con( 905 (ins 906 HLO_Tensor:$lhs, 907 HLO_Tensor:$rhs), 908 ConvolutionAttributes.attributes); 909 910 let results = (outs HLO_Tensor); 911 let hasCustomHLOConverter = 1; 912} 913 914def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, 915 BASE_HLO_CopyOp { 916 let arguments = (ins HLO_Tensor); 917 let results = (outs HLO_Tensor); 918 let hasFolder = 1; 919} 920 921def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", 922 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CrossReplicaSumOp { 923 924 let arguments = (ins 925 HLO_Tensor:$operand, 926 I64ElementsAttr:$replica_groups 927 ); 928 929 let results = (outs HLO_Tensor); 930} 931 932def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { 933 let arguments = (ins 934 Variadic<HLO_Tensor>:$args, 935 StrAttr:$call_target_name, 936 DefaultValuedAttr<BoolAttr, "false">:$has_side_effect, 937 DefaultValuedAttr<StrAttr, "">:$backend_config 938 ); 939 let results = (outs Variadic<HLO_Tensor>); 940 let hasCustomHLOConverter = 1; 941} 942 943def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { 944 let arguments = ( 945 ins HLO_Tensor:$lhs, 946 HLO_Tensor:$rhs, 947 HLO_PrecisionConfigAttr:$precision_config 948 ); 949 let results = (outs HLO_Tensor); 950} 951 952def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, 953 BASE_HLO_DotGeneralOp { 954 let arguments = (ins 955 HLO_Tensor:$lhs, 956 HLO_Tensor:$rhs, 957 DotDimensionNumbers:$dot_dimension_numbers, 958 HLO_PrecisionConfigAttr:$precision_config 959 ); 960 961 let results = (outs HLO_Tensor); 962 let verifier = [{ return Verify(*this); }]; 963} 964 965// Define Base Einsum op within the HLO dialect as these are client ops and 966// therefore this class is not common between HLO and LHLO ops. 967class BASE_EinsumOp { 968 string summary = "Einsum operator"; 969 970 string description = [{ 971 Returns a tensor whose elements are defined by equation, which is written 972 in a shorthand form inspired by the Einstein summation convention. 973 }]; 974} 975 976def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp { 977 let arguments = (ins 978 HLO_Tensor:$lhs, 979 HLO_Tensor:$rhs, 980 StrAttr:$einsum_config 981 ); 982 983 let results = (outs HLO_Tensor); 984 985 // TODO(hinsu): Canonicalize to lower this client side HLO op to server 986 // side HLO ops. 987} 988 989def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp { 990 let arguments = (ins 991 HLO_Tensor:$operand, 992 StrAttr:$einsum_config 993 ); 994 995 let results = (outs HLO_Tensor); 996 997 let hasCanonicalizer = 1; 998 999 // UnaryEinsumOp is unconditionally canonicalized to the binary EinsumOp so 1000 // the HLO converter shouldn't be invoked. 1001 let hasCustomHLOConverter = 1; 1002} 1003 1004def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { 1005 let arguments = (ins 1006 HLO_Tensor:$operand, 1007 HLO_FftTypeAttr: $fft_type, 1008 I64ElementsAttr:$fft_length 1009 ); 1010 1011 let results = (outs HLO_Tensor); 1012} 1013 1014def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { 1015 let arguments = (ins 1016 HLO_Tensor:$operand, 1017 HLO_IntTensor:$start_indices, 1018 GatherDimensionNumbers:$dimension_numbers, 1019 I64ElementsAttr:$slice_sizes, 1020 DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted 1021 ); 1022 1023 let results = (outs HLO_Tensor); 1024 1025 let hasCanonicalizer = 1; 1026} 1027 1028def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, 1029 BASE_HLO_GetDimensionSizeOp { 1030 let arguments = (ins 1031 HLO_Tensor:$operand, 1032 I64Attr:$dimension 1033 ); 1034 // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the 1035 // XLA semantics is available. This limitation is because of the current XLA 1036 // implementation. 1037 let results = (outs I32Tensor); 1038 1039 let hasFolder = 1; 1040} 1041 1042def HLO_MapOp: HLO_Op<"map", 1043 [RecursiveSideEffects, SameOperandsElementType, 1044 SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>, 1045 BASE_HLO_MapOp { 1046 let arguments = (ins 1047 Variadic<HLO_Tensor>:$operands, 1048 I64ElementsAttr:$dimensions 1049 ); 1050 let regions = (region SizedRegion<1>:$computation); 1051 let results = (outs HLO_Tensor); 1052 let hasCustomHLOConverter = 1; 1053} 1054 1055def HLO_ReshapeOp: HLO_Op<"reshape", 1056 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { 1057 let arguments = (ins HLO_Tensor:$operand); 1058 1059 let results = (outs HLO_StaticShapeTensor); 1060 let hasFolder = 1; 1061 let hasCanonicalizer = 1; 1062 1063 let hasCustomHLOConverter = 1; 1064} 1065 1066def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { 1067 let summary = "Reshape a tensor to a given, possibly dynamic, shape."; 1068 let description = [{ 1069 Reshapes `operand` to `output_shape`. 1070 1071 Requires: 1072 - The length of `output_shape` is equal to the rank of `result`. 1073 - The number of elements in `operand` (that is, the product of extents of 1074 its shape) is equal to the number of elements in `output_shape` (that is, 1075 the product of values in `output_shape`). 1076 }]; 1077 1078 let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); 1079 let results = (outs HLO_Tensor:$result); 1080 1081 let hasCanonicalizer = 1; 1082 // Cannot be exported to legacy formats. 1083 let hasCustomHLOConverter = 1; 1084} 1085 1086def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, 1087 BASE_HLO_ScatterOp { 1088 let arguments = (ins 1089 HLO_Tensor:$operand, 1090 HLO_Tensor:$scatter_indices, 1091 HLO_Tensor:$updates, 1092 ScatterDimensionNumbers:$scatter_dimension_numbers, 1093 DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted, 1094 DefaultValuedAttr<BoolAttr, "false">:$unique_indices 1095 ); 1096 1097 let regions = (region SizedRegion<1>:$update_computation); 1098 1099 let results = (outs HLO_Tensor); 1100 1101 let hasCustomHLOConverter = 1; 1102 1103 let hasFolder = 1; 1104} 1105 1106// TODO(jpienaar): Add broadcastable trait. 1107def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, 1108 DeclareOpInterfaceMethods<InferShapedTypeOpInterface, 1109 ["reifyReturnTypeShapes"]>, DeclareOpInterfaceMethods<InferTypeOpInterface>, 1110 ]>, BASE_HLO_SelectOp { 1111 let arguments = (ins 1112 HLO_PredTensor:$pred, 1113 HLO_Tensor:$on_true, 1114 HLO_Tensor:$on_false 1115 ); 1116 1117 let results = (outs HLO_Tensor); 1118 1119 let hasFolder = 1; 1120} 1121 1122def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", 1123 [RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp { 1124 let arguments = (ins 1125 HLO_Tensor:$operand, 1126 HLO_Tensor:$source, 1127 HLO_Tensor:$init_value, 1128 OptionalAttr<I64ElementsAttr>:$window_dimensions, 1129 OptionalAttr<I64ElementsAttr>:$window_strides, 1130 OptionalAttr<I64ElementsAttr>:$padding 1131 ); 1132 1133 let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); 1134 1135 let results = (outs HLO_Tensor); 1136 1137 let hasCustomHLOConverter = 1; 1138} 1139 1140def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, 1141 BASE_HLO_SetDimensionSizeOp { 1142 let arguments = (ins 1143 HLO_Tensor:$operand, 1144 I32Tensor:$size, 1145 I64Attr:$dimension 1146 ); 1147 let results = (outs HLO_Tensor); 1148 1149 let hasFolder = 1; 1150} 1151 1152def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, 1153 SameOperandsAndResultShape]>, BASE_HLO_SortOp { 1154 let arguments = (ins 1155 Variadic<HLO_Tensor>:$operands, 1156 DefaultValuedAttr<I64Attr, "-1">:$dimension, 1157 DefaultValuedAttr<BoolAttr, "false">:$is_stable 1158 ); 1159 1160 let results = (outs Variadic<HLO_Tensor>); 1161 1162 let regions = (region SizedRegion<1>:$comparator); 1163 1164 let builders = [ 1165 OpBuilderDAG<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension, 1166 CArg<"bool", "false">:$is_stable)>]; 1167 1168 // TODO(b/129422361): SortOp has special conversion logic to HLO. 1169 let hasCustomHLOConverter = 1; 1170} 1171 1172def HLO_ReverseOp: HLO_Op<"reverse", 1173 [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { 1174 let arguments = (ins 1175 HLO_Tensor:$operand, 1176 I64ElementsAttr:$dimensions 1177 ); 1178 1179 let results = (outs HLO_Tensor); 1180 1181 let hasFolder = 1; 1182} 1183 1184def HLO_PadOp: HLO_Op<"pad", 1185 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PadOp { 1186 let arguments = (ins 1187 HLO_Tensor:$operand, 1188 HLO_Tensor:$padding_value, 1189 I64ElementsAttr: $edge_padding_low, 1190 I64ElementsAttr: $edge_padding_high, 1191 I64ElementsAttr: $interior_padding 1192 ); 1193 1194 let results = (outs HLO_Tensor); 1195 1196 let description = [{ 1197 Pads the `operand` according to TBD. 1198 }]; 1199 1200 // TODO(b/129422361): PadOp has a custom constructor for HLO. 1201 let hasCustomHLOConverter = 1; 1202 1203 let hasFolder = 1; 1204} 1205 1206def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { 1207 let arguments = (ins 1208 HLO_Tensor:$operand, 1209 StrAttr:$tag 1210 ); 1211 let hasCustomHLOConverter = 1; 1212} 1213 1214def HLO_TransposeOp: HLO_Op<"transpose", 1215 [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp { 1216 let arguments = (ins 1217 HLO_Tensor:$operand, 1218 I64ElementsAttr:$permutation 1219 ); 1220 let results = (outs HLO_Tensor); 1221 1222 let hasFolder = 1; 1223} 1224 1225def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", 1226 [NoSideEffect, SameOperandsAndResultElementType]>, 1227 BASE_HLO_TriangularSolveOp { 1228 let arguments = (ins 1229 HLO_FpOrComplexTensor:$a, 1230 HLO_FpOrComplexTensor:$b, 1231 BoolAttr:$left_side, 1232 BoolAttr:$lower, 1233 BoolAttr:$unit_diagonal, 1234 HLO_TransposeAttr:$transpose_a 1235 ); 1236 let results = (outs HLO_FpOrComplexTensor); 1237} 1238 1239def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ 1240 RecursiveSideEffects, 1241 SingleBlockImplicitTerminator<"ReturnOp"> 1242 ]>, BASE_HLO_ReduceWindowOp { 1243 1244 // TODO(hinsu): Verify that padding attribute is 2-d and the remaining 1245 // attributes are 1-d. Attributes' leading dimension should match rank of the 1246 // inputs. 1247 let arguments = (ins 1248 HLO_Tensor:$operand, 1249 HLO_Tensor:$init_value, 1250 I64ElementsAttr:$window_dimensions, 1251 // If strides or dilations attributes are missing then the default value is 1252 // one for each of the input dimensions. Similarly, padding values are zero 1253 // for both low and high in each of the dimensions, if not specified. 1254 OptionalAttr<I64ElementsAttr>:$window_strides, 1255 OptionalAttr<I64ElementsAttr>:$base_dilations, 1256 OptionalAttr<I64ElementsAttr>:$window_dilations, 1257 OptionalAttr<I64ElementsAttr>:$padding 1258 ); 1259 1260 let results = (outs HLO_Tensor); 1261 1262 // TODO(hinsu): Verify that the attached body arguments and results are 1263 // compatible with reduce op's operands. 1264 let regions = (region SizedRegion<1>:$body); 1265 1266 let hasCustomHLOConverter = 1; 1267 1268 // TODO(hinsu): Implement custom printer and parser. 1269} 1270 1271def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { 1272 let summary = [{ 1273 The `hlo.return` operation terminates a region and returns values. 1274 }]; 1275 1276 let arguments = (ins 1277 Variadic<HLO_TensorOrTuple>:$results 1278 ); 1279 1280 // Disable conversion operator for return op as the op is not an actual XLA 1281 // instruction and is only used as a terminator for regions. 1282 let hasCustomHLOConverter = 1; 1283 1284 // TODO(hinsu): Implement custom printer and parser. 1285} 1286 1287def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { 1288 let arguments = (ins 1289 HLO_Tensor:$input, 1290 HLO_Tensor:$index, 1291 I64Attr:$dim, 1292 I64Attr:$batch_dims 1293 ); 1294 1295 let results = (outs HLO_Tensor); 1296 1297 // TODO(hinsu): Canonicalize to lower this client side HLO op to server 1298 // side HLO ops. 1299} 1300 1301//===----------------------------------------------------------------------===// 1302// MHLO RNG Operators. 1303//===----------------------------------------------------------------------===// 1304 1305def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { 1306 let arguments = (ins 1307 HLO_PredIntOrFpTensor:$a, 1308 HLO_PredIntOrFpTensor:$b, 1309 HLO_DimensionTensor:$shape 1310 ); 1311 1312 let results = (outs HLO_PredIntOrFpTensor); 1313 1314 let hasCustomHLOConverter = 1; 1315} 1316 1317def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { 1318 let arguments = (ins 1319 HLO_FpTensor:$mu, 1320 HLO_FpTensor:$sigma, 1321 HLO_DimensionTensor:$shape 1322 ); 1323 1324 let results = (outs HLO_FpTensor); 1325 1326 let hasCustomHLOConverter = 1; 1327} 1328 1329def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, 1330 BASE_HLO_RngBitGeneratorOp { 1331 let arguments = (ins 1332 // TODO(jpienaar): This could be an enum instead. 1333 I32Attr:$rng_algorithm, 1334 HLO_IntOrFpTensor:$initial_state 1335 ); 1336 1337 let results = (outs HLO_TensorOrTuple:$result); 1338 1339 // TODO(jpienaar): This should not be needed. 1340 let hasCustomHLOConverter = 1; 1341} 1342 1343//===----------------------------------------------------------------------===// 1344// MHLO Quantize Operator. 1345//===----------------------------------------------------------------------===// 1346def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, 1347 BASE_HLO_DequantizeOp { 1348 let arguments = (ins 1349 TensorOf<[I32]>:$input, 1350 F32Attr:$min_range, 1351 F32Attr:$max_range, 1352 HLO_DequantizeModeAttr:$mode, 1353 BoolAttr:$transpose_output, 1354 DefaultValuedAttr<BoolAttr, "false">:$is_16bits 1355 ); 1356 1357 let results = (outs TensorOf<[BF16]>:$output); 1358 1359 let hasCustomHLOConverter = 1; 1360} 1361 1362def HLO_FusionOp : HLO_Op<"fusion", []> { 1363 let summary = "Fusion operator"; 1364 let description = [{ 1365 Models the fusion instruction. 1366 1367 A fusion op is consists of a group of basic ops (represented as a region 1368 attached to it). It serves as a hint to the backend that it is beneficial 1369 to emit the contained ops into a single loop nest or kernel. 1370 }]; 1371 let regions = (region SizedRegion<1>:$fused_computation); 1372 1373 let arguments = (ins 1374 Variadic<HLO_TensorOrTuple>:$operands, 1375 OptionalAttr<HLO_FusionKindAttr>:$fusion_kind 1376 ); 1377 1378 let results = (outs 1379 Variadic<HLO_TensorOrTuple>:$results 1380 ); 1381 1382 // FusionOp has special conversion logic to HLO. 1383 let hasCustomHLOConverter = 1; 1384} 1385 1386// This is an op for purposes internal to XLA/GPU. 1387def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp { 1388 let arguments = (ins HLO_Tensor:$operand); 1389 let results = (outs HLO_Tensor); 1390 let hasCustomHLOConverter = 1; 1391} 1392 1393def HLO_ReducePrecisionOp : 1394 HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, 1395 BASE_HLO_ReducePrecisionOp { 1396 let arguments = (ins 1397 HLO_FpTensor:$operand, 1398 I32Attr:$exponent_bits, 1399 I32Attr:$mantissa_bits 1400 ); 1401 let results = (outs HLO_FpTensor:$output); 1402} 1403 1404#endif // HLO_OPS 1405