• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17 
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/ErrorHandling.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
35 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
37 #include "mlir/Dialect/Traits.h"  // from @llvm-project
38 #include "mlir/IR/Attributes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
41 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
42 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
43 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
44 #include "mlir/IR/Matchers.h"  // from @llvm-project
45 #include "mlir/IR/Operation.h"  // from @llvm-project
46 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
48 #include "mlir/IR/Types.h"  // from @llvm-project
49 #include "mlir/Pass/Pass.h"  // from @llvm-project
50 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
51 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
54 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
55 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
56 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
57 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
60 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
61 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
62 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
63 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
64 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
65 #include "tensorflow/compiler/xla/client/padding.h"
66 #include "tensorflow/compiler/xla/client/sharding_builder.h"
67 #include "tensorflow/compiler/xla/xla_data.pb.h"
68 #include "tensorflow/core/framework/kernel_shape_util.h"
69 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
70 #include "tensorflow/core/platform/bfloat16.h"
71 #include "tensorflow/core/tpu/tpu_api.h"
72 #include "tensorflow/core/util/padding.h"
73 #include "tensorflow/core/util/tensor_format.h"
74 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
75 
76 namespace mlir {
77 namespace mhlo {
78 namespace {
79 
80 constexpr char kShardingAttr[] = "mhlo.sharding";
81 
82 class LegalizeTF : public LegalizeTFBase<LegalizeTF> {
83  public:
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)84   explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
85                       llvm::Optional<StringRef> tf2xla_fallback_device_type,
86                       bool prefer_tf2xla) {
87     allow_partial_conversion_ = allow_partial_conversion;
88     legalize_chlo_ = legalize_chlo;
89     prefer_tf2xla_ = prefer_tf2xla;
90     use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
91     if (tf2xla_fallback_device_type.hasValue()) {
92       device_type_ = tf2xla_fallback_device_type.getValue().str();
93     }
94   }
95   /// Performs the lowering to XLA dialect.
96   void runOnFunction() override;
97 };
98 
99 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(tensorflow::TensorFormat format,RankedTensorType input_ty)100 static size_t GetFeatureDimension(tensorflow::TensorFormat format,
101                                   RankedTensorType input_ty) {
102   return GetTensorFeatureDimIndex(input_ty.getRank(), format);
103 }
104 
105 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)106 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
107   auto array_attr = attr.cast<ArrayAttr>();
108   values->reserve(array_attr.getValue().size());
109   for (Attribute val : array_attr.getValue())
110     values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
111 }
112 
113 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)114 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
115                                                Builder *builder) {
116   RankedTensorType ty = RankedTensorType::get(
117       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
118   return DenseIntElementsAttr::get(ty, values);
119 }
120 
121 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)122 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
123   RankedTensorType ty =
124       RankedTensorType::get(static_cast<int64_t>(attr.size()),
125                             IntegerType::get(attr.getContext(), 64));
126   return DenseIntElementsAttr::get(ty, attr.getValue());
127 }
128 
129 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)130 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
131                                                Builder *builder) {
132   RankedTensorType ty = RankedTensorType::get(
133       {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
134   return DenseIntElementsAttr::get(ty, values);
135 }
136 
137 // Returns a 1-d i64 elements attribute populated with numbers from start to
138 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)139 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
140                                                      Builder *builder) {
141   int size = end - start;
142 
143   SmallVector<int64_t, 4> vals;
144   vals.resize(size);
145   std::iota(vals.begin(), vals.end(), start);
146 
147   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
148   return DenseIntElementsAttr::get(ty, vals);
149 }
150 
151 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
152 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)153 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
154                                                        Builder *builder) {
155   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
156   return DenseIntElementsAttr::get(ty, val);
157 }
158 
159 // Returns the corresponding type that should be used for performing sum
160 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)161 Type GetSumAccumulationType(Type input_type) {
162   MLIRContext *ctx = input_type.getContext();
163   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
164   if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
165     return IntegerType::get(ctx, 32);
166   return input_type;
167 }
168 
169 // Returns axis in HLO format from TF elements attr with exactly one element or
170 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
171 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)172 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
173                                         Builder *b) {
174   IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
175   if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
176     SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
177     intAttr = elementAttr.getValue<IntegerAttr>(index);
178   }
179 
180   assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
181 
182   int64_t axis = intAttr.getInt();
183   if (axis < 0) {
184     axis += rank;
185   }
186   return b->getI64IntegerAttr(axis);
187 }
188 
189 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
190 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
191 // be negative, in which case, the corresponding HLO axis is
192 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)193 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
194                                                            int64_t rank) {
195   DenseIntElementsAttr attrs;
196   if (!matchPattern(value, m_Constant(&attrs)) ||
197       attrs.getType().getRank() != 0) {
198     return llvm::None;
199   }
200   int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
201   return axis < 0 ? axis + rank : axis;
202 }
203 
204 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
205 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)206 static ConvertOp CastValueToI64(Location loc, Value value,
207                                 PatternRewriter *rewriter) {
208   return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
209 }
210 
211 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
212 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)213 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
214                                              PatternRewriter *rewriter) {
215   auto indices_type = value.getType().cast<RankedTensorType>();
216   int num_outputs = indices_type.getShape().front();
217   SmallVector<Type, 2> unpacked_indices_type(
218       num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
219   auto unpacked_indices = rewriter->create<TF::UnpackOp>(
220       loc, unpacked_indices_type, value,
221       IntegerAttr::get(rewriter->getIntegerType(64), 0));
222   return unpacked_indices;
223 }
224 
225 // Returns size of dimension at the specified index, if ranked tensor.
226 // Otherwise, returns -1.
227 //
228 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)229 int64_t GetDimSize(Type ty, int64_t index) {
230   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
231   if (!ranked_ty) return -1;
232 
233   return ranked_ty.getDimSize(index);
234 }
235 
236 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)237 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
238   return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
239       sizes.begin(), sizes.end()));
240 }
241 
242 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)243 tensorflow::TensorShape ToTensorShape(
244     llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
245   return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
246       sizes.begin(), sizes.end()));
247 }
248 
249 // Returns int, float, or complex scalar DenseElementsAttr attribute with the
250 // given element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,OpBuilder * builder)251 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
252                                     OpBuilder *builder) {
253   return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
254 }
255 
256 // Returns a limit scalar const op for the given type.
257 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)258 static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
259                                          hlo::ScalarLimit limit,
260                                          OpBuilder *builder) {
261   return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
262 }
263 
264 // Creates an mhlo::SliceOp where the major dimensions have full size, and
265 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)266 static Value SliceInMinorDims(Location loc, Value v,
267                               ArrayRef<int64_t> minor_starts,
268                               ArrayRef<int64_t> minor_limits,
269                               OpBuilder *builder) {
270   auto type = v.getType().cast<RankedTensorType>();
271   llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
272   int64_t major_dims = type.getRank() - minor_starts.size();
273   std::copy(minor_starts.begin(), minor_starts.end(),
274             slice_starts.begin() + major_dims);
275   auto slice_limits = llvm::to_vector<4>(type.getShape());
276   std::copy(minor_limits.begin(), minor_limits.end(),
277             slice_limits.begin() + major_dims);
278   llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
279   return builder->create<SliceOp>(loc, v,
280                                   GetI64ElementsAttr(slice_starts, builder),
281                                   GetI64ElementsAttr(slice_limits, builder),
282                                   GetI64ElementsAttr(slice_strides, builder));
283 }
284 
285 // Creates a vector of index values:
286 //  [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
287 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)288 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
289     Location loc, ArrayRef<Value> minor_indices, int64_t rank,
290     OpBuilder *builder) {
291   auto zero =
292       GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
293                            loc, 0, builder)
294           .output();
295   llvm::SmallVector<Value, 4> indices(rank, zero);
296   std::copy(minor_indices.begin(), minor_indices.end(),
297             indices.begin() + (rank - minor_indices.size()));
298   return indices;
299 }
300 
301 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
302 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)303 static Value DynamicSliceInMinorDims(Location loc, Value v,
304                                      ArrayRef<Value> minor_starts,
305                                      ArrayRef<int64_t> minor_sizes,
306                                      OpBuilder *builder) {
307   if (minor_starts.empty()) return v;
308   auto type = v.getType().cast<RankedTensorType>();
309   auto slice_starts = CreateFullIndexVectorFromMinorIndices(
310       loc, minor_starts, type.getRank(), builder);
311   int64_t major_dims = type.getRank() - minor_starts.size();
312   auto slice_sizes = llvm::to_vector<4>(type.getShape());
313   std::copy(minor_sizes.begin(), minor_sizes.end(),
314             slice_sizes.begin() + major_dims);
315   auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
316   return builder->create<mhlo::DynamicSliceOp>(
317       loc, slice_type, v, slice_starts,
318       GetI64ElementsAttr(slice_sizes, builder));
319 }
320 
321 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
322 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)323 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
324                                            ArrayRef<Value> minor_starts,
325                                            OpBuilder *builder) {
326   if (minor_starts.empty()) return v;
327   auto type = v.getType().cast<RankedTensorType>();
328   auto dus_starts = CreateFullIndexVectorFromMinorIndices(
329       loc, minor_starts, type.getRank(), builder);
330   return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
331                                                llvm::makeArrayRef(dus_starts));
332 }
333 
334 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
335 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)336 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
337                                     ArrayRef<int64_t> minor_starts,
338                                     OpBuilder *builder) {
339   llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
340   for (uint64_t i = 0; i < minor_starts.size(); ++i) {
341     dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
342                                          minor_starts[i], builder);
343   }
344   return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
345 }
346 
347 // Deprecated: This is maintained to aid in porting old code that is not yet
348 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
349 // Gets the resulting type from a broadcast between two types for statically
350 // shaped types. This is to be used for legacy lowerings that both use non
351 // left-padded broadcasting and static shapes. Its use should not be permitted
352 // in new code.
353 // May return nullptr on invalid static broadcast dimensions.
354 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)355 static RankedTensorType GetStaticBroadcastType(
356     RankedTensorType x, RankedTensorType y,
357     DenseIntElementsAttr broadcast_dimensions_attr) {
358   auto element_type = x.getElementType();
359   auto shape_x = x.getShape();
360   auto shape_y = y.getShape();
361 
362   if (shape_x.size() == shape_y.size()) {
363     llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
364     for (int i = 0; i < shape_x.size(); i++) {
365       auto x_val = shape_x[i];
366       auto y_val = shape_y[i];
367       out_shape[i] = std::max(x_val, y_val);
368     }
369     return RankedTensorType::get(out_shape, element_type);
370   }
371 
372   auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
373   auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
374 
375   llvm::SmallVector<int64_t, 4> broadcast_dimensions;
376   // Explicit broadcast dimensions.
377   for (const APInt &int_value : broadcast_dimensions_attr) {
378     broadcast_dimensions.push_back(int_value.getSExtValue());
379   }
380   if (broadcast_dimensions.size() != shape_small.size()) {
381     return nullptr;
382   }
383   llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
384                                           shape_large.end());
385 
386   // Update according to the broadcast dimensions.
387   for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
388     auto old_value = out_shape[index_pair.value()];
389     auto new_value = shape_small[index_pair.index()];
390     out_shape[index_pair.value()] = std::max(old_value, new_value);
391   }
392   return RankedTensorType::get(out_shape, element_type);
393 }
394 
395 // Deprecated: This is maintained to aid in porting old code that is not yet
396 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
397 // Applies static binary broadcasting to a binary elementwise op.
398 // This is a legacy helper to provide general broadcasting support in legacy,
399 // static shaped code that relies on non-left-padded broadcasting semantics.
400 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)401 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
402                                    DenseIntElementsAttr broadcast_dims,
403                                    OpBuilder &builder) {
404   auto x_type = x.getType().cast<RankedTensorType>();
405   auto y_type = y.getType().cast<RankedTensorType>();
406   auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
407   if (!result_type) {
408     emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
409                    << " with broadcast_dims = " << broadcast_dims;
410     return nullptr;
411   }
412   auto larger_broadcast_dims =
413       GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
414   if (x_type.getRank() < y_type.getRank()) {
415     if (x_type != result_type) {
416       x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
417     }
418     if (y_type != result_type) {
419       y = builder.create<BroadcastInDimOp>(loc, result_type, y,
420                                            larger_broadcast_dims);
421     }
422   } else {
423     if (x_type != result_type) {
424       x = builder.create<BroadcastInDimOp>(loc, result_type, x,
425                                            larger_broadcast_dims);
426     }
427     if (y_type != result_type) {
428       y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
429     }
430   }
431   return builder.create<BinaryOp>(loc, x, y);
432 }
433 
434 // Gets a 1D tensor type suitable for expressing extents of the given tensor
435 // value type. If the value type is ranked, the result will be statically
436 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)437 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
438   Builder b(value_type.getContext());
439   int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
440   return RankedTensorType::get({dim}, b.getIndexType());
441 }
442 
443 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
444 // value (broadcast_from) along that feature dimension. This is a shortcut
445 // for the cases where a 1D tensor must be broadcast along a specific feature
446 // dimension, which can vary based on data layout, etc.
447 //
448 // The extent of `broadcast_from` dim0 must be equal to the extent of the
449 // feature_dim of `broadcast_to`.
450 //
451 // Example:
452 //   [1x2x3x4], [2], 1 -> [1x2x3x4]
453 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
454 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)455 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
456                                      Value broadcast_from, int64_t feature_dim,
457                                      OpBuilder &builder) {
458   auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
459   auto to_type = broadcast_to.getType().cast<RankedTensorType>();
460   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
461   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
462   auto result_extents = builder.create<shape::ToExtentTensorOp>(
463       loc, result_extents_type, result_shape);
464   return builder.create<DynamicBroadcastInDimOp>(
465       loc, to_type, broadcast_from, result_extents, broadcast_dims);
466 }
467 
468 // Broadcasts `input` to the shape of `broadcast_to` value following
469 // TF::BroadcastTo semantics.
470 //
471 // Requires that input is a ranked tensor.
472 //
473 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
474 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)475 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
476                                 OpBuilder &builder) {
477   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
478   auto to_type = broadcast_to.getType().cast<TensorType>();
479   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
480   auto result_extents = builder.create<shape::ToExtentTensorOp>(
481       loc, result_extents_type, result_shape);
482   int64_t rank = input.getType().cast<RankedTensorType>().getRank();
483   auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
484   return builder.create<DynamicBroadcastInDimOp>(
485       loc, to_type, input, result_extents, broadcast_dims);
486 }
487 
488 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)489 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
490                bool transpose_rhs, int64_t num_batch_dims,
491                ArrayAttr precision_config, OpBuilder *builder) {
492   auto batch_dimensions = GetI64ElementsAttr(
493       llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims)), builder);
494   auto lhs_contracting_dimensions = GetI64ElementsAttr(
495       llvm::makeArrayRef({transpose_lhs ? num_batch_dims : num_batch_dims + 1}),
496       builder);
497   auto rhs_contracting_dimensions = GetI64ElementsAttr(
498       llvm::makeArrayRef({transpose_rhs ? num_batch_dims + 1 : num_batch_dims}),
499       builder);
500   auto dimension_numbers = DotDimensionNumbers::get(
501       /*lhs_batching_dimensions=*/batch_dimensions,
502       /*rhs_batching_dimensions=*/batch_dimensions,
503       /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
504       /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
505       builder->getContext());
506   auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
507   auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
508   auto shape = llvm::to_vector<4>(lhs_shape);
509   shape[shape.size() - 2] =
510       transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
511   shape[shape.size() - 1] =
512       transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
513   Type element_type = getElementTypeOrSelf(lhs.getType());
514   return builder->create<DotGeneralOp>(
515       loc, RankedTensorType::get(shape, element_type), lhs, rhs,
516       dimension_numbers, precision_config);
517 }
518 
519 // Builds body for reduce op by using the template binary op as the
520 // reducer op.
521 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)522 static void BuildReduceBody(Type element_type, Region *body,
523                             OpBuilder *builder) {
524   OpBuilder::InsertionGuard guard(*builder);
525   Block *block = builder->createBlock(body);
526 
527   // Block arguments are scalars of the given element type.
528   Type type = RankedTensorType::get(/*shape=*/{}, element_type);
529   block->addArguments({type, type});
530 
531   Location loc = body->getLoc();
532   auto reducer =
533       builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
534   builder->create<ReturnOp>(loc, reducer.getResult());
535 }
536 
537 // Builds a set of operations for applying reduction on the input value. A
538 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)539 static Value ApplyReduction(Location loc, Value input,
540                             DenseIntElementsAttr reduce_dims,
541                             OpBuilder *builder) {
542   auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
543   return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
544                                     builder->getBoolAttr(false));
545 }
546 
547 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
548 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)549 static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
550                                              int lower_limit, int upper_limit,
551                                              OpBuilder *builder) {
552   auto i32_type = builder->getIntegerType(32);
553   auto key_type = RankedTensorType::get({num_elements}, i32_type);
554   auto shape_tensor = builder->create<mhlo::ConstOp>(
555       loc, GetI64ElementsAttr({num_elements}, builder));
556 
557   auto lower = builder->create<mhlo::ConstOp>(
558       loc, builder->getI32IntegerAttr(lower_limit));
559   auto upper = builder->create<mhlo::ConstOp>(
560       loc, builder->getI32IntegerAttr(upper_limit));
561 
562   return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
563                                              shape_tensor);
564 }
565 
566 using WhileBodyFnType = llvm::function_ref<void(
567     Location loc, Value iteration, ArrayRef<Value> old_values,
568     SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
569 
570 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
571 // each time calling the given `body_fn` on a set of values to generate a new
572 // set of values. Returns the final set of values via `final_values`. The
573 // initial set of values is passed in via `init_values`.
574 //
575 // This effectively does:
576 //
577 // ```c++
578 // SmallVector<Values, 4> old_values = init_values;
579 // SmallVector<Values, 4> new_values;
580 // for (int i = 0; i < num_iterations; ++i) {
581 //   body_fn(old_values, &new_values, ...);
582 //   old_values = new_values;
583 // }
584 // ```
585 //
586 // Under the hood an induction variable is prepended to values to control the
587 // number of iterations, but that is transparent to `body_fn`, which does not
588 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)589 static void CreateWhile32(Location loc, int num_iterations,
590                           WhileBodyFnType body_fn, ArrayRef<Value> init_values,
591                           SmallVectorImpl<Value> *final_values,
592                           OpBuilder *builder) {
593   int value_count = init_values.size() + 1;
594 
595   // Prepend a loop induction variable to the initial values.
596   SmallVector<Value, 2> init_values_with_loop_iv;
597   init_values_with_loop_iv.reserve(value_count);
598   // The initial value for the loop induction variable is 0.
599   init_values_with_loop_iv.push_back(
600       builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
601   init_values_with_loop_iv.append(init_values.begin(), init_values.end());
602 
603   // Prepare the initial tuple for the while op.
604   auto init_tuple =
605       builder->create<mhlo::TupleOp>(loc, init_values_with_loop_iv);
606   auto tuple_type = init_tuple.getType();
607 
608   // Create the while op.
609   auto while_op = builder->create<mhlo::WhileOp>(
610       loc, tuple_type, SmallVector<Value>{init_tuple});
611 
612   {
613     OpBuilder::InsertionGuard guard(*builder);
614 
615     // Build up the only block in the condition region. It should take one
616     // argument of the loop's tuple type.
617     Region &condition = while_op.cond();
618     Block *block = builder->createBlock(&condition);
619     BlockArgument arg = block->addArgument(tuple_type);
620 
621     // Get the loop induction variable and compare it against the upper limit.
622     auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
623     auto upper_limit = builder->create<mhlo::ConstOp>(
624         loc, builder->getI32IntegerAttr(num_iterations));
625     StringAttr compare_direction = StringAttr::get(builder->getContext(), "LT");
626     Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
627                                                      compare_direction);
628 
629     builder->create<mhlo::ReturnOp>(loc, compare);
630   }
631 
632   {
633     OpBuilder::InsertionGuard guard(*builder);
634 
635     // Build up the only block in the body region. It should take one
636     // argument of the loop's tuple type.
637     Region &body = while_op.body();
638     Block *block = builder->createBlock(&body);
639     BlockArgument arg = block->addArgument(tuple_type);
640 
641     SmallVector<Value, 4> old_values;  // From the previous iteration
642     SmallVector<Value, 4> new_values;  // Generated by this iteration
643     old_values.reserve(value_count);
644     new_values.reserve(value_count);
645 
646     // Unpack the tuple value from the last iteration.
647     for (int i = 0; i < value_count; ++i)
648       old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
649 
650     // Feed all values excluding the loop induction variable to body_fn.
651     body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
652             &new_values, builder);
653 
654     // Increment the loop induction variable by one.
655     auto one =
656         builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
657     auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
658     auto plus_one = builder->create<chlo::BroadcastAddOp>(
659         loc, old_values[0], one, scalar_broadcast_dims);
660     // Prepend with the updated loop induction variable.
661     new_values.insert(new_values.begin(), plus_one);
662 
663     Value updated_tuple = builder->create<mhlo::TupleOp>(loc, new_values);
664 
665     builder->create<mhlo::ReturnOp>(loc, updated_tuple);
666   }
667 
668   // TODO(jpienaar): Support multi-operand while op.
669   final_values->reserve(init_values.size());
670   for (int i = 0, e = init_values.size(); i < e; ++i)
671     final_values->push_back(
672         builder->create<GetTupleElementOp>(loc, while_op.getResult(0), i + 1));
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // BatchNorm op utilities.
677 //===----------------------------------------------------------------------===//
678 
getFeatureDimensionAttr(Builder & b,tensorflow::TensorFormat format,Value input)679 static IntegerAttr getFeatureDimensionAttr(Builder &b,
680                                            tensorflow::TensorFormat format,
681                                            Value input) {
682   return b.getI64IntegerAttr(
683       GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // FFT op utilities.
688 //===----------------------------------------------------------------------===//
689 
690 // Returns the 1D i64 elements attribute populated with the inner-most dim of
691 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)692 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
693                                                  Builder *builder) {
694   if (type.getRank() == 0) {
695     return builder->getI64TensorAttr({});
696   }
697   return builder->getI64TensorAttr(type.getShape().back());
698 }
699 
700 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)701 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
702   if (!type.hasRank()) {
703     return false;
704   }
705   return !type.isDynamicDim(type.getShape().size() - 1);
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // MatMul op utilities.
710 //===----------------------------------------------------------------------===//
711 
712 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
713 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)714 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
715   if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
716   return GetI64ElementsAttr({0, 1}, b);
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // MatrixBandPart op utilities.
721 //===----------------------------------------------------------------------===//
722 
723 // Gets the size of the dimension `dim_from_end` from the end of `input`.
724 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)725 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
726   // Note: the verifier enforces that `input` is a ranked tensor.
727   auto input_type = input.getType().cast<TensorType>();
728   auto input_shape = input_type.getShape();
729   int dim = (input_shape.size() - 1) - dim_from_end;
730   return input_shape[dim];
731 }
732 
733 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
734 // have the same size as the last two dimensions of `input` (the second-to-last
735 // dimension and last dimension, respectively). The element type of the
736 // outputted RankedTensorType will match the element type of `input`.
737 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)738 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
739   // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
740   int dim_0 = GetDimensionSizeFromEnd(input, 1);
741   int dim_1 = GetDimensionSizeFromEnd(input, 0);
742   auto element_type = num_lower.getType().cast<TensorType>().getElementType();
743   return RankedTensorType::get({dim_0, dim_1}, element_type);
744 }
745 
746 // Creates a HLO ConvertOp, converting `input` to have the same element type as
747 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)748 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
749                              Value elem_type_tensor) {
750   auto element_type =
751       elem_type_tensor.getType().cast<TensorType>().getElementType();
752   return builder->create<mhlo::ConvertOp>(loc, input, element_type);
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // Pad op utilities.
757 //===----------------------------------------------------------------------===//
758 
759 // Slices input attribute of rank two and returns the specified column.
760 //
761 // Always returns 64 bit integer attribute regardless of bitwidth of the input
762 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)763 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
764     ElementsAttr input, int column) {
765   auto int_attr = input.cast<DenseIntElementsAttr>();
766   auto shaped_type = int_attr.getType();
767   auto shape = shaped_type.getShape();
768 
769   if (shape.size() != 2) return DenseIntElementsAttr();
770 
771   llvm::SmallVector<int64_t, 4> values;
772   values.reserve(shaped_type.getNumElements() / shape[1]);
773 
774   for (auto it : llvm::enumerate(int_attr.getIntValues())) {
775     if (static_cast<int>(it.index() % shape[1]) == column) {
776       values.push_back(it.value().getSExtValue());
777     }
778   }
779 
780   auto element_type = IntegerType::get(input.getContext(), 64);
781   return DenseIntElementsAttr::get(
782       RankedTensorType::get({shape[0]}, element_type), values);
783 }
784 
785 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
786 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)787 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
788   auto length = tf_padding.getType().getShape()[0];
789   auto element_type = IntegerType::get(tf_padding.getContext(), 64);
790   return DenseIntElementsAttr::get<int64_t>(
791       RankedTensorType::get({length}, element_type), 0);
792 }
793 
794 //===----------------------------------------------------------------------===//
795 // Binary op utilities.
796 //===----------------------------------------------------------------------===//
797 
798 // Returns whether the two values are guaranteed to be broadcastable to the
799 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
800 // must be broadcasted with a size 1 tensor or another dynamic dimension.
801 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)802 static bool AreBroadcastCompatible(Value x, Value y) {
803   auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
804   auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
805   if (!x_rankless || !y_rankless) {
806     return false;
807   }
808 
809   // Check that the shapes can be broadcasted.
810   auto shape_x = x_rankless.getShape();
811   auto shape_y = y_rankless.getShape();
812 
813   int rank_diff = shape_x.size() - shape_y.size();
814   int offset_x = rank_diff > 0 ? rank_diff : 0;
815   int offset_y = rank_diff < 0 ? -rank_diff : 0;
816   for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
817     int index_x = i + offset_x;
818     int index_y = i + offset_y;
819     if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
820         (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
821       return false;
822     }
823   }
824 
825   return true;
826 }
827 
828 // Return a new TensorType the same rank and dimensions as the input with an
829 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)830 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
831                                     Type element_type) {
832   RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
833   if (ranked_type) {
834     return RankedTensorType::get(ranked_type.getShape(), element_type);
835   }
836 
837   return UnrankedTensorType::get(element_type);
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // Softmax op utilities.
842 //===----------------------------------------------------------------------===//
843 
844 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)845 static Type GetAccumulationType(Type ty) {
846   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
847   // repeated floating point additions.
848   return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // Softplus op utilities.
853 //===----------------------------------------------------------------------===//
854 
GetEpsilonValue(Type ty)855 static DenseElementsAttr GetEpsilonValue(Type ty) {
856   auto element_ty = ty.cast<TensorType>().getElementType();
857   auto scalar_ty = RankedTensorType::get({}, element_ty);
858   if (element_ty.isF16()) {
859     uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
860         Eigen::NumTraits<Eigen::half>::epsilon());
861     auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
862     return DenseElementsAttr::get(scalar_ty, value);
863   } else if (element_ty.isBF16()) {
864     uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
865         Eigen::NumTraits<Eigen::bfloat16>::epsilon());
866     auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
867     return DenseElementsAttr::get(scalar_ty, value);
868   } else if (element_ty.isF32()) {
869     auto value = APFloat(std::numeric_limits<float>::epsilon());
870     return DenseElementsAttr::get(scalar_ty, value);
871   } else if (element_ty.isF64()) {
872     auto value = APFloat(std::numeric_limits<double>::epsilon());
873     return DenseElementsAttr::get(scalar_ty, value);
874   }
875   llvm_unreachable("unsupported element type for tf.SoftPlus");
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // ArgMax/ArgMin op utilities.
880 //===----------------------------------------------------------------------===//
881 
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)882 static void BuildArgMinMaxReductionBody(Type input_element_type,
883                                         Type index_element_type,
884                                         StringRef direction, Region *body,
885                                         OpBuilder *builder) {
886   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
887 
888   Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
889   Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
890   Block *block = builder->createBlock(body);
891   block->addArguments({input_type, index_type, input_type, index_type});
892 
893   Value lhs_val = block->getArgument(0);
894   Value lhs_index = block->getArgument(1);
895   Value rhs_val = block->getArgument(2);
896   Value rhs_index = block->getArgument(3);
897 
898   ImplicitLocOpBuilder b(body->getLoc(), *builder);
899   StringAttr compare_direction = StringAttr::get(b.getContext(), direction);
900   Value compare_dt = b.create<CompareOp>(lhs_val, rhs_val, compare_direction);
901   Value selected_input =
902       b.create<SelectOp>(input_type, compare_dt, lhs_val, rhs_val);
903 
904   Value compare_eq = b.create<CompareOp>(lhs_val, rhs_val,
905                                          StringAttr::get(b.getContext(), "EQ"));
906   Value min_index = b.create<MinOp>(lhs_index, rhs_index);
907   Value min_val_index =
908       b.create<SelectOp>(index_type, compare_dt, lhs_index, rhs_index);
909   Value selected_index =
910       b.create<SelectOp>(index_type, compare_eq, min_index, min_val_index);
911 
912   Value return_values[] = {selected_input, selected_index};
913   b.create<ReturnOp>(return_values);
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // PartitionedCall op utilities.
918 //===----------------------------------------------------------------------===//
919 
920 // Verify that the arguments to be passed into the function are the same types
921 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)922 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
923                                 SymbolRefAttr func) {
924   auto module = op->getParentOfType<ModuleOp>();
925   auto function =
926       dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
927   FunctionType function_ty = function.getType();
928 
929   for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
930     if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
931       // Argument type and input type mismatch.
932       return false;
933     }
934   }
935   return true;
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // Slice op utilities.
940 //===----------------------------------------------------------------------===//
941 
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)942 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
943                                           DenseIntElementsAttr slice_sizes) {
944   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
945   if (!input_ty) return false;
946   auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
947   if (!start_indices_ty) return false;
948 
949   int64_t input_rank = input_ty.getRank();
950   ArrayRef<int64_t> input_shape = input_ty.getShape();
951   DenseIntElementsAttr constant_start_indices;
952   bool is_constant_start =
953       matchPattern(start_indices, m_Constant(&constant_start_indices));
954 
955   for (int64_t i = 0; i < input_rank; ++i) {
956     int64_t input_size = input_shape[i];
957     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
958     // A slice_size of -1 means "all elements from start_index to the end".
959     // In order to support these semantics, we need to know both the start index
960     // and the shape of the input dimension.
961     if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
962   }
963   return true;
964 }
965 
966 // TF slice size can be -1, which represents all elements from start_index to
967 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
968 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)969 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
970     Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
971     Builder *builder) {
972   DenseIntElementsAttr constant_start_indices;
973   if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
974     return hlo::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
975         .cast<DenseIntElementsAttr>();
976   }
977 
978   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
979   int64_t input_rank = input_ty.getRank();
980   ArrayRef<int64_t> input_shape = input_ty.getShape();
981   SmallVector<int64_t, 4> normalized_sizes;
982 
983   for (int64_t i = 0; i < input_rank; ++i) {
984     int64_t input_size = input_shape[i];
985     int64_t start_index =
986         constant_start_indices.getValue<IntegerAttr>(i).getInt();
987     int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
988     normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
989                                                 : slice_size);
990   }
991 
992   return GetI64ElementsAttr(normalized_sizes, builder);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // Sort op utilities.
997 //===----------------------------------------------------------------------===//
998 
999 // Builds the region `body` for mhlo.sort's comparator: for each type in
1000 // `element_types`, create two block arguments, one for lhs and one for rhs, and
1001 // generates mhlo.compare op to compare them with the given `direction`.
1002 //
1003 // Note that this right now only does comparision on the first pair of block
1004 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,llvm::Optional<StringRef> compare_type,Region * body,OpBuilder * builder)1005 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
1006                                     StringRef direction,
1007                                     llvm::Optional<StringRef> compare_type,
1008                                     Region *body, OpBuilder *builder) {
1009   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
1010 
1011   Block *block = builder->createBlock(body);
1012   // Add two arguments for each element type.
1013   for (Type element_type : element_types) {
1014     TensorType tensor_type = RankedTensorType::get({}, element_type);
1015     block->addArguments({tensor_type, tensor_type});
1016   }
1017 
1018   Location loc = body->getLoc();
1019   StringAttr compare_direction = builder->getStringAttr(direction);
1020   StringAttr type_attr;
1021   if (compare_type) type_attr = builder->getStringAttr(*compare_type);
1022   Value compare = builder->create<mhlo::CompareOp>(
1023       loc, block->getArgument(0), block->getArgument(1), compare_direction,
1024       type_attr);
1025 
1026   builder->create<mhlo::ReturnOp>(loc, compare);
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // XlaGather op utilities.
1031 //===----------------------------------------------------------------------===//
1032 
HasValidGatherDims(StringAttr attr)1033 bool HasValidGatherDims(StringAttr attr) {
1034   ::xla::GatherDimensionNumbers dims;
1035   return dims.ParseFromString(attr.getValue().str());
1036 }
1037 
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)1038 GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
1039   ::xla::GatherDimensionNumbers dims;
1040   if (!dims.ParseFromString(attr.getValue().str())) return {};
1041   return ::xla::ConvertGatherDimensionNumbers(dims, builder);
1042 }
1043 
1044 //===----------------------------------------------------------------------===//
1045 // XlaDot op utilities.
1046 //===----------------------------------------------------------------------===//
1047 
HasValidDotDims(StringAttr attr)1048 bool HasValidDotDims(StringAttr attr) {
1049   ::xla::DotDimensionNumbers dims;
1050   return dims.ParseFromString(attr.getValue().str());
1051 }
1052 
GetDotDimNumsAttr(StringAttr attr,Builder * builder)1053 DotDimensionNumbers GetDotDimNumsAttr(StringAttr attr, Builder *builder) {
1054   ::xla::DotDimensionNumbers dims;
1055   if (!dims.ParseFromString(attr.getValue().str())) return {};
1056   return ::xla::ConvertDotDimensionNumbers(dims, builder);
1057 }
1058 
HasValidPrecisionConfig(StringAttr attr)1059 bool HasValidPrecisionConfig(StringAttr attr) {
1060   ::xla::PrecisionConfig precision;
1061   return precision.ParseFromString(attr.getValue().str());
1062 }
1063 
GetPrecisionConfigAttr(StringAttr attr,Builder * builder)1064 mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) {
1065   ::xla::PrecisionConfig precision;
1066   if (!precision.ParseFromString(attr.getValue().str())) return {};
1067   return ::xla::ConvertPrecisionConfig(&precision, builder);
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // Op converters.
1072 //===----------------------------------------------------------------------===//
1073 
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)1074 NamedAttribute GetConvDimensionNumbersAttr(
1075     ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
1076     Builder *builder) {
1077   int64_t num_spatial_dims = spatial_dim_indices.size();
1078   int64_t num_dims = num_spatial_dims + 2;
1079 
1080   IntegerAttr batch_dim =
1081       builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
1082   IntegerAttr feature_dim =
1083       builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
1084   DenseIntElementsAttr spatial_dims =
1085       GetI64ElementsAttr(spatial_dim_indices, builder);
1086 
1087   // Filters data_format is always HWIO so input channels dimension is after
1088   // all spatial dimensions.
1089   IntegerAttr kernel_input_feature_dim =
1090       builder->getI64IntegerAttr(num_spatial_dims);
1091   IntegerAttr kernel_output_feature_dim =
1092       builder->getI64IntegerAttr(num_spatial_dims + 1);
1093   DenseIntElementsAttr kernel_spatial_dimensions =
1094       GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
1095 
1096   return builder->getNamedAttr(
1097       "dimension_numbers",
1098       ConvDimensionNumbers::get(
1099           batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
1100           kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
1101           feature_dim, spatial_dims, builder->getContext()));
1102 }
1103 
1104 // Converts a TF::BiasAddOp to HLO.
1105 // This differs from a normal TF::AddOp with respect to how the data_format
1106 // is handled, which can optionally require a general broadcast of the
1107 // 'bias' term in a way that is not compatible with the standard left-padded
1108 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1109 // The correct 'bias' broadcast will be synthesized manually.
1110 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1111  public:
1112   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1113   LogicalResult matchAndRewrite(TF::BiasAddOp op,
1114                                 PatternRewriter &rewriter) const override {
1115     Location loc = op.getLoc();
1116     tensorflow::TensorFormat data_format;
1117     if (!FormatFromString(op.data_format().str(), &data_format))
1118       return op.emitOpError("invalid data format");
1119 
1120     auto feature_dim = GetFeatureDimension(
1121         data_format, op.value().getType().cast<RankedTensorType>());
1122     auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1123                                                   feature_dim, rewriter);
1124     Value add = rewriter.create<AddOp>(loc, op.value(), bias_broadcast);
1125     if (add.getType() != op.getType()) {
1126       add = rewriter.create<tensor::CastOp>(loc, op.getType(), add);
1127     }
1128     rewriter.replaceOp(op, {add});
1129     return success();
1130   }
1131 };
1132 
1133 // Convert TF::GatherV2Op to mhlo::DynamicGatherOp
1134 class ConvertGatherV2OpDynamic : public OpRewritePattern<TF::GatherV2Op> {
1135   using OpRewritePattern<TF::GatherV2Op>::OpRewritePattern;
1136   // TODO(disc): To recover static special case's performance with folding and
1137   // canonicalization.
matchAndRewrite(TF::GatherV2Op op,PatternRewriter & rewriter) const1138   LogicalResult matchAndRewrite(TF::GatherV2Op op,
1139                                 PatternRewriter &rewriter) const override {
1140     Location loc = op.getLoc();
1141     Value params = op.params();
1142     // params and indices of GatherNdOp must be ranked
1143     auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1144     Value indices = op.indices();
1145     auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1146     if (!params_ty || !indices_ty) return failure();
1147 
1148     // TODO(disc): Remove this constraint once fold and canonicalization
1149     // implemented.
1150     if (params_ty.hasStaticShape() && indices_ty.hasStaticShape())
1151       return failure();
1152 
1153     int64_t params_rank = params_ty.getRank();
1154     int64_t indices_rank = indices_ty.getRank();
1155 
1156     // axis
1157     DenseIntElementsAttr axis_attr;
1158     // axis must be const for GatherOp
1159     if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1160 
1161     int64_t axis = (*axis_attr.begin()).getSExtValue();
1162     if (axis < 0) axis += params_rank;
1163 
1164     // slice_sizes
1165     SmallVector<int64_t, 4> slice_sizes;
1166     slice_sizes.reserve(params_rank);
1167     for (int64_t dim_idx = 0; dim_idx < params_rank; ++dim_idx) {
1168       if (dim_idx == axis) {
1169         slice_sizes.push_back(1);
1170       } else {
1171         // potentially dynamic
1172         int64_t dim_size = params_ty.getDimSize(dim_idx);
1173         slice_sizes.push_back(dim_size);
1174       }
1175     }
1176     SmallVector<Value, 4> slice_sizes_vals;
1177     for (int64_t dim_idx = 0; dim_idx < params_rank; ++dim_idx) {
1178       if (dim_idx == axis) {
1179         slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1180             loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1181       } else {
1182         int64_t dim_size = params_ty.getDimSize(dim_idx);
1183         if (dim_size != ShapedType::kDynamicSize) {
1184           slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1185               loc,
1186               rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1187         } else {
1188           slice_sizes_vals.push_back(rewriter.create<IndexCastOp>(
1189               loc, rewriter.create<tensor::DimOp>(loc, params, dim_idx),
1190               indices_ty.getElementType()));
1191         }
1192       }
1193     }
1194     Value slice_sizes_value = rewriter.create<tensor::FromElementsOp>(
1195         loc, indices_ty.getElementType(), slice_sizes_vals);
1196     // offset_dims
1197     SmallVector<int64_t, 4> offset_dims;
1198     for (int64_t dim_idx = 0; dim_idx < params_rank; dim_idx++) {
1199       if (dim_idx < axis) {
1200         offset_dims.push_back(dim_idx);
1201       } else if (dim_idx >= axis + 1) {
1202         offset_dims.push_back(dim_idx + indices_rank - 1);
1203       }
1204     }
1205     // collapsed_slice_dims
1206     SmallVector<int64_t, 4> collapsed_slice_dims(1, axis);
1207     // start_index_map
1208     SmallVector<int64_t, 4> start_index_map(1, axis);
1209     // index_vector_dim
1210     int64_t index_vector_dim = indices_rank;
1211     auto dims_attr = GatherDimensionNumbers::get(
1212         /*offset_dims=*/GetI64ElementsAttr(offset_dims, &rewriter),
1213         /*collapsed_slice_dims=*/
1214         GetI64ElementsAttr(collapsed_slice_dims, &rewriter),
1215         /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1216         /*index_vector_dim=*/rewriter.getI64IntegerAttr(index_vector_dim),
1217         rewriter.getContext());
1218 
1219     rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1220         op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1221         dims_attr);
1222     return success();
1223   }
1224 };
1225 
1226 // Conterts tf.Conv2D to mhlo.dynamic_conv.
1227 // TODO(disc): To recover static special case's performance with adding folding,
1228 // canonicalization func and removing ConvertConvOp.
1229 template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
1230 class ConvertConvDynamic : public OpRewritePattern<OpT> {
1231  public:
1232   using OpRewritePattern<OpT>::OpRewritePattern;
1233 
GetPaddingValues(OpT & op,PatternRewriter & rewriter,Value input_size,Value filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding_type,Type shape_scalar_type,Value * padding_low,Value * padding_high) const1234   bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size,
1235                         Value filter_size, int64_t dilation_rate,
1236                         int64_t stride, tensorflow::Padding padding_type,
1237                         Type shape_scalar_type, Value *padding_low,
1238                         Value *padding_high) const {
1239     // Stride must be > 0
1240     if (stride <= 0) return false;
1241     // Dilation rate must be >= 1
1242     if (dilation_rate < 1) return false;
1243 
1244     Location loc = op.getLoc();
1245     switch (padding_type) {
1246       case tensorflow::Padding::VALID: {
1247         auto zero = rewriter.create<ConstantIntOp>(loc, 0, shape_scalar_type);
1248         *padding_low = *padding_high = zero;
1249         break;
1250       }
1251       case tensorflow::Padding::EXPLICIT:
1252         break;
1253       case tensorflow::Padding::SAME: {
1254         auto zero = rewriter.create<ConstantIntOp>(loc, 0, shape_scalar_type);
1255         auto one = rewriter.create<ConstantIntOp>(loc, 1, shape_scalar_type);
1256         auto two = rewriter.create<ConstantIntOp>(loc, 2, shape_scalar_type);
1257         // See also the parallel implementation in
1258         // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size
1259         // - 1) * dilation_rate + 1
1260         Value stride_value =
1261             rewriter.create<ConstantIntOp>(loc, stride, shape_scalar_type);
1262         Value dilation_rate_value = rewriter.create<ConstantIntOp>(
1263             loc, dilation_rate, shape_scalar_type);
1264         Value effective_filter_size_op = rewriter.create<AddIOp>(
1265             loc, one,
1266             rewriter.create<MulIOp>(
1267                 loc, dilation_rate_value,
1268                 rewriter.create<SubIOp>(loc, filter_size, one)));
1269         // output_size = (input_size + stride - 1) / stride;
1270         Value output_size = rewriter.create<UnsignedDivIOp>(
1271             loc,
1272             rewriter.create<AddIOp>(
1273                 loc, input_size,
1274                 rewriter.create<SubIOp>(loc, stride_value, one)),
1275             stride_value);
1276         // std::max(int64{0}, (output_size - 1) * stride +
1277         //     effective_filter_size - input_size);
1278         Value padding_needed = rewriter.create<SubIOp>(
1279             loc,
1280             rewriter.create<AddIOp>(
1281                 loc, effective_filter_size_op,
1282                 rewriter.create<MulIOp>(
1283                     loc, stride_value,
1284                     rewriter.create<SubIOp>(loc, output_size, one))),
1285             input_size);
1286         Value cond = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge,
1287                                                    padding_needed, zero);
1288         padding_needed = rewriter.create<mlir::SelectOp>(
1289             loc, padding_needed.getType(), cond, padding_needed, zero);
1290         *padding_low =
1291             rewriter.create<UnsignedDivIOp>(loc, padding_needed, two);
1292         *padding_high =
1293             rewriter.create<SubIOp>(loc, padding_needed, *padding_low);
1294         break;
1295       }
1296     }
1297     return true;
1298   }
1299 
matchAndRewriteDynamicConv(OpT op,PatternRewriter & rewriter) const1300   LogicalResult matchAndRewriteDynamicConv(OpT op,
1301                                            PatternRewriter &rewriter) const {
1302     tensorflow::TensorFormat data_format;
1303     if (!FormatFromString(op.data_format().str(), &data_format))
1304       return op.emitOpError("invalid data format");
1305 
1306     tensorflow::Padding padding;
1307     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1308       return failure();
1309 
1310     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1311     auto filter_ty =
1312         op.filter().getType().template dyn_cast<RankedTensorType>();
1313     auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1314     if (!input_ty || !filter_ty || !result_ty) return failure();
1315     // TODO(disc): Remove this constraint once fold and canonicalization
1316     // implemented.
1317     if (input_ty.hasStaticShape() && filter_ty.hasStaticShape())
1318       return failure();
1319 
1320     ArrayRef<Attribute> dilations = op.dilations().getValue();
1321     ArrayRef<Attribute> strides = op.strides().getValue();
1322     ArrayRef<Attribute> explicit_paddings;
1323     if (padding == tensorflow::Padding::EXPLICIT) {
1324       // EXPLICIT padding mode and the associated attribute is attached to
1325       // Conv2D.
1326       explicit_paddings =
1327           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1328     }
1329 
1330     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1331     SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1332     SmallVector<int64_t, num_spatial_dims> window_strides;
1333     SmallVector<Value, num_spatial_dims * 2> paddings;
1334 
1335     auto get_int = [](Attribute attr) {
1336       return attr.template cast<IntegerAttr>().getInt();
1337     };
1338 
1339     constexpr int num_dims = num_spatial_dims + 2;
1340 
1341     Location loc = op.getLoc();
1342     auto shape_scalar_type = rewriter.getIntegerType(32);
1343 
1344     auto get_const = [&](int64_t val) {
1345       return rewriter.create<mlir::ConstantIntOp>(loc, val, shape_scalar_type);
1346     };
1347     auto get_dim_value = [&](Value val, int64_t dim) {
1348       Value dim_value = rewriter.create<tensor::DimOp>(loc, val, dim);
1349       return rewriter.create<IndexCastOp>(loc, dim_value, shape_scalar_type);
1350     };
1351 
1352     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1353       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1354       spatial_dim_indices.push_back(dim);
1355 
1356       const int64_t dilation = get_int(dilations[dim]);
1357       rhs_dilations.push_back(dilation);
1358       const int64_t stride = get_int(strides[dim]);
1359       window_strides.push_back(stride);
1360 
1361       Value pad_low, pad_high;
1362       if (padding == tensorflow::Padding::EXPLICIT) {
1363         pad_low = get_const(get_int(explicit_paddings[2 * dim]));
1364         pad_high = get_const(get_int(explicit_paddings[2 * dim + 1]));
1365       } else {
1366         auto input_size = get_dim_value(op.input(), dim);
1367         auto filter_size = get_dim_value(op.filter(), i);
1368         if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation,
1369                               stride, padding, shape_scalar_type, &pad_low,
1370                               &pad_high)) {
1371           return failure();
1372         }
1373       }
1374       paddings.push_back(pad_low);
1375       paddings.push_back(pad_high);
1376     }
1377     auto rhs_dilations_attr = rewriter.getNamedAttr(
1378         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1379 
1380     auto window_strides_attr = rewriter.getNamedAttr(
1381         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1382 
1383     auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1384         spatial_dim_indices, data_format, &rewriter);
1385 
1386     const int64_t input_channels =
1387         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1388     // Filters data_format is always HWIO so input channels dimension is after
1389     // all spatial dimensions.
1390     const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1391     // TensorFlow convolution op verifies that the number of input channels is
1392     // divisible by the number of filter channels.
1393     // For depthwise convolution the feature_group_count argument would be set
1394     // to the input feature dimension.
1395     const int64_t feature_group_count =
1396         depthwise_conv ? input_channels : input_channels / filter_channels;
1397     auto feature_group_count_attr = rewriter.getNamedAttr(
1398         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1399 
1400     auto batch_group_count_attr = rewriter.getNamedAttr(
1401         "batch_group_count", rewriter.getI64IntegerAttr(1));
1402 
1403     Value paddings_op = rewriter.create<tensor::FromElementsOp>(
1404         op.getLoc(), rewriter.getI32Type(), paddings);
1405 
1406     SmallVector<Value, 3> operands(op.getOperands());
1407     operands.push_back(paddings_op);
1408     // Reshape the filter to {spatial_dims...., 1,in_channels *
1409     // channel_multiplier}
1410     if (depthwise_conv) {
1411       ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1412       llvm::SmallVector<int64_t, num_dims> new_shape(
1413           filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1414       new_shape.push_back(1);
1415       new_shape.push_back(filter_shape[num_spatial_dims] *
1416                           filter_shape[num_spatial_dims + 1]);
1417       operands[1] = rewriter.create<mhlo::ReshapeOp>(
1418           op.getLoc(),
1419           RankedTensorType::get(new_shape, filter_ty.getElementType()),
1420           operands[1]);
1421     }
1422     NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1423                               dimension_numbers_attr, feature_group_count_attr,
1424                               batch_group_count_attr};
1425     rewriter.replaceOpWithNewOp<mhlo::DynamicConvOp>(op, op.getType(), operands,
1426                                                      llvm::makeArrayRef(attrs));
1427     return success();
1428   }
1429 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const1430   LogicalResult matchAndRewrite(OpT op,
1431                                 PatternRewriter &rewriter) const override {
1432     return matchAndRewriteDynamicConv(op, rewriter);
1433   }
1434 };
1435 
1436 using ConvertConv2DDynamic =
1437     ConvertConvDynamic<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1438 
1439 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1440 // converting TensorFlow op attributes to HLO op attributes.
1441 //
1442 // Sample result for Conv2D:
1443 //
1444 //   %conv = "mhlo.convolution"(%input, %filter) {
1445 //     strides = [1, 2],
1446 //     paddings = [[1, 0], [1, 1]],
1447 //     ...
1448 //   }
1449 //
1450 // This pattern is not defined using declarative rewrite rules as computation of
1451 // the paddings attribute anyway requires multiple source op attributes and
1452 // result op attributes. Defining it as declarative rewrite rule will introduce
1453 // some duplication in the C++ helper methods.
1454 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1455 class ConvertConvOp : public OpRewritePattern<OpTy> {
1456  public:
1457   using OpRewritePattern<OpTy>::OpRewritePattern;
1458 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1459   LogicalResult matchAndRewrite(OpTy op,
1460                                 PatternRewriter &rewriter) const override {
1461     tensorflow::TensorFormat data_format;
1462     if (!FormatFromString(op.data_format().str(), &data_format))
1463       return op.emitOpError("invalid data format");
1464 
1465     tensorflow::Padding padding;
1466     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1467       return failure();
1468 
1469     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1470     auto filter_ty =
1471         op.filter().getType().template dyn_cast<RankedTensorType>();
1472     auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1473 
1474     // Input, filter and the result needs to have static shape for calculation
1475     // of HLO paddings and feature group count attributes.
1476     for (RankedTensorType ty : {input_ty, filter_ty, result_ty})
1477       if (!ty || !ty.hasStaticShape()) return failure();
1478 
1479     ArrayRef<Attribute> dilations = op.dilations().getValue();
1480     ArrayRef<Attribute> strides = op.strides().getValue();
1481     ArrayRef<Attribute> explicit_paddings;
1482     if (padding == tensorflow::Padding::EXPLICIT) {
1483       // EXPLICIT padding mode and the associated attribute is limited to
1484       // Conv2D. So, fetch attribute by identifier instead of the
1485       // op.explicit_paddings() attribute getter.
1486       explicit_paddings =
1487           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1488     }
1489 
1490     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1491     SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1492     SmallVector<int64_t, num_spatial_dims> window_strides;
1493     SmallVector<int64_t, num_spatial_dims * 2> paddings;
1494 
1495     auto get_int = [](Attribute attr) {
1496       return attr.template cast<IntegerAttr>().getInt();
1497     };
1498 
1499     constexpr int num_dims = num_spatial_dims + 2;
1500     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1501       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1502       spatial_dim_indices.push_back(dim);
1503 
1504       const int64_t dilation = get_int(dilations[dim]);
1505       rhs_dilations.push_back(dilation);
1506       const int64_t stride = get_int(strides[dim]);
1507       window_strides.push_back(stride);
1508 
1509       int64_t pad_low, pad_high;
1510       if (padding == tensorflow::Padding::EXPLICIT) {
1511         pad_low = get_int(explicit_paddings[2 * dim]);
1512         pad_high = get_int(explicit_paddings[2 * dim + 1]);
1513       } else {
1514         int64_t output_size;
1515         int64_t pad_low_int64;
1516         int64_t pad_high_int64;
1517         tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1518             input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
1519             padding, &output_size, &pad_low_int64, &pad_high_int64);
1520         if (!status.ok()) return failure();
1521         pad_low = pad_low_int64;
1522         pad_high = pad_high_int64;
1523       }
1524       paddings.push_back(pad_low);
1525       paddings.push_back(pad_high);
1526     }
1527 
1528     auto rhs_dilations_attr = rewriter.getNamedAttr(
1529         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1530 
1531     auto window_strides_attr = rewriter.getNamedAttr(
1532         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1533 
1534     auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1535         spatial_dim_indices, data_format, &rewriter);
1536 
1537     const int64_t input_channels =
1538         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1539     // Filters data_format is always HWIO so input channels dimension is after
1540     // all spatial dimensions.
1541     const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1542     // TensorFlow convolution op verifies that the number of input channels is
1543     // divisible by the number of filter channels.
1544     // For depthwise convolution the feature_group_count argument would be set
1545     // to the input feature dimension.
1546     const int64_t feature_group_count =
1547         depthwise_conv ? input_channels : input_channels / filter_channels;
1548     auto feature_group_count_attr = rewriter.getNamedAttr(
1549         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1550 
1551     auto batch_group_count_attr = rewriter.getNamedAttr(
1552         "batch_group_count", rewriter.getI64IntegerAttr(1));
1553 
1554     RankedTensorType paddings_ty = RankedTensorType::get(
1555         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1556     auto paddings_attr = rewriter.getNamedAttr(
1557         "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1558 
1559     SmallVector<Value, 2> operands(op.getOperands());
1560     // Reshape the filter to {spatial_dims...., 1,in_channels *
1561     // channel_multiplier}
1562     if (depthwise_conv) {
1563       ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1564       llvm::SmallVector<int64_t, num_dims> new_shape(
1565           filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1566       new_shape.push_back(1);
1567       new_shape.push_back(filter_shape[num_spatial_dims] *
1568                           filter_shape[num_spatial_dims + 1]);
1569       operands[1] = rewriter.create<mhlo::ReshapeOp>(
1570           op.getLoc(),
1571           RankedTensorType::get(new_shape, filter_ty.getElementType()),
1572           operands[1]);
1573     }
1574     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
1575                               dimension_numbers_attr, feature_group_count_attr,
1576                               batch_group_count_attr, paddings_attr};
1577     rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
1578                                         llvm::makeArrayRef(attrs));
1579     return success();
1580   }
1581 };
1582 
1583 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1584 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1585 using ConvertDepthConv2DOp =
1586     ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1587                   /*depthwise_conv=*/true>;
1588 
1589 // Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const.
1590 class ConvertPadOpDynamic : public OpRewritePattern<TF::PadV2Op> {
1591  public:
1592   using OpRewritePattern::OpRewritePattern;
1593   // TODO(disc): To recover static special case's performance with folding and
1594   // canonicalization.
matchAndRewrite(TF::PadV2Op op,PatternRewriter & rewriter) const1595   LogicalResult matchAndRewrite(TF::PadV2Op op,
1596                                 PatternRewriter &rewriter) const override {
1597     Location loc = op.getLoc();
1598     auto input = op.input();
1599     auto paddings = op.paddings();
1600     auto constant_values = op.constant_values();
1601     auto input_type = input.getType().dyn_cast<RankedTensorType>();
1602     auto paddings_type = paddings.getType().dyn_cast<RankedTensorType>();
1603     if (!input_type || !paddings_type || !paddings_type.hasStaticShape())
1604       return failure();
1605 
1606     // TODO(disc): Remove this constraint once fold and canonicalization is
1607     // implemented.
1608     if (input_type.hasStaticShape()) return failure();
1609 
1610     int input_rank = input_type.getRank();
1611     // interior padding
1612     std::vector<int64_t> interior_values(input_rank, 0);
1613     auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter);
1614 
1615     Value interior_padding_tensor =
1616         rewriter.create<mhlo::ConstOp>(loc, interior_attr);
1617     Type paddings_elem_ty = paddings_type.getElementType();
1618     if (!paddings_elem_ty.isInteger(64)) {
1619       interior_padding_tensor = rewriter.create<mhlo::ConvertOp>(
1620           loc, interior_padding_tensor, paddings_elem_ty);
1621     }
1622     llvm::SmallVector<int64_t, 2> transposed_shape = {2, input_rank};
1623     auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter);
1624     Value transposed_paddings = rewriter.create<mhlo::TransposeOp>(
1625         loc, RankedTensorType::get(transposed_shape, paddings_elem_ty),
1626         paddings, transpose_attr);
1627     Value reshaped_paddings = rewriter.create<mhlo::ReshapeOp>(
1628         loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty),
1629         transposed_paddings);
1630 
1631     auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter);
1632     auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1633     auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1634     Value left_padding_tensor = rewriter.create<mhlo::SliceOp>(
1635         loc, reshaped_paddings, left_padding_start_attr,
1636         left_padding_limit_attr, left_padding_stride_attr);
1637 
1638     auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1639     auto right_padding_limit_attr =
1640         GetI64ElementsAttr({2 * input_rank}, &rewriter);
1641     auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1642     Value right_padding_tensor = rewriter.create<mhlo::SliceOp>(
1643         loc, reshaped_paddings, right_padding_start_attr,
1644         right_padding_limit_attr, right_padding_stride_attr);
1645 
1646     rewriter.replaceOpWithNewOp<mhlo::DynamicPadOp>(
1647         op, op.getType(), input, constant_values, left_padding_tensor,
1648         right_padding_tensor, interior_padding_tensor);
1649 
1650     return success();
1651   }
1652 };
1653 
1654 class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {
1655   using OpRewritePattern<TF::GatherNdOp>::OpRewritePattern;
1656   // Converts tf.GatherNdOp to mhlo.DynamicGatherOp.
1657   // Here we leave 'slice_sizes' as an Attr, without defining a new
1658   // DynamicGatherOp, since GatherDimensionNumbers has already provide enough
1659   // information for shape inference and code generation of mhlo::GatherOp. '?'
1660   // will be filled into slice_sizes for dimensions that are dynamic sized.
1661   // TODO(disc): To recover static special case's performance with folding and
1662   // canonicalization.
matchAndRewrite(TF::GatherNdOp op,PatternRewriter & rewriter) const1663   LogicalResult matchAndRewrite(TF::GatherNdOp op,
1664                                 PatternRewriter &rewriter) const override {
1665     Location loc = op.getLoc();
1666     auto params = op.params();
1667     auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1668     auto indices = op.indices();
1669     auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1670     auto params_rank = params_ty.getRank();
1671     auto indices_rank = indices_ty.getRank();
1672     int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1);
1673     if (!params_ty || !indices_ty) return failure();
1674     // the last dim of indices of GatherNdOp must be fixed shaped
1675     if (num_index_dims == ShapedType::kDynamicSize) return failure();
1676 
1677     SmallVector<int64_t, 4> slice_sizes;
1678     slice_sizes.reserve(params_rank);
1679     for (int64_t i = 0; i < params_rank; ++i) {
1680       if (i < num_index_dims) {
1681         slice_sizes.push_back(1);
1682       } else {
1683         // potentially dynamic
1684         int64_t dim_size = params_ty.getDimSize(i);
1685         slice_sizes.push_back(dim_size);
1686       }
1687     }
1688     SmallVector<Value, 4> slice_sizes_vals;
1689     Value slice_sizes_value = nullptr;
1690     for (int64_t i = 0; i < params_rank; ++i) {
1691       if (i < num_index_dims) {
1692         slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1693             loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1694       } else {
1695         int64_t dim_size = params_ty.getDimSize(i);
1696         if (dim_size != ShapedType::kDynamicSize) {
1697           slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1698               loc,
1699               rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1700         } else {
1701           slice_sizes_vals.push_back(rewriter.create<IndexCastOp>(
1702               loc, rewriter.create<tensor::DimOp>(loc, params, i),
1703               indices_ty.getElementType()));
1704         }
1705       }
1706     }
1707     slice_sizes_value =
1708         rewriter.create<tensor::FromElementsOp>(loc, slice_sizes_vals);
1709 
1710     // collapsed_slice_dims
1711     SmallVector<int64_t, 4> collapsed_slice_dims;
1712     collapsed_slice_dims.reserve(num_index_dims);
1713     for (int64_t i = 0; i < num_index_dims; ++i) {
1714       collapsed_slice_dims.push_back(i);
1715     }
1716     // offset_dims
1717     SmallVector<int64_t, 4> offset_dims;
1718     offset_dims.reserve(params_rank - num_index_dims);
1719     for (int64_t i = num_index_dims; i < params_rank; i++) {
1720       offset_dims.push_back(i + indices_rank - 1 - num_index_dims);
1721     }
1722     // start_index_map
1723     SmallVector<int64_t, 4> start_index_map;
1724     offset_dims.reserve(num_index_dims);
1725     for (int64_t i = 0; i < num_index_dims; i++) {
1726       start_index_map.push_back(i);
1727     }
1728     // index_vector_dim
1729     int64_t index_vector_dim = indices_rank - 1;
1730 
1731     auto dims_attr = GatherDimensionNumbers::get(
1732         /*offset_dims=*/GetI64ElementsAttr(offset_dims, &rewriter),
1733         /*collapsed_slice_dims=*/
1734         GetI64ElementsAttr(collapsed_slice_dims, &rewriter),
1735         /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1736         /*index_vector_dim=*/rewriter.getI64IntegerAttr(index_vector_dim),
1737         rewriter.getContext());
1738     // TODO(disc): Remove this if-statement once fold and canonicalization is
1739     // implemented.
1740     if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
1741       rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
1742           op, op.getType(), op.params(), op.indices(), dims_attr,
1743           GetI64ElementsAttr(slice_sizes, &rewriter));
1744     } else {
1745       rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1746           op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1747           dims_attr);
1748     }
1749     return success();
1750   }
1751 };
1752 
1753 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1754 // division can result in strange behavior.
1755 //
1756 //      floordiv = cast(floordiv(cast(left), cast(right))))
1757 //
1758 //   %left_cast = cast(%left)
1759 //   %right_cast = cast(%right)
1760 //   %div = div(%left, %left)
1761 //   %floored = floor(%div)
1762 //   %floored_cast = cast(%floored)
1763 //
1764 // Required to manually specify the intermediate types.
1765 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1766  public:
1767   using OpRewritePattern::OpRewritePattern;
1768 
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1769   LogicalResult matchAndRewrite(TF::FloorDivOp op,
1770                                 PatternRewriter &rewriter) const override {
1771     auto l = op.x();
1772     auto r = op.y();
1773     auto element_type = getElementTypeOrSelf(l.getType());
1774     if (!element_type.isBF16()) return failure();
1775 
1776     auto out_type = op.z().getType().cast<TensorType>();
1777 
1778     l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1779     r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1780 
1781     auto intermediate = rewriter.create<TF::FloorDivOp>(
1782         op.getLoc(),
1783         ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1784         r);
1785 
1786     auto floor_op =
1787         rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1788     rewriter.replaceOp(op, floor_op.getResult());
1789     return success();
1790   }
1791 };
1792 
1793 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1794  public:
1795   using OpRewritePattern::OpRewritePattern;
1796 
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1797   LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1798                                 PatternRewriter &rewriter) const override {
1799     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1800     auto output_type = op.output().getType();
1801     if (!input_type) {
1802       return rewriter.notifyMatchFailure(op, "requires ranked input shape");
1803     }
1804     llvm::SmallVector<int64_t, 4> broadcast_dimensions;
1805     if (input_type.getRank() > 0) {
1806       auto ranked_output_type = output_type.dyn_cast<RankedTensorType>();
1807       if (!ranked_output_type) {
1808         return rewriter.notifyMatchFailure(op, "requires ranked output shape");
1809       }
1810       auto rank_diff = ranked_output_type.getRank() - input_type.getRank();
1811       // The tf.BroadcastTo op performs "right-aligned" numpy-style
1812       // broadcasting.
1813       broadcast_dimensions = llvm::to_vector<4>(
1814           llvm::seq<int64_t>(rank_diff, ranked_output_type.getRank()));
1815     }
1816     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1817         op, output_type, op.input(), op.shape(),
1818         rewriter.getI64TensorAttr(broadcast_dimensions));
1819     return success();
1820   }
1821 };
1822 
1823 /// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis
1824 /// have to be a constant.
1825 class ConvertRollOp : public OpRewritePattern<TF::RollOp> {
1826  public:
1827   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::RollOp op,PatternRewriter & rewriter) const1828   LogicalResult matchAndRewrite(TF::RollOp op,
1829                                 PatternRewriter &rewriter) const override {
1830     auto shift_ty = op.shift().getType().dyn_cast<RankedTensorType>();
1831     if (!shift_ty || shift_ty.getRank() != 0) {
1832       return rewriter.notifyMatchFailure(
1833           op, "require the type of shift to be 0D tensor");
1834     }
1835 
1836     APInt val;
1837     if (!matchPattern(op.axis(), m_ConstantInt(&val))) {
1838       return rewriter.notifyMatchFailure(op, "require axis to be constant");
1839     }
1840     int axis = val.getSExtValue();
1841 
1842     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1843     if (!input_ty || !input_ty.hasStaticShape()) {
1844       return rewriter.notifyMatchFailure(
1845           op, "require the type of input to have static shapes");
1846     }
1847     ArrayRef<int64_t> input_shape = input_ty.getShape();
1848     int input_rank = input_ty.getRank();
1849     if (axis < 0) axis += input_rank;
1850 
1851     // Adjust large offsets into [0, axis_size). This also makes negative
1852     // offsets positive.
1853     // offset = ((offset % axis_size) + axis_size) % axis_size
1854     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1855     Value offset = op.shift();
1856     auto axis_size = b.create<mhlo::ConstOp>(b.getIntegerAttr(
1857         getElementTypeOrSelf(offset.getType()), input_shape[axis]));
1858     offset = b.create<RemOp>(
1859         b.create<AddOp>(b.create<RemOp>(offset, axis_size), axis_size),
1860         axis_size);
1861 
1862     // Stack two copies of the dimension, then slice from the calculated
1863     // offset. This also works if shift is not constant.
1864     // DynamicSliceOp requires the sizes being integer, and we can get the
1865     // information from input shape.
1866     auto concat = b.create<ConcatenateOp>(ValueRange{op.input(), op.input()},
1867                                           b.getI64IntegerAttr(axis));
1868     Value zero = b.create<mhlo::ConstOp>(
1869         b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0));
1870     SmallVector<Value> slice_begin_indices(input_rank, zero);
1871     slice_begin_indices[axis] = b.create<SubOp>(axis_size, offset);
1872     rewriter.replaceOpWithNewOp<DynamicSliceOp>(
1873         op, input_ty, concat, slice_begin_indices,
1874         rewriter.getI64TensorAttr(input_shape));
1875     return success();
1876   }
1877 };
1878 
1879 /// Converts a TF::LeakyReluOp to HLO.
1880 /// LeakyRelu(x) = alpha * x if x < 0 else x.
1881 class ConvertLeakyReluOp : public OpRewritePattern<TF::LeakyReluOp> {
1882  public:
1883   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluOp op,PatternRewriter & rewriter) const1884   LogicalResult matchAndRewrite(TF::LeakyReluOp op,
1885                                 PatternRewriter &rewriter) const override {
1886     Location loc = op.getLoc();
1887     float alpha = op.alpha().convertToFloat();
1888     Value features = op.features();
1889     auto featureType = features.getType().cast<RankedTensorType>();
1890     ArrayRef<int64_t> featureShape = featureType.getShape();
1891     Type eltType = featureType.getElementType();
1892 
1893     auto alphaVal = rewriter.create<mhlo::ConstOp>(
1894         loc, rewriter.getFloatAttr(eltType, alpha));
1895 
1896     // Broadcast `alpha` to match the shape of feature.
1897     auto featureShapeAttr = DenseIntElementsAttr::get(
1898         RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
1899         featureShape);
1900     auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
1901         loc, featureType, alphaVal, featureShapeAttr);
1902 
1903     Attribute zeroAttr = rewriter.getZeroAttr(featureType);
1904     Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
1905 
1906     Value leakyActivationVal = rewriter.create<mhlo::MulOp>(
1907         loc, features.getType(), features, broadcastAlphaVal);
1908 
1909     StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
1910     Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1911         loc, features, zeroVal, compare_direction);
1912 
1913     rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1914                                           features, leakyActivationVal);
1915     return success();
1916   }
1917 };
1918 
1919 /// Converts a TF::LeakyReluGradOp to HLO.
1920 /// LeakyReluGrad(gradient, inputs) = gradient if input > 0
1921 /// else alpha * gradient.
1922 class ConvertLeakyReluGradOp : public OpRewritePattern<TF::LeakyReluGradOp> {
1923  public:
1924   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluGradOp op,PatternRewriter & rewriter) const1925   LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
1926                                 PatternRewriter &rewriter) const override {
1927     Location loc = op.getLoc();
1928     float alpha = op.alpha().convertToFloat();
1929     Value gradients = op.gradients();
1930     Value features = op.features();
1931     auto featureType = features.getType().cast<RankedTensorType>();
1932     ArrayRef<int64_t> featureShape = featureType.getShape();
1933     Type eltType = featureType.getElementType();
1934 
1935     auto alphaVal = rewriter.create<mhlo::ConstOp>(
1936         loc, rewriter.getFloatAttr(eltType, alpha));
1937     auto featureShapeAttr = DenseIntElementsAttr::get(
1938         RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
1939         featureShape);
1940     auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
1941         loc, featureType, alphaVal, featureShapeAttr);
1942 
1943     Attribute zeroAttr = rewriter.getZeroAttr(featureType);
1944     Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
1945 
1946     Value leakyGradientVal = rewriter.create<mhlo::MulOp>(
1947         loc, features.getType(), gradients, broadcastAlphaVal);
1948 
1949     StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
1950 
1951     Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1952         loc, features, zeroVal, compare_direction);
1953 
1954     rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1955                                           gradients, leakyGradientVal);
1956     return success();
1957   }
1958 };
1959 
1960 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1961 // For a Rank-2 input, it creates the following ops:
1962 //   %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1963 //   %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1964 //   %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1965 //   %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1966 //   %5 = "mhlo.broadcast"(%4)
1967 //   %6 = "mhlo.select"(%3, %input, %5)
1968 //   %7 = "mhlo.reduce"(%6, %4) ( {
1969 //   ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1970 //     %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1971 //     "mhlo.return"(%9) : (tensor<f32>) -> ()
1972 //   }) {dimensions = dense<0> : tensor<1xi64>}
1973 //
1974 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1975 // create the above ops, then reshape it back to rank N/2.
1976 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1977  public:
1978   using OpRewritePattern::OpRewritePattern;
1979 
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1980   LogicalResult matchAndRewrite(TF::DiagPartOp op,
1981                                 PatternRewriter &rewriter) const override {
1982     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1983     if (!input_type || !input_type.hasStaticShape()) return failure();
1984     int64_t num_dims = input_type.getRank();
1985     if (num_dims < 2 || num_dims % 2 != 0) return failure();
1986     const int64_t out_dims = num_dims / 2;
1987 
1988     int64_t new_size = 1;
1989     llvm::SmallVector<int64_t, 4> new_dims;
1990     for (int i = 0; i < out_dims; i++) {
1991       if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1992         return op.emitOpError("invalid dimensions size");
1993       new_size *= input_type.getDimSize(i);
1994       new_dims.push_back(input_type.getDimSize(i));
1995     }
1996     Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1997         op.getLoc(),
1998         RankedTensorType::get({new_size, new_size},
1999                               input_type.getElementType()),
2000         op.input());
2001     auto iota_type = RankedTensorType::get({new_size, new_size},
2002                                            rewriter.getIntegerType(32));
2003     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
2004                                          rewriter.getI64IntegerAttr(0));
2005     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
2006                                          rewriter.getI64IntegerAttr(1));
2007     Value compare = rewriter.create<CompareOp>(
2008         op.getLoc(), iota0, iota1,
2009         StringAttr::get(rewriter.getContext(), "EQ"));
2010     Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
2011                                       0, &rewriter);
2012     Value zero_matrix = rewriter.create<BroadcastOp>(
2013         op.getLoc(), reshaped_input.getType(), zero,
2014         GetI64ElementsAttr({new_size, new_size}, &rewriter));
2015     Value masked =
2016         rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
2017                                   compare, reshaped_input, zero_matrix);
2018     auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
2019                                             GetI64ElementsAttr({0}, &rewriter));
2020     assert(!input_type.getElementType().isInteger(1) &&
2021            "data type should not be i1");
2022     BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
2023                            &rewriter);
2024     rewriter.replaceOpWithNewOp<ReshapeOp>(
2025         op, RankedTensorType::get(new_dims, input_type.getElementType()),
2026         reduce.getResult(0));
2027     return success();
2028   }
2029 };
2030 
2031 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
2032 class ConvertMatrixDiagPartV3Op
2033     : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
2034   using Shape = llvm::SmallVector<int64_t, 4>;
2035 
2036   // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
2037   // with k. This can be either a single value (for a single diagonal) or a
2038   // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const2039   LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
2040     DenseIntElementsAttr kattr;
2041     if (!matchPattern(op.k(), m_Constant(&kattr))) {
2042       return failure();
2043     }
2044     DenseIntElementsAttr::iterator it = kattr.begin();
2045     (*k)[0] = (*it).getSExtValue();
2046     it++;
2047     if (it == kattr.end()) {
2048       // Handle input like e.g. "k = 5", in which case we extract a single
2049       // diagonal.
2050       (*k)[1] = (*k)[0];
2051     } else {
2052       // Handle input like e.g. "k = [-1, 1]", in which case we extract a
2053       // band (multiple diagonals).
2054       (*k)[1] = (*it).getSExtValue();
2055     }
2056     return success();
2057   }
2058 
2059   // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const2060   BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
2061                                 int int_size, PatternRewriter &rewriter) const {
2062     return rewriter.create<BroadcastOp>(
2063         loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
2064         GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
2065                              &rewriter),
2066         GetI64ElementsAttr(shape, &rewriter));
2067   }
2068 
2069  public:
2070   using OpRewritePattern::OpRewritePattern;
2071 
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const2072   LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
2073                                 PatternRewriter &rewriter) const override {
2074     Location loc = op.getLoc();
2075     ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
2076     auto element_type = input_type.getElementType();
2077 
2078     // Align is a string specifying how superdiagonals and subdiagonals should
2079     // be aligned/padded for diagonals that are shorter than max_diag_len. The
2080     // format is "{super}_{sub}", with {super} the superdiagonal alignment and
2081     // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
2082     // left, "RIGHT" means rows will be padded ot the right.  The default is
2083     // "RIGHT_LEFT".
2084     StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
2085     enum Alignment { kLeft, kRight };
2086 
2087     // default is RIGHT_LEFT
2088     Alignment superdiagonal_align = kRight;
2089     Alignment subdiagonal_align = kLeft;
2090 
2091     if (align == "RIGHT_LEFT") {
2092       superdiagonal_align = kRight;
2093       subdiagonal_align = kLeft;
2094     } else if (align == "RIGHT_RIGHT") {
2095       superdiagonal_align = kRight;
2096       subdiagonal_align = kRight;
2097     } else if (align == "LEFT_RIGHT") {
2098       superdiagonal_align = kLeft;
2099       subdiagonal_align = kRight;
2100     } else if (align == "LEFT_LEFT") {
2101       superdiagonal_align = kLeft;
2102       subdiagonal_align = kLeft;
2103     } else {
2104       return failure();  // unsupported alignment
2105     }
2106 
2107     // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
2108     // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
2109     if (!input_type || !input_type.hasStaticShape()) return failure();
2110     int64_t num_dims = input_type.getRank();
2111     if (num_dims < 2) return failure();
2112     int64_t rows = input_type.getDimSize(num_dims - 2);  // rows
2113     int64_t cols = input_type.getDimSize(num_dims - 1);  // cols
2114 
2115     // We extract the diagonals from k[0] up to and including k[1].
2116     // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
2117     // the main diagonal). It's negative for subdiagonals (under and to the left
2118     // of the main diagonal) and positive for superdiagonals (above and to the
2119     // right of the main diagonal).
2120     int64_t k[2];
2121     if (failed(ExtractK(op, &k))) return failure();
2122     int num_diags = k[1] - k[0] + 1;
2123 
2124     // Shifting diagonals away from the main diagonal might shorten them. This
2125     // is the longest diagonal we will see. We make this the last dimension of
2126     // the output shape.
2127     int64_t max_diag_len =
2128         std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
2129                  cols + std::min(-k[0], static_cast<int64_t>(0)));
2130 
2131     // The first dimension is the index vector dimension we'll use for gather.
2132     // It's 1 here, but will be 2 once we glue x and y together.
2133     Shape indices_shape({1, num_diags, max_diag_len});
2134 
2135     RankedTensorType iota_type =
2136         RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
2137     Value iotaM =
2138         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
2139     Value iotaN =
2140         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
2141 
2142     // Boradcasted constants, of the same shape as iotaM and iotaN.
2143     Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
2144     Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
2145     Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
2146     Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
2147     Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
2148     Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
2149     Value b_max_diag_len =
2150         BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
2151 
2152     // d = k[1] - m
2153     // (A.k.a. the number of the diagonal, depending on m. Note that we
2154     //  subtract m here. This means we start with the superdiagonals and
2155     //  move downwards towards the subdiagonals. So the start indices will
2156     //  be decreasing.)
2157     Value d = rewriter.create<SubOp>(loc, b_k1, iotaM);
2158     Value neg_d = rewriter.create<NegOp>(loc, d);
2159 
2160     // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
2161     // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
2162     Value diag_len_d = rewriter.create<MinOp>(
2163         loc,
2164         rewriter.create<AddOp>(loc, b_rows,
2165                                rewriter.create<MinOp>(loc, d, b_zero)),
2166         rewriter.create<SubOp>(loc, b_cols,
2167                                rewriter.create<MaxOp>(loc, d, b_zero)));
2168 
2169     // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
2170     Value cmp;
2171     if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
2172       cmp = b_true;
2173     } else if (superdiagonal_align == kRight) {
2174       // offset = d>=0 ? max_diag_len - diag_len_d : 0
2175       cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
2176     } else if (subdiagonal_align == kRight) {
2177       // offset = d<=0 ? max_diag_len - diag_len_d : 0
2178       cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
2179     } else {
2180       // offset = 0
2181       cmp = b_false;
2182     }
2183 
2184     // This offset shifts the diagonals to the "left" or "right", depending
2185     // on alignment.
2186     Value offset = rewriter.create<SelectOp>(
2187         loc, b_zero.getType(), cmp,
2188         rewriter.create<SubOp>(loc, b_max_diag_len, diag_len_d), b_zero);
2189 
2190     // x = max(d, 0) - offset
2191     // y = max(-d, 0) - offset
2192     Value x = rewriter.create<SubOp>(
2193         loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
2194     Value y = rewriter.create<SubOp>(
2195         loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
2196 
2197     Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
2198     Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
2199 
2200     // GatherOp is happy about letting us index out of bounds values, but those
2201     // values will be undefined. So we mask them later. Set up the boolean
2202     // expression that tells us which entries, in the output shape, are out of
2203     // bounds and thus become the padding_value.
2204     Value x_in_bounds = rewriter.create<AndOp>(
2205         loc,
2206         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
2207                                             b_zero),
2208         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
2209     Value y_in_bounds = rewriter.create<AndOp>(
2210         loc,
2211         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
2212                                             b_zero),
2213         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
2214     Value in_bounds = rewriter.create<ReshapeOp>(
2215         loc,
2216         RankedTensorType::get(Shape({num_diags, max_diag_len}),
2217                               rewriter.getIntegerType(1)),
2218         rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
2219 
2220     // Now combine x and y into the index data structure needed for gather.
2221     Shape concat_shape({2, num_diags, max_diag_len});
2222     Value start_indices = rewriter.create<ConcatenateOp>(
2223         loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
2224         mlir::ValueRange({n_plus_y, n_plus_x}),
2225         mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
2226 
2227     // Shape of the final output. (Except for dimension folding in the
2228     // single diagonal case.)
2229     Shape output_shape;
2230     for (int i = 0; i < num_dims - 2; i++) {
2231       output_shape.push_back(input_type.getDimSize(i));
2232     }
2233     output_shape.push_back(num_diags);
2234     output_shape.push_back(max_diag_len);
2235     auto output_type = RankedTensorType::get(output_shape, element_type);
2236 
2237     // A slice is the shape of what GatherOp copies per lookup. So the last
2238     // two dimensions (M, N in the matrix-diag-part docs) are where we go
2239     // through entry by entry.
2240     ArrayRef<int64_t> input_shape = input_type.getShape();
2241     Shape slice_sizes(input_shape.begin(), input_shape.end());
2242     int slice_dimensions = slice_sizes.size();
2243     slice_sizes[slice_dimensions - 2] = 1;
2244     slice_sizes[slice_dimensions - 1] = 1;
2245 
2246     // Dimensions of the input we won't see in the output (M and N).
2247     SmallVector<int64_t, 2> collapsed_dims(
2248         {slice_dimensions - 2, slice_dimensions - 1});
2249 
2250     // Which dimensions (in the input) the two offset "columns" map to.
2251     SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
2252 
2253     // Gather the diagonal entries.
2254     // TODO(kramm): For a single diagonal, this might be slower than the
2255     //              mask + sum approach. Special-case num_diags==1?
2256     auto dims_attr = GatherDimensionNumbers::get(
2257         /*offset_dims=*/GetI64ElementsAttrForSeq(0, num_dims - 2, &rewriter),
2258         /*collapsed_slice_dims=*/GetI64ElementsAttr(collapsed_dims, &rewriter),
2259         /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
2260         /*index_vector_dim=*/rewriter.getI64IntegerAttr(0),
2261         rewriter.getContext());
2262     Value gather = rewriter.create<mhlo::GatherOp>(
2263         loc, output_type, op.input(), start_indices, dims_attr,
2264         GetI64ElementsAttr(slice_sizes, &rewriter));
2265 
2266     // We now need to broadcast the "in_bounds" boolean expression, as well as
2267     // the padding value, to do the final select.
2268     Shape broadcast_bounds;
2269     for (int i = 0; i < output_shape.size() - 2; i++) {
2270       broadcast_bounds.push_back(output_shape[i]);
2271     }
2272     Value b_in_bounds = rewriter.create<BroadcastOp>(
2273         loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
2274         in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
2275     Value b_padding = rewriter.create<BroadcastOp>(
2276         loc, output_type, op.padding_value(),
2277         GetI64ElementsAttr(output_shape, &rewriter));
2278 
2279     // Replace all out-of-bounds values in the result with padding_value.
2280     Value result = rewriter.create<SelectOp>(loc, output_type, b_in_bounds,
2281                                              gather, b_padding);
2282 
2283     if (num_diags == 1) {
2284       // matrix_diag_part folds away the 1-sized band dimension if we only
2285       // extract a single diagonal.
2286       result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2287     }
2288 
2289     rewriter.replaceOp(op, result);
2290     return success();
2291   }
2292 };
2293 
2294 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
2295 // depending on arity of the op.
2296 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
2297  public:
2298   using OpRewritePattern::OpRewritePattern;
2299 
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const2300   LogicalResult matchAndRewrite(TF::EinsumOp op,
2301                                 PatternRewriter &rewriter) const override {
2302     StringAttr equation = op->getAttrOfType<StringAttr>("equation");
2303     if (op.N() == 1) {
2304       rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
2305           op, op.getType(), *op.inputs().begin(), equation);
2306     } else if (op.N() == 2) {
2307       ValueRange inputs = op.inputs();
2308       rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
2309                                             inputs[1], equation);
2310     } else {
2311       // TensorFlow EinsumOp verifies that the number of operands are at most
2312       // two.
2313       return failure();
2314     }
2315     return success();
2316   }
2317 };
2318 
2319 // Bypasses IdentityN op.
2320 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
2321  public:
2322   using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const2323   LogicalResult matchAndRewrite(TF::IdentityNOp op,
2324                                 PatternRewriter &rewriter) const override {
2325     rewriter.replaceOp(op, op.getOperands());
2326     return success();
2327   }
2328 };
2329 
2330 template <typename OpTy>
2331 class ConvertFFTOp : public OpRewritePattern<OpTy> {
2332  public:
2333   using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2334   LogicalResult matchAndRewrite(OpTy op,
2335                                 PatternRewriter &rewriter) const override {
2336     auto input_ty = op.input().getType().template cast<ShapedType>();
2337     if (!input_ty.hasRank()) {
2338       return failure();
2339     }
2340     auto input_shape = input_ty.getShape();
2341     DenseIntElementsAttr fft_length_attr;
2342     if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
2343       return failure();
2344     }
2345     int64_t fft_length;
2346     if (fft_length_attr.getNumElements() != 0) {
2347       fft_length = fft_length_attr.getValue<IntegerAttr>(0).getInt();
2348     } else {
2349       return failure();
2350     }
2351 
2352     std::string fft_string = "RFFT";
2353     if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
2354       fft_length = fft_length / 2 + 1;
2355       fft_string = "IRFFT";
2356     }
2357     Location loc = op.getLoc();
2358 
2359     // The inner-most dim cannot be dynamic.
2360     if (input_ty.isDynamicDim(input_shape.size() - 1)) {
2361       return failure();
2362     }
2363 
2364     auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
2365     expected_shape.push_back(fft_length);
2366 
2367     // Zero pad or truncate the last axis
2368     Value reshaped = op.input();
2369     SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
2370     SmallVector<int64_t, 4> strides(input_shape.size(), 1);
2371 
2372     // Last dim larger than fft_length, slice the input
2373     if (input_shape.back() > fft_length) {
2374       reshaped = rewriter.create<SliceOp>(
2375           op.getLoc(),
2376           RankedTensorType::get(expected_shape, input_ty.getElementType()),
2377           op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
2378           GetI64ElementsAttr(expected_shape, &rewriter),
2379           GetI64ElementsAttr(strides, &rewriter));
2380 
2381       // Last dim smaller than fft_length, zero-pad the input
2382     } else if (input_ty.getShape().back() < fft_length) {
2383       SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
2384       SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
2385       padding.push_back(fft_length - input_shape.back());
2386       Value zero =
2387           GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
2388       reshaped = rewriter.create<PadOp>(
2389           loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
2390           op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
2391           GetI64ElementsAttr(padding, &rewriter),
2392           GetI64ElementsAttr(no_padding, &rewriter));
2393     }
2394 
2395     rewriter.replaceOpWithNewOp<FftOp>(op, op.getType(), reshaped, fft_string,
2396                                        rewriter.getI64TensorAttr(fft_length));
2397     return success();
2398   }
2399 };
2400 
2401 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
2402 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
2403 
2404 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
2405 // BatchNormGradOp for training and a sequence of binary ops for inference.
2406 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
2407 template <typename FusedBatchNormGradOpT>
2408 class ConvertFusedBatchNormGradBase
2409     : public OpRewritePattern<FusedBatchNormGradOpT> {
2410  public:
2411   using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
2412 
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const2413   LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
2414                                 PatternRewriter &rewriter) const override {
2415     Location loc = op.getLoc();
2416     Value grad = op.y_backprop();
2417     Value act = op.x();
2418     Value scale = op.scale();
2419     Value mean = op.reserve_space_1();
2420     Value var = op.reserve_space_2();
2421 
2422     // TODO(b/141785544): Update this to not require static shapes.
2423     // activation shape needs to be static to convert negative indices in
2424     // TensorFlow to absolute indices required by HLO.
2425     RankedTensorType act_type =
2426         act.getType().template dyn_cast<RankedTensorType>();
2427     if (!act_type) return failure();
2428     Type act_ele_type = act_type.getElementType();
2429     // To support mixed precision, the statistics type, which maybe more
2430     // precise than the input types, are used for this op.
2431     Type kernel_type =
2432         scale.getType().template cast<TensorType>().getElementType();
2433     grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
2434     act = rewriter.create<ConvertOp>(loc, act, kernel_type);
2435 
2436     tensorflow::TensorFormat data_format;
2437     if (!FormatFromString(op.data_format().str(), &data_format))
2438       return op.emitOpError("invalid data format");
2439 
2440     auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act);
2441     auto feature_dim = feature_dim_attr.getValue().getSExtValue();
2442 
2443     // Gets the result values.
2444     Value x_backprop, scale_backprop, offset_backprop;
2445     if (op.is_training()) {  // training
2446       // TODO(b/145536565): handle GPU logic separately.
2447       // Infers the output type with the converted `act`.
2448       Type feature_type = RankedTensorType::get(
2449           {GetDimSize(act_type, feature_dim)}, kernel_type);
2450       Type result_type = TupleType::get(
2451           rewriter.getContext(), {act.getType(), feature_type, feature_type});
2452 
2453       auto training_op = rewriter.create<BatchNormGradOp>(
2454           loc, result_type, act, scale, mean, var, grad, op.epsilon(),
2455           feature_dim);
2456 
2457       x_backprop =
2458           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
2459 
2460       scale_backprop =
2461           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
2462 
2463       offset_backprop =
2464           rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
2465     } else {  // inference
2466       SmallVector<int64_t, 4> non_feature_dims;
2467       for (int64_t i = 0; i < act_type.getRank(); ++i) {
2468         if (i == feature_dim) continue;
2469         non_feature_dims.push_back(i);
2470       }
2471       auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
2472       auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2473 
2474       // scratch1 = rsqrt(var + epsilon)
2475       RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
2476       auto epsilon = rewriter.create<ConstOp>(
2477           loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
2478       auto add_op = rewriter.create<chlo::BroadcastAddOp>(
2479           loc, var, epsilon.getResult(), scalar_broadcast_dims);
2480 
2481       Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
2482 
2483       // scratch2 = sum(y_backprop * (x - mean))
2484       auto sub_op = rewriter.create<mhlo::SubOp>(
2485           loc, act,
2486           Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
2487       auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
2488       Value scratch2 =
2489           ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
2490 
2491       // x_backprop = y_backprop * (scale * scratch1)
2492       auto scaled_grad =
2493           rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
2494       x_backprop = rewriter.create<mhlo::MulOp>(
2495           loc, grad,
2496           Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
2497                                   rewriter));
2498 
2499       // scale_backprop = scratch2 * scratch1
2500       scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
2501 
2502       // offset_backprop = sum(y_backprop)
2503       offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
2504     }
2505 
2506     x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
2507     Value last_val[2];
2508     if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
2509       // It doesn't matter what values we provide for the last 2 results.
2510       last_val[0] = last_val[1] = op.x();
2511     } else {
2512       auto const_val = rewriter.create<ConstOp>(
2513           op.getLoc(),
2514           DenseElementsAttr::get<float>(
2515               RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
2516               0.0));
2517       auto maybe_cast = [&](Value val, Type t) -> Value {
2518         if (val.getType() == t) return val;
2519         return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
2520       };
2521       last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
2522       last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
2523     }
2524     rewriter.replaceOp(
2525         op, {/*x_backprop=*/x_backprop,
2526              /*scale_backprop=*/scale_backprop,
2527              /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
2528     return success();
2529   }
2530 };
2531 
2532 using ConvertFusedBatchNormGradOp =
2533     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
2534 using ConvertFusedBatchNormGradV2Op =
2535     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
2536 using ConvertFusedBatchNormGradV3Op =
2537     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
2538 
2539 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
2540 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
2541 // parameter.
2542 template <typename FusedBatchNormOpT>
2543 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
2544  public:
2545   using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
2546 
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const2547   LogicalResult matchAndRewrite(FusedBatchNormOpT op,
2548                                 PatternRewriter &rewriter) const override {
2549     tensorflow::TensorFormat data_format;
2550     if (!FormatFromString(op.data_format().str(), &data_format))
2551       return op.emitOpError("invalid data format");
2552 
2553     auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.x());
2554 
2555     auto input_type_tensor = op.x().getType().template cast<TensorType>();
2556     auto input_element_type = input_type_tensor.getElementType();
2557 
2558     auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
2559     auto scale_element_type = scale_type_tensor.getElementType();
2560 
2561     auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
2562     auto mean_element_type = mean_type_tensor.getElementType();
2563     // In the training case, dimensions of input tensors must be static.
2564     if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
2565                              !scale_type_tensor.hasStaticShape() ||
2566                              !mean_type_tensor.hasStaticShape()))
2567       return failure();
2568 
2569     // TODO(b/69928690): Support mixed precision in the XLA batch
2570     // normalization operators. As a workaround, create a new x with the same
2571     // element type as scale (which may be more precise than the input type).
2572     Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
2573                                                             scale_element_type);
2574     TensorType bn_train_input_type_tensor =
2575         bn_train_input.getType().template cast<TensorType>();
2576 
2577     if (op.is_training()) {
2578       // Training case.
2579       auto operand_shape = bn_train_input_type_tensor.getShape();
2580       // The mean and variance are each 1 dimensional arrays the size of the
2581       // feature dimension, with the same element type as the operand (x).
2582       // This shape must be constructed manually because the mean and variance
2583       // inputs are empty in the training case.
2584       Type mean_var_type = RankedTensorType::get(
2585           {operand_shape[feature_dim.getInt()]}, scale_element_type);
2586       // Op result type is a tuple of 3 values: output with same shape as input;
2587       // batch_mean, and batch_var.
2588       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
2589                                             mean_var_type, mean_var_type};
2590       Type result_type = TupleType::get(rewriter.getContext(), operand_types);
2591 
2592       auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
2593           op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
2594           op.epsilon(), feature_dim.getInt());
2595       // HLO op outputs a tuple of tensors. Extract those results.
2596       auto bn_train_op_result = bn_train_op.getResult();
2597       Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
2598           op.getLoc(), bn_train_op_result, 0);
2599       Value batch_mean = rewriter.create<mhlo::GetTupleElementOp>(
2600           op.getLoc(), bn_train_op_result, 1);
2601       Value reserve_space_1 = batch_mean;
2602       Value batch_variance = rewriter.create<mhlo::GetTupleElementOp>(
2603           op.getLoc(), bn_train_op_result, 2);
2604 
2605       // Apply Bessel's correction on the variance.
2606       int total_input_size = bn_train_input_type_tensor.getNumElements();
2607       int total_scale_size = scale_type_tensor.getNumElements();
2608       int sample_size = total_input_size / total_scale_size;
2609       int sample_size_minus_one = std::max(1, sample_size - 1);
2610       double factor = static_cast<double>(sample_size) /
2611                       static_cast<double>(sample_size_minus_one);
2612       auto factor_const_op = rewriter.create<mhlo::ConstOp>(
2613           op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2614 
2615       Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2616           op.getLoc(), batch_variance.getType(), batch_variance,
2617           factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2618 
2619       // Convert back to input type to stay aligned with expected output type
2620       // for TF op.
2621       y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2622                                                input_element_type);
2623 
2624       float exponential_avg_factor =
2625           op.exponential_avg_factor().convertToFloat();
2626       if (exponential_avg_factor != 1.0f) {
2627         auto alpha = rewriter.create<mhlo::ConstOp>(
2628             op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2629                                                1.0f - exponential_avg_factor));
2630         auto beta = rewriter.create<mhlo::ConstOp>(
2631             op.getLoc(),
2632             rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2633 
2634         // new_running_mean = alpha * old_mean + beta * batch_mean.
2635         auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2636             op.getLoc(), op.mean().getType(), alpha, op.mean(),
2637             /*broadcast_dimensions=*/DenseIntElementsAttr());
2638         auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2639             op.getLoc(), batch_mean.getType(), beta, batch_mean,
2640             /*broadcast_dimensions=*/DenseIntElementsAttr());
2641         batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2642             op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2643             /*broadcast_dimensions=*/DenseIntElementsAttr());
2644 
2645         // new_running_variance = alpha * old_variance + beta * batch_variance.
2646         auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2647             op.getLoc(), op.variance().getType(), alpha, op.variance(),
2648             /*broadcast_dimensions=*/DenseIntElementsAttr());
2649         auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2650             op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2651             /*broadcast_dimensions=*/DenseIntElementsAttr());
2652         corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2653             op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2654             /*broadcast_dimensions=*/DenseIntElementsAttr());
2655       }
2656 
2657       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2658         // FusedBatchNormV2 expects 4 outputs.
2659         // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2660         // They are used to pass the per-batch mean and variance to the
2661         // gradiant. Here we maintain the same behavior by setting them to the
2662         // mean and variance calculated by BatchNormTraining.
2663         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2664                                 /*batch_variance=*/corrected_variance,
2665                                 /*reserve_space_1=*/reserve_space_1,
2666                                 /*reserve_space_2=*/batch_variance});
2667       } else {  // TF::FusedBatchNormV3Op
2668         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2669         // last reserve_space_3 output.
2670         auto reserve_space_3_type =
2671             op.getResult(5).getType().template cast<TensorType>();
2672         int num_elements = reserve_space_3_type.hasStaticShape()
2673                                ? reserve_space_3_type.getNumElements()
2674                                : 0;
2675         auto const_attr_type = RankedTensorType::get(
2676             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2677         Value dummy_const = rewriter.create<ConstOp>(
2678             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2679         if (const_attr_type != reserve_space_3_type)
2680           dummy_const = rewriter.create<tensor::CastOp>(
2681               op.getLoc(), reserve_space_3_type, dummy_const);
2682         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2683                                 /*batch_variance=*/corrected_variance,
2684                                 /*reserve_space_1=*/reserve_space_1,
2685                                 /*reserve_space_2=*/batch_variance,
2686                                 /*reserve_space_3=*/dummy_const});
2687       }
2688     } else {  // Inference case.
2689       auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2690           op.getLoc(),
2691           /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2692           op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2693           feature_dim.getInt());
2694 
2695       // Convert back to input type to stay aligned with expected output type
2696       // for TF op.
2697       auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2698                                                     input_element_type);
2699 
2700       // The mean, variance, and reserved space outputs of the batch norm op are
2701       // not used for inference. It doesn't matter what values we provide for
2702       // the last 5 results as long as they are of the same type. Forward
2703       // input mean and variance to output mean, variance, reserved_space_1 and
2704       // reserved_space_2.
2705       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2706         rewriter.replaceOp(op, {/*y=*/y_out,
2707                                 /*batch_mean=*/op.mean(),
2708                                 /*batch_variance=*/op.variance(),
2709                                 /*reserve_space_1=*/op.mean(),
2710                                 /*reserve_space_2=*/op.variance()});
2711       } else {
2712         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2713         // last reserve_space_3 output.
2714         auto reserve_space_3_type =
2715             op.getResult(5).getType().template cast<TensorType>();
2716         int num_elements = reserve_space_3_type.hasStaticShape()
2717                                ? reserve_space_3_type.getNumElements()
2718                                : 0;
2719         auto const_attr_type = RankedTensorType::get(
2720             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2721         Value dummy_const = rewriter.create<ConstOp>(
2722             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2723         if (const_attr_type != reserve_space_3_type)
2724           dummy_const = rewriter.create<tensor::CastOp>(
2725               op.getLoc(), reserve_space_3_type, dummy_const);
2726         rewriter.replaceOp(op, {/*y=*/y_out,
2727                                 /*batch_mean=*/op.mean(),
2728                                 /*batch_variance=*/op.variance(),
2729                                 /*reserve_space_1=*/op.mean(),
2730                                 /*reserve_space_2=*/op.variance(),
2731                                 /*reserve_space_3=*/dummy_const});
2732       }
2733     }
2734     return success();
2735   }
2736 };
2737 
2738 using ConvertFusedBatchNormV2Op =
2739     ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2740 using ConvertFusedBatchNormV3Op =
2741     ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2742 
2743 using PaddingArray =
2744     std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
2745 
2746 // Returns padding values for ReduceWindow op as a vector of pairs.
2747 //
2748 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2749 // dimensions to be equal to the size of window dimensions and window strides.
2750 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2751 static PaddingArray GetReduceWindowPaddingAsArray(
2752     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2753     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2754   if (padding == "VALID") {
2755     return PaddingArray(num_dims, std::make_pair(0, 0));
2756   }
2757   assert(padding == "SAME");
2758   llvm::SmallVector<tensorflow::int64, num_dims> input_shape, window_shape,
2759       strides;
2760   input_shape.reserve(input_dims.size());
2761   window_shape.reserve(window_shape.size());
2762   strides.reserve(window_strides.size());
2763 
2764   for (const auto &dim : input_dims) input_shape.push_back(dim);
2765   for (Attribute attr : window_dims)
2766     window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2767   for (Attribute attr : window_strides)
2768     strides.push_back(attr.cast<IntegerAttr>().getInt());
2769 
2770   PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2771                                              ::xla::Padding::kSame);
2772   return paddings;
2773 }
2774 
2775 // Same as GetReduceWindowPaddingAsArray but returns padding as
2776 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2777 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2778 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2779     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2780     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2781   if (padding == "VALID") return {};
2782   assert(padding == "SAME");
2783   PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2784       input_dims, window_dims, window_strides, padding, builder);
2785   int64_t rank = paddings.size();
2786   llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2787   for (int i = 0; i < rank; i++) {
2788     flatten_paddings[2 * i] = paddings[i].first;
2789     flatten_paddings[2 * i + 1] = paddings[i].second;
2790   }
2791   return DenseIntElementsAttr::get(
2792       RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2793       flatten_paddings);
2794 }
2795 
2796 // Helper function for dividing each entry of `pooled` by the count of its
2797 // corresponding window, i.e., the number of non-padding entries of the window
2798 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2799 // to this entry, depending on `ksize` and `strides`. This function is used for
2800 // `AvgPool` and `AvgPoolGrad` legalizations.
2801 // `zero` is passed as a parameter because it can be reused from caller level.
2802 // `pooled` must have `RankedTensorType`.
2803 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2804 Operation *AvgPoolDivideByCount(
2805     Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2806     const SmallVector<int64_t, num_dims> &ksize,
2807     const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2808     PatternRewriter &rewriter) {
2809   Location loc = op.getLoc();
2810   RankedTensorType pooled_type =
2811       pooled.getType().template cast<RankedTensorType>();
2812   Type element_type = pooled_type.getElementType();
2813   Operation *result = nullptr;
2814   RankedTensorType orig_input_type =
2815       RankedTensorType::get(input_shape, element_type);
2816 
2817   if (op.padding() == "VALID") {
2818     // All window counts are equal here because we don't have padding
2819     // (each entry of `pooled` corresponds to a window that consists of
2820     //  original input entries only).
2821     int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2822                                            std::multiplies<int64_t>());
2823     // Divide `pooled` by window counts.
2824     Value divisor =
2825         GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2826     auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2827     result = rewriter.create<chlo::BroadcastDivOp>(
2828         loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2829   } else {
2830     assert(op.padding() == "SAME");
2831     // For SAME padding, only original entries that contributed to a window
2832     // are counted for the average of this window, not padded entries.
2833 
2834     // Build all-ones tensor of same shape as the original input.
2835     ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2836     auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
2837 
2838     // Get padding for the input.
2839     DenseIntElementsAttr input_padding_attr =
2840         GetReduceWindowPaddingAsAttr<num_dims>(
2841             input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2842 
2843     // Count the 1's in each window, using the same padding as for the input,
2844     // which gives us the window counts by which `pooled` needs to be divided.
2845     auto divisor = rewriter.create<ReduceWindowOp>(
2846         loc, pooled_type,
2847         /*operand=*/all_ones_tensor,
2848         /*init_value=*/zero,
2849         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2850         /*window_strides=*/GetI64ElementsAttr(op.strides()),
2851         /*base_dilations=*/DenseIntElementsAttr(),
2852         /*window_dilations=*/DenseIntElementsAttr(),
2853         /*padding=*/input_padding_attr);
2854     BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2855 
2856     // Divide `pooled` by window counts.
2857     result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled,
2858                                           divisor.getResult(0));
2859   }
2860   return result;
2861 }
2862 
GetAvgPoolInput(TF::AvgPoolOp op)2863 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2864 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2865 
2866 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2867 // dimensions with add as the reduction function. The reduction result is
2868 // then divided by the number of elements in the window.
2869 template <typename OpTy, int num_dims>
2870 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2871  public:
2872   using OpRewritePattern<OpTy>::OpRewritePattern;
2873 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2874   LogicalResult matchAndRewrite(OpTy op,
2875                                 PatternRewriter &rewriter) const override {
2876     Value input_value = GetAvgPoolInput(op);
2877     auto input_type =
2878         input_value.getType().template dyn_cast<RankedTensorType>();
2879     if (!input_type) return failure();
2880 
2881     // We will do accumulation first; use a larger bitwidth if suitable.
2882     Type input_element_type = input_type.getElementType();
2883     Type sum_element_type = GetSumAccumulationType(input_element_type);
2884     Type result_type;
2885 
2886     // The result type for reduction and division with the proper element type.
2887     if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2888       result_type =
2889           RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2890     else
2891       result_type = UnrankedTensorType::get(sum_element_type);
2892 
2893     // Convert if we need enlarge the element type's bitwidth.
2894     if (input_element_type != sum_element_type)
2895       input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2896                                                sum_element_type);
2897 
2898     // Create the ReduceWindow op.
2899     Value init =
2900         GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2901     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2902         input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2903         &rewriter);
2904     auto reduce = rewriter.create<ReduceWindowOp>(
2905         op.getLoc(), result_type, input_value, init,
2906         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2907         /*base_dilations=*/DenseIntElementsAttr(),
2908         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2909     BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2910 
2911     // Count the number of elements in the window. The following calculation
2912     // is only valid for no paddings.
2913     SmallVector<int64_t, num_dims> input_shape(
2914         llvm::to_vector<num_dims>(input_type.getShape()));
2915     SmallVector<int64_t, num_dims> ksize, strides;
2916     GetI64ArrayAttrValues(op.ksize(), &ksize);
2917     GetI64ArrayAttrValues(op.strides(), &strides);
2918 
2919     Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2920         reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter);
2921 
2922     // Convert back if we enlarged the element type's bitwidth.
2923     Value result = result_op->getOpResult(0);
2924     if (input_element_type != sum_element_type)
2925       result =
2926           rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2927 
2928     rewriter.replaceOp(op, result);
2929     return success();
2930   }
2931 };
2932 
2933 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2934 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2935 
2936 // `AvgPoolGradOp` is converted to the following operations:
2937 // 1. Divide each entry of the output gradient (the gradient for the previous
2938 //    layer in backpropagation order) by the count of the corresponding window
2939 //    (i.e., the number of non-padding entries of the window which `AvgPool`
2940 //    has mapped to this entry in forward propagation).
2941 // 2. Add appropriate interior and exterior padding for step 3 (see example
2942 //    below).
2943 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2944 //    as windows) and stride 1 in each dimension. This is implemented as a
2945 //    `ReduceWindowOp` with `AddOp` as body.
2946 //
2947 // Example:
2948 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2949 // and SAME padding with 0's. It is defined by
2950 //    f(x) = [ (x_1 + x_2 + x_3) / 3 ]      ( x = (x_1, x_2, x_3, x_4) )
2951 //           [ (x_3 + x_4 + 0)   / 2 ]      (the 0 results from right padding)
2952 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2953 // for the average, this is why the second denominator is 2 and not 3.
2954 // The Jacobian Df is
2955 //    [ 1/3  1/3  1/3  0   ]
2956 //    [ 0    0    1/2  1/2 ]
2957 //
2958 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2959 // needs the original input shape and not the tensor as argument).
2960 // Let v = [ 4  6 ]^T  be the output gradient (^T = transposed). Then the
2961 // average pool gradient is given by
2962 //    Df^T * v = [ 4/3  4/3  13/3  3 ]^T
2963 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2964 // structure of Df by using the 3-step approach from above:
2965 // 1. Divide output gradient v by window counts: [ 4/3  6/2 ]^T
2966 // 2. Add appropriate padding: [ 0  0  4/3  0  3  0 ]^T
2967 // 3. Convolve with kernel [ 1  1  1 ]: [ 4/3  4/3  11/3  3 ]^T
2968 //
2969 // Note that the padding in step 2. is chosen in such a way that the subsequent
2970 // convolution produces the gradient. Higher dimensions, different padding, and
2971 // different windows/strides work in a similar way, the main difference is in
2972 // the computation of the paddings in step 2.
2973 //
2974 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2975 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2976 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2977 // examples for different cases.
2978 template <typename OpTy, int num_dims>
2979 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2980   using DimVector = SmallVector<int64_t, num_dims>;
2981 
2982  public:
2983   using OpRewritePattern<OpTy>::OpRewritePattern;
2984 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2985   LogicalResult matchAndRewrite(OpTy op,
2986                                 PatternRewriter &rewriter) const override {
2987     Location loc = op.getLoc();
2988     tensorflow::TensorFormat data_format;
2989     if (!FormatFromString(op.data_format().str(), &data_format)) {
2990       return op.emitOpError("invalid data format");
2991     }
2992     // `out_grad` is the gradient that was propagated via backpropagation from
2993     // the output layer.
2994     Value out_grad = op.grad();
2995     auto out_grad_type =
2996         out_grad.getType().template dyn_cast<RankedTensorType>();
2997     if (!out_grad_type) {
2998       return failure();
2999     }
3000     Type element_type = out_grad_type.getElementType();
3001     DenseIntElementsAttr orig_input_shape_attr;
3002     if (!matchPattern(op.orig_input_shape(),
3003                       m_Constant(&orig_input_shape_attr))) {
3004       return failure();
3005     }
3006     auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
3007     DimVector orig_input_shape(orig_input_shape_values.begin(),
3008                                orig_input_shape_values.end());
3009     DimVector ksize, strides;
3010     GetI64ArrayAttrValues(op.ksize(), &ksize);
3011     GetI64ArrayAttrValues(op.strides(), &strides);
3012     Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
3013 
3014     auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
3015         out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
3016 
3017     // Get same padding as for original input.
3018     PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
3019         orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
3020 
3021     // Add padding around `out_grad_divided` values in such a way that the
3022     // subsequent `ReduceWindowOp` produces the gradient.
3023     DimVector out_grad_shape(
3024         llvm::to_vector<num_dims>(out_grad_type.getShape()));
3025     DimVector low_padding(num_dims, 0);
3026     DimVector high_padding(num_dims, 0);
3027     DimVector interior_padding(num_dims, 0);
3028     constexpr int num_spatial_dims = num_dims - 2;
3029     for (int i = 0; i < num_spatial_dims; ++i) {
3030       int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
3031       int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
3032                                            orig_padding[dim].first +
3033                                            orig_padding[dim].second;
3034       // Set interior padding such that neighboring entries from
3035       // `out_grad_divided` have distance `strides[dim]` from each other in
3036       // every dimension.
3037       interior_padding[dim] = strides[dim] - 1;
3038       // Set exterior padding in the same way as for convolution gradient
3039       // computation.
3040       auto status = ::xla::ConvGradExtractAndVerifyDimension(
3041           /*input_size=*/orig_input_shape_padded_in_dim,
3042           /*filter_size=*/ksize[dim],
3043           /*output_size=*/out_grad_shape[dim],
3044           /*dilation=*/1,
3045           /*stride=*/strides[dim],
3046           /*padding=*/::xla::Padding::kValid);
3047       if (!status.ok()) {
3048         return failure();
3049       }
3050       ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
3051           status.ValueOrDie();
3052       // Subtract the original exterior padding since it doesn't contribute to
3053       // the gradient. Note that we save one `PadOp` and some unnecessary kernel
3054       // computations, compared to the `xla::AvgPoolGrad` implementation, by
3055       // subtracting the original exterior padding before `ReduceWindowOp`
3056       // instead of trimming the result of `ReduceWindowOp` (the final result is
3057       // the same because all strides are 1).
3058       low_padding[dim] =
3059           conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
3060       high_padding[dim] =
3061           conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
3062 
3063       // Update `out_grad_shape` to result shape of following `PadOp`.
3064       out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
3065                             (out_grad_shape[dim] - 1) * strides[dim] + 1;
3066     }
3067     Value reduce_window_input = rewriter.create<PadOp>(
3068         loc, RankedTensorType::get(out_grad_shape, element_type),
3069         /*operand=*/out_grad_divided->getOpResult(0),
3070         /*padding_value=*/zero,
3071         /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
3072         /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
3073         /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
3074 
3075     // Compute result by convolving `reduce_window_input` with an all-ones
3076     // kernel, using `ReduceWindowOp` with `AddOp` body.
3077 
3078     Type sum_element_type = GetSumAccumulationType(element_type);
3079     if (element_type != sum_element_type) {
3080       // Convert to appropriate sum accumulation type to avoid precision loss.
3081       reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
3082                                                        sum_element_type);
3083       zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
3084     }
3085     auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
3086     auto reduce_window_op = rewriter.create<ReduceWindowOp>(
3087         loc, RankedTensorType::get(orig_input_shape, sum_element_type),
3088         /*operand=*/reduce_window_input,
3089         /*init_value=*/zero,
3090         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
3091         /*window_strides=*/ones,
3092         /*base_dilations=*/DenseIntElementsAttr(),
3093         /*window_dilations=*/DenseIntElementsAttr(),
3094         /*padding=*/DenseIntElementsAttr());
3095     BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
3096                            &rewriter);
3097     Value result = reduce_window_op.getResult(0);
3098 
3099     if (element_type != sum_element_type) {
3100       // Convert back to original element type.
3101       result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
3102     }
3103     rewriter.replaceOp(op, {result});
3104     return success();
3105   }
3106 };
3107 
3108 using ConvertAvgPool2DGradOp =
3109     ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
3110 using ConvertAvgPool3DGradOp =
3111     ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
3112 
3113 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
3114 // dimensions with max as the reduction function.
3115 //
3116 // Sample result for VALID padding mode:
3117 //
3118 //   %init = constant dense<...> : tensor<i32>
3119 //   %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
3120 //               {window_dimensions = ..., window_strides = ... }
3121 //
3122 template <typename OpTy, int num_dims>
3123 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
3124  public:
3125   using OpRewritePattern<OpTy>::OpRewritePattern;
3126 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3127   LogicalResult matchAndRewrite(OpTy op,
3128                                 PatternRewriter &rewriter) const override {
3129     Type element_type =
3130         op.input().getType().template cast<TensorType>().getElementType();
3131     if (!element_type.isSignlessIntOrFloat()) return failure();
3132     tensorflow::Padding padding;
3133     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
3134       return failure();
3135     if (padding == tensorflow::Padding::EXPLICIT) {
3136       return failure();
3137     }
3138     Location loc = op.getLoc();
3139     ConstOp init = GetScalarLimitConstOfType(element_type, loc,
3140                                              hlo::kInfinityLowest, &rewriter);
3141 
3142     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
3143     if (!input_ty) return failure();
3144     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
3145         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
3146     auto reduce = rewriter.create<ReduceWindowOp>(
3147         loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
3148         GetI64ElementsAttr(op.strides()),
3149         /*base_dilations=*/DenseIntElementsAttr(),
3150         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
3151     BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
3152 
3153     rewriter.replaceOp(op, reduce.getResult(0));
3154     return success();
3155   }
3156 };
3157 
3158 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
3159 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
3160 
3161 // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on
3162 // the condition only.
3163 class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {
3164  public:
3165   using OpRewritePattern::OpRewritePattern;
3166 
matchAndRewrite(TF::SelectOp op,PatternRewriter & rewriter) const3167   LogicalResult matchAndRewrite(TF::SelectOp op,
3168                                 PatternRewriter &rewriter) const override {
3169     // This lowering only works on ranked types.
3170     auto cond_type = op.condition().getType().dyn_cast<RankedTensorType>();
3171     auto then_type = op.t().getType().dyn_cast<RankedTensorType>();
3172     auto else_type = op.e().getType().dyn_cast<RankedTensorType>();
3173     if (!cond_type || !then_type || !else_type) {
3174       return failure();
3175     }
3176 
3177     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3178     Value cond_shape = b.createOrFold<shape::ShapeOfOp>(op.condition());
3179     Value then_shape = b.createOrFold<shape::ShapeOfOp>(op.t());
3180     Value else_shape = b.createOrFold<shape::ShapeOfOp>(op.e());
3181 
3182     // First check that the `then` and `else` shapes are the equal.
3183     Value assumption =
3184         b.createOrFold<shape::CstrEqOp>(ValueRange{then_shape, else_shape});
3185     // For a vector cond we also verify that the majormost dim of `then` matches
3186     // the vector size. To do that split off the first dim of `then`.
3187     bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1;
3188     Value then_shape_split = then_shape;
3189     if (needs_broadcast) {
3190       Value const_one = b.create<ConstantIndexOp>(1);
3191       Type extents = shape::getExtentTensorType(b.getContext());
3192       SmallVector<Value, 2> then_split;
3193       b.createOrFold<shape::SplitAtOp>(then_split, TypeRange{extents, extents},
3194                                        then_shape, const_one);
3195       then_shape_split = then_split[0];
3196     }
3197     // If the condition is not a scalar, check that it matches the other shapes.
3198     if (cond_type.getRank() > 0) {
3199       Value eq_cstr = b.createOrFold<shape::CstrEqOp>(
3200           ValueRange{cond_shape, then_shape_split});
3201       auto witness = shape::WitnessType::get(b.getContext());
3202       assumption = b.createOrFold<shape::AssumingAllOp>(
3203           witness, ValueRange{assumption, eq_cstr});
3204     }
3205     auto result_type = op.getResult().getType().cast<TensorType>();
3206     auto assuming_op =
3207         b.create<shape::AssumingOp>(ArrayRef<Type>{result_type}, assumption);
3208 
3209     OpBuilder::InsertionGuard guard(b);
3210     b.createBlock(&assuming_op.doRegion());
3211 
3212     // Broadcast the cond if necessary.
3213     Value cond = op.condition();
3214     if (needs_broadcast) {
3215       Value result_extents = b.create<shape::ToExtentTensorOp>(
3216           GetExtentsTensorTypeFor(result_type), then_shape);
3217       cond = b.create<mhlo::DynamicBroadcastInDimOp>(
3218           RankedTensorType::get(result_type.getShape(), b.getI1Type()), cond,
3219           result_extents, GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b));
3220     }
3221     Value select = b.create<mhlo::SelectOp>(result_type, cond, op.t(), op.e());
3222     b.create<shape::AssumingYieldOp>(select);
3223     rewriter.replaceOp(op, {assuming_op.getResult(0)});
3224     return success();
3225   }
3226 };
3227 
3228 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
3229 //
3230 //     sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
3231 //
3232 // Sample result with 2-d f16 inputs with B batches of with N elements each.
3233 //
3234 //    // Create an array of 0.5 the shape of the input array.
3235 //    %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
3236 //    %half_array = "mhlo.broadcast"(half)
3237 //                           {broadcast_sizes = dense<2> : tensor<1xi64>}
3238 //                           : (tensor<f32>) -> tensor<2xf32>
3239 //
3240 //    // Compute Tanh of half the logits of the values.
3241 //    %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
3242 //    %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
3243 //
3244 //    // Have the result of Tanh and add 0.5.
3245 //    %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
3246 //    %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
3247 //
3248 class ConvertSigmoidOp : public RewritePattern {
3249  public:
ConvertSigmoidOp(MLIRContext * context)3250   explicit ConvertSigmoidOp(MLIRContext *context)
3251       : RewritePattern(
3252             TF::SigmoidOp::getOperationName(), 0, context,
3253             {mhlo::ConstOp::getOperationName(),
3254              shape::ShapeOfOp::getOperationName(),
3255              shape::ToExtentTensorOp::getOperationName(),
3256              mhlo::DynamicBroadcastInDimOp::getOperationName(),
3257              mhlo::MulOp::getOperationName(), mhlo::TanhOp::getOperationName(),
3258              mhlo::AddOp::getOperationName()}) {}
3259 
matchAndRewrite(Operation * sigmoid_op,PatternRewriter & rewriter) const3260   LogicalResult matchAndRewrite(Operation *sigmoid_op,
3261                                 PatternRewriter &rewriter) const override {
3262     auto op = cast<TF::SigmoidOp>(sigmoid_op);
3263     Location loc = op.getLoc();
3264 
3265     // Create constant half with shape and element type same as the operand.
3266     Value operand = op.getOperand();
3267     auto operand_ty = operand.getType().cast<TensorType>();
3268     auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
3269     ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
3270     auto scalar_half = rewriter.create<ConstOp>(loc, attr);
3271     auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
3272 
3273     auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
3274     auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
3275     auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
3276     auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
3277 
3278     rewriter.replaceOp(op, add_op.getResult());
3279     return success();
3280   }
3281 };
3282 
3283 // Converts the tf.Slice op into mhlo.real_dynamic_slice
3284 // TODO(disc): To recover static special case's performance with folding and
3285 // canonicalization.
3286 class ConvertSliceOpDynamic : public OpRewritePattern<TF::SliceOp> {
3287  public:
3288   using OpRewritePattern::OpRewritePattern;
3289 
matchAndRewrite(TF::SliceOp op,PatternRewriter & rewriter) const3290   LogicalResult matchAndRewrite(TF::SliceOp op,
3291                                 PatternRewriter &rewriter) const override {
3292     Location loc = op.getLoc();
3293     Value input = op.input();
3294     Value begin_indices = op.begin();
3295     Value sizes = op.size();
3296 
3297     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
3298     auto begin_type = begin_indices.getType().dyn_cast<RankedTensorType>();
3299     auto size_type = sizes.getType().dyn_cast<RankedTensorType>();
3300 
3301     if (!input_ty || !begin_type || !size_type ||
3302         !begin_type.hasStaticShape() || !size_type.hasStaticShape() ||
3303         begin_type.getRank() != 1 || size_type.getRank() != 1) {
3304       return failure();
3305     }
3306     // TODO(disc): remove static shape check once folding/canonicalization func
3307     // added
3308     DenseIntElementsAttr size_attr;
3309     if (matchPattern(op.size(), m_Constant(&size_attr))) {
3310       return failure();
3311     }
3312 
3313     int rank = begin_type.getDimSize(0);
3314     auto shape_scalar_type = begin_type.getElementType();
3315     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
3316     SmallVector<Value, 4> stride_values(rank, one);
3317     SmallVector<Value, 4> end_values;
3318     SmallVector<Value, 4> begin_values;
3319     end_values.reserve(rank);
3320     for (int i = 0; i < rank; ++i) {
3321       SmallVector<Value, 4> indices;
3322       indices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
3323       auto begin_value =
3324           rewriter.create<tensor::ExtractOp>(loc, begin_indices, indices);
3325       auto size_value = rewriter.create<tensor::ExtractOp>(loc, sizes, indices);
3326       Value minus_one = rewriter.create<IndexCastOp>(
3327           loc, rewriter.create<ConstantIndexOp>(loc, -1), shape_scalar_type);
3328       auto is_minus_one = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq,
3329                                                   size_value, minus_one);
3330       Value end_value = rewriter.create<AddIOp>(loc, begin_value, size_value);
3331       auto dim_value = rewriter.create<IndexCastOp>(
3332           loc, rewriter.create<tensor::DimOp>(loc, input, i),
3333           shape_scalar_type);
3334       end_value = rewriter.create<mlir::SelectOp>(loc, is_minus_one, dim_value,
3335                                                   end_value);
3336       auto end_value_casted =
3337           rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), end_value);
3338       end_values.push_back(end_value_casted);
3339 
3340       auto begin_value_casted = rewriter.create<IndexCastOp>(
3341           loc, rewriter.getIndexType(), begin_value);
3342       begin_values.push_back(begin_value_casted);
3343     }
3344 
3345     auto start_indices = rewriter.create<tensor::FromElementsOp>(
3346         loc, rewriter.getIndexType(), begin_values);
3347     auto end_indices = rewriter.create<tensor::FromElementsOp>(
3348         loc, rewriter.getIndexType(), end_values);
3349     auto stride_indices = rewriter.create<tensor::FromElementsOp>(
3350         loc, rewriter.getIndexType(), stride_values);
3351 
3352     auto d_slice = rewriter.create<mhlo::RealDynamicSliceOp>(
3353         loc, op.getOperation()->getResult(0).getType(), input, start_indices,
3354         end_indices, stride_indices);
3355     rewriter.replaceOp(op, d_slice.getOperation()->getResults());
3356     return success();
3357   }
3358 };
3359 
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)3360 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
3361                                            Value *out_lhs, Value *out_rhs,
3362                                            PatternRewriter *rewriter) {
3363   // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
3364   // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
3365   // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
3366   // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
3367   // To perform the matmul, we need to first broadcast lhs and rhs to a common
3368   // set of leading dimensions before doing the actual matmul.
3369   // That's what the code below does.
3370   // In particular, we populate out_lhs and out_rhs to have dimension structure:
3371   // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
3372   // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
3373   // To do this, we need to calculate those output shapes, which involves
3374   // slicing off the leading batch dims of each operand, broadcasting them,
3375   // then concatenating the broadcasted leading dims back to the row/col dims.
3376   // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
3377 
3378   // TODO(silvasean): Reduce duplication across reified shape calculations and
3379   // the static computation of output types needed to create ops.
3380   Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
3381   Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
3382   Value const_neg2 =
3383       rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
3384   auto shape_type = shape::ShapeType::get(rewriter->getContext());
3385   auto lhs_splitted = rewriter->create<shape::SplitAtOp>(
3386       loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2);
3387   auto rhs_splitted = rewriter->create<shape::SplitAtOp>(
3388       loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2);
3389   auto lhs_type = lhs.getType().cast<RankedTensorType>();
3390   auto rhs_type = rhs.getType().cast<RankedTensorType>();
3391   // The last two dimensions are the matrix row/col dimensions. Don't broadcast
3392   // them.
3393   SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
3394   mlir::OpTrait::util::getBroadcastedShape(
3395       lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2),
3396       result_batch_shape_compile_time_extents);
3397   auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
3398       loc, shape_type, lhs_splitted.head(), rhs_splitted.head(),
3399       /*error=*/nullptr);
3400   // Lambda which handles the broadcasting of one side to the common
3401   // leading-batch dimensions.
3402   auto broadcast_one_side = [&](Value side, RankedTensorType type,
3403                                 Value tail_shape, Value *out_side) {
3404     ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
3405     auto result_shape = result_batch_shape_compile_time_extents;
3406     result_shape.append(matrix_dims.begin(), matrix_dims.end());
3407     auto result_type =
3408         RankedTensorType::get(result_shape, type.getElementType());
3409     auto shape =
3410         rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
3411     auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
3412         loc,
3413         RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
3414                               rewriter->getIndexType()),
3415         shape);
3416     *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
3417                                                     shape_tensor);
3418   };
3419   broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
3420   broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
3421 }
3422 
3423 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
3424  public:
3425   // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
3426   // to CHLO and it is missing legalization to MHLO. Once that is done, this
3427   // pattern's benefit can be changed back to one as well as the fallback
3428   // lowering pattern for the op can be removed.
3429   //
3430   // Set benefit of this pattern to zero to prefer the fallback pattern when
3431   // available and applicable. That pattern avoids broadcast on operands and is
3432   // therefore faster.
3433   //
3434   // Native legalization for BatchMatMulV3 needs to be added as well.
ConvertBatchMatMulV2Op(MLIRContext * context)3435   explicit ConvertBatchMatMulV2Op(MLIRContext *context)
3436       : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
3437 
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const3438   LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
3439                                 PatternRewriter &rewriter) const override {
3440     Value lhs = op.x();
3441     Value rhs = op.y();
3442     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
3443     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
3444     if (!lhs_type || !rhs_type) return failure();
3445     if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
3446       lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
3447     }
3448     if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
3449       rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
3450     }
3451 
3452     // Broadcast both operands.
3453     BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
3454                                    &rewriter);
3455     lhs_type = lhs.getType().cast<RankedTensorType>();
3456     rhs_type = rhs.getType().cast<RankedTensorType>();
3457     assert(lhs_type.getRank() == rhs_type.getRank());
3458     int64_t rank = lhs_type.getRank();
3459     auto batch_dimensions = GetI64ElementsAttr(
3460         llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
3461     auto lhs_contracting_dimensions = GetI64ElementsAttr(
3462         llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
3463     auto rhs_contracting_dimensions = GetI64ElementsAttr(
3464         llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
3465     auto dimension_numbers = DotDimensionNumbers::get(
3466         /*lhs_batching_dimensions=*/batch_dimensions,
3467         /*rhs_batching_dimensions=*/batch_dimensions,
3468         /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
3469         /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
3470         rewriter.getContext());
3471     // TODO(silvasean): Emit shape checks for contracting dimensions.
3472     // (The batch dimensions are checked by the broadcasting logic)
3473     rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
3474                                               dimension_numbers,
3475                                               /*precision_config=*/nullptr);
3476     return success();
3477   }
3478 };
3479 
3480 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
3481 // split has fully static shape and the dimension to split is a constant.
3482 //
3483 // The main logic of this pattern is to calculate the index start and end range
3484 // for each slice. And this happens only on the dimension to be split; for all
3485 // other dimensions, all resultant slices' index start and end range covers the
3486 // input tensor's full range. Strides for all resultant slices are all one.
3487 //
3488 // For example, the following source IR:
3489 //
3490 //   %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3491 //   %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
3492 //                (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3493 //
3494 // will be converted into:
3495 //
3496 //   %0 = "mhlo.slice"(%input) {
3497 //             limit_indices = dense<[4, 2]> : tensor<2xi64>,
3498 //             start_indices = dense<0> : tensor<2xi64>,
3499 //             strides = dense<1> : tensor<2xi64>} :
3500 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3501 //   %1 = "mhlo.slice"(%input) {
3502 //             limit_indices = dense<4> : tensor<2xi64>,
3503 //              start_indices = dense<[0, 2]> : tensor<2xi64>,
3504 //            strides = dense<1> : tensor<2xi64>} :
3505 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3506 //    %2 = "mhlo.slice"(%input) {
3507 //            limit_indices = dense<[4, 6]> : tensor<2xi64>,
3508 //            start_indices = dense<[0, 4]> : tensor<2xi64>,
3509 //             strides = dense<1> : tensor<2xi64>} :
3510 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3511 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
3512 // applicable.
3513 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
3514  public:
3515   using OpRewritePattern::OpRewritePattern;
3516 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3517   LogicalResult matchAndRewrite(TF::SplitOp op,
3518                                 PatternRewriter &rewriter) const override {
3519     // We can only split along static dimensions.
3520     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3521     if (!input_type) return failure();
3522 
3523     // We can only match when the split dimension is a constant scalar.
3524     DenseIntElementsAttr split_dim_attr;
3525     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3526       return failure();
3527 
3528     // Get the dimension we are splitting at. Offset properly if it's negative.
3529     int64_t input_rank = input_type.getRank();
3530     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3531     if (dim_index < 0) dim_index += input_rank;
3532 
3533     // Calculate the dimension size for each slice along the split dimension.
3534     int64_t input_dim_size = input_type.getDimSize(dim_index);
3535     // If we are splitting along the dynamic dimension then we cannot compute
3536     // the static dimension length.
3537     if (TensorType::isDynamic(input_dim_size)) return failure();
3538 
3539     int64_t num_splits = op.getNumResults();
3540     int64_t slice_size = input_dim_size / num_splits;
3541 
3542     // Get each slice's type.
3543     auto slice_shape = llvm::to_vector<4>(input_type.getShape());
3544     slice_shape[dim_index] = slice_size;
3545     Type slice_type =
3546         RankedTensorType::get(slice_shape, input_type.getElementType());
3547 
3548     // Parameters for constructing each slice.
3549     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3550     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3551     SmallVector<int64_t, 4> strides(input_rank, 1);
3552 
3553     // All HLO slice results used to replace the original tf.Split op.
3554     SmallVector<Value, 4> slices;
3555     slices.reserve(num_splits);
3556 
3557     for (int i = 0; i < num_splits; ++i) {
3558       begin_indices[dim_index] = i * slice_size;
3559       end_indices[dim_index] = (i + 1) * slice_size;
3560       slices.push_back(
3561           rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
3562                                    GetI64ElementsAttr(begin_indices, &rewriter),
3563                                    GetI64ElementsAttr(end_indices, &rewriter),
3564                                    GetI64ElementsAttr(strides, &rewriter)));
3565     }
3566 
3567     rewriter.replaceOp(op, slices);
3568     return success();
3569   }
3570 };
3571 
3572 // Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the
3573 // dimension to split is a constant.
3574 // TODO(disc): To recover static special case's performance with folding and
3575 // canonicalization. delete ConvertSplitOp
3576 class ConvertSplitOpDynamic : public OpRewritePattern<TF::SplitOp> {
3577  public:
3578   using OpRewritePattern::OpRewritePattern;
3579 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3580   LogicalResult matchAndRewrite(TF::SplitOp op,
3581                                 PatternRewriter &rewriter) const override {
3582     Location loc = op.getLoc();
3583     Value input = op.value();
3584     auto input_type = input.getType().dyn_cast<RankedTensorType>();
3585     if (!input_type) return failure();
3586     // We can only match when the split dimension is a constant scalar.
3587     DenseIntElementsAttr split_dim_attr;
3588     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3589       return failure();
3590 
3591     // Get the dimension we are splitting at. Offset properly if it's negative.
3592     int64_t input_rank = input_type.getRank();
3593     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3594     if (dim_index < 0) dim_index += input_rank;
3595 
3596     // TODO(disc): remove static shape check once folding/canonicalization func
3597     // added and ConvertSplitOp deleted. Calculate the dimension size for each
3598     // slice along the split dimension. We are splitting along the dynamic
3599     // dimension, or using static pattern transform
3600     int64_t c_input_dim_size = input_type.getDimSize(dim_index);
3601     if (!TensorType::isDynamic(c_input_dim_size)) return failure();
3602 
3603     Value input_dim_size =
3604         rewriter.create<tensor::DimOp>(loc, input, dim_index);
3605     // Calculate the dimension size for each slice along the split dimension.
3606     int num_splits = op.getNumResults();
3607     Value num_splits_value =
3608         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(num_splits));
3609     Value slice_size =
3610         rewriter.create<SignedDivIOp>(loc, input_dim_size, num_splits_value);
3611 
3612     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
3613     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
3614 
3615     SmallVector<Value, 4> begin_indices(input_rank, zero);
3616     SmallVector<Value, 4> end_indices;
3617     end_indices.reserve(input_rank);
3618     SmallVector<Value, 4> strides(input_rank, one);
3619     for (int i = 0; i < input_rank; ++i) {
3620       end_indices.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
3621     }
3622 
3623     // All HLO d_slice results used to replace the original tf.Split op.
3624     SmallVector<Value, 4> slices;
3625     slices.reserve(num_splits);
3626 
3627     for (int i = 0; i < num_splits; ++i) {
3628       begin_indices[dim_index] = rewriter.create<MulIOp>(
3629           loc, slice_size, rewriter.create<ConstantIndexOp>(loc, i));
3630       end_indices[dim_index] = rewriter.create<MulIOp>(
3631           loc, slice_size, rewriter.create<ConstantIndexOp>(loc, i + 1));
3632       auto begin_value = rewriter.create<tensor::FromElementsOp>(
3633           loc, rewriter.getIndexType(), begin_indices);
3634       auto end_value = rewriter.create<tensor::FromElementsOp>(
3635           loc, rewriter.getIndexType(), end_indices);
3636       auto stride_value = rewriter.create<tensor::FromElementsOp>(
3637           loc, rewriter.getIndexType(), strides);
3638       slices.push_back(rewriter.create<RealDynamicSliceOp>(
3639           loc, op.getOperation()->getResult(i).getType(), input, begin_value,
3640           end_value, stride_value));
3641     }
3642 
3643     rewriter.replaceOp(op, slices);
3644     return success();
3645   }
3646 };
3647 
3648 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
3649 // be split has fully static shape and the dimension to split and split sizes
3650 // are constants.
3651 //
3652 // This is similar to the conversion for tf.Split op other than that the size of
3653 // each chunk on the dimension to split is explicitly given as an op operand
3654 // and they are not necessarily the same.
3655 //
3656 // For example, given the following IR:
3657 //
3658 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
3659 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
3660 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
3661 //                   (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
3662 //                   (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
3663 //
3664 // We will generate slices following slices:
3665 // %0 = "mhlo.slice"(%input) {
3666 //        limit_indices = dense<[4, 1]> : tensor<2xi64>,
3667 //        start_indices = dense<0> : tensor<2xi64>,
3668 //        strides = dense<1> : tensor<2xi64>} :
3669 //        (tensor<4x6xf32>) -> tensor<4x1xf32>
3670 // %1 = "mhlo.slice"(%input) {
3671 //        limit_indices = dense<[4, 3]> : tensor<2xi64>,
3672 //        start_indices = dense<[0, 1]> : tensor<2xi64>,
3673 //        strides = dense<1> : tensor<2xi64>} :
3674 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3675 // %2 = "mhlo.slice"(%input) {
3676 //        limit_indices = dense<[4, 6]> : tensor<2xi64>,
3677 //        start_indices = dense<[0, 3]> : tensor<2xi64>,
3678 //        strides = dense<1> : tensor<2xi64>} :
3679 //        (tensor<4x6xf32>) -> tensor<4x3xf32>
3680 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
3681  public:
3682   using OpRewritePattern::OpRewritePattern;
3683 
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3684   LogicalResult matchAndRewrite(TF::SplitVOp op,
3685                                 PatternRewriter &rewriter) const override {
3686     // We can only split along static dimensions.
3687     // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3688     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3689     if (!input_type) return failure();
3690 
3691     // We can only match when the split dimension is a constant scalar.
3692     DenseIntElementsAttr split_dim_attr;
3693     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3694       return failure();
3695 
3696     // We can only match when the split sizes is a constant int vector.
3697     DenseIntElementsAttr split_sizes_attr;
3698     if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3699       return failure();
3700 
3701     // Get each chunck's size along the dimension to split. It may contain
3702     // dynamic sizes and we need to update it if so.
3703     SmallVector<int64_t, 4> split_sizes;
3704     int64_t total_dim_size = 0;  // Total dimension size assigned to splits
3705     llvm::Optional<int> dynamic_dim_index;
3706     split_sizes.reserve(
3707         split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3708     for (auto dim : llvm::enumerate(split_sizes_attr)) {
3709       int64_t dim_val = dim.value().getSExtValue();
3710       split_sizes.push_back(dim_val);
3711       if (dim_val == ShapedType::kDynamicSize) {
3712         // We cannot have more than one dynamic dimension.
3713         assert(!dynamic_dim_index && "invalid split sizes");
3714         dynamic_dim_index = dim.index();
3715       } else {
3716         total_dim_size += dim_val;
3717       }
3718     }
3719 
3720     // Get the dimension we are splitting at. Offset properly if it's negative.
3721     int64_t input_rank = input_type.getRank();
3722     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3723     if (dim_index < 0) dim_index += input_rank;
3724 
3725     int64_t input_dim_size = input_type.getDimSize(dim_index);
3726     if (TensorType::isDynamic(input_dim_size)) return failure();
3727 
3728     assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3729             (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3730            "invalid split sizes");
3731 
3732     // Update the dynamic dimension with calculated concrete size.
3733     if (dynamic_dim_index)
3734       split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3735 
3736     // Parameters for constructing each slice.
3737     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3738     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3739     SmallVector<int64_t, 4> strides(input_rank, 1);
3740 
3741     // All HLO slice results used to replace the original tf.Split op.
3742     SmallVector<Value, 4> slices;
3743     slices.reserve(op.getNumResults());
3744 
3745     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3746       end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3747       slices.push_back(rewriter.create<mhlo::SliceOp>(
3748           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3749           GetI64ElementsAttr(end_indices, &rewriter),
3750           GetI64ElementsAttr(strides, &rewriter)));
3751       // Prepare the begin indice for the next slice.
3752       begin_indices[dim_index] = end_indices[dim_index];
3753     }
3754 
3755     rewriter.replaceOp(op, slices);
3756     return success();
3757   }
3758 };
3759 
3760 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3761 // negative strides and Reshape op to update the output shape. Indices and
3762 // strides operands are converted to attributes with non-negative indexing.
3763 //
3764 // If the begin input is not a compile time constant, the begin input needs to
3765 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3766 // case, strides must have a known value of 1 (otherwise we have insufficient
3767 // information to conform to XLA's op semantics).
3768 //
3769 // For example with an op like following,
3770 //   tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3771 //     : tensor<AxBxf32> -> tensor<Pxf32>
3772 //
3773 // If the %begin input is constant, output would be:
3774 //   %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3775 //   %sliced = "mhlo.Slice" (%input)
3776 //             {start_indices = ..., limit_indices = ..., strides = ...}
3777 //   %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3778 //
3779 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3780  public:
3781   using OpRewritePattern::OpRewritePattern;
3782 
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3783   LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3784                                          ArrayRef<int64_t> begin_indices,
3785                                          ArrayRef<int64_t> end_indices,
3786                                          ArrayRef<int64_t> strides,
3787                                          RankedTensorType input_ty,
3788                                          PatternRewriter &rewriter) const {
3789     SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3790         dims_to_reverse;
3791     int64_t input_rank = input_ty.getRank();
3792     ArrayRef<int64_t> input_shape = input_ty.getShape();
3793     hlo_begin_indices.reserve(input_rank);
3794     hlo_end_indices.reserve(input_rank);
3795     hlo_strides.reserve(input_rank);
3796 
3797     int64_t indices_elements = begin_indices.size();
3798     if (input_rank < indices_elements) return failure();
3799 
3800     // Convert from TensorFlow negative or out of range indices and strides
3801     // values to legal HLO Slice attributes.
3802     for (int i = 0, e = indices_elements; i != e; i++) {
3803       int64_t begin = begin_indices[i];
3804       int64_t end = end_indices[i];
3805       int64_t stride = strides[i];
3806 
3807       if (stride < 0) {
3808         // Negative stride means that the output values are computed starting
3809         // from end until begin. Mark the dimension for reversal before slice
3810         // and compute indices for the reversed input.
3811         dims_to_reverse.push_back(i);
3812         begin = (input_shape[i] - 1) - begin;
3813         end = (input_shape[i] - 1) - end;
3814         stride = -stride;
3815       }
3816 
3817       // Unlike TensorFlow, HLO requires begin and end values to be within
3818       // range.
3819       begin = std::max(int64_t(0), begin);
3820       end = std::max(begin, end);
3821       end = std::min(end, input_shape[i]);
3822 
3823       hlo_begin_indices.push_back(begin);
3824       hlo_end_indices.push_back(end);
3825       hlo_strides.push_back(stride);
3826     }
3827 
3828     Location loc = op.getLoc();
3829     Value input = op.input();
3830     if (!dims_to_reverse.empty())
3831       input = rewriter.create<ReverseOp>(
3832           loc, input_ty, op.input(),
3833           GetI64ElementsAttr(dims_to_reverse, &rewriter));
3834     auto sliced = rewriter.create<SliceOp>(
3835         loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3836         GetI64ElementsAttr(hlo_end_indices, &rewriter),
3837         GetI64ElementsAttr(hlo_strides, &rewriter));
3838 
3839     // Reshape slice result so that the shape is updated depending on
3840     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3841     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3842     return success();
3843   }
3844 
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3845   LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3846                                         RankedTensorType input_ty,
3847                                         RankedTensorType result_ty,
3848                                         PatternRewriter &rewriter) const {
3849     // If begin and end values are dynamic, we can only support this lowering
3850     // if strides are a known value of 1.
3851     DenseIntElementsAttr sparse_strides_attr;
3852     if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3853       return rewriter.notifyMatchFailure(
3854           op,
3855           "requires that strides are known when begin/end values are dynamic");
3856     }
3857     SmallVector<int64_t, 4> strides;
3858     int64_t stride_value;
3859     for (const APInt &stride : sparse_strides_attr) {
3860       if ((stride_value = stride.getSExtValue()) != 1) {
3861         return rewriter.notifyMatchFailure(op,
3862                                            "requires that strides are all 1 "
3863                                            "when begin/end values are dynamic");
3864       }
3865       strides.push_back(stride_value);
3866     }
3867 
3868     ArrayRef<int64_t> input_shape = input_ty.getShape();
3869     int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3870 
3871     // When begin/end values are dynamic, we can only support shrinking a major
3872     // axis. For instance, if there are 4 dims, we can support a
3873     // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3874     // other.
3875     bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3876     if (!shrink_axis_mask_ok)
3877       return rewriter.notifyMatchFailure(
3878           op,
3879           "requires that shrink_axis_mask, if set, refer to a major axis "
3880           "dimension (when begin/end values are dynamic)");
3881 
3882     // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3883     // to the last dimension.
3884     int ellipsis_mask = op.ellipsis_mask();
3885     if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3886       return rewriter.notifyMatchFailure(
3887           op,
3888           "requires that ellipsis_mask, if set, refer to the last dimension of "
3889           "input (when begin/end values are dynamic)");
3890 
3891     uint64_t begin_mask = op.begin_mask();
3892     if (begin_mask)
3893       return rewriter.notifyMatchFailure(
3894           op,
3895           "requires that begin_mask is either set to 0 or not set when "
3896           "begin/end values are dynamic");
3897     uint64_t end_mask = op.end_mask();
3898     if (end_mask)
3899       return rewriter.notifyMatchFailure(
3900           op,
3901           "requires that end_mask is either set to 0 or not set when begin/end "
3902           "values are dynamic");
3903     uint64_t new_axis_mask = op.new_axis_mask();
3904     if (new_axis_mask)
3905       return rewriter.notifyMatchFailure(
3906           op,
3907           "requires that new_axis_mask is either set to 0 or not set when "
3908           "begin/end values are dynamic");
3909 
3910     // In this case where the begin and end values are dynamic, the number of
3911     // output elements has to be equal to the number of input elements that
3912     // are sliced.
3913     int output_elements = result_ty.getNumElements();
3914     int input_elements_sliced = 1;
3915 
3916     // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3917     // verifier.
3918     int64_t slicing_dim_size =
3919         op.begin().getType().cast<RankedTensorType>().getShape()[0];
3920     const int input_rank = input_shape.size();
3921     for (int d = slicing_dim_size; d < input_rank; ++d) {
3922       // We only support slicing major dimensions, so minor dimensions after
3923       // slicing dimensions are all sliced with their full sizes.
3924       input_elements_sliced *= input_shape[d];
3925     }
3926     if (input_elements_sliced != output_elements) {
3927       return rewriter.notifyMatchFailure(
3928           op,
3929           "requires the number of output elements to be equal to the number of "
3930           "input elements sliced (when begin/end values are dynamic)");
3931     }
3932 
3933     SmallVector<Value, 4> slice_begin_indices;
3934     // For the dimensions that are to be sliced, all have slice sizes of 1.
3935     SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
3936     auto begin_element_ty =
3937         op.begin().getType().cast<ShapedType>().getElementType();
3938     // Scalar tensor type.
3939     TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3940     Location loc = op.getLoc();
3941     auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3942     for (int d = 0; d < slicing_dim_size; ++d) {
3943       auto index = rewriter.create<SliceOp>(
3944           loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3945           GetI64ElementsAttr({d + 1}, &rewriter),
3946           GetI64ElementsAttr({1}, &rewriter));
3947       // Convert index to scalar.
3948       auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3949       // If the index is negative, wrap it around with dimension size.
3950       auto index_negative =
3951           rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3952       auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3953                                             input_shape[d], &rewriter);
3954       auto wrapped_index =
3955           rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3956       auto final_index = rewriter.create<SelectOp>(
3957           loc, type, index_negative, wrapped_index, reshaped_index);
3958       slice_begin_indices.push_back(final_index);
3959     }
3960 
3961     // For non-slice dims, get the full slice of that dimension.
3962     for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) {
3963       slice_sizes.push_back(input_shape[d]);
3964       slice_begin_indices.push_back(zero);
3965     }
3966 
3967     auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3968     // This must be an xla DynamicSlice op due to the inputs that aren't
3969     // constant.
3970     auto sliced = rewriter.create<DynamicSliceOp>(
3971         loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
3972 
3973     // Reshape slice result so that the shape is updated depending on
3974     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3975     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3976     return success();
3977   }
3978 
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3979   LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3980                                 PatternRewriter &rewriter) const override {
3981     // Input shape needs to be static to convert negative indices in TensorFlow
3982     // to absolute indices required by HLO.
3983     //
3984     // TODO(hinsu): Relax this constraint for ops without negative indices and
3985     // strides.
3986     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3987     if (!input_ty || !input_ty.hasStaticShape()) return failure();
3988 
3989     // Output shape needs to be static to apply 'new_axis_mask' or
3990     // 'shrink_axis_mask' by reshaping tensor after slice.
3991     //
3992     // TODO(hinsu): Relax this constraint for ops without the above masks.
3993     auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3994     if (!result_ty || !result_ty.hasStaticShape()) return failure();
3995 
3996     DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3997     if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3998         !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3999       return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
4000     }
4001 
4002     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
4003     if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
4004       return failure();
4005     }
4006     return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
4007                                     input_ty, rewriter);
4008   }
4009 };
4010 
4011 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
4012 //
4013 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
4014 // the reverse: it propagates the graident for the sliced tensor to the original
4015 // input tensor by doing padding with zeros. The main logic is calculating the
4016 // indices and strides for padding.
4017 class ConvertStridedSliceGradOp
4018     : public OpRewritePattern<TF::StridedSliceGradOp> {
4019  public:
4020   using OpRewritePattern::OpRewritePattern;
4021 
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const4022   LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
4023                                 PatternRewriter &rewriter) const override {
4024     // We need constant input shape to perform padding calculations later.
4025     DenseIntElementsAttr input_shape_attr;
4026     if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
4027       return failure();
4028 
4029     // We also need constant begin/end indices and strides to perform padding
4030     // calculations.
4031     // Bounded shape after performing strided slice
4032     SmallVector<int64_t, 4> shape;
4033     // Bounded begin, end, and strides for strided slice
4034     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
4035     if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
4036                                          &strides))
4037       return failure();
4038 
4039     Value grad = op.dy();
4040     Type element_type = grad.getType().cast<ShapedType>().getElementType();
4041 
4042     // Perform reshape to undo any new/shrink axes done by strided slice.
4043     grad = rewriter.create<mhlo::ReshapeOp>(
4044         op.getLoc(), RankedTensorType::get(shape, element_type), grad);
4045 
4046     SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
4047     SmallVector<int64_t, 4> dims_to_reverse;
4048     padding_low.reserve(shape.size());
4049     padding_high.reserve(shape.size());
4050     padding_interm.reserve(shape.size());
4051 
4052     // Prepare padding parameters for each dimension.
4053     for (int i = 0, e = shape.size(); i < e; ++i) {
4054       int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
4055       if (strides[i] > 0) {
4056         padding_low.push_back(begin_indices[i]);
4057         padding_interm.push_back(strides[i] - 1);
4058 
4059         // Pad the upper dimension up to the expected input shape. It's not
4060         // sufficient simply to use end_indices[i] to compute the padding in
4061         // cases where the stride does not divide evenly into the interval
4062         // between begin_indices[i] and end_indices[i].
4063         int64_t size =
4064             padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
4065         padding_high.push_back(input_dim - size);
4066       } else {
4067         dims_to_reverse.push_back(i);
4068         padding_high.push_back(input_dim - begin_indices[i] - 1);
4069         padding_interm.push_back(-strides[i] - 1);
4070 
4071         // Pad the lower dimension up to the expected input shape.
4072         int64_t size =
4073             padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
4074         padding_low.push_back(input_dim - size);
4075       }
4076     }
4077 
4078     if (!dims_to_reverse.empty()) {
4079       grad = rewriter.create<mhlo::ReverseOp>(
4080           op.getLoc(), grad.getType(), grad,
4081           GetI64ElementsAttr(dims_to_reverse, &rewriter));
4082     }
4083 
4084     auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
4085     rewriter.replaceOpWithNewOp<mhlo::PadOp>(
4086         op, op.getType(), grad, zero,
4087         GetI64ElementsAttr(padding_low, &rewriter),
4088         GetI64ElementsAttr(padding_high, &rewriter),
4089         GetI64ElementsAttr(padding_interm, &rewriter));
4090     return success();
4091   }
4092 };
4093 
4094 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
4095 /// offset applied to generate the range values. The output tensor needs to
4096 /// have a static shape.
4097 ///
4098 /// For example an op like the following:
4099 ///   %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
4100 ///      : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
4101 ///
4102 /// Output would be:
4103 ///   %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
4104 ///   %scaled = "mhlo.multiply"(%iota, %delta)
4105 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
4106 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
4107 ///   %result = "mhlo.add"(%scaled, %offset)
4108 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
4109 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
4110 ///
4111 /// Implementation is defined in C++ due to no type interface for the iota op.
4112 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
4113   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
4114 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const4115   LogicalResult matchAndRewrite(TF::RangeOp op,
4116                                 PatternRewriter &rewriter) const override {
4117     auto result = op.getResult();
4118     auto result_type = result.getType();
4119     if (!result_type.cast<ShapedType>().hasStaticShape()) {
4120       return failure();
4121     }
4122 
4123     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4124                                         rewriter.getI64IntegerAttr(0));
4125     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4126         op.getLoc(), result_type, iota, op.delta(),
4127         hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
4128     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4129         op, result_type, scaled, op.start(),
4130         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4131     return success();
4132   }
4133 };
4134 
4135 // Converts RangeOp for cases with the length is a dynamic value. The shape of
4136 // the resulting tensor computed, then the start and delta is used with the
4137 // dynamic_iota value to compute the final range value.
4138 //
4139 // For example, the resulting range op value:
4140 //   %range = "tf.range"(%start, %limit, %delta)
4141 //
4142 // Is converted to the following.
4143 //   %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
4144 //
4145 // Implementation is defined in C++ due to the complicated type behavior.
4146 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
4147   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
4148 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const4149   LogicalResult matchAndRewrite(TF::RangeOp op,
4150                                 PatternRewriter &rewriter) const override {
4151     auto result = op.getResult();
4152     auto result_type = result.getType().cast<ShapedType>();
4153     if (result_type.hasStaticShape()) {
4154       return failure();
4155     }
4156 
4157     Value start = op.start();
4158     Value delta = op.delta();
4159     Value limit = op.limit();
4160 
4161     // To compute the length we need to use floating point calculations so that
4162     // ceil can be computed for the number of steps.
4163     auto compute_element_type =
4164         getElementTypeOrSelf(start.getType()).isa<FloatType>()
4165             ? getElementTypeOrSelf(start.getType())
4166             : rewriter.getF64Type();
4167     auto compute_type = RankedTensorType::get(
4168         limit.getType().cast<ShapedType>().getShape(), compute_element_type);
4169 
4170     // Compute the length of the sequence we are going to need. This includes
4171     // some conversion to float for the operations.
4172     //
4173     // %size = ceil(abs((%limit - %start) / %delta))
4174     auto range = rewriter.create<mhlo::SubOp>(op.getLoc(), limit, start);
4175     auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
4176 
4177     // Delta is not necessarily the same type as start and limit.
4178     auto abs_cast =
4179         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
4180     auto delta_cast =
4181         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
4182 
4183     // Compute the total number of integer steps and convert to the HLO
4184     // dimension tensor.
4185     auto normalized =
4186         rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
4187     auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
4188     auto steps = rewriter.create<mhlo::ConvertOp>(
4189         op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
4190     auto reshape = rewriter.create<mhlo::ReshapeOp>(
4191         op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
4192 
4193     // Using the resulting length compute the correct range value:
4194     //
4195     // %range = %start + %delta * iota(%size)
4196     auto out_scalar_type =
4197         RankedTensorType::get({}, getElementTypeOrSelf(result_type));
4198     auto start_out_cast =
4199         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
4200     auto delta_out_cast =
4201         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
4202 
4203     auto iota = rewriter.create<DynamicIotaOp>(
4204         op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
4205     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4206         op.getLoc(), result_type, iota, delta_out_cast,
4207         hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
4208     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4209         op, result_type, scaled, start_out_cast,
4210         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
4211     return success();
4212   }
4213 };
4214 
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)4215 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
4216   auto int_attr = attr.cast<DenseIntElementsAttr>();
4217   auto type = val.getType().cast<ShapedType>();
4218 
4219   SmallVector<int64_t, 6> axis;
4220   axis.reserve(int_attr.getNumElements());
4221 
4222   int64_t rank = type.getRank();
4223   for (auto val : int_attr.getValues<APInt>()) {
4224     axis.push_back((val.getSExtValue() + rank) % rank);
4225   }
4226 
4227   return builder->getI64TensorAttr(axis);
4228 }
4229 
4230 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
4231 /// and offset applied to generate the linspace values. The output tensor needs
4232 /// to have a static shape.  The implementation is defined in C++ because there
4233 /// is no type inference for the iota op.
4234 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
4235   using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
4236 
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const4237   LogicalResult matchAndRewrite(TF::LinSpaceOp op,
4238                                 PatternRewriter &rewriter) const override {
4239     auto result = op.getResult();
4240     auto result_type = result.getType().dyn_cast<ShapedType>();
4241     if (!result_type || !result_type.hasStaticShape()) {
4242       return failure();
4243     }
4244 
4245     DenseIntElementsAttr num_attr;
4246     if (!matchPattern(op.num(), m_Constant(&num_attr))) {
4247       return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
4248     }
4249 
4250     if (num_attr.begin() == num_attr.end()) {
4251       return rewriter.notifyMatchFailure(op, "Num must not be empty");
4252     }
4253     int64_t num = (*num_attr.begin()).getSExtValue();
4254 
4255     // Calculate the scaling that needs to be applied to the iota.
4256     auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
4257         op.getLoc(), op.start().getType(), op.stop(), op.start(),
4258         hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
4259     Value step_denominator = rewriter.create<ConvertOp>(
4260         op.getLoc(), op.num(), result_type.getElementType());
4261     if (num > 1) {
4262       Value one = GetScalarConstOfType(result_type.getElementType(),
4263                                        op.getLoc(), 1, &rewriter);
4264       step_denominator = rewriter.create<chlo::BroadcastSubOp>(
4265           op.getLoc(), step_denominator.getType(), step_denominator, one,
4266           hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
4267     }
4268     auto step = rewriter.create<chlo::BroadcastDivOp>(
4269         op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
4270         hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
4271                                         step_denominator));
4272 
4273     // Scale the iota and add the offset.
4274     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4275                                         rewriter.getI64IntegerAttr(0));
4276     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4277         op.getLoc(), result_type, iota, step,
4278         hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
4279     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4280         op, result_type, scaled, op.start(),
4281         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4282     return success();
4283   }
4284 };
4285 
4286 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
4287 /// ReductionOp.
4288 /// `is_accumulation` controls whether it uses higher precision for the actual
4289 /// reduction. This is set to false for ops like max where there is no precision
4290 /// concerns.
4291 //
4292 // The Derived class should have a static method to return the initial value to
4293 // use for reduction:
4294 //   static Value GetInitialValue(Type reduce_element_type, Location loc,
4295 //                                PatternRewriter *rewriter);
4296 // The reduce_element_type is guaranteed to be a float, int, or complex type
4297 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
4298 template <typename Derived, typename OpTy, typename ReductionOp,
4299           bool is_accumulation = true>
4300 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
4301   using OpRewritePattern<OpTy>::OpRewritePattern;
4302 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4303   LogicalResult matchAndRewrite(OpTy op,
4304                                 PatternRewriter &rewriter) const override {
4305     // TODO(b/141785544): Update this to not require ranked shapes.
4306     // Input shape needs to be ranked to convert negative indices in TensorFlow
4307     // to absolute indices required by HLO.
4308     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4309     if (!input_ty) return failure();
4310     ArrayRef<int64_t> input_shape = input_ty.getShape();
4311 
4312     DenseIntElementsAttr dimensions;
4313     if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
4314       return failure();
4315 
4316     // Build the final shape from input_shape and dimensions using a bitmap
4317     // to mark the reduced dimensions.
4318     SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
4319     SmallVector<int64_t, 4> xla_dimensions;
4320     for (const APInt &index_raw : dimensions.getValues<APInt>()) {
4321       int64_t index = index_raw.getSExtValue();
4322       int64_t rank = input_shape.size();
4323       if ((index < -rank || index >= rank)) return failure();
4324       index = (index + rank) % rank;
4325       reduced_dimensions_bitmap[index] = true;
4326       xla_dimensions.push_back(index);
4327     }
4328 
4329     Location loc = op.getLoc();
4330     Type element_type = input_ty.getElementType();
4331 
4332     // Only float, int, and complex types are currently supported.
4333     if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
4334         !element_type.isa<ComplexType>()) {
4335       return rewriter.notifyMatchFailure(
4336           op, "element type must be float, int, or complex type");
4337     }
4338 
4339     // Convert to an accumulation type to not lose precision when doing
4340     // repeated arithmetic operations.
4341     Type reduce_element_type =
4342         is_accumulation ? GetAccumulationType(element_type) : element_type;
4343     auto casted_input =
4344         rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
4345 
4346     // Each reduction op can have a different initial value.
4347     Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
4348 
4349     auto reduction = rewriter.create<ReduceOp>(
4350         loc, casted_input.getResult(), init,
4351         GetI64ElementsAttr(xla_dimensions, &rewriter));
4352     BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
4353                                  &rewriter);
4354     Value result = reduction.getResult(0);
4355 
4356     // The mean op needs to divide by the product of the reduced dimensions.
4357     if (std::is_same<OpTy, TF::MeanOp>::value) {
4358       Value in_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4359       Value divisor_count = rewriter.create<ConstantIndexOp>(loc, 1);
4360       for (size_t i = 0; i < input_shape.size(); ++i) {
4361         if (reduced_dimensions_bitmap[i]) {
4362           Value index = rewriter.create<ConstantIndexOp>(loc, i);
4363           auto dim = rewriter.create<tensor::ExtractOp>(loc, in_shape, index);
4364           divisor_count = rewriter.create<MulIOp>(loc, divisor_count, dim);
4365         }
4366       }
4367       // HLO ops are only defined on tensors, so we cast the divisor from
4368       // index -> i64 -> tensor<1xi64> -> tensor<i64> -> tensor<reduction type>
4369       auto divisor_casted = rewriter.create<IndexCastOp>(
4370           loc, rewriter.getI64Type(), divisor_count);
4371       auto divisor_tensor = rewriter.create<tensor::FromElementsOp>(
4372           loc, rewriter.getI64Type(), ValueRange{divisor_casted});
4373       auto divisor_reshaped = rewriter.create<mhlo::ReshapeOp>(
4374           loc, RankedTensorType::get({}, rewriter.getI64Type()),
4375           divisor_tensor);
4376       auto divisor = rewriter.create<ConvertOp>(
4377           loc, RankedTensorType::get({}, reduce_element_type),
4378           divisor_reshaped);
4379       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
4380       result = rewriter.create<chlo::BroadcastDivOp>(loc, result, divisor,
4381                                                      broadcast_dims);
4382     }
4383 
4384     result = rewriter.create<ConvertOp>(loc, result, element_type);
4385 
4386     // Need to reshape back after the reduction if we're keeping the reduced
4387     // dimensions. Note that we do this through successive (nominally 1)
4388     // applications of the TF ExpandDims op vs a more labor intensive
4389     // reshape. Various code generation techniques benefit from the knowledge
4390     // that this is a restricted form of shape manipulation that is just adding
4391     // unit dims.
4392     if (op.keep_dims()) {
4393       for (auto dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) {
4394         if (dim_is_reduced.value()) {
4395           auto index_attr = GetI32ElementsAttr(
4396               {static_cast<int>(dim_is_reduced.index())}, &rewriter);
4397           Value index = rewriter.create<ConstantOp>(loc, index_attr);
4398           result = rewriter.create<TF::ExpandDimsOp>(loc, result, index);
4399         }
4400       }
4401     }
4402     rewriter.replaceOp(op, {result});
4403 
4404     return success();
4405   }
4406 };
4407 
4408 // Converts Mean op to HLO Reduce op.
4409 //
4410 //   %init = constant dense<...> : tensor<T>
4411 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4412 //               {dimensions = ...}
4413 //   %divisor = constant dense<...> : tensor<T>
4414 //   %mean = "mhlo.divide"(%sum, %divisor)
4415 class ConvertMeanOp
4416     : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
4417  public:
4418   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4419   static Value GetInitialValue(Type reduce_element_type, Location loc,
4420                                PatternRewriter *rewriter) {
4421     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4422   }
4423 };
4424 
4425 // Converts Sum op to HLO Reduce op.
4426 //
4427 //   %init = constant dense<...> : tensor<T>
4428 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4429 //               {dimensions = ...}
4430 class ConvertSumOp
4431     : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
4432  public:
4433   using GenericConvertReductionOp::GenericConvertReductionOp;
4434 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4435   static Value GetInitialValue(Type reduce_element_type, Location loc,
4436                                PatternRewriter *rewriter) {
4437     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4438   }
4439 };
4440 
4441 // Converts Max op to HLO Reduce op.
4442 //
4443 //   %init = constant dense<...> : tensor<T>
4444 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
4445 //               {dimensions = ...}
4446 class ConvertMaxOp
4447     : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
4448                                        /* is_accumulation= */ false> {
4449  public:
4450   using GenericConvertReductionOp::GenericConvertReductionOp;
4451 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4452   static Value GetInitialValue(Type reduce_element_type, Location loc,
4453                                PatternRewriter *rewriter) {
4454     return GetScalarLimitConstOfType(reduce_element_type, loc,
4455                                      hlo::kInfinityLowest, rewriter);
4456   }
4457 };
4458 
4459 // Converts Min op to HLO Reduce op.
4460 //
4461 //   %init = constant dense<...> : tensor<T>
4462 //   %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
4463 //               {dimensions = ...}
4464 class ConvertMinOp
4465     : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
4466                                        /* is_accumulation= */ false> {
4467  public:
4468   using GenericConvertReductionOp::GenericConvertReductionOp;
4469 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4470   static Value GetInitialValue(Type reduce_element_type, Location loc,
4471                                PatternRewriter *rewriter) {
4472     return GetScalarLimitConstOfType(reduce_element_type, loc,
4473                                      hlo::kInfinityMax, rewriter);
4474   }
4475 };
4476 
4477 // Converts Prod op to HLO Reduce op.
4478 //
4479 //   %init = constant dense<...> : tensor<T>
4480 //   %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
4481 //               {dimensions = ...}
4482 class ConvertProdOp
4483     : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
4484  public:
4485   using GenericConvertReductionOp::GenericConvertReductionOp;
4486 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4487   static Value GetInitialValue(Type reduce_element_type, Location loc,
4488                                PatternRewriter *rewriter) {
4489     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4490   }
4491 };
4492 
4493 // Converts All op to HLO Reduce op.
4494 //
4495 //   %init = constant dense<...> : tensor<T>
4496 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
4497 //               {dimensions = ...}
4498 class ConvertAllOp
4499     : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
4500  public:
4501   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4502   static Value GetInitialValue(Type reduce_element_type, Location loc,
4503                                PatternRewriter *rewriter) {
4504     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4505   }
4506 };
4507 
4508 // Converts Any op to HLO Reduce op.
4509 //
4510 //   %init = constant dense<...> : tensor<T>
4511 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
4512 //               {dimensions = ...}
4513 class ConvertAnyOp
4514     : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
4515  public:
4516   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4517   static Value GetInitialValue(Type reduce_element_type, Location loc,
4518                                PatternRewriter *rewriter) {
4519     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4520   }
4521 };
4522 
4523 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
4524 // a reduction on the original input and the corresponding index. The reduction
4525 // sub-computation selects the max (or min) value and the index for the value.
4526 //   Derived: is the resulting derived class of this class.
4527 //   OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
4528 template <typename Derived, typename OpTy>
4529 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
4530   using OpRewritePattern<OpTy>::OpRewritePattern;
4531 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4532   LogicalResult matchAndRewrite(OpTy op,
4533                                 PatternRewriter &rewriter) const override {
4534     RankedTensorType input_type =
4535         op.input().getType().template dyn_cast<RankedTensorType>();
4536     if (!input_type) {
4537       return failure();
4538     }
4539 
4540     Type input_element_type = input_type.getElementType();
4541     // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
4542     // tf.ArgMax doesn't support complex data types, this check can be removed.
4543     if (!input_element_type.isSignlessIntOrFloat()) return failure();
4544 
4545     Location loc = op.getLoc();
4546     Value init_value =
4547         Derived::GetInitialValue(input_element_type, loc, rewriter);
4548 
4549     RankedTensorType output_type =
4550         op.output().getType().template dyn_cast<RankedTensorType>();
4551     if (!output_type) {
4552       return rewriter.notifyMatchFailure(op, "requires known rank");
4553     }
4554 
4555     Type index_element_type = output_type.getElementType();
4556     Value index_init_value =
4557         GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
4558 
4559     RankedTensorType index_type =
4560         RankedTensorType::get(input_type.getShape(), index_element_type);
4561 
4562     llvm::Optional<int64_t> optional_axis =
4563         GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
4564     if (!optional_axis.hasValue())
4565       return rewriter.notifyMatchFailure(op, "required axis");
4566     int64_t axis = optional_axis.getValue();
4567 
4568     IntegerAttr iota_dimension =
4569         IntegerAttr::get(rewriter.getIntegerType(64), axis);
4570     Value index_values =
4571         rewriter.create<IotaOp>(loc, index_type, iota_dimension);
4572 
4573     std::vector<int64_t> dimensions = input_type.getShape();
4574     dimensions.erase(dimensions.begin() + axis);
4575     ArrayRef<int64_t> reduction_result_shape(dimensions);
4576 
4577     Value operands[] = {op.input(), index_values};
4578     Value init_values[] = {init_value, index_init_value};
4579     DenseIntElementsAttr reduction_dimensions =
4580         GetI64ElementsAttr({axis}, &rewriter);
4581 
4582     auto reduction = rewriter.create<ReduceOp>(
4583         loc, llvm::ArrayRef<Value>(operands),
4584         llvm::ArrayRef<Value>(init_values), reduction_dimensions);
4585     StringRef direction = Derived::GetDirection();
4586     BuildArgMinMaxReductionBody(input_element_type, index_element_type,
4587                                 direction, &reduction.body(), &rewriter);
4588 
4589     rewriter.replaceOp(op, {reduction.getResult(1)});
4590     return success();
4591   }
4592 };
4593 
4594 // Converts tensorflow ArgMax op to mhlo operations. The actual
4595 // implementation is in class ConvertArgMinMaxOp:
4596 //
4597 //   %init_index = constant dense<...> : tensor<T>
4598 //   %init = constant dense<...> : tensor<T>
4599 //   %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4600 //                              %init_index) ["mhlo.arg_max"]
4601 class ConvertArgMaxOp
4602     : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
4603  public:
4604   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4605 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4606   static Value GetInitialValue(Type reduce_element_type, Location loc,
4607                                PatternRewriter &rewriter) {
4608     return GetScalarLimitConstOfType(reduce_element_type, loc,
4609                                      hlo::kInfinityLowest, &rewriter);
4610   }
4611 
GetDirection()4612   static StringRef GetDirection() { return "GE"; }
4613 };
4614 
4615 // Converts tensorflow ArgMin op to mhlo operations. The actual
4616 // implementation is in class ConvertArgMinMaxOp:
4617 //
4618 //   %init_index = constant dense<...> : tensor<T>
4619 //   %init = constant dense<...> : tensor<T>
4620 //   %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4621 //                              %init_index) ["mhlo.arg_min"]
4622 class ConvertArgMinOp
4623     : public ConvertArgMinMaxOp<ConvertArgMinOp, TF::ArgMinOp> {
4624  public:
4625   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4626 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4627   static Value GetInitialValue(Type reduce_element_type, Location loc,
4628                                PatternRewriter &rewriter) {
4629     return GetScalarLimitConstOfType(reduce_element_type, loc,
4630                                      hlo::kInfinityMax, &rewriter);
4631   }
4632 
GetDirection()4633   static StringRef GetDirection() { return "LE"; }
4634 };
4635 
4636 // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with
4637 // assignment:
4638 //
4639 //   %result = "mhlo.scatter"(%tensor, %indices, %updates)
4640 //     { dimensions = ... }
4641 //
4642 template <typename Derived, typename OpTy>
4643 class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
4644  public:
4645   using OpRewritePattern<OpTy>::OpRewritePattern;
4646 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4647   LogicalResult matchAndRewrite(OpTy op,
4648                                 PatternRewriter &rewriter) const override {
4649     auto tensor_ty =
4650         op.tensor().getType().template dyn_cast<RankedTensorType>();
4651     auto indices_ty =
4652         op.indices().getType().template dyn_cast<RankedTensorType>();
4653     auto updates_ty =
4654         op.updates().getType().template dyn_cast<RankedTensorType>();
4655 
4656     if (!tensor_ty || !indices_ty || !updates_ty) return failure();
4657     // Last dimension of the indices needs to known at compile time for
4658     // computation of the 'update_window_dims' attribute in the dimensions
4659     // struct.
4660     int64_t num_index_dims = indices_ty.getShape().back();
4661     if (ShapedType::isDynamic(num_index_dims)) return failure();
4662 
4663     int64_t tensor_rank = tensor_ty.getRank();
4664     int64_t indices_rank = indices_ty.getRank();
4665     int64_t updates_rank = updates_ty.getRank();
4666 
4667     int64_t window_dims = tensor_rank - num_index_dims;
4668     auto dims_attr = ScatterDimensionNumbers::get(
4669         GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
4670                                  &rewriter),
4671         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
4672         GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
4673         rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
4674 
4675     Location loc = op.getLoc();
4676     auto scatter = rewriter.create<ScatterOp>(
4677         loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
4678     Derived::BuildScatterBody(tensor_ty.getElementType(),
4679                               &scatter.update_computation(), loc, rewriter);
4680 
4681     rewriter.replaceOp(op, scatter.getResult());
4682     return success();
4683   }
4684 };
4685 
4686 class ConvertTensorScatterUpdateOp
4687     : public ConvertTensorScatterOp<ConvertTensorScatterUpdateOp,
4688                                     TF::TensorScatterUpdateOp> {
4689  public:
4690   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4691 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4692   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4693                                OpBuilder &builder) {
4694     OpBuilder::InsertionGuard guard(builder);
4695     Block *block = builder.createBlock(region);
4696     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4697     block->addArguments({type, type});
4698     builder.create<ReturnOp>(loc, block->getArgument(1));
4699   }
4700 };
4701 
4702 class ConvertTensorScatterAddOp
4703     : public ConvertTensorScatterOp<ConvertTensorScatterAddOp,
4704                                     TF::TensorScatterAddOp> {
4705  public:
4706   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4707 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4708   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4709                                OpBuilder &builder) {
4710     OpBuilder::InsertionGuard guard(builder);
4711     Block *block = builder.createBlock(region);
4712     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4713     block->addArguments({type, type});
4714     auto add_op = builder.create<AddOp>(loc, block->getArgument(0),
4715                                         block->getArgument(1));
4716     builder.create<ReturnOp>(loc, add_op.getResult());
4717   }
4718 };
4719 
4720 class ConvertTensorScatterSubOp
4721     : public ConvertTensorScatterOp<ConvertTensorScatterSubOp,
4722                                     TF::TensorScatterSubOp> {
4723  public:
4724   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4725 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4726   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4727                                OpBuilder &builder) {
4728     OpBuilder::InsertionGuard guard(builder);
4729     Block *block = builder.createBlock(region);
4730     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4731     block->addArguments({type, type});
4732     auto sub_op = builder.create<SubOp>(loc, block->getArgument(0),
4733                                         block->getArgument(1));
4734     builder.create<ReturnOp>(loc, sub_op.getResult());
4735   }
4736 };
4737 
4738 class ConvertTensorScatterMinOp
4739     : public ConvertTensorScatterOp<ConvertTensorScatterMinOp,
4740                                     TF::TensorScatterMinOp> {
4741  public:
4742   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4743 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4744   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4745                                OpBuilder &builder) {
4746     OpBuilder::InsertionGuard guard(builder);
4747     Block *block = builder.createBlock(region);
4748     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4749     block->addArguments({type, type});
4750     auto min_op = builder.create<MinOp>(loc, block->getArgument(0),
4751                                         block->getArgument(1));
4752     builder.create<ReturnOp>(loc, min_op.getResult());
4753   }
4754 };
4755 
4756 class ConvertTensorScatterMaxOp
4757     : public ConvertTensorScatterOp<ConvertTensorScatterMaxOp,
4758                                     TF::TensorScatterMaxOp> {
4759  public:
4760   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4761 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4762   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4763                                OpBuilder &builder) {
4764     OpBuilder::InsertionGuard guard(builder);
4765     Block *block = builder.createBlock(region);
4766     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4767     block->addArguments({type, type});
4768     auto max_op = builder.create<MaxOp>(loc, block->getArgument(0),
4769                                         block->getArgument(1));
4770     builder.create<ReturnOp>(loc, max_op.getResult());
4771   }
4772 };
4773 
4774 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
4775 //   For shape [S1, S2] and multiples [M1, M2],
4776 //     MS1 = M1 * S1; MS2 = M2 * S2
4777 //
4778 //   %broadcast = mhlo.broadcast_in_dim(%input) {
4779 //     broadcast_dimensions = [0, 2]
4780 //   }
4781 //   %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
4782 //      -> tensor<MS1xMS2xf32>
4783 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
4784  public:
4785   using OpRewritePattern::OpRewritePattern;
4786 
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4787   LogicalResult matchAndRewrite(TF::TileOp op,
4788                                 PatternRewriter &rewriter) const override {
4789     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
4790     if (!input_ty || !input_ty.hasStaticShape()) return failure();
4791     ArrayRef<int64_t> input_shape = input_ty.getShape();
4792     Type element_type = input_ty.getElementType();
4793 
4794     DenseIntElementsAttr multiples;
4795     if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
4796         multiples.getType().getRank() != 1)
4797       return failure();
4798 
4799     const int64_t input_shape_size = input_shape.size();
4800     if (multiples.getNumElements() != input_shape_size) return failure();
4801 
4802     SmallVector<int64_t, 8> broadcasted_shape;
4803     SmallVector<int64_t, 4> broadcast_dimensions;
4804     broadcasted_shape.reserve(input_shape.size() * 2);
4805     broadcast_dimensions.reserve(input_shape.size());
4806     for (auto multiple_and_input :
4807          llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4808       int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4809       int64_t input_size = std::get<1>(multiple_and_input);
4810 
4811       if (multiple < 0) return failure();
4812 
4813       // Line input up with the next dimension in broadcasted_shape
4814       // when broadcasting.
4815       int64_t broadcast_dim;
4816       int64_t output_size = input_size * multiple;
4817       if (input_size == 1 || multiple == 1) {
4818         // Special case for when normal broadcasting will just work.
4819         broadcast_dim = broadcasted_shape.size();
4820         broadcasted_shape.push_back(output_size);
4821       } else {
4822         // Tiling will happen for this dimension during the ReshapeOp below.
4823         broadcasted_shape.push_back(multiple);
4824         broadcast_dim = broadcasted_shape.size();
4825         broadcasted_shape.push_back(input_size);
4826       }
4827       broadcast_dimensions.push_back(broadcast_dim);
4828     }
4829     Location loc = op.getLoc();
4830     Type broadcasted_type =
4831         RankedTensorType::get(broadcasted_shape, element_type);
4832     Type output_type = op.getType();
4833 
4834     Value result = rewriter.create<BroadcastInDimOp>(
4835         loc, broadcasted_type, op.input(),
4836         GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4837 
4838     if (output_type != broadcasted_type) {
4839       result = rewriter.create<ReshapeOp>(loc, output_type, result);
4840     }
4841 
4842     rewriter.replaceOp(op, {result});
4843 
4844     return success();
4845   }
4846 };
4847 
4848 // Converts the tf.TileOp op into mhlo.dynamic_reshape
4849 // TODO(disc): To recover static special case's performance with folding and
4850 // canonicalization.
4851 class ConvertTileOpDynamic : public OpRewritePattern<TF::TileOp> {
4852  public:
4853   using OpRewritePattern::OpRewritePattern;
4854   // clang-format off
4855   // Converts Tile op to HLO DBroadcastInDim and DReshape ops.
4856   //   For shape [S1, S2] and multiples [M1, M2],
4857   //     MS1 = M1 * S1; MS2 = M2 * S2
4858   //
4859   //   %out_dim_size = [S1, M1, S2, M2]
4860   //   %broadcast_dimensions = [1, 3];
4861   //   %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions);
4862   //   %shape = [MS1, MS2]
4863   //   %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor<S1xM1xS2xM2xf32>) -> tensor<MS1xMS2xf32>
4864   // clang-format on
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4865   LogicalResult matchAndRewrite(TF::TileOp op,
4866                                 PatternRewriter &rewriter) const final {
4867     Location loc = op.getLoc();
4868     Value input = op.input();
4869     Value multiples = op.multiples();
4870     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
4871     if (!input_ty) return failure();
4872     // TODO(disc): Remove this constraint once fold and canonicalization
4873     // implemented.
4874     if (input_ty.hasStaticShape()) return failure();
4875 
4876     Type element_type = input_ty.getElementType();
4877     int64_t input_rank = input_ty.getRank();
4878     SmallVector<Value, 4> input_shape_values;
4879     for (int64_t i = 0; i < input_rank; ++i) {
4880       auto dim_size = input_ty.getDimSize(i);
4881       if (dim_size == ShapedType::kDynamicSize) {
4882         input_shape_values.push_back(
4883             rewriter.create<tensor::DimOp>(loc, input, i));
4884       } else {
4885         input_shape_values.push_back(
4886             rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(dim_size)));
4887       }
4888     }
4889 
4890     auto multiples_ty = multiples.getType().dyn_cast<RankedTensorType>();
4891     int64_t multiples_rank = multiples_ty.getRank();
4892     // rank of multiples input of tf.TileOp must be 1
4893     if (multiples_rank != 1) return failure();
4894     // multiples input of tf.TileOp must be fixed shaped
4895     if ((!multiples_ty.hasStaticShape()) ||
4896         (multiples_ty.getDimSize(0) != input_rank)) {
4897       return failure();
4898     }
4899     // %out_dim_size
4900     SmallVector<Value, 4> out_dim_size;
4901     out_dim_size.reserve(input_rank * 2);
4902     for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4903       Value index =
4904           rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(dim_idx));
4905       Value multiples_size =
4906           rewriter.create<tensor::ExtractOp>(loc, multiples, ValueRange{index});
4907       Value multiples_size_casted = rewriter.create<IndexCastOp>(
4908           loc, rewriter.getIndexType(), multiples_size);
4909       out_dim_size.push_back(multiples_size_casted);
4910       out_dim_size.push_back(input_shape_values[dim_idx]);
4911     }
4912     SmallVector<int64_t, 4> broadcast_dimensions;
4913     broadcast_dimensions.reserve(input_rank);
4914     for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4915       broadcast_dimensions.push_back(1 + 2 * dim_idx);
4916     }
4917     auto broadcast_dims_attr =
4918         GetI64ElementsAttr(broadcast_dimensions, &rewriter);
4919 
4920     Value out_dim_size_tensor = rewriter.create<tensor::FromElementsOp>(
4921         loc, rewriter.getIndexType(), out_dim_size);
4922     SmallVector<int64_t, 4> broadcast_shape(input_rank * 2,
4923                                             ShapedType::kDynamicSize);
4924     RankedTensorType broadcast_type =
4925         RankedTensorType::get(broadcast_shape, element_type);
4926     Value broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
4927         loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr);
4928 
4929     // %shape = [MS1, MS2]
4930     SmallVector<Value, 4> shape_values;
4931     shape_values.reserve(input_rank);
4932     for (int64_t i = 0; i < input_rank; ++i) {
4933       Value dim_size_value = rewriter.create<mlir::MulIOp>(
4934           loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]);
4935       shape_values.push_back(dim_size_value);
4936     }
4937     Value shape = rewriter.create<tensor::FromElementsOp>(
4938         loc, rewriter.getIndexType(), shape_values);
4939     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
4940                                                         broadcast, shape);
4941     return success();
4942   }
4943 };
4944 
4945 template <typename OpTy, int num_dims>
4946 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4947  public:
4948   using OpRewritePattern<OpTy>::OpRewritePattern;
4949 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4950   LogicalResult matchAndRewrite(OpTy op,
4951                                 PatternRewriter &rewriter) const override {
4952     Location loc = op.getLoc();
4953 
4954     Type element_type =
4955         op.orig_input().getType().template cast<TensorType>().getElementType();
4956 
4957     // Compute paddings using the original input and kernel shape and strides.
4958     // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4959     // ReduceWindow op.
4960     auto input_ty =
4961         op.orig_input().getType().template dyn_cast<RankedTensorType>();
4962     if (!input_ty) return failure();
4963     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4964         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4965 
4966     auto result = rewriter.create<SelectAndScatterOp>(
4967         loc, op.getType(), op.orig_input(), op.grad(),
4968         GetScalarConstOfType(element_type, loc, 0, &rewriter),
4969         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4970         paddings_attr);
4971 
4972     BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4973     {
4974       OpBuilder::InsertionGuard guard(rewriter);
4975       Block *block = rewriter.createBlock(&result.select());
4976 
4977       // Block arguments are scalars of the given element type.
4978       Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4979       block->addArguments({type, type});
4980 
4981       auto reducer = rewriter.create<CompareOp>(
4982           loc, block->getArgument(0), block->getArgument(1),
4983           StringAttr::get(rewriter.getContext(), "GE"));
4984       rewriter.create<ReturnOp>(loc, reducer.getResult());
4985     }
4986 
4987     rewriter.replaceOp(op, {result});
4988 
4989     return success();
4990   }
4991 };
4992 
4993 using ConvertMaxPool2DGradOp =
4994     ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4995 using ConvertMaxPool3DGradOp =
4996     ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4997 
4998 // Converts tf.Conv?DBackpropInputOp into:
4999 //   %rev_filter = "mhlo.reverse"(%filter)
5000 //   %result = "mhlo.convolution"(%out_backprop, %rev_filter)
5001 template <typename OpTy, int num_spatial_dims>
5002 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
5003  public:
5004   using OpRewritePattern<OpTy>::OpRewritePattern;
5005 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5006   LogicalResult matchAndRewrite(OpTy op,
5007                                 PatternRewriter &rewriter) const override {
5008     // Unpack all of the attributes.
5009     tensorflow::TensorFormat data_format;
5010     if (!FormatFromString(op.data_format().str(), &data_format))
5011       return op.emitOpError("invalid data format");
5012 
5013     tensorflow::Padding padding;
5014     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5015       return failure();
5016 
5017     auto out_backprop_ty =
5018         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5019     auto filter_ty =
5020         op.filter().getType().template dyn_cast<RankedTensorType>();
5021 
5022     for (RankedTensorType ty : {out_backprop_ty, filter_ty})
5023       if (!ty || !ty.hasStaticShape()) return failure();
5024 
5025     DenseIntElementsAttr input_shape_attr;
5026     if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
5027         input_shape_attr.getType().getRank() != 1)
5028       return failure();
5029 
5030     auto input_shape = input_shape_attr.getValues<int32_t>();
5031 
5032     auto dilations_attr = GetI64ElementsAttr(op.dilations());
5033     std::vector<int> dilations{
5034         dilations_attr.template getValues<int64_t>().begin(),
5035         dilations_attr.template getValues<int64_t>().end()};
5036     auto strides_attr = GetI64ElementsAttr(op.strides());
5037     std::vector<tensorflow::int32> strides{
5038         strides_attr.template getValues<int64_t>().begin(),
5039         strides_attr.template getValues<int64_t>().end()};
5040 
5041     std::vector<tensorflow::int64> explicit_paddings;
5042     if (padding == tensorflow::Padding::EXPLICIT) {
5043       // EXPLICIT padding mode and the associated attribute is limited to
5044       // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
5045       // op.explicit_paddings() attribute getter.
5046       ArrayRef<Attribute> explicit_paddings_attr =
5047           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5048       explicit_paddings.reserve(explicit_paddings_attr.size());
5049       for (Attribute explicit_padding : explicit_paddings_attr)
5050         explicit_paddings.push_back(
5051             explicit_padding.cast<IntegerAttr>().getInt());
5052     }
5053 
5054     constexpr int num_dims = num_spatial_dims + 2;
5055     ArrayRef<int64_t> filter_shape = filter_ty.getShape();
5056 
5057     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5058     tensorflow::ConvBackpropDimensions dims;
5059     if (!tensorflow::ConvBackpropComputeDimensionsV2(
5060              /*label=*/"", num_spatial_dims,
5061              ToTensorShape<int32_t, num_dims>(input_shape),
5062              ToTensorShape<int64_t, num_dims>(filter_shape),
5063              ToTensorShape<int64_t, num_dims>(out_backprop_ty.getShape()),
5064              dilations, strides, padding, explicit_paddings, data_format, &dims)
5065              .ok()) {
5066       return failure();
5067     }
5068 
5069     // Compute ConvDimensionNumbers, dilation, and padding.
5070     SmallVector<int64_t, num_spatial_dims> spatial_dims;
5071     SmallVector<int64_t, num_spatial_dims> lhs_dilation;
5072     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5073     SmallVector<int64_t, num_spatial_dims * 2> paddings;
5074 
5075     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5076       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
5077       spatial_dims.push_back(dim);
5078       const auto &spatial_dim_i = dims.spatial_dims[i];
5079       lhs_dilation.push_back(spatial_dim_i.stride);
5080       rhs_dilation.push_back(dilations[dim]);
5081       paddings.push_back(spatial_dim_i.pad_before);
5082       paddings.push_back(spatial_dim_i.pad_after);
5083     }
5084 
5085     RankedTensorType paddings_ty = RankedTensorType::get(
5086         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5087     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5088 
5089     auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
5090 
5091     Value filter = op.filter();
5092 
5093     const int feature_dim =
5094         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5095     const int64_t in_depth = *(input_shape.begin() + feature_dim);
5096     const int64_t filter_in_depth = filter_shape[num_spatial_dims];
5097     const int64_t feature_group_count = in_depth / filter_in_depth;
5098 
5099     if (feature_group_count != 1) {
5100       // 1. Reshape filter from
5101       //   [H, W, ..., filter_in_depth, out_depth] to
5102       //   [H, W, ..., filter_in_depth, G, out_depth / G].
5103       auto new_shape = llvm::to_vector<6>(filter_shape);
5104       new_shape.back() = feature_group_count;
5105       new_shape.push_back(filter_shape.back() / feature_group_count);
5106       Type filter_element_ty = filter_ty.getElementType();
5107       auto ty = RankedTensorType::get(new_shape, filter_element_ty);
5108       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5109 
5110       // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
5111       llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
5112       std::iota(perm.begin(), perm.end(), 0);
5113       std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
5114       std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
5115       ty = RankedTensorType::get(new_shape, filter_element_ty);
5116       filter = rewriter.create<TransposeOp>(
5117           op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
5118 
5119       // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
5120       new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
5121       new_shape[num_spatial_dims + 1] = new_shape.back();
5122       new_shape.pop_back();
5123       ty = RankedTensorType::get(new_shape, filter_element_ty);
5124       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5125     }
5126 
5127     auto kernel_spatial_dims_attr =
5128         GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
5129 
5130     // Mirror the filter in the spatial dimensions.
5131     filter = rewriter.create<ReverseOp>(op.getLoc(), filter,
5132                                         kernel_spatial_dims_attr);
5133 
5134     const int batch_dim =
5135         tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5136     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
5137     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
5138 
5139     // activation gradients
5140     //   = gradients (with padding and dilation) <conv> mirrored_weights
5141     Value result = rewriter.create<ConvOp>(
5142         op.getLoc(), op.getType(), op.out_backprop(), filter,
5143         /*window_strides=*/
5144         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5145                                    &rewriter),
5146         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
5147         GetI64ElementsAttr(rhs_dilation, &rewriter),
5148         /*window_reversal=*/nullptr,
5149         ConvDimensionNumbers::get(
5150             /*input_batch_dimension=*/batch_dim_attr,
5151             /*input_feature_dimension=*/feature_dim_attr,
5152             /*input_spatial_dimensions=*/spatial_dims_attr,
5153             // TF filter shape is [ H, W, ..., inC, outC ]
5154             // Transpose the input and output features for computing the
5155             // gradient.
5156             /*kernel_input_feature_dimension=*/
5157             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
5158             /*kernel_output_feature_dimension=*/
5159             rewriter.getI64IntegerAttr(num_spatial_dims),
5160             /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
5161             /*output_batch_dimension=*/batch_dim_attr,
5162             /*output_feature_dimension=*/feature_dim_attr,
5163             /*output_spatial_dimensions=*/spatial_dims_attr,
5164             rewriter.getContext()),
5165         rewriter.getI64IntegerAttr(feature_group_count),
5166         /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
5167         /*precision_config=*/ArrayAttr());
5168 
5169     rewriter.replaceOp(op, {result});
5170 
5171     return success();
5172   }
5173 };
5174 
5175 using ConvertConv2DBackpropInputOp =
5176     ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
5177                                /*num_spatial_dims=*/2>;
5178 using ConvertConv3DBackpropInputOp =
5179     ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
5180                                /*num_spatial_dims=*/3>;
5181 
5182 // Converts tf.Conv?DBackpropFilterOp into:
5183 //   %result = "mhlo.convolution"(%input, %out_backprop)
5184 template <typename OpTy, int num_spatial_dims>
5185 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
5186  public:
5187   using OpRewritePattern<OpTy>::OpRewritePattern;
5188 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5189   LogicalResult matchAndRewrite(OpTy op,
5190                                 PatternRewriter &rewriter) const override {
5191     // Unpack all of the attributes.
5192     tensorflow::TensorFormat data_format;
5193     if (!FormatFromString(op.data_format().str(), &data_format))
5194       return op.emitOpError("invalid data format");
5195 
5196     tensorflow::Padding padding;
5197     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5198       return failure();
5199 
5200     auto out_backprop_ty =
5201         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5202     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
5203 
5204     for (RankedTensorType ty : {out_backprop_ty, input_ty})
5205       if (!ty || !ty.hasStaticShape()) return failure();
5206 
5207     ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
5208     ArrayRef<int64_t> input_shape = input_ty.getShape();
5209 
5210     DenseIntElementsAttr filter_shape_attr;
5211     if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
5212         filter_shape_attr.getType().getRank() != 1)
5213       return failure();
5214 
5215     auto dilations_attr = GetI64ElementsAttr(op.dilations());
5216     std::vector<int> dilations{
5217         dilations_attr.template getValues<int64_t>().begin(),
5218         dilations_attr.template getValues<int64_t>().end()};
5219     auto strides_attr = GetI64ElementsAttr(op.strides());
5220     std::vector<tensorflow::int32> strides{
5221         strides_attr.template getValues<int64_t>().begin(),
5222         strides_attr.template getValues<int64_t>().end()};
5223 
5224     std::vector<tensorflow::int64> explicit_paddings;
5225     if (padding == tensorflow::Padding::EXPLICIT) {
5226       // EXPLICIT padding mode and the associated attribute is limited to
5227       // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
5228       // op.explicit_paddings() attribute getter.
5229       ArrayRef<Attribute> explicit_paddings_attr =
5230           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5231       explicit_paddings.reserve(explicit_paddings_attr.size());
5232       for (Attribute explicit_padding : explicit_paddings_attr)
5233         explicit_paddings.push_back(
5234             explicit_padding.cast<IntegerAttr>().getInt());
5235     }
5236 
5237     constexpr int num_dims = num_spatial_dims + 2;
5238     auto filter_shape = filter_shape_attr.getValues<int32_t>();
5239 
5240     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5241     tensorflow::ConvBackpropDimensions dims;
5242     if (!tensorflow::ConvBackpropComputeDimensionsV2(
5243              /*label=*/"", num_spatial_dims,
5244              ToTensorShape<int64_t, num_dims>(input_shape),
5245              ToTensorShape<int32_t, num_dims>(filter_shape),
5246              ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
5247              strides, padding, explicit_paddings, data_format, &dims)
5248              .ok()) {
5249       return failure();
5250     }
5251 
5252     // The activations (inputs) form the LHS of the convolution.
5253     // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
5254     // For the gradient computation, we need to:
5255     // 1. In the case of group convolution, move the num_groups dimension before
5256     // the batch dimension
5257     // 2. Swap the roles of the batch and feature dimensions.
5258     const int feature_dim =
5259         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5260     const int64_t in_depth = input_shape[feature_dim];
5261     const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
5262     const int64_t batch_group_count = in_depth / filter_in_depth;
5263 
5264     // Compute ConvDimensionNumbers, dilation, and padding.
5265     SmallVector<int64_t, num_spatial_dims> spatial_dims;
5266     SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
5267     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5268     SmallVector<int64_t, num_spatial_dims * 2> paddings;
5269     SmallVector<int64_t, num_spatial_dims> window_strides;
5270 
5271     // The filter gradients are computed by a convolution of the input
5272     // activations and the output gradients, with some appropriate padding.
5273     // See the comment at the top of conv_grad_ops.h for details.
5274 
5275     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5276       const int64_t dim =
5277           tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
5278       kernel_spatial_dims.push_back(dim);
5279       // Besides padding the input, we will also expand output_rows to
5280       //    expanded_out_rows = (output_rows - 1) * stride + 1
5281       // with zeros in between:
5282       //
5283       //      a . . . b . . . c . . . d . . . e
5284       //
5285       // This is done by specifying the window dilation factors in the
5286       // convolution HLO below.
5287       const auto &spatial_dim_i = dims.spatial_dims[i];
5288       rhs_dilation.push_back(spatial_dim_i.stride);
5289       window_strides.push_back(dilations[dim]);
5290 
5291       // We will also need to pad the input with zeros such that after the
5292       // convolution, we get the right size for the filter.
5293       // The padded_in_rows should be such that when we convolve this with the
5294       // expanded_out_rows as a filter, we should get filter_rows back.
5295 
5296       const int64_t padded_in_size =
5297           spatial_dim_i.expanded_output_size +
5298           (spatial_dim_i.filter_size - 1) * dilations[dim];
5299 
5300       // However it can be smaller than input_rows: in this
5301       // case it means some of the inputs are not used.
5302       //
5303       // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
5304       //
5305       // INPUT =  [ A  B  C ]
5306       //
5307       // FILTER = [ x y ]
5308       //
5309       // and the output will only have one column: a = A * x + B * y
5310       //
5311       // and input "C" is not used at all.
5312       //
5313       // We apply negative padding in this case.
5314       const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
5315 
5316       // + For the EXPLICIT padding, we pad the top/left side with the explicit
5317       //   padding and pad the bottom/right side with the remaining space.
5318       // + For the VALID padding, we don't pad anything on the top/left side
5319       //   and pad the bottom/right side with the remaining space.
5320       // + For the SAME padding, we pad top/left side the same as bottom/right
5321       //   side.
5322       //
5323       // In addition, if the padded input size is smaller than the input size,
5324       // we need to ignore some training elements of the input. We do this by
5325       // applying negative padding on the right/bottom.
5326       const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
5327                                      ? explicit_paddings[2 * dim]
5328                                  : padding == tensorflow::Padding::SAME
5329                                      ? std::max<int64_t>(pad_total / 2, 0)
5330                                      : 0;
5331       paddings.push_back(pad_before);
5332       paddings.push_back(pad_total - pad_before);
5333     }
5334 
5335     RankedTensorType paddings_ty = RankedTensorType::get(
5336         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5337     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5338     auto kernel_spatial_dims_attr =
5339         GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
5340 
5341     const int batch_dim =
5342         tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5343     auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
5344     auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
5345 
5346     Value result = rewriter.create<ConvOp>(
5347         op.getLoc(), op.getType(), op.input(), op.out_backprop(),
5348         /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
5349         /*padding=*/paddings_attr, /*lhs_dilation=*/
5350         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5351                                    &rewriter),
5352         GetI64ElementsAttr(rhs_dilation, &rewriter),
5353         /*window_reversal=*/nullptr,
5354         ConvDimensionNumbers::get(
5355             // Swap batch_dim and feature_dim in the activations.
5356             /*input_batch_dimension=*/feature_dim_attr,
5357             /*input_feature_dimension=*/batch_dim_attr,
5358             /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
5359             // The gradients become the RHS of the convolution.
5360             // The gradients have shape [batch, out_rows, out_cols, ...,
5361             // out_depth] where the batch becomes the input feature for the
5362             // convolution.
5363             /*kernel_input_feature_dimension=*/batch_dim_attr,
5364             /*kernel_output_feature_dimension=*/feature_dim_attr,
5365             /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
5366             /*output_batch_dimension=*/
5367             rewriter.getI64IntegerAttr(num_spatial_dims),
5368             /*output_feature_dimension=*/
5369             rewriter.getI64IntegerAttr(num_spatial_dims + 1),
5370             /*output_spatial_dimensions=*/
5371             GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter),
5372             rewriter.getContext()),
5373         /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
5374         rewriter.getI64IntegerAttr(batch_group_count),
5375         /*precision_config=*/ArrayAttr());
5376 
5377     rewriter.replaceOp(op, {result});
5378 
5379     return success();
5380   }
5381 };
5382 
5383 using ConvertConv2DBackpropFilterOp =
5384     ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
5385                                 /*num_spatial_dims=*/2>;
5386 using ConvertConv3DBackpropFilterOp =
5387     ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
5388                                 /*num_spatial_dims=*/3>;
5389 
5390 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
5391  public:
5392   using OpRewritePattern::OpRewritePattern;
5393 
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const5394   LogicalResult matchAndRewrite(TF::OneHotOp op,
5395                                 PatternRewriter &rewriter) const override {
5396     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
5397     if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
5398     ArrayRef<int64_t> indices_shape = indices_ty.getShape();
5399     Type element_type = indices_ty.getElementType();
5400 
5401     DenseIntElementsAttr depth_attr;
5402     if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
5403       return failure();
5404     }
5405 
5406     int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
5407     int64_t axis = op.axis();
5408     if (axis == -1) axis = indices_shape.size();
5409 
5410     llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
5411     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
5412     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
5413 
5414     llvm::SmallVector<int64_t, 4> output_dims =
5415         llvm::to_vector<4>(indices_shape);
5416     output_dims.insert(output_dims.begin() + axis, depth);
5417 
5418     Location loc = op.getLoc();
5419 
5420     // The iota result is the effective output shape of the computation,
5421     // and indices must be broadcast into it. At this point, this computation
5422     // would need to be reworked quite a bit to support dynamic shapes, so
5423     // just using static broadcasting.
5424     auto index_type = RankedTensorType::get(output_dims, element_type);
5425     auto iota = rewriter.create<IotaOp>(
5426         loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
5427     auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
5428         loc, index_type, op.indices(),
5429         GetI64ElementsAttr(broadcast_dims, &rewriter));
5430 
5431     Value compare = rewriter.create<mhlo::CompareOp>(
5432         loc, broadcast_indices, iota,
5433         StringAttr::get(rewriter.getContext(), "EQ"));
5434     Value on_value = rewriter.create<BroadcastOp>(
5435         loc, op.getType(), op.on_value(),
5436         GetI64ElementsAttr(output_dims, &rewriter));
5437     Value off_value = rewriter.create<BroadcastOp>(
5438         loc, op.getType(), op.off_value(),
5439         GetI64ElementsAttr(output_dims, &rewriter));
5440     Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
5441                                              on_value, off_value);
5442 
5443     rewriter.replaceOp(op, {result});
5444 
5445     return success();
5446   }
5447 };
5448 
5449 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
5450 // get_tuple_element ops.
5451 //
5452 // All HLO infeed ops expect a HLO token type operand and produce a tuple
5453 // containing a token. This HLO token type is used to order multiple infeed
5454 // operations within a computation. The token type can come from other
5455 // infeed/outfeed/send/recv ops or can be generated using create_token op with
5456 // no operands. Here we emit a create_token op to generate the token type
5457 // operand of infeed.
5458 //
5459 // For example the following IR:
5460 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
5461 //
5462 // would be lowered to
5463 //
5464 // %token = "mhlo.create_token"() : () -> !mhlo.token
5465 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
5466 //      (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
5467 //      !mhlo.token>
5468 // %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0}
5469 // %0#0 = "mhlo.get_tuple_element"(%data) {index = 0}
5470 // %0#1 = "mhlo.get_tuple_element"(%data) {index = 1}
5471 //
5472 class ConvertInfeedDequeueTupleOp
5473     : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
5474  public:
5475   using OpRewritePattern::OpRewritePattern;
5476 
GetTPUInfeedLayoutFromAPI(RankedTensorType t) const5477   FailureOr<std::vector<int64_t>> GetTPUInfeedLayoutFromAPI(
5478       RankedTensorType t) const {
5479     // Call the TPU API to determine the right infeed layout. Note that
5480     // this can fail if we're not running on a TPU-enabled node.
5481     // TODO(kramm): Move this into a separate pass. See b/181724526
5482     xla::Shape old_shape = xla::TypeToShape(t);
5483     XLA_Shape old_shape_c = {};
5484     XLA_Shape new_shape_c = {};
5485     TfTpu_ExecutorApiFn *executor = tensorflow::tpu::ExecutorApiFn();
5486     if (!tensorflow::tpu::IsInitialized(executor)) {
5487       return failure();
5488     }
5489     ApiConverter::ToC(old_shape, &old_shape_c);
5490     executor->TpuTransferManager_GetInfeedLayoutFn(&old_shape_c, &new_shape_c);
5491     xla::Shape new_shape = ApiConverter::FromC(&new_shape_c);
5492     ApiConverter::Free(&old_shape_c);
5493     ApiConverter::Free(&new_shape_c);
5494 
5495     xla::Layout layout = new_shape.layout();
5496     auto minor_to_major = layout.minor_to_major();
5497     return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
5498   }
5499 
GetLayout(const Type & type,PatternRewriter & rewriter) const5500   FailureOr<Attribute> GetLayout(const Type &type,
5501                                  PatternRewriter &rewriter) const {
5502     auto i64_type = rewriter.getIntegerType(64);
5503     if (type.isa<TupleType>()) {
5504       TupleType tuple_type = type.dyn_cast<TupleType>();
5505       std::vector<mlir::Attribute> v;
5506       for (const mlir::Type &t : tuple_type.getTypes()) {
5507         auto layout = GetLayout(t, rewriter);
5508         if (failed(layout)) return failure();
5509         v.push_back(layout.getValue());
5510       }
5511       ArrayRef<Attribute> shape(v);
5512       return rewriter.getArrayAttr(shape);
5513     } else if (RankedTensorType t = type.dyn_cast<RankedTensorType>()) {
5514       if (!t.hasStaticShape()) return failure();
5515       auto layout = GetTPUInfeedLayoutFromAPI(t);
5516       std::vector<int64_t> minor_to_major;
5517       if (succeeded(layout)) {
5518         minor_to_major = layout.getValue();
5519       } else {
5520         /* If we're not running on a TPU node, we might not be able to
5521          * actually call the part of the TPU API that gives us layout.
5522          * This happens e.g. for unit tests. Below we just create a reasonable
5523          * layout.  We sort by dimension size, which makes the layout agree with
5524          * the "correct" TPU layout in surprisingly many cases.
5525          * Note that the corresponding InfeedEnqueue op will be generated
5526          * through another path, and might still generate an (incompatible)
5527          * layout using the TPU API. Running legalize_tf.cc on non-TPU nodes
5528          * thus is a potential source of bugs.
5529          */
5530         minor_to_major.resize(t.getRank());
5531         std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
5532         std::sort(minor_to_major.begin(), minor_to_major.end(),
5533                   [=](int64_t a, int64_t b) {
5534                     int da = t.getDimSize(a);
5535                     int db = t.getDimSize(b);
5536                     return da > db || (da == db && a > b);
5537                   });
5538       }
5539       std::vector<Attribute> elements;
5540       for (int64_t i = 0; i < minor_to_major.size(); i++) {
5541         elements.push_back(
5542             rewriter.getIntegerAttr(i64_type, minor_to_major[i]));
5543       }
5544       return rewriter.getArrayAttr(elements);
5545     } else {
5546       return rewriter.getUnitAttr();  // e.g. tokens
5547     }
5548   }
5549 
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const5550   LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
5551                                 PatternRewriter &rewriter) const override {
5552     std::vector<Type> result_types(op.outputs().size());
5553     for (auto idx_and_output : llvm::enumerate(op.outputs())) {
5554       result_types[idx_and_output.index()] = (idx_and_output.value().getType());
5555     }
5556     // Infeed takes a single token operand. Generate the token using
5557     // create_token op to pass to the infeed op.
5558     auto token = rewriter.create<CreateTokenOp>(
5559         op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
5560 
5561     // Emit infeed op.
5562     // The result type of infeed is a tuple(tuple(result types), token type).
5563     auto data_tuple_type =
5564         mlir::TupleType::get(rewriter.getContext(), result_types);
5565     auto data_and_token_type = mlir::TupleType::get(
5566         rewriter.getContext(), {data_tuple_type, token.getType()});
5567 
5568     auto layout = GetLayout(data_and_token_type, rewriter);
5569     if (failed(layout)) return failure();
5570 
5571     auto data_and_token = rewriter.create<InfeedOp>(
5572         op.getLoc(), data_and_token_type, token,
5573         /*infeed_config=*/rewriter.getStringAttr(""),
5574         /*layout=*/layout.getValue().cast<ArrayAttr>());
5575 
5576     if (op._XlaSharding().hasValue()) {
5577       // _XlaSharding attribute in TF is a serialized string of the OpSharding
5578       // proto, so convert to a text form here.
5579       ::xla::OpSharding sharding_proto;
5580       if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
5581         return failure();
5582 
5583       // Token is a control signal and not a real data, so arbitrarily assign
5584       // the token to device 0.
5585       if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
5586         *sharding_proto.add_tuple_shardings() =
5587             ::xla::sharding_builder::AssignDevice(0);
5588         data_and_token->setAttr(
5589             kShardingAttr,
5590             rewriter.getStringAttr(sharding_proto.SerializeAsString()));
5591       } else {
5592         data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
5593       }
5594     }
5595 
5596     // The infeed instruction produces a tuple of the infeed data and a token
5597     // type. Emit get_tuple_element to get infeed data tuple.
5598     auto data_tuple = rewriter.create<GetTupleElementOp>(
5599         op.getLoc(), data_tuple_type, data_and_token,
5600         rewriter.getI32IntegerAttr(0));
5601 
5602     // Emit get_tuple_element for each result.
5603     std::vector<Value> results;
5604     for (auto idx_and_type : llvm::enumerate(result_types)) {
5605       auto tuple_element = rewriter.create<GetTupleElementOp>(
5606           op.getLoc(), idx_and_type.value(), data_tuple,
5607           rewriter.getI32IntegerAttr(idx_and_type.index()));
5608       results.push_back(tuple_element);
5609     }
5610     rewriter.replaceOp(op, ValueRange(results));
5611     return success();
5612   }
5613 };
5614 
5615 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
5616 // ops.
5617 //
5618 // XLA HLO outfeed op expects a token, which we generate by emitting an
5619 // create_token op.
5620 //
5621 // For example the following IR:
5622 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5623 //      ()
5624 //
5625 // would be lowered to
5626 //
5627 // %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5628 //      tuple<tensor<3xi32>, tensor<4xf32>>
5629 // %token = "mhlo.create_token"() : () -> !mhlo.token
5630 // %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
5631 //      (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
5632 //
5633 class ConvertOutfeedEnqueueTupleOp
5634     : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
5635  public:
5636   using OpRewritePattern::OpRewritePattern;
5637 
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const5638   LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
5639                                 PatternRewriter &rewriter) const override {
5640     auto token_type = mhlo::TokenType::get(rewriter.getContext());
5641     auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
5642     auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
5643     rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
5644                                /*outfeed_config=*/rewriter.getStringAttr(""));
5645     rewriter.eraseOp(op);
5646     return success();
5647   }
5648 };
5649 
5650 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
5651 //
5652 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
5653 // the top K components' values and indices. This is translated into a few
5654 // ops in XLA HLO: first generating an integer sequence for the indices,
5655 // then sort both the original input tensor and the indices togheter, and
5656 // at last slice out the top K components.
5657 //
5658 // For example, for the following IR:
5659 //
5660 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
5661 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
5662 //                                 (tensor<16x8xf32>, tensor<16x8xi32>)
5663 //
5664 // We will get:
5665 //
5666 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
5667 // %2 = "mhlo.sort"(%input, %1) ( {
5668 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
5669 //      %arg3: tensor<i32>, %arg4: tensor<i32>):
5670 //   %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
5671 //   "mhlo.return"(%7) : (tensor<i1>) -> ()
5672 // }) {dimension = 1 : i64, is_stable = true} : ...
5673 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
5674 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
5675 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
5676 //                           start_indices dense<0> : tensor<2xi64>,
5677 //                           strides = dense<1> : tensor<2xi64>} :
5678 //                              (tensor<16x16xf32>) -> tensor<16x8xf32>
5679 // %6 = "mhlo.slice"(%4) ...
5680 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
5681  public:
5682   using OpRewritePattern::OpRewritePattern;
5683 
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const5684   LogicalResult matchAndRewrite(TF::TopKV2Op op,
5685                                 PatternRewriter &rewriter) const override {
5686     // We can only match when the `k` operand is a constant scalar.
5687     DenseIntElementsAttr k_attr;
5688     if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
5689 
5690     // The last dimension of the input tensor's shape should be known so we can
5691     // have clamped end_indices for slices.
5692     TensorType input_type = op.input().getType().cast<TensorType>();
5693     if (!input_type.hasRank()) return failure();
5694     int64_t input_rank = input_type.getRank();
5695     int64_t last_dim_index = input_rank - 1;
5696     int64_t last_dim_size = input_type.getDimSize(last_dim_index);
5697     if (last_dim_size == ShapedType::kDynamicSize) return failure();
5698 
5699     // Create an Itoa op for indices.
5700     auto i32_type = rewriter.getIntegerType(32);
5701     Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
5702     Value iota_op = rewriter.create<mhlo::IotaOp>(
5703         op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
5704 
5705     // Create the sort op. It takes two inputs, one for the original input, the
5706     // other for the indices.
5707     auto sort_op = rewriter.create<mhlo::SortOp>(
5708         op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
5709         /*is_stable=*/true);
5710 
5711     // Use TOTALORDER comparison type instead of the default comparison if the
5712     // element type is of type float.
5713     llvm::Optional<StringRef> compare_type;
5714     if (input_type.getElementType().isa<FloatType>())
5715       compare_type.emplace("TOTALORDER");
5716     BuildSortComparisonBody({input_type.getElementType(), i32_type},
5717                             /*direction=*/"GT", compare_type,
5718                             &sort_op.comparator(), &rewriter);
5719 
5720     // Get the sorted input and index tuple element.
5721     auto tuple_first_element = sort_op.getResult(0);
5722     auto tuple_second_element = sort_op.getResult(1);
5723 
5724     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
5725     auto end_indices = llvm::to_vector<4>(input_type.getShape());
5726     end_indices.back() =
5727         std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
5728     SmallVector<int64_t, 4> strides(input_rank, 1);
5729 
5730     // Get the slice for the top K elements.
5731 
5732     Value values = rewriter.create<mhlo::SliceOp>(
5733         op.getLoc(), tuple_first_element,
5734         GetI64ElementsAttr(begin_indices, &rewriter),
5735         GetI64ElementsAttr(end_indices, &rewriter),
5736         GetI64ElementsAttr(strides, &rewriter));
5737 
5738     Value indices = rewriter.create<mhlo::SliceOp>(
5739         op.getLoc(), tuple_second_element,
5740         GetI64ElementsAttr(begin_indices, &rewriter),
5741         GetI64ElementsAttr(end_indices, &rewriter),
5742         GetI64ElementsAttr(strides, &rewriter));
5743 
5744     rewriter.replaceOp(op, {values, indices});
5745     return success();
5746   }
5747 };
5748 
5749 // Converts tf.Unpack to a series of XLA HLO slice ops.
5750 //
5751 // Each slice takes one element along the dimension to unpack and takes the full
5752 // range for all other dimensions. Each slice is then reshaped to drop the
5753 // dimension to unpack (which is always of size 1).
5754 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
5755 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
5756  public:
5757   using OpRewritePattern::OpRewritePattern;
5758 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5759   LogicalResult matchAndRewrite(TF::UnpackOp op,
5760                                 PatternRewriter &rewriter) const override {
5761     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5762     if (!value_type) return failure();
5763 
5764     int64_t value_rank = value_type.getRank();
5765     int64_t axis = op.axis();
5766     if (axis < 0) axis += value_rank;
5767 
5768     // Parameters for constructing each slice.
5769     SmallVector<int64_t, 4> begin_indices(value_rank, 0);
5770     auto end_indices = llvm::to_vector<4>(value_type.getShape());
5771     SmallVector<int64_t, 4> strides(value_rank, 1);
5772 
5773     // All HLO slice+squeeze results used to replace the original tf.Unpack op.
5774     SmallVector<Value, 4> results;
5775     results.reserve(op.getNumResults());
5776 
5777     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
5778       begin_indices[axis] = i;
5779       end_indices[axis] = i + 1;
5780 
5781       auto slice_op = rewriter.create<mhlo::SliceOp>(
5782           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
5783           GetI64ElementsAttr(end_indices, &rewriter),
5784           GetI64ElementsAttr(strides, &rewriter));
5785       // Reshape to drop the axis dimension.
5786       auto result =
5787           rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
5788                                          rewriter.getI64ArrayAttr(op.axis()));
5789       results.push_back(result);
5790     }
5791 
5792     rewriter.replaceOp(op, results);
5793     return success();
5794   }
5795 };
5796 
5797 // Converts tf.Unpack to a series of XLA HLO Slice ops.
5798 // TODO(disc): To recover static special case's performance with folding and
5799 // canonicalization.
5800 class ConvertUnpackOpDynamic : public OpRewritePattern<TF::UnpackOp> {
5801  public:
5802   using OpRewritePattern::OpRewritePattern;
5803 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5804   LogicalResult matchAndRewrite(TF::UnpackOp op,
5805                                 PatternRewriter &rewriter) const override {
5806     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5807     if (!value_type) return failure();
5808     // TODO(disc): Remove this constraint once fold and canonicalization
5809     // implemented.
5810     if (value_type.hasStaticShape()) return failure();
5811 
5812     int64_t value_rank = value_type.getRank();
5813     int64_t axis = op.axis();
5814     if (axis < 0) axis += value_rank;
5815     Location loc = op.getLoc();
5816 
5817     auto shape_scalar_type = rewriter.getIntegerType(32);
5818     // Parameters for constructing each slice.
5819     SmallVector<Value, 4> begin_indices, end_indices, strides;
5820     begin_indices.reserve(value_rank);
5821     end_indices.reserve(value_rank);
5822     strides.reserve(value_rank);
5823     // final output shape
5824     SmallVector<Value, 4> shape_values;
5825     shape_values.reserve(value_rank - 1);
5826     // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1
5827     SmallVector<int64_t, 4> slice_shape(value_rank, ShapedType::kDynamicSize);
5828     for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) {
5829       int64_t dim_size = value_type.getDimSize(dim_idx);
5830       if (dim_size == ShapedType::kDynamicSize) {
5831         Value dim_i = rewriter.create<IndexCastOp>(
5832             loc, rewriter.create<tensor::DimOp>(loc, op.getOperand(), dim_idx),
5833             shape_scalar_type);
5834         end_indices.push_back(dim_i);
5835         if (dim_idx != axis) {
5836           shape_values.push_back(dim_i);
5837         }
5838       } else {
5839         Value dim_i = rewriter.create<ConstantOp>(
5840             loc, shape_scalar_type,
5841             rewriter.getIntegerAttr(shape_scalar_type, dim_size));
5842         end_indices.push_back(dim_i);
5843         if (dim_idx != axis) {
5844           shape_values.push_back(dim_i);
5845           slice_shape[dim_idx] = dim_size;
5846         } else {
5847           slice_shape[dim_idx] = 1;
5848         }
5849       }
5850       begin_indices.push_back(rewriter.create<ConstantIntOp>(loc, 0, 32));
5851       strides.push_back(rewriter.create<ConstantIntOp>(loc, 1, 32));
5852     }
5853 
5854     SmallVector<Value, 4> results;
5855     results.reserve(op.getNumResults());
5856     for (int64_t i = 0; i < op.getNumResults(); ++i) {
5857       begin_indices[axis] = rewriter.create<ConstantIntOp>(loc, i, 32);
5858       end_indices[axis] = rewriter.create<ConstantIntOp>(loc, i + 1, 32);
5859       Value slice_op = rewriter.create<RealDynamicSliceOp>(
5860           loc, RankedTensorType::get(slice_shape, value_type.getElementType()),
5861           op.value(),
5862           rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5863                                                   begin_indices),
5864           rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5865                                                   end_indices),
5866           rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5867                                                   strides));
5868       // Reshape to drop the axis dimension.
5869       Value new_shape = rewriter.create<tensor::FromElementsOp>(
5870           loc, rewriter.getI32Type(), shape_values);
5871       Value reshape_op = rewriter.create<DynamicReshapeOp>(loc, op.getType(i),
5872                                                            slice_op, new_shape);
5873       results.push_back(reshape_op);
5874     }
5875 
5876     rewriter.replaceOp(op, results);
5877     return success();
5878   }
5879 };
5880 
5881 // Converts the tf.Sign op into mhlo.sign
5882 // TODO(disc): To recover static special case's performance with folding and
5883 // canonicalization.
5884 class ConvertSignOpDynamic : public OpRewritePattern<TF::SignOp> {
5885  public:
5886   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::SignOp op,PatternRewriter & rewriter) const5887   LogicalResult matchAndRewrite(TF::SignOp op,
5888                                 PatternRewriter &rewriter) const override {
5889     Location loc = op.getLoc();
5890     Value x = op.x();
5891     auto x_type = x.getType().dyn_cast<RankedTensorType>();
5892     if (!x_type) return failure();
5893     // TODO(disc): Remove this constraint once fold and canonicalization
5894     // implemented.
5895     if (x_type.hasStaticShape()) return failure();
5896 
5897     Value hlo_sign = rewriter.create<mhlo::SignOp>(loc, x);
5898     const StringAttr kNe = rewriter.getStringAttr(
5899         mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
5900     Value hlo_cmp = rewriter.create<mhlo::CompareOp>(loc, x, x, kNe);
5901 
5902     auto zero =
5903         GetScalarConstOfType(x_type.getElementType(), loc, 0, &rewriter);
5904     Value shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), x);
5905 
5906     auto broadcast_dims_attr =
5907         GetI64ElementsAttr(ArrayRef<int64_t>({}), &rewriter);
5908     Value broadcasted_zero = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
5909         loc, x_type, zero, shape_op, broadcast_dims_attr);
5910 
5911     auto hlo_select = rewriter.create<mhlo::SelectOp>(
5912         loc, hlo_cmp, broadcasted_zero, hlo_sign);
5913 
5914     rewriter.replaceOp(op, hlo_select.getResult());
5915     return success();
5916   }
5917 };
5918 
5919 // Converts the tf.SigmoidGradOp
5920 // TODO(disc): To recover static special case's performance with folding and
5921 // canonicalization.
5922 class ConvertSigmoidGradOpDynamic : public OpRewritePattern<TF::SigmoidGradOp> {
5923  public:
5924   using OpRewritePattern::OpRewritePattern;
5925 
matchAndRewrite(TF::SigmoidGradOp op,PatternRewriter & rewriter) const5926   LogicalResult matchAndRewrite(TF::SigmoidGradOp op,
5927                                 PatternRewriter &rewriter) const override {
5928     Location loc = op.getLoc();
5929     Value y = op.y();
5930     Value dy = op.dy();
5931     auto tp_y = y.getType().dyn_cast<RankedTensorType>();
5932     auto tp_dy = dy.getType().dyn_cast<RankedTensorType>();
5933     if (!tp_y || !tp_dy) return failure();
5934 
5935     // TODO(disc): Remove this constraint once fold and canonicalization
5936     // implemented.
5937     if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure();
5938 
5939     Attribute attr;
5940     Type elem_tp = tp_y.getElementType();
5941     if (elem_tp.isSignlessInteger()) {
5942       attr = rewriter.getIntegerAttr(elem_tp, 1);
5943     } else {
5944       assert(elem_tp.isa<FloatType>());
5945       attr = rewriter.getFloatAttr(elem_tp, 1);
5946     }
5947     Value one = rewriter.create<mhlo::ConstOp>(
5948         loc, DenseElementsAttr::get(RankedTensorType::get({}, elem_tp), attr));
5949 
5950     auto v0 = rewriter.create<chlo::BroadcastMulOp>(
5951         loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y));
5952     auto v1 = rewriter.create<chlo::BroadcastSubOp>(
5953         loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y));
5954     auto result = rewriter.create<chlo::BroadcastMulOp>(
5955         loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1));
5956 
5957     rewriter.replaceOp(op, result.getOperation()->getResults());
5958     return success();
5959   }
5960 };
5961 
5962 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
5963 //
5964 // TF unsorted segment reduction op peforms the following calculation:
5965 //
5966 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's  shape is
5967 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
5968 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
5969 // Dm+1, ..., Dn]. Then
5970 //   output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
5971 //      <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
5972 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
5973 //
5974 // The op will be translated to XLA HLO scatter with the following parameters:
5975 // * Update window dims is [segment_id_rank, data_rank).
5976 // * Inserted window dims is {0}.
5977 // * Scatter dims to operand dims mapping is {0}.
5978 // * Index vector dim is segment_id_rank.
5979 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
5980 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
5981   using OpRewritePattern<OpTy>::OpRewritePattern;
5982 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5983   LogicalResult matchAndRewrite(OpTy op,
5984                                 PatternRewriter &rewriter) const override {
5985     auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
5986     if (!data_type) return failure();
5987     int64_t data_rank = data_type.getRank();
5988 
5989     auto segment_ids_type =
5990         op.segment_ids().getType().template dyn_cast<RankedTensorType>();
5991     if (!segment_ids_type) return failure();
5992     int64_t segment_ids_rank = segment_ids_type.getRank();
5993 
5994     DenseIntElementsAttr num_segments_attr;
5995     if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
5996       return failure();
5997 
5998     // The final shape for TF unsorted segment reduction op is [num_segments] +
5999     // data_shape[segment_ids_rank:].
6000     SmallVector<int64_t, 4> output_shape;
6001     output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
6002     auto suffix = data_type.getShape().drop_front(segment_ids_rank);
6003     output_shape.append(suffix.begin(), suffix.end());
6004     auto output_type =
6005         RankedTensorType::get(output_shape, data_type.getElementType());
6006 
6007     // Broadcast the initial value for reduction. This will become the
6008     // 'operand' parameter to scatter to for the final scatter op.
6009     Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
6010                                                 op.getLoc(), &rewriter);
6011     auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
6012         op.getLoc(), output_type, init,
6013         GetI64ElementsAttr(output_shape, &rewriter));
6014 
6015     // Parameters for the generated scatter op.
6016     SmallVector<int64_t, 1> inserted_window_dims(1, 0);
6017     SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
6018     int64_t index_vector_dim = segment_ids_rank;
6019 
6020     // Put all parameters in a StructAttr.
6021     auto dims_attr = ScatterDimensionNumbers::get(
6022         GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
6023         GetI64ElementsAttr(inserted_window_dims, &rewriter),
6024         GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
6025         rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
6026 
6027     auto scatter =
6028         rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
6029                                    op.segment_ids(), op.data(), dims_attr);
6030     BuildReduceBody<ReductionOp>(data_type.getElementType(),
6031                                  &scatter.update_computation(), &rewriter);
6032 
6033     rewriter.replaceOp(op, scatter.getResult());
6034     return success();
6035   }
6036 };
6037 
6038 class ConvertUnsortedSegmentMaxOp
6039     : public GenericConvertUnsortedSegmentReductionOp<
6040           ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
6041  public:
6042   using GenericConvertUnsortedSegmentReductionOp::
6043       GenericConvertUnsortedSegmentReductionOp;
6044 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6045   static Value GetInitialValue(Type reduce_element_type, Location loc,
6046                                PatternRewriter *rewriter) {
6047     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
6048                                      rewriter);
6049   }
6050 };
6051 
6052 class ConvertUnsortedSegmentMinOp
6053     : public GenericConvertUnsortedSegmentReductionOp<
6054           ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
6055  public:
6056   using GenericConvertUnsortedSegmentReductionOp::
6057       GenericConvertUnsortedSegmentReductionOp;
6058 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6059   static Value GetInitialValue(Type reduce_element_type, Location loc,
6060                                PatternRewriter *rewriter) {
6061     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
6062                                      rewriter);
6063   }
6064 };
6065 
6066 class ConvertUnsortedSegmentProdOp
6067     : public GenericConvertUnsortedSegmentReductionOp<
6068           ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
6069  public:
6070   using GenericConvertUnsortedSegmentReductionOp::
6071       GenericConvertUnsortedSegmentReductionOp;
6072 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6073   static Value GetInitialValue(Type reduce_element_type, Location loc,
6074                                PatternRewriter *rewriter) {
6075     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
6076   }
6077 };
6078 
6079 class ConvertUnsortedSegmentSumOp
6080     : public GenericConvertUnsortedSegmentReductionOp<
6081           ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
6082  public:
6083   using GenericConvertUnsortedSegmentReductionOp::
6084       GenericConvertUnsortedSegmentReductionOp;
6085 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6086   static Value GetInitialValue(Type reduce_element_type, Location loc,
6087                                PatternRewriter *rewriter) {
6088     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
6089   }
6090 };
6091 
6092 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
6093 //
6094 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
6095 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
6096 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
6097 // translated into an HLO while op to first emulate shuffling indices using
6098 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
6099 // with the shuffled indices.
6100 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
6101  public:
6102   using OpRewritePattern::OpRewritePattern;
6103 
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const6104   LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
6105                                 PatternRewriter &rewriter) const override {
6106     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
6107     if (!input_type) return failure();
6108 
6109     int64_t input_rank = input_type.getRank();
6110     int64_t first_dim_size = input_type.getDimSize(0);
6111     if (ShapedType::isDynamic(first_dim_size)) return failure();
6112 
6113     // We are shuffling along the first dimension. If its size is <= 1, then
6114     // shuffling is a no-op.
6115     if (first_dim_size <= 1) {
6116       rewriter.replaceOp(op, op.value());
6117       return success();
6118     }
6119 
6120     // For vectors, shuffle values by sorting instead of the obvious
6121     // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
6122     // but not easily parallelizable. For a sufficiently parallel architecture,
6123     // it is faster to sort many times, than Fisher-Yates shuffle once.
6124     if (input_rank == 1) {
6125       // Shuffle values by assigning each value a random key and sorting the
6126       // keys. Keys can collide causing detectable patterns in the shuffled
6127       // output. Collisions translates into more ascending sub-sequences in the
6128       // shuffled output than would be expected by chance. To avoid collisions,
6129       // the number of possible key values must be sufficiently large.
6130 
6131       // How are more than 2^32 keys created? In each loop iteration, the
6132       // algorithm sorts by random keys. Conceptually, the earlier iterations
6133       // are sorting on the lower-order bits of larger keys that are never
6134       // actually assembled.
6135 
6136       // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
6137       // the number of possible keys and n is the number of values. If d = n^2,
6138       // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
6139       // as n goes to infinity is zero.
6140 
6141       // This implementation ensures that the key-space is greater than or equal
6142       // to the cube of the number of values. The risk of collisions can be
6143       // further reduced by increasing Exponent at the expense of
6144       // performance.
6145 
6146       // For Exponent = 2, the expected number of collisions per shuffle is
6147       // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
6148       // about 1/2.
6149 
6150       // For Exponent = 3, the expected number of collisions per shuffle is
6151       // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
6152       // about 1/3255.
6153 
6154       // For Exponent = 4, the expected number of collisions per shuffle is
6155       // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
6156       // about 1/132622.
6157       constexpr int exponent = 3;
6158       int64_t num_elements = input_type.getNumElements();
6159       uint32_t u32_max = std::numeric_limits<uint32_t>::max();
6160       int rounds =
6161           std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
6162 
6163       Value current = op.value();
6164       for (int i = 0; i < rounds; ++i) {
6165         auto keys =
6166             CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
6167                                /*upper_limit=*/u32_max, &rewriter);
6168         auto sorted = rewriter.create<mhlo::SortOp>(
6169             op.getLoc(), llvm::ArrayRef<Value>{keys, current});
6170         auto i32_type = rewriter.getIntegerType(32);
6171         BuildSortComparisonBody({i32_type, input_type.getElementType()},
6172                                 /*direction=*/"LT", llvm::None,
6173                                 &sorted.comparator(), &rewriter);
6174         current = sorted.getResult(1);
6175       }
6176       rewriter.replaceOp(op, current);
6177       return success();
6178     }
6179 
6180     // The Fisher-Yates algorithm.
6181 
6182     // Generate range(n) as the initial value for the indices to be swapped.
6183     auto indices_type =
6184         RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
6185     Value indices = rewriter.create<mhlo::IotaOp>(
6186         op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
6187 
6188     // Generate random numbers to be used as swaps for the indices.
6189     Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
6190                                      first_dim_size, &rewriter);
6191 
6192     // While loop body to perform index swaps.
6193     auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
6194                             SmallVectorImpl<Value> *new_values,
6195                             OpBuilder *builder) {
6196       Value swaps = old_values[0];
6197       Value indices = old_values[1];
6198 
6199       auto vec1_i32_type =
6200           RankedTensorType::get({1}, builder->getIntegerType(32));
6201       auto scalar_i32_type =
6202           RankedTensorType::get({}, builder->getIntegerType(32));
6203       auto scalar_i64_type =
6204           RankedTensorType::get({}, builder->getIntegerType(64));
6205 
6206       auto scalar_one =
6207           DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
6208 
6209       // We need to swap the indices[i] with indices[swaps[i]]. First get
6210       // these index values.
6211       Value source_index = builder->create<mhlo::DynamicSliceOp>(
6212           loc, vec1_i32_type, indices, i, scalar_one);
6213       Value swap_index = builder->create<mhlo::ReshapeOp>(
6214           loc, scalar_i32_type,
6215           builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
6216                                                 scalar_one));
6217       Value target_index = builder->create<mhlo::DynamicSliceOp>(
6218           loc, vec1_i32_type, indices, swap_index, scalar_one);
6219 
6220       // Then perform the swap.
6221       // indices[i] <- indices[swaps[i]]
6222       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
6223           loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
6224       // indices[swaps[i]] <- indices[i]
6225       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
6226           loc, indices.getType(), indices, source_index,
6227           llvm::makeArrayRef(swap_index));
6228 
6229       // Update new values.
6230       new_values->assign({swaps, indices});
6231     };
6232 
6233     // Create a while op to swap indices.
6234     SmallVector<Value, 2> while_output;
6235     CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
6236                   &while_output, &rewriter);
6237     Value swaped_indices = while_output[1];
6238 
6239     // Gather the data using the swapped indices as the shuffled order.
6240     ArrayRef<int64_t> input_shape = input_type.getShape();
6241     SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
6242     slice_sizes[0] = 1;
6243     auto dims_attr = GatherDimensionNumbers::get(
6244         /*offset_dims=*/GetI64ElementsAttrForSeq(1, input_rank, &rewriter),
6245         /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
6246         /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
6247         /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
6248         rewriter.getContext());
6249     rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
6250         op, op.getType(), op.value(), swaped_indices, dims_attr,
6251         GetI64ElementsAttr(slice_sizes, &rewriter));
6252 
6253     return success();
6254   }
6255 };
6256 
6257 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
6258 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
6259  public:
6260   using OpRewritePattern::OpRewritePattern;
6261 
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const6262   LogicalResult matchAndRewrite(TF::XlaShardingOp op,
6263                                 PatternRewriter &rewriter) const override {
6264     // TODO(b/148313088): define sharding attribute struct in MLIR intead of
6265     // using a string.
6266     if (!op._XlaSharding().hasValue()) return failure();
6267 
6268     auto custom_call = rewriter.create<mhlo::CustomCallOp>(
6269         op.getLoc(), op.getType(), op.input(),
6270         /*call_target_name=*/rewriter.getStringAttr("Sharding"),
6271         /*has_side_effect=*/rewriter.getBoolAttr(false),
6272         /*backend_config=*/rewriter.getStringAttr(""),
6273         /*api_version=*/
6274         mhlo::CustomCallApiVersionAttr::get(
6275             rewriter.getContext(),
6276             mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL));
6277     custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
6278     rewriter.replaceOp(op, custom_call.getResult(0));
6279 
6280     return success();
6281   }
6282 };
6283 
6284 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
6285 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
6286  public:
6287   using OpRewritePattern::OpRewritePattern;
6288 
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const6289   LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
6290                                 PatternRewriter &rewriter) const override {
6291     auto input = op.x();
6292     auto indices = op.i();
6293     auto updates = op.v();
6294 
6295     // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
6296     // on the contents of `x`.
6297     auto input_type = input.getType().cast<ShapedType>();
6298     auto updates_type = updates.getType().cast<ShapedType>();
6299     auto indices_type = indices.getType().cast<ShapedType>();
6300     if (!indices_type.hasStaticShape()) return failure();
6301 
6302     if (indices_type.getRank() != 1) return failure();
6303 
6304     SmallVector<Type, 4> unpacked_indices_type(
6305         indices_type.getDimSize(0),
6306         RankedTensorType::get({}, indices_type.getElementType()));
6307     // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
6308     // required to have matching types. This rewrite rule creates
6309     // DynamicUpdateSlice ops where the first "start index" is always i32 and
6310     // subsequent ones are constructed based on zero_attr. Thus the type
6311     // for zero_attr needs to be i32 as well.
6312     auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
6313     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6314         op.getLoc(), unpacked_indices_type, indices, zero_attr);
6315 
6316     SmallVector<int64_t, 4> split_updates_shape;
6317     split_updates_shape.append(updates_type.getShape().begin(),
6318                                updates_type.getShape().end());
6319     split_updates_shape.front() = 1;
6320     SmallVector<Type, 4> split_updates_type;
6321     split_updates_type.resize(
6322         updates_type.getShape().front(),
6323         RankedTensorType::get(split_updates_shape,
6324                               updates_type.getElementType()));
6325 
6326     auto cst =
6327         rewriter.create<mhlo::ConstOp>(op.getLoc(), zero_attr).getResult();
6328     auto split_updates = rewriter.create<TF::SplitOp>(
6329         op.getLoc(), split_updates_type, cst, updates);
6330 
6331     SmallVector<Value, 6> input_indices;
6332     input_indices.resize(input_type.getRank(), cst);
6333 
6334     SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
6335     SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
6336     SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
6337                                    updates_type.getShape().end());
6338 
6339     for (auto pair :
6340          llvm::zip(unpacked_indices.output(), split_updates.output())) {
6341       input_indices.front() = std::get<0>(pair);
6342       input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
6343           op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
6344     }
6345 
6346     rewriter.replaceOp(op, input);
6347     return success();
6348   }
6349 };
6350 
6351 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
6352 class ConvertXlaDynamicUpdateSliceOp
6353     : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
6354  public:
6355   using OpRewritePattern::OpRewritePattern;
6356 
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const6357   LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
6358                                 PatternRewriter &rewriter) const override {
6359     auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
6360     if (!indices_type || !indices_type.hasStaticShape() ||
6361         indices_type.getShape().size() != 1)
6362       return failure();
6363 
6364     SmallVector<Type, 4> unpacked_indices_type(
6365         indices_type.getDimSize(0),
6366         RankedTensorType::get({}, indices_type.getElementType()));
6367     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6368         op.getLoc(), unpacked_indices_type, op.indices(),
6369         IntegerAttr::get(rewriter.getIntegerType(64), 0));
6370     rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
6371         op, op.getType(), op.input(), op.update(), unpacked_indices.output());
6372     return success();
6373   }
6374 };
6375 
6376 // Converts a TF XlaAllReduce op to AllReduce HLO.
6377 class ConvertXlaAllReduceOp : public OpRewritePattern<TF::XlaAllReduceOp> {
6378   using OpRewritePattern::OpRewritePattern;
6379 
matchAndRewrite(TF::XlaAllReduceOp op,PatternRewriter & rewriter) const6380   LogicalResult matchAndRewrite(TF::XlaAllReduceOp op,
6381                                 PatternRewriter &rewriter) const override {
6382     DenseIntElementsAttr group_assignment;
6383     if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
6384       return failure();
6385     auto replica_groups =
6386         hlo::ConvertElementsAttr(group_assignment, rewriter.getIntegerType(64))
6387             .cast<DenseIntElementsAttr>();
6388     if (replica_groups.getType().getRank() != 2) return failure();
6389 
6390     Location loc = op.getLoc();
6391     Type element_type = getElementTypeOrSelf(op.input().getType());
6392 
6393     auto all_reduce = rewriter.create<AllReduceOp>(
6394         loc, op.getType(), op.input(), replica_groups, ChannelHandle());
6395     StringRef reduce_op = op.reduce_op();
6396     if (reduce_op == "Add") {
6397       BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
6398                              &rewriter);
6399     } else if (reduce_op == "Mul") {
6400       BuildReduceBody<MulOp>(element_type, &all_reduce.computation(),
6401                              &rewriter);
6402     } else if (reduce_op == "Min") {
6403       BuildReduceBody<MinOp>(element_type, &all_reduce.computation(),
6404                              &rewriter);
6405     } else if (reduce_op == "Max") {
6406       BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(),
6407                              &rewriter);
6408     } else {
6409       // For mean, add replicas in the same group. Then divide the sum by the
6410       // number of replicas in each group below.
6411       assert(reduce_op == "Mean");
6412       BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
6413                              &rewriter);
6414     }
6415     Value result = all_reduce.getResult();
6416 
6417     // For mean, divide the merge result by group size.
6418     if (reduce_op == "Mean") {
6419       int64_t replica_group_size = replica_groups.getType().getDimSize(1);
6420       auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
6421                                           &rewriter);
6422       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
6423       result = rewriter.create<chlo::BroadcastDivOp>(
6424           loc, result, divisor.getResult(), broadcast_dims);
6425     }
6426 
6427     rewriter.replaceOp(op, {result});
6428     return success();
6429   }
6430 };
6431 
6432 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
6433 // semantics for static and dynamic cases.
6434 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
6435  public:
6436   using OpRewritePattern::OpRewritePattern;
6437 
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const6438   LogicalResult matchAndRewrite(TF::ClipByValueOp op,
6439                                 PatternRewriter &rewriter) const override {
6440     Value input = op.t();
6441     Value min = op.clip_value_min();
6442     Value max = op.clip_value_max();
6443 
6444     auto input_ty = input.getType().cast<ShapedType>();
6445     auto min_ty = min.getType().cast<ShapedType>();
6446     auto max_ty = max.getType().cast<ShapedType>();
6447 
6448     if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
6449       return failure();
6450     }
6451 
6452     auto shape = rewriter.create<TF::ShapeOp>(
6453         op.getLoc(),
6454         RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
6455         input);
6456 
6457     if (min_ty != input_ty) {
6458       min =
6459           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
6460     }
6461 
6462     if (max_ty != input_ty) {
6463       max =
6464           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
6465     }
6466 
6467     rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
6468     return success();
6469   }
6470 };
6471 
6472 // Converts ConstOp to XLA's constant operation and introduces a tensor cast if
6473 // needed.
6474 class ConvertConstOp : public OpRewritePattern<TF::ConstOp> {
6475  public:
6476   using OpRewritePattern::OpRewritePattern;
6477 
matchAndRewrite(TF::ConstOp op,PatternRewriter & rewriter) const6478   LogicalResult matchAndRewrite(TF::ConstOp op,
6479                                 PatternRewriter &rewriter) const override {
6480     // Convert only for valid HLO tensors.
6481     auto ty = op.getType().dyn_cast<TensorType>();
6482     if (!ty || !ty.getElementType().isa<FloatType, IntegerType, ComplexType>())
6483       return failure();
6484 
6485     Location loc = op.getLoc();
6486     Value result = rewriter.create<mhlo::ConstOp>(loc, op.value());
6487     if (result.getType() != op.getType())
6488       result = rewriter.create<tensor::CastOp>(loc, op.getType(), result);
6489     rewriter.replaceOp(op, result);
6490     return success();
6491   }
6492 };
6493 
6494 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
6495 // setting appropriate window dimensions, with the given aggregation op as the
6496 // reduction function. The input tensor needs to have a static shape, and 'axis'
6497 // must be const. The TableGen pattern is not used for this rewrite because it
6498 // involves regions.
6499 template <typename OpT, typename AggregationOp>
6500 class ConvertCumOp : public OpRewritePattern<OpT> {
6501   using OpRewritePattern<OpT>::OpRewritePattern;
6502 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const6503   LogicalResult matchAndRewrite(OpT op,
6504                                 PatternRewriter &rewriter) const override {
6505     auto input = op.x();
6506     auto input_type = input.getType().template dyn_cast<ShapedType>();
6507     if (!input_type || !input_type.hasStaticShape()) {
6508       return failure();
6509     }
6510 
6511     ArrayRef<int64_t> input_shape = input_type.getShape();
6512     int64_t rank = input_shape.size();
6513 
6514     // We can only match when the axis is a constant scalar.
6515     DenseIntElementsAttr axis_attr;
6516     if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
6517       return failure();
6518     }
6519 
6520     // Get the dimension to apply the reduction on, and offset properly if it is
6521     // negative.
6522     int64_t axis = (*axis_attr.begin()).getSExtValue();
6523     if (axis < 0) {
6524       axis += rank;
6525     }
6526 
6527     // If we're supposed to sum things up in the reverse direction, we reverse
6528     // the input and then later reverse the output.
6529     if (op.reverse()) {
6530       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6531       input = rewriter.create<ReverseOp>(
6532           op.getLoc(), op.getType(), input,
6533           GetI64ElementsAttr(dims_to_reverse, &rewriter));
6534     }
6535 
6536     // Convert if we need to enlarge the element type's bitwidth to avoid
6537     // precision loss.
6538     Type input_element_type = input_type.getElementType();
6539 
6540     // TODO(hinsu): Handle complex element types.
6541     if (!input_element_type.isIntOrFloat()) return failure();
6542 
6543     Type sum_element_type = GetSumAccumulationType(input_element_type);
6544     input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
6545 
6546     SmallVector<int64_t, 4> window_dims(rank, 1);
6547     SmallVector<int64_t, 4> window_strides(rank, 1);
6548     window_dims[axis] = input_shape[axis];
6549 
6550     SmallVector<int64_t, 8> paddings(rank * 2, 0);
6551     paddings[axis * 2] =
6552         std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
6553     auto paddings_attr = DenseIntElementsAttr::get(
6554         RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
6555         paddings);
6556 
6557     int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
6558     Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
6559                                       &rewriter);
6560 
6561     auto reduce = rewriter.create<ReduceWindowOp>(
6562         op.getLoc(), input_type, input, init,
6563         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
6564         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
6565         /*base_dilations=*/DenseIntElementsAttr(),
6566         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
6567     BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
6568     Value result = reduce.getResult(0);
6569 
6570     if (op.exclusive()) {
6571       // In "exclusive" operation, the output will start with the "init" (0)
6572       // values. There is no way to express that as a ReduceWindowOp, so run the
6573       // normal operation, and then use a PadOp to add the 0 "column" on the
6574       // left and cut away the last column on the right.
6575       llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
6576       llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
6577       llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
6578       low_padding[axis] = 1;
6579       high_padding[axis] = -1;
6580       result = rewriter.create<PadOp>(
6581           op.getLoc(), op.getType(), result, init,
6582           GetI64ElementsAttr(low_padding, &rewriter),
6583           GetI64ElementsAttr(high_padding, &rewriter),
6584           GetI64ElementsAttr(interior_padding, &rewriter));
6585     }
6586 
6587     // Convert back if we enlarged the element type's bitwidth.
6588     result =
6589         rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
6590 
6591     if (op.reverse()) {
6592       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6593       result = rewriter.create<ReverseOp>(
6594           op.getLoc(), op.getType(), result,
6595           GetI64ElementsAttr(dims_to_reverse, &rewriter));
6596     }
6597 
6598     rewriter.replaceOp(op, result);
6599     return success();
6600   }
6601 };
6602 
6603 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
6604 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
6605 
6606 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
6607 // dialect lowerings. This involves extracting the shape type, extracting and
6608 // converting each dimension to a known integer type, and repacking into a final
6609 // tensor.
6610 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
6611  public:
6612   using OpRewritePattern::OpRewritePattern;
6613 
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const6614   LogicalResult matchAndRewrite(TF::ShapeOp op,
6615                                 PatternRewriter &rewriter) const override {
6616     Value input = op.input();
6617 
6618     auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
6619     if (!result_ty) {
6620       return failure();
6621     }
6622 
6623     auto index_tensor =
6624         RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
6625     auto shape_op =
6626         rewriter.create<shape::ShapeOfOp>(op.getLoc(), index_tensor, input);
6627     rewriter.replaceOpWithNewOp<IndexCastOp>(op, shape_op, result_ty);
6628     return success();
6629   }
6630 };
6631 
6632 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
6633  public:
6634   using OpRewritePattern::OpRewritePattern;
6635 
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const6636   LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
6637                                 PatternRewriter &rewriter) const override {
6638     auto input = op.input();
6639     auto input_ty = input.getType().cast<ShapedType>();
6640     auto result_ty = op.getType().cast<ShapedType>();
6641     if (!result_ty.hasRank() || !input_ty.hasRank() ||
6642         result_ty.hasStaticShape()) {
6643       return failure();
6644     }
6645 
6646     DenseIntElementsAttr expand_dims_attr;
6647     if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
6648       return failure();
6649     }
6650 
6651     auto shape = rewriter.create<shape::ShapeOfOp>(
6652         op.getLoc(),
6653         RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
6654         input);
6655     auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues());
6656 
6657     llvm::SmallVector<Value, 4> dims;
6658     dims.resize(result_ty.getRank());
6659 
6660     auto inserted_dim = expand_dims[0].getSExtValue();
6661 
6662     // Handle the negative value use case.
6663     if (inserted_dim < 0) {
6664       inserted_dim += result_ty.getRank();
6665       // This means the value is completely incorrect, just return.
6666       if (inserted_dim < 0) {
6667         return failure();
6668       }
6669     }
6670 
6671     dims[inserted_dim] = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
6672 
6673     for (int i = 0; i < dims.size() - 1; i++) {
6674       // Add the extracted dim.
6675       Value index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
6676       Value dim = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape, index);
6677       dims[i >= inserted_dim ? i + 1 : i] = dim;
6678     }
6679 
6680     auto from_extents =
6681         rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6682     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
6683                                                         from_extents);
6684     return success();
6685   }
6686 };
6687 
6688 // Converts a TF QR op to HLO.
6689 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
6690  public:
6691   using OpRewritePattern::OpRewritePattern;
6692 
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const6693   LogicalResult matchAndRewrite(TF::QrOp op,
6694                                 PatternRewriter &rewriter) const override {
6695     // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
6696     // Loan. def qr_blocked(a, block_size):
6697     //   m = a.shape[0]
6698     //   n = a.shape[1]
6699     //   q = np.eye(m)
6700     //   for i in xrange(0, min(m, n), block_size):
6701     //     k = min(block_size, min(m, n) - s)
6702     //     (a, vs, taus) = qr(a[i:, i:i+k])
6703     //     y = vs
6704     //     w = ComputeWYRepresentation(vs, taus, m-i, k)
6705     //     a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
6706     //     q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
6707     //   return (q, a)
6708     auto type = op.input().getType().dyn_cast<RankedTensorType>();
6709     if (!type || !type.hasStaticShape()) return failure();
6710     // The block size is chosen to match old bridge lowering.
6711     constexpr int64_t kBlockSize = 128;
6712     Value a = op.input();
6713     int64_t m = type.getDimSize(type.getRank() - 2);
6714     int64_t n = type.getDimSize(type.getRank() - 1);
6715     int64_t p = std::min(m, n);
6716     auto batch_dims = type.getShape().drop_back(2);
6717     auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
6718     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6719                                          rewriter.getI64IntegerAttr(0));
6720     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6721                                          rewriter.getI64IntegerAttr(1));
6722     Value compare = rewriter.create<CompareOp>(
6723         op.getLoc(), iota0, iota1,
6724         StringAttr::get(rewriter.getContext(), "EQ"));
6725     Value identity_matrix =
6726         rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
6727     auto q_shape = llvm::to_vector<4>(type.getShape());
6728     q_shape.back() = m;
6729     Value q = rewriter.create<BroadcastOp>(
6730         op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()),
6731         identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter));
6732     auto precision_config = rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"});
6733     for (int64_t i = 0; i < p; i += kBlockSize) {
6734       int64_t k = std::min(kBlockSize, p - i);
6735       auto a_block =
6736           SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
6737       Value r_block;
6738       Value taus;
6739       Value vs;
6740       QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
6741       a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
6742 
6743       // Compute the I-WY block representation of a product of Householder
6744       // matrices.
6745       Value w =
6746           ComputeWYRepresentation(op.getLoc(), type.getElementType(),
6747                                   batch_dims, vs, taus, m - i, k, &rewriter);
6748       auto y = vs;
6749 
6750       // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
6751       Value a_panel =
6752           SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
6753       auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
6754                                batch_dims.size(), precision_config, &rewriter);
6755       a_update = BatchDot(op.getLoc(), y, false, a_update, false,
6756                           batch_dims.size(), precision_config, &rewriter);
6757       a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
6758       a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
6759                                  &rewriter);
6760 
6761       // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
6762       Value q_panel =
6763           SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
6764       Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
6765                                 batch_dims.size(), precision_config, &rewriter);
6766       q_update = BatchDot(op.getLoc(), q_update, false, y, true,
6767                           batch_dims.size(), precision_config, &rewriter);
6768       q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
6769       q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
6770     }
6771     // full_matrices is false when only a partial result in needed. Slice to the
6772     // needed dimensions here.
6773     if (!op.full_matrices()) {
6774       q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
6775       a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
6776     }
6777     rewriter.replaceOp(op, {q, a});
6778     return success();
6779   }
6780 
6781  private:
6782   // Computes a Householder reflection of the form:
6783   // H = I - tau v v.T.
6784   // such that
6785   // H . ( x1  ) = ( x1   )
6786   //     ( x2  ) = ( x2   )
6787   //     ( ... ) = ( ...  )
6788   //     ( xk  ) = ( beta )
6789   //     ( ... )   ( 0    )
6790   //     ( ... )   ( 0    )
6791   // Unlike the usual formulation, we allow the caller to supply 'k' rather than
6792   // only providing the relevant part of 'x' to maintain XLA's static shape
6793   // invariant. In addition, the implementation supports batching.
6794   // Pseudo-code, without batching:
6795   //   alpha = x[k]
6796   //   x_copy = np.copy(x)
6797   //   x_copy[:k+1] = 0
6798   //   xnorm = norm2(x_copy)
6799   //   if xnorm == 0:
6800   //     beta = alpha
6801   //     tau = 0
6802   //     v = np.zeros_like(x)
6803   //   else:
6804   //     beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
6805   //     tau = (beta - alpha) / beta
6806   //     v = x / (alpha - beta)
6807   //   v[k] = 1
6808   //   return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const6809   void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
6810              const int64_t m, OpBuilder *builder, Value *v, Value *tau,
6811              Value *beta) const {
6812     auto x_type = x.getType().cast<RankedTensorType>();
6813 
6814     llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
6815     std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
6816     const int64_t minor_dim = batch_dims.size();
6817 
6818     Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
6819     Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
6820 
6821     // alpha = x[k]
6822     Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
6823     alpha = builder->create<ReshapeOp>(
6824         loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
6825 
6826     // Compute x[k+1:] (padded with zeros in elements 0..k)
6827     Value iota = builder->create<IotaOp>(
6828         loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
6829         builder->getI64IntegerAttr(0));
6830     Value gtk = builder->create<chlo::BroadcastCompareOp>(
6831         loc, iota, k, GetI64ElementsAttr({}, builder),
6832         StringAttr::get(builder->getContext(), "GT"));
6833     gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
6834     Value x_after_k = builder->create<chlo::BroadcastMulOp>(
6835         loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
6836     Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
6837     // sigma = np.dot(x[k+1:], x[k+1:])
6838     auto sigma = builder->create<ReduceOp>(
6839         loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
6840     BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
6841     // mu = np.sqrt(x[k]*x[k] + sigma)
6842     Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
6843     Value mu = builder->create<SqrtOp>(
6844         loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
6845 
6846     Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
6847         loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
6848         StringAttr::get(builder->getContext(), "EQ"));
6849     Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
6850         loc, alpha, zero, GetI64ElementsAttr({}, builder),
6851         StringAttr::get(builder->getContext(), "LT"));
6852     auto batch_size_one = builder->create<BroadcastOp>(
6853         loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
6854     Value signed_mu = builder->create<chlo::BroadcastMulOp>(
6855         loc,
6856         builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
6857                                   batch_size_one,
6858                                   builder->create<NegOp>(loc, batch_size_one)),
6859         mu, GetI64ElementsAttr({}, builder));
6860     *beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
6861                                       alpha, signed_mu);
6862     *tau = builder->create<DivOp>(
6863         loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
6864     Value zero_tau = builder->create<BroadcastOp>(
6865         loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
6866     *tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
6867                                      zero_tau, *tau);
6868     Value divisor = builder->create<SubOp>(loc, alpha, *beta);
6869     divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
6870                                         batch_size_one, divisor);
6871 
6872     Value eqk = builder->create<chlo::BroadcastCompareOp>(
6873         loc, iota, k, GetI64ElementsAttr({}, builder),
6874         StringAttr::get(builder->getContext(), "EQ"));
6875     eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
6876     llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
6877     e_k_shape.push_back(m);
6878     auto e_k = builder->create<BroadcastOp>(
6879         loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk,
6880         GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
6881                            builder));
6882 
6883     // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
6884     // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
6885     // Note that the add performs a degenerate broadcast.
6886     *v = builder->create<chlo::BroadcastAddOp>(
6887         loc, e_k,
6888         StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
6889                                      GetI64ElementsAttr(batch_dim_ids, builder),
6890                                      *builder),
6891         /*broadcast_dimensions=*/nullptr);
6892   }
6893 
6894   // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
6895   // Loan "Matrix Computations", 4th Edition. This is an unblocked
6896   // implementation used as an inner routine of the blocked implementation.
6897   // Algorithm is adapted slightly so the shapes inside the loop are static, at
6898   // the cost of some redundant computation. Since this is used as an inner
6899   // block kernel, accumulates the Householder transformations (vs, taus) rather
6900   // than the matrix q. Equivalent Python code, without batching: def qr(a):
6901   //   m = a.shape[0]
6902   //   n = a.shape[1]
6903   //   vs = np.zeros([m, n])
6904   //   taus = np.zeros([n])
6905   //   for j in xrange(min(m, n)):
6906   //     v, tau, beta = house(a[:, j], j)
6907   //     # Unusually, we apply the Householder transformation to the entirety of
6908   //     # a, wasting FLOPs to maintain the static shape invariant that XLA
6909   //     # requires. For columns that precede j this has no effect.
6910   //     a[:, :] -= tau * np.dot(v[:, np.newaxis],
6911   //                              np.dot(v[np.newaxis, :], a[:, :]))
6912   //     # Form column j explicitly rather than relying on the precision of the
6913   //     # Householder update.
6914   //     a[j, j] = beta
6915   //     a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
6916   //     vs[:, j] = v
6917   //     taus[j] = tau
6918   //   return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const6919   void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
6920                PatternRewriter *rewriter) const {
6921     auto a_type = a.getType().cast<RankedTensorType>();
6922     const int num_dims = a_type.getRank();
6923     assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
6924 
6925     const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
6926     const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
6927 
6928     const int64_t num_batch_dims = num_dims - 2;
6929     auto batch_dims = a_type.getShape().take_front(num_batch_dims);
6930     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6931     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6932 
6933     auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6934                           SmallVectorImpl<Value> *new_values,
6935                           OpBuilder *builder) {
6936       auto a = old_values[0];
6937       auto vs = old_values[1];
6938       auto taus = old_values[2];
6939 
6940       // v, beta = house(a[:, j], j)
6941       auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
6942       auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
6943       x_collapsed_shape.push_back(m);
6944       auto x_collapsed = builder->create<ReshapeOp>(
6945           loc,
6946           RankedTensorType::get(x_collapsed_shape,
6947                                 getElementTypeOrSelf(x.getType())),
6948           x);
6949       Value v, tau, beta;
6950       House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
6951 
6952       auto shape = llvm::to_vector<4>(batch_dims);
6953       shape.append({1, m});
6954       auto v_broadcast = builder->create<ReshapeOp>(
6955           loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
6956           v);
6957       // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6958       //                          np.dot(v[np.newaxis, :], a[:, :]))
6959       auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
6960       auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
6961                           precision, builder);
6962       vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
6963                      precision, builder);
6964       auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
6965           loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
6966           *builder);
6967       a = builder->create<SubOp>(loc, a, tau_x_vva);
6968 
6969       // It is more precise to populate column 'k' explicitly, rather than
6970       // computing it implicitly by applying the Householder transformation.
6971       // a[k,k] = beta
6972       // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
6973       auto iota = builder->create<IotaOp>(
6974           loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
6975           builder->getI64IntegerAttr(0));
6976       Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
6977           loc, iota, j, GetI64ElementsAttr({}, builder),
6978           StringAttr::get(builder->getContext(), "LT"));
6979       predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
6980                                                     a_type.getElementType());
6981       Value mask = builder->create<chlo::BroadcastCompareOp>(
6982           loc, iota, j, GetI64ElementsAttr({}, builder),
6983           StringAttr::get(builder->getContext(), "EQ"));
6984       mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
6985       llvm::SmallVector<int64_t, 4> broadcast_mask_shape(a_type.getRank(), 1);
6986       broadcast_mask_shape[a_type.getRank() - 2] = m;
6987       mask = builder->create<BroadcastOp>(
6988           loc,
6989           RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()),
6990           mask,
6991           GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
6992                              builder));
6993       Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
6994           loc, x, predecessor_mask,
6995           GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
6996       Value masked_beta = StaticBinaryBroadcast<MulOp>(
6997           loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
6998           *builder);
6999       Value new_x =
7000           builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
7001       // Update a[:,j]
7002       llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
7003       std::iota(dim_ids.begin(), dim_ids.end(), 0);
7004       new_x = builder->create<BroadcastInDimOp>(
7005           loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
7006       const int64_t minor_dim = num_batch_dims;
7007       auto iota_mn = builder->create<IotaOp>(
7008           loc,
7009           RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
7010           builder->getI64IntegerAttr(minor_dim + 1));
7011       Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
7012           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
7013           StringAttr::get(builder->getContext(), "EQ"));
7014       a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
7015 
7016       // vs[:, j] = v
7017       llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
7018       std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
7019       Value vs_zeros =
7020           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
7021       vs_zeros = builder->create<BroadcastOp>(
7022           loc, vs.getType(), vs_zeros,
7023           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
7024                              builder));
7025       auto vs_update = builder->create<SelectOp>(
7026           loc, vs.getType(), xa_mask,
7027           StaticBinaryBroadcast<AddOp>(
7028               loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
7029               *builder),
7030           vs_zeros);
7031       vs = builder->create<AddOp>(loc, vs, vs_update);
7032 
7033       // taus[j] = tau
7034       llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
7035       std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
7036 
7037       auto iota_shape = llvm::to_vector<4>(batch_dims);
7038       iota_shape.push_back(n);
7039       auto iota_n = builder->create<IotaOp>(
7040           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
7041           builder->getI64IntegerAttr(minor_dim));
7042       Value taus_zeros =
7043           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
7044       taus_zeros = builder->create<BroadcastOp>(
7045           loc, taus.getType(), taus_zeros,
7046           GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
7047                              builder));
7048       Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
7049           loc, iota_n, j, GetI64ElementsAttr({}, builder),
7050           StringAttr::get(builder->getContext(), "EQ"));
7051       auto taus_update = builder->create<SelectOp>(
7052           loc, taus.getType(), taus_mask,
7053           StaticBinaryBroadcast<AddOp>(
7054               loc, taus_zeros, tau,
7055               GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
7056           taus_zeros);
7057       taus = builder->create<AddOp>(loc, taus, taus_update);
7058       new_values->assign({a, vs, taus});
7059     };
7060 
7061     Value zero =
7062         GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
7063     *vs = rewriter->create<BroadcastOp>(
7064         loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
7065     auto taus_shape = llvm::to_vector<4>(batch_dims);
7066     taus_shape.push_back(n);
7067     *taus = rewriter->create<BroadcastOp>(
7068         loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero,
7069         GetI64ElementsAttr(taus_shape, rewriter));
7070 
7071     SmallVector<Value, 4> while_output;
7072     CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
7073                   &while_output, rewriter);
7074     *r = while_output[0];
7075     *vs = while_output[1];
7076     *taus = while_output[2];
7077   }
7078 
7079   // Computes W and Y such that I-WY is equivalent to the sequence of
7080   // Householder
7081   // transformations given by vs and taus.
7082   // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
7083   // Y = np.zeros([m, n])
7084   // W = np.zeros([m, n])
7085   // Y[:, 0] = vs[:, 0]
7086   // W[:, 0] = -taus[0] * vs[:, 0]
7087   // for j in xrange(1, n):
7088   //   v = vs[:, j]
7089   //   z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
7090   //   W[:, j] = z
7091   //   Y[:, j] = v
7092   // return W
7093   // There is no need to return Y since at termination of the loop it is equal
7094   // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const7095   Value ComputeWYRepresentation(Location loc, Type data_type,
7096                                 ArrayRef<int64_t> batch_dims, Value vs,
7097                                 Value taus, int64_t m, int64_t n,
7098                                 PatternRewriter *rewriter) const {
7099     int64_t n_index = batch_dims.size() + 1;
7100     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
7101     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
7102 
7103     auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
7104                        SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
7105       // w has shape [..., m, n]
7106       auto w = old_values[0];
7107       const auto vs = old_values[1];
7108       const auto taus = old_values[2];
7109 
7110       // Want j values in range [1, ... n).
7111       j = builder->create<AddOp>(
7112           loc, j,
7113           GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
7114                                builder));
7115       // vs has shape [..., m, 1]
7116       auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
7117       // beta has shape [..., 1]
7118       auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
7119 
7120       auto iota_shape = llvm::to_vector<4>(batch_dims);
7121       iota_shape.append({m, n});
7122       auto iota_mn = builder->create<IotaOp>(
7123           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
7124           builder->getI64IntegerAttr(n_index));
7125 
7126       // y has shape [..., m, n]
7127       Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
7128                                         0, builder);
7129       zero = builder->create<BroadcastOp>(
7130           loc, vs.getType(), zero,
7131           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
7132                              builder));
7133       auto compare = builder->create<chlo::BroadcastCompareOp>(
7134           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
7135           StringAttr::get(builder->getContext(), "GE"));
7136       auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
7137 
7138       // yv has shape [..., n, 1]
7139       auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
7140       auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
7141                          builder);
7142       // wyv has shape [..., m, 1]
7143       auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
7144                           precision, builder);
7145 
7146       // z = -beta * (v + wyv)
7147       auto neg_beta = builder->create<NegOp>(loc, beta);
7148       auto v_wyv = builder->create<AddOp>(loc, v, wyv);
7149       auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
7150       beta_broadcast_dims.push_back(n_index);
7151       auto z = StaticBinaryBroadcast<MulOp>(
7152           loc, neg_beta, v_wyv,
7153           GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
7154 
7155       w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
7156       new_values->assign({w, vs, taus});
7157     };
7158 
7159     Value w =
7160         GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
7161     auto w_shape = llvm::to_vector<4>(batch_dims);
7162     w_shape.append({m, n});
7163     w = rewriter->create<BroadcastOp>(loc,
7164                                       RankedTensorType::get(w_shape, data_type),
7165                                       w, GetI64ElementsAttr(w_shape, rewriter));
7166     auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
7167     auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
7168     auto neg_beta = rewriter->create<NegOp>(loc, beta);
7169     auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
7170     beta_broadcast_dims.push_back(n_index);
7171     auto bv = StaticBinaryBroadcast<MulOp>(
7172         loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
7173         *rewriter);
7174     w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
7175 
7176     SmallVector<Value, 4> while_output;
7177     CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
7178     return while_output[0];
7179   }
7180 };
7181 
7182 // TF.PrintOp to mhlo.PrintOp
7183 class ConvertPrintOp : public OpRewritePattern<TF::PrintOp> {
7184  public:
7185   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::PrintOp op,PatternRewriter & rewriter) const7186   LogicalResult matchAndRewrite(TF::PrintOp op,
7187                                 PatternRewriter &rewriter) const final {
7188     auto input = op.input();
7189     rewriter.replaceOpWithNewOp<mhlo::PrintOp>(op, op.getType(), input);
7190     return success();
7191   }
7192 };
7193 
7194 // Emits debug information which includes the number of ops of each type which
7195 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)7196 void EmitLegalizationErrors(Operation *op,
7197                             const DenseSet<Operation *> &nonlegalized_ops) {
7198   // Track the legalization failures by mapping op name to information about
7199   // that failure: the number of unlegalized occurrences of the op, and one
7200   // example operation that failed.
7201   std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
7202   DenseSet<Operation *> error_ops;
7203   for (Operation *nonlegalized_op : nonlegalized_ops) {
7204     // Increment count of this legalization failure.
7205     StringRef op_name = nonlegalized_op->getName().getStringRef();
7206     // If this emplace is successful, it's the first time we've encountered
7207     // this op type. Initialize count to 0 so that after increment, it is 1.
7208     auto insertion_result = op_name_to_error_info.emplace(
7209         op_name, std::make_pair(0, nonlegalized_op));
7210     ++insertion_result.first->second.first;
7211   }
7212   std::vector<std::string> error_messages;
7213   error_messages.reserve(op_name_to_error_info.size());
7214   for (const auto &op_info : op_name_to_error_info) {
7215     error_messages.push_back(
7216         llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
7217   }
7218   Location loc = op->getLoc();
7219   emitError(loc) << "The following operations cannot be legalized: "
7220                  << llvm::join(error_messages, "; ")
7221                  << ". These legalization failure(s) may be due to missing TF "
7222                     "to HLO lowerings and/or unsupported attributes, etc.";
7223   // Emit more information about the missing ops. This error message
7224   // contains useful details beyond the op name (input and output shapes,
7225   // attributes, etc.).
7226   if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
7227     emitError(loc)
7228         << "Emitting more detail about one op that failed to legalize...";
7229   } else if (VLOG_IS_ON(1)) {
7230     emitError(loc) << "Emitting more detail about one of each type of op "
7231                       "that failed to legalize...";
7232   }
7233   for (const auto &op_info : op_name_to_error_info) {
7234     op_info.second.second->emitOpError() << "is not legalizable";
7235     if (!VLOG_IS_ON(1)) break;
7236   }
7237 }
7238 
7239 // Performs the lowering to XLA dialect.
runOnFunction()7240 void LegalizeTF::runOnFunction() {
7241   llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
7242   if (use_tf2xla_fallback_) {
7243     tf2xla_fallback_device_type = device_type_;
7244   }
7245   if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
7246                         legalize_chlo_, tf2xla_fallback_device_type,
7247                         prefer_tf2xla_))) {
7248     signalPassFailure();
7249   }
7250 }
7251 
7252 // Patterns whose root op is in the set `include_ops` are moved from the set
7253 // `from` to the returned set. This is used to partition patterns by op so they
7254 // can be cleanly migrated from the old bridge to the MLIR bridge.
PatternsIncludeOps(OwningRewritePatternList & from,const llvm::DenseSet<mlir::TypeID> & include_ops)7255 OwningRewritePatternList PatternsIncludeOps(
7256     OwningRewritePatternList &from,
7257     const llvm::DenseSet<mlir::TypeID> &include_ops) {
7258   OwningRewritePatternList to(from.getContext());
7259   // Filter NativePatterns.
7260   for (auto &pattern : from.getNativePatterns()) {
7261     Optional<OperationName> pat_op_name = pattern->getRootKind();
7262     // If the pattern does not have a specific operation, always include it,
7263     // If the pattern is in include_ops then include it.
7264     bool include =
7265         !pat_op_name ||
7266         include_ops.count(pat_op_name->getAbstractOperation()->typeID);
7267     if (include) to.add(std::move(pattern));
7268   }
7269 
7270   // Don't filter PDLPatterns.
7271   to.add(std::move(from.getPDLPatterns()));
7272 
7273   return to;
7274 }
7275 
7276 /// Returns ops that should use MLIR legalization only in the case of
7277 /// prefer_tf2xla. All other ops not in this list should use XlaOpKernel
7278 /// legalization only or not be legalized by the new bridge.
MlirPreferredOps()7279 const llvm::DenseSet<mlir::TypeID> &MlirPreferredOps() {
7280   // The static variable is a pointer in order to avoid destruction upon thread
7281   // termination.
7282 
7283   // clang-format off
7284   static const llvm::DenseSet<mlir::TypeID>* ops =
7285       new llvm::DenseSet<mlir::TypeID>{
7286     // Ops that are legalized in the old bridge using MlirXlaOpKernel
7287     TypeID::get<TF::AbsOp>(),
7288     TypeID::get<TF::AtanOp>(),
7289     TypeID::get<TF::AvgPool3DOp>(),
7290     TypeID::get<TF::BiasAddGradOp>(),
7291     TypeID::get<TF::CeilOp>(),
7292     TypeID::get<TF::CheckNumericsOp>(),
7293     TypeID::get<TF::ComplexOp>(),
7294     TypeID::get<TF::CosOp>(),
7295     TypeID::get<TF::DiagPartOp>(),
7296     TypeID::get<TF::DivOp>(),
7297     TypeID::get<TF::EinsumOp>(),
7298     TypeID::get<TF::ExpOp>(),
7299     TypeID::get<TF::Expm1Op>(),
7300     TypeID::get<TF::FakeQuantWithMinMaxArgsOp>(),
7301     TypeID::get<TF::FloorOp>(),
7302     TypeID::get<TF::GreaterEqualOp>(),
7303     TypeID::get<TF::IFFTOp>(),
7304     TypeID::get<TF::ImagOp>(),
7305     TypeID::get<TF::IsFiniteOp>(),
7306     TypeID::get<TF::IsInfOp>(),
7307     TypeID::get<TF::IsNanOp>(),
7308     TypeID::get<TF::LessEqualOp>(),
7309     TypeID::get<TF::LgammaOp>(),
7310     TypeID::get<TF::Log1pOp>(),
7311     TypeID::get<TF::LogicalOrOp>(),
7312     TypeID::get<TF::LogSoftmaxOp>(),
7313     TypeID::get<TF::MatrixBandPartOp>(),
7314     TypeID::get<TF::MaxPool3DGradOp>(),
7315     TypeID::get<TF::PreventGradientOp>(),
7316     TypeID::get<TF::RandomShuffleOp>(),
7317     TypeID::get<TF::RealOp>(),
7318     TypeID::get<TF::ReciprocalOp>(),
7319     TypeID::get<TF::ReluOp>(),
7320     TypeID::get<TF::Relu6Op>(),
7321     TypeID::get<TF::ReluGradOp>(),
7322     TypeID::get<TF::RsqrtOp>(),
7323     TypeID::get<TF::SelectOp>(),
7324     TypeID::get<TF::SigmoidOp>(),
7325     TypeID::get<TF::SignOp>(),
7326     TypeID::get<TF::SoftmaxOp>(),
7327     TypeID::get<TF::SqrtOp>(),
7328     TypeID::get<TF::SqrtGradOp>(),
7329     TypeID::get<TF::SquaredDifferenceOp>(),
7330     TypeID::get<TF::TanhOp>(),
7331     TypeID::get<TF::TanhGradOp>(),
7332     TypeID::get<TF::XlogyOp>(),
7333     TypeID::get<TF::ZetaOp>(),
7334 
7335     // Ops that have no XlaOpKernel.
7336     TypeID::get<TF::RiscAddOp>(),
7337     TypeID::get<TF::RiscDotOp>(),
7338 
7339     // Const op has a simple legalization and it is much more efficient to lower
7340     // within MLIR.
7341     TypeID::get<TF::ConstOp>(),
7342 
7343     // AssertOp with string types are not supported by the fallback.
7344     TypeID::get<TF::AssertOp>(),
7345 
7346     // TF2XLA fallback pattern doesn't support these op as MLIR hlo builder
7347     // doesn't override the necessary builder methods. These ops have simple
7348     // lowering pattern so this should be safe.
7349     TypeID::get<TF::CrossReplicaSumOp>(),
7350     TypeID::get<TF::InfeedDequeueTupleOp>(),
7351     TypeID::get<TF::OutfeedEnqueueTupleOp>(),
7352     TypeID::get<TF::XlaShardingOp>(),
7353   };
7354   // clang-format on
7355   return *ops;
7356 }
7357 
7358 }  // end namespace
7359 
7360 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
7361 
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)7362 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
7363                          bool legalize_chlo,
7364                          llvm::Optional<StringRef> tf2xla_fallback_device_type,
7365                          bool prefer_tf2xla) {
7366   MLIRContext *context = op->getContext();
7367   OwningRewritePatternList legalize_lower_patterns(context);
7368   // Note that the `OperationConverter` orders patterns lexicographically by:
7369   // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
7370   //    to arrive at conversion target). This requires relevant patterns to
7371   //    specify the list of ops generated by it which most of patterns
7372   //    implemented in C++ don't do so this comparison doesn't work in those
7373   //    cases.
7374   // 2) Descending pattern benefit.
7375   // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
7376   // 4) Order of patterns in `OwningRewritePatternList`.
7377 
7378   // Add TF->HLO legalization patterns.
7379   PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
7380 
7381   // Add TF->TF lowering patterns.
7382   TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
7383 
7384   if (tf2xla_fallback_device_type && prefer_tf2xla) {
7385     VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
7386                "either native MLIR legalization, or TF2XLA fallback "
7387                "legalzation, with a preference toward TF2XLA.";
7388   } else if (tf2xla_fallback_device_type) {
7389     VLOG(1) << "TF to XLA legalization patterns include all native patterns "
7390                "and TF2XLA fallback patterns.";
7391   } else {
7392     VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
7393   }
7394 
7395   // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
7396   // only patterns whose ops are in the set MlirPreferredOps are kept.
7397   OwningRewritePatternList patterns =
7398       (tf2xla_fallback_device_type && prefer_tf2xla)
7399           ? PatternsIncludeOps(legalize_lower_patterns, MlirPreferredOps())
7400           : std::move(legalize_lower_patterns);
7401 
7402   if (tf2xla_fallback_device_type) {
7403     // Add TF->HLO legalization patterns via TF2XLA fallback.
7404     PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
7405                                          patterns, context, prefer_tf2xla);
7406   }
7407 
7408   // Populate with CHLO->HLO lowerings to account for TF ops legalized to
7409   // CHLO first.
7410   if (legalize_chlo) {
7411     chlo::PopulateDecomposeChloPatterns(context, &patterns);
7412     chlo::PopulateChloBroadcastingPatterns(context, &patterns);
7413   }
7414   // ConstantLike op is convenient to create splat constants, but is
7415   // canonicalized to plain HLO constant if statically shaped. Add the
7416   // canonicalization pattern to pattern list to enable multi-hop lowering.
7417   chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
7418 
7419   ConversionTarget target(*context);
7420   if (legalize_chlo) {
7421     target.addIllegalDialect<chlo::HloClientDialect>();
7422   } else {
7423     target.addLegalDialect<chlo::HloClientDialect>();
7424   }
7425   target.addLegalDialect<MhloDialect>();
7426   target.addLegalDialect<StandardOpsDialect>();
7427   target.addLegalDialect<tensor::TensorDialect>();
7428   target.addLegalDialect<shape::ShapeDialect>();
7429   target.addLegalOp<CallOp>();
7430 
7431   if (!allow_partial_conversion) {
7432     // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
7433     target.addLegalOp<ModuleOp, FuncOp, ::mlir::ReturnOp>();
7434     DenseSet<Operation *> nonlegalized_ops;
7435     LogicalResult result = applyPartialConversion(
7436         op, target, std::move(patterns), &nonlegalized_ops);
7437     // In order to enforce that the conversion result is fully converted,
7438     // fail if there are any nonlegalized ops in the set.
7439     if (failed(result) || !nonlegalized_ops.empty()) {
7440       EmitLegalizationErrors(op, nonlegalized_ops);
7441       return failure();
7442     }
7443     return result;
7444   }
7445 
7446   return applyPartialConversion(op, target, std::move(patterns));
7447 }
7448 
PopulateLegalizeTfPatterns(MLIRContext * context,OwningRewritePatternList * patterns)7449 void PopulateLegalizeTfPatterns(MLIRContext *context,
7450                                 OwningRewritePatternList *patterns) {
7451   populateWithGenerated(*patterns);
7452   // clang-format off
7453   patterns->insert<
7454     ConvertAllOp,
7455     ConvertAnyOp,
7456     ConvertArgMaxOp,
7457     ConvertArgMinOp,
7458     ConvertBatchMatMulV2Op,
7459     ConvertBiasAddOp,
7460     ConvertBroadcastToOp,
7461     ConvertBF16FloorDivOp,
7462     ConvertClipByValueOp,
7463     ConvertConstOp,
7464     ConvertConv2DOp,
7465     ConvertConv3DOp,
7466     ConvertDepthConv2DOp,
7467     ConvertConv2DBackpropFilterOp,
7468     ConvertConv3DBackpropFilterOp,
7469     ConvertConv2DBackpropInputOp,
7470     ConvertConv3DBackpropInputOp,
7471     ConvertCumprodOp,
7472     ConvertCumsumOp,
7473     ConvertDiagPartOp,
7474     ConvertDynamicExpandDimsOp,
7475     ConvertEinsumOp,
7476     ConvertRFFTOp,
7477     ConvertIRFFTOp,
7478     ConvertFusedBatchNormGradOp,
7479     ConvertFusedBatchNormGradV2Op,
7480     ConvertFusedBatchNormGradV3Op,
7481     ConvertFusedBatchNormV2Op,
7482     ConvertFusedBatchNormV3Op,
7483     ConvertInfeedDequeueTupleOp,
7484     ConvertIdentityNOp,
7485     ConvertInplaceUpdateOp,
7486     ConvertLinSpaceOp,
7487     ConvertMaxOp,
7488     ConvertMinOp,
7489     ConvertAvgPool2DOp,
7490     ConvertAvgPool3DOp,
7491     ConvertAvgPool2DGradOp,
7492     ConvertAvgPool3DGradOp,
7493     ConvertMaxPool2DOp,
7494     ConvertMaxPool3DOp,
7495     ConvertMaxPool2DGradOp,
7496     ConvertMaxPool3DGradOp,
7497     ConvertMeanOp,
7498     ConvertOneHotOp,
7499     ConvertOutfeedEnqueueTupleOp,
7500     ConvertProdOp,
7501     ConvertQrOp,
7502     ConvertDynamicRangeOp,
7503     ConvertMatrixDiagPartV3Op,
7504     ConvertRangeOp,
7505     ConvertSelectOp,
7506     ConvertSigmoidOp,
7507     ConvertShapeOp,
7508     ConvertSplitOp,
7509     ConvertSplitVOp,
7510     ConvertStridedSliceOp,
7511     ConvertStridedSliceGradOp,
7512     ConvertSumOp,
7513     ConvertTensorScatterAddOp,
7514     ConvertTensorScatterSubOp,
7515     ConvertTensorScatterMinOp,
7516     ConvertTensorScatterMaxOp,
7517     ConvertTensorScatterUpdateOp,
7518     ConvertTileOp,
7519     ConvertTopKV2Op,
7520     ConvertUnpackOp,
7521     ConvertUnsortedSegmentMaxOp,
7522     ConvertUnsortedSegmentMinOp,
7523     ConvertUnsortedSegmentProdOp,
7524     ConvertUnsortedSegmentSumOp,
7525     ConvertRandomShuffleOp,
7526     ConvertXlaShardingOp,
7527     ConvertXlaDynamicUpdateSliceOp,
7528     ConvertXlaAllReduceOp,
7529     ConvertRollOp,
7530     ConvertLeakyReluOp,
7531     ConvertLeakyReluGradOp,
7532     ConvertPrintOp,
7533     ConvertSplitOpDynamic,
7534     ConvertSliceOpDynamic,
7535     ConvertTileOpDynamic,
7536     ConvertUnpackOpDynamic,
7537     ConvertSignOpDynamic,
7538     ConvertSigmoidGradOpDynamic,
7539     ConvertGatherV2OpDynamic,
7540     ConvertConv2DDynamic,
7541     ConvertPadOpDynamic,
7542     ConvertGatherNdOpDynamic>(context);
7543   // clang-format on
7544 }
7545 
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)7546 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
7547     bool allow_partial_conversion, bool legalize_chlo,
7548     llvm::Optional<StringRef> tf2xla_fallback_device_type, bool prefer_tf2xla) {
7549   return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
7550                                       tf2xla_fallback_device_type,
7551                                       prefer_tf2xla);
7552 }
7553 
7554 }  // end namespace mhlo
7555 }  // end namespace mlir
7556