1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This is the legalization pattern definition file for TF to XLA. 17 18include "mlir/IR/OpBase.td" 19include "mlir/Dialect/StandardOps/IR/Ops.td" 20include "mlir/Dialect/Tensor/IR/TensorOps.td" 21include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 22include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" 23include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" 24 25def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; 26 27// IEEE compliant floating point tensors. 28def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; 29 30//===----------------------------------------------------------------------===// 31// BatchNorm op patterns. 32//===----------------------------------------------------------------------===// 33 34def FeatureDimension : NativeCodeCall< 35 "getFeatureDimensionAttr($_builder, $0.getValue(), $1)">; 36def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>; 37def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>; 38 39def CastValueToI64: NativeCodeCall< 40 "CastValueToI64($0.getLoc(), $1, &$_builder)">; 41 42// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is 43// the corresponding value of ranked tensor type whose axis is referred in $0. 44def GetHLOAxisFromTFAxis : NativeCodeCall< 45 "GetHLOAxisFromTFAxis(" 46 "$0, $1.getType().cast<RankedTensorType>().getRank(), &$_builder)">; 47 48// Same as the above but with $1 of type operand_range from variadic TensorFlow 49// input. 50def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< 51 "GetHLOAxisFromTFAxis(" 52 "$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), " 53 "&$_builder)">; 54 55def CastElementsToI64Elements : NativeCodeCall< 56 "hlo::ConvertElementsAttr(" 57 "$0.cast<ElementsAttr>(), $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">; 58 59def : Pattern< 60 (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, 61 $exponential_avg_factor, $data_format, 62 FalseBoolAttr:$is_training), 63 [(HLO_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance, 64 $epsilon, (FeatureDimension $data_format, $x)), 65 // We already guaranteed that the last four results has no use so it 66 // does not matter what value we provide here for replacement. 67 /*batch_mean=*/(replaceWithValue $x), 68 /*batch_variance=*/(replaceWithValue $x), 69 /*reserve_space_1=*/(replaceWithValue $x), 70 /*reserve_space_2=*/(replaceWithValue $x)], 71 [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), 72 (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; 73 74//===----------------------------------------------------------------------===// 75// Assert op pattern. 76//===----------------------------------------------------------------------===// 77 78// HLO and XLA doesn't support Assertions. 79def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; 80 81//===----------------------------------------------------------------------===// 82// Binary op patterns. 83//===----------------------------------------------------------------------===// 84 85// Check that two values can be broadcasted together 86def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">, 87 "types must be broadcastable">; 88 89class DirectBinaryPat<Op FromOp, Op ToOp> 90 : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), 91 (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; 92 93foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp], 94 [TF_Atan2Op, HLOClient_BroadcastAtan2Op], 95 [TF_ComplexOp, HLOClient_BroadcastComplexOp], 96 [TF_DivOp, HLOClient_BroadcastDivOp], 97 [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], 98 [TF_MaximumOp, HLOClient_BroadcastMaxOp], 99 [TF_MinimumOp, HLOClient_BroadcastMinOp], 100 [TF_MulOp, HLOClient_BroadcastMulOp], 101 [TF_PowOp, HLOClient_BroadcastPowOp], 102 [TF_RealDivOp, HLOClient_BroadcastDivOp], 103 [TF_SubOp, HLOClient_BroadcastSubOp], 104 [TF_ZetaOp, HLOClient_BroadcastZetaOp]] in 105 def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>; 106 107def LowerRightShiftSigned : 108 Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), 109 (HLOClient_BroadcastShiftRightArithmeticOp $l, $r, 110 (BinBroadcastDimensions $l, $r)), 111 [(SignedIntTensor $r)]>; 112 113// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op 114// supports unsigned integers. 115 116// Performs a substitution of FloorDiv, pseudo code below: 117// 118// return floor(div(x, y)) 119def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), 120 (HLO_FloorOp 121 (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), 122 [(IEEEFloatTensor $l)]>; 123 124// Performs a substitution of FloorDiv for integer tensors, which required 125// additional correction for a negative numerator / denominator. Equivalent 126// pseudocode is shown below: 127// 128// if ((x < 0) != (y < 0)) { 129// T abs_x = std::abs(x); 130// T abs_y = std::abs(y); 131// return -(abs_x + abs_y - 1) / abs_y; 132// } else { 133// return x / y; 134// } 135// 136// BroadcastToDimensions is used to compute the broadcast attr to higher 137// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') 138// without returning the broadcast of 'r' to broadcast('l', 'r'). 139// 140// NOTE: This should be optimized for unsigned integers. 141// Requires static shaped inputs to create constant splats and computation of 142// broadcast attributes. 143def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), 144 (HLO_SelectOp 145 (HLOClient_BroadcastCompareOp 146 (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)), 147 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, 148 (HLO_DEFAULT_COMPARISON_TYPE)), 149 (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), 150 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, 151 (HLO_DEFAULT_COMPARISON_TYPE)), 152 (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ, 153 (HLO_DEFAULT_COMPARISON_TYPE)), 154 (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), 155 (HLOClient_BroadcastDivOp 156 (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), 157 (HLOClient_BroadcastSubOp (HLO_AbsOp $r), 158 (HLO_ConstOp (GetScalarOfType<1> $r)), 159 (NullDenseIntElementsAttr)), 160 (BinBroadcastDimensions $l, $r))), 161 (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), 162 [(SignedIntTensor $l)]>; 163 164// Performs a substitution of FloorMod designed to correct for possibly negative 165// values. Pseudocode shown below: 166// 167// T trunc_mod = std::fmod(x, y); 168// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y 169// Requires static shaped inputs to create constant splats and computation of 170// broadcast attributes. 171def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), 172 (HLO_SelectOp 173 (HLOClient_BroadcastAndOp 174 (HLOClient_BroadcastCompareOp 175 (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), 176 (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)), 177 (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE, 178 (HLO_DEFAULT_COMPARISON_TYPE)), 179 (HLOClient_BroadcastCompareOp 180 (HLOClient_BroadcastCompareOp:$r_cmp $r, 181 (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)), 182 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, 183 (HLO_DEFAULT_COMPARISON_TYPE)), 184 (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, 185 (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT, 186 (HLO_DEFAULT_COMPARISON_TYPE)), 187 (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE, 188 (HLO_DEFAULT_COMPARISON_TYPE)), 189 (NullDenseIntElementsAttr)), 190 (HLOClient_BroadcastAddOp $r, 191 $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; 192 193 194def Get2DTransposePerm: NativeCodeCall< 195 "Get2DTransposePerm($0, &$_builder)">; 196 197def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>; 198 199def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), 200 (HLO_DotOp 201 (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), 202 (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), 203 /*precision_config=*/(NullArrayAttr))>; 204 205//===----------------------------------------------------------------------===// 206// Logical & bitwise binary op patterns. 207//===----------------------------------------------------------------------===// 208 209class DirectLogicalBinaryPat<Op FromOp, Op ToOp> 210 : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), 211 (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), 212 [(SignedIntTensor $l)]>; 213 214foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], 215 [TF_LogicalOrOp, HLOClient_BroadcastOrOp], 216 [TF_BitwiseAndOp, HLOClient_BroadcastAndOp], 217 [TF_BitwiseOrOp, HLOClient_BroadcastOrOp], 218 [TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in 219 def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>; 220 221//===----------------------------------------------------------------------===// 222// Compare op patterns. 223//===----------------------------------------------------------------------===// 224 225class DirectComparePat<Op FromOp, StrEnumAttrCase direction> 226 : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), 227 (HLOClient_BroadcastCompareOp 228 $l, $r, (BinBroadcastDimensions $l, $r), direction, 229 (HLO_DEFAULT_COMPARISON_TYPE))>; 230 231def : DirectComparePat<TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT>; 232def : DirectComparePat<TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE>; 233def : DirectComparePat<TF_LessOp, HLO_COMPARISON_DIRECTION_LT>; 234def : DirectComparePat<TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE>; 235 236class EqualityPat<Op FromOp, StrEnumAttrCase direction> 237 : Pat<(FromOp AnyTensor:$l, AnyTensor:$r, 238 TrueBoolAttr:$incompatible_shape_error), 239 (HLOClient_BroadcastCompareOp 240 $l, $r, (BinBroadcastDimensions $l, $r), direction, 241 (HLO_DEFAULT_COMPARISON_TYPE)), 242 [(HLO_Tensor $l)]>; 243 244def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>; 245def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>; 246 247//===----------------------------------------------------------------------===// 248// Concat op patterns. 249//===----------------------------------------------------------------------===// 250 251def OneElementAttrPred 252 : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() == 1">; 253 254def OneElementAttr 255 : ElementsAttrBase<And<[ElementsAttr.predicate, OneElementAttrPred]>, 256 "Scalar ElementsAttr">; 257 258def HasRankedFirstOperand 259 : Constraint<CPred<"(*$0.begin()).getType().isa<RankedTensorType>()">>; 260 261def IsShapedTensor 262 : Constraint<CPred<"$0.getType().isa<RankedTensorType>()">>; 263 264// This pattern converts TensorFlow axis format to HLO axis format which 265// doesn't wrap around like TensorFlow and is always positive. For this 266// conversion, use the first input to get inputs rank. Other inputs need not be 267// ranked. 268// Defining op for `axis` is TensorFlow constant op in the pattern as during 269// the conversion, original Concat op operands still refers to the old ops even 270// if HLO constant op is introduced as an replacement for the TensorFlow 271// Constant op. 272def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), 273 (HLO_ConcatenateOp $inputs, 274 (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), 275 [(HasRankedFirstOperand $inputs)]>; 276 277//===----------------------------------------------------------------------===// 278// CollectivePermute op patterns. 279//===----------------------------------------------------------------------===// 280 281def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), 282 (HLO_CollectivePermuteOp $input, 283 (CastElementsToI64Elements $source_target_pairs))>; 284 285//===----------------------------------------------------------------------===// 286// CrossReplicaSum op patterns. 287//===----------------------------------------------------------------------===// 288 289def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), 290 (HLO_CrossReplicaSumOp $input, 291 (CastElementsToI64Elements $group_assignment))>; 292 293//===----------------------------------------------------------------------===// 294// All2All op patterns. 295//===----------------------------------------------------------------------===// 296 297def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), 298 (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; 299 300//===----------------------------------------------------------------------===// 301// FFT op patterns. 302//===----------------------------------------------------------------------===// 303 304def GetInnerDimFromValue : NativeCodeCall< 305 "GetInnerDimFromValue($0.getType().cast<ShapedType>(), &$_builder)">; 306 307def CheckInnerDimStatic 308 : Constraint<CPred<"CheckInnerDimStatic($0.getType().cast<ShapedType>(), &$_builder)">>; 309 310def : Pat<(TF_FFTOp:$res $input), 311 (HLO_FftOp $input, HLO_FFT_TYPE_FFT, (GetInnerDimFromValue $res)), 312 [(CheckInnerDimStatic $input)]>; 313 314def : Pat<(TF_IFFTOp:$res $input), 315 (HLO_FftOp $input, HLO_FFT_TYPE_IFFT, (GetInnerDimFromValue $res)), 316 [(CheckInnerDimStatic $input)]>; 317 318//===----------------------------------------------------------------------===// 319// GatherV2 op patterns. 320//===----------------------------------------------------------------------===// 321 322// Here, $params and $indices needs to be ranked so that $axis and $batch_dims 323// attributes can be converted from TensorFlow axis format supporting negative 324// indexing to the HLO format. 325def LegalizeGatherV2 : 326 Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, 327 (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), 328 (HLO_TorchIndexSelectOp $params, $indices, 329 (GetHLOAxisFromTFAxis $axis, $params), 330 (GetHLOAxisFromTFAxis $batch_dims, $indices))>; 331 332//===----------------------------------------------------------------------===// 333// Pad op patterns. 334//===----------------------------------------------------------------------===// 335 336class SliceDenseIntElementsAttrColumn2D<string column> : NativeCodeCall< 337 "SliceDenseIntElementsAttrColumn2D($0.cast<ElementsAttr>(), " # column # " )">; 338 339class SliceDenseIntElementsAttr<string index, string axis> : NativeCodeCall< 340 "SliceDenseIntElementsAttr($0.cast<ElementsAttr>(), " # index # ", " # axis # ")">; 341 342// Interior padding attribute based on the TF padding. 343def GetInteriorPadding : NativeCodeCall < 344 "GetInteriorPadding($0.cast<ElementsAttr>())">; 345 346def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), 347 (HLO_PadOp $input, $c, 348 (SliceDenseIntElementsAttrColumn2D<"0"> $padding), 349 (SliceDenseIntElementsAttrColumn2D<"1"> $padding), 350 (GetInteriorPadding $padding))>; 351 352//===----------------------------------------------------------------------===// 353// Identity op patterns. 354//===----------------------------------------------------------------------===// 355 356foreach src = [TF_IdentityOp, TF_StopGradientOp] in 357 def : Pat<(src $op), (replaceWithValue $op)>; 358 359// TODO(b/32223192): Support CheckNumerics in HLO. 360foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in 361 def : Pat<(src $op, $msg), (replaceWithValue $op)>; 362 363//===----------------------------------------------------------------------===// 364// MatMul op patterns. 365//===----------------------------------------------------------------------===// 366 367def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), 368 (HLO_DotOp 369 (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), 370 (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), 371 /*precision_config=*/(NullArrayAttr))>; 372 373//===----------------------------------------------------------------------===// 374// MatrixBandPart op pattern. 375//===----------------------------------------------------------------------===// 376 377class getIntegerAttr<string x>: NativeCodeCall< 378 "$_builder.getI64IntegerAttr(" # x # ")">; 379 380class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall< 381 "$_builder.getIntegerAttr(getElementTypeOrSelf($1.getType()), " 382 " GetDimensionSizeFromEnd($0, " # dimFromEnd # "))" 383 >; 384 385// TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern 386// For now, this op needs to be created in C++ because the expected output type 387// cannot be inferred. 388class createIotaOp<string dim>: NativeCodeCall< 389 "$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), " 390 "Get2DTensorType($1, $2), $_builder.getI64IntegerAttr(" # dim # "))">; 391 392// This op needs to be created in C++ because the generated Convert Op has no 393// way to specify shape information as an input. In the MatrixBandPart op 394// lowering, ConvertOp is not a root operation and the appropriate types cannot 395// be inferred, so we construct it manually. 396def createConvertOp: NativeCodeCall< 397 "CreateConvertOp(&($_builder), $0.getOwner()->getLoc(), $1, $2)">; 398 399// Performs a substitution of MatrixBandPartOp for XLA HLO ops. Pseudocode is 400// shown below, given a tensor `input` with k dimensions [I, J, K, ..., M, N] 401// and two integers, `num_lower` and `num_upper`: 402// 403// iota_m = { M x N matrix with 0,1,...M along the M dimension } 404// iota_n = { M x N matrix with 0,1,...N along the N dimension } 405// num_lower_or_m = (num_lower < 0) ? m : num_lower 406// num_upper_or_n = (num_upper < 0) ? n : num_upper 407// offset = iota_m - iota_n 408// indicator = (-num_lower_or_m < offset) & (offset < num_upper_or_n) 409// zero_matrix = { [I, J, K,...M, N] zero matrix } 410// return (indicator ? input : zero_matrix) 411// 412// TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp. 413def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower, 414 $num_upper), 415 [(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"1"> $input, $num_lower)), 416 (HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"0"> $input, $num_upper)), 417 (HLO_SelectOp:$num_lower_or_m 418 (HLO_CompareOp 419 $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), 420 HLO_COMPARISON_DIRECTION_LT, (HLO_DEFAULT_COMPARISON_TYPE) 421 ), 422 $m_dim, 423 $num_lower 424 ), 425 (HLO_SelectOp:$num_upper_or_n 426 (HLO_CompareOp 427 $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT, 428 (HLO_DEFAULT_COMPARISON_TYPE) 429 ), 430 $n_dim, 431 $num_upper 432 ), 433 (TF_SelectV2Op 434 (HLO_AndOp 435 (HLOClient_BroadcastCompareOp 436 (HLO_NegOp $num_lower_or_m), 437 (HLO_SubOp:$offset 438 (createIotaOp<"1"> $op, $input, $num_lower), 439 (createIotaOp<"0"> $op, $input, $num_lower) 440 ), 441 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE, 442 (HLO_DEFAULT_COMPARISON_TYPE) 443 ), 444 (HLOClient_BroadcastCompareOp $offset, $num_upper_or_n, 445 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE, 446 (HLO_DEFAULT_COMPARISON_TYPE) 447 ) 448 ), 449 $input, 450 (HLO_ConstOp (ConstantSplat<"0"> $input)) 451 )]>; 452 453//===----------------------------------------------------------------------===// 454// Nullary op patterns. 455//===----------------------------------------------------------------------===// 456 457def : Pat<(TF_ConstOp:$res ElementsAttr:$value), 458 (Tensor_CastOp (HLO_ConstOp $value)), 459 [(HLO_Tensor $res)]>; 460 461//===----------------------------------------------------------------------===// 462// Elu op patterns. 463//===----------------------------------------------------------------------===// 464 465def : Pat<(TF_EluOp AnyRankedTensor:$features), 466 (HLO_SelectOp 467 (HLOClient_BroadcastCompareOp 468 $features, 469 (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), 470 (BinBroadcastDimensions $zero, $features), 471 HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)), 472 $features, 473 (HLO_Expm1Op $features))>; 474 475def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), 476 (HLO_SelectOp 477 (HLOClient_BroadcastCompareOp 478 $features, 479 (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), 480 (BinBroadcastDimensions $zero, $features), 481 HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)), 482 $gradients, 483 (HLO_MulOp 484 $gradients, 485 (HLOClient_BroadcastAddOp 486 $features, 487 (HLO_ConstOp:$one (GetScalarOfType<1> $features)), 488 (BinBroadcastDimensions $one, $features))))>; 489 490//===----------------------------------------------------------------------===// 491// Relu op patterns. 492//===----------------------------------------------------------------------===// 493 494// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will 495// require HLO canonicalization of min and max on a tensor to ClampOp. 496 497// TODO(hinsu): Lower unsigned and quantized types after supporting 498// them in GetScalarOfType. 499def : Pat<(TF_ReluOp AnyRankedTensor:$input), 500 (HLOClient_BroadcastMaxOp 501 (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, 502 (BinBroadcastDimensions $zero, $input)), 503 [(TF_SintOrFpTensor $input)]>; 504 505// TODO(hinsu): Lower unsigned and quantized types after supporting 506// them in GetScalarOfType. 507def : Pat<(TF_Relu6Op AnyRankedTensor:$input), 508 (HLO_ClampOp (HLO_ConstOp (GetScalarOfType<0> $input)), $input, 509 (HLO_ConstOp (GetScalarOfType<6> $input))), 510 [(TF_SintOrFpTensor $input)]>; 511 512// ReluGrad(gradients, features) = gradients * (features > 0) 513// 514// $gradients needs to be of static shape so that on_true and on_false operands 515// of SelectOp have same shape. 516// 517// $features needs to be ranked for computation of the broadcast dimensions for 518// CompareOp. 519// 520// TODO(hinsu): Relax $gradients static shape requirement when there is a way 521// to create splat tensor of dynamic shape in HLO. 522def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), 523 (HLO_SelectOp 524 (HLOClient_BroadcastCompareOp $features, 525 (HLO_ConstOp (GetScalarOfType<0> $features)), 526 (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT, 527 (HLO_DEFAULT_COMPARISON_TYPE)), 528 $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; 529 530//===----------------------------------------------------------------------===// 531// Slice op patterns. 532//===----------------------------------------------------------------------===// 533 534def CastToI64AndUnpackTensor: NativeCodeCall< 535 "UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">; 536 537def CanBeTranslatedToDynamicSlice : Constraint<CPred< 538 "CanBeTranslatedToDynamicSlice($0, $1, $2.cast<DenseIntElementsAttr>())">>; 539 540def TFSliceSizes2HLOSliceSizes : NativeCodeCall< 541 "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast<DenseIntElementsAttr>()," 542 "&$_builder)">; 543 544def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, 545 (ConstantLikeMatcher AnyAttr:$slice_sizes)), 546 (HLO_DynamicSliceOp $input, 547 (CastToI64AndUnpackTensor $op, $starting_indices), 548 (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), 549 [(CanBeTranslatedToDynamicSlice $input, $starting_indices, 550 $slice_sizes)]>; 551 552//===----------------------------------------------------------------------===// 553// PartitionedCall and LegacyCall op patterns. 554//===----------------------------------------------------------------------===// 555 556def ArgTypesMatchCallee : Constraint< 557 CPred<"ArgTypesMatchCallee($0[0].getOwner(), $1, $2)">>; 558 559foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { 560 def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, 561 $config, $config_proto, $executor_type), 562 (CallOp $f, $args), 563 [(ArgTypesMatchCallee $op, $args, $f)]>; 564} 565 566// The extra attr on this op is _disable_call_shape_inference, which we ignore 567// in the bridge. 568def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), 569 (CallOp $f, $args), 570 [(ArgTypesMatchCallee $op, $args, $f)]>; 571 572//===----------------------------------------------------------------------===// 573// Reverse op patterns. 574//===----------------------------------------------------------------------===// 575 576// Handles axis conversion for TF reverse. 577def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast<ElementsAttr>(), &$_builder)">; 578 579def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), 580 (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; 581 582//===----------------------------------------------------------------------===// 583// Unary op patterns. 584//===----------------------------------------------------------------------===// 585 586foreach Mapping = [ 587 [TF_AbsOp, HLO_AbsOp], 588 [TF_AcosOp, HLOClient_AcosOp], 589 [TF_AcoshOp, HLOClient_AcoshOp], 590 [TF_AsinOp, HLOClient_AsinOp], 591 [TF_AsinhOp, HLOClient_AsinhOp], 592 [TF_AtanOp, HLOClient_AtanOp], 593 [TF_AtanhOp, HLOClient_AtanhOp], 594 [TF_CeilOp, HLO_CeilOp], 595 [TF_CoshOp, HLOClient_CoshOp], 596 [TF_ComplexAbsOp, HLO_AbsOp], 597 [TF_ConjOp, HLOClient_ConjOp], 598 [TF_CosOp, HLO_CosOp], 599 [TF_DigammaOp, HLOClient_DigammaOp], 600 [TF_ExpOp, HLO_ExpOp], 601 [TF_Expm1Op, HLO_Expm1Op], 602 [TF_ErfOp, HLOClient_ErfOp], 603 [TF_ErfcOp, HLOClient_ErfcOp], 604 [TF_FloorOp, HLO_FloorOp], 605 [TF_ImagOp, HLO_ImagOp], 606 [TF_InvertOp, HLO_NotOp], 607 [TF_IsFiniteOp, HLO_IsFiniteOp], 608 [TF_IsInfOp, HLOClient_IsInfOp], 609 [TF_LgammaOp, HLOClient_LgammaOp], 610 [TF_LogOp, HLO_LogOp], 611 [TF_Log1pOp, HLO_Log1pOp], 612 [TF_LogicalNotOp, HLO_NotOp], 613 [TF_NegOp, HLO_NegOp], 614 [TF_RealOp, HLO_RealOp], 615 [TF_RsqrtOp, HLO_RsqrtOp], 616 [TF_SigmoidOp, HLO_LogisticOp], 617 [TF_SinhOp, HLOClient_SinhOp], 618 [TF_SinOp, HLO_SinOp], 619 [TF_SqrtOp, HLO_SqrtOp], 620 [TF_TanhOp, HLO_TanhOp], 621 [TF_TanOp, HLOClient_TanOp] 622 ] in { 623 def : Pat<(Mapping[0] HLO_Tensor:$input), 624 (Mapping[1] $input)>; 625} 626 627def : Pat<(TF_AngleOp $x), (HLO_Atan2Op (HLO_ImagOp $x), (HLO_RealOp $x))>; 628 629// TODO(bixia): Lower Cast with a Complex type source operand or with 630// Truncate=True for floating point value conversions. 631def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse), 632 (HLO_ConvertOp $arg)>; 633 634def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), 635 (HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; 636 637// Result of the following ops changing tensor shape needs to have static 638// shape as HLO doesn't yet support dynamic reshaping ops. 639// 640// TODO(hinsu): Update once HLO supports dynamic reshaping ops. 641foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { 642 def : Pat<(TfOp:$res $arg, $ignored), 643 (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; 644} 645 646// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. 647def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>; 648 649def BothElementTypesSameWidthIntOrFloat : Constraint<CPred< 650 "getElementTypeOrSelf($0.getType()).isIntOrFloat() && " 651 "getElementTypeOrSelf($1.getType()).isIntOrFloat() && " 652 "getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == " 653 "getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">, 654 "element types must be integers or floats of same width">; 655 656// TODO(mgester): Due to restrictions of xla::BitcastConvertType we currently 657// only lower if both input and output types are int or float and have same width 658 659def : Pat<(TF_BitcastOp:$res HLO_Tensor:$arg), 660 (HLO_BitcastConvertOp $arg), 661 [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>; 662 663// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic 664// and going to MHLO. 665 666//===----------------------------------------------------------------------===// 667// Random ops. 668//===----------------------------------------------------------------------===// 669 670foreach srcDstOpPair = [[TF_RandomUniformOp, HLO_RngUniformOp], 671 [TF_RandomStandardNormalOp, HLO_RngNormalOp]] in { 672// TODO(b/148269299): handle random number generator seeds/states correctly. 673def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), 674 (srcDstOpPair[1] 675 (HLO_ConstOp 676 (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), 677 (HLO_ConstOp 678 (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), 679 (CastValueToI64 $old, $shape)), 680 [(IsShapedTensor $shape)]>; 681} 682 683//===----------------------------------------------------------------------===// 684// Sigmoid grad op. 685//===----------------------------------------------------------------------===// 686 687// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the 688// shape of $l instead of having it as a constant. 689def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), 690 (HLO_MulOp 691 (HLO_MulOp $r, $l), 692 (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>; 693 694//===----------------------------------------------------------------------===// 695// Softplus op. 696//===----------------------------------------------------------------------===// 697 698def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; 699 700def : Pattern<(TF_SoftplusOp AnyTensor:$features), 701 [ 702 (HLO_ExpOp:$features_exp $features), 703 (HLOClient_BroadcastAddOp:$threshold 704 (HLO_LogOp (HLO_ConstOp (EpsilonValue $features))), 705 (HLO_ConstOp (GetScalarOfType<2> $features)), 706 (NullDenseIntElementsAttr) 707 ), 708 (HLO_SelectOp:$output 709 (HLOClient_BroadcastCompareOp 710 $features, 711 (HLO_NegOp $threshold), 712 (NullDenseIntElementsAttr), 713 HLO_COMPARISON_DIRECTION_GT, 714 (HLO_DEFAULT_COMPARISON_TYPE) 715 ), 716 $features, 717 (HLO_SelectOp 718 (HLOClient_BroadcastCompareOp 719 $features, 720 $threshold, 721 (NullDenseIntElementsAttr), 722 HLO_COMPARISON_DIRECTION_LT, 723 (HLO_DEFAULT_COMPARISON_TYPE) 724 ), 725 $features_exp, 726 (HLO_Log1pOp $features_exp) 727 ) 728 ), 729 (replaceWithValue $output) 730 ]>; 731 732//===----------------------------------------------------------------------===// 733// XlaReplicaId op. 734//===----------------------------------------------------------------------===// 735 736def : Pat<(TF_XlaReplicaIdOp), 737 (TF_CastOp (HLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; 738 739//===----------------------------------------------------------------------===// 740// XlaGather op. 741//===----------------------------------------------------------------------===// 742 743def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; 744 745def HasValidGatherDims : Constraint<CPred<"HasValidGatherDims($0)">>; 746 747def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), 748 $dimension_numbers, $indices_are_sorted), 749 (HLO_GatherOp $operand, $start_indices, 750 (ToGatherDimNumsAttr $dimension_numbers), 751 (CastElementsToI64Elements $slice_sizes), 752 $indices_are_sorted), 753 [(HasValidGatherDims $dimension_numbers)]>; 754