• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17 
18 #include <cstddef>
19 #include <cstdint>
20 #include <iterator>
21 #include <numeric>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/Dialect/StandardOps/Ops.h"  // TF:llvm-project
28 #include "mlir/Dialect/Traits.h"  // TF:llvm-project
29 #include "mlir/IR/Attributes.h"  // TF:llvm-project
30 #include "mlir/IR/Diagnostics.h"  // TF:llvm-project
31 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
32 #include "mlir/IR/Matchers.h"  // TF:llvm-project
33 #include "mlir/IR/Module.h"  // TF:llvm-project
34 #include "mlir/IR/Operation.h"  // TF:llvm-project
35 #include "mlir/IR/PatternMatch.h"  // TF:llvm-project
36 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
37 #include "mlir/IR/TypeUtilities.h"  // TF:llvm-project
38 #include "mlir/IR/Types.h"  // TF:llvm-project
39 #include "mlir/Pass/Pass.h"  // TF:llvm-project
40 #include "mlir/Transforms/DialectConversion.h"  // TF:llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
43 #include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
44 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
45 #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
46 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
47 #include "tensorflow/compiler/xla/client/padding.h"
48 #include "tensorflow/core/framework/common_shape_fns.h"
49 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
50 #include "tensorflow/core/util/padding.h"
51 #include "tensorflow/core/util/tensor_format.h"
52 
53 namespace mlir {
54 namespace xla_hlo {
55 namespace {
56 
57 class LegalizeTF : public FunctionPass<LegalizeTF> {
58  public:
59   LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)60   LegalizeTF(const LegalizeTF &) {}
LegalizeTF(bool allow_partial_conversion)61   explicit LegalizeTF(bool allow_partial_conversion) {
62     allow_partial_conversion_ = allow_partial_conversion;
63   }
64 
65   /// Performs the lowering to XLA dialect.
66   void runOnFunction() override;
67 
68  private:
69   Option<bool> allow_partial_conversion_{
70       *this, "allow-partial-conversion",
71       llvm::cl::desc("Allow operations that can't be legalized."),
72       llvm::cl::init(false)};
73 };
74 
75 /// Returns if the given TF data format string is the default format.
IsDefaultDataFormat(StringRef format)76 static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
77 
78 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(StringAttr format,RankedTensorType inputType)79 static size_t GetFeatureDimension(StringAttr format,
80                                   RankedTensorType inputType) {
81   return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
82 }
83 
84 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)85 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
86   auto array_attr = attr.cast<ArrayAttr>();
87   values->reserve(array_attr.getValue().size());
88   for (Attribute val : array_attr.getValue())
89     values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
90 }
91 
92 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)93 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
94                                                Builder *builder) {
95   RankedTensorType ty = RankedTensorType::get(
96       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
97   return DenseIntElementsAttr::get(ty, values);
98 }
99 
100 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)101 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
102   RankedTensorType ty =
103       RankedTensorType::get(static_cast<int64_t>(attr.size()),
104                             IntegerType::get(64, attr.getContext()));
105   return DenseIntElementsAttr::get(ty, attr.getValue());
106 }
107 
108 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)109 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
110                                                Builder *builder) {
111   RankedTensorType ty = RankedTensorType::get(
112       {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
113   return DenseIntElementsAttr::get(ty, values);
114 }
115 
116 // Returns the corresponding type that should be used for performing sum
117 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)118 Type GetSumAccumulationType(Type input_type) {
119   MLIRContext *ctx = input_type.getContext();
120   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
121   if (input_type.isInteger(8) || input_type.isInteger(16))
122     return IntegerType::get(32, ctx);
123   return input_type;
124 }
125 
126 // Returns axis in HLO format from TF elements attr with exactly one element
127 // containing axis in the TensorFlow format. TensorFlow format supports negative
128 // indexing unlike HLO.
GetHLOAxisFromTFAxis(ElementsAttr attr,int64_t rank,Builder * b)129 static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank,
130                                         Builder *b) {
131   SmallVector<uint64_t, 1> index(attr.getType().getRank(), 0);
132   int64_t axis = attr.getValue<IntegerAttr>(index).getInt();
133   if (axis < 0) {
134     axis += rank;
135   }
136   return b->getI64IntegerAttr(axis);
137 }
138 
GetHLOAxisFromTFAxis(IntegerAttr attr,int64_t rank,Builder * b)139 static IntegerAttr GetHLOAxisFromTFAxis(IntegerAttr attr, int64_t rank,
140                                         Builder *b) {
141   int64_t axis = attr.getInt();
142   if (axis < 0) {
143     axis += rank;
144   }
145   return b->getI64IntegerAttr(axis);
146 }
147 
148 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
149 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
150 // be negative, in which case, the corresponding HLO axis is
151 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)152 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
153                                                            int64_t rank) {
154   DenseIntElementsAttr attrs;
155   if (!matchPattern(value, m_Constant(&attrs)) ||
156       attrs.getType().getRank() != 0) {
157     return llvm::None;
158   }
159   int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
160   return axis < 0 ? axis + rank : axis;
161 }
162 
163 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
164 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)165 static ConvertOp CastValueToI64(Location loc, Value value,
166                                 PatternRewriter *rewriter) {
167   return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
168 }
169 
170 // Returns size of dimension at the specified index, if ranked tensor.
171 // Otherwise, returns -1.
172 //
173 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)174 int64_t GetDimSize(Type ty, int64_t index) {
175   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
176   if (!ranked_ty) return -1;
177 
178   return ranked_ty.getDimSize(index);
179 }
180 
181 template <typename T>
ToTensorShape(llvm::ArrayRef<T> sizes)182 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
183   return tensorflow::TensorShape(
184       llvm::SmallVector<tensorflow::int64, 4>(sizes.begin(), sizes.end()));
185 }
186 
187 // Returns minimal value for the given int or float element type.
GetMinValueForType(Type ty,Location loc,PatternRewriter * rewriter)188 static ConstOp GetMinValueForType(Type ty, Location loc,
189                                   PatternRewriter *rewriter) {
190   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
191 
192   DenseElementsAttr attr;
193   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
194     APFloat neg_inf =
195         APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true);
196     attr = DenseElementsAttr::get(scalar_ty, neg_inf);
197   } else {
198     auto int_ty = ty.cast<IntegerType>();
199     APInt min_val = APInt::getSignedMinValue(int_ty.getWidth());
200     attr = DenseElementsAttr::get(scalar_ty, min_val);
201   }
202   return rewriter->create<ConstOp>(loc, attr);
203 }
204 
205 // Returns maximal value for the given int or float element type.
GetMaxValueForType(Type ty,Location loc,PatternRewriter * rewriter)206 static ConstOp GetMaxValueForType(Type ty, Location loc,
207                                   PatternRewriter *rewriter) {
208   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
209 
210   DenseElementsAttr attr;
211   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
212     APFloat pos_inf =
213         APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false);
214     attr = DenseElementsAttr::get(scalar_ty, pos_inf);
215   } else {
216     auto int_ty = ty.cast<IntegerType>();
217     APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth());
218     attr = DenseElementsAttr::get(scalar_ty, max_val);
219   }
220   return rewriter->create<ConstOp>(loc, attr);
221 }
222 
223 // Returns int or float scalar DenseElementsAttr attribute with the given
224 // element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,PatternRewriter * rewriter)225 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
226                                     PatternRewriter *rewriter) {
227   return rewriter->create<ConstOp>(loc, xla::GetScalarOfType(ty, raw_value));
228 }
229 
230 // Builds body for reduce op by using the using the template binary op as the
231 // reducer op.
232 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)233 static void BuildReduceBody(Type element_type, Region *body,
234                             OpBuilder *builder) {
235   OpBuilder::InsertionGuard guard(*builder);
236   Block *block = builder->createBlock(body);
237 
238   // Block arguments are scalars of the given element type.
239   Type type = RankedTensorType::get(/*shape=*/{}, element_type);
240   block->addArguments({type, type});
241 
242   Location loc = body->getLoc();
243   auto reducer =
244       builder->create<Op>(loc, block->getArgument(0), block->getArgument(1),
245                           /*broadcast_dimensions=*/nullptr);
246   builder->create<ReturnOp>(loc, reducer.getResult());
247 }
248 
249 // Builds region taking two arguments and returning second argument as the
250 // result. Corresponds to the function f(x, y) = y.
251 // Used in Scatter op's computation to update specific elements.
BuildBinaryAssignmentRegion(Type element_type,Region * region,OpBuilder * builder)252 static void BuildBinaryAssignmentRegion(Type element_type, Region *region,
253                                         OpBuilder *builder) {}
254 
255 // Builds a set of operations for applying reduction on the input value. A
256 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)257 static Value ApplyReduction(Location loc, Value input,
258                             DenseIntElementsAttr reduce_dims,
259                             OpBuilder *builder) {
260   auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
261   return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
262                                     builder->getBoolAttr(false));
263 }
264 
265 // Creates a xla_hlo.rng_uniform op with `builder` to generate `num_elements`
266 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)267 static xla_hlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
268                                                 int lower_limit,
269                                                 int upper_limit,
270                                                 OpBuilder *builder) {
271   auto i32_type = builder->getIntegerType(32);
272   auto key_type = RankedTensorType::get({num_elements}, i32_type);
273   auto shape_tensor = builder->create<xla_hlo::ConstOp>(
274       loc, GetI64ElementsAttr({num_elements}, builder));
275 
276   auto lower = builder->create<xla_hlo::ConstOp>(
277       loc, builder->getI32IntegerAttr(lower_limit));
278   auto upper = builder->create<xla_hlo::ConstOp>(
279       loc, builder->getI32IntegerAttr(upper_limit));
280 
281   return builder->create<xla_hlo::RngUniformOp>(loc, key_type, lower, upper,
282                                                 shape_tensor);
283 }
284 
285 using WhileBodyFnType = llvm::function_ref<void(
286     Location loc, Value iteration, ArrayRef<Value> old_values,
287     SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
288 
289 // Creates a xla_hlo.while op with `builder` to loop `num_interations` times,
290 // each time calling the given `body_fn` on a set of values to generate a new
291 // set of values. Returns the final set of values via `final_values`. The
292 // initial set of values is passed in via `init_values`.
293 //
294 // This effectively does:
295 //
296 // ```c++
297 // SmallVector<Values, 4> old_values = init_values;
298 // SmallVector<Values, 4> new_values;
299 // for (int i = 0; i < num_iterations; ++i) {
300 //   body_fn(old_values, &new_values, ...);
301 //   old_values = new_values;
302 // }
303 // ```
304 //
305 // Under the hood an induction variable is prepended to values to control the
306 // number of iterations, but that is transparent to `body_fn`, which does not
307 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)308 static void CreateWhile32(Location loc, int num_iterations,
309                           WhileBodyFnType body_fn, ArrayRef<Value> init_values,
310                           SmallVectorImpl<Value> *final_values,
311                           OpBuilder *builder) {
312   int value_count = init_values.size() + 1;
313 
314   // Prepend a loop induction variable to the initial values.
315   SmallVector<Value, 2> init_values_with_loop_iv;
316   init_values_with_loop_iv.reserve(value_count);
317   // The initial value for the loop induction variable is 0.
318   init_values_with_loop_iv.push_back(
319       builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
320   init_values_with_loop_iv.append(init_values.begin(), init_values.end());
321 
322   // Prepare the initial tuple for the while op.
323   auto init_tuple =
324       builder->create<xla_hlo::TupleOp>(loc, init_values_with_loop_iv);
325   auto tuple_type = init_tuple.getType();
326 
327   // Create the while op.
328   auto while_op = builder->create<xla_hlo::WhileOp>(loc, init_tuple);
329 
330   {
331     OpBuilder::InsertionGuard guard(*builder);
332 
333     // Build up the only block in the condition region. It should take one
334     // argument of the loop's tuple type.
335     Region &condition = while_op.cond();
336     Block *block = builder->createBlock(&condition);
337     BlockArgument arg = block->addArgument(tuple_type);
338 
339     // Get the loop induction variable and compare it against the upper limit.
340     auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
341     auto upper_limit = builder->create<xla_hlo::ConstOp>(
342         loc, builder->getI32IntegerAttr(num_iterations));
343     StringAttr compare_direction = StringAttr::get("LT", builder->getContext());
344     Value compare = builder->create<xla_hlo::CompareOp>(
345         loc, loop_iv, upper_limit,
346         /*broadcast_dimensions=*/nullptr, compare_direction);
347 
348     builder->create<xla_hlo::ReturnOp>(loc, compare);
349   }
350 
351   {
352     OpBuilder::InsertionGuard guard(*builder);
353 
354     // Build up the only block in the body region. It should take one
355     // argument of the loop's tuple type.
356     Region &body = while_op.body();
357     Block *block = builder->createBlock(&body);
358     BlockArgument arg = block->addArgument(tuple_type);
359 
360     SmallVector<Value, 4> old_values;  // From the previous iteration
361     SmallVector<Value, 4> new_values;  // Generated by this iteration
362     old_values.reserve(value_count);
363     new_values.reserve(value_count);
364 
365     // Unpack the tuple value from the last iteration.
366     for (int i = 0; i < value_count; ++i)
367       old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
368 
369     // Feed all values excluding the loop induction variable to body_fn.
370     body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
371             &new_values, builder);
372 
373     // Increment the loop induction variable by one.
374     auto one =
375         builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
376     auto no_broadcast_dims = GetI64ElementsAttr({}, builder);
377     auto plus_one = builder->create<xla_hlo::AddOp>(loc, old_values[0], one,
378                                                     no_broadcast_dims);
379     // Prepend with the updated loop induction variable.
380     new_values.insert(new_values.begin(), plus_one);
381 
382     Value updated_tuple = builder->create<xla_hlo::TupleOp>(loc, new_values);
383 
384     builder->create<xla_hlo::ReturnOp>(loc, updated_tuple);
385   }
386 
387   final_values->reserve(init_values.size());
388   for (int i = 0, e = init_values.size(); i < e; ++i)
389     final_values->push_back(
390         builder->create<GetTupleElementOp>(loc, while_op, i + 1));
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // BatchNorm op utilities.
395 //===----------------------------------------------------------------------===//
396 
getFeatureDimensionAttr(Builder & b,StringAttr format,Value input)397 static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
398                                            Value input) {
399   return b.getI64IntegerAttr(
400       GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // Bias op utilities.
405 //===----------------------------------------------------------------------===//
406 
407 // Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
408 // Requires input to have ranked tensor.
getBiasFeatureDimension(Builder & b,StringAttr format,Value input)409 static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
410                                                     StringAttr format,
411                                                     Value input) {
412   auto inputType = input.getType().cast<RankedTensorType>();
413   size_t featureDim = GetFeatureDimension(format, inputType);
414   RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
415   return DenseIntElementsAttr::get(type, featureDim);
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // MatMul op utilities.
420 //===----------------------------------------------------------------------===//
421 
422 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
423 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)424 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
425   if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
426   return GetI64ElementsAttr({0, 1}, b);
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // Pad op utilities.
431 //===----------------------------------------------------------------------===//
432 
433 // Slices input attribute of rank two and returns the specified column.
434 //
435 // Always returns 64 bit integer attribute regardless of bitwidth of the input
436 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)437 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
438     ElementsAttr input, int column) {
439   auto int_attr = input.cast<DenseIntElementsAttr>();
440   auto shaped_type = int_attr.getType();
441   auto shape = shaped_type.getShape();
442 
443   if (shape.size() != 2) return DenseIntElementsAttr();
444 
445   llvm::SmallVector<int64_t, 4> values;
446   values.reserve(shaped_type.getNumElements() / shape[1]);
447 
448   for (auto it : llvm::enumerate(int_attr.getIntValues())) {
449     if (it.index() % shape[1] == column) {
450       values.push_back(it.value().getSExtValue());
451     }
452   }
453 
454   auto element_type = IntegerType::get(64, input.getContext());
455   return DenseIntElementsAttr::get(
456       RankedTensorType::get({shape[0]}, element_type), values);
457 }
458 
459 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
460 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)461 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
462   auto length = tf_padding.getType().getShape()[0];
463   auto element_type = IntegerType::get(64, tf_padding.getContext());
464   return DenseIntElementsAttr::get<int64_t>(
465       RankedTensorType::get({length}, element_type), 0);
466 }
467 
468 //===----------------------------------------------------------------------===//
469 // Binary op utilities.
470 //===----------------------------------------------------------------------===//
471 
472 // Returns whether the two values are guaranteed to be broadcastable to the
473 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
474 // must be broadcasted with a size 1 tensor or another dynamic dimension.
475 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)476 static bool AreBroadcastCompatible(Value x, Value y) {
477   auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
478   auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
479   if (!x_rankless || !y_rankless) {
480     return false;
481   }
482 
483   // Check that the shapes can be broadcasted.
484   auto shape_x = x_rankless.getShape();
485   auto shape_y = y_rankless.getShape();
486 
487   int rank_diff = shape_x.size() - shape_y.size();
488   int offset_x = rank_diff > 0 ? rank_diff : 0;
489   int offset_y = rank_diff < 0 ? -rank_diff : 0;
490   for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
491     int index_x = i + offset_x;
492     int index_y = i + offset_y;
493     if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
494         (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
495       return false;
496     }
497   }
498 
499   return true;
500 }
501 
502 // Return a new TensorType the same rank and dimensions as the input with an
503 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)504 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
505                                     Type element_type) {
506   RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
507   if (ranked_type) {
508     return RankedTensorType::get(ranked_type.getShape(), element_type);
509   }
510 
511   return UnrankedTensorType::get(element_type);
512 }
513 
514 //===----------------------------------------------------------------------===//
515 // Softmax op utilities.
516 //===----------------------------------------------------------------------===//
517 
518 // Returns a 1-d i64 elements attribute populated with numbers from start to
519 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)520 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
521                                                      Builder *builder) {
522   int size = end - start;
523 
524   SmallVector<int64_t, 4> vals;
525   vals.resize(size);
526   std::iota(vals.begin(), vals.end(), start);
527 
528   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
529   return DenseIntElementsAttr::get(ty, vals);
530 }
531 
532 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)533 static Type GetAccumulationType(Type ty) {
534   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
535   // repeated floating point additions.
536   return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // ArgMax/ArgMin op utilities.
541 //===----------------------------------------------------------------------===//
542 
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)543 static void BuildArgMinMaxReductionBody(Type input_element_type,
544                                         Type index_element_type,
545                                         StringRef direction, Region *body,
546                                         OpBuilder *builder) {
547   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
548 
549   Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
550   Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
551   Block *block = builder->createBlock(body);
552   block->addArguments({input_type, index_type, input_type, index_type});
553 
554   Location loc = body->getLoc();
555   StringAttr compare_direction =
556       StringAttr::get(direction, builder->getContext());
557   Value compare = builder->create<CompareOp>(
558       loc, block->getArgument(0), block->getArgument(2),
559       /*broadcast_dimensions=*/nullptr, compare_direction);
560 
561   Value selected_input = builder->create<SelectOp>(
562       loc, input_type, compare, block->getArgument(0), block->getArgument(2));
563   Value selected_index = builder->create<SelectOp>(
564       loc, index_type, compare, block->getArgument(1), block->getArgument(3));
565 
566   Value return_values[] = {selected_input, selected_index};
567   builder->create<ReturnOp>(loc, return_values);
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // Slice op utilities.
572 //===----------------------------------------------------------------------===//
573 
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)574 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
575                                           DenseIntElementsAttr slice_sizes) {
576   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
577   int64_t input_rank = input_ty.getRank();
578   ArrayRef<int64_t> input_shape = input_ty.getShape();
579   DenseIntElementsAttr constant_start_indices;
580   if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
581     for (int64_t i = 0; i < input_rank; ++i) {
582       int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
583       int64_t input_size = input_shape[i];
584       if (slice_size < 0 || (input_size != -1 && slice_size > input_size)) {
585         return false;
586       }
587     }
588     return true;
589   }
590 
591   for (int64_t i = 0; i < input_rank; ++i) {
592     int64_t input_size = input_shape[i];
593     int64_t start_index =
594         constant_start_indices.getValue<IntegerAttr>(i).getInt();
595     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
596     if (start_index < 0) return false;
597     // A slice_size of -1 means "all elements from start_index to the end".
598     // We can't support this semantics for dynamic shapes.
599     if (slice_size == -1) {
600       if (input_size == -1) return false;
601       slice_size = input_size - start_index;
602     }
603     if (input_size != -1 && start_index + slice_size > input_size) {
604       return false;
605     }
606   }
607 
608   return true;
609 }
610 
611 // TF slice size can be -1, which represents all elements from start_index to
612 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
613 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)614 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
615     Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
616     Builder *builder) {
617   DenseIntElementsAttr constant_start_indices;
618   if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
619     return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
620         .cast<DenseIntElementsAttr>();
621   }
622 
623   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
624   int64_t input_rank = input_ty.getRank();
625   ArrayRef<int64_t> input_shape = input_ty.getShape();
626   SmallVector<int64_t, 4> normalized_sizes;
627 
628   for (int64_t i = 0; i < input_rank; ++i) {
629     int64_t input_size = input_shape[i];
630     int64_t start_index =
631         constant_start_indices.getValue<IntegerAttr>(i).getInt();
632     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
633     normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
634                                                 : slice_size);
635   }
636 
637   return GetI64ElementsAttr(normalized_sizes, builder);
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // Sort op utilities.
642 //===----------------------------------------------------------------------===//
643 
644 // Builds the region `body` for xla_hlo.sort's comparator: for each type in
645 // `element_types`, create two block arguments, one for lhs and one for rhs, and
646 // generates xla_hlo.compare op to compare them with the given `direction`.
647 //
648 // Note that this right now only does comparision on the first pair of block
649 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,Region * body,OpBuilder * builder)650 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
651                                     StringRef direction, Region *body,
652                                     OpBuilder *builder) {
653   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
654 
655   Block *block = builder->createBlock(body);
656   // Add two arguments for each element type.
657   for (Type element_type : element_types) {
658     TensorType tensor_type = RankedTensorType::get({}, element_type);
659     block->addArguments({tensor_type, tensor_type});
660   }
661 
662   Location loc = body->getLoc();
663   StringAttr compare_direction =
664       StringAttr::get(direction, builder->getContext());
665   Value compare = builder->create<xla_hlo::CompareOp>(
666       loc, block->getArgument(0), block->getArgument(1),
667       /*broadcast_dimensions=*/nullptr, compare_direction);
668 
669   builder->create<xla_hlo::ReturnOp>(loc, compare);
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // Op converters.
674 //===----------------------------------------------------------------------===//
675 
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)676 NamedAttribute GetConvDimensionNumbersAttr(
677     ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
678     Builder *builder) {
679   int64_t num_spatial_dims = spatial_dim_indices.size();
680   int64_t num_dims = num_spatial_dims + 2;
681 
682   IntegerAttr batch_dim =
683       builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
684   IntegerAttr feature_dim =
685       builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
686   DenseIntElementsAttr spatial_dims =
687       GetI64ElementsAttr(spatial_dim_indices, builder);
688 
689   // Filters data_format is always HWIO so input channels dimension is after
690   // all spatial dimensions.
691   IntegerAttr kernel_input_feature_dim =
692       builder->getI64IntegerAttr(num_spatial_dims);
693   IntegerAttr kernel_output_feature_dim =
694       builder->getI64IntegerAttr(num_spatial_dims + 1);
695   DenseIntElementsAttr kernel_spatial_dimensions =
696       GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
697 
698   return builder->getNamedAttr(
699       "dimension_numbers",
700       ConvDimensionNumbers::get(
701           batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
702           kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
703           feature_dim, spatial_dims, builder->getContext()));
704 }
705 
706 // Converts the TensorFlow conv op in template to the generic HLO conv op by
707 // converting TensorFlow op attributes to HLO op attributes.
708 //
709 // Sample result for Conv2D:
710 //
711 //   %conv = "xla_hlo.conv"(%input, %filter) {
712 //     strides = [1, 2],
713 //     paddings = [[1, 0], [1, 1]],
714 //     ...
715 //   }
716 //
717 // This pattern is not defined using declarative rewrite rules as computation of
718 // the paddings attribute anyway requires multiple source op attributes and
719 // result op attributes. Defining it as declarative rewrite rule will introduce
720 // some duplication in the C++ helper methods.
721 template <typename OpT, int num_spatial_dims>
722 class ConvertConv : public OpRewritePattern<OpT> {
723  public:
724   using OpRewritePattern<OpT>::OpRewritePattern;
725 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const726   PatternMatchResult matchAndRewrite(OpT op,
727                                      PatternRewriter &rewriter) const override {
728     tensorflow::TensorFormat format;
729     std::string data_format = op.data_format().str();
730     if (!FormatFromString(data_format, &format)) return Pattern::matchFailure();
731 
732     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
733     auto filter_ty =
734         op.filter().getType().template dyn_cast<RankedTensorType>();
735     auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
736 
737     // Input, filter and the result needs to have static shape for calculation
738     // of HLO paddings and feature group count attributes.
739     for (RankedTensorType ty : {input_ty, filter_ty, result_ty}) {
740       if (!ty || !ty.hasStaticShape()) return Pattern::matchFailure();
741     }
742 
743     int num_dims = num_spatial_dims + 2;
744     tensorflow::Padding padding;
745     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
746       return Pattern::matchFailure();
747 
748     auto get_int = [](Attribute attr) {
749       return attr.template cast<IntegerAttr>().getInt();
750     };
751 
752     SmallVector<int64_t, 4> spatial_dim_indices;
753     SmallVector<int64_t, 4> rhs_dilations;
754     SmallVector<int64_t, 4> window_strides;
755     SmallVector<int64_t, 8> paddings;
756 
757     ArrayRef<Attribute> dilations = op.dilations().getValue();
758     ArrayRef<Attribute> strides = op.strides().getValue();
759     ArrayRef<Attribute> explicit_paddings;
760     if (padding == tensorflow::Padding::EXPLICIT) {
761       // EXPLICIT padding mode and the associated attribute is limited to
762       // Conv2D. So, fetch attribute by identifier instead of the
763       // op.explicit_paddings() attribute getter.
764       explicit_paddings =
765           op.template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
766     }
767 
768     for (int i = 0; i < num_spatial_dims; ++i) {
769       int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
770       spatial_dim_indices.push_back(dim);
771 
772       int64_t stride = get_int(strides[dim]);
773       int64_t dilation = get_int(dilations[dim]);
774       window_strides.push_back(stride);
775       rhs_dilations.push_back(dilation);
776 
777       int64_t pad_low, pad_high;
778       if (padding == tensorflow::Padding::EXPLICIT) {
779         pad_low = get_int(explicit_paddings[2 * dim]);
780         pad_high = get_int(explicit_paddings[2 * dim + 1]);
781       } else {
782         tensorflow::int64 output_size;
783         tensorflow::int64 pad_low_int64;
784         tensorflow::int64 pad_high_int64;
785         tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
786             input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
787             padding, &output_size, &pad_low_int64, &pad_high_int64);
788         if (!status.ok()) return Pattern::matchFailure();
789         pad_low = pad_low_int64;
790         pad_high = pad_high_int64;
791       }
792       paddings.push_back(pad_low);
793       paddings.push_back(pad_high);
794     }
795 
796     auto rhs_dilations_attr = rewriter.getNamedAttr(
797         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
798 
799     auto window_strides_attr = rewriter.getNamedAttr(
800         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
801 
802     auto dimension_numbers_attr =
803         GetConvDimensionNumbersAttr(spatial_dim_indices, format, &rewriter);
804 
805     int64_t input_channels =
806         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, format));
807     // Filters data_format is always HWIO so input channels dimension is after
808     // all spatial dimensions.
809     int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
810     // TensorFlow convolution op verifies that the number of input channels is
811     // divisible by the number of filter channels.
812     int64_t feature_group_count = input_channels / filter_channels;
813     auto feature_group_count_attr = rewriter.getNamedAttr(
814         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
815 
816     auto batch_group_count_attr = rewriter.getNamedAttr(
817         "batch_group_count", rewriter.getI64IntegerAttr(1));
818 
819     RankedTensorType paddings_ty = RankedTensorType::get(
820         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
821     auto paddings_attr = rewriter.getNamedAttr(
822         "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
823 
824     SmallVector<Value, 2> operands(op.getOperands());
825     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
826                               dimension_numbers_attr, feature_group_count_attr,
827                               batch_group_count_attr, paddings_attr};
828     rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
829                                         llvm::makeArrayRef(attrs));
830     return Pattern::matchSuccess();
831   }
832 };
833 
834 using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
835 
836 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
837 // division can result in strange behavior.
838 //
839 //      floordiv = cast(floordiv(cast(left), cast(right))))
840 //
841 //   %left_cast = cast(%left)
842 //   %right_cast = cast(%right)
843 //   %div = div(%left, %left)
844 //   %floored = floor(%div)
845 //   %floored_cast = cast(%floored)
846 //
847 // Required to manually specify the intermediate types.
848 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
849  public:
850   using OpRewritePattern::OpRewritePattern;
851 
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const852   PatternMatchResult matchAndRewrite(TF::FloorDivOp op,
853                                      PatternRewriter &rewriter) const override {
854     auto l = op.x();
855     auto r = op.y();
856     auto element_type = getElementTypeOrSelf(l.getType());
857     if (!element_type.isBF16()) return matchFailure();
858 
859     auto out_type = op.z().getType().cast<TensorType>();
860 
861     l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
862     r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
863 
864     auto intermediate = rewriter.create<TF::FloorDivOp>(
865         op.getLoc(),
866         ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
867         r);
868 
869     auto floor_op =
870         rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
871     rewriter.replaceOp(op, floor_op.getResult());
872     return Pattern::matchSuccess();
873   }
874 };
875 
876 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
877 // depending on arity of the op.
878 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
879  public:
880   using OpRewritePattern::OpRewritePattern;
881 
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const882   PatternMatchResult matchAndRewrite(TF::EinsumOp op,
883                                      PatternRewriter &rewriter) const override {
884     StringAttr equation = op.getAttrOfType<StringAttr>("equation");
885     if (op.N() == 1) {
886       rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
887           op, op.getType(), *op.inputs().begin(), equation);
888     } else if (op.N() == 2) {
889       ValueRange inputs = op.inputs();
890       rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
891                                             inputs[1], equation);
892     } else {
893       // TensorFlow EinsumOp verifies that the number of operands are at most
894       // two.
895       return Pattern::matchFailure();
896     }
897     return Pattern::matchSuccess();
898   }
899 };
900 
901 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
902 // BatchNormGradOp for training and a sequence of binary ops for inference.
903 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
904 template <typename FusedBatchNormGradOpT>
905 class ConvertFusedBatchNormGradBase
906     : public OpRewritePattern<FusedBatchNormGradOpT> {
907  public:
908   using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
909 
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const910   PatternMatchResult matchAndRewrite(FusedBatchNormGradOpT op,
911                                      PatternRewriter &rewriter) const override {
912     Location loc = op.getLoc();
913     Value grad = op.y_backprop();
914     Value act = op.x();
915     Value scale = op.scale();
916     Value mean = op.reserve_space_1();
917     Value var = op.reserve_space_2();
918 
919     // TODO(b/141785544): Update this to not require static shapes.
920     // activation shape needs to be static to convert negative indices in
921     // TensorFlow to absolute indices required by HLO.
922     RankedTensorType act_type =
923         act.getType().template dyn_cast<RankedTensorType>();
924     if (!act_type) return Pattern::matchFailure();
925     Type act_ele_type = act_type.getElementType();
926     // To support mixed precision, the statistics type, which maybe more
927     // precise than the input types, are used for this op.
928     Type kernel_type =
929         scale.getType().template cast<TensorType>().getElementType();
930     grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
931     act = rewriter.create<ConvertOp>(loc, act, kernel_type);
932 
933     auto feature_dim_attr =
934         getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act);
935     auto feature_dim = feature_dim_attr.getValue().getSExtValue();
936 
937     // Gets the result values.
938     Value x_backprop, scale_backprop, offset_backprop;
939     if (op.is_training()) {  // training
940       // TODO(b/145536565): handle GPU logic seperately.
941       // Infers the output type with the converted `act`.
942       Type feature_type = RankedTensorType::get(
943           {GetDimSize(act_type, feature_dim)}, kernel_type);
944       Type result_type = TupleType::get(
945           {act.getType(), feature_type, feature_type}, rewriter.getContext());
946 
947       auto training_op = rewriter.create<BatchNormGradOp>(
948           loc, result_type, act, scale, mean, var, grad, op.epsilon(),
949           feature_dim_attr.getValue());
950 
951       x_backprop =
952           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
953 
954       scale_backprop =
955           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
956 
957       offset_backprop =
958           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
959     } else {  // inference
960       SmallVector<int64_t, 4> non_feature_dims;
961       for (int64_t i = 0; i < act_type.getRank(); ++i) {
962         if (i == feature_dim) continue;
963         non_feature_dims.push_back(i);
964       }
965       auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
966       auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter);
967       auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
968 
969       // scratch1 = rsqrt(var + epsilon)
970       RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
971       auto epsilon = rewriter.create<ConstOp>(
972           loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
973       auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(),
974                                            no_broadcast_dims);
975       Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
976 
977       // scratch2 = sum(y_backprop * (x - mean))
978       auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims);
979       auto weighted_grad =
980           rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims);
981       Value scratch2 =
982           ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
983 
984       // x_backprop = y_backprop * (scale * scratch1)
985       auto scaled_grad =
986           rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims);
987       x_backprop =
988           rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims);
989 
990       // scale_backprop = scratch2 * scratch1
991       scale_backprop =
992           rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims);
993 
994       // offset_backprop = sum(y_backprop)
995       offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
996     }
997 
998     x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
999     // It doesn't matter what values we provide for the last 2 results.
1000     rewriter.replaceOp(op,
1001                        {/*x_backprop=*/x_backprop,
1002                         /*scale_backprop=*/scale_backprop,
1003                         /*offset_backprop=*/offset_backprop, op.x(), op.x()});
1004     return Pattern::matchSuccess();
1005   }
1006 };
1007 
1008 using ConvertFusedBatchNormGradOp =
1009     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
1010 using ConvertFusedBatchNormGradV2Op =
1011     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
1012 using ConvertFusedBatchNormGradV3Op =
1013     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
1014 
1015 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
1016 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
1017 // parameter.
1018 class ConvertFusedBatchNormV3Op
1019     : public OpRewritePattern<TF::FusedBatchNormV3Op> {
1020  public:
1021   using OpRewritePattern::OpRewritePattern;
1022 
matchAndRewrite(TF::FusedBatchNormV3Op op,PatternRewriter & rewriter) const1023   PatternMatchResult matchAndRewrite(TF::FusedBatchNormV3Op op,
1024                                      PatternRewriter &rewriter) const override {
1025     auto feature_dim =
1026         getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
1027 
1028     auto input_type_tensor = op.x().getType().dyn_cast<TensorType>();
1029     auto input_element_type = input_type_tensor.getElementType();
1030 
1031     auto scale_type_tensor = op.scale().getType().dyn_cast<TensorType>();
1032     auto scale_element_type = scale_type_tensor.getElementType();
1033     // In the training case, dimensions of input tensors must be static.
1034     if (op.is_training() && ((!input_type_tensor.hasStaticShape()) ||
1035                              (!scale_type_tensor.hasStaticShape()))) {
1036       return matchFailure();
1037     }
1038 
1039     // TODO(b/69928690): Support mixed precision in the XLA batch
1040     // normalization operators. As a workaround, create a new x with the same
1041     // element type as scale (which may be more precise than the input type).
1042     Value bn_train_input = rewriter.create<xla_hlo::ConvertOp>(
1043         op.getLoc(), op.x(), scale_element_type);
1044     TensorType bn_train_input_type_tensor =
1045         bn_train_input.getType().cast<TensorType>();
1046 
1047     if (op.is_training()) {
1048       // Training case.
1049       auto operand_shape = bn_train_input_type_tensor.getShape();
1050       // The mean and variance are each 1 dimensional arrays the size of the
1051       // feature dimension, with the same element type as the operand (x).
1052       // This shape must be constructed manually because the mean and variance
1053       // inputs are empty in the training case.
1054       Type mean_var_type = RankedTensorType::get(
1055           {operand_shape[feature_dim.getInt()]}, scale_element_type);
1056       // Op result type is a tuple of 3 values: output with same shape as input;
1057       // batch_mean, and batch_var.
1058       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
1059                                             mean_var_type, mean_var_type};
1060       Type result_type = TupleType::get(operand_types, rewriter.getContext());
1061 
1062       auto bn_train_op = rewriter.create<xla_hlo::BatchNormTrainingOp>(
1063           op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
1064           op.epsilon(), feature_dim.getValue());
1065       // HLO op outputs a tuple of tensors. Extract those results.
1066       auto bn_train_op_result = bn_train_op.getResult();
1067       Value y_out = rewriter.create<xla_hlo::GetTupleElementOp>(
1068           op.getLoc(), bn_train_op_result, 0);
1069       Value batch_mean = rewriter.create<xla_hlo::GetTupleElementOp>(
1070           op.getLoc(), bn_train_op_result, 1);
1071       Value batch_variance = rewriter.create<xla_hlo::GetTupleElementOp>(
1072           op.getLoc(), bn_train_op_result, 2);
1073 
1074       // Apply Bessel's correction on the variance.
1075       int total_input_size = bn_train_input_type_tensor.getNumElements();
1076       int total_scale_size = scale_type_tensor.getNumElements();
1077       int sample_size = total_input_size / total_scale_size;
1078       int sample_size_minus_one = std::max(1, sample_size - 1);
1079       double factor = static_cast<double>(sample_size) /
1080                       static_cast<double>(sample_size_minus_one);
1081       auto factor_const_op = rewriter.create<xla_hlo::ConstOp>(
1082           op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
1083 
1084       auto corrected_variance = rewriter.create<xla_hlo::MulOp>(
1085           op.getLoc(), batch_variance.getType(), batch_variance,
1086           factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1087 
1088       // Convert back to input type to stay aligned with expected output type
1089       // for TF op.
1090       y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), y_out,
1091                                                   input_element_type);
1092 
1093       // TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are
1094       // currently marked as "reserved spaces 1 and 2". They are used to
1095       // pass the per-batch mean and variance to the gradiant. Here we
1096       // maintain the same behavior by setting them to the mean and variance
1097       // calculated by BatchNormTraining. Output 5 is unused; it doesn't
1098       // matter what we pass there.
1099       rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
1100                               /*batch_variance=*/corrected_variance,
1101                               /*reserve_space_1=*/batch_mean,
1102                               /*reserve_space_2=*/corrected_variance,
1103                               /*reserve_space_3=*/op.x()});
1104     } else {  // Inference case.
1105       auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
1106           op.getLoc(),
1107           /*result_type=*/bn_train_input_type_tensor, bn_train_input,
1108           op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
1109           feature_dim.getValue());
1110 
1111       // Convert back to input type to stay aligned with expected output type
1112       // for TF op.
1113       auto y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), bn_train_op,
1114                                                        input_element_type);
1115 
1116       // The mean, variance, and reserved space outputs of the batch norm op are
1117       // not used for inference. It doesn't matter what values we provide for
1118       // the last 5 results.
1119       rewriter.replaceOp(
1120           op, {/*y=*/y_out, /*batch_mean=*/op.x(),
1121                /*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(),
1122                /*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()});
1123     }
1124     return Pattern::matchSuccess();
1125   }
1126 };
1127 
1128 // Returns padding attribute for ReduceWindow op with given params.
1129 //
1130 // Requires padding to be either 'SAME' or 'VALID' and the number of input
1131 // dimensions to be equal to the size of window dimensions and window strides.
GetReduceWindowPadding(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)1132 static DenseIntElementsAttr GetReduceWindowPadding(
1133     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
1134     ArrayAttr window_strides, StringRef padding, Builder *builder) {
1135   if (padding == "VALID") return {};
1136   DCHECK_EQ(padding.str(), "SAME");
1137 
1138   llvm::SmallVector<tensorflow::int64, 4> input_shape, window_shape, strides;
1139   input_shape.reserve(input_dims.size());
1140   window_shape.reserve(window_shape.size());
1141   strides.reserve(window_strides.size());
1142 
1143   for (const auto &dim : input_dims) input_shape.push_back(dim);
1144   for (Attribute attr : window_dims)
1145     window_shape.push_back(attr.cast<IntegerAttr>().getInt());
1146   for (Attribute attr : window_strides)
1147     strides.push_back(attr.cast<IntegerAttr>().getInt());
1148 
1149   std::vector<std::pair<tensorflow::int64, tensorflow::int64>> paddings =
1150       ::xla::MakePadding(input_shape, window_shape, strides,
1151                          ::xla::Padding::kSame);
1152   int64_t rank = paddings.size();
1153   llvm::SmallVector<int64_t, 8> flatten_paddings(rank * 2);
1154   for (int i = 0; i < rank; i++) {
1155     flatten_paddings[2 * i] = paddings[i].first;
1156     flatten_paddings[2 * i + 1] = paddings[i].second;
1157   }
1158   return DenseIntElementsAttr::get(
1159       RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
1160       flatten_paddings);
1161 }
1162 
1163 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
1164 // dimensions with add as the reduction function. The reduction result is
1165 // then divided by the number of elements in the window.
1166 class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
1167  public:
1168   using OpRewritePattern::OpRewritePattern;
1169 
matchAndRewrite(TF::AvgPoolOp op,PatternRewriter & rewriter) const1170   PatternMatchResult matchAndRewrite(TF::AvgPoolOp op,
1171                                      PatternRewriter &rewriter) const override {
1172     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1173     if (!input_type) return matchFailure();
1174 
1175     // TODO(b/147217034): support other data formats.
1176     if (!IsDefaultDataFormat(op.data_format())) return matchFailure();
1177     // TODO(b/147217034): support "SAME" padding.
1178     if (op.padding() != "VALID") return matchFailure();
1179 
1180     // We will do accumulation first; use a larger bitwidth if suitable.
1181     Type input_element_type = input_type.getElementType();
1182     Type sum_element_type = GetSumAccumulationType(input_element_type);
1183     Type result_type;
1184 
1185     // The result type for reduction and division with the proper element type.
1186     if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>())
1187       result_type =
1188           RankedTensorType::get(ranked_type.getShape(), sum_element_type);
1189     else
1190       result_type = UnrankedTensorType::get(sum_element_type);
1191 
1192     Value input_value = op.value();
1193 
1194     // Convert if we need enlarge the element type's bitwidth.
1195     if (input_element_type != sum_element_type)
1196       input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
1197                                                sum_element_type);
1198 
1199     // Create the tf.ReduceWindow op.
1200     Value init =
1201         GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
1202     DenseIntElementsAttr paddings_attr =
1203         GetReduceWindowPadding(input_type.getShape(), op.ksize(), op.strides(),
1204                                op.padding(), &rewriter);
1205     auto reduce = rewriter.create<ReduceWindowOp>(
1206         op.getLoc(), result_type, input_value, init,
1207         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
1208         /*base_dilations=*/DenseIntElementsAttr(),
1209         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
1210     BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
1211 
1212     // Count the number of elements in the window. The following calculation
1213     // is only valid for no paddings.
1214     SmallVector<int64_t, 4> ksize;
1215     GetI64ArrayAttrValues(op.ksize(), &ksize);
1216     int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1,
1217                                     std::multiplies<int64_t>());
1218 
1219     // Divide by the number of elements in the window.
1220     Value divisor =
1221         GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
1222     auto batch_dims =
1223         GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter);
1224     Value result = rewriter.create<DivOp>(op.getLoc(), result_type, reduce,
1225                                           divisor, batch_dims);
1226 
1227     // Convert back if we enlarged the element type's bitwidth.
1228     if (input_element_type != sum_element_type)
1229       result =
1230           rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
1231 
1232     rewriter.replaceOp(op, result);
1233     return matchSuccess();
1234   }
1235 };
1236 
1237 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
1238 // dimensions with max as the reduction function.
1239 //
1240 // Sample result for VALID padding mode:
1241 //
1242 //   %init = constant dense<...> : tensor<i32>
1243 //   %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"]
1244 //               {window_dimensions = ..., window_strides = ... }
1245 //
1246 class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
1247  public:
1248   using OpRewritePattern::OpRewritePattern;
1249 
matchAndRewrite(TF::MaxPoolOp op,PatternRewriter & rewriter) const1250   PatternMatchResult matchAndRewrite(TF::MaxPoolOp op,
1251                                      PatternRewriter &rewriter) const override {
1252     Type element_type =
1253         op.input().getType().cast<TensorType>().getElementType();
1254     if (!element_type.isIntOrFloat()) return matchFailure();
1255     Location loc = op.getLoc();
1256     ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
1257 
1258     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1259     if (!input_ty) return matchFailure();
1260     DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
1261         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
1262     auto reduce = rewriter.create<ReduceWindowOp>(
1263         loc, op.getType(), op.input(), init.getResult(),
1264         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
1265         /*base_dilations=*/DenseIntElementsAttr(),
1266         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
1267     BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
1268 
1269     rewriter.replaceOp(op, reduce.getResult());
1270     return matchSuccess();
1271   }
1272 };
1273 
1274 // Converts SelectV2 to HLO Select op and necessary BroadcastInDim ops on
1275 // operands.
1276 //
1277 // For example, the following source IR:
1278 //
1279 //   %select = "tf.SelectV2"(%condition, %t, %e) :
1280 //               (tensor<1xi1>, tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
1281 //
1282 // will be converted into:
1283 //
1284 //   %pred = "xla_hlo.broadcast_in_dim"(%cond)
1285 //             {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
1286 //               (tensor<1xi1>) -> tensor<2xi1>
1287 //   %on_false = "xla_hlo.broadcast_in_dim"(%e)
1288 //                 {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
1289 //                   (tensor<1xi32>) -> tensor<2xi32>
1290 //   %select = "xla_hlo.select"(%pred, %t, %on_false) :
1291 //               (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1292 class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
1293  public:
1294   using OpRewritePattern::OpRewritePattern;
1295 
matchAndRewrite(TF::SelectV2Op op,PatternRewriter & rewriter) const1296   PatternMatchResult matchAndRewrite(TF::SelectV2Op op,
1297                                      PatternRewriter &rewriter) const override {
1298     llvm::SmallVector<int64_t, 4> broadcast_then_else_shape;
1299     auto ranked_then_type = op.t().getType().dyn_cast<RankedTensorType>();
1300     auto ranked_else_type = op.e().getType().dyn_cast<RankedTensorType>();
1301     auto ranked_cond_type =
1302         op.condition().getType().dyn_cast<RankedTensorType>();
1303     if (!ranked_then_type || !ranked_then_type.hasStaticShape() ||
1304         !ranked_else_type || !ranked_else_type.hasStaticShape() ||
1305         !ranked_cond_type || !ranked_cond_type.hasStaticShape())
1306       return matchFailure();
1307 
1308     if (!OpTrait::util::getBroadcastedShape(ranked_then_type.getShape(),
1309                                             ranked_else_type.getShape(),
1310                                             broadcast_then_else_shape))
1311       return matchFailure();
1312 
1313     llvm::SmallVector<int64_t, 4> broadcast_shape;
1314     if (!OpTrait::util::getBroadcastedShape(broadcast_then_else_shape,
1315                                             ranked_cond_type.getShape(),
1316                                             broadcast_shape))
1317       return matchFailure();
1318 
1319     auto broadcast_or_self = [&](Value value) {
1320       RankedTensorType type = value.getType().cast<RankedTensorType>();
1321       auto output_type =
1322           RankedTensorType::get(broadcast_shape, type.getElementType());
1323       if (output_type == type) return value;
1324 
1325       int64_t rank = type.getRank();
1326       SmallVector<int64_t, 4> broadcast_dimensions(rank);
1327       std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
1328                 broadcast_shape.size() - rank);
1329 
1330       return rewriter
1331           .create<BroadcastInDimOp>(
1332               op.getLoc(), output_type, value,
1333               GetI64ElementsAttr(broadcast_dimensions, &rewriter))
1334           .getResult();
1335     };
1336 
1337     // HLO SelectOp supports broadcasting for predicate/condition if
1338     // predicate/condition is a scalar.
1339     Value pred = ranked_cond_type.getRank() == 0
1340                      ? op.condition()
1341                      : broadcast_or_self(op.condition());
1342     Value on_true = broadcast_or_self(op.t());
1343     Value on_false = broadcast_or_self(op.e());
1344 
1345     rewriter.replaceOpWithNewOp<SelectOp>(op, on_true.getType(), pred, on_true,
1346                                           on_false);
1347 
1348     return matchSuccess();
1349   };
1350 };
1351 
1352 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
1353 //
1354 //     sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
1355 //
1356 // Sample result with 2-d f16 inputs with B batches of with N elements each.
1357 //
1358 //    // Create an array of 0.5 the shape of the input array.
1359 //    %half = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
1360 //    %half_array = "xla_hlo.broadcast"(half)
1361 //                           {broadcast_sizes = dense<2> : tensor<1xi64>}
1362 //                           : (tensor<f32>) -> tensor<2xf32>
1363 //
1364 //    // Compute Tanh of half the logits of the values.
1365 //    %halved_logits = xla_hlo.mul %logits, %half_array : tensor<2xf32>
1366 //    %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
1367 //
1368 //    // Have the result of Tanh and add 0.5.
1369 //    %halved_tanh = xla_hlo.mul %tanh, %half : tensor<2xf32>
1370 //    %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32>
1371 //
1372 class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
1373  public:
1374   using OpRewritePattern::OpRewritePattern;
1375 
matchAndRewrite(TF::SigmoidOp op,PatternRewriter & rewriter) const1376   PatternMatchResult matchAndRewrite(TF::SigmoidOp op,
1377                                      PatternRewriter &rewriter) const override {
1378     auto operand = op.getOperand();
1379 
1380     auto scalar_one = rewriter.create<ConstOp>(
1381         op.getLoc(),
1382         rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
1383 
1384     auto shaped_type = operand.getType().cast<ShapedType>();
1385     auto constant_ones = rewriter.create<BroadcastOp>(
1386         op.getLoc(), shaped_type, scalar_one,
1387         DenseIntElementsAttr::get(
1388             RankedTensorType::get({shaped_type.getRank()},
1389                                   rewriter.getIntegerType(64)),
1390             shaped_type.getShape()));
1391 
1392     auto scaled_input = rewriter.create<MulOp>(
1393         op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
1394     auto tanh_op =
1395         rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
1396     auto mul_op =
1397         rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
1398                                /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1399     auto add_op =
1400         rewriter.create<AddOp>(op.getLoc(), mul_op, constant_ones,
1401                                /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1402 
1403     rewriter.replaceOp(op, add_op.getResult());
1404     return matchSuccess();
1405   }
1406 };
1407 
1408 // Converts Softmax and LogSoftmax to HLO ops, computing softmax with the
1409 // following formulas:
1410 //
1411 //     softmax = div(exp(logits), sum(exp(logits)))
1412 
1413 //     log_softmax = sub(logits, log(sum(exp(logits))))
1414 //
1415 // Sample result with 2-d f16 inputs with B batches of with N elements each.
1416 //
1417 //    %reduce_dim = tf.Const dense<[1]> : tensor<1xi64>
1418 //
1419 //    // Subtract each element by their batches' max to improve numerical
1420 //    // stability.
1421 //    %max = "tf.Max"(%input, %reduce_dim)
1422 //           : (tensor<BxNxf16>, tensor<1xi64>) -> tensor<Bxf16>
1423 //    %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0}
1424 //            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
1425 //
1426 //    %exp = "xla_hlo.exp"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
1427 //    %sum = "tf.Sum"(%exp, %reduce_dim)
1428 //            : (tensor<BxNxf32>, tensor<1xi64>) -> tensor<Bxf32>
1429 //
1430 //    // Softmax computation:
1431 //    %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0}
1432 //            : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
1433 template <typename OpTy, bool use_log = true>
1434 class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
1435  public:
1436   using OpRewritePattern<OpTy>::OpRewritePattern;
1437 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1438   PatternMatchResult matchAndRewrite(OpTy op,
1439                                      PatternRewriter &rewriter) const override {
1440     Value logits = op.logits();
1441 
1442     // Softmax converter requires ranked type because the XLA reduce ops used
1443     // while lowering requires dimensions attribute to reduce along.
1444     RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
1445     if (!type) return Pattern::matchFailure();
1446 
1447     auto loc = op.getLoc();
1448     int rank = type.getRank();
1449 
1450     // Note that the TensorFlow Softmax op verifies that the input rank is
1451     // greater than or equal to one so both of the following sequences are
1452     // valid.
1453     auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter);
1454     auto reduce_dim = rewriter.create<TF::ConstOp>(
1455         loc, GetI64ElementsAttr({rank - 1}, &rewriter));
1456 
1457     // Exponential of input values and then their sum can be very large here.
1458     // Division with large denominator is numerically unstable. To improve
1459     // numerical stability, subtract each batch with their max element so that
1460     // the maximum input value is zero. It can be shown that softmax computed
1461     // after adding or subtracting all inputs in a batch using a common value
1462     // gives mathematically equivalent result.
1463     auto max_logits =
1464         rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1465                                    /*keep_dims=*/rewriter.getBoolAttr(false));
1466     auto shifted_logits =
1467         rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims);
1468 
1469     // Exponentiate the inputs.
1470     Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
1471 
1472     // Compute summation of the exponentials.
1473     auto exp_sum =
1474         rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1475                                    /*keep_dims=*/rewriter.getBoolAttr(false));
1476     Value sum = exp_sum.getResult();
1477 
1478     if (use_log) {
1479       Value log = rewriter.create<LogOp>(loc, sum);
1480       rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims);
1481     } else {
1482       rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims);
1483     }
1484     return Pattern::matchSuccess();
1485   }
1486 };
1487 
1488 // Converts Size to HLO ops, computing the size of a ranked input tensor.
1489 // TODO(b/145253252): Update this to not require ranked input tensor shapes.
1490 //
1491 // The main logic of this pattern is to calculate the size by multiplying every
1492 // dimension of the input tensor's shape together.
1493 //
1494 // For example, the following source IR:
1495 //
1496 //   %size = "tf.Size"(%input) : (tensor<2x?x8xf32>) -> tensor<i32>
1497 //
1498 // will be converted into:
1499 //
1500 //   %const = xla_hlo.constant dense<1> : tensor<i32>
1501 //   %dim_0 = "xla_hlo.get_dimension_size"(%input) {dimension = 0 : i32} :
1502 //                                         (tensor<2x?x8xf32>) -> tensor<i32>
1503 //   %prod_0 = xla_hlo.mul %const, %dim_0 : tensor<i32>
1504 //   %dim_1 = "xla_hlo.get_dimension_size"(%input) {dimension = 1 : i32} :
1505 //                                         (tensor<2x?x8xf32>) -> tensor<i32>
1506 //   %prod_1 = xla_hlo.mul %prod_0, %dim_1 : tensor<i32>
1507 //   %dim_2 = "xla_hlo.get_dimension_size"(%input) {dimension = 2 : i32} :
1508 //                                         (tensor<2x?x8xf32>) -> tensor<i32>
1509 //   %size = xla_hlo.mul %prod_1, %dim_2 : tensor<i32>
1510 class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
1511  public:
1512   using OpRewritePattern::OpRewritePattern;
1513 
matchAndRewrite(TF::SizeOp op,PatternRewriter & rewriter) const1514   PatternMatchResult matchAndRewrite(TF::SizeOp op,
1515                                      PatternRewriter &rewriter) const override {
1516     Value input = op.input();
1517     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1518     if (!input_ty) return Pattern::matchFailure();
1519 
1520     const int64_t rank = input_ty.getRank();
1521     auto result_type = op.getResult().getType();
1522     Operation *size =
1523         GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
1524                              op.getLoc(), 1, &rewriter);
1525     for (int64_t i = 0; i < rank; ++i) {
1526       auto dim = rewriter.create<GetDimensionSizeOp>(
1527           op.getLoc(), result_type, input,
1528           rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
1529       size = rewriter.create<MulOp>(
1530           op.getLoc(), size->getResult(0), dim.getResult(),
1531           /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1532     }
1533     rewriter.replaceOp(op, size->getResult(0));
1534 
1535     return Pattern::matchSuccess();
1536   }
1537 };
1538 
1539 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
1540 // split has fully static shape and the dimension to split is a constant.
1541 //
1542 // The main logic of this pattern is to calculate the index start and end range
1543 // for each slice. And this happens only on the dimension to be split; for all
1544 // other dimensions, all resultant slices' index start and end range covers the
1545 // input tensor's full range. Strides for all resultant slices are all one.
1546 //
1547 // For example, the following source IR:
1548 //
1549 //   %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
1550 //   %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
1551 //                (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
1552 //
1553 // will be converted into:
1554 //
1555 //   %0 = "xla_hlo.slice"(%input) {
1556 //             limit_indices = dense<[4, 2]> : tensor<2xi64>,
1557 //             start_indices = dense<0> : tensor<2xi64>,
1558 //             strides = dense<1> : tensor<2xi64>} :
1559 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
1560 //   %1 = "xla_hlo.slice"(%input) {
1561 //             limit_indices = dense<4> : tensor<2xi64>,
1562 //              start_indices = dense<[0, 2]> : tensor<2xi64>,
1563 //            strides = dense<1> : tensor<2xi64>} :
1564 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
1565 //    %2 = "xla_hlo.slice"(%input) {
1566 //            limit_indices = dense<[4, 6]> : tensor<2xi64>,
1567 //            start_indices = dense<[0, 4]> : tensor<2xi64>,
1568 //             strides = dense<1> : tensor<2xi64>} :
1569 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
1570 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
1571 // applicable.
1572 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
1573  public:
1574   using OpRewritePattern::OpRewritePattern;
1575 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const1576   PatternMatchResult matchAndRewrite(TF::SplitOp op,
1577                                      PatternRewriter &rewriter) const override {
1578     // We can only split along static dimensions.
1579     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1580     if (!input_type) return matchFailure();
1581 
1582     // We can only match when the split dimension is a constant scalar.
1583     DenseIntElementsAttr split_dim_attr;
1584     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
1585       return matchFailure();
1586 
1587     // Get the dimension we are splitting at. Offset properly if it's negative.
1588     int64_t input_rank = input_type.getRank();
1589     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
1590     if (dim_index < 0) dim_index += input_rank;
1591 
1592     // Calculate the dimension size for each slice along the split dimension.
1593     int64_t input_dim_size = input_type.getDimSize(dim_index);
1594     // If we are splitting along the dynamic dimension then we cannot compute
1595     // the static dimension length.
1596     if (TensorType::isDynamic(input_dim_size)) return matchFailure();
1597 
1598     int64_t num_splits = op.getNumResults();
1599     int64_t slice_size = input_dim_size / num_splits;
1600 
1601     // Get each slice's type.
1602     auto slice_shape = llvm::to_vector<4>(input_type.getShape());
1603     slice_shape[dim_index] = slice_size;
1604     Type slice_type =
1605         RankedTensorType::get(slice_shape, input_type.getElementType());
1606 
1607     // Parameters for constructing each slice.
1608     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
1609     auto end_indices = llvm::to_vector<4>(input_type.getShape());
1610     SmallVector<int64_t, 4> strides(input_rank, 1);
1611 
1612     // All HLO slice results used to replace the original tf.Split op.
1613     SmallVector<Value, 4> slices;
1614     slices.reserve(num_splits);
1615 
1616     for (int i = 0; i < num_splits; ++i) {
1617       begin_indices[dim_index] = i * slice_size;
1618       end_indices[dim_index] = (i + 1) * slice_size;
1619       slices.push_back(
1620           rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
1621                                    GetI64ElementsAttr(begin_indices, &rewriter),
1622                                    GetI64ElementsAttr(end_indices, &rewriter),
1623                                    GetI64ElementsAttr(strides, &rewriter)));
1624     }
1625 
1626     rewriter.replaceOp(op, slices);
1627     return matchSuccess();
1628   }
1629 };
1630 
1631 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
1632 // be split has fully static shape and the dimension to split and split sizes
1633 // are constants.
1634 //
1635 // This is similar to the conversion for tf.Split op other than that the size of
1636 // each chunk on the dimension to split is explicitly given as an op operand
1637 // and they are not necessarily the same.
1638 //
1639 // For example, given the following IR:
1640 //
1641 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
1642 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
1643 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
1644 //                   (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
1645 //                   (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
1646 //
1647 // We will generate slices following slices:
1648 // %0 = "xla_hlo.slice"(%input) {
1649 //        limit_indices = dense<[4, 1]> : tensor<2xi64>,
1650 //        start_indices = dense<0> : tensor<2xi64>,
1651 //        strides = dense<1> : tensor<2xi64>} :
1652 //        (tensor<4x6xf32>) -> tensor<4x1xf32>
1653 // %1 = "xla_hlo.slice"(%input) {
1654 //        limit_indices = dense<[4, 3]> : tensor<2xi64>,
1655 //        start_indices = dense<[0, 1]> : tensor<2xi64>,
1656 //        strides = dense<1> : tensor<2xi64>} :
1657 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
1658 // %2 = "xla_hlo.slice"(%input) {
1659 //        limit_indices = dense<[4, 6]> : tensor<2xi64>,
1660 //        start_indices = dense<[0, 3]> : tensor<2xi64>,
1661 //        strides = dense<1> : tensor<2xi64>} :
1662 //        (tensor<4x6xf32>) -> tensor<4x3xf32>
1663 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
1664  public:
1665   using OpRewritePattern::OpRewritePattern;
1666 
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const1667   PatternMatchResult matchAndRewrite(TF::SplitVOp op,
1668                                      PatternRewriter &rewriter) const override {
1669     // We can only split along static dimensions.
1670     // TODO(b/145731001): enhance to support dynamic-shaped inputs.
1671     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1672     if (!input_type) return matchFailure();
1673 
1674     // We can only match when the split dimension is a constant scalar.
1675     DenseIntElementsAttr split_dim_attr;
1676     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
1677       return matchFailure();
1678 
1679     // We can only match when the split sizes is a constant int vector.
1680     DenseIntElementsAttr split_sizes_attr;
1681     if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1682       return matchFailure();
1683 
1684     // Get each chunck's size along the dimension to split. It may contain
1685     // dynamic sizes and we need to update it if so.
1686     SmallVector<int64_t, 4> split_sizes;
1687     int64_t total_dim_size = 0;  // Total dimension size assigned to splits
1688     llvm::Optional<int> dynamic_dim_index;
1689     split_sizes.reserve(
1690         split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1691     for (auto dim : llvm::enumerate(split_sizes_attr)) {
1692       int64_t dim_val = dim.value().getSExtValue();
1693       split_sizes.push_back(dim_val);
1694       if (dim_val == ShapedType::kDynamicSize) {
1695         // We cannot have more than one dynamic dimension.
1696         assert(!dynamic_dim_index && "invalid split sizes");
1697         dynamic_dim_index = dim.index();
1698       } else {
1699         total_dim_size += dim_val;
1700       }
1701     }
1702 
1703     // Get the dimension we are splitting at. Offset properly if it's negative.
1704     int64_t input_rank = input_type.getRank();
1705     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
1706     if (dim_index < 0) dim_index += input_rank;
1707 
1708     int64_t input_dim_size = input_type.getDimSize(dim_index);
1709     if (TensorType::isDynamic(input_dim_size)) return matchFailure();
1710 
1711     assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
1712             (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
1713            "invalid split sizes");
1714 
1715     // Update the dynamic dimension with calculated concrete size.
1716     if (dynamic_dim_index)
1717       split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
1718 
1719     // Parameters for constructing each slice.
1720     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
1721     auto end_indices = llvm::to_vector<4>(input_type.getShape());
1722     SmallVector<int64_t, 4> strides(input_rank, 1);
1723 
1724     // All HLO slice results used to replace the original tf.Split op.
1725     SmallVector<Value, 4> slices;
1726     slices.reserve(op.getNumResults());
1727 
1728     for (int i = 0; i < op.getNumResults(); ++i) {
1729       end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
1730       slices.push_back(rewriter.create<xla_hlo::SliceOp>(
1731           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
1732           GetI64ElementsAttr(end_indices, &rewriter),
1733           GetI64ElementsAttr(strides, &rewriter)));
1734       // Prepare the begin indice for the next slice.
1735       begin_indices[dim_index] = end_indices[dim_index];
1736     }
1737 
1738     rewriter.replaceOp(op, slices);
1739     return matchSuccess();
1740   }
1741 };
1742 
1743 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
1744 // negative strides and Reshape op to update the output shape. Indices and
1745 // strides operands are converted to attributes with non-negative indexing.
1746 //
1747 // For example with an op like following,
1748 //   tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
1749 //     : tensor<AxBxf32> -> tensor<Pxf32>
1750 //
1751 // Output would be:
1752 //   %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
1753 //   %sliced = "xla_hlo.Slice" (%input)
1754 //             {start_indices = ..., limit_indices = ..., strides = ...}
1755 //   %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
1756 //
1757 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
1758  public:
1759   using OpRewritePattern::OpRewritePattern;
1760 
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const1761   PatternMatchResult matchAndRewrite(TF::StridedSliceOp op,
1762                                      PatternRewriter &rewriter) const override {
1763     // Input shape needs to be static to convert negative indices in TensorFlow
1764     // to absolute indices required by HLO.
1765     //
1766     // TODO(hinsu): Relax this constraint for ops without negative indices and
1767     // strides.
1768     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1769     if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
1770     ArrayRef<int64_t> input_shape = input_ty.getShape();
1771 
1772     // Output shape needs to be static to apply 'new_axis_mask' or
1773     // 'shrink_axis_mask' by reshaping tensor after slice.
1774     //
1775     // TODO(hinsu): Relax this constraint for ops without the above masks.
1776     auto result_ty = op.getType().dyn_cast<RankedTensorType>();
1777     if (!result_ty || !result_ty.hasStaticShape()) return matchFailure();
1778 
1779     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
1780     if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides))
1781       return matchFailure();
1782 
1783     SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
1784         dims_to_reverse;
1785     int64_t input_rank = input_ty.getRank();
1786     hlo_begin_indices.reserve(input_rank);
1787     hlo_end_indices.reserve(input_rank);
1788     hlo_strides.reserve(input_rank);
1789 
1790     int64_t indices_elements = begin_indices.size();
1791     if (input_rank < indices_elements) return matchFailure();
1792 
1793     // Convert from TensorFlow negative or out of range indices and strides
1794     // values to legal HLO Slice attributes.
1795     for (int i = 0, e = indices_elements; i != e; i++) {
1796       int64_t begin = begin_indices[i];
1797       int64_t end = end_indices[i];
1798       int64_t stride = strides[i];
1799 
1800       if (stride < 0) {
1801         // Negative stride means that the output values are computed starting
1802         // from end until begin. Mark the dimension for reversal before slice
1803         // and compute indices for the reversed input.
1804         dims_to_reverse.push_back(i);
1805         begin = (input_shape[i] - 1) - begin;
1806         end = (input_shape[i] - 1) - end;
1807         stride = -stride;
1808       }
1809 
1810       // Unlike TensorFlow, HLO requires begin and end values to be within
1811       // range.
1812       begin = std::max(int64_t(0), begin);
1813       end = std::max(begin, end);
1814       end = std::min(end, input_shape[i]);
1815 
1816       hlo_begin_indices.push_back(begin);
1817       hlo_end_indices.push_back(end);
1818       hlo_strides.push_back(stride);
1819     }
1820 
1821     Location loc = op.getLoc();
1822     Value input = op.input();
1823     if (!dims_to_reverse.empty())
1824       input = rewriter.create<ReverseOp>(
1825           loc, input_ty, op.input(),
1826           GetI64ElementsAttr(dims_to_reverse, &rewriter));
1827     auto sliced = rewriter.create<SliceOp>(
1828         loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
1829         GetI64ElementsAttr(hlo_end_indices, &rewriter),
1830         GetI64ElementsAttr(hlo_strides, &rewriter));
1831 
1832     // Reshape slice result so that the shape is updated depending on
1833     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
1834     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
1835     return matchSuccess();
1836   }
1837 };
1838 
1839 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
1840 //
1841 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
1842 // the reverse: it propagates the graident for the sliced tensor to the original
1843 // input tensor by doing padding with zeros. The main logic is calculating the
1844 // indices and strides for padding.
1845 class ConvertStridedSliceGradOp
1846     : public OpRewritePattern<TF::StridedSliceGradOp> {
1847  public:
1848   using OpRewritePattern::OpRewritePattern;
1849 
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const1850   PatternMatchResult matchAndRewrite(TF::StridedSliceGradOp op,
1851                                      PatternRewriter &rewriter) const override {
1852     // We need constant input shape to perform padding calculations later.
1853     DenseIntElementsAttr input_shape_attr;
1854     if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
1855       return matchFailure();
1856 
1857     // We also need constant begin/end indices and strides to perform padding
1858     // calculations.
1859     // Bounded shape after performing strided slice
1860     SmallVector<int64_t, 4> shape;
1861     // Bounded begin, end, and strides for strided slice
1862     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
1863     if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
1864                                          &strides))
1865       return matchFailure();
1866 
1867     Value grad = op.dy();
1868     Type element_type = grad.getType().cast<ShapedType>().getElementType();
1869 
1870     // Perform reshape to undo any new/shrink axies done by strided slice.
1871     grad = rewriter.create<xla_hlo::ReshapeOp>(
1872         op.getLoc(), RankedTensorType::get(shape, element_type), grad);
1873 
1874     SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
1875     SmallVector<int64_t, 4> dims_to_reverse;
1876     padding_low.reserve(shape.size());
1877     padding_high.reserve(shape.size());
1878     padding_interm.reserve(shape.size());
1879 
1880     // Prepare padding parameters for each dimension.
1881     for (int i = 0, e = shape.size(); i < e; ++i) {
1882       int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
1883       if (strides[i] > 0) {
1884         padding_low.push_back(begin_indices[i]);
1885         padding_interm.push_back(strides[i] - 1);
1886 
1887         // Pad the upper dimension up to the expected input shape. It's not
1888         // sufficient simply to use end_indices[i] to compute the padding in
1889         // cases where the stride does not divide evenly into the interval
1890         // between begin_indices[i] and end_indices[i].
1891         int64_t size =
1892             padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
1893         padding_high.push_back(input_dim - size);
1894       } else {
1895         dims_to_reverse.push_back(i);
1896         padding_high.push_back(input_dim - begin_indices[i] - 1);
1897         padding_interm.push_back(-strides[i] - 1);
1898 
1899         // Pad the lower dimension up to the expected input shape.
1900         int64_t size =
1901             padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
1902         padding_low.push_back(input_dim - size);
1903       }
1904     }
1905 
1906     if (!dims_to_reverse.empty()) {
1907       grad = rewriter.create<xla_hlo::ReverseOp>(
1908           op.getLoc(), grad.getType(), grad,
1909           GetI64ElementsAttr(dims_to_reverse, &rewriter));
1910     }
1911 
1912     auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
1913     rewriter.replaceOpWithNewOp<xla_hlo::PadOp>(
1914         op, op.getType(), grad, zero,
1915         GetI64ElementsAttr(padding_low, &rewriter),
1916         GetI64ElementsAttr(padding_high, &rewriter),
1917         GetI64ElementsAttr(padding_interm, &rewriter));
1918     return matchSuccess();
1919   }
1920 };
1921 
1922 /// Converts the RangeOp tensorflow op to a xla_hlo.iota op with a scaling and
1923 /// offset applied to generate the range values. The output tensor needs to
1924 /// have a static shape.
1925 ///
1926 /// For example an op like the following:
1927 ///   %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
1928 ///      : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
1929 ///
1930 /// Output would be:
1931 ///   %iota = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
1932 ///   %scaled = "xla_hlo.mul"(%iota, %delta)
1933 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
1934 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
1935 ///   %result = "xla_hlo.add"(%scaled, %offset)
1936 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
1937 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
1938 ///
1939 /// Implementation is defined in C++ due to no type interface for the iota op.
1940 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
1941   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
1942 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const1943   PatternMatchResult matchAndRewrite(TF::RangeOp op,
1944                                      PatternRewriter &rewriter) const override {
1945     auto result = op.getResult();
1946     auto result_type = result.getType();
1947     if (!result_type.cast<ShapedType>().hasStaticShape()) {
1948       return matchFailure();
1949     }
1950 
1951     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
1952                                         rewriter.getI64IntegerAttr(0));
1953     auto scaled = rewriter.create<MulOp>(
1954         op.getLoc(), result_type, iota, op.delta(),
1955         xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
1956     rewriter.replaceOpWithNewOp<AddOp>(
1957         op, result_type, scaled, op.start(),
1958         xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
1959     return matchSuccess();
1960   }
1961 };
1962 
1963 /// Converts a generic OpTy tensorflow op to a xla_hlo.reduce op over
1964 /// ReductionOp.
1965 /// `is_accumulation` controls whether it uses higher precision for the actual
1966 /// reduction. This is set to false for ops like max where there is no precision
1967 /// concerns.
1968 template <typename Derived, typename OpTy, typename ReductionOp,
1969           bool is_accumulation = true>
1970 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
1971   using OpRewritePattern<OpTy>::OpRewritePattern;
1972 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1973   PatternMatchResult matchAndRewrite(OpTy op,
1974                                      PatternRewriter &rewriter) const override {
1975     // TODO(b/141785544): Update this to not require static shapes.
1976     // Input shape needs to be static to convert negative indices in TensorFlow
1977     // to absolute indices required by HLO.
1978     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1979     if (!input_ty) return this->matchFailure();
1980     ArrayRef<int64_t> input_shape = input_ty.getShape();
1981 
1982     DenseIntElementsAttr dimensions;
1983     if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
1984       return this->matchFailure();
1985 
1986     // Build the final shape from input_shape and dimensions using a bitmap
1987     // to mark the reduced dimensions.
1988     SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
1989     SmallVector<int64_t, 4> xla_dimensions;
1990     for (APInt index_raw : dimensions.getValues<APInt>()) {
1991       int64_t index = index_raw.getSExtValue();
1992       int64_t rank = input_shape.size();
1993       if ((index < -rank || index >= rank)) return this->matchFailure();
1994       index = (index + rank) % rank;
1995       reduced_dimensions_bitmap[index] = true;
1996       xla_dimensions.push_back(index);
1997     }
1998 
1999     Location loc = op.getLoc();
2000     Type element_type = input_ty.getElementType();
2001     // Convert to an accumulation type to not lose precision when doing
2002     // repeated arithmetic operations.
2003     Type reduce_element_type =
2004         is_accumulation ? GetAccumulationType(element_type) : element_type;
2005     auto casted_input =
2006         rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
2007 
2008     // Each reduction op can have a different initial value.
2009     Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
2010 
2011     auto reduction = rewriter.create<ReduceOp>(
2012         loc, casted_input.getResult(), init,
2013         GetI64ElementsAttr(xla_dimensions, &rewriter));
2014     BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
2015                                  &rewriter);
2016     Value result = reduction.getResult(0);
2017 
2018     // The mean op needs to divide by the product of the reduced dimensions.
2019     if (std::is_same<OpTy, TF::MeanOp>::value) {
2020       int64_t divisor_count = 1;
2021       for (size_t i = 0; i < input_shape.size(); ++i) {
2022         if (reduced_dimensions_bitmap[i]) {
2023           if (TensorType::isDynamic(input_shape[i])) {
2024             return this->matchFailure();
2025           }
2026           divisor_count *= input_shape[i];
2027         }
2028       }
2029       auto divisor = GetScalarConstOfType(reduce_element_type, loc,
2030                                           divisor_count, &rewriter);
2031       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2032       result = rewriter.create<DivOp>(loc, result, divisor.getResult(),
2033                                       broadcast_dims);
2034     }
2035 
2036     result = rewriter.create<ConvertOp>(loc, result, element_type);
2037 
2038     // Need to reshape back after the reduction if we're keeping the reduced
2039     // dimensions.
2040     if (op.keep_dims()) {
2041       result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2042     }
2043     rewriter.replaceOp(op, {result});
2044 
2045     return this->matchSuccess();
2046   }
2047 };
2048 
2049 // Converts Mean op to HLO Reduce op.
2050 //
2051 //   %init = constant dense<...> : tensor<T>
2052 //   %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
2053 //               {dimensions = ...}
2054 //   %divisor = constant dense<...> : tensor<T>
2055 //   %mean = "xla_hlo.div"(%sum, %divisor)
2056 class ConvertMeanOp
2057     : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
2058  public:
2059   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2060   static Value GetInitialValue(Type reduce_element_type, Location loc,
2061                                PatternRewriter *rewriter) {
2062     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2063   }
2064 };
2065 
2066 // Converts Sum op to HLO Reduce op.
2067 //
2068 //   %init = constant dense<...> : tensor<T>
2069 //   %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
2070 //               {dimensions = ...}
2071 class ConvertSumOp
2072     : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
2073  public:
2074   using GenericConvertReductionOp::GenericConvertReductionOp;
2075 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2076   static Value GetInitialValue(Type reduce_element_type, Location loc,
2077                                PatternRewriter *rewriter) {
2078     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2079   }
2080 };
2081 
2082 // Converts Max op to HLO Reduce op.
2083 //
2084 //   %init = constant dense<...> : tensor<T>
2085 //   %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"]
2086 //               {dimensions = ...}
2087 class ConvertMaxOp
2088     : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
2089                                        /* is_accumulation= */ false> {
2090  public:
2091   using GenericConvertReductionOp::GenericConvertReductionOp;
2092 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2093   static Value GetInitialValue(Type reduce_element_type, Location loc,
2094                                PatternRewriter *rewriter) {
2095     return GetMinValueForType(reduce_element_type, loc, rewriter);
2096   }
2097 };
2098 
2099 // Converts Min op to HLO Reduce op.
2100 //
2101 //   %init = constant dense<...> : tensor<T>
2102 //   %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.min"]
2103 //               {dimensions = ...}
2104 class ConvertMinOp
2105     : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
2106                                        /* is_accumulation= */ false> {
2107  public:
2108   using GenericConvertReductionOp::GenericConvertReductionOp;
2109 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2110   static Value GetInitialValue(Type reduce_element_type, Location loc,
2111                                PatternRewriter *rewriter) {
2112     return GetMaxValueForType(reduce_element_type, loc, rewriter);
2113   }
2114 };
2115 
2116 // Converts Prod op to HLO Reduce op.
2117 //
2118 //   %init = constant dense<...> : tensor<T>
2119 //   %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.mul"]
2120 //               {dimensions = ...}
2121 class ConvertProdOp
2122     : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
2123  public:
2124   using GenericConvertReductionOp::GenericConvertReductionOp;
2125 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2126   static Value GetInitialValue(Type reduce_element_type, Location loc,
2127                                PatternRewriter *rewriter) {
2128     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
2129   }
2130 };
2131 
2132 // Converts All op to HLO Reduce op.
2133 //
2134 //   %init = constant dense<...> : tensor<T>
2135 //   %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"]
2136 //               {dimensions = ...}
2137 class ConvertAllOp
2138     : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
2139  public:
2140   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2141   static Value GetInitialValue(Type reduce_element_type, Location loc,
2142                                PatternRewriter *rewriter) {
2143     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
2144   }
2145 };
2146 
2147 // Converts Any op to HLO Reduce op.
2148 //
2149 //   %init = constant dense<...> : tensor<T>
2150 //   %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"]
2151 //               {dimensions = ...}
2152 class ConvertAnyOp
2153     : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
2154  public:
2155   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2156   static Value GetInitialValue(Type reduce_element_type, Location loc,
2157                                PatternRewriter *rewriter) {
2158     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2159   }
2160 };
2161 
2162 // Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform
2163 // a reduction on the original input and the corresponding index. The reduction
2164 // sub-computation selects the max (or min) value and the index for the value.
2165 //   Derived: is the resulting derived class of this class.
2166 //   OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
2167 template <typename Derived, typename OpTy>
2168 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
2169   using OpRewritePattern<OpTy>::OpRewritePattern;
2170 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2171   PatternMatchResult matchAndRewrite(OpTy op,
2172                                      PatternRewriter &rewriter) const override {
2173     RankedTensorType input_type =
2174         op.input().getType().template dyn_cast<RankedTensorType>();
2175     if (!input_type) {
2176       return this->matchFailure();
2177     }
2178 
2179     Type input_element_type = input_type.getElementType();
2180     // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
2181     // tf.ArgMax doesn't support complex data types, this check can be removed.
2182     if (!input_element_type.isIntOrFloat()) return this->matchFailure();
2183 
2184     Location loc = op.getLoc();
2185     Value init_value =
2186         Derived::GetInitialValue(input_element_type, loc, rewriter);
2187 
2188     RankedTensorType output_type =
2189         op.output().getType().template dyn_cast<RankedTensorType>();
2190     if (!output_type) {
2191       return this->matchFailure();
2192     }
2193 
2194     Type index_element_type = output_type.getElementType();
2195     Value index_init_value =
2196         GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
2197 
2198     RankedTensorType index_type =
2199         RankedTensorType::get(input_type.getShape(), index_element_type);
2200 
2201     llvm::Optional<int64_t> optional_axis =
2202         GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
2203     if (!optional_axis.hasValue()) {
2204       return this->matchFailure();
2205     }
2206     int64_t axis = optional_axis.getValue();
2207 
2208     IntegerAttr iota_dimension =
2209         IntegerAttr::get(rewriter.getIntegerType(64), axis);
2210     Value index_values =
2211         rewriter.create<IotaOp>(loc, index_type, iota_dimension);
2212 
2213     std::vector<int64_t> dimensions = input_type.getShape();
2214     dimensions.erase(dimensions.begin() + axis);
2215     ArrayRef<int64_t> reduction_result_shape(dimensions);
2216 
2217     Value operands[] = {op.input(), index_values};
2218     Value init_values[] = {init_value, index_init_value};
2219     DenseIntElementsAttr reduction_dimensions =
2220         GetI64ElementsAttr({axis}, &rewriter);
2221 
2222     auto reduction = rewriter.create<ReduceOp>(
2223         loc, llvm::ArrayRef<Value>(operands),
2224         llvm::ArrayRef<Value>(init_values), reduction_dimensions);
2225     StringRef direction = Derived::GetDirection();
2226     BuildArgMinMaxReductionBody(input_element_type, index_element_type,
2227                                 direction, &reduction.body(), &rewriter);
2228 
2229     rewriter.replaceOp(op, {reduction.getResult(1)});
2230     return this->matchSuccess();
2231   }
2232 };
2233 
2234 // Converts tensorflow ArgMax op to xla_hlo operations. The actual
2235 // implementation is in class ConvertArgMinMaxOp:
2236 //
2237 //   %init_index = constant dense<...> : tensor<T>
2238 //   %init = constant dense<...> : tensor<T>
2239 //   %reduce = "xla_hlo.reduce"(%selected_input, %select_index, %init,
2240 //                              %init_index) ["xla_hlo.arg_max"]
2241 class ConvertArgMaxOp
2242     : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
2243  public:
2244   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
2245 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)2246   static Value GetInitialValue(Type reduce_element_type, Location loc,
2247                                PatternRewriter &rewriter) {
2248     return GetMinValueForType(reduce_element_type, loc, &rewriter);
2249   }
2250 
GetDirection()2251   static StringRef GetDirection() { return "GT"; }
2252 };
2253 
2254 // Converts TF TensorScatterUpdate op into Scatter Op with assignment:
2255 //
2256 //   %result = "xla_hlo.scatter"(%tensor, %indices, %updates)
2257 //     { dimensions = ... }
2258 //
2259 class ConvertTensorScatterUpdateOp
2260     : public OpRewritePattern<TF::TensorScatterUpdateOp> {
2261  public:
2262   using OpRewritePattern::OpRewritePattern;
2263 
matchAndRewrite(TF::TensorScatterUpdateOp op,PatternRewriter & rewriter) const2264   PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
2265                                      PatternRewriter &rewriter) const override {
2266     auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2267     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2268     auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
2269 
2270     if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
2271     // Last dimension of the indices needs to known at compile time for
2272     // computation of the 'update_window_dims' attribute in the dimensions
2273     // struct.
2274     int64_t num_index_dims = indices_ty.getShape().back();
2275     if (ShapedType::isDynamic(num_index_dims)) return matchFailure();
2276 
2277     int64_t tensor_rank = tensor_ty.getRank();
2278     int64_t indices_rank = indices_ty.getRank();
2279     int64_t updates_rank = updates_ty.getRank();
2280 
2281     int64_t window_dims = tensor_rank - num_index_dims;
2282     auto dims_attr = ScatterDimensionNumbers::get(
2283         GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
2284                                  &rewriter),
2285         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
2286         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
2287         rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
2288 
2289     Location loc = op.getLoc();
2290     auto scatter = rewriter.create<ScatterOp>(
2291         loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
2292 
2293     // Build region to assign the new value.
2294     [&](Region *region) {
2295       OpBuilder::InsertionGuard guard(rewriter);
2296       Block *block = rewriter.createBlock(region);
2297 
2298       // Block arguments are scalars of the given element type.
2299       Type type =
2300           RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType());
2301       block->addArguments({type, type});
2302       rewriter.create<ReturnOp>(loc, block->getArgument(1));
2303     }(&scatter.update_computation());
2304 
2305     rewriter.replaceOp(op, scatter.getResult());
2306     return matchSuccess();
2307   }
2308 };
2309 
2310 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
2311 //   For shape [S1, S2] and multiples [M1, M2],
2312 //     MS1 = M1 * S1; MS2 = M2 * S2
2313 //
2314 //   %broadcast = xla_hlo.broadcast_in_dim(%input) {
2315 //     broadcast_dimensions = [0, 2]
2316 //   }
2317 //   %result = "xla_hlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
2318 //      -> tensor<MS1xMS2xf32>
2319 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
2320  public:
2321   using OpRewritePattern::OpRewritePattern;
2322 
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const2323   PatternMatchResult matchAndRewrite(TF::TileOp op,
2324                                      PatternRewriter &rewriter) const override {
2325     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
2326     if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
2327     ArrayRef<int64_t> input_shape = input_ty.getShape();
2328     Type element_type = input_ty.getElementType();
2329 
2330     DenseIntElementsAttr multiples;
2331     if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
2332         multiples.getType().getRank() != 1)
2333       return matchFailure();
2334 
2335     if (multiples.getNumElements() != input_shape.size()) return matchFailure();
2336 
2337     SmallVector<int64_t, 8> broadcasted_shape;
2338     SmallVector<int64_t, 4> broadcast_dimensions;
2339     broadcasted_shape.reserve(input_shape.size() * 2);
2340     broadcast_dimensions.reserve(input_shape.size());
2341     for (auto multiple_and_input :
2342          llvm::zip(multiples.getValues<APInt>(), input_shape)) {
2343       int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
2344       int64_t input_size = std::get<1>(multiple_and_input);
2345 
2346       if (multiple < 0) return matchFailure();
2347 
2348       // Line input up with the next dimension in broadcasted_shape
2349       // when broadcasting.
2350       broadcast_dimensions.push_back(broadcasted_shape.size());
2351       int64_t output_size = input_size * multiple;
2352       if (input_size == 1 || multiple == 1) {
2353         // Special case for when normal broadcasting will just work.
2354         broadcasted_shape.push_back(output_size);
2355       } else {
2356         // Tiling will happen for this dimension during the ReshapeOp below.
2357         broadcasted_shape.push_back(input_size);
2358         broadcasted_shape.push_back(multiple);
2359       }
2360     }
2361     Location loc = op.getLoc();
2362     Type broadcasted_type =
2363         RankedTensorType::get(broadcasted_shape, element_type);
2364     Type output_type = op.getType();
2365 
2366     Value result = rewriter.create<BroadcastInDimOp>(
2367         loc, broadcasted_type, op.input(),
2368         GetI64ElementsAttr(broadcast_dimensions, &rewriter));
2369 
2370     if (output_type != broadcasted_type) {
2371       result = rewriter.create<ReshapeOp>(loc, output_type, result);
2372     }
2373 
2374     rewriter.replaceOp(op, {result});
2375 
2376     return matchSuccess();
2377   }
2378 };
2379 
2380 class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
2381  public:
2382   using OpRewritePattern::OpRewritePattern;
2383 
matchAndRewrite(TF::MaxPoolGradOp op,PatternRewriter & rewriter) const2384   PatternMatchResult matchAndRewrite(TF::MaxPoolGradOp op,
2385                                      PatternRewriter &rewriter) const override {
2386     Location loc = op.getLoc();
2387 
2388     Type element_type =
2389         op.orig_input().getType().cast<TensorType>().getElementType();
2390 
2391     // Compute paddings using the original input and kernel shape and strides.
2392     // Here, ReduceWindow op as used as the MaxPool op is lowered to the
2393     // ReduceWindow op.
2394     auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
2395     if (!input_ty) return matchFailure();
2396     DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
2397         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2398 
2399     auto result = rewriter.create<SelectAndScatterOp>(
2400         loc, op.getType(), op.orig_input(), op.grad(),
2401         GetScalarConstOfType(element_type, loc, 0, &rewriter),
2402         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2403         paddings_attr);
2404 
2405     BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
2406     {
2407       OpBuilder::InsertionGuard guard(rewriter);
2408       Block *block = rewriter.createBlock(&result.select());
2409 
2410       // Block arguments are scalars of the given element type.
2411       Type type = RankedTensorType::get(/*shape=*/{}, element_type);
2412       block->addArguments({type, type});
2413 
2414       auto reducer = rewriter.create<CompareOp>(
2415           loc, block->getArgument(0), block->getArgument(1),
2416           /*broadcast_dimensions=*/nullptr,
2417           StringAttr::get("GE", rewriter.getContext()));
2418       rewriter.create<ReturnOp>(loc, reducer.getResult());
2419     }
2420 
2421     rewriter.replaceOp(op, {result});
2422 
2423     return matchSuccess();
2424   }
2425 };
2426 
2427 // Converts hlo.Conv2DBackpropInputOp into:
2428 //   %rev_filter = "xla_hlo.reverse"(%filter)
2429 //   %result = "xla_hlo.conv"(%out_backprop, %rev_filter)
2430 class ConvertConv2DBackpropInputOp
2431     : public OpRewritePattern<TF::Conv2DBackpropInputOp> {
2432  public:
2433   using OpRewritePattern::OpRewritePattern;
2434 
matchAndRewrite(TF::Conv2DBackpropInputOp op,PatternRewriter & rewriter) const2435   PatternMatchResult matchAndRewrite(TF::Conv2DBackpropInputOp op,
2436                                      PatternRewriter &rewriter) const override {
2437     // Unpack all of the attributes.
2438     tensorflow::TensorFormat data_format;
2439     if (!FormatFromString(op.data_format().str(), &data_format)) {
2440       return matchFailure();
2441     }
2442     tensorflow::Padding padding;
2443     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2444       return Pattern::matchFailure();
2445 
2446     auto out_backprop_ty =
2447         op.out_backprop().getType().dyn_cast<RankedTensorType>();
2448     if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
2449       return matchFailure();
2450     ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
2451     auto filter_ty = op.filter().getType().dyn_cast<RankedTensorType>();
2452     if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
2453     ArrayRef<int64_t> filter_shape = filter_ty.getShape();
2454     int num_spatial_dims = 2;
2455     Location loc = op.getLoc();
2456 
2457     int num_dims = num_spatial_dims + 2;
2458     int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
2459     int feature_dim =
2460         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
2461 
2462     DenseIntElementsAttr input_shape_attr;
2463     if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
2464         input_shape_attr.getType().getRank() != 1) {
2465       return matchFailure();
2466     }
2467     auto input_shape =
2468         llvm::to_vector<4>(input_shape_attr.getValues<int32_t>());
2469     if (input_shape.size() != num_dims) return matchFailure();
2470 
2471     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
2472     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
2473 
2474     auto strides_attr = GetI64ElementsAttr(op.strides());
2475     std::vector<tensorflow::int32> strides{
2476         strides_attr.getValues<int64_t>().begin(),
2477         strides_attr.getValues<int64_t>().end()};
2478     auto dilations_attr = GetI64ElementsAttr(op.dilations());
2479     std::vector<int> dilations{dilations_attr.getValues<int64_t>().begin(),
2480                                dilations_attr.getValues<int64_t>().end()};
2481     auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings());
2482     std::vector<tensorflow::int64> explicit_paddings{
2483         explicit_paddings_attr.getValues<int64_t>().begin(),
2484         explicit_paddings_attr.getValues<int64_t>().end()};
2485 
2486     int64_t in_depth = input_shape[feature_dim];
2487     int64_t filter_in_depth = filter_shape[num_spatial_dims];
2488     int64_t feature_group_count = in_depth / filter_in_depth;
2489 
2490     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
2491     tensorflow::ConvBackpropDimensions dims;
2492     if (!tensorflow::ConvBackpropComputeDimensionsV2(
2493              "", num_spatial_dims, ToTensorShape<int>(input_shape),
2494              ToTensorShape<int64_t>(filter_shape),
2495              ToTensorShape<int64_t>(out_backprop_shape), dilations, strides,
2496              padding, explicit_paddings, data_format, &dims)
2497              .ok()) {
2498       return matchFailure();
2499     }
2500 
2501     // Compute ConvDimensionNumbers, dilation, and padding.
2502     SmallVector<int64_t, 4> kernel_spatial_dims(num_spatial_dims);
2503     SmallVector<int64_t, 4> conv_paddings(num_spatial_dims * 2);
2504     SmallVector<int64_t, 4> lhs_dilation(num_spatial_dims);
2505     SmallVector<int64_t, 4> rhs_dilation(num_spatial_dims);
2506     SmallVector<int64_t, 4> ones(num_spatial_dims, 1);
2507     SmallVector<int64_t, 4> spatial_dims(num_spatial_dims);
2508     for (int i = 0; i < num_spatial_dims; ++i) {
2509       int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
2510       spatial_dims[i] = dim;
2511       kernel_spatial_dims[i] = i;
2512 
2513       conv_paddings[i * 2] = dims.spatial_dims[i].pad_before;
2514       conv_paddings[i * 2 + 1] = dims.spatial_dims[i].pad_after;
2515       lhs_dilation[i] = dims.spatial_dims[i].stride;
2516       rhs_dilation[i] = dilations[dim];
2517     }
2518     RankedTensorType paddings_ty = RankedTensorType::get(
2519         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
2520     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings);
2521     auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
2522 
2523     Value filter = op.filter();
2524 
2525     if (feature_group_count != 1) {
2526       /*
2527       // TODO(parkers): Convert this code to mlir.
2528     filter = TransposeFilterForGroupConvolutionBackpropInput(
2529         filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
2530         */
2531       return matchFailure();
2532     }
2533 
2534     // Mirror the filter in the spatial dimensions.
2535     filter = rewriter.create<ReverseOp>(
2536         loc, filter, GetI64ElementsAttr(kernel_spatial_dims, &rewriter));
2537 
2538     // activation gradients
2539     //   = gradients (with padding and dilation) <conv> mirrored_weights
2540     Value result = rewriter.create<ConvOp>(
2541         loc, op.getType(), op.out_backprop(), filter,
2542         /*window_strides=*/GetI64ElementsAttr(ones, &rewriter),
2543         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
2544         GetI64ElementsAttr(rhs_dilation, &rewriter),
2545         ConvDimensionNumbers::get(
2546             /*input_batch_dimension=*/batch_dim_attr,
2547             /*input_feature_dimension=*/feature_dim_attr,
2548             /*input_spatial_dimensions=*/spatial_dims_attr,
2549             // TF filter shape is [ H, W, ..., inC, outC ]
2550             // Transpose the input and output features for computing the
2551             // gradient.
2552             /*kernel_input_feature_dimension=*/
2553             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
2554             /*kernel_output_feature_dimension=*/
2555             rewriter.getI64IntegerAttr(num_spatial_dims),
2556             /*kernel_spatial_dimensions=*/
2557             GetI64ElementsAttr(kernel_spatial_dims, &rewriter),
2558             /*output_batch_dimension=*/batch_dim_attr,
2559             /*output_feature_dimension=*/feature_dim_attr,
2560             /*output_spatial_dimensions=*/spatial_dims_attr,
2561             rewriter.getContext()),
2562         rewriter.getI64IntegerAttr(feature_group_count),
2563         /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
2564         /*precision_config=*/ArrayAttr());
2565 
2566     rewriter.replaceOp(op, {result});
2567 
2568     return matchSuccess();
2569   }
2570 };
2571 
2572 // Converts tf.Conv2DBackpropFilterOp into:
2573 //   %result = "xla_hlo.conv"(%input, %out_backprop)
2574 class ConvertConv2DBackpropFilterOp
2575     : public OpRewritePattern<TF::Conv2DBackpropFilterOp> {
2576  public:
2577   using OpRewritePattern::OpRewritePattern;
2578 
matchAndRewrite(TF::Conv2DBackpropFilterOp op,PatternRewriter & rewriter) const2579   PatternMatchResult matchAndRewrite(TF::Conv2DBackpropFilterOp op,
2580                                      PatternRewriter &rewriter) const override {
2581     // Unpack all of the attributes.
2582     tensorflow::TensorFormat data_format;
2583     if (!FormatFromString(op.data_format().str(), &data_format)) {
2584       return matchFailure();
2585     }
2586     tensorflow::Padding padding;
2587     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2588       return Pattern::matchFailure();
2589 
2590     auto out_backprop_ty =
2591         op.out_backprop().getType().dyn_cast<RankedTensorType>();
2592     if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
2593       return matchFailure();
2594     ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
2595     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
2596     if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
2597     ArrayRef<int64_t> input_shape = input_ty.getShape();
2598 
2599     DenseIntElementsAttr filter_shape_attr;
2600     if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
2601         filter_shape_attr.getType().getRank() != 1) {
2602       return matchFailure();
2603     }
2604 
2605     auto strides_attr = GetI64ElementsAttr(op.strides());
2606     std::vector<tensorflow::int32> strides{
2607         strides_attr.getValues<int64_t>().begin(),
2608         strides_attr.getValues<int64_t>().end()};
2609     auto dilations_attr = GetI64ElementsAttr(op.dilations());
2610     SmallVector<int, 4> dilations{dilations_attr.getValues<int64_t>().begin(),
2611                                   dilations_attr.getValues<int64_t>().end()};
2612     auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings());
2613     SmallVector<tensorflow::int64, 4> explicit_paddings{
2614         explicit_paddings_attr.getValues<int64_t>().begin(),
2615         explicit_paddings_attr.getValues<int64_t>().end()};
2616 
2617     int num_spatial_dims = 2;
2618     int num_dims = num_spatial_dims + 2;
2619     int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
2620     int feature_dim =
2621         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
2622 
2623     auto filter_shape =
2624         llvm::to_vector<4>(filter_shape_attr.getValues<int32_t>());
2625     if (filter_shape.size() != num_dims) return matchFailure();
2626 
2627     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
2628     tensorflow::ConvBackpropDimensions dims;
2629     if (!tensorflow::ConvBackpropComputeDimensionsV2(
2630              "", num_spatial_dims, ToTensorShape<int64_t>(input_shape),
2631              ToTensorShape<int>(filter_shape),
2632              ToTensorShape<int64_t>(out_backprop_shape), dilations, strides,
2633              padding, explicit_paddings, data_format, &dims)
2634              .ok()) {
2635       return matchFailure();
2636     }
2637 
2638     // The activations (inputs) form the LHS of the convolution.
2639     // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
2640     // For the gradient computation, we need to:
2641     // 1. In the case of group convolution, move the num_groups dimension before
2642     // the batch dimension
2643     // 2. Swap the roles of the batch and feature dimensions.
2644     int64_t in_depth = input_shape[feature_dim];
2645     int64_t filter_in_depth = filter_shape[num_spatial_dims];
2646     int64_t feature_group_count = in_depth / filter_in_depth;
2647     if (feature_group_count != 1) {
2648       /*
2649           // TODO(parkers): translate this code to mlir.
2650           activations = TransposeInputForGroupConvolutionBackpropFilter(
2651               activations, input_shape, feature_group_count, batch_dim,
2652          feature_dim);
2653       */
2654       return matchFailure();
2655     }
2656 
2657     // Compute ConvDimensionNumbers, dilation, and padding.
2658     SmallVector<int64_t, 8> conv_padding(num_spatial_dims * 2);
2659     SmallVector<int64_t, 4> rhs_dilation(num_spatial_dims);
2660     SmallVector<int64_t, 4> window_strides(num_spatial_dims);
2661     SmallVector<int64_t, 4> lhs_dilation(num_spatial_dims, 1);
2662     SmallVector<int64_t, 4> spatial_dims(num_spatial_dims);
2663     SmallVector<int64_t, 4> kernel_spatial_dims(num_spatial_dims);
2664 
2665     // The filter gradients are computed by a convolution of the input
2666     // activations and the output gradients, with some appropriate padding.
2667     // See the comment at the top of conv_grad_ops.h for details.
2668 
2669     for (int64_t i = 0; i < num_spatial_dims; ++i) {
2670       int64_t dim =
2671           tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2672       kernel_spatial_dims[i] = dim;
2673       // Besides padding the input, we will also expand output_rows to
2674       //    expanded_out_rows = (output_rows - 1) * stride + 1
2675       // with zeros in between:
2676       //
2677       //      a . . . b . . . c . . . d . . . e
2678       //
2679       // This is done by specifying the window dilation factors in the
2680       // convolution HLO below.
2681       rhs_dilation[i] = dims.spatial_dims[i].stride;
2682       window_strides[i] = dilations[dim];
2683 
2684       // We will also need to pad the input with zeros such that after the
2685       // convolution, we get the right size for the filter.
2686       // The padded_in_rows should be such that when we convolve this with the
2687       // expanded_out_rows as a filter, we should get filter_rows back.
2688 
2689       const int64_t padded_in_size =
2690           dims.spatial_dims[i].expanded_output_size +
2691           (dims.spatial_dims[i].filter_size - 1) * dilations[dim];
2692 
2693       // However it can be smaller than input_rows: in this
2694       // case it means some of the inputs are not used.
2695       //
2696       // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
2697       //
2698       // INPUT =  [ A  B  C ]
2699       //
2700       // FILTER = [ x y ]
2701       //
2702       // and the output will only have one column: a = A * x + B * y
2703       //
2704       // and input "C" is not used at all.
2705       //
2706       // We apply negative padding in this case.
2707       const int64_t pad_total =
2708           padded_in_size - dims.spatial_dims[i].input_size;
2709 
2710       // + For the EXPLICIT padding, we pad the top/left side with the explicit
2711       //   padding and pad the bottom/right side with the remaining space.
2712       // + For the VALID padding, we don't pad anything on the top/left side
2713       //   and pad the bottom/right side with the remaining space.
2714       // + For the SAME padding, we pad top/left side the same as bottom/right
2715       //   side.
2716       //
2717       // In addition, if the padded input size is smaller than the input size,
2718       // we need to ignore some training elements of the input. We do this by
2719       // applying negative padding on the right/bottom.
2720       const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
2721                                      ? explicit_paddings[2 * dim]
2722                                      : padding == tensorflow::Padding::SAME
2723                                            ? std::max<int64_t>(pad_total / 2, 0)
2724                                            : 0;
2725       conv_padding[i * 2] = pad_before;
2726       conv_padding[i * 2 + 1] = pad_total - pad_before;
2727     }
2728 
2729     RankedTensorType paddings_ty = RankedTensorType::get(
2730         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
2731     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_padding);
2732     auto out_spatial_dims_attr =
2733         GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
2734     auto kernel_spatial_dims_attr =
2735         GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
2736 
2737     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
2738     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
2739 
2740     Location loc = op.getLoc();
2741     Value result = rewriter.create<ConvOp>(
2742         loc, op.getType(), op.input(), op.out_backprop(),
2743         /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
2744         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
2745         GetI64ElementsAttr(rhs_dilation, &rewriter),
2746         ConvDimensionNumbers::get(
2747             // Swap batch_dim and feature_dim in the activations.
2748             /*input_batch_dimension=*/feature_dim_attr,
2749             /*input_feature_dimension=*/batch_dim_attr,
2750             /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
2751             // The gradients become the RHS of the convolution.
2752             // The gradients have shape [batch, out_rows, out_cols, ...,
2753             // out_depth] where the batch becomes the input feature for the
2754             // convolution.
2755             /*kernel_input_feature_dimension=*/batch_dim_attr,
2756             /*kernel_output_feature_dimension=*/feature_dim_attr,
2757             /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
2758             /*output_batch_dimension=*/
2759             rewriter.getI64IntegerAttr(num_spatial_dims),
2760             /*output_feature_dimension=*/
2761             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
2762             /*output_spatial_dimensions=*/out_spatial_dims_attr,
2763             rewriter.getContext()),
2764         rewriter.getI64IntegerAttr(feature_group_count),
2765         /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
2766         /*precision_config=*/ArrayAttr());
2767 
2768     rewriter.replaceOp(op, {result});
2769 
2770     return matchSuccess();
2771   }
2772 };
2773 
2774 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
2775  public:
2776   using OpRewritePattern::OpRewritePattern;
2777 
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const2778   PatternMatchResult matchAndRewrite(TF::OneHotOp op,
2779                                      PatternRewriter &rewriter) const override {
2780     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2781     if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure();
2782     ArrayRef<int64_t> indices_shape = indices_ty.getShape();
2783     Type element_type = indices_ty.getElementType();
2784 
2785     DenseIntElementsAttr depth_attr;
2786     if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
2787       return matchFailure();
2788     }
2789 
2790     int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
2791     int64_t axis = op.axis().getSExtValue();
2792     if (axis == -1) axis = indices_shape.size();
2793 
2794     llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
2795     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
2796     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
2797 
2798     llvm::SmallVector<int64_t, 4> output_dims =
2799         llvm::to_vector<4>(indices_shape);
2800     output_dims.insert(output_dims.begin() + axis, depth);
2801 
2802     Location loc = op.getLoc();
2803     auto index_type = RankedTensorType::get(output_dims, element_type);
2804     Value compare = rewriter.create<CompareOp>(
2805         loc, op.indices(),
2806         rewriter.create<IotaOp>(
2807             loc, index_type,
2808             IntegerAttr::get(rewriter.getIntegerType(64), axis)),
2809         GetI64ElementsAttr(broadcast_dims, &rewriter),
2810         StringAttr::get("EQ", rewriter.getContext()));
2811     Value on_value = rewriter.create<BroadcastOp>(
2812         loc, op.getType(), op.on_value(),
2813         GetI64ElementsAttr(output_dims, &rewriter));
2814     Value off_value = rewriter.create<BroadcastOp>(
2815         loc, op.getType(), op.off_value(),
2816         GetI64ElementsAttr(output_dims, &rewriter));
2817     Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
2818                                              on_value, off_value);
2819 
2820     rewriter.replaceOp(op, {result});
2821 
2822     return matchSuccess();
2823   }
2824 };
2825 
2826 // Converts InfeedEnqueueTuple to XLA HLO after_all, infeed and
2827 // get_tuple_element ops.
2828 //
2829 // All HLO infeed ops expect a HLO token type operand and produce a tuple
2830 // containing a token. This HLO token type is used to order multiple infeed
2831 // operations within a computation. The token type can come from other
2832 // infeed/outfeed/send/recv ops or can be generated using an after_all op with
2833 // no operands. Here we emit an after_all op to generate the token type operand
2834 // of infeed.
2835 //
2836 // For example the following IR:
2837 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
2838 //
2839 // would be lowered to
2840 //
2841 // %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
2842 // %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} :
2843 //      (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
2844 //      !xla_hlo.token>
2845 // %data = "xla_hlo.get_tuple_element"(%data_and_token) {index = 0}
2846 // %0#0 = "xla_hlo.get_tuple_element"(%data) {index = 0}
2847 // %0#1 = "xla_hlo.get_tuple_element"(%data) {index = 1}
2848 //
2849 class ConvertInfeedDequeueTupleOp
2850     : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
2851  public:
2852   using OpRewritePattern::OpRewritePattern;
2853 
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const2854   PatternMatchResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
2855                                      PatternRewriter &rewriter) const override {
2856     std::vector<Type> result_types(op.outputs().size());
2857     for (auto idx_and_output : llvm::enumerate(op.outputs())) {
2858       result_types[idx_and_output.index()] = (idx_and_output.value().getType());
2859     }
2860     // Infeed takes a single token operand. Generate the token using after_all
2861     // op to pass to the infeed op.
2862     auto afterall = rewriter.create<AfterAllOp>(
2863         op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()),
2864         ValueRange());
2865 
2866     // Emit infeed op.
2867     // The result type of infeed is a tuple(tuple(result types), token type).
2868     auto data_tuple_type =
2869         mlir::TupleType::get(result_types, rewriter.getContext());
2870     auto data_and_token_type = mlir::TupleType::get(
2871         {data_tuple_type, afterall.getType()}, rewriter.getContext());
2872 
2873     auto data_and_token =
2874         rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, afterall,
2875                                   /*infeed_config=*/rewriter.getStringAttr(""));
2876 
2877     // The infeed instruction produces a tuple of the infeed data and a token
2878     // type. Emit get_tuple_element to get infeed data tuple.
2879     auto data_tuple = rewriter.create<GetTupleElementOp>(
2880         op.getLoc(), data_tuple_type, data_and_token,
2881         rewriter.getI32IntegerAttr(0));
2882 
2883     // Emit get_tuple_element for each result.
2884     std::vector<Value> results;
2885     for (auto idx_and_type : llvm::enumerate(result_types)) {
2886       auto tuple_element = rewriter.create<GetTupleElementOp>(
2887           op.getLoc(), idx_and_type.value(), data_tuple,
2888           rewriter.getI32IntegerAttr(idx_and_type.index()));
2889       results.push_back(tuple_element);
2890     }
2891     rewriter.replaceOp(op, ValueRange(results));
2892     return matchSuccess();
2893   }
2894 };
2895 
2896 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops.
2897 //
2898 // XLA HLO outfeed op expects a token, which we generate by emitting an
2899 // after_all op.
2900 //
2901 // For example the following IR:
2902 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
2903 //      ()
2904 //
2905 // would be lowered to
2906 //
2907 // %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
2908 //      tuple<tensor<3xi32>, tensor<4xf32>>
2909 // %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
2910 // %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
2911 //      (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
2912 //
2913 class ConvertOutfeedEnqueueTupleOp
2914     : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
2915  public:
2916   using OpRewritePattern::OpRewritePattern;
2917 
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const2918   PatternMatchResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
2919                                      PatternRewriter &rewriter) const override {
2920     auto token_type = xla_hlo::TokenType::get(rewriter.getContext());
2921     auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
2922     auto afterall =
2923         rewriter.create<AfterAllOp>(op.getLoc(), token_type, ValueRange());
2924     rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, afterall,
2925                                /*outfeed_config=*/rewriter.getStringAttr(""));
2926     rewriter.eraseOp(op);
2927     return matchSuccess();
2928   }
2929 };
2930 
2931 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
2932 //
2933 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
2934 // the top K components' values and indices. This is translated into a few
2935 // ops in XLA HLO: first generating an integer sequence for the indices,
2936 // then sort both the original input tensor and the indices togheter, and
2937 // at last slice out the top K components.
2938 //
2939 // For example, for the following IR:
2940 //
2941 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
2942 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
2943 //                                 (tensor<16x8xf32>, tensor<16x8xi32>)
2944 //
2945 // We will get:
2946 //
2947 // %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
2948 // %2 = "xla_hlo.sort"(%input, %1) ( {
2949 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
2950 //      %arg3: tensor<i32>, %arg4: tensor<i32>):
2951 //   %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
2952 //   "xla_hlo.return"(%7) : (tensor<i1>) -> ()
2953 // }) {dimension = 1 : i64, is_stable = true} : ...
2954 // %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
2955 // %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
2956 // %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
2957 //                           start_indices dense<0> : tensor<2xi64>,
2958 //                           strides = dense<1> : tensor<2xi64>} :
2959 //                              (tensor<16x16xf32>) -> tensor<16x8xf32>
2960 // %6 = "xla_hlo.slice"(%4) ...
2961 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
2962  public:
2963   using OpRewritePattern::OpRewritePattern;
2964 
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const2965   PatternMatchResult matchAndRewrite(TF::TopKV2Op op,
2966                                      PatternRewriter &rewriter) const override {
2967     // We can only match when the `k` operand is a constant scalar.
2968     DenseIntElementsAttr k_attr;
2969     if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure();
2970 
2971     // The last dimension of the input tensor's shape should be known so we can
2972     // have clamped end_indices for slices.
2973     TensorType input_type = op.input().getType().cast<TensorType>();
2974     if (!input_type.hasRank()) return matchFailure();
2975     int64_t input_rank = input_type.getRank();
2976     int64_t last_dim_index = input_rank - 1;
2977     int64_t last_dim_size = input_type.getDimSize(last_dim_index);
2978     if (last_dim_size == ShapedType::kDynamicSize) return matchFailure();
2979 
2980     // Create an Itoa op for indices.
2981     auto i32_type = rewriter.getIntegerType(32);
2982     Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
2983     Value iota_op = rewriter.create<xla_hlo::IotaOp>(
2984         op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
2985 
2986     // Create the sort op. It takes two inputs, one for the original input, the
2987     // other for the indices.
2988     auto sort_op = rewriter.create<xla_hlo::SortOp>(
2989         op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
2990         /*is_stable=*/true);
2991     BuildSortComparisonBody({input_type.getElementType(), i32_type},
2992                             /*direction=*/"GT", &sort_op.comparator(),
2993                             &rewriter);
2994 
2995     // Get the sorted input and index tuple element.
2996     auto tuple_first_element =
2997         rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
2998     auto tuple_second_element =
2999         rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
3000 
3001     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3002     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3003     end_indices.back() =
3004         std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
3005     SmallVector<int64_t, 4> strides(input_rank, 1);
3006 
3007     // Get the slice for the top K elements.
3008 
3009     Value values = rewriter.create<xla_hlo::SliceOp>(
3010         op.getLoc(), tuple_first_element,
3011         GetI64ElementsAttr(begin_indices, &rewriter),
3012         GetI64ElementsAttr(end_indices, &rewriter),
3013         GetI64ElementsAttr(strides, &rewriter));
3014 
3015     Value indices = rewriter.create<xla_hlo::SliceOp>(
3016         op.getLoc(), tuple_second_element,
3017         GetI64ElementsAttr(begin_indices, &rewriter),
3018         GetI64ElementsAttr(end_indices, &rewriter),
3019         GetI64ElementsAttr(strides, &rewriter));
3020 
3021     rewriter.replaceOp(op, {values, indices});
3022     return matchSuccess();
3023   }
3024 };
3025 
3026 // Converts tf.Unpack to a series of XLA HLO slice ops.
3027 //
3028 // Each slice takes one element along the dimension to unpack and takes the full
3029 // range for all other dimensions. Each slice is then reshaped to drop the
3030 // dimension to unpack (which is always of size 1).
3031 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
3032 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
3033  public:
3034   using OpRewritePattern::OpRewritePattern;
3035 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const3036   PatternMatchResult matchAndRewrite(TF::UnpackOp op,
3037                                      PatternRewriter &rewriter) const override {
3038     auto value_type = op.value().getType().cast<RankedTensorType>();
3039     if (!value_type) return matchFailure();
3040 
3041     int64_t value_rank = value_type.getRank();
3042     int64_t axis = op.axis().getSExtValue();
3043     if (axis < 0) axis += value_rank;
3044 
3045     // Parameters for constructing each slice.
3046     SmallVector<int64_t, 4> begin_indices(value_rank, 0);
3047     auto end_indices = llvm::to_vector<4>(value_type.getShape());
3048     SmallVector<int64_t, 4> strides(value_rank, 1);
3049 
3050     // All HLO slice+reshape results used to replace the original tf.Unpack op.
3051     SmallVector<Value, 4> results;
3052     results.reserve(op.getNumResults());
3053 
3054     for (int i = 0; i < op.getNumResults(); ++i) {
3055       begin_indices[axis] = i;
3056       end_indices[axis] = i + 1;
3057 
3058       auto slice_op = rewriter.create<xla_hlo::SliceOp>(
3059           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3060           GetI64ElementsAttr(end_indices, &rewriter),
3061           GetI64ElementsAttr(strides, &rewriter));
3062       // Reshape to drop the axis dimension.
3063       auto reshape_op = rewriter.create<xla_hlo::ReshapeOp>(
3064           op.getLoc(), op.getType(i), slice_op);
3065       results.push_back(reshape_op);
3066     }
3067 
3068     rewriter.replaceOp(op, results);
3069     return matchSuccess();
3070   }
3071 };
3072 
3073 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
3074 //
3075 // TF unsorted segment reduction op peforms the following calculation:
3076 //
3077 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's  shape is
3078 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
3079 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
3080 // Dm+1, ..., Dn]. Then
3081 //   output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
3082 //      <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
3083 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
3084 //
3085 // The op will be translated to XLA HLO scatter with the following parameters:
3086 // * Update window dims is [segment_id_rank, data_rank).
3087 // * Inserted window dims is {0}.
3088 // * Scatter dims to operand dims mapping is {0}.
3089 // * Index vector dim is segment_id_rank.
3090 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
3091 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
3092   using OpRewritePattern<OpTy>::OpRewritePattern;
3093 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3094   PatternMatchResult matchAndRewrite(OpTy op,
3095                                      PatternRewriter &rewriter) const override {
3096     auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
3097     if (!data_type) return this->matchFailure();
3098     int64_t data_rank = data_type.getRank();
3099 
3100     auto segment_ids_type =
3101         op.segment_ids().getType().template dyn_cast<RankedTensorType>();
3102     if (!segment_ids_type) return this->matchFailure();
3103     int64_t segment_ids_rank = segment_ids_type.getRank();
3104 
3105     DenseIntElementsAttr num_segments_attr;
3106     if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
3107       return this->matchFailure();
3108 
3109     // The final shape for TF unsorted segment reduction op is [num_segments] +
3110     // data_shape[segment_ids_rank:].
3111     SmallVector<int64_t, 4> output_shape;
3112     output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
3113     auto suffix = data_type.getShape().drop_front(segment_ids_rank);
3114     output_shape.append(suffix.begin(), suffix.end());
3115     auto output_type =
3116         RankedTensorType::get(output_shape, data_type.getElementType());
3117 
3118     // Broadccast the initial value for reduction. This will become the
3119     // 'operand' parameter to scatter to for the final scatter op.
3120     Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
3121                                                 op.getLoc(), &rewriter);
3122     auto broadcasted_init = rewriter.create<xla_hlo::BroadcastOp>(
3123         op.getLoc(), output_type, init,
3124         GetI64ElementsAttr(output_shape, &rewriter));
3125 
3126     // Parameters for the generated scatter op.
3127     SmallVector<int64_t, 1> inserted_window_dims(1, 0);
3128     SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
3129     int64_t index_vector_dim = segment_ids_rank;
3130 
3131     // Put all parameters in a StructAttr.
3132     auto dims_attr = ScatterDimensionNumbers::get(
3133         GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
3134         GetI64ElementsAttr(inserted_window_dims, &rewriter),
3135         GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
3136         rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
3137 
3138     auto scatter =
3139         rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
3140                                    op.segment_ids(), op.data(), dims_attr);
3141     BuildReduceBody<ReductionOp>(data_type.getElementType(),
3142                                  &scatter.update_computation(), &rewriter);
3143 
3144     rewriter.replaceOp(op, scatter.getResult());
3145     return this->matchSuccess();
3146   }
3147 };
3148 
3149 class ConvertUnsortedSegmentMaxOp
3150     : public GenericConvertUnsortedSegmentReductionOp<
3151           ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
3152  public:
3153   using GenericConvertUnsortedSegmentReductionOp::
3154       GenericConvertUnsortedSegmentReductionOp;
3155 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3156   static Value GetInitialValue(Type reduce_element_type, Location loc,
3157                                PatternRewriter *rewriter) {
3158     return GetMinValueForType(reduce_element_type, loc, rewriter);
3159   }
3160 };
3161 
3162 class ConvertUnsortedSegmentMinOp
3163     : public GenericConvertUnsortedSegmentReductionOp<
3164           ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
3165  public:
3166   using GenericConvertUnsortedSegmentReductionOp::
3167       GenericConvertUnsortedSegmentReductionOp;
3168 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3169   static Value GetInitialValue(Type reduce_element_type, Location loc,
3170                                PatternRewriter *rewriter) {
3171     return GetMaxValueForType(reduce_element_type, loc, rewriter);
3172   }
3173 };
3174 
3175 class ConvertUnsortedSegmentProdOp
3176     : public GenericConvertUnsortedSegmentReductionOp<
3177           ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
3178  public:
3179   using GenericConvertUnsortedSegmentReductionOp::
3180       GenericConvertUnsortedSegmentReductionOp;
3181 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3182   static Value GetInitialValue(Type reduce_element_type, Location loc,
3183                                PatternRewriter *rewriter) {
3184     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3185   }
3186 };
3187 
3188 class ConvertUnsortedSegmentSumOp
3189     : public GenericConvertUnsortedSegmentReductionOp<
3190           ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
3191  public:
3192   using GenericConvertUnsortedSegmentReductionOp::
3193       GenericConvertUnsortedSegmentReductionOp;
3194 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3195   static Value GetInitialValue(Type reduce_element_type, Location loc,
3196                                PatternRewriter *rewriter) {
3197     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3198   }
3199 };
3200 
3201 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
3202 //
3203 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
3204 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
3205 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
3206 // translated into an HLO while op to first emulate shuffling indices using
3207 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
3208 // with the shuffled indices.
3209 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
3210  public:
3211   using OpRewritePattern::OpRewritePattern;
3212 
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const3213   PatternMatchResult matchAndRewrite(TF::RandomShuffleOp op,
3214                                      PatternRewriter &rewriter) const override {
3215     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3216     if (!input_type) return matchFailure();
3217 
3218     int64_t input_rank = input_type.getRank();
3219     int64_t first_dim_size = input_type.getDimSize(0);
3220     if (ShapedType::isDynamic(first_dim_size)) return matchFailure();
3221 
3222     // We are shuffling along the first dimension. If its size is <= 1, then
3223     // shuffling is a no-op.
3224     if (first_dim_size <= 1) {
3225       rewriter.replaceOp(op, op.value());
3226       return matchSuccess();
3227     }
3228 
3229     // For vectors, shuffle values by sorting instead of the obvious
3230     // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
3231     // but not easily parallelizable. For a sufficiently parallel architecture,
3232     // it is faster to sort many times, than Fisher-Yates shuffle once.
3233     if (input_rank == 1) {
3234       // Shuffle values by assigning each value a random key and sorting the
3235       // keys. Keys can collide causing detectable patterns in the shuffled
3236       // output. Collisions translates into more ascending sub-sequences in the
3237       // shuffled output than would be expected by chance. To avoid collisions,
3238       // the number of possible key values must be sufficiently large.
3239 
3240       // How are more than 2^32 keys created? In each loop iteration, the
3241       // algorithm sorts by random keys. Conceptually, the earlier iterations
3242       // are sorting on the lower-order bits of larger keys that are never
3243       // actually assembled.
3244 
3245       // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
3246       // the number of possible keys and n is the number of values. If d = n^2,
3247       // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
3248       // as n goes to infinity is zero.
3249 
3250       // This implementation ensures that the key-space is greater than or equal
3251       // to the cube of the number of values. The risk of collisions can be
3252       // further reduced by increasing Exponent at the expense of
3253       // performance.
3254 
3255       // For Exponent = 2, the expected number of collisions per shuffle is
3256       // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
3257       // about 1/2.
3258 
3259       // For Exponent = 3, the expected number of collisions per shuffle is
3260       // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
3261       // about 1/3255.
3262 
3263       // For Exponent = 4, the expected number of collisions per shuffle is
3264       // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
3265       // about 1/132622.
3266       constexpr int exponent = 3;
3267       int64_t num_elements = input_type.getNumElements();
3268       uint32_t u32_max = std::numeric_limits<uint32_t>::max();
3269       int rounds =
3270           std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
3271 
3272       Value current = op.value();
3273       for (int i = 0; i < rounds; ++i) {
3274         auto keys =
3275             CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
3276                                /*upper_limit=*/u32_max, &rewriter);
3277         auto sorted = rewriter.create<xla_hlo::SortOp>(
3278             op.getLoc(), llvm::ArrayRef<Value>{keys, current});
3279         auto i32_type = rewriter.getIntegerType(32);
3280         BuildSortComparisonBody({i32_type, input_type.getElementType()},
3281                                 /*direction=*/"LT", &sorted.comparator(),
3282                                 &rewriter);
3283         current = rewriter.create<GetTupleElementOp>(op.getLoc(),
3284                                                      sorted.getResult(), 1);
3285       }
3286       rewriter.replaceOp(op, current);
3287       return matchSuccess();
3288     }
3289 
3290     // The Fisher-Yates algorithm.
3291 
3292     // Generate range(n) as the initial value for the indices to be swapped.
3293     auto indices_type =
3294         RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
3295     Value indices = rewriter.create<xla_hlo::IotaOp>(
3296         op.getLoc(), indices_type, rewriter.getI64IntegerAttr(first_dim_size));
3297 
3298     // Generate random numbers to be used as swaps for the indices.
3299     Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
3300                                      first_dim_size, &rewriter);
3301 
3302     // While loop body to perform index swaps.
3303     auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
3304                             SmallVectorImpl<Value> *new_values,
3305                             OpBuilder *builder) {
3306       Value swaps = old_values[0];
3307       Value indices = old_values[1];
3308 
3309       auto vec1_i32_type =
3310           RankedTensorType::get({1}, builder->getIntegerType(32));
3311       auto scalar_i32_type =
3312           RankedTensorType::get({}, builder->getIntegerType(32));
3313       auto scalar_i64_type =
3314           RankedTensorType::get({}, builder->getIntegerType(64));
3315 
3316       auto scalar_one =
3317           DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
3318 
3319       // We need to swap the indices[i] with indices[swaps[i]]. First get
3320       // these index values.
3321       Value source_index = builder->create<xla_hlo::DynamicSliceOp>(
3322           loc, vec1_i32_type, indices, i, scalar_one);
3323       Value swap_index = builder->create<xla_hlo::ReshapeOp>(
3324           loc, scalar_i32_type,
3325           builder->create<xla_hlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
3326                                                    scalar_one));
3327       Value target_index = builder->create<xla_hlo::DynamicSliceOp>(
3328           loc, vec1_i32_type, indices, swap_index, scalar_one);
3329 
3330       // Then perform the swap.
3331       // indices[i] <- indices[swaps[i]]
3332       indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
3333           loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
3334       // indices[swaps[i]] <- indices[i]
3335       indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
3336           loc, indices.getType(), indices, source_index,
3337           llvm::makeArrayRef(swap_index));
3338 
3339       // Update new values.
3340       new_values->assign({swaps, indices});
3341     };
3342 
3343     // Create a while op to swap indices.
3344     SmallVector<Value, 2> while_output;
3345     CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
3346                   &while_output, &rewriter);
3347     Value swaped_indices = while_output[1];
3348 
3349     // Gather the data using the swapped indices as the shuffled order.
3350     ArrayRef<int64_t> input_shape = input_type.getShape();
3351     SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
3352     slice_sizes[0] = 1;
3353     auto dims_attr = GatherDimensionNumbers::get(
3354         /*offset_dims=*/GetI64ElementsAttrForSeq(1, first_dim_size, &rewriter),
3355         /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
3356         /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
3357         /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
3358         rewriter.getContext());
3359     rewriter.replaceOpWithNewOp<xla_hlo::GatherOp>(
3360         op, op.getType(), op.value(), swaped_indices, dims_attr,
3361         GetI64ElementsAttr(slice_sizes, &rewriter));
3362 
3363     return matchSuccess();
3364   }
3365 };
3366 
3367 // Converts tf.VariableShape op to a XLA HLO constant representing the variable
3368 // shape.
3369 class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
3370  public:
3371   using OpRewritePattern::OpRewritePattern;
3372 
matchAndRewrite(TF::VariableShapeOp op,PatternRewriter & rewriter) const3373   PatternMatchResult matchAndRewrite(TF::VariableShapeOp op,
3374                                      PatternRewriter &rewriter) const override {
3375     // The input type should be a tensor<!tf.resource<resource-type>>. We need
3376     // to get the inner resource type.
3377     auto input_type = op.input().getType().cast<TensorType>();
3378     auto subtypes =
3379         input_type.getElementType().cast<TF::ResourceType>().getSubtypes();
3380     // It can be missing; then we cannot convert.
3381     if (subtypes.empty()) return matchFailure();
3382 
3383     auto resource_type = subtypes[0].cast<TensorType>();
3384     if (!resource_type.hasStaticShape()) return matchFailure();
3385 
3386     auto resource_shape = resource_type.getShape();
3387     Attribute const_attr;
3388 
3389     // We need to match the original op result's element type.
3390     auto element_type = op.getType().cast<TensorType>().getElementType();
3391     unsigned bitwidth = element_type.cast<IntegerType>().getWidth();
3392     if (bitwidth == 32) {
3393       SmallVector<int32_t, 4> shape(resource_shape.begin(),
3394                                     resource_shape.end());
3395       const_attr = GetI32ElementsAttr(shape, &rewriter);
3396     } else {
3397       assert(bitwidth == 64);
3398       const_attr = GetI64ElementsAttr(resource_shape, &rewriter);
3399     }
3400 
3401     rewriter.replaceOpWithNewOp<xla_hlo::ConstOp>(op, const_attr);
3402     return matchSuccess();
3403   }
3404 };
3405 
3406 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
3407 
legalizeTF(Operation * op,bool allow_partial_conversion)3408 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
3409   MLIRContext *context = op->getContext();
3410 
3411   // Add lowering patterns to the list.
3412   OwningRewritePatternList patterns;
3413   populateWithGenerated(context, &patterns);
3414 
3415   // Add patterns that lower some of the high level TensorFlow ops to lower
3416   // level TensorFlow ops. So, we don't have to target all the TensorFlow ops
3417   // here for lowering to HLO.
3418   TF::PopulateLoweringTFPatterns(context, &patterns);
3419   patterns.insert<
3420       ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBF16FloorDivOp,
3421       ConvertConv2D, ConvertConv2DBackpropFilterOp,
3422       ConvertConv2DBackpropInputOp, ConvertEinsumOp,
3423       ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
3424       ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
3425       ConvertInfeedDequeueTupleOp, ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp,
3426       ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
3427       ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
3428       ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
3429       ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
3430       ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
3431       ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
3432       ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
3433       ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
3434       ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
3435       ConvertRandomShuffleOp, ConvertVariableShapeOp>(op->getContext());
3436 
3437   ConversionTarget target(*context);
3438   target.addLegalDialect<XlaHloDialect>();
3439 
3440   if (!allow_partial_conversion) {
3441     // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
3442     target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
3443                       ::mlir::ReturnOp>();
3444     return applyFullConversion(op, target, patterns);
3445   }
3446 
3447   return applyPartialConversion(op, target, patterns);
3448 }
3449 
3450 /// Performs the lowering to XLA dialect.
runOnFunction()3451 void LegalizeTF::runOnFunction() {
3452   if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
3453     signalPassFailure();
3454 }
3455 
3456 static PassRegistration<LegalizeTF> pass(
3457     "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
3458 
3459 }  // end namespace
3460 
createLegalizeTFPass(bool allow_partial_conversion)3461 std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass(
3462     bool allow_partial_conversion) {
3463   return std::make_unique<LegalizeTF>(allow_partial_conversion);
3464 }
3465 
3466 }  // end namespace xla_hlo
3467 }  // end namespace mlir
3468