• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This is the legalization pattern definition file for TF to XLA.
17
18include "mlir/IR/OpBase.td"
19include "mlir/Dialect/StandardOps/IR/Ops.td"
20include "mlir/Dialect/Tensor/IR/TensorOps.td"
21include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
22include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
23include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
24
25def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
26
27// IEEE compliant floating point tensors.
28def IEEEFloatTensor : TensorOf<[F16, F32, F64]>;
29
30//===----------------------------------------------------------------------===//
31// BatchNorm op patterns.
32//===----------------------------------------------------------------------===//
33
34def FeatureDimension : NativeCodeCall<
35    "getFeatureDimensionAttr($_builder, $0.getValue(), $1)">;
36def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
37def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
38
39def CastValueToI64: NativeCodeCall<
40  "CastValueToI64($0.getLoc(), $1, &$_builder)">;
41
42// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is
43// the corresponding value of ranked tensor type whose axis is referred in $0.
44def GetHLOAxisFromTFAxis : NativeCodeCall<
45  "GetHLOAxisFromTFAxis("
46  "$0, $1.getType().cast<RankedTensorType>().getRank(), &$_builder)">;
47
48// Same as the above but with $1 of type operand_range from variadic TensorFlow
49// input.
50def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
51  "GetHLOAxisFromTFAxis("
52  "$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), "
53  "&$_builder)">;
54
55def CastElementsToI64Elements : NativeCodeCall<
56  "hlo::ConvertElementsAttr("
57    "$0.cast<ElementsAttr>(), $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
58
59def : Pattern<
60    (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon,
61                         $exponential_avg_factor, $data_format,
62                         FalseBoolAttr:$is_training),
63    [(HLO_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance,
64                               $epsilon, (FeatureDimension $data_format, $x)),
65     // We already guaranteed that the last four results has no use so it
66     // does not matter what value we provide here for replacement.
67     /*batch_mean=*/(replaceWithValue $x),
68     /*batch_variance=*/(replaceWithValue $x),
69     /*reserve_space_1=*/(replaceWithValue $x),
70     /*reserve_space_2=*/(replaceWithValue $x)],
71    [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
72     (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>;
73
74//===----------------------------------------------------------------------===//
75// Assert op pattern.
76//===----------------------------------------------------------------------===//
77
78// HLO and XLA doesn't support Assertions.
79def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>;
80
81//===----------------------------------------------------------------------===//
82// Binary op patterns.
83//===----------------------------------------------------------------------===//
84
85// Check that two values can be broadcasted together
86def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
87    "types must be broadcastable">;
88
89class DirectBinaryPat<Op FromOp, Op ToOp>
90  : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
91        (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
92
93foreach fromToBinPair = [[TF_AddV2Op, HLOClient_BroadcastAddOp],
94                         [TF_Atan2Op, HLOClient_BroadcastAtan2Op],
95                         [TF_ComplexOp, HLOClient_BroadcastComplexOp],
96                         [TF_DivOp, HLOClient_BroadcastDivOp],
97                         [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp],
98                         [TF_MaximumOp, HLOClient_BroadcastMaxOp],
99                         [TF_MinimumOp, HLOClient_BroadcastMinOp],
100                         [TF_MulOp, HLOClient_BroadcastMulOp],
101                         [TF_PowOp, HLOClient_BroadcastPowOp],
102                         [TF_RealDivOp, HLOClient_BroadcastDivOp],
103                         [TF_SubOp, HLOClient_BroadcastSubOp],
104                         [TF_ZetaOp, HLOClient_BroadcastZetaOp]] in
105  def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
106
107def LowerRightShiftSigned :
108  Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r),
109      (HLOClient_BroadcastShiftRightArithmeticOp $l, $r,
110       (BinBroadcastDimensions $l, $r)),
111      [(SignedIntTensor $r)]>;
112
113// TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op
114// supports unsigned integers.
115
116// Performs a substitution of FloorDiv, pseudo code below:
117//
118//  return floor(div(x, y))
119def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
120          (HLO_FloorOp
121           (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
122          [(IEEEFloatTensor $l)]>;
123
124// Performs a substitution of FloorDiv for integer tensors, which required
125// additional correction for a negative numerator / denominator. Equivalent
126// pseudocode is shown below:
127//
128// if ((x < 0) != (y < 0)) {
129//   T abs_x = std::abs(x);
130//   T abs_y = std::abs(y);
131//   return -(abs_x + abs_y - 1) / abs_y;
132// } else {
133//   return x / y;
134// }
135//
136// BroadcastToDimensions is used to compute the broadcast attr to higher
137// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r')
138// without returning the broadcast of 'r' to broadcast('l', 'r').
139//
140// NOTE: This should be optimized for unsigned integers.
141// Requires static shaped inputs to create constant splats and computation of
142// broadcast attributes.
143def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
144        (HLO_SelectOp
145         (HLOClient_BroadcastCompareOp
146          (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)),
147           (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
148           (HLO_DEFAULT_COMPARISON_TYPE)),
149          (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)),
150           (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
151           (HLO_DEFAULT_COMPARISON_TYPE)),
152          (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ,
153          (HLO_DEFAULT_COMPARISON_TYPE)),
154         (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
155         (HLOClient_BroadcastDivOp
156          (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
157                       (HLOClient_BroadcastSubOp (HLO_AbsOp $r),
158                        (HLO_ConstOp (GetScalarOfType<1> $r)),
159                        (NullDenseIntElementsAttr)),
160                     (BinBroadcastDimensions $l, $r))),
161          (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))),
162        [(SignedIntTensor $l)]>;
163
164// Performs a substitution of FloorMod designed to correct for possibly negative
165// values. Pseudocode shown below:
166//
167//   T trunc_mod = std::fmod(x, y);
168//   return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
169// Requires static shaped inputs to create constant splats and computation of
170// broadcast attributes.
171def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
172      (HLO_SelectOp
173       (HLOClient_BroadcastAndOp
174        (HLOClient_BroadcastCompareOp
175         (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
176         (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)),
177         (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE,
178         (HLO_DEFAULT_COMPARISON_TYPE)),
179        (HLOClient_BroadcastCompareOp
180         (HLOClient_BroadcastCompareOp:$r_cmp $r,
181          (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)),
182          (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
183          (HLO_DEFAULT_COMPARISON_TYPE)),
184         (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros,
185          (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT,
186          (HLO_DEFAULT_COMPARISON_TYPE)),
187         (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE,
188         (HLO_DEFAULT_COMPARISON_TYPE)),
189        (NullDenseIntElementsAttr)),
190       (HLOClient_BroadcastAddOp $r,
191        $rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
192
193
194def Get2DTransposePerm: NativeCodeCall<
195  "Get2DTransposePerm($0, &$_builder)">;
196
197def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>;
198
199def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b),
200          (HLO_DotOp
201          (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))),
202          (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
203          /*precision_config=*/(NullArrayAttr))>;
204
205//===----------------------------------------------------------------------===//
206// Logical & bitwise binary op patterns.
207//===----------------------------------------------------------------------===//
208
209class DirectLogicalBinaryPat<Op FromOp, Op ToOp>
210  : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
211        (ToOp $l, $r, (BinBroadcastDimensions $l, $r)),
212        [(SignedIntTensor $l)]>;
213
214foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
215                         [TF_LogicalOrOp, HLOClient_BroadcastOrOp],
216                         [TF_BitwiseAndOp, HLOClient_BroadcastAndOp],
217                         [TF_BitwiseOrOp, HLOClient_BroadcastOrOp],
218                         [TF_BitwiseXorOp, HLOClient_BroadcastXorOp]] in
219  def : DirectLogicalBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
220
221//===----------------------------------------------------------------------===//
222// Compare op patterns.
223//===----------------------------------------------------------------------===//
224
225class DirectComparePat<Op FromOp, StrEnumAttrCase direction>
226  : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
227        (HLOClient_BroadcastCompareOp
228           $l, $r, (BinBroadcastDimensions $l, $r), direction,
229           (HLO_DEFAULT_COMPARISON_TYPE))>;
230
231def : DirectComparePat<TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT>;
232def : DirectComparePat<TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE>;
233def : DirectComparePat<TF_LessOp, HLO_COMPARISON_DIRECTION_LT>;
234def : DirectComparePat<TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE>;
235
236class EqualityPat<Op FromOp, StrEnumAttrCase direction>
237    : Pat<(FromOp AnyTensor:$l, AnyTensor:$r,
238           TrueBoolAttr:$incompatible_shape_error),
239        (HLOClient_BroadcastCompareOp
240         $l, $r, (BinBroadcastDimensions $l, $r), direction,
241         (HLO_DEFAULT_COMPARISON_TYPE)),
242        [(HLO_Tensor $l)]>;
243
244def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
245def : EqualityPat<TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE>;
246
247//===----------------------------------------------------------------------===//
248// Concat op patterns.
249//===----------------------------------------------------------------------===//
250
251def OneElementAttrPred
252  : CPred<"$_self.cast<ElementsAttr>().getType().getNumElements() == 1">;
253
254def OneElementAttr
255  : ElementsAttrBase<And<[ElementsAttr.predicate, OneElementAttrPred]>,
256                     "Scalar ElementsAttr">;
257
258def HasRankedFirstOperand
259  : Constraint<CPred<"(*$0.begin()).getType().isa<RankedTensorType>()">>;
260
261def IsShapedTensor
262  : Constraint<CPred<"$0.getType().isa<RankedTensorType>()">>;
263
264// This pattern converts TensorFlow axis format to HLO axis format which
265// doesn't wrap around like TensorFlow and is always positive. For this
266// conversion, use the first input to get inputs rank. Other inputs need not be
267// ranked.
268// Defining op for `axis` is TensorFlow constant op in the pattern as during
269// the conversion, original Concat op operands still refers to the old ops even
270// if HLO constant op is introduced as an replacement for the TensorFlow
271// Constant op.
272def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)),
273          (HLO_ConcatenateOp $inputs,
274            (GetHLOAxisFromTFAxisVariadic $axis, $inputs)),
275          [(HasRankedFirstOperand $inputs)]>;
276
277//===----------------------------------------------------------------------===//
278// CollectivePermute op patterns.
279//===----------------------------------------------------------------------===//
280
281def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)),
282          (HLO_CollectivePermuteOp $input,
283            (CastElementsToI64Elements $source_target_pairs))>;
284
285//===----------------------------------------------------------------------===//
286// CrossReplicaSum op patterns.
287//===----------------------------------------------------------------------===//
288
289def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)),
290          (HLO_CrossReplicaSumOp $input,
291            (CastElementsToI64Elements $group_assignment))>;
292
293//===----------------------------------------------------------------------===//
294// All2All op patterns.
295//===----------------------------------------------------------------------===//
296
297def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count),
298          (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>;
299
300//===----------------------------------------------------------------------===//
301// FFT op patterns.
302//===----------------------------------------------------------------------===//
303
304def GetInnerDimFromValue : NativeCodeCall<
305  "GetInnerDimFromValue($0.getType().cast<ShapedType>(), &$_builder)">;
306
307def CheckInnerDimStatic
308  : Constraint<CPred<"CheckInnerDimStatic($0.getType().cast<ShapedType>(), &$_builder)">>;
309
310def : Pat<(TF_FFTOp:$res $input),
311          (HLO_FftOp $input, HLO_FFT_TYPE_FFT, (GetInnerDimFromValue $res)),
312          [(CheckInnerDimStatic $input)]>;
313
314def : Pat<(TF_IFFTOp:$res $input),
315          (HLO_FftOp $input, HLO_FFT_TYPE_IFFT, (GetInnerDimFromValue $res)),
316          [(CheckInnerDimStatic $input)]>;
317
318//===----------------------------------------------------------------------===//
319// GatherV2 op patterns.
320//===----------------------------------------------------------------------===//
321
322// Here, $params and $indices needs to be ranked so that $axis and $batch_dims
323// attributes can be converted from TensorFlow axis format supporting negative
324// indexing to the HLO format.
325def LegalizeGatherV2 :
326  Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices,
327        (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims),
328      (HLO_TorchIndexSelectOp $params, $indices,
329        (GetHLOAxisFromTFAxis $axis, $params),
330        (GetHLOAxisFromTFAxis $batch_dims, $indices))>;
331
332//===----------------------------------------------------------------------===//
333// Pad op patterns.
334//===----------------------------------------------------------------------===//
335
336class SliceDenseIntElementsAttrColumn2D<string column> : NativeCodeCall<
337  "SliceDenseIntElementsAttrColumn2D($0.cast<ElementsAttr>(), " # column # " )">;
338
339class SliceDenseIntElementsAttr<string index, string axis> : NativeCodeCall<
340  "SliceDenseIntElementsAttr($0.cast<ElementsAttr>(), " # index # ", " # axis # ")">;
341
342// Interior padding attribute based on the TF padding.
343def GetInteriorPadding : NativeCodeCall <
344  "GetInteriorPadding($0.cast<ElementsAttr>())">;
345
346def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c),
347          (HLO_PadOp $input, $c,
348           (SliceDenseIntElementsAttrColumn2D<"0"> $padding),
349           (SliceDenseIntElementsAttrColumn2D<"1"> $padding),
350           (GetInteriorPadding $padding))>;
351
352//===----------------------------------------------------------------------===//
353// Identity op patterns.
354//===----------------------------------------------------------------------===//
355
356foreach src = [TF_IdentityOp, TF_StopGradientOp] in
357  def : Pat<(src $op), (replaceWithValue $op)>;
358
359// TODO(b/32223192): Support CheckNumerics in HLO.
360foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in
361  def : Pat<(src $op, $msg), (replaceWithValue $op)>;
362
363//===----------------------------------------------------------------------===//
364// MatMul op patterns.
365//===----------------------------------------------------------------------===//
366
367def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b),
368          (HLO_DotOp
369          (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))),
370          (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))),
371          /*precision_config=*/(NullArrayAttr))>;
372
373//===----------------------------------------------------------------------===//
374// MatrixBandPart op pattern.
375//===----------------------------------------------------------------------===//
376
377class getIntegerAttr<string x>: NativeCodeCall<
378  "$_builder.getI64IntegerAttr(" # x # ")">;
379
380class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
381  "$_builder.getIntegerAttr(getElementTypeOrSelf($1.getType()), "
382  "                         GetDimensionSizeFromEnd($0, " # dimFromEnd # "))"
383  >;
384
385// TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern
386// For now, this op needs to be created in C++ because the expected output type
387// cannot be inferred.
388class createIotaOp<string dim>: NativeCodeCall<
389  "$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), "
390  "Get2DTensorType($1, $2), $_builder.getI64IntegerAttr(" # dim # "))">;
391
392// This op needs to be created in C++ because the generated Convert Op has no
393// way to specify shape information as an input. In the MatrixBandPart op
394// lowering, ConvertOp is not a root operation and the appropriate types cannot
395// be inferred, so we construct it manually.
396def createConvertOp: NativeCodeCall<
397  "CreateConvertOp(&($_builder), $0.getOwner()->getLoc(), $1, $2)">;
398
399// Performs a substitution of MatrixBandPartOp for XLA HLO ops. Pseudocode is
400// shown below, given a tensor `input` with k dimensions [I, J, K, ..., M, N]
401// and two integers, `num_lower` and `num_upper`:
402//
403//   iota_m = { M x N matrix with 0,1,...M along the M dimension }
404//   iota_n = { M x N matrix with 0,1,...N along the N dimension }
405//   num_lower_or_m = (num_lower < 0) ? m : num_lower
406//   num_upper_or_n = (num_upper < 0) ? n : num_upper
407//   offset = iota_m - iota_n
408//   indicator = (-num_lower_or_m < offset) & (offset < num_upper_or_n)
409//   zero_matrix = { [I, J, K,...M, N] zero matrix }
410//   return (indicator ? input : zero_matrix)
411//
412// TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp.
413def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower,
414               $num_upper),
415         [(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"1"> $input, $num_lower)),
416          (HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"0"> $input, $num_upper)),
417          (HLO_SelectOp:$num_lower_or_m
418           (HLO_CompareOp
419            $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
420            HLO_COMPARISON_DIRECTION_LT, (HLO_DEFAULT_COMPARISON_TYPE)
421           ),
422           $m_dim,
423           $num_lower
424          ),
425          (HLO_SelectOp:$num_upper_or_n
426           (HLO_CompareOp
427            $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT,
428            (HLO_DEFAULT_COMPARISON_TYPE)
429           ),
430           $n_dim,
431           $num_upper
432          ),
433          (TF_SelectV2Op
434           (HLO_AndOp
435            (HLOClient_BroadcastCompareOp
436             (HLO_NegOp $num_lower_or_m),
437             (HLO_SubOp:$offset
438              (createIotaOp<"1"> $op, $input, $num_lower),
439              (createIotaOp<"0"> $op, $input, $num_lower)
440             ),
441             (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE,
442             (HLO_DEFAULT_COMPARISON_TYPE)
443            ),
444            (HLOClient_BroadcastCompareOp $offset, $num_upper_or_n,
445             (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE,
446             (HLO_DEFAULT_COMPARISON_TYPE)
447            )
448           ),
449           $input,
450           (HLO_ConstOp (ConstantSplat<"0"> $input))
451          )]>;
452
453//===----------------------------------------------------------------------===//
454// Nullary op patterns.
455//===----------------------------------------------------------------------===//
456
457def : Pat<(TF_ConstOp:$res ElementsAttr:$value),
458          (Tensor_CastOp (HLO_ConstOp $value)),
459          [(HLO_Tensor $res)]>;
460
461//===----------------------------------------------------------------------===//
462// Elu op patterns.
463//===----------------------------------------------------------------------===//
464
465def : Pat<(TF_EluOp AnyRankedTensor:$features),
466          (HLO_SelectOp
467           (HLOClient_BroadcastCompareOp
468              $features,
469              (HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
470              (BinBroadcastDimensions $zero, $features),
471              HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)),
472           $features,
473           (HLO_Expm1Op $features))>;
474
475def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
476           (HLO_SelectOp
477            (HLOClient_BroadcastCompareOp
478              $features,
479              (HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
480              (BinBroadcastDimensions $zero, $features),
481              HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)),
482            $gradients,
483            (HLO_MulOp
484             $gradients,
485             (HLOClient_BroadcastAddOp
486               $features,
487               (HLO_ConstOp:$one (GetScalarOfType<1> $features)),
488               (BinBroadcastDimensions $one, $features))))>;
489
490//===----------------------------------------------------------------------===//
491// Relu op patterns.
492//===----------------------------------------------------------------------===//
493
494// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will
495// require HLO canonicalization of min and max on a tensor to ClampOp.
496
497// TODO(hinsu): Lower unsigned and quantized types after supporting
498// them in GetScalarOfType.
499def : Pat<(TF_ReluOp AnyRankedTensor:$input),
500          (HLOClient_BroadcastMaxOp
501               (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
502               (BinBroadcastDimensions $zero, $input)),
503          [(TF_SintOrFpTensor $input)]>;
504
505// TODO(hinsu): Lower unsigned and quantized types after supporting
506// them in GetScalarOfType.
507def : Pat<(TF_Relu6Op AnyRankedTensor:$input),
508          (HLO_ClampOp (HLO_ConstOp (GetScalarOfType<0> $input)), $input,
509                       (HLO_ConstOp (GetScalarOfType<6> $input))),
510          [(TF_SintOrFpTensor $input)]>;
511
512// ReluGrad(gradients, features) = gradients * (features > 0)
513//
514// $gradients needs to be of static shape so that on_true and on_false operands
515// of SelectOp have same shape.
516//
517// $features needs to be ranked for computation of the broadcast dimensions for
518// CompareOp.
519//
520// TODO(hinsu): Relax $gradients static shape requirement when there is a way
521// to create splat tensor of dynamic shape in HLO.
522def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features),
523          (HLO_SelectOp
524            (HLOClient_BroadcastCompareOp $features,
525              (HLO_ConstOp (GetScalarOfType<0> $features)),
526              (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT,
527              (HLO_DEFAULT_COMPARISON_TYPE)),
528            $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
529
530//===----------------------------------------------------------------------===//
531// Slice op patterns.
532//===----------------------------------------------------------------------===//
533
534def CastToI64AndUnpackTensor: NativeCodeCall<
535  "UnpackTensorAlongZeroDim($0.getLoc(), CastValueToI64($0.getLoc(), $1, &$_builder), &$_builder).output()">;
536
537def CanBeTranslatedToDynamicSlice : Constraint<CPred<
538  "CanBeTranslatedToDynamicSlice($0, $1, $2.cast<DenseIntElementsAttr>())">>;
539
540def TFSliceSizes2HLOSliceSizes : NativeCodeCall<
541    "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast<DenseIntElementsAttr>(),"
542    "&$_builder)">;
543
544def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
545           (ConstantLikeMatcher AnyAttr:$slice_sizes)),
546          (HLO_DynamicSliceOp $input,
547           (CastToI64AndUnpackTensor $op, $starting_indices),
548           (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)),
549          [(CanBeTranslatedToDynamicSlice $input, $starting_indices,
550            $slice_sizes)]>;
551
552//===----------------------------------------------------------------------===//
553// PartitionedCall and LegacyCall op patterns.
554//===----------------------------------------------------------------------===//
555
556def ArgTypesMatchCallee : Constraint<
557    CPred<"ArgTypesMatchCallee($0[0].getOwner(), $1, $2)">>;
558
559foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in {
560  def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f,
561             $config, $config_proto, $executor_type),
562            (CallOp $f, $args),
563          [(ArgTypesMatchCallee $op, $args, $f)]>;
564}
565
566// The extra attr on this op is _disable_call_shape_inference, which we ignore
567// in the bridge.
568def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr),
569          (CallOp $f, $args),
570        [(ArgTypesMatchCallee $op, $args, $f)]>;
571
572//===----------------------------------------------------------------------===//
573// Reverse op patterns.
574//===----------------------------------------------------------------------===//
575
576// Handles axis conversion for TF reverse.
577def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast<ElementsAttr>(), &$_builder)">;
578
579def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)),
580    (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>;
581
582//===----------------------------------------------------------------------===//
583// Unary op patterns.
584//===----------------------------------------------------------------------===//
585
586foreach Mapping = [
587                   [TF_AbsOp, HLO_AbsOp],
588                   [TF_AcosOp, HLOClient_AcosOp],
589                   [TF_AcoshOp, HLOClient_AcoshOp],
590                   [TF_AsinOp, HLOClient_AsinOp],
591                   [TF_AsinhOp, HLOClient_AsinhOp],
592                   [TF_AtanOp, HLOClient_AtanOp],
593                   [TF_AtanhOp, HLOClient_AtanhOp],
594                   [TF_CeilOp, HLO_CeilOp],
595                   [TF_CoshOp, HLOClient_CoshOp],
596                   [TF_ComplexAbsOp, HLO_AbsOp],
597                   [TF_ConjOp, HLOClient_ConjOp],
598                   [TF_CosOp, HLO_CosOp],
599                   [TF_DigammaOp, HLOClient_DigammaOp],
600                   [TF_ExpOp, HLO_ExpOp],
601                   [TF_Expm1Op, HLO_Expm1Op],
602                   [TF_ErfOp, HLOClient_ErfOp],
603                   [TF_ErfcOp, HLOClient_ErfcOp],
604                   [TF_FloorOp, HLO_FloorOp],
605                   [TF_ImagOp, HLO_ImagOp],
606                   [TF_InvertOp, HLO_NotOp],
607                   [TF_IsFiniteOp, HLO_IsFiniteOp],
608                   [TF_IsInfOp, HLOClient_IsInfOp],
609                   [TF_LgammaOp, HLOClient_LgammaOp],
610                   [TF_LogOp, HLO_LogOp],
611                   [TF_Log1pOp, HLO_Log1pOp],
612                   [TF_LogicalNotOp, HLO_NotOp],
613                   [TF_NegOp, HLO_NegOp],
614                   [TF_RealOp, HLO_RealOp],
615                   [TF_RsqrtOp, HLO_RsqrtOp],
616                   [TF_SigmoidOp, HLO_LogisticOp],
617                   [TF_SinhOp, HLOClient_SinhOp],
618                   [TF_SinOp, HLO_SinOp],
619                   [TF_SqrtOp, HLO_SqrtOp],
620                   [TF_TanhOp, HLO_TanhOp],
621                   [TF_TanOp, HLOClient_TanOp]
622                  ] in {
623 def : Pat<(Mapping[0] HLO_Tensor:$input),
624           (Mapping[1] $input)>;
625}
626
627def : Pat<(TF_AngleOp $x), (HLO_Atan2Op (HLO_ImagOp $x), (HLO_RealOp $x))>;
628
629// TODO(bixia): Lower Cast with a Complex type source operand or with
630// Truncate=True for floating point value conversions.
631def : Pat<(TF_CastOp HLO_PredIntOrFpTensor:$arg, ConstBoolAttrFalse),
632          (HLO_ConvertOp $arg)>;
633
634def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)),
635          (HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>;
636
637// Result of the following ops changing tensor shape needs to have static
638// shape as HLO doesn't yet support dynamic reshaping ops.
639//
640// TODO(hinsu): Update once HLO supports dynamic reshaping ops.
641foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in {
642  def : Pat<(TfOp:$res $arg, $ignored),
643            (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;
644}
645
646// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
647def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>;
648
649def BothElementTypesSameWidthIntOrFloat : Constraint<CPred<
650  "getElementTypeOrSelf($0.getType()).isIntOrFloat() && "
651  "getElementTypeOrSelf($1.getType()).isIntOrFloat() && "
652  "getElementTypeOrSelf($0.getType()).getIntOrFloatBitWidth() == "
653  "getElementTypeOrSelf($1.getType()).getIntOrFloatBitWidth()">,
654  "element types must be integers or floats of same width">;
655
656// TODO(mgester): Due to restrictions of xla::BitcastConvertType we currently
657// only lower if both input and output types are int or float and have same width
658
659def : Pat<(TF_BitcastOp:$res HLO_Tensor:$arg),
660          (HLO_BitcastConvertOp $arg),
661          [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>;
662
663// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic
664// and going to MHLO.
665
666//===----------------------------------------------------------------------===//
667// Random ops.
668//===----------------------------------------------------------------------===//
669
670foreach srcDstOpPair = [[TF_RandomUniformOp, HLO_RngUniformOp],
671                        [TF_RandomStandardNormalOp, HLO_RngNormalOp]] in {
672// TODO(b/148269299): handle random number generator seeds/states correctly.
673def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
674          (srcDstOpPair[1]
675            (HLO_ConstOp
676              (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)),
677            (HLO_ConstOp
678              (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)),
679            (CastValueToI64 $old, $shape)),
680          [(IsShapedTensor $shape)]>;
681}
682
683//===----------------------------------------------------------------------===//
684// Sigmoid grad op.
685//===----------------------------------------------------------------------===//
686
687// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the
688// shape of $l instead of having it as a constant.
689def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
690          (HLO_MulOp
691           (HLO_MulOp $r, $l),
692           (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>;
693
694//===----------------------------------------------------------------------===//
695// Softplus op.
696//===----------------------------------------------------------------------===//
697
698def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">;
699
700def : Pattern<(TF_SoftplusOp AnyTensor:$features),
701              [
702                (HLO_ExpOp:$features_exp $features),
703                (HLOClient_BroadcastAddOp:$threshold
704                 (HLO_LogOp (HLO_ConstOp (EpsilonValue $features))),
705                 (HLO_ConstOp (GetScalarOfType<2> $features)),
706                 (NullDenseIntElementsAttr)
707                ),
708                (HLO_SelectOp:$output
709                 (HLOClient_BroadcastCompareOp
710                  $features,
711                  (HLO_NegOp $threshold),
712                  (NullDenseIntElementsAttr),
713                  HLO_COMPARISON_DIRECTION_GT,
714                  (HLO_DEFAULT_COMPARISON_TYPE)
715                 ),
716                 $features,
717                 (HLO_SelectOp
718                  (HLOClient_BroadcastCompareOp
719                   $features,
720                   $threshold,
721                   (NullDenseIntElementsAttr),
722                   HLO_COMPARISON_DIRECTION_LT,
723                   (HLO_DEFAULT_COMPARISON_TYPE)
724                  ),
725                  $features_exp,
726                  (HLO_Log1pOp $features_exp)
727                 )
728                ),
729                (replaceWithValue $output)
730              ]>;
731
732//===----------------------------------------------------------------------===//
733// XlaReplicaId op.
734//===----------------------------------------------------------------------===//
735
736def : Pat<(TF_XlaReplicaIdOp),
737          (TF_CastOp (HLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>;
738
739//===----------------------------------------------------------------------===//
740// XlaGather op.
741//===----------------------------------------------------------------------===//
742
743def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">;
744
745def HasValidGatherDims : Constraint<CPred<"HasValidGatherDims($0)">>;
746
747def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes),
748                          $dimension_numbers, $indices_are_sorted),
749          (HLO_GatherOp $operand, $start_indices,
750                        (ToGatherDimNumsAttr $dimension_numbers),
751                        (CastElementsToI64Elements $slice_sizes),
752                        $indices_are_sorted),
753          [(HasValidGatherDims $dimension_numbers)]>;
754