1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/ErrorHandling.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
35 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
37 #include "mlir/Dialect/Traits.h" // from @llvm-project
38 #include "mlir/IR/Attributes.h" // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
41 #include "mlir/IR/Diagnostics.h" // from @llvm-project
42 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
43 #include "mlir/IR/MLIRContext.h" // from @llvm-project
44 #include "mlir/IR/Matchers.h" // from @llvm-project
45 #include "mlir/IR/Operation.h" // from @llvm-project
46 #include "mlir/IR/PatternMatch.h" // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
48 #include "mlir/IR/Types.h" // from @llvm-project
49 #include "mlir/Pass/Pass.h" // from @llvm-project
50 #include "mlir/Support/LogicalResult.h" // from @llvm-project
51 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
54 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
55 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
56 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
57 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
58 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
59 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
60 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
61 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
62 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
63 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
64 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
65 #include "tensorflow/compiler/xla/client/padding.h"
66 #include "tensorflow/compiler/xla/client/sharding_builder.h"
67 #include "tensorflow/compiler/xla/xla_data.pb.h"
68 #include "tensorflow/core/framework/kernel_shape_util.h"
69 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
70 #include "tensorflow/core/platform/bfloat16.h"
71 #include "tensorflow/core/tpu/tpu_api.h"
72 #include "tensorflow/core/util/padding.h"
73 #include "tensorflow/core/util/tensor_format.h"
74 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
75
76 namespace mlir {
77 namespace mhlo {
78 namespace {
79
80 constexpr char kShardingAttr[] = "mhlo.sharding";
81
82 class LegalizeTF : public LegalizeTFBase<LegalizeTF> {
83 public:
LegalizeTF(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)84 explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo,
85 llvm::Optional<StringRef> tf2xla_fallback_device_type,
86 bool prefer_tf2xla) {
87 allow_partial_conversion_ = allow_partial_conversion;
88 legalize_chlo_ = legalize_chlo;
89 prefer_tf2xla_ = prefer_tf2xla;
90 use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue();
91 if (tf2xla_fallback_device_type.hasValue()) {
92 device_type_ = tf2xla_fallback_device_type.getValue().str();
93 }
94 }
95 /// Performs the lowering to XLA dialect.
96 void runOnFunction() override;
97 };
98
99 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(tensorflow::TensorFormat format,RankedTensorType input_ty)100 static size_t GetFeatureDimension(tensorflow::TensorFormat format,
101 RankedTensorType input_ty) {
102 return GetTensorFeatureDimIndex(input_ty.getRank(), format);
103 }
104
105 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)106 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
107 auto array_attr = attr.cast<ArrayAttr>();
108 values->reserve(array_attr.getValue().size());
109 for (Attribute val : array_attr.getValue())
110 values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
111 }
112
113 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)114 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
115 Builder *builder) {
116 RankedTensorType ty = RankedTensorType::get(
117 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
118 return DenseIntElementsAttr::get(ty, values);
119 }
120
121 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)122 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
123 RankedTensorType ty =
124 RankedTensorType::get(static_cast<int64_t>(attr.size()),
125 IntegerType::get(attr.getContext(), 64));
126 return DenseIntElementsAttr::get(ty, attr.getValue());
127 }
128
129 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)130 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
131 Builder *builder) {
132 RankedTensorType ty = RankedTensorType::get(
133 {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
134 return DenseIntElementsAttr::get(ty, values);
135 }
136
137 // Returns a 1-d i64 elements attribute populated with numbers from start to
138 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)139 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
140 Builder *builder) {
141 int size = end - start;
142
143 SmallVector<int64_t, 4> vals;
144 vals.resize(size);
145 std::iota(vals.begin(), vals.end(), start);
146
147 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
148 return DenseIntElementsAttr::get(ty, vals);
149 }
150
151 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
152 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)153 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
154 Builder *builder) {
155 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
156 return DenseIntElementsAttr::get(ty, val);
157 }
158
159 // Returns the corresponding type that should be used for performing sum
160 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)161 Type GetSumAccumulationType(Type input_type) {
162 MLIRContext *ctx = input_type.getContext();
163 if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
164 if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
165 return IntegerType::get(ctx, 32);
166 return input_type;
167 }
168
169 // Returns axis in HLO format from TF elements attr with exactly one element or
170 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
171 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)172 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
173 Builder *b) {
174 IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
175 if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
176 SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
177 intAttr = elementAttr.getValue<IntegerAttr>(index);
178 }
179
180 assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
181
182 int64_t axis = intAttr.getInt();
183 if (axis < 0) {
184 axis += rank;
185 }
186 return b->getI64IntegerAttr(axis);
187 }
188
189 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
190 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
191 // be negative, in which case, the corresponding HLO axis is
192 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)193 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
194 int64_t rank) {
195 DenseIntElementsAttr attrs;
196 if (!matchPattern(value, m_Constant(&attrs)) ||
197 attrs.getType().getRank() != 0) {
198 return llvm::None;
199 }
200 int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
201 return axis < 0 ? axis + rank : axis;
202 }
203
204 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
205 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)206 static ConvertOp CastValueToI64(Location loc, Value value,
207 PatternRewriter *rewriter) {
208 return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
209 }
210
211 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
212 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)213 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
214 PatternRewriter *rewriter) {
215 auto indices_type = value.getType().cast<RankedTensorType>();
216 int num_outputs = indices_type.getShape().front();
217 SmallVector<Type, 2> unpacked_indices_type(
218 num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
219 auto unpacked_indices = rewriter->create<TF::UnpackOp>(
220 loc, unpacked_indices_type, value,
221 IntegerAttr::get(rewriter->getIntegerType(64), 0));
222 return unpacked_indices;
223 }
224
225 // Returns size of dimension at the specified index, if ranked tensor.
226 // Otherwise, returns -1.
227 //
228 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)229 int64_t GetDimSize(Type ty, int64_t index) {
230 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
231 if (!ranked_ty) return -1;
232
233 return ranked_ty.getDimSize(index);
234 }
235
236 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)237 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
238 return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
239 sizes.begin(), sizes.end()));
240 }
241
242 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)243 tensorflow::TensorShape ToTensorShape(
244 llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
245 return tensorflow::TensorShape(llvm::SmallVector<tensorflow::int64, num_dims>(
246 sizes.begin(), sizes.end()));
247 }
248
249 // Returns int, float, or complex scalar DenseElementsAttr attribute with the
250 // given element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,OpBuilder * builder)251 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
252 OpBuilder *builder) {
253 return builder->create<ConstOp>(loc, hlo::GetScalarOfType(ty, raw_value));
254 }
255
256 // Returns a limit scalar const op for the given type.
257 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)258 static ConstOp GetScalarLimitConstOfType(Type ty, Location loc,
259 hlo::ScalarLimit limit,
260 OpBuilder *builder) {
261 return builder->create<ConstOp>(loc, hlo::GetScalarLimitOfType(ty, limit));
262 }
263
264 // Creates an mhlo::SliceOp where the major dimensions have full size, and
265 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)266 static Value SliceInMinorDims(Location loc, Value v,
267 ArrayRef<int64_t> minor_starts,
268 ArrayRef<int64_t> minor_limits,
269 OpBuilder *builder) {
270 auto type = v.getType().cast<RankedTensorType>();
271 llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
272 int64_t major_dims = type.getRank() - minor_starts.size();
273 std::copy(minor_starts.begin(), minor_starts.end(),
274 slice_starts.begin() + major_dims);
275 auto slice_limits = llvm::to_vector<4>(type.getShape());
276 std::copy(minor_limits.begin(), minor_limits.end(),
277 slice_limits.begin() + major_dims);
278 llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
279 return builder->create<SliceOp>(loc, v,
280 GetI64ElementsAttr(slice_starts, builder),
281 GetI64ElementsAttr(slice_limits, builder),
282 GetI64ElementsAttr(slice_strides, builder));
283 }
284
285 // Creates a vector of index values:
286 // [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
287 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)288 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
289 Location loc, ArrayRef<Value> minor_indices, int64_t rank,
290 OpBuilder *builder) {
291 auto zero =
292 GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
293 loc, 0, builder)
294 .output();
295 llvm::SmallVector<Value, 4> indices(rank, zero);
296 std::copy(minor_indices.begin(), minor_indices.end(),
297 indices.begin() + (rank - minor_indices.size()));
298 return indices;
299 }
300
301 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
302 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)303 static Value DynamicSliceInMinorDims(Location loc, Value v,
304 ArrayRef<Value> minor_starts,
305 ArrayRef<int64_t> minor_sizes,
306 OpBuilder *builder) {
307 if (minor_starts.empty()) return v;
308 auto type = v.getType().cast<RankedTensorType>();
309 auto slice_starts = CreateFullIndexVectorFromMinorIndices(
310 loc, minor_starts, type.getRank(), builder);
311 int64_t major_dims = type.getRank() - minor_starts.size();
312 auto slice_sizes = llvm::to_vector<4>(type.getShape());
313 std::copy(minor_sizes.begin(), minor_sizes.end(),
314 slice_sizes.begin() + major_dims);
315 auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
316 return builder->create<mhlo::DynamicSliceOp>(
317 loc, slice_type, v, slice_starts,
318 GetI64ElementsAttr(slice_sizes, builder));
319 }
320
321 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
322 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)323 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
324 ArrayRef<Value> minor_starts,
325 OpBuilder *builder) {
326 if (minor_starts.empty()) return v;
327 auto type = v.getType().cast<RankedTensorType>();
328 auto dus_starts = CreateFullIndexVectorFromMinorIndices(
329 loc, minor_starts, type.getRank(), builder);
330 return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
331 llvm::makeArrayRef(dus_starts));
332 }
333
334 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
335 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)336 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
337 ArrayRef<int64_t> minor_starts,
338 OpBuilder *builder) {
339 llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
340 for (uint64_t i = 0; i < minor_starts.size(); ++i) {
341 dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
342 minor_starts[i], builder);
343 }
344 return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
345 }
346
347 // Deprecated: This is maintained to aid in porting old code that is not yet
348 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
349 // Gets the resulting type from a broadcast between two types for statically
350 // shaped types. This is to be used for legacy lowerings that both use non
351 // left-padded broadcasting and static shapes. Its use should not be permitted
352 // in new code.
353 // May return nullptr on invalid static broadcast dimensions.
354 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)355 static RankedTensorType GetStaticBroadcastType(
356 RankedTensorType x, RankedTensorType y,
357 DenseIntElementsAttr broadcast_dimensions_attr) {
358 auto element_type = x.getElementType();
359 auto shape_x = x.getShape();
360 auto shape_y = y.getShape();
361
362 if (shape_x.size() == shape_y.size()) {
363 llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
364 for (int i = 0; i < shape_x.size(); i++) {
365 auto x_val = shape_x[i];
366 auto y_val = shape_y[i];
367 out_shape[i] = std::max(x_val, y_val);
368 }
369 return RankedTensorType::get(out_shape, element_type);
370 }
371
372 auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
373 auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
374
375 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
376 // Explicit broadcast dimensions.
377 for (const APInt &int_value : broadcast_dimensions_attr) {
378 broadcast_dimensions.push_back(int_value.getSExtValue());
379 }
380 if (broadcast_dimensions.size() != shape_small.size()) {
381 return nullptr;
382 }
383 llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
384 shape_large.end());
385
386 // Update according to the broadcast dimensions.
387 for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
388 auto old_value = out_shape[index_pair.value()];
389 auto new_value = shape_small[index_pair.index()];
390 out_shape[index_pair.value()] = std::max(old_value, new_value);
391 }
392 return RankedTensorType::get(out_shape, element_type);
393 }
394
395 // Deprecated: This is maintained to aid in porting old code that is not yet
396 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
397 // Applies static binary broadcasting to a binary elementwise op.
398 // This is a legacy helper to provide general broadcasting support in legacy,
399 // static shaped code that relies on non-left-padded broadcasting semantics.
400 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)401 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
402 DenseIntElementsAttr broadcast_dims,
403 OpBuilder &builder) {
404 auto x_type = x.getType().cast<RankedTensorType>();
405 auto y_type = y.getType().cast<RankedTensorType>();
406 auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
407 if (!result_type) {
408 emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
409 << " with broadcast_dims = " << broadcast_dims;
410 return nullptr;
411 }
412 auto larger_broadcast_dims =
413 GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
414 if (x_type.getRank() < y_type.getRank()) {
415 if (x_type != result_type) {
416 x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
417 }
418 if (y_type != result_type) {
419 y = builder.create<BroadcastInDimOp>(loc, result_type, y,
420 larger_broadcast_dims);
421 }
422 } else {
423 if (x_type != result_type) {
424 x = builder.create<BroadcastInDimOp>(loc, result_type, x,
425 larger_broadcast_dims);
426 }
427 if (y_type != result_type) {
428 y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
429 }
430 }
431 return builder.create<BinaryOp>(loc, x, y);
432 }
433
434 // Gets a 1D tensor type suitable for expressing extents of the given tensor
435 // value type. If the value type is ranked, the result will be statically
436 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)437 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
438 Builder b(value_type.getContext());
439 int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
440 return RankedTensorType::get({dim}, b.getIndexType());
441 }
442
443 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
444 // value (broadcast_from) along that feature dimension. This is a shortcut
445 // for the cases where a 1D tensor must be broadcast along a specific feature
446 // dimension, which can vary based on data layout, etc.
447 //
448 // The extent of `broadcast_from` dim0 must be equal to the extent of the
449 // feature_dim of `broadcast_to`.
450 //
451 // Example:
452 // [1x2x3x4], [2], 1 -> [1x2x3x4]
453 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
454 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)455 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
456 Value broadcast_from, int64_t feature_dim,
457 OpBuilder &builder) {
458 auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
459 auto to_type = broadcast_to.getType().cast<RankedTensorType>();
460 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
461 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
462 auto result_extents = builder.create<shape::ToExtentTensorOp>(
463 loc, result_extents_type, result_shape);
464 return builder.create<DynamicBroadcastInDimOp>(
465 loc, to_type, broadcast_from, result_extents, broadcast_dims);
466 }
467
468 // Broadcasts `input` to the shape of `broadcast_to` value following
469 // TF::BroadcastTo semantics.
470 //
471 // Requires that input is a ranked tensor.
472 //
473 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
474 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)475 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
476 OpBuilder &builder) {
477 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
478 auto to_type = broadcast_to.getType().cast<TensorType>();
479 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
480 auto result_extents = builder.create<shape::ToExtentTensorOp>(
481 loc, result_extents_type, result_shape);
482 int64_t rank = input.getType().cast<RankedTensorType>().getRank();
483 auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
484 return builder.create<DynamicBroadcastInDimOp>(
485 loc, to_type, input, result_extents, broadcast_dims);
486 }
487
488 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)489 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
490 bool transpose_rhs, int64_t num_batch_dims,
491 ArrayAttr precision_config, OpBuilder *builder) {
492 auto batch_dimensions = GetI64ElementsAttr(
493 llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims)), builder);
494 auto lhs_contracting_dimensions = GetI64ElementsAttr(
495 llvm::makeArrayRef({transpose_lhs ? num_batch_dims : num_batch_dims + 1}),
496 builder);
497 auto rhs_contracting_dimensions = GetI64ElementsAttr(
498 llvm::makeArrayRef({transpose_rhs ? num_batch_dims + 1 : num_batch_dims}),
499 builder);
500 auto dimension_numbers = DotDimensionNumbers::get(
501 /*lhs_batching_dimensions=*/batch_dimensions,
502 /*rhs_batching_dimensions=*/batch_dimensions,
503 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
504 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
505 builder->getContext());
506 auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
507 auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
508 auto shape = llvm::to_vector<4>(lhs_shape);
509 shape[shape.size() - 2] =
510 transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
511 shape[shape.size() - 1] =
512 transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
513 Type element_type = getElementTypeOrSelf(lhs.getType());
514 return builder->create<DotGeneralOp>(
515 loc, RankedTensorType::get(shape, element_type), lhs, rhs,
516 dimension_numbers, precision_config);
517 }
518
519 // Builds body for reduce op by using the template binary op as the
520 // reducer op.
521 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)522 static void BuildReduceBody(Type element_type, Region *body,
523 OpBuilder *builder) {
524 OpBuilder::InsertionGuard guard(*builder);
525 Block *block = builder->createBlock(body);
526
527 // Block arguments are scalars of the given element type.
528 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
529 block->addArguments({type, type});
530
531 Location loc = body->getLoc();
532 auto reducer =
533 builder->create<Op>(loc, block->getArgument(0), block->getArgument(1));
534 builder->create<ReturnOp>(loc, reducer.getResult());
535 }
536
537 // Builds a set of operations for applying reduction on the input value. A
538 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)539 static Value ApplyReduction(Location loc, Value input,
540 DenseIntElementsAttr reduce_dims,
541 OpBuilder *builder) {
542 auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
543 return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
544 builder->getBoolAttr(false));
545 }
546
547 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
548 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)549 static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
550 int lower_limit, int upper_limit,
551 OpBuilder *builder) {
552 auto i32_type = builder->getIntegerType(32);
553 auto key_type = RankedTensorType::get({num_elements}, i32_type);
554 auto shape_tensor = builder->create<mhlo::ConstOp>(
555 loc, GetI64ElementsAttr({num_elements}, builder));
556
557 auto lower = builder->create<mhlo::ConstOp>(
558 loc, builder->getI32IntegerAttr(lower_limit));
559 auto upper = builder->create<mhlo::ConstOp>(
560 loc, builder->getI32IntegerAttr(upper_limit));
561
562 return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
563 shape_tensor);
564 }
565
566 using WhileBodyFnType = llvm::function_ref<void(
567 Location loc, Value iteration, ArrayRef<Value> old_values,
568 SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
569
570 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
571 // each time calling the given `body_fn` on a set of values to generate a new
572 // set of values. Returns the final set of values via `final_values`. The
573 // initial set of values is passed in via `init_values`.
574 //
575 // This effectively does:
576 //
577 // ```c++
578 // SmallVector<Values, 4> old_values = init_values;
579 // SmallVector<Values, 4> new_values;
580 // for (int i = 0; i < num_iterations; ++i) {
581 // body_fn(old_values, &new_values, ...);
582 // old_values = new_values;
583 // }
584 // ```
585 //
586 // Under the hood an induction variable is prepended to values to control the
587 // number of iterations, but that is transparent to `body_fn`, which does not
588 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)589 static void CreateWhile32(Location loc, int num_iterations,
590 WhileBodyFnType body_fn, ArrayRef<Value> init_values,
591 SmallVectorImpl<Value> *final_values,
592 OpBuilder *builder) {
593 int value_count = init_values.size() + 1;
594
595 // Prepend a loop induction variable to the initial values.
596 SmallVector<Value, 2> init_values_with_loop_iv;
597 init_values_with_loop_iv.reserve(value_count);
598 // The initial value for the loop induction variable is 0.
599 init_values_with_loop_iv.push_back(
600 builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
601 init_values_with_loop_iv.append(init_values.begin(), init_values.end());
602
603 // Prepare the initial tuple for the while op.
604 auto init_tuple =
605 builder->create<mhlo::TupleOp>(loc, init_values_with_loop_iv);
606 auto tuple_type = init_tuple.getType();
607
608 // Create the while op.
609 auto while_op = builder->create<mhlo::WhileOp>(
610 loc, tuple_type, SmallVector<Value>{init_tuple});
611
612 {
613 OpBuilder::InsertionGuard guard(*builder);
614
615 // Build up the only block in the condition region. It should take one
616 // argument of the loop's tuple type.
617 Region &condition = while_op.cond();
618 Block *block = builder->createBlock(&condition);
619 BlockArgument arg = block->addArgument(tuple_type);
620
621 // Get the loop induction variable and compare it against the upper limit.
622 auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
623 auto upper_limit = builder->create<mhlo::ConstOp>(
624 loc, builder->getI32IntegerAttr(num_iterations));
625 StringAttr compare_direction = StringAttr::get(builder->getContext(), "LT");
626 Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
627 compare_direction);
628
629 builder->create<mhlo::ReturnOp>(loc, compare);
630 }
631
632 {
633 OpBuilder::InsertionGuard guard(*builder);
634
635 // Build up the only block in the body region. It should take one
636 // argument of the loop's tuple type.
637 Region &body = while_op.body();
638 Block *block = builder->createBlock(&body);
639 BlockArgument arg = block->addArgument(tuple_type);
640
641 SmallVector<Value, 4> old_values; // From the previous iteration
642 SmallVector<Value, 4> new_values; // Generated by this iteration
643 old_values.reserve(value_count);
644 new_values.reserve(value_count);
645
646 // Unpack the tuple value from the last iteration.
647 for (int i = 0; i < value_count; ++i)
648 old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
649
650 // Feed all values excluding the loop induction variable to body_fn.
651 body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
652 &new_values, builder);
653
654 // Increment the loop induction variable by one.
655 auto one =
656 builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
657 auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
658 auto plus_one = builder->create<chlo::BroadcastAddOp>(
659 loc, old_values[0], one, scalar_broadcast_dims);
660 // Prepend with the updated loop induction variable.
661 new_values.insert(new_values.begin(), plus_one);
662
663 Value updated_tuple = builder->create<mhlo::TupleOp>(loc, new_values);
664
665 builder->create<mhlo::ReturnOp>(loc, updated_tuple);
666 }
667
668 // TODO(jpienaar): Support multi-operand while op.
669 final_values->reserve(init_values.size());
670 for (int i = 0, e = init_values.size(); i < e; ++i)
671 final_values->push_back(
672 builder->create<GetTupleElementOp>(loc, while_op.getResult(0), i + 1));
673 }
674
675 //===----------------------------------------------------------------------===//
676 // BatchNorm op utilities.
677 //===----------------------------------------------------------------------===//
678
getFeatureDimensionAttr(Builder & b,tensorflow::TensorFormat format,Value input)679 static IntegerAttr getFeatureDimensionAttr(Builder &b,
680 tensorflow::TensorFormat format,
681 Value input) {
682 return b.getI64IntegerAttr(
683 GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
684 }
685
686 //===----------------------------------------------------------------------===//
687 // FFT op utilities.
688 //===----------------------------------------------------------------------===//
689
690 // Returns the 1D i64 elements attribute populated with the inner-most dim of
691 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)692 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
693 Builder *builder) {
694 if (type.getRank() == 0) {
695 return builder->getI64TensorAttr({});
696 }
697 return builder->getI64TensorAttr(type.getShape().back());
698 }
699
700 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)701 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
702 if (!type.hasRank()) {
703 return false;
704 }
705 return !type.isDynamicDim(type.getShape().size() - 1);
706 }
707
708 //===----------------------------------------------------------------------===//
709 // MatMul op utilities.
710 //===----------------------------------------------------------------------===//
711
712 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
713 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)714 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
715 if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
716 return GetI64ElementsAttr({0, 1}, b);
717 }
718
719 //===----------------------------------------------------------------------===//
720 // MatrixBandPart op utilities.
721 //===----------------------------------------------------------------------===//
722
723 // Gets the size of the dimension `dim_from_end` from the end of `input`.
724 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)725 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
726 // Note: the verifier enforces that `input` is a ranked tensor.
727 auto input_type = input.getType().cast<TensorType>();
728 auto input_shape = input_type.getShape();
729 int dim = (input_shape.size() - 1) - dim_from_end;
730 return input_shape[dim];
731 }
732
733 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
734 // have the same size as the last two dimensions of `input` (the second-to-last
735 // dimension and last dimension, respectively). The element type of the
736 // outputted RankedTensorType will match the element type of `input`.
737 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)738 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
739 // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
740 int dim_0 = GetDimensionSizeFromEnd(input, 1);
741 int dim_1 = GetDimensionSizeFromEnd(input, 0);
742 auto element_type = num_lower.getType().cast<TensorType>().getElementType();
743 return RankedTensorType::get({dim_0, dim_1}, element_type);
744 }
745
746 // Creates a HLO ConvertOp, converting `input` to have the same element type as
747 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)748 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
749 Value elem_type_tensor) {
750 auto element_type =
751 elem_type_tensor.getType().cast<TensorType>().getElementType();
752 return builder->create<mhlo::ConvertOp>(loc, input, element_type);
753 }
754
755 //===----------------------------------------------------------------------===//
756 // Pad op utilities.
757 //===----------------------------------------------------------------------===//
758
759 // Slices input attribute of rank two and returns the specified column.
760 //
761 // Always returns 64 bit integer attribute regardless of bitwidth of the input
762 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)763 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
764 ElementsAttr input, int column) {
765 auto int_attr = input.cast<DenseIntElementsAttr>();
766 auto shaped_type = int_attr.getType();
767 auto shape = shaped_type.getShape();
768
769 if (shape.size() != 2) return DenseIntElementsAttr();
770
771 llvm::SmallVector<int64_t, 4> values;
772 values.reserve(shaped_type.getNumElements() / shape[1]);
773
774 for (auto it : llvm::enumerate(int_attr.getIntValues())) {
775 if (static_cast<int>(it.index() % shape[1]) == column) {
776 values.push_back(it.value().getSExtValue());
777 }
778 }
779
780 auto element_type = IntegerType::get(input.getContext(), 64);
781 return DenseIntElementsAttr::get(
782 RankedTensorType::get({shape[0]}, element_type), values);
783 }
784
785 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
786 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)787 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
788 auto length = tf_padding.getType().getShape()[0];
789 auto element_type = IntegerType::get(tf_padding.getContext(), 64);
790 return DenseIntElementsAttr::get<int64_t>(
791 RankedTensorType::get({length}, element_type), 0);
792 }
793
794 //===----------------------------------------------------------------------===//
795 // Binary op utilities.
796 //===----------------------------------------------------------------------===//
797
798 // Returns whether the two values are guaranteed to be broadcastable to the
799 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
800 // must be broadcasted with a size 1 tensor or another dynamic dimension.
801 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)802 static bool AreBroadcastCompatible(Value x, Value y) {
803 auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
804 auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
805 if (!x_rankless || !y_rankless) {
806 return false;
807 }
808
809 // Check that the shapes can be broadcasted.
810 auto shape_x = x_rankless.getShape();
811 auto shape_y = y_rankless.getShape();
812
813 int rank_diff = shape_x.size() - shape_y.size();
814 int offset_x = rank_diff > 0 ? rank_diff : 0;
815 int offset_y = rank_diff < 0 ? -rank_diff : 0;
816 for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
817 int index_x = i + offset_x;
818 int index_y = i + offset_y;
819 if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
820 (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
821 return false;
822 }
823 }
824
825 return true;
826 }
827
828 // Return a new TensorType the same rank and dimensions as the input with an
829 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)830 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
831 Type element_type) {
832 RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
833 if (ranked_type) {
834 return RankedTensorType::get(ranked_type.getShape(), element_type);
835 }
836
837 return UnrankedTensorType::get(element_type);
838 }
839
840 //===----------------------------------------------------------------------===//
841 // Softmax op utilities.
842 //===----------------------------------------------------------------------===//
843
844 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)845 static Type GetAccumulationType(Type ty) {
846 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
847 // repeated floating point additions.
848 return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
849 }
850
851 //===----------------------------------------------------------------------===//
852 // Softplus op utilities.
853 //===----------------------------------------------------------------------===//
854
GetEpsilonValue(Type ty)855 static DenseElementsAttr GetEpsilonValue(Type ty) {
856 auto element_ty = ty.cast<TensorType>().getElementType();
857 auto scalar_ty = RankedTensorType::get({}, element_ty);
858 if (element_ty.isF16()) {
859 uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
860 Eigen::NumTraits<Eigen::half>::epsilon());
861 auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
862 return DenseElementsAttr::get(scalar_ty, value);
863 } else if (element_ty.isBF16()) {
864 uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
865 Eigen::NumTraits<Eigen::bfloat16>::epsilon());
866 auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
867 return DenseElementsAttr::get(scalar_ty, value);
868 } else if (element_ty.isF32()) {
869 auto value = APFloat(std::numeric_limits<float>::epsilon());
870 return DenseElementsAttr::get(scalar_ty, value);
871 } else if (element_ty.isF64()) {
872 auto value = APFloat(std::numeric_limits<double>::epsilon());
873 return DenseElementsAttr::get(scalar_ty, value);
874 }
875 llvm_unreachable("unsupported element type for tf.SoftPlus");
876 }
877
878 //===----------------------------------------------------------------------===//
879 // ArgMax/ArgMin op utilities.
880 //===----------------------------------------------------------------------===//
881
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)882 static void BuildArgMinMaxReductionBody(Type input_element_type,
883 Type index_element_type,
884 StringRef direction, Region *body,
885 OpBuilder *builder) {
886 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
887
888 Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
889 Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
890 Block *block = builder->createBlock(body);
891 block->addArguments({input_type, index_type, input_type, index_type});
892
893 Value lhs_val = block->getArgument(0);
894 Value lhs_index = block->getArgument(1);
895 Value rhs_val = block->getArgument(2);
896 Value rhs_index = block->getArgument(3);
897
898 ImplicitLocOpBuilder b(body->getLoc(), *builder);
899 StringAttr compare_direction = StringAttr::get(b.getContext(), direction);
900 Value compare_dt = b.create<CompareOp>(lhs_val, rhs_val, compare_direction);
901 Value selected_input =
902 b.create<SelectOp>(input_type, compare_dt, lhs_val, rhs_val);
903
904 Value compare_eq = b.create<CompareOp>(lhs_val, rhs_val,
905 StringAttr::get(b.getContext(), "EQ"));
906 Value min_index = b.create<MinOp>(lhs_index, rhs_index);
907 Value min_val_index =
908 b.create<SelectOp>(index_type, compare_dt, lhs_index, rhs_index);
909 Value selected_index =
910 b.create<SelectOp>(index_type, compare_eq, min_index, min_val_index);
911
912 Value return_values[] = {selected_input, selected_index};
913 b.create<ReturnOp>(return_values);
914 }
915
916 //===----------------------------------------------------------------------===//
917 // PartitionedCall op utilities.
918 //===----------------------------------------------------------------------===//
919
920 // Verify that the arguments to be passed into the function are the same types
921 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)922 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
923 SymbolRefAttr func) {
924 auto module = op->getParentOfType<ModuleOp>();
925 auto function =
926 dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
927 FunctionType function_ty = function.getType();
928
929 for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
930 if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
931 // Argument type and input type mismatch.
932 return false;
933 }
934 }
935 return true;
936 }
937
938 //===----------------------------------------------------------------------===//
939 // Slice op utilities.
940 //===----------------------------------------------------------------------===//
941
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)942 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
943 DenseIntElementsAttr slice_sizes) {
944 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
945 if (!input_ty) return false;
946 auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
947 if (!start_indices_ty) return false;
948
949 int64_t input_rank = input_ty.getRank();
950 ArrayRef<int64_t> input_shape = input_ty.getShape();
951 DenseIntElementsAttr constant_start_indices;
952 bool is_constant_start =
953 matchPattern(start_indices, m_Constant(&constant_start_indices));
954
955 for (int64_t i = 0; i < input_rank; ++i) {
956 int64_t input_size = input_shape[i];
957 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
958 // A slice_size of -1 means "all elements from start_index to the end".
959 // In order to support these semantics, we need to know both the start index
960 // and the shape of the input dimension.
961 if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
962 }
963 return true;
964 }
965
966 // TF slice size can be -1, which represents all elements from start_index to
967 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
968 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)969 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
970 Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
971 Builder *builder) {
972 DenseIntElementsAttr constant_start_indices;
973 if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
974 return hlo::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
975 .cast<DenseIntElementsAttr>();
976 }
977
978 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
979 int64_t input_rank = input_ty.getRank();
980 ArrayRef<int64_t> input_shape = input_ty.getShape();
981 SmallVector<int64_t, 4> normalized_sizes;
982
983 for (int64_t i = 0; i < input_rank; ++i) {
984 int64_t input_size = input_shape[i];
985 int64_t start_index =
986 constant_start_indices.getValue<IntegerAttr>(i).getInt();
987 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
988 normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
989 : slice_size);
990 }
991
992 return GetI64ElementsAttr(normalized_sizes, builder);
993 }
994
995 //===----------------------------------------------------------------------===//
996 // Sort op utilities.
997 //===----------------------------------------------------------------------===//
998
999 // Builds the region `body` for mhlo.sort's comparator: for each type in
1000 // `element_types`, create two block arguments, one for lhs and one for rhs, and
1001 // generates mhlo.compare op to compare them with the given `direction`.
1002 //
1003 // Note that this right now only does comparision on the first pair of block
1004 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,llvm::Optional<StringRef> compare_type,Region * body,OpBuilder * builder)1005 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
1006 StringRef direction,
1007 llvm::Optional<StringRef> compare_type,
1008 Region *body, OpBuilder *builder) {
1009 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
1010
1011 Block *block = builder->createBlock(body);
1012 // Add two arguments for each element type.
1013 for (Type element_type : element_types) {
1014 TensorType tensor_type = RankedTensorType::get({}, element_type);
1015 block->addArguments({tensor_type, tensor_type});
1016 }
1017
1018 Location loc = body->getLoc();
1019 StringAttr compare_direction = builder->getStringAttr(direction);
1020 StringAttr type_attr;
1021 if (compare_type) type_attr = builder->getStringAttr(*compare_type);
1022 Value compare = builder->create<mhlo::CompareOp>(
1023 loc, block->getArgument(0), block->getArgument(1), compare_direction,
1024 type_attr);
1025
1026 builder->create<mhlo::ReturnOp>(loc, compare);
1027 }
1028
1029 //===----------------------------------------------------------------------===//
1030 // XlaGather op utilities.
1031 //===----------------------------------------------------------------------===//
1032
HasValidGatherDims(StringAttr attr)1033 bool HasValidGatherDims(StringAttr attr) {
1034 ::xla::GatherDimensionNumbers dims;
1035 return dims.ParseFromString(attr.getValue().str());
1036 }
1037
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)1038 GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) {
1039 ::xla::GatherDimensionNumbers dims;
1040 if (!dims.ParseFromString(attr.getValue().str())) return {};
1041 return ::xla::ConvertGatherDimensionNumbers(dims, builder);
1042 }
1043
1044 //===----------------------------------------------------------------------===//
1045 // XlaDot op utilities.
1046 //===----------------------------------------------------------------------===//
1047
HasValidDotDims(StringAttr attr)1048 bool HasValidDotDims(StringAttr attr) {
1049 ::xla::DotDimensionNumbers dims;
1050 return dims.ParseFromString(attr.getValue().str());
1051 }
1052
GetDotDimNumsAttr(StringAttr attr,Builder * builder)1053 DotDimensionNumbers GetDotDimNumsAttr(StringAttr attr, Builder *builder) {
1054 ::xla::DotDimensionNumbers dims;
1055 if (!dims.ParseFromString(attr.getValue().str())) return {};
1056 return ::xla::ConvertDotDimensionNumbers(dims, builder);
1057 }
1058
HasValidPrecisionConfig(StringAttr attr)1059 bool HasValidPrecisionConfig(StringAttr attr) {
1060 ::xla::PrecisionConfig precision;
1061 return precision.ParseFromString(attr.getValue().str());
1062 }
1063
GetPrecisionConfigAttr(StringAttr attr,Builder * builder)1064 mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) {
1065 ::xla::PrecisionConfig precision;
1066 if (!precision.ParseFromString(attr.getValue().str())) return {};
1067 return ::xla::ConvertPrecisionConfig(&precision, builder);
1068 }
1069
1070 //===----------------------------------------------------------------------===//
1071 // Op converters.
1072 //===----------------------------------------------------------------------===//
1073
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)1074 NamedAttribute GetConvDimensionNumbersAttr(
1075 ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
1076 Builder *builder) {
1077 int64_t num_spatial_dims = spatial_dim_indices.size();
1078 int64_t num_dims = num_spatial_dims + 2;
1079
1080 IntegerAttr batch_dim =
1081 builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
1082 IntegerAttr feature_dim =
1083 builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
1084 DenseIntElementsAttr spatial_dims =
1085 GetI64ElementsAttr(spatial_dim_indices, builder);
1086
1087 // Filters data_format is always HWIO so input channels dimension is after
1088 // all spatial dimensions.
1089 IntegerAttr kernel_input_feature_dim =
1090 builder->getI64IntegerAttr(num_spatial_dims);
1091 IntegerAttr kernel_output_feature_dim =
1092 builder->getI64IntegerAttr(num_spatial_dims + 1);
1093 DenseIntElementsAttr kernel_spatial_dimensions =
1094 GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
1095
1096 return builder->getNamedAttr(
1097 "dimension_numbers",
1098 ConvDimensionNumbers::get(
1099 batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
1100 kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
1101 feature_dim, spatial_dims, builder->getContext()));
1102 }
1103
1104 // Converts a TF::BiasAddOp to HLO.
1105 // This differs from a normal TF::AddOp with respect to how the data_format
1106 // is handled, which can optionally require a general broadcast of the
1107 // 'bias' term in a way that is not compatible with the standard left-padded
1108 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1109 // The correct 'bias' broadcast will be synthesized manually.
1110 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1111 public:
1112 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1113 LogicalResult matchAndRewrite(TF::BiasAddOp op,
1114 PatternRewriter &rewriter) const override {
1115 Location loc = op.getLoc();
1116 tensorflow::TensorFormat data_format;
1117 if (!FormatFromString(op.data_format().str(), &data_format))
1118 return op.emitOpError("invalid data format");
1119
1120 auto feature_dim = GetFeatureDimension(
1121 data_format, op.value().getType().cast<RankedTensorType>());
1122 auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1123 feature_dim, rewriter);
1124 Value add = rewriter.create<AddOp>(loc, op.value(), bias_broadcast);
1125 if (add.getType() != op.getType()) {
1126 add = rewriter.create<tensor::CastOp>(loc, op.getType(), add);
1127 }
1128 rewriter.replaceOp(op, {add});
1129 return success();
1130 }
1131 };
1132
1133 // Convert TF::GatherV2Op to mhlo::DynamicGatherOp
1134 class ConvertGatherV2OpDynamic : public OpRewritePattern<TF::GatherV2Op> {
1135 using OpRewritePattern<TF::GatherV2Op>::OpRewritePattern;
1136 // TODO(disc): To recover static special case's performance with folding and
1137 // canonicalization.
matchAndRewrite(TF::GatherV2Op op,PatternRewriter & rewriter) const1138 LogicalResult matchAndRewrite(TF::GatherV2Op op,
1139 PatternRewriter &rewriter) const override {
1140 Location loc = op.getLoc();
1141 Value params = op.params();
1142 // params and indices of GatherNdOp must be ranked
1143 auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1144 Value indices = op.indices();
1145 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1146 if (!params_ty || !indices_ty) return failure();
1147
1148 // TODO(disc): Remove this constraint once fold and canonicalization
1149 // implemented.
1150 if (params_ty.hasStaticShape() && indices_ty.hasStaticShape())
1151 return failure();
1152
1153 int64_t params_rank = params_ty.getRank();
1154 int64_t indices_rank = indices_ty.getRank();
1155
1156 // axis
1157 DenseIntElementsAttr axis_attr;
1158 // axis must be const for GatherOp
1159 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1160
1161 int64_t axis = (*axis_attr.begin()).getSExtValue();
1162 if (axis < 0) axis += params_rank;
1163
1164 // slice_sizes
1165 SmallVector<int64_t, 4> slice_sizes;
1166 slice_sizes.reserve(params_rank);
1167 for (int64_t dim_idx = 0; dim_idx < params_rank; ++dim_idx) {
1168 if (dim_idx == axis) {
1169 slice_sizes.push_back(1);
1170 } else {
1171 // potentially dynamic
1172 int64_t dim_size = params_ty.getDimSize(dim_idx);
1173 slice_sizes.push_back(dim_size);
1174 }
1175 }
1176 SmallVector<Value, 4> slice_sizes_vals;
1177 for (int64_t dim_idx = 0; dim_idx < params_rank; ++dim_idx) {
1178 if (dim_idx == axis) {
1179 slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1180 loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1181 } else {
1182 int64_t dim_size = params_ty.getDimSize(dim_idx);
1183 if (dim_size != ShapedType::kDynamicSize) {
1184 slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1185 loc,
1186 rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1187 } else {
1188 slice_sizes_vals.push_back(rewriter.create<IndexCastOp>(
1189 loc, rewriter.create<tensor::DimOp>(loc, params, dim_idx),
1190 indices_ty.getElementType()));
1191 }
1192 }
1193 }
1194 Value slice_sizes_value = rewriter.create<tensor::FromElementsOp>(
1195 loc, indices_ty.getElementType(), slice_sizes_vals);
1196 // offset_dims
1197 SmallVector<int64_t, 4> offset_dims;
1198 for (int64_t dim_idx = 0; dim_idx < params_rank; dim_idx++) {
1199 if (dim_idx < axis) {
1200 offset_dims.push_back(dim_idx);
1201 } else if (dim_idx >= axis + 1) {
1202 offset_dims.push_back(dim_idx + indices_rank - 1);
1203 }
1204 }
1205 // collapsed_slice_dims
1206 SmallVector<int64_t, 4> collapsed_slice_dims(1, axis);
1207 // start_index_map
1208 SmallVector<int64_t, 4> start_index_map(1, axis);
1209 // index_vector_dim
1210 int64_t index_vector_dim = indices_rank;
1211 auto dims_attr = GatherDimensionNumbers::get(
1212 /*offset_dims=*/GetI64ElementsAttr(offset_dims, &rewriter),
1213 /*collapsed_slice_dims=*/
1214 GetI64ElementsAttr(collapsed_slice_dims, &rewriter),
1215 /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1216 /*index_vector_dim=*/rewriter.getI64IntegerAttr(index_vector_dim),
1217 rewriter.getContext());
1218
1219 rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1220 op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1221 dims_attr);
1222 return success();
1223 }
1224 };
1225
1226 // Conterts tf.Conv2D to mhlo.dynamic_conv.
1227 // TODO(disc): To recover static special case's performance with adding folding,
1228 // canonicalization func and removing ConvertConvOp.
1229 template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
1230 class ConvertConvDynamic : public OpRewritePattern<OpT> {
1231 public:
1232 using OpRewritePattern<OpT>::OpRewritePattern;
1233
GetPaddingValues(OpT & op,PatternRewriter & rewriter,Value input_size,Value filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding_type,Type shape_scalar_type,Value * padding_low,Value * padding_high) const1234 bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size,
1235 Value filter_size, int64_t dilation_rate,
1236 int64_t stride, tensorflow::Padding padding_type,
1237 Type shape_scalar_type, Value *padding_low,
1238 Value *padding_high) const {
1239 // Stride must be > 0
1240 if (stride <= 0) return false;
1241 // Dilation rate must be >= 1
1242 if (dilation_rate < 1) return false;
1243
1244 Location loc = op.getLoc();
1245 switch (padding_type) {
1246 case tensorflow::Padding::VALID: {
1247 auto zero = rewriter.create<ConstantIntOp>(loc, 0, shape_scalar_type);
1248 *padding_low = *padding_high = zero;
1249 break;
1250 }
1251 case tensorflow::Padding::EXPLICIT:
1252 break;
1253 case tensorflow::Padding::SAME: {
1254 auto zero = rewriter.create<ConstantIntOp>(loc, 0, shape_scalar_type);
1255 auto one = rewriter.create<ConstantIntOp>(loc, 1, shape_scalar_type);
1256 auto two = rewriter.create<ConstantIntOp>(loc, 2, shape_scalar_type);
1257 // See also the parallel implementation in
1258 // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size
1259 // - 1) * dilation_rate + 1
1260 Value stride_value =
1261 rewriter.create<ConstantIntOp>(loc, stride, shape_scalar_type);
1262 Value dilation_rate_value = rewriter.create<ConstantIntOp>(
1263 loc, dilation_rate, shape_scalar_type);
1264 Value effective_filter_size_op = rewriter.create<AddIOp>(
1265 loc, one,
1266 rewriter.create<MulIOp>(
1267 loc, dilation_rate_value,
1268 rewriter.create<SubIOp>(loc, filter_size, one)));
1269 // output_size = (input_size + stride - 1) / stride;
1270 Value output_size = rewriter.create<UnsignedDivIOp>(
1271 loc,
1272 rewriter.create<AddIOp>(
1273 loc, input_size,
1274 rewriter.create<SubIOp>(loc, stride_value, one)),
1275 stride_value);
1276 // std::max(int64{0}, (output_size - 1) * stride +
1277 // effective_filter_size - input_size);
1278 Value padding_needed = rewriter.create<SubIOp>(
1279 loc,
1280 rewriter.create<AddIOp>(
1281 loc, effective_filter_size_op,
1282 rewriter.create<MulIOp>(
1283 loc, stride_value,
1284 rewriter.create<SubIOp>(loc, output_size, one))),
1285 input_size);
1286 Value cond = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge,
1287 padding_needed, zero);
1288 padding_needed = rewriter.create<mlir::SelectOp>(
1289 loc, padding_needed.getType(), cond, padding_needed, zero);
1290 *padding_low =
1291 rewriter.create<UnsignedDivIOp>(loc, padding_needed, two);
1292 *padding_high =
1293 rewriter.create<SubIOp>(loc, padding_needed, *padding_low);
1294 break;
1295 }
1296 }
1297 return true;
1298 }
1299
matchAndRewriteDynamicConv(OpT op,PatternRewriter & rewriter) const1300 LogicalResult matchAndRewriteDynamicConv(OpT op,
1301 PatternRewriter &rewriter) const {
1302 tensorflow::TensorFormat data_format;
1303 if (!FormatFromString(op.data_format().str(), &data_format))
1304 return op.emitOpError("invalid data format");
1305
1306 tensorflow::Padding padding;
1307 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1308 return failure();
1309
1310 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1311 auto filter_ty =
1312 op.filter().getType().template dyn_cast<RankedTensorType>();
1313 auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1314 if (!input_ty || !filter_ty || !result_ty) return failure();
1315 // TODO(disc): Remove this constraint once fold and canonicalization
1316 // implemented.
1317 if (input_ty.hasStaticShape() && filter_ty.hasStaticShape())
1318 return failure();
1319
1320 ArrayRef<Attribute> dilations = op.dilations().getValue();
1321 ArrayRef<Attribute> strides = op.strides().getValue();
1322 ArrayRef<Attribute> explicit_paddings;
1323 if (padding == tensorflow::Padding::EXPLICIT) {
1324 // EXPLICIT padding mode and the associated attribute is attached to
1325 // Conv2D.
1326 explicit_paddings =
1327 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1328 }
1329
1330 SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1331 SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1332 SmallVector<int64_t, num_spatial_dims> window_strides;
1333 SmallVector<Value, num_spatial_dims * 2> paddings;
1334
1335 auto get_int = [](Attribute attr) {
1336 return attr.template cast<IntegerAttr>().getInt();
1337 };
1338
1339 constexpr int num_dims = num_spatial_dims + 2;
1340
1341 Location loc = op.getLoc();
1342 auto shape_scalar_type = rewriter.getIntegerType(32);
1343
1344 auto get_const = [&](int64_t val) {
1345 return rewriter.create<mlir::ConstantIntOp>(loc, val, shape_scalar_type);
1346 };
1347 auto get_dim_value = [&](Value val, int64_t dim) {
1348 Value dim_value = rewriter.create<tensor::DimOp>(loc, val, dim);
1349 return rewriter.create<IndexCastOp>(loc, dim_value, shape_scalar_type);
1350 };
1351
1352 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1353 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1354 spatial_dim_indices.push_back(dim);
1355
1356 const int64_t dilation = get_int(dilations[dim]);
1357 rhs_dilations.push_back(dilation);
1358 const int64_t stride = get_int(strides[dim]);
1359 window_strides.push_back(stride);
1360
1361 Value pad_low, pad_high;
1362 if (padding == tensorflow::Padding::EXPLICIT) {
1363 pad_low = get_const(get_int(explicit_paddings[2 * dim]));
1364 pad_high = get_const(get_int(explicit_paddings[2 * dim + 1]));
1365 } else {
1366 auto input_size = get_dim_value(op.input(), dim);
1367 auto filter_size = get_dim_value(op.filter(), i);
1368 if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation,
1369 stride, padding, shape_scalar_type, &pad_low,
1370 &pad_high)) {
1371 return failure();
1372 }
1373 }
1374 paddings.push_back(pad_low);
1375 paddings.push_back(pad_high);
1376 }
1377 auto rhs_dilations_attr = rewriter.getNamedAttr(
1378 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1379
1380 auto window_strides_attr = rewriter.getNamedAttr(
1381 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1382
1383 auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1384 spatial_dim_indices, data_format, &rewriter);
1385
1386 const int64_t input_channels =
1387 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1388 // Filters data_format is always HWIO so input channels dimension is after
1389 // all spatial dimensions.
1390 const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1391 // TensorFlow convolution op verifies that the number of input channels is
1392 // divisible by the number of filter channels.
1393 // For depthwise convolution the feature_group_count argument would be set
1394 // to the input feature dimension.
1395 const int64_t feature_group_count =
1396 depthwise_conv ? input_channels : input_channels / filter_channels;
1397 auto feature_group_count_attr = rewriter.getNamedAttr(
1398 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1399
1400 auto batch_group_count_attr = rewriter.getNamedAttr(
1401 "batch_group_count", rewriter.getI64IntegerAttr(1));
1402
1403 Value paddings_op = rewriter.create<tensor::FromElementsOp>(
1404 op.getLoc(), rewriter.getI32Type(), paddings);
1405
1406 SmallVector<Value, 3> operands(op.getOperands());
1407 operands.push_back(paddings_op);
1408 // Reshape the filter to {spatial_dims...., 1,in_channels *
1409 // channel_multiplier}
1410 if (depthwise_conv) {
1411 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1412 llvm::SmallVector<int64_t, num_dims> new_shape(
1413 filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1414 new_shape.push_back(1);
1415 new_shape.push_back(filter_shape[num_spatial_dims] *
1416 filter_shape[num_spatial_dims + 1]);
1417 operands[1] = rewriter.create<mhlo::ReshapeOp>(
1418 op.getLoc(),
1419 RankedTensorType::get(new_shape, filter_ty.getElementType()),
1420 operands[1]);
1421 }
1422 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1423 dimension_numbers_attr, feature_group_count_attr,
1424 batch_group_count_attr};
1425 rewriter.replaceOpWithNewOp<mhlo::DynamicConvOp>(op, op.getType(), operands,
1426 llvm::makeArrayRef(attrs));
1427 return success();
1428 }
1429
matchAndRewrite(OpT op,PatternRewriter & rewriter) const1430 LogicalResult matchAndRewrite(OpT op,
1431 PatternRewriter &rewriter) const override {
1432 return matchAndRewriteDynamicConv(op, rewriter);
1433 }
1434 };
1435
1436 using ConvertConv2DDynamic =
1437 ConvertConvDynamic<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1438
1439 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1440 // converting TensorFlow op attributes to HLO op attributes.
1441 //
1442 // Sample result for Conv2D:
1443 //
1444 // %conv = "mhlo.convolution"(%input, %filter) {
1445 // strides = [1, 2],
1446 // paddings = [[1, 0], [1, 1]],
1447 // ...
1448 // }
1449 //
1450 // This pattern is not defined using declarative rewrite rules as computation of
1451 // the paddings attribute anyway requires multiple source op attributes and
1452 // result op attributes. Defining it as declarative rewrite rule will introduce
1453 // some duplication in the C++ helper methods.
1454 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1455 class ConvertConvOp : public OpRewritePattern<OpTy> {
1456 public:
1457 using OpRewritePattern<OpTy>::OpRewritePattern;
1458
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1459 LogicalResult matchAndRewrite(OpTy op,
1460 PatternRewriter &rewriter) const override {
1461 tensorflow::TensorFormat data_format;
1462 if (!FormatFromString(op.data_format().str(), &data_format))
1463 return op.emitOpError("invalid data format");
1464
1465 tensorflow::Padding padding;
1466 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1467 return failure();
1468
1469 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1470 auto filter_ty =
1471 op.filter().getType().template dyn_cast<RankedTensorType>();
1472 auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1473
1474 // Input, filter and the result needs to have static shape for calculation
1475 // of HLO paddings and feature group count attributes.
1476 for (RankedTensorType ty : {input_ty, filter_ty, result_ty})
1477 if (!ty || !ty.hasStaticShape()) return failure();
1478
1479 ArrayRef<Attribute> dilations = op.dilations().getValue();
1480 ArrayRef<Attribute> strides = op.strides().getValue();
1481 ArrayRef<Attribute> explicit_paddings;
1482 if (padding == tensorflow::Padding::EXPLICIT) {
1483 // EXPLICIT padding mode and the associated attribute is limited to
1484 // Conv2D. So, fetch attribute by identifier instead of the
1485 // op.explicit_paddings() attribute getter.
1486 explicit_paddings =
1487 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1488 }
1489
1490 SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1491 SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1492 SmallVector<int64_t, num_spatial_dims> window_strides;
1493 SmallVector<int64_t, num_spatial_dims * 2> paddings;
1494
1495 auto get_int = [](Attribute attr) {
1496 return attr.template cast<IntegerAttr>().getInt();
1497 };
1498
1499 constexpr int num_dims = num_spatial_dims + 2;
1500 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1501 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1502 spatial_dim_indices.push_back(dim);
1503
1504 const int64_t dilation = get_int(dilations[dim]);
1505 rhs_dilations.push_back(dilation);
1506 const int64_t stride = get_int(strides[dim]);
1507 window_strides.push_back(stride);
1508
1509 int64_t pad_low, pad_high;
1510 if (padding == tensorflow::Padding::EXPLICIT) {
1511 pad_low = get_int(explicit_paddings[2 * dim]);
1512 pad_high = get_int(explicit_paddings[2 * dim + 1]);
1513 } else {
1514 int64_t output_size;
1515 int64_t pad_low_int64;
1516 int64_t pad_high_int64;
1517 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1518 input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
1519 padding, &output_size, &pad_low_int64, &pad_high_int64);
1520 if (!status.ok()) return failure();
1521 pad_low = pad_low_int64;
1522 pad_high = pad_high_int64;
1523 }
1524 paddings.push_back(pad_low);
1525 paddings.push_back(pad_high);
1526 }
1527
1528 auto rhs_dilations_attr = rewriter.getNamedAttr(
1529 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1530
1531 auto window_strides_attr = rewriter.getNamedAttr(
1532 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1533
1534 auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1535 spatial_dim_indices, data_format, &rewriter);
1536
1537 const int64_t input_channels =
1538 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1539 // Filters data_format is always HWIO so input channels dimension is after
1540 // all spatial dimensions.
1541 const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1542 // TensorFlow convolution op verifies that the number of input channels is
1543 // divisible by the number of filter channels.
1544 // For depthwise convolution the feature_group_count argument would be set
1545 // to the input feature dimension.
1546 const int64_t feature_group_count =
1547 depthwise_conv ? input_channels : input_channels / filter_channels;
1548 auto feature_group_count_attr = rewriter.getNamedAttr(
1549 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1550
1551 auto batch_group_count_attr = rewriter.getNamedAttr(
1552 "batch_group_count", rewriter.getI64IntegerAttr(1));
1553
1554 RankedTensorType paddings_ty = RankedTensorType::get(
1555 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1556 auto paddings_attr = rewriter.getNamedAttr(
1557 "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1558
1559 SmallVector<Value, 2> operands(op.getOperands());
1560 // Reshape the filter to {spatial_dims...., 1,in_channels *
1561 // channel_multiplier}
1562 if (depthwise_conv) {
1563 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1564 llvm::SmallVector<int64_t, num_dims> new_shape(
1565 filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1566 new_shape.push_back(1);
1567 new_shape.push_back(filter_shape[num_spatial_dims] *
1568 filter_shape[num_spatial_dims + 1]);
1569 operands[1] = rewriter.create<mhlo::ReshapeOp>(
1570 op.getLoc(),
1571 RankedTensorType::get(new_shape, filter_ty.getElementType()),
1572 operands[1]);
1573 }
1574 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1575 dimension_numbers_attr, feature_group_count_attr,
1576 batch_group_count_attr, paddings_attr};
1577 rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
1578 llvm::makeArrayRef(attrs));
1579 return success();
1580 }
1581 };
1582
1583 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1584 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1585 using ConvertDepthConv2DOp =
1586 ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1587 /*depthwise_conv=*/true>;
1588
1589 // Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const.
1590 class ConvertPadOpDynamic : public OpRewritePattern<TF::PadV2Op> {
1591 public:
1592 using OpRewritePattern::OpRewritePattern;
1593 // TODO(disc): To recover static special case's performance with folding and
1594 // canonicalization.
matchAndRewrite(TF::PadV2Op op,PatternRewriter & rewriter) const1595 LogicalResult matchAndRewrite(TF::PadV2Op op,
1596 PatternRewriter &rewriter) const override {
1597 Location loc = op.getLoc();
1598 auto input = op.input();
1599 auto paddings = op.paddings();
1600 auto constant_values = op.constant_values();
1601 auto input_type = input.getType().dyn_cast<RankedTensorType>();
1602 auto paddings_type = paddings.getType().dyn_cast<RankedTensorType>();
1603 if (!input_type || !paddings_type || !paddings_type.hasStaticShape())
1604 return failure();
1605
1606 // TODO(disc): Remove this constraint once fold and canonicalization is
1607 // implemented.
1608 if (input_type.hasStaticShape()) return failure();
1609
1610 int input_rank = input_type.getRank();
1611 // interior padding
1612 std::vector<int64_t> interior_values(input_rank, 0);
1613 auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter);
1614
1615 Value interior_padding_tensor =
1616 rewriter.create<mhlo::ConstOp>(loc, interior_attr);
1617 Type paddings_elem_ty = paddings_type.getElementType();
1618 if (!paddings_elem_ty.isInteger(64)) {
1619 interior_padding_tensor = rewriter.create<mhlo::ConvertOp>(
1620 loc, interior_padding_tensor, paddings_elem_ty);
1621 }
1622 llvm::SmallVector<int64_t, 2> transposed_shape = {2, input_rank};
1623 auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter);
1624 Value transposed_paddings = rewriter.create<mhlo::TransposeOp>(
1625 loc, RankedTensorType::get(transposed_shape, paddings_elem_ty),
1626 paddings, transpose_attr);
1627 Value reshaped_paddings = rewriter.create<mhlo::ReshapeOp>(
1628 loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty),
1629 transposed_paddings);
1630
1631 auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter);
1632 auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1633 auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1634 Value left_padding_tensor = rewriter.create<mhlo::SliceOp>(
1635 loc, reshaped_paddings, left_padding_start_attr,
1636 left_padding_limit_attr, left_padding_stride_attr);
1637
1638 auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1639 auto right_padding_limit_attr =
1640 GetI64ElementsAttr({2 * input_rank}, &rewriter);
1641 auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1642 Value right_padding_tensor = rewriter.create<mhlo::SliceOp>(
1643 loc, reshaped_paddings, right_padding_start_attr,
1644 right_padding_limit_attr, right_padding_stride_attr);
1645
1646 rewriter.replaceOpWithNewOp<mhlo::DynamicPadOp>(
1647 op, op.getType(), input, constant_values, left_padding_tensor,
1648 right_padding_tensor, interior_padding_tensor);
1649
1650 return success();
1651 }
1652 };
1653
1654 class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {
1655 using OpRewritePattern<TF::GatherNdOp>::OpRewritePattern;
1656 // Converts tf.GatherNdOp to mhlo.DynamicGatherOp.
1657 // Here we leave 'slice_sizes' as an Attr, without defining a new
1658 // DynamicGatherOp, since GatherDimensionNumbers has already provide enough
1659 // information for shape inference and code generation of mhlo::GatherOp. '?'
1660 // will be filled into slice_sizes for dimensions that are dynamic sized.
1661 // TODO(disc): To recover static special case's performance with folding and
1662 // canonicalization.
matchAndRewrite(TF::GatherNdOp op,PatternRewriter & rewriter) const1663 LogicalResult matchAndRewrite(TF::GatherNdOp op,
1664 PatternRewriter &rewriter) const override {
1665 Location loc = op.getLoc();
1666 auto params = op.params();
1667 auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1668 auto indices = op.indices();
1669 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1670 auto params_rank = params_ty.getRank();
1671 auto indices_rank = indices_ty.getRank();
1672 int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1);
1673 if (!params_ty || !indices_ty) return failure();
1674 // the last dim of indices of GatherNdOp must be fixed shaped
1675 if (num_index_dims == ShapedType::kDynamicSize) return failure();
1676
1677 SmallVector<int64_t, 4> slice_sizes;
1678 slice_sizes.reserve(params_rank);
1679 for (int64_t i = 0; i < params_rank; ++i) {
1680 if (i < num_index_dims) {
1681 slice_sizes.push_back(1);
1682 } else {
1683 // potentially dynamic
1684 int64_t dim_size = params_ty.getDimSize(i);
1685 slice_sizes.push_back(dim_size);
1686 }
1687 }
1688 SmallVector<Value, 4> slice_sizes_vals;
1689 Value slice_sizes_value = nullptr;
1690 for (int64_t i = 0; i < params_rank; ++i) {
1691 if (i < num_index_dims) {
1692 slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1693 loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1694 } else {
1695 int64_t dim_size = params_ty.getDimSize(i);
1696 if (dim_size != ShapedType::kDynamicSize) {
1697 slice_sizes_vals.push_back(rewriter.create<ConstantOp>(
1698 loc,
1699 rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1700 } else {
1701 slice_sizes_vals.push_back(rewriter.create<IndexCastOp>(
1702 loc, rewriter.create<tensor::DimOp>(loc, params, i),
1703 indices_ty.getElementType()));
1704 }
1705 }
1706 }
1707 slice_sizes_value =
1708 rewriter.create<tensor::FromElementsOp>(loc, slice_sizes_vals);
1709
1710 // collapsed_slice_dims
1711 SmallVector<int64_t, 4> collapsed_slice_dims;
1712 collapsed_slice_dims.reserve(num_index_dims);
1713 for (int64_t i = 0; i < num_index_dims; ++i) {
1714 collapsed_slice_dims.push_back(i);
1715 }
1716 // offset_dims
1717 SmallVector<int64_t, 4> offset_dims;
1718 offset_dims.reserve(params_rank - num_index_dims);
1719 for (int64_t i = num_index_dims; i < params_rank; i++) {
1720 offset_dims.push_back(i + indices_rank - 1 - num_index_dims);
1721 }
1722 // start_index_map
1723 SmallVector<int64_t, 4> start_index_map;
1724 offset_dims.reserve(num_index_dims);
1725 for (int64_t i = 0; i < num_index_dims; i++) {
1726 start_index_map.push_back(i);
1727 }
1728 // index_vector_dim
1729 int64_t index_vector_dim = indices_rank - 1;
1730
1731 auto dims_attr = GatherDimensionNumbers::get(
1732 /*offset_dims=*/GetI64ElementsAttr(offset_dims, &rewriter),
1733 /*collapsed_slice_dims=*/
1734 GetI64ElementsAttr(collapsed_slice_dims, &rewriter),
1735 /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
1736 /*index_vector_dim=*/rewriter.getI64IntegerAttr(index_vector_dim),
1737 rewriter.getContext());
1738 // TODO(disc): Remove this if-statement once fold and canonicalization is
1739 // implemented.
1740 if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
1741 rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
1742 op, op.getType(), op.params(), op.indices(), dims_attr,
1743 GetI64ElementsAttr(slice_sizes, &rewriter));
1744 } else {
1745 rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1746 op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1747 dims_attr);
1748 }
1749 return success();
1750 }
1751 };
1752
1753 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1754 // division can result in strange behavior.
1755 //
1756 // floordiv = cast(floordiv(cast(left), cast(right))))
1757 //
1758 // %left_cast = cast(%left)
1759 // %right_cast = cast(%right)
1760 // %div = div(%left, %left)
1761 // %floored = floor(%div)
1762 // %floored_cast = cast(%floored)
1763 //
1764 // Required to manually specify the intermediate types.
1765 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1766 public:
1767 using OpRewritePattern::OpRewritePattern;
1768
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1769 LogicalResult matchAndRewrite(TF::FloorDivOp op,
1770 PatternRewriter &rewriter) const override {
1771 auto l = op.x();
1772 auto r = op.y();
1773 auto element_type = getElementTypeOrSelf(l.getType());
1774 if (!element_type.isBF16()) return failure();
1775
1776 auto out_type = op.z().getType().cast<TensorType>();
1777
1778 l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1779 r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1780
1781 auto intermediate = rewriter.create<TF::FloorDivOp>(
1782 op.getLoc(),
1783 ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1784 r);
1785
1786 auto floor_op =
1787 rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1788 rewriter.replaceOp(op, floor_op.getResult());
1789 return success();
1790 }
1791 };
1792
1793 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1794 public:
1795 using OpRewritePattern::OpRewritePattern;
1796
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1797 LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1798 PatternRewriter &rewriter) const override {
1799 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1800 auto output_type = op.output().getType();
1801 if (!input_type) {
1802 return rewriter.notifyMatchFailure(op, "requires ranked input shape");
1803 }
1804 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
1805 if (input_type.getRank() > 0) {
1806 auto ranked_output_type = output_type.dyn_cast<RankedTensorType>();
1807 if (!ranked_output_type) {
1808 return rewriter.notifyMatchFailure(op, "requires ranked output shape");
1809 }
1810 auto rank_diff = ranked_output_type.getRank() - input_type.getRank();
1811 // The tf.BroadcastTo op performs "right-aligned" numpy-style
1812 // broadcasting.
1813 broadcast_dimensions = llvm::to_vector<4>(
1814 llvm::seq<int64_t>(rank_diff, ranked_output_type.getRank()));
1815 }
1816 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1817 op, output_type, op.input(), op.shape(),
1818 rewriter.getI64TensorAttr(broadcast_dimensions));
1819 return success();
1820 }
1821 };
1822
1823 /// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis
1824 /// have to be a constant.
1825 class ConvertRollOp : public OpRewritePattern<TF::RollOp> {
1826 public:
1827 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::RollOp op,PatternRewriter & rewriter) const1828 LogicalResult matchAndRewrite(TF::RollOp op,
1829 PatternRewriter &rewriter) const override {
1830 auto shift_ty = op.shift().getType().dyn_cast<RankedTensorType>();
1831 if (!shift_ty || shift_ty.getRank() != 0) {
1832 return rewriter.notifyMatchFailure(
1833 op, "require the type of shift to be 0D tensor");
1834 }
1835
1836 APInt val;
1837 if (!matchPattern(op.axis(), m_ConstantInt(&val))) {
1838 return rewriter.notifyMatchFailure(op, "require axis to be constant");
1839 }
1840 int axis = val.getSExtValue();
1841
1842 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1843 if (!input_ty || !input_ty.hasStaticShape()) {
1844 return rewriter.notifyMatchFailure(
1845 op, "require the type of input to have static shapes");
1846 }
1847 ArrayRef<int64_t> input_shape = input_ty.getShape();
1848 int input_rank = input_ty.getRank();
1849 if (axis < 0) axis += input_rank;
1850
1851 // Adjust large offsets into [0, axis_size). This also makes negative
1852 // offsets positive.
1853 // offset = ((offset % axis_size) + axis_size) % axis_size
1854 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1855 Value offset = op.shift();
1856 auto axis_size = b.create<mhlo::ConstOp>(b.getIntegerAttr(
1857 getElementTypeOrSelf(offset.getType()), input_shape[axis]));
1858 offset = b.create<RemOp>(
1859 b.create<AddOp>(b.create<RemOp>(offset, axis_size), axis_size),
1860 axis_size);
1861
1862 // Stack two copies of the dimension, then slice from the calculated
1863 // offset. This also works if shift is not constant.
1864 // DynamicSliceOp requires the sizes being integer, and we can get the
1865 // information from input shape.
1866 auto concat = b.create<ConcatenateOp>(ValueRange{op.input(), op.input()},
1867 b.getI64IntegerAttr(axis));
1868 Value zero = b.create<mhlo::ConstOp>(
1869 b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0));
1870 SmallVector<Value> slice_begin_indices(input_rank, zero);
1871 slice_begin_indices[axis] = b.create<SubOp>(axis_size, offset);
1872 rewriter.replaceOpWithNewOp<DynamicSliceOp>(
1873 op, input_ty, concat, slice_begin_indices,
1874 rewriter.getI64TensorAttr(input_shape));
1875 return success();
1876 }
1877 };
1878
1879 /// Converts a TF::LeakyReluOp to HLO.
1880 /// LeakyRelu(x) = alpha * x if x < 0 else x.
1881 class ConvertLeakyReluOp : public OpRewritePattern<TF::LeakyReluOp> {
1882 public:
1883 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluOp op,PatternRewriter & rewriter) const1884 LogicalResult matchAndRewrite(TF::LeakyReluOp op,
1885 PatternRewriter &rewriter) const override {
1886 Location loc = op.getLoc();
1887 float alpha = op.alpha().convertToFloat();
1888 Value features = op.features();
1889 auto featureType = features.getType().cast<RankedTensorType>();
1890 ArrayRef<int64_t> featureShape = featureType.getShape();
1891 Type eltType = featureType.getElementType();
1892
1893 auto alphaVal = rewriter.create<mhlo::ConstOp>(
1894 loc, rewriter.getFloatAttr(eltType, alpha));
1895
1896 // Broadcast `alpha` to match the shape of feature.
1897 auto featureShapeAttr = DenseIntElementsAttr::get(
1898 RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
1899 featureShape);
1900 auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
1901 loc, featureType, alphaVal, featureShapeAttr);
1902
1903 Attribute zeroAttr = rewriter.getZeroAttr(featureType);
1904 Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
1905
1906 Value leakyActivationVal = rewriter.create<mhlo::MulOp>(
1907 loc, features.getType(), features, broadcastAlphaVal);
1908
1909 StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
1910 Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1911 loc, features, zeroVal, compare_direction);
1912
1913 rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1914 features, leakyActivationVal);
1915 return success();
1916 }
1917 };
1918
1919 /// Converts a TF::LeakyReluGradOp to HLO.
1920 /// LeakyReluGrad(gradient, inputs) = gradient if input > 0
1921 /// else alpha * gradient.
1922 class ConvertLeakyReluGradOp : public OpRewritePattern<TF::LeakyReluGradOp> {
1923 public:
1924 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluGradOp op,PatternRewriter & rewriter) const1925 LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
1926 PatternRewriter &rewriter) const override {
1927 Location loc = op.getLoc();
1928 float alpha = op.alpha().convertToFloat();
1929 Value gradients = op.gradients();
1930 Value features = op.features();
1931 auto featureType = features.getType().cast<RankedTensorType>();
1932 ArrayRef<int64_t> featureShape = featureType.getShape();
1933 Type eltType = featureType.getElementType();
1934
1935 auto alphaVal = rewriter.create<mhlo::ConstOp>(
1936 loc, rewriter.getFloatAttr(eltType, alpha));
1937 auto featureShapeAttr = DenseIntElementsAttr::get(
1938 RankedTensorType::get(featureShape.size(), rewriter.getIntegerType(64)),
1939 featureShape);
1940 auto broadcastAlphaVal = rewriter.create<mhlo::BroadcastOp>(
1941 loc, featureType, alphaVal, featureShapeAttr);
1942
1943 Attribute zeroAttr = rewriter.getZeroAttr(featureType);
1944 Value zeroVal = rewriter.create<ConstantOp>(loc, featureType, zeroAttr);
1945
1946 Value leakyGradientVal = rewriter.create<mhlo::MulOp>(
1947 loc, features.getType(), gradients, broadcastAlphaVal);
1948
1949 StringAttr compare_direction = StringAttr::get(rewriter.getContext(), "GT");
1950
1951 Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1952 loc, features, zeroVal, compare_direction);
1953
1954 rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1955 gradients, leakyGradientVal);
1956 return success();
1957 }
1958 };
1959
1960 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1961 // For a Rank-2 input, it creates the following ops:
1962 // %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1963 // %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1964 // %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1965 // %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1966 // %5 = "mhlo.broadcast"(%4)
1967 // %6 = "mhlo.select"(%3, %input, %5)
1968 // %7 = "mhlo.reduce"(%6, %4) ( {
1969 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1970 // %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1971 // "mhlo.return"(%9) : (tensor<f32>) -> ()
1972 // }) {dimensions = dense<0> : tensor<1xi64>}
1973 //
1974 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1975 // create the above ops, then reshape it back to rank N/2.
1976 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1977 public:
1978 using OpRewritePattern::OpRewritePattern;
1979
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1980 LogicalResult matchAndRewrite(TF::DiagPartOp op,
1981 PatternRewriter &rewriter) const override {
1982 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1983 if (!input_type || !input_type.hasStaticShape()) return failure();
1984 int64_t num_dims = input_type.getRank();
1985 if (num_dims < 2 || num_dims % 2 != 0) return failure();
1986 const int64_t out_dims = num_dims / 2;
1987
1988 int64_t new_size = 1;
1989 llvm::SmallVector<int64_t, 4> new_dims;
1990 for (int i = 0; i < out_dims; i++) {
1991 if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1992 return op.emitOpError("invalid dimensions size");
1993 new_size *= input_type.getDimSize(i);
1994 new_dims.push_back(input_type.getDimSize(i));
1995 }
1996 Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1997 op.getLoc(),
1998 RankedTensorType::get({new_size, new_size},
1999 input_type.getElementType()),
2000 op.input());
2001 auto iota_type = RankedTensorType::get({new_size, new_size},
2002 rewriter.getIntegerType(32));
2003 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
2004 rewriter.getI64IntegerAttr(0));
2005 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
2006 rewriter.getI64IntegerAttr(1));
2007 Value compare = rewriter.create<CompareOp>(
2008 op.getLoc(), iota0, iota1,
2009 StringAttr::get(rewriter.getContext(), "EQ"));
2010 Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
2011 0, &rewriter);
2012 Value zero_matrix = rewriter.create<BroadcastOp>(
2013 op.getLoc(), reshaped_input.getType(), zero,
2014 GetI64ElementsAttr({new_size, new_size}, &rewriter));
2015 Value masked =
2016 rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
2017 compare, reshaped_input, zero_matrix);
2018 auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
2019 GetI64ElementsAttr({0}, &rewriter));
2020 assert(!input_type.getElementType().isInteger(1) &&
2021 "data type should not be i1");
2022 BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
2023 &rewriter);
2024 rewriter.replaceOpWithNewOp<ReshapeOp>(
2025 op, RankedTensorType::get(new_dims, input_type.getElementType()),
2026 reduce.getResult(0));
2027 return success();
2028 }
2029 };
2030
2031 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
2032 class ConvertMatrixDiagPartV3Op
2033 : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
2034 using Shape = llvm::SmallVector<int64_t, 4>;
2035
2036 // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
2037 // with k. This can be either a single value (for a single diagonal) or a
2038 // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const2039 LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
2040 DenseIntElementsAttr kattr;
2041 if (!matchPattern(op.k(), m_Constant(&kattr))) {
2042 return failure();
2043 }
2044 DenseIntElementsAttr::iterator it = kattr.begin();
2045 (*k)[0] = (*it).getSExtValue();
2046 it++;
2047 if (it == kattr.end()) {
2048 // Handle input like e.g. "k = 5", in which case we extract a single
2049 // diagonal.
2050 (*k)[1] = (*k)[0];
2051 } else {
2052 // Handle input like e.g. "k = [-1, 1]", in which case we extract a
2053 // band (multiple diagonals).
2054 (*k)[1] = (*it).getSExtValue();
2055 }
2056 return success();
2057 }
2058
2059 // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const2060 BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
2061 int int_size, PatternRewriter &rewriter) const {
2062 return rewriter.create<BroadcastOp>(
2063 loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
2064 GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
2065 &rewriter),
2066 GetI64ElementsAttr(shape, &rewriter));
2067 }
2068
2069 public:
2070 using OpRewritePattern::OpRewritePattern;
2071
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const2072 LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
2073 PatternRewriter &rewriter) const override {
2074 Location loc = op.getLoc();
2075 ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
2076 auto element_type = input_type.getElementType();
2077
2078 // Align is a string specifying how superdiagonals and subdiagonals should
2079 // be aligned/padded for diagonals that are shorter than max_diag_len. The
2080 // format is "{super}_{sub}", with {super} the superdiagonal alignment and
2081 // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
2082 // left, "RIGHT" means rows will be padded ot the right. The default is
2083 // "RIGHT_LEFT".
2084 StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
2085 enum Alignment { kLeft, kRight };
2086
2087 // default is RIGHT_LEFT
2088 Alignment superdiagonal_align = kRight;
2089 Alignment subdiagonal_align = kLeft;
2090
2091 if (align == "RIGHT_LEFT") {
2092 superdiagonal_align = kRight;
2093 subdiagonal_align = kLeft;
2094 } else if (align == "RIGHT_RIGHT") {
2095 superdiagonal_align = kRight;
2096 subdiagonal_align = kRight;
2097 } else if (align == "LEFT_RIGHT") {
2098 superdiagonal_align = kLeft;
2099 subdiagonal_align = kRight;
2100 } else if (align == "LEFT_LEFT") {
2101 superdiagonal_align = kLeft;
2102 subdiagonal_align = kLeft;
2103 } else {
2104 return failure(); // unsupported alignment
2105 }
2106
2107 // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
2108 // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
2109 if (!input_type || !input_type.hasStaticShape()) return failure();
2110 int64_t num_dims = input_type.getRank();
2111 if (num_dims < 2) return failure();
2112 int64_t rows = input_type.getDimSize(num_dims - 2); // rows
2113 int64_t cols = input_type.getDimSize(num_dims - 1); // cols
2114
2115 // We extract the diagonals from k[0] up to and including k[1].
2116 // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
2117 // the main diagonal). It's negative for subdiagonals (under and to the left
2118 // of the main diagonal) and positive for superdiagonals (above and to the
2119 // right of the main diagonal).
2120 int64_t k[2];
2121 if (failed(ExtractK(op, &k))) return failure();
2122 int num_diags = k[1] - k[0] + 1;
2123
2124 // Shifting diagonals away from the main diagonal might shorten them. This
2125 // is the longest diagonal we will see. We make this the last dimension of
2126 // the output shape.
2127 int64_t max_diag_len =
2128 std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
2129 cols + std::min(-k[0], static_cast<int64_t>(0)));
2130
2131 // The first dimension is the index vector dimension we'll use for gather.
2132 // It's 1 here, but will be 2 once we glue x and y together.
2133 Shape indices_shape({1, num_diags, max_diag_len});
2134
2135 RankedTensorType iota_type =
2136 RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
2137 Value iotaM =
2138 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
2139 Value iotaN =
2140 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
2141
2142 // Boradcasted constants, of the same shape as iotaM and iotaN.
2143 Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
2144 Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
2145 Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
2146 Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
2147 Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
2148 Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
2149 Value b_max_diag_len =
2150 BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
2151
2152 // d = k[1] - m
2153 // (A.k.a. the number of the diagonal, depending on m. Note that we
2154 // subtract m here. This means we start with the superdiagonals and
2155 // move downwards towards the subdiagonals. So the start indices will
2156 // be decreasing.)
2157 Value d = rewriter.create<SubOp>(loc, b_k1, iotaM);
2158 Value neg_d = rewriter.create<NegOp>(loc, d);
2159
2160 // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
2161 // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
2162 Value diag_len_d = rewriter.create<MinOp>(
2163 loc,
2164 rewriter.create<AddOp>(loc, b_rows,
2165 rewriter.create<MinOp>(loc, d, b_zero)),
2166 rewriter.create<SubOp>(loc, b_cols,
2167 rewriter.create<MaxOp>(loc, d, b_zero)));
2168
2169 // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
2170 Value cmp;
2171 if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
2172 cmp = b_true;
2173 } else if (superdiagonal_align == kRight) {
2174 // offset = d>=0 ? max_diag_len - diag_len_d : 0
2175 cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
2176 } else if (subdiagonal_align == kRight) {
2177 // offset = d<=0 ? max_diag_len - diag_len_d : 0
2178 cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
2179 } else {
2180 // offset = 0
2181 cmp = b_false;
2182 }
2183
2184 // This offset shifts the diagonals to the "left" or "right", depending
2185 // on alignment.
2186 Value offset = rewriter.create<SelectOp>(
2187 loc, b_zero.getType(), cmp,
2188 rewriter.create<SubOp>(loc, b_max_diag_len, diag_len_d), b_zero);
2189
2190 // x = max(d, 0) - offset
2191 // y = max(-d, 0) - offset
2192 Value x = rewriter.create<SubOp>(
2193 loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
2194 Value y = rewriter.create<SubOp>(
2195 loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
2196
2197 Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
2198 Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
2199
2200 // GatherOp is happy about letting us index out of bounds values, but those
2201 // values will be undefined. So we mask them later. Set up the boolean
2202 // expression that tells us which entries, in the output shape, are out of
2203 // bounds and thus become the padding_value.
2204 Value x_in_bounds = rewriter.create<AndOp>(
2205 loc,
2206 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
2207 b_zero),
2208 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
2209 Value y_in_bounds = rewriter.create<AndOp>(
2210 loc,
2211 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
2212 b_zero),
2213 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
2214 Value in_bounds = rewriter.create<ReshapeOp>(
2215 loc,
2216 RankedTensorType::get(Shape({num_diags, max_diag_len}),
2217 rewriter.getIntegerType(1)),
2218 rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
2219
2220 // Now combine x and y into the index data structure needed for gather.
2221 Shape concat_shape({2, num_diags, max_diag_len});
2222 Value start_indices = rewriter.create<ConcatenateOp>(
2223 loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
2224 mlir::ValueRange({n_plus_y, n_plus_x}),
2225 mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
2226
2227 // Shape of the final output. (Except for dimension folding in the
2228 // single diagonal case.)
2229 Shape output_shape;
2230 for (int i = 0; i < num_dims - 2; i++) {
2231 output_shape.push_back(input_type.getDimSize(i));
2232 }
2233 output_shape.push_back(num_diags);
2234 output_shape.push_back(max_diag_len);
2235 auto output_type = RankedTensorType::get(output_shape, element_type);
2236
2237 // A slice is the shape of what GatherOp copies per lookup. So the last
2238 // two dimensions (M, N in the matrix-diag-part docs) are where we go
2239 // through entry by entry.
2240 ArrayRef<int64_t> input_shape = input_type.getShape();
2241 Shape slice_sizes(input_shape.begin(), input_shape.end());
2242 int slice_dimensions = slice_sizes.size();
2243 slice_sizes[slice_dimensions - 2] = 1;
2244 slice_sizes[slice_dimensions - 1] = 1;
2245
2246 // Dimensions of the input we won't see in the output (M and N).
2247 SmallVector<int64_t, 2> collapsed_dims(
2248 {slice_dimensions - 2, slice_dimensions - 1});
2249
2250 // Which dimensions (in the input) the two offset "columns" map to.
2251 SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
2252
2253 // Gather the diagonal entries.
2254 // TODO(kramm): For a single diagonal, this might be slower than the
2255 // mask + sum approach. Special-case num_diags==1?
2256 auto dims_attr = GatherDimensionNumbers::get(
2257 /*offset_dims=*/GetI64ElementsAttrForSeq(0, num_dims - 2, &rewriter),
2258 /*collapsed_slice_dims=*/GetI64ElementsAttr(collapsed_dims, &rewriter),
2259 /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter),
2260 /*index_vector_dim=*/rewriter.getI64IntegerAttr(0),
2261 rewriter.getContext());
2262 Value gather = rewriter.create<mhlo::GatherOp>(
2263 loc, output_type, op.input(), start_indices, dims_attr,
2264 GetI64ElementsAttr(slice_sizes, &rewriter));
2265
2266 // We now need to broadcast the "in_bounds" boolean expression, as well as
2267 // the padding value, to do the final select.
2268 Shape broadcast_bounds;
2269 for (int i = 0; i < output_shape.size() - 2; i++) {
2270 broadcast_bounds.push_back(output_shape[i]);
2271 }
2272 Value b_in_bounds = rewriter.create<BroadcastOp>(
2273 loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
2274 in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
2275 Value b_padding = rewriter.create<BroadcastOp>(
2276 loc, output_type, op.padding_value(),
2277 GetI64ElementsAttr(output_shape, &rewriter));
2278
2279 // Replace all out-of-bounds values in the result with padding_value.
2280 Value result = rewriter.create<SelectOp>(loc, output_type, b_in_bounds,
2281 gather, b_padding);
2282
2283 if (num_diags == 1) {
2284 // matrix_diag_part folds away the 1-sized band dimension if we only
2285 // extract a single diagonal.
2286 result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2287 }
2288
2289 rewriter.replaceOp(op, result);
2290 return success();
2291 }
2292 };
2293
2294 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
2295 // depending on arity of the op.
2296 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
2297 public:
2298 using OpRewritePattern::OpRewritePattern;
2299
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const2300 LogicalResult matchAndRewrite(TF::EinsumOp op,
2301 PatternRewriter &rewriter) const override {
2302 StringAttr equation = op->getAttrOfType<StringAttr>("equation");
2303 if (op.N() == 1) {
2304 rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
2305 op, op.getType(), *op.inputs().begin(), equation);
2306 } else if (op.N() == 2) {
2307 ValueRange inputs = op.inputs();
2308 rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
2309 inputs[1], equation);
2310 } else {
2311 // TensorFlow EinsumOp verifies that the number of operands are at most
2312 // two.
2313 return failure();
2314 }
2315 return success();
2316 }
2317 };
2318
2319 // Bypasses IdentityN op.
2320 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
2321 public:
2322 using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const2323 LogicalResult matchAndRewrite(TF::IdentityNOp op,
2324 PatternRewriter &rewriter) const override {
2325 rewriter.replaceOp(op, op.getOperands());
2326 return success();
2327 }
2328 };
2329
2330 template <typename OpTy>
2331 class ConvertFFTOp : public OpRewritePattern<OpTy> {
2332 public:
2333 using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2334 LogicalResult matchAndRewrite(OpTy op,
2335 PatternRewriter &rewriter) const override {
2336 auto input_ty = op.input().getType().template cast<ShapedType>();
2337 if (!input_ty.hasRank()) {
2338 return failure();
2339 }
2340 auto input_shape = input_ty.getShape();
2341 DenseIntElementsAttr fft_length_attr;
2342 if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
2343 return failure();
2344 }
2345 int64_t fft_length;
2346 if (fft_length_attr.getNumElements() != 0) {
2347 fft_length = fft_length_attr.getValue<IntegerAttr>(0).getInt();
2348 } else {
2349 return failure();
2350 }
2351
2352 std::string fft_string = "RFFT";
2353 if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
2354 fft_length = fft_length / 2 + 1;
2355 fft_string = "IRFFT";
2356 }
2357 Location loc = op.getLoc();
2358
2359 // The inner-most dim cannot be dynamic.
2360 if (input_ty.isDynamicDim(input_shape.size() - 1)) {
2361 return failure();
2362 }
2363
2364 auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
2365 expected_shape.push_back(fft_length);
2366
2367 // Zero pad or truncate the last axis
2368 Value reshaped = op.input();
2369 SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
2370 SmallVector<int64_t, 4> strides(input_shape.size(), 1);
2371
2372 // Last dim larger than fft_length, slice the input
2373 if (input_shape.back() > fft_length) {
2374 reshaped = rewriter.create<SliceOp>(
2375 op.getLoc(),
2376 RankedTensorType::get(expected_shape, input_ty.getElementType()),
2377 op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
2378 GetI64ElementsAttr(expected_shape, &rewriter),
2379 GetI64ElementsAttr(strides, &rewriter));
2380
2381 // Last dim smaller than fft_length, zero-pad the input
2382 } else if (input_ty.getShape().back() < fft_length) {
2383 SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
2384 SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
2385 padding.push_back(fft_length - input_shape.back());
2386 Value zero =
2387 GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
2388 reshaped = rewriter.create<PadOp>(
2389 loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
2390 op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
2391 GetI64ElementsAttr(padding, &rewriter),
2392 GetI64ElementsAttr(no_padding, &rewriter));
2393 }
2394
2395 rewriter.replaceOpWithNewOp<FftOp>(op, op.getType(), reshaped, fft_string,
2396 rewriter.getI64TensorAttr(fft_length));
2397 return success();
2398 }
2399 };
2400
2401 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
2402 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
2403
2404 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
2405 // BatchNormGradOp for training and a sequence of binary ops for inference.
2406 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
2407 template <typename FusedBatchNormGradOpT>
2408 class ConvertFusedBatchNormGradBase
2409 : public OpRewritePattern<FusedBatchNormGradOpT> {
2410 public:
2411 using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
2412
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const2413 LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
2414 PatternRewriter &rewriter) const override {
2415 Location loc = op.getLoc();
2416 Value grad = op.y_backprop();
2417 Value act = op.x();
2418 Value scale = op.scale();
2419 Value mean = op.reserve_space_1();
2420 Value var = op.reserve_space_2();
2421
2422 // TODO(b/141785544): Update this to not require static shapes.
2423 // activation shape needs to be static to convert negative indices in
2424 // TensorFlow to absolute indices required by HLO.
2425 RankedTensorType act_type =
2426 act.getType().template dyn_cast<RankedTensorType>();
2427 if (!act_type) return failure();
2428 Type act_ele_type = act_type.getElementType();
2429 // To support mixed precision, the statistics type, which maybe more
2430 // precise than the input types, are used for this op.
2431 Type kernel_type =
2432 scale.getType().template cast<TensorType>().getElementType();
2433 grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
2434 act = rewriter.create<ConvertOp>(loc, act, kernel_type);
2435
2436 tensorflow::TensorFormat data_format;
2437 if (!FormatFromString(op.data_format().str(), &data_format))
2438 return op.emitOpError("invalid data format");
2439
2440 auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act);
2441 auto feature_dim = feature_dim_attr.getValue().getSExtValue();
2442
2443 // Gets the result values.
2444 Value x_backprop, scale_backprop, offset_backprop;
2445 if (op.is_training()) { // training
2446 // TODO(b/145536565): handle GPU logic separately.
2447 // Infers the output type with the converted `act`.
2448 Type feature_type = RankedTensorType::get(
2449 {GetDimSize(act_type, feature_dim)}, kernel_type);
2450 Type result_type = TupleType::get(
2451 rewriter.getContext(), {act.getType(), feature_type, feature_type});
2452
2453 auto training_op = rewriter.create<BatchNormGradOp>(
2454 loc, result_type, act, scale, mean, var, grad, op.epsilon(),
2455 feature_dim);
2456
2457 x_backprop =
2458 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
2459
2460 scale_backprop =
2461 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
2462
2463 offset_backprop =
2464 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
2465 } else { // inference
2466 SmallVector<int64_t, 4> non_feature_dims;
2467 for (int64_t i = 0; i < act_type.getRank(); ++i) {
2468 if (i == feature_dim) continue;
2469 non_feature_dims.push_back(i);
2470 }
2471 auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
2472 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2473
2474 // scratch1 = rsqrt(var + epsilon)
2475 RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
2476 auto epsilon = rewriter.create<ConstOp>(
2477 loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
2478 auto add_op = rewriter.create<chlo::BroadcastAddOp>(
2479 loc, var, epsilon.getResult(), scalar_broadcast_dims);
2480
2481 Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
2482
2483 // scratch2 = sum(y_backprop * (x - mean))
2484 auto sub_op = rewriter.create<mhlo::SubOp>(
2485 loc, act,
2486 Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
2487 auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
2488 Value scratch2 =
2489 ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
2490
2491 // x_backprop = y_backprop * (scale * scratch1)
2492 auto scaled_grad =
2493 rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
2494 x_backprop = rewriter.create<mhlo::MulOp>(
2495 loc, grad,
2496 Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
2497 rewriter));
2498
2499 // scale_backprop = scratch2 * scratch1
2500 scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
2501
2502 // offset_backprop = sum(y_backprop)
2503 offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
2504 }
2505
2506 x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
2507 Value last_val[2];
2508 if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
2509 // It doesn't matter what values we provide for the last 2 results.
2510 last_val[0] = last_val[1] = op.x();
2511 } else {
2512 auto const_val = rewriter.create<ConstOp>(
2513 op.getLoc(),
2514 DenseElementsAttr::get<float>(
2515 RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
2516 0.0));
2517 auto maybe_cast = [&](Value val, Type t) -> Value {
2518 if (val.getType() == t) return val;
2519 return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
2520 };
2521 last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
2522 last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
2523 }
2524 rewriter.replaceOp(
2525 op, {/*x_backprop=*/x_backprop,
2526 /*scale_backprop=*/scale_backprop,
2527 /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
2528 return success();
2529 }
2530 };
2531
2532 using ConvertFusedBatchNormGradOp =
2533 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
2534 using ConvertFusedBatchNormGradV2Op =
2535 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
2536 using ConvertFusedBatchNormGradV3Op =
2537 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
2538
2539 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
2540 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
2541 // parameter.
2542 template <typename FusedBatchNormOpT>
2543 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
2544 public:
2545 using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
2546
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const2547 LogicalResult matchAndRewrite(FusedBatchNormOpT op,
2548 PatternRewriter &rewriter) const override {
2549 tensorflow::TensorFormat data_format;
2550 if (!FormatFromString(op.data_format().str(), &data_format))
2551 return op.emitOpError("invalid data format");
2552
2553 auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.x());
2554
2555 auto input_type_tensor = op.x().getType().template cast<TensorType>();
2556 auto input_element_type = input_type_tensor.getElementType();
2557
2558 auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
2559 auto scale_element_type = scale_type_tensor.getElementType();
2560
2561 auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
2562 auto mean_element_type = mean_type_tensor.getElementType();
2563 // In the training case, dimensions of input tensors must be static.
2564 if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
2565 !scale_type_tensor.hasStaticShape() ||
2566 !mean_type_tensor.hasStaticShape()))
2567 return failure();
2568
2569 // TODO(b/69928690): Support mixed precision in the XLA batch
2570 // normalization operators. As a workaround, create a new x with the same
2571 // element type as scale (which may be more precise than the input type).
2572 Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
2573 scale_element_type);
2574 TensorType bn_train_input_type_tensor =
2575 bn_train_input.getType().template cast<TensorType>();
2576
2577 if (op.is_training()) {
2578 // Training case.
2579 auto operand_shape = bn_train_input_type_tensor.getShape();
2580 // The mean and variance are each 1 dimensional arrays the size of the
2581 // feature dimension, with the same element type as the operand (x).
2582 // This shape must be constructed manually because the mean and variance
2583 // inputs are empty in the training case.
2584 Type mean_var_type = RankedTensorType::get(
2585 {operand_shape[feature_dim.getInt()]}, scale_element_type);
2586 // Op result type is a tuple of 3 values: output with same shape as input;
2587 // batch_mean, and batch_var.
2588 SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
2589 mean_var_type, mean_var_type};
2590 Type result_type = TupleType::get(rewriter.getContext(), operand_types);
2591
2592 auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
2593 op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
2594 op.epsilon(), feature_dim.getInt());
2595 // HLO op outputs a tuple of tensors. Extract those results.
2596 auto bn_train_op_result = bn_train_op.getResult();
2597 Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
2598 op.getLoc(), bn_train_op_result, 0);
2599 Value batch_mean = rewriter.create<mhlo::GetTupleElementOp>(
2600 op.getLoc(), bn_train_op_result, 1);
2601 Value reserve_space_1 = batch_mean;
2602 Value batch_variance = rewriter.create<mhlo::GetTupleElementOp>(
2603 op.getLoc(), bn_train_op_result, 2);
2604
2605 // Apply Bessel's correction on the variance.
2606 int total_input_size = bn_train_input_type_tensor.getNumElements();
2607 int total_scale_size = scale_type_tensor.getNumElements();
2608 int sample_size = total_input_size / total_scale_size;
2609 int sample_size_minus_one = std::max(1, sample_size - 1);
2610 double factor = static_cast<double>(sample_size) /
2611 static_cast<double>(sample_size_minus_one);
2612 auto factor_const_op = rewriter.create<mhlo::ConstOp>(
2613 op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2614
2615 Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2616 op.getLoc(), batch_variance.getType(), batch_variance,
2617 factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2618
2619 // Convert back to input type to stay aligned with expected output type
2620 // for TF op.
2621 y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2622 input_element_type);
2623
2624 float exponential_avg_factor =
2625 op.exponential_avg_factor().convertToFloat();
2626 if (exponential_avg_factor != 1.0f) {
2627 auto alpha = rewriter.create<mhlo::ConstOp>(
2628 op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2629 1.0f - exponential_avg_factor));
2630 auto beta = rewriter.create<mhlo::ConstOp>(
2631 op.getLoc(),
2632 rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2633
2634 // new_running_mean = alpha * old_mean + beta * batch_mean.
2635 auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2636 op.getLoc(), op.mean().getType(), alpha, op.mean(),
2637 /*broadcast_dimensions=*/DenseIntElementsAttr());
2638 auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2639 op.getLoc(), batch_mean.getType(), beta, batch_mean,
2640 /*broadcast_dimensions=*/DenseIntElementsAttr());
2641 batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2642 op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2643 /*broadcast_dimensions=*/DenseIntElementsAttr());
2644
2645 // new_running_variance = alpha * old_variance + beta * batch_variance.
2646 auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2647 op.getLoc(), op.variance().getType(), alpha, op.variance(),
2648 /*broadcast_dimensions=*/DenseIntElementsAttr());
2649 auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2650 op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2651 /*broadcast_dimensions=*/DenseIntElementsAttr());
2652 corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2653 op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2654 /*broadcast_dimensions=*/DenseIntElementsAttr());
2655 }
2656
2657 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2658 // FusedBatchNormV2 expects 4 outputs.
2659 // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2660 // They are used to pass the per-batch mean and variance to the
2661 // gradiant. Here we maintain the same behavior by setting them to the
2662 // mean and variance calculated by BatchNormTraining.
2663 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2664 /*batch_variance=*/corrected_variance,
2665 /*reserve_space_1=*/reserve_space_1,
2666 /*reserve_space_2=*/batch_variance});
2667 } else { // TF::FusedBatchNormV3Op
2668 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2669 // last reserve_space_3 output.
2670 auto reserve_space_3_type =
2671 op.getResult(5).getType().template cast<TensorType>();
2672 int num_elements = reserve_space_3_type.hasStaticShape()
2673 ? reserve_space_3_type.getNumElements()
2674 : 0;
2675 auto const_attr_type = RankedTensorType::get(
2676 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2677 Value dummy_const = rewriter.create<ConstOp>(
2678 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2679 if (const_attr_type != reserve_space_3_type)
2680 dummy_const = rewriter.create<tensor::CastOp>(
2681 op.getLoc(), reserve_space_3_type, dummy_const);
2682 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2683 /*batch_variance=*/corrected_variance,
2684 /*reserve_space_1=*/reserve_space_1,
2685 /*reserve_space_2=*/batch_variance,
2686 /*reserve_space_3=*/dummy_const});
2687 }
2688 } else { // Inference case.
2689 auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2690 op.getLoc(),
2691 /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2692 op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2693 feature_dim.getInt());
2694
2695 // Convert back to input type to stay aligned with expected output type
2696 // for TF op.
2697 auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2698 input_element_type);
2699
2700 // The mean, variance, and reserved space outputs of the batch norm op are
2701 // not used for inference. It doesn't matter what values we provide for
2702 // the last 5 results as long as they are of the same type. Forward
2703 // input mean and variance to output mean, variance, reserved_space_1 and
2704 // reserved_space_2.
2705 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2706 rewriter.replaceOp(op, {/*y=*/y_out,
2707 /*batch_mean=*/op.mean(),
2708 /*batch_variance=*/op.variance(),
2709 /*reserve_space_1=*/op.mean(),
2710 /*reserve_space_2=*/op.variance()});
2711 } else {
2712 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2713 // last reserve_space_3 output.
2714 auto reserve_space_3_type =
2715 op.getResult(5).getType().template cast<TensorType>();
2716 int num_elements = reserve_space_3_type.hasStaticShape()
2717 ? reserve_space_3_type.getNumElements()
2718 : 0;
2719 auto const_attr_type = RankedTensorType::get(
2720 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2721 Value dummy_const = rewriter.create<ConstOp>(
2722 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2723 if (const_attr_type != reserve_space_3_type)
2724 dummy_const = rewriter.create<tensor::CastOp>(
2725 op.getLoc(), reserve_space_3_type, dummy_const);
2726 rewriter.replaceOp(op, {/*y=*/y_out,
2727 /*batch_mean=*/op.mean(),
2728 /*batch_variance=*/op.variance(),
2729 /*reserve_space_1=*/op.mean(),
2730 /*reserve_space_2=*/op.variance(),
2731 /*reserve_space_3=*/dummy_const});
2732 }
2733 }
2734 return success();
2735 }
2736 };
2737
2738 using ConvertFusedBatchNormV2Op =
2739 ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2740 using ConvertFusedBatchNormV3Op =
2741 ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2742
2743 using PaddingArray =
2744 std::vector<std::pair<tensorflow::int64, tensorflow::int64>>;
2745
2746 // Returns padding values for ReduceWindow op as a vector of pairs.
2747 //
2748 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2749 // dimensions to be equal to the size of window dimensions and window strides.
2750 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2751 static PaddingArray GetReduceWindowPaddingAsArray(
2752 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2753 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2754 if (padding == "VALID") {
2755 return PaddingArray(num_dims, std::make_pair(0, 0));
2756 }
2757 assert(padding == "SAME");
2758 llvm::SmallVector<tensorflow::int64, num_dims> input_shape, window_shape,
2759 strides;
2760 input_shape.reserve(input_dims.size());
2761 window_shape.reserve(window_shape.size());
2762 strides.reserve(window_strides.size());
2763
2764 for (const auto &dim : input_dims) input_shape.push_back(dim);
2765 for (Attribute attr : window_dims)
2766 window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2767 for (Attribute attr : window_strides)
2768 strides.push_back(attr.cast<IntegerAttr>().getInt());
2769
2770 PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2771 ::xla::Padding::kSame);
2772 return paddings;
2773 }
2774
2775 // Same as GetReduceWindowPaddingAsArray but returns padding as
2776 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2777 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2778 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2779 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2780 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2781 if (padding == "VALID") return {};
2782 assert(padding == "SAME");
2783 PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2784 input_dims, window_dims, window_strides, padding, builder);
2785 int64_t rank = paddings.size();
2786 llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2787 for (int i = 0; i < rank; i++) {
2788 flatten_paddings[2 * i] = paddings[i].first;
2789 flatten_paddings[2 * i + 1] = paddings[i].second;
2790 }
2791 return DenseIntElementsAttr::get(
2792 RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2793 flatten_paddings);
2794 }
2795
2796 // Helper function for dividing each entry of `pooled` by the count of its
2797 // corresponding window, i.e., the number of non-padding entries of the window
2798 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2799 // to this entry, depending on `ksize` and `strides`. This function is used for
2800 // `AvgPool` and `AvgPoolGrad` legalizations.
2801 // `zero` is passed as a parameter because it can be reused from caller level.
2802 // `pooled` must have `RankedTensorType`.
2803 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2804 Operation *AvgPoolDivideByCount(
2805 Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2806 const SmallVector<int64_t, num_dims> &ksize,
2807 const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2808 PatternRewriter &rewriter) {
2809 Location loc = op.getLoc();
2810 RankedTensorType pooled_type =
2811 pooled.getType().template cast<RankedTensorType>();
2812 Type element_type = pooled_type.getElementType();
2813 Operation *result = nullptr;
2814 RankedTensorType orig_input_type =
2815 RankedTensorType::get(input_shape, element_type);
2816
2817 if (op.padding() == "VALID") {
2818 // All window counts are equal here because we don't have padding
2819 // (each entry of `pooled` corresponds to a window that consists of
2820 // original input entries only).
2821 int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2822 std::multiplies<int64_t>());
2823 // Divide `pooled` by window counts.
2824 Value divisor =
2825 GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2826 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2827 result = rewriter.create<chlo::BroadcastDivOp>(
2828 loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2829 } else {
2830 assert(op.padding() == "SAME");
2831 // For SAME padding, only original entries that contributed to a window
2832 // are counted for the average of this window, not padded entries.
2833
2834 // Build all-ones tensor of same shape as the original input.
2835 ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2836 auto all_ones_tensor = rewriter.create<ConstOp>(loc, splat);
2837
2838 // Get padding for the input.
2839 DenseIntElementsAttr input_padding_attr =
2840 GetReduceWindowPaddingAsAttr<num_dims>(
2841 input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2842
2843 // Count the 1's in each window, using the same padding as for the input,
2844 // which gives us the window counts by which `pooled` needs to be divided.
2845 auto divisor = rewriter.create<ReduceWindowOp>(
2846 loc, pooled_type,
2847 /*operand=*/all_ones_tensor,
2848 /*init_value=*/zero,
2849 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2850 /*window_strides=*/GetI64ElementsAttr(op.strides()),
2851 /*base_dilations=*/DenseIntElementsAttr(),
2852 /*window_dilations=*/DenseIntElementsAttr(),
2853 /*padding=*/input_padding_attr);
2854 BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2855
2856 // Divide `pooled` by window counts.
2857 result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled,
2858 divisor.getResult(0));
2859 }
2860 return result;
2861 }
2862
GetAvgPoolInput(TF::AvgPoolOp op)2863 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2864 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2865
2866 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2867 // dimensions with add as the reduction function. The reduction result is
2868 // then divided by the number of elements in the window.
2869 template <typename OpTy, int num_dims>
2870 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2871 public:
2872 using OpRewritePattern<OpTy>::OpRewritePattern;
2873
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2874 LogicalResult matchAndRewrite(OpTy op,
2875 PatternRewriter &rewriter) const override {
2876 Value input_value = GetAvgPoolInput(op);
2877 auto input_type =
2878 input_value.getType().template dyn_cast<RankedTensorType>();
2879 if (!input_type) return failure();
2880
2881 // We will do accumulation first; use a larger bitwidth if suitable.
2882 Type input_element_type = input_type.getElementType();
2883 Type sum_element_type = GetSumAccumulationType(input_element_type);
2884 Type result_type;
2885
2886 // The result type for reduction and division with the proper element type.
2887 if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2888 result_type =
2889 RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2890 else
2891 result_type = UnrankedTensorType::get(sum_element_type);
2892
2893 // Convert if we need enlarge the element type's bitwidth.
2894 if (input_element_type != sum_element_type)
2895 input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2896 sum_element_type);
2897
2898 // Create the ReduceWindow op.
2899 Value init =
2900 GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2901 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2902 input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2903 &rewriter);
2904 auto reduce = rewriter.create<ReduceWindowOp>(
2905 op.getLoc(), result_type, input_value, init,
2906 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2907 /*base_dilations=*/DenseIntElementsAttr(),
2908 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2909 BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2910
2911 // Count the number of elements in the window. The following calculation
2912 // is only valid for no paddings.
2913 SmallVector<int64_t, num_dims> input_shape(
2914 llvm::to_vector<num_dims>(input_type.getShape()));
2915 SmallVector<int64_t, num_dims> ksize, strides;
2916 GetI64ArrayAttrValues(op.ksize(), &ksize);
2917 GetI64ArrayAttrValues(op.strides(), &strides);
2918
2919 Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2920 reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter);
2921
2922 // Convert back if we enlarged the element type's bitwidth.
2923 Value result = result_op->getOpResult(0);
2924 if (input_element_type != sum_element_type)
2925 result =
2926 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2927
2928 rewriter.replaceOp(op, result);
2929 return success();
2930 }
2931 };
2932
2933 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2934 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2935
2936 // `AvgPoolGradOp` is converted to the following operations:
2937 // 1. Divide each entry of the output gradient (the gradient for the previous
2938 // layer in backpropagation order) by the count of the corresponding window
2939 // (i.e., the number of non-padding entries of the window which `AvgPool`
2940 // has mapped to this entry in forward propagation).
2941 // 2. Add appropriate interior and exterior padding for step 3 (see example
2942 // below).
2943 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2944 // as windows) and stride 1 in each dimension. This is implemented as a
2945 // `ReduceWindowOp` with `AddOp` as body.
2946 //
2947 // Example:
2948 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2949 // and SAME padding with 0's. It is defined by
2950 // f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) )
2951 // [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding)
2952 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2953 // for the average, this is why the second denominator is 2 and not 3.
2954 // The Jacobian Df is
2955 // [ 1/3 1/3 1/3 0 ]
2956 // [ 0 0 1/2 1/2 ]
2957 //
2958 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2959 // needs the original input shape and not the tensor as argument).
2960 // Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the
2961 // average pool gradient is given by
2962 // Df^T * v = [ 4/3 4/3 13/3 3 ]^T
2963 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2964 // structure of Df by using the 3-step approach from above:
2965 // 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T
2966 // 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T
2967 // 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T
2968 //
2969 // Note that the padding in step 2. is chosen in such a way that the subsequent
2970 // convolution produces the gradient. Higher dimensions, different padding, and
2971 // different windows/strides work in a similar way, the main difference is in
2972 // the computation of the paddings in step 2.
2973 //
2974 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2975 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2976 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2977 // examples for different cases.
2978 template <typename OpTy, int num_dims>
2979 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2980 using DimVector = SmallVector<int64_t, num_dims>;
2981
2982 public:
2983 using OpRewritePattern<OpTy>::OpRewritePattern;
2984
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2985 LogicalResult matchAndRewrite(OpTy op,
2986 PatternRewriter &rewriter) const override {
2987 Location loc = op.getLoc();
2988 tensorflow::TensorFormat data_format;
2989 if (!FormatFromString(op.data_format().str(), &data_format)) {
2990 return op.emitOpError("invalid data format");
2991 }
2992 // `out_grad` is the gradient that was propagated via backpropagation from
2993 // the output layer.
2994 Value out_grad = op.grad();
2995 auto out_grad_type =
2996 out_grad.getType().template dyn_cast<RankedTensorType>();
2997 if (!out_grad_type) {
2998 return failure();
2999 }
3000 Type element_type = out_grad_type.getElementType();
3001 DenseIntElementsAttr orig_input_shape_attr;
3002 if (!matchPattern(op.orig_input_shape(),
3003 m_Constant(&orig_input_shape_attr))) {
3004 return failure();
3005 }
3006 auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
3007 DimVector orig_input_shape(orig_input_shape_values.begin(),
3008 orig_input_shape_values.end());
3009 DimVector ksize, strides;
3010 GetI64ArrayAttrValues(op.ksize(), &ksize);
3011 GetI64ArrayAttrValues(op.strides(), &strides);
3012 Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
3013
3014 auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
3015 out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
3016
3017 // Get same padding as for original input.
3018 PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
3019 orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
3020
3021 // Add padding around `out_grad_divided` values in such a way that the
3022 // subsequent `ReduceWindowOp` produces the gradient.
3023 DimVector out_grad_shape(
3024 llvm::to_vector<num_dims>(out_grad_type.getShape()));
3025 DimVector low_padding(num_dims, 0);
3026 DimVector high_padding(num_dims, 0);
3027 DimVector interior_padding(num_dims, 0);
3028 constexpr int num_spatial_dims = num_dims - 2;
3029 for (int i = 0; i < num_spatial_dims; ++i) {
3030 int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
3031 int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
3032 orig_padding[dim].first +
3033 orig_padding[dim].second;
3034 // Set interior padding such that neighboring entries from
3035 // `out_grad_divided` have distance `strides[dim]` from each other in
3036 // every dimension.
3037 interior_padding[dim] = strides[dim] - 1;
3038 // Set exterior padding in the same way as for convolution gradient
3039 // computation.
3040 auto status = ::xla::ConvGradExtractAndVerifyDimension(
3041 /*input_size=*/orig_input_shape_padded_in_dim,
3042 /*filter_size=*/ksize[dim],
3043 /*output_size=*/out_grad_shape[dim],
3044 /*dilation=*/1,
3045 /*stride=*/strides[dim],
3046 /*padding=*/::xla::Padding::kValid);
3047 if (!status.ok()) {
3048 return failure();
3049 }
3050 ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
3051 status.ValueOrDie();
3052 // Subtract the original exterior padding since it doesn't contribute to
3053 // the gradient. Note that we save one `PadOp` and some unnecessary kernel
3054 // computations, compared to the `xla::AvgPoolGrad` implementation, by
3055 // subtracting the original exterior padding before `ReduceWindowOp`
3056 // instead of trimming the result of `ReduceWindowOp` (the final result is
3057 // the same because all strides are 1).
3058 low_padding[dim] =
3059 conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
3060 high_padding[dim] =
3061 conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
3062
3063 // Update `out_grad_shape` to result shape of following `PadOp`.
3064 out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
3065 (out_grad_shape[dim] - 1) * strides[dim] + 1;
3066 }
3067 Value reduce_window_input = rewriter.create<PadOp>(
3068 loc, RankedTensorType::get(out_grad_shape, element_type),
3069 /*operand=*/out_grad_divided->getOpResult(0),
3070 /*padding_value=*/zero,
3071 /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
3072 /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
3073 /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
3074
3075 // Compute result by convolving `reduce_window_input` with an all-ones
3076 // kernel, using `ReduceWindowOp` with `AddOp` body.
3077
3078 Type sum_element_type = GetSumAccumulationType(element_type);
3079 if (element_type != sum_element_type) {
3080 // Convert to appropriate sum accumulation type to avoid precision loss.
3081 reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
3082 sum_element_type);
3083 zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
3084 }
3085 auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
3086 auto reduce_window_op = rewriter.create<ReduceWindowOp>(
3087 loc, RankedTensorType::get(orig_input_shape, sum_element_type),
3088 /*operand=*/reduce_window_input,
3089 /*init_value=*/zero,
3090 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
3091 /*window_strides=*/ones,
3092 /*base_dilations=*/DenseIntElementsAttr(),
3093 /*window_dilations=*/DenseIntElementsAttr(),
3094 /*padding=*/DenseIntElementsAttr());
3095 BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
3096 &rewriter);
3097 Value result = reduce_window_op.getResult(0);
3098
3099 if (element_type != sum_element_type) {
3100 // Convert back to original element type.
3101 result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
3102 }
3103 rewriter.replaceOp(op, {result});
3104 return success();
3105 }
3106 };
3107
3108 using ConvertAvgPool2DGradOp =
3109 ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
3110 using ConvertAvgPool3DGradOp =
3111 ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
3112
3113 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
3114 // dimensions with max as the reduction function.
3115 //
3116 // Sample result for VALID padding mode:
3117 //
3118 // %init = constant dense<...> : tensor<i32>
3119 // %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
3120 // {window_dimensions = ..., window_strides = ... }
3121 //
3122 template <typename OpTy, int num_dims>
3123 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
3124 public:
3125 using OpRewritePattern<OpTy>::OpRewritePattern;
3126
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3127 LogicalResult matchAndRewrite(OpTy op,
3128 PatternRewriter &rewriter) const override {
3129 Type element_type =
3130 op.input().getType().template cast<TensorType>().getElementType();
3131 if (!element_type.isSignlessIntOrFloat()) return failure();
3132 tensorflow::Padding padding;
3133 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
3134 return failure();
3135 if (padding == tensorflow::Padding::EXPLICIT) {
3136 return failure();
3137 }
3138 Location loc = op.getLoc();
3139 ConstOp init = GetScalarLimitConstOfType(element_type, loc,
3140 hlo::kInfinityLowest, &rewriter);
3141
3142 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
3143 if (!input_ty) return failure();
3144 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
3145 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
3146 auto reduce = rewriter.create<ReduceWindowOp>(
3147 loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
3148 GetI64ElementsAttr(op.strides()),
3149 /*base_dilations=*/DenseIntElementsAttr(),
3150 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
3151 BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
3152
3153 rewriter.replaceOp(op, reduce.getResult(0));
3154 return success();
3155 }
3156 };
3157
3158 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
3159 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
3160
3161 // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on
3162 // the condition only.
3163 class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {
3164 public:
3165 using OpRewritePattern::OpRewritePattern;
3166
matchAndRewrite(TF::SelectOp op,PatternRewriter & rewriter) const3167 LogicalResult matchAndRewrite(TF::SelectOp op,
3168 PatternRewriter &rewriter) const override {
3169 // This lowering only works on ranked types.
3170 auto cond_type = op.condition().getType().dyn_cast<RankedTensorType>();
3171 auto then_type = op.t().getType().dyn_cast<RankedTensorType>();
3172 auto else_type = op.e().getType().dyn_cast<RankedTensorType>();
3173 if (!cond_type || !then_type || !else_type) {
3174 return failure();
3175 }
3176
3177 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
3178 Value cond_shape = b.createOrFold<shape::ShapeOfOp>(op.condition());
3179 Value then_shape = b.createOrFold<shape::ShapeOfOp>(op.t());
3180 Value else_shape = b.createOrFold<shape::ShapeOfOp>(op.e());
3181
3182 // First check that the `then` and `else` shapes are the equal.
3183 Value assumption =
3184 b.createOrFold<shape::CstrEqOp>(ValueRange{then_shape, else_shape});
3185 // For a vector cond we also verify that the majormost dim of `then` matches
3186 // the vector size. To do that split off the first dim of `then`.
3187 bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1;
3188 Value then_shape_split = then_shape;
3189 if (needs_broadcast) {
3190 Value const_one = b.create<ConstantIndexOp>(1);
3191 Type extents = shape::getExtentTensorType(b.getContext());
3192 SmallVector<Value, 2> then_split;
3193 b.createOrFold<shape::SplitAtOp>(then_split, TypeRange{extents, extents},
3194 then_shape, const_one);
3195 then_shape_split = then_split[0];
3196 }
3197 // If the condition is not a scalar, check that it matches the other shapes.
3198 if (cond_type.getRank() > 0) {
3199 Value eq_cstr = b.createOrFold<shape::CstrEqOp>(
3200 ValueRange{cond_shape, then_shape_split});
3201 auto witness = shape::WitnessType::get(b.getContext());
3202 assumption = b.createOrFold<shape::AssumingAllOp>(
3203 witness, ValueRange{assumption, eq_cstr});
3204 }
3205 auto result_type = op.getResult().getType().cast<TensorType>();
3206 auto assuming_op =
3207 b.create<shape::AssumingOp>(ArrayRef<Type>{result_type}, assumption);
3208
3209 OpBuilder::InsertionGuard guard(b);
3210 b.createBlock(&assuming_op.doRegion());
3211
3212 // Broadcast the cond if necessary.
3213 Value cond = op.condition();
3214 if (needs_broadcast) {
3215 Value result_extents = b.create<shape::ToExtentTensorOp>(
3216 GetExtentsTensorTypeFor(result_type), then_shape);
3217 cond = b.create<mhlo::DynamicBroadcastInDimOp>(
3218 RankedTensorType::get(result_type.getShape(), b.getI1Type()), cond,
3219 result_extents, GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b));
3220 }
3221 Value select = b.create<mhlo::SelectOp>(result_type, cond, op.t(), op.e());
3222 b.create<shape::AssumingYieldOp>(select);
3223 rewriter.replaceOp(op, {assuming_op.getResult(0)});
3224 return success();
3225 }
3226 };
3227
3228 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
3229 //
3230 // sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
3231 //
3232 // Sample result with 2-d f16 inputs with B batches of with N elements each.
3233 //
3234 // // Create an array of 0.5 the shape of the input array.
3235 // %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
3236 // %half_array = "mhlo.broadcast"(half)
3237 // {broadcast_sizes = dense<2> : tensor<1xi64>}
3238 // : (tensor<f32>) -> tensor<2xf32>
3239 //
3240 // // Compute Tanh of half the logits of the values.
3241 // %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
3242 // %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
3243 //
3244 // // Have the result of Tanh and add 0.5.
3245 // %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
3246 // %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
3247 //
3248 class ConvertSigmoidOp : public RewritePattern {
3249 public:
ConvertSigmoidOp(MLIRContext * context)3250 explicit ConvertSigmoidOp(MLIRContext *context)
3251 : RewritePattern(
3252 TF::SigmoidOp::getOperationName(), 0, context,
3253 {mhlo::ConstOp::getOperationName(),
3254 shape::ShapeOfOp::getOperationName(),
3255 shape::ToExtentTensorOp::getOperationName(),
3256 mhlo::DynamicBroadcastInDimOp::getOperationName(),
3257 mhlo::MulOp::getOperationName(), mhlo::TanhOp::getOperationName(),
3258 mhlo::AddOp::getOperationName()}) {}
3259
matchAndRewrite(Operation * sigmoid_op,PatternRewriter & rewriter) const3260 LogicalResult matchAndRewrite(Operation *sigmoid_op,
3261 PatternRewriter &rewriter) const override {
3262 auto op = cast<TF::SigmoidOp>(sigmoid_op);
3263 Location loc = op.getLoc();
3264
3265 // Create constant half with shape and element type same as the operand.
3266 Value operand = op.getOperand();
3267 auto operand_ty = operand.getType().cast<TensorType>();
3268 auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
3269 ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
3270 auto scalar_half = rewriter.create<ConstOp>(loc, attr);
3271 auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
3272
3273 auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
3274 auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
3275 auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
3276 auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
3277
3278 rewriter.replaceOp(op, add_op.getResult());
3279 return success();
3280 }
3281 };
3282
3283 // Converts the tf.Slice op into mhlo.real_dynamic_slice
3284 // TODO(disc): To recover static special case's performance with folding and
3285 // canonicalization.
3286 class ConvertSliceOpDynamic : public OpRewritePattern<TF::SliceOp> {
3287 public:
3288 using OpRewritePattern::OpRewritePattern;
3289
matchAndRewrite(TF::SliceOp op,PatternRewriter & rewriter) const3290 LogicalResult matchAndRewrite(TF::SliceOp op,
3291 PatternRewriter &rewriter) const override {
3292 Location loc = op.getLoc();
3293 Value input = op.input();
3294 Value begin_indices = op.begin();
3295 Value sizes = op.size();
3296
3297 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
3298 auto begin_type = begin_indices.getType().dyn_cast<RankedTensorType>();
3299 auto size_type = sizes.getType().dyn_cast<RankedTensorType>();
3300
3301 if (!input_ty || !begin_type || !size_type ||
3302 !begin_type.hasStaticShape() || !size_type.hasStaticShape() ||
3303 begin_type.getRank() != 1 || size_type.getRank() != 1) {
3304 return failure();
3305 }
3306 // TODO(disc): remove static shape check once folding/canonicalization func
3307 // added
3308 DenseIntElementsAttr size_attr;
3309 if (matchPattern(op.size(), m_Constant(&size_attr))) {
3310 return failure();
3311 }
3312
3313 int rank = begin_type.getDimSize(0);
3314 auto shape_scalar_type = begin_type.getElementType();
3315 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
3316 SmallVector<Value, 4> stride_values(rank, one);
3317 SmallVector<Value, 4> end_values;
3318 SmallVector<Value, 4> begin_values;
3319 end_values.reserve(rank);
3320 for (int i = 0; i < rank; ++i) {
3321 SmallVector<Value, 4> indices;
3322 indices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
3323 auto begin_value =
3324 rewriter.create<tensor::ExtractOp>(loc, begin_indices, indices);
3325 auto size_value = rewriter.create<tensor::ExtractOp>(loc, sizes, indices);
3326 Value minus_one = rewriter.create<IndexCastOp>(
3327 loc, rewriter.create<ConstantIndexOp>(loc, -1), shape_scalar_type);
3328 auto is_minus_one = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq,
3329 size_value, minus_one);
3330 Value end_value = rewriter.create<AddIOp>(loc, begin_value, size_value);
3331 auto dim_value = rewriter.create<IndexCastOp>(
3332 loc, rewriter.create<tensor::DimOp>(loc, input, i),
3333 shape_scalar_type);
3334 end_value = rewriter.create<mlir::SelectOp>(loc, is_minus_one, dim_value,
3335 end_value);
3336 auto end_value_casted =
3337 rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), end_value);
3338 end_values.push_back(end_value_casted);
3339
3340 auto begin_value_casted = rewriter.create<IndexCastOp>(
3341 loc, rewriter.getIndexType(), begin_value);
3342 begin_values.push_back(begin_value_casted);
3343 }
3344
3345 auto start_indices = rewriter.create<tensor::FromElementsOp>(
3346 loc, rewriter.getIndexType(), begin_values);
3347 auto end_indices = rewriter.create<tensor::FromElementsOp>(
3348 loc, rewriter.getIndexType(), end_values);
3349 auto stride_indices = rewriter.create<tensor::FromElementsOp>(
3350 loc, rewriter.getIndexType(), stride_values);
3351
3352 auto d_slice = rewriter.create<mhlo::RealDynamicSliceOp>(
3353 loc, op.getOperation()->getResult(0).getType(), input, start_indices,
3354 end_indices, stride_indices);
3355 rewriter.replaceOp(op, d_slice.getOperation()->getResults());
3356 return success();
3357 }
3358 };
3359
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)3360 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
3361 Value *out_lhs, Value *out_rhs,
3362 PatternRewriter *rewriter) {
3363 // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
3364 // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
3365 // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
3366 // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
3367 // To perform the matmul, we need to first broadcast lhs and rhs to a common
3368 // set of leading dimensions before doing the actual matmul.
3369 // That's what the code below does.
3370 // In particular, we populate out_lhs and out_rhs to have dimension structure:
3371 // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
3372 // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
3373 // To do this, we need to calculate those output shapes, which involves
3374 // slicing off the leading batch dims of each operand, broadcasting them,
3375 // then concatenating the broadcasted leading dims back to the row/col dims.
3376 // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
3377
3378 // TODO(silvasean): Reduce duplication across reified shape calculations and
3379 // the static computation of output types needed to create ops.
3380 Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
3381 Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
3382 Value const_neg2 =
3383 rewriter->create<ConstantOp>(loc, rewriter->getIndexAttr(-2));
3384 auto shape_type = shape::ShapeType::get(rewriter->getContext());
3385 auto lhs_splitted = rewriter->create<shape::SplitAtOp>(
3386 loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2);
3387 auto rhs_splitted = rewriter->create<shape::SplitAtOp>(
3388 loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2);
3389 auto lhs_type = lhs.getType().cast<RankedTensorType>();
3390 auto rhs_type = rhs.getType().cast<RankedTensorType>();
3391 // The last two dimensions are the matrix row/col dimensions. Don't broadcast
3392 // them.
3393 SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
3394 mlir::OpTrait::util::getBroadcastedShape(
3395 lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2),
3396 result_batch_shape_compile_time_extents);
3397 auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
3398 loc, shape_type, lhs_splitted.head(), rhs_splitted.head(),
3399 /*error=*/nullptr);
3400 // Lambda which handles the broadcasting of one side to the common
3401 // leading-batch dimensions.
3402 auto broadcast_one_side = [&](Value side, RankedTensorType type,
3403 Value tail_shape, Value *out_side) {
3404 ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
3405 auto result_shape = result_batch_shape_compile_time_extents;
3406 result_shape.append(matrix_dims.begin(), matrix_dims.end());
3407 auto result_type =
3408 RankedTensorType::get(result_shape, type.getElementType());
3409 auto shape =
3410 rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
3411 auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
3412 loc,
3413 RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
3414 rewriter->getIndexType()),
3415 shape);
3416 *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
3417 shape_tensor);
3418 };
3419 broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
3420 broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
3421 }
3422
3423 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
3424 public:
3425 // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
3426 // to CHLO and it is missing legalization to MHLO. Once that is done, this
3427 // pattern's benefit can be changed back to one as well as the fallback
3428 // lowering pattern for the op can be removed.
3429 //
3430 // Set benefit of this pattern to zero to prefer the fallback pattern when
3431 // available and applicable. That pattern avoids broadcast on operands and is
3432 // therefore faster.
3433 //
3434 // Native legalization for BatchMatMulV3 needs to be added as well.
ConvertBatchMatMulV2Op(MLIRContext * context)3435 explicit ConvertBatchMatMulV2Op(MLIRContext *context)
3436 : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
3437
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const3438 LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
3439 PatternRewriter &rewriter) const override {
3440 Value lhs = op.x();
3441 Value rhs = op.y();
3442 auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
3443 auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
3444 if (!lhs_type || !rhs_type) return failure();
3445 if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
3446 lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
3447 }
3448 if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
3449 rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
3450 }
3451
3452 // Broadcast both operands.
3453 BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
3454 &rewriter);
3455 lhs_type = lhs.getType().cast<RankedTensorType>();
3456 rhs_type = rhs.getType().cast<RankedTensorType>();
3457 assert(lhs_type.getRank() == rhs_type.getRank());
3458 int64_t rank = lhs_type.getRank();
3459 auto batch_dimensions = GetI64ElementsAttr(
3460 llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
3461 auto lhs_contracting_dimensions = GetI64ElementsAttr(
3462 llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
3463 auto rhs_contracting_dimensions = GetI64ElementsAttr(
3464 llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
3465 auto dimension_numbers = DotDimensionNumbers::get(
3466 /*lhs_batching_dimensions=*/batch_dimensions,
3467 /*rhs_batching_dimensions=*/batch_dimensions,
3468 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
3469 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
3470 rewriter.getContext());
3471 // TODO(silvasean): Emit shape checks for contracting dimensions.
3472 // (The batch dimensions are checked by the broadcasting logic)
3473 rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
3474 dimension_numbers,
3475 /*precision_config=*/nullptr);
3476 return success();
3477 }
3478 };
3479
3480 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
3481 // split has fully static shape and the dimension to split is a constant.
3482 //
3483 // The main logic of this pattern is to calculate the index start and end range
3484 // for each slice. And this happens only on the dimension to be split; for all
3485 // other dimensions, all resultant slices' index start and end range covers the
3486 // input tensor's full range. Strides for all resultant slices are all one.
3487 //
3488 // For example, the following source IR:
3489 //
3490 // %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3491 // %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
3492 // (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3493 //
3494 // will be converted into:
3495 //
3496 // %0 = "mhlo.slice"(%input) {
3497 // limit_indices = dense<[4, 2]> : tensor<2xi64>,
3498 // start_indices = dense<0> : tensor<2xi64>,
3499 // strides = dense<1> : tensor<2xi64>} :
3500 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3501 // %1 = "mhlo.slice"(%input) {
3502 // limit_indices = dense<4> : tensor<2xi64>,
3503 // start_indices = dense<[0, 2]> : tensor<2xi64>,
3504 // strides = dense<1> : tensor<2xi64>} :
3505 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3506 // %2 = "mhlo.slice"(%input) {
3507 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
3508 // start_indices = dense<[0, 4]> : tensor<2xi64>,
3509 // strides = dense<1> : tensor<2xi64>} :
3510 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3511 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
3512 // applicable.
3513 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
3514 public:
3515 using OpRewritePattern::OpRewritePattern;
3516
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3517 LogicalResult matchAndRewrite(TF::SplitOp op,
3518 PatternRewriter &rewriter) const override {
3519 // We can only split along static dimensions.
3520 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3521 if (!input_type) return failure();
3522
3523 // We can only match when the split dimension is a constant scalar.
3524 DenseIntElementsAttr split_dim_attr;
3525 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3526 return failure();
3527
3528 // Get the dimension we are splitting at. Offset properly if it's negative.
3529 int64_t input_rank = input_type.getRank();
3530 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3531 if (dim_index < 0) dim_index += input_rank;
3532
3533 // Calculate the dimension size for each slice along the split dimension.
3534 int64_t input_dim_size = input_type.getDimSize(dim_index);
3535 // If we are splitting along the dynamic dimension then we cannot compute
3536 // the static dimension length.
3537 if (TensorType::isDynamic(input_dim_size)) return failure();
3538
3539 int64_t num_splits = op.getNumResults();
3540 int64_t slice_size = input_dim_size / num_splits;
3541
3542 // Get each slice's type.
3543 auto slice_shape = llvm::to_vector<4>(input_type.getShape());
3544 slice_shape[dim_index] = slice_size;
3545 Type slice_type =
3546 RankedTensorType::get(slice_shape, input_type.getElementType());
3547
3548 // Parameters for constructing each slice.
3549 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3550 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3551 SmallVector<int64_t, 4> strides(input_rank, 1);
3552
3553 // All HLO slice results used to replace the original tf.Split op.
3554 SmallVector<Value, 4> slices;
3555 slices.reserve(num_splits);
3556
3557 for (int i = 0; i < num_splits; ++i) {
3558 begin_indices[dim_index] = i * slice_size;
3559 end_indices[dim_index] = (i + 1) * slice_size;
3560 slices.push_back(
3561 rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
3562 GetI64ElementsAttr(begin_indices, &rewriter),
3563 GetI64ElementsAttr(end_indices, &rewriter),
3564 GetI64ElementsAttr(strides, &rewriter)));
3565 }
3566
3567 rewriter.replaceOp(op, slices);
3568 return success();
3569 }
3570 };
3571
3572 // Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the
3573 // dimension to split is a constant.
3574 // TODO(disc): To recover static special case's performance with folding and
3575 // canonicalization. delete ConvertSplitOp
3576 class ConvertSplitOpDynamic : public OpRewritePattern<TF::SplitOp> {
3577 public:
3578 using OpRewritePattern::OpRewritePattern;
3579
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3580 LogicalResult matchAndRewrite(TF::SplitOp op,
3581 PatternRewriter &rewriter) const override {
3582 Location loc = op.getLoc();
3583 Value input = op.value();
3584 auto input_type = input.getType().dyn_cast<RankedTensorType>();
3585 if (!input_type) return failure();
3586 // We can only match when the split dimension is a constant scalar.
3587 DenseIntElementsAttr split_dim_attr;
3588 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3589 return failure();
3590
3591 // Get the dimension we are splitting at. Offset properly if it's negative.
3592 int64_t input_rank = input_type.getRank();
3593 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3594 if (dim_index < 0) dim_index += input_rank;
3595
3596 // TODO(disc): remove static shape check once folding/canonicalization func
3597 // added and ConvertSplitOp deleted. Calculate the dimension size for each
3598 // slice along the split dimension. We are splitting along the dynamic
3599 // dimension, or using static pattern transform
3600 int64_t c_input_dim_size = input_type.getDimSize(dim_index);
3601 if (!TensorType::isDynamic(c_input_dim_size)) return failure();
3602
3603 Value input_dim_size =
3604 rewriter.create<tensor::DimOp>(loc, input, dim_index);
3605 // Calculate the dimension size for each slice along the split dimension.
3606 int num_splits = op.getNumResults();
3607 Value num_splits_value =
3608 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(num_splits));
3609 Value slice_size =
3610 rewriter.create<SignedDivIOp>(loc, input_dim_size, num_splits_value);
3611
3612 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
3613 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
3614
3615 SmallVector<Value, 4> begin_indices(input_rank, zero);
3616 SmallVector<Value, 4> end_indices;
3617 end_indices.reserve(input_rank);
3618 SmallVector<Value, 4> strides(input_rank, one);
3619 for (int i = 0; i < input_rank; ++i) {
3620 end_indices.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
3621 }
3622
3623 // All HLO d_slice results used to replace the original tf.Split op.
3624 SmallVector<Value, 4> slices;
3625 slices.reserve(num_splits);
3626
3627 for (int i = 0; i < num_splits; ++i) {
3628 begin_indices[dim_index] = rewriter.create<MulIOp>(
3629 loc, slice_size, rewriter.create<ConstantIndexOp>(loc, i));
3630 end_indices[dim_index] = rewriter.create<MulIOp>(
3631 loc, slice_size, rewriter.create<ConstantIndexOp>(loc, i + 1));
3632 auto begin_value = rewriter.create<tensor::FromElementsOp>(
3633 loc, rewriter.getIndexType(), begin_indices);
3634 auto end_value = rewriter.create<tensor::FromElementsOp>(
3635 loc, rewriter.getIndexType(), end_indices);
3636 auto stride_value = rewriter.create<tensor::FromElementsOp>(
3637 loc, rewriter.getIndexType(), strides);
3638 slices.push_back(rewriter.create<RealDynamicSliceOp>(
3639 loc, op.getOperation()->getResult(i).getType(), input, begin_value,
3640 end_value, stride_value));
3641 }
3642
3643 rewriter.replaceOp(op, slices);
3644 return success();
3645 }
3646 };
3647
3648 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
3649 // be split has fully static shape and the dimension to split and split sizes
3650 // are constants.
3651 //
3652 // This is similar to the conversion for tf.Split op other than that the size of
3653 // each chunk on the dimension to split is explicitly given as an op operand
3654 // and they are not necessarily the same.
3655 //
3656 // For example, given the following IR:
3657 //
3658 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
3659 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
3660 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
3661 // (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
3662 // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
3663 //
3664 // We will generate slices following slices:
3665 // %0 = "mhlo.slice"(%input) {
3666 // limit_indices = dense<[4, 1]> : tensor<2xi64>,
3667 // start_indices = dense<0> : tensor<2xi64>,
3668 // strides = dense<1> : tensor<2xi64>} :
3669 // (tensor<4x6xf32>) -> tensor<4x1xf32>
3670 // %1 = "mhlo.slice"(%input) {
3671 // limit_indices = dense<[4, 3]> : tensor<2xi64>,
3672 // start_indices = dense<[0, 1]> : tensor<2xi64>,
3673 // strides = dense<1> : tensor<2xi64>} :
3674 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3675 // %2 = "mhlo.slice"(%input) {
3676 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
3677 // start_indices = dense<[0, 3]> : tensor<2xi64>,
3678 // strides = dense<1> : tensor<2xi64>} :
3679 // (tensor<4x6xf32>) -> tensor<4x3xf32>
3680 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
3681 public:
3682 using OpRewritePattern::OpRewritePattern;
3683
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3684 LogicalResult matchAndRewrite(TF::SplitVOp op,
3685 PatternRewriter &rewriter) const override {
3686 // We can only split along static dimensions.
3687 // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3688 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3689 if (!input_type) return failure();
3690
3691 // We can only match when the split dimension is a constant scalar.
3692 DenseIntElementsAttr split_dim_attr;
3693 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3694 return failure();
3695
3696 // We can only match when the split sizes is a constant int vector.
3697 DenseIntElementsAttr split_sizes_attr;
3698 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3699 return failure();
3700
3701 // Get each chunck's size along the dimension to split. It may contain
3702 // dynamic sizes and we need to update it if so.
3703 SmallVector<int64_t, 4> split_sizes;
3704 int64_t total_dim_size = 0; // Total dimension size assigned to splits
3705 llvm::Optional<int> dynamic_dim_index;
3706 split_sizes.reserve(
3707 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3708 for (auto dim : llvm::enumerate(split_sizes_attr)) {
3709 int64_t dim_val = dim.value().getSExtValue();
3710 split_sizes.push_back(dim_val);
3711 if (dim_val == ShapedType::kDynamicSize) {
3712 // We cannot have more than one dynamic dimension.
3713 assert(!dynamic_dim_index && "invalid split sizes");
3714 dynamic_dim_index = dim.index();
3715 } else {
3716 total_dim_size += dim_val;
3717 }
3718 }
3719
3720 // Get the dimension we are splitting at. Offset properly if it's negative.
3721 int64_t input_rank = input_type.getRank();
3722 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3723 if (dim_index < 0) dim_index += input_rank;
3724
3725 int64_t input_dim_size = input_type.getDimSize(dim_index);
3726 if (TensorType::isDynamic(input_dim_size)) return failure();
3727
3728 assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3729 (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3730 "invalid split sizes");
3731
3732 // Update the dynamic dimension with calculated concrete size.
3733 if (dynamic_dim_index)
3734 split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3735
3736 // Parameters for constructing each slice.
3737 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3738 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3739 SmallVector<int64_t, 4> strides(input_rank, 1);
3740
3741 // All HLO slice results used to replace the original tf.Split op.
3742 SmallVector<Value, 4> slices;
3743 slices.reserve(op.getNumResults());
3744
3745 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3746 end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3747 slices.push_back(rewriter.create<mhlo::SliceOp>(
3748 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3749 GetI64ElementsAttr(end_indices, &rewriter),
3750 GetI64ElementsAttr(strides, &rewriter)));
3751 // Prepare the begin indice for the next slice.
3752 begin_indices[dim_index] = end_indices[dim_index];
3753 }
3754
3755 rewriter.replaceOp(op, slices);
3756 return success();
3757 }
3758 };
3759
3760 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3761 // negative strides and Reshape op to update the output shape. Indices and
3762 // strides operands are converted to attributes with non-negative indexing.
3763 //
3764 // If the begin input is not a compile time constant, the begin input needs to
3765 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3766 // case, strides must have a known value of 1 (otherwise we have insufficient
3767 // information to conform to XLA's op semantics).
3768 //
3769 // For example with an op like following,
3770 // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3771 // : tensor<AxBxf32> -> tensor<Pxf32>
3772 //
3773 // If the %begin input is constant, output would be:
3774 // %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3775 // %sliced = "mhlo.Slice" (%input)
3776 // {start_indices = ..., limit_indices = ..., strides = ...}
3777 // %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3778 //
3779 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3780 public:
3781 using OpRewritePattern::OpRewritePattern;
3782
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3783 LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3784 ArrayRef<int64_t> begin_indices,
3785 ArrayRef<int64_t> end_indices,
3786 ArrayRef<int64_t> strides,
3787 RankedTensorType input_ty,
3788 PatternRewriter &rewriter) const {
3789 SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3790 dims_to_reverse;
3791 int64_t input_rank = input_ty.getRank();
3792 ArrayRef<int64_t> input_shape = input_ty.getShape();
3793 hlo_begin_indices.reserve(input_rank);
3794 hlo_end_indices.reserve(input_rank);
3795 hlo_strides.reserve(input_rank);
3796
3797 int64_t indices_elements = begin_indices.size();
3798 if (input_rank < indices_elements) return failure();
3799
3800 // Convert from TensorFlow negative or out of range indices and strides
3801 // values to legal HLO Slice attributes.
3802 for (int i = 0, e = indices_elements; i != e; i++) {
3803 int64_t begin = begin_indices[i];
3804 int64_t end = end_indices[i];
3805 int64_t stride = strides[i];
3806
3807 if (stride < 0) {
3808 // Negative stride means that the output values are computed starting
3809 // from end until begin. Mark the dimension for reversal before slice
3810 // and compute indices for the reversed input.
3811 dims_to_reverse.push_back(i);
3812 begin = (input_shape[i] - 1) - begin;
3813 end = (input_shape[i] - 1) - end;
3814 stride = -stride;
3815 }
3816
3817 // Unlike TensorFlow, HLO requires begin and end values to be within
3818 // range.
3819 begin = std::max(int64_t(0), begin);
3820 end = std::max(begin, end);
3821 end = std::min(end, input_shape[i]);
3822
3823 hlo_begin_indices.push_back(begin);
3824 hlo_end_indices.push_back(end);
3825 hlo_strides.push_back(stride);
3826 }
3827
3828 Location loc = op.getLoc();
3829 Value input = op.input();
3830 if (!dims_to_reverse.empty())
3831 input = rewriter.create<ReverseOp>(
3832 loc, input_ty, op.input(),
3833 GetI64ElementsAttr(dims_to_reverse, &rewriter));
3834 auto sliced = rewriter.create<SliceOp>(
3835 loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3836 GetI64ElementsAttr(hlo_end_indices, &rewriter),
3837 GetI64ElementsAttr(hlo_strides, &rewriter));
3838
3839 // Reshape slice result so that the shape is updated depending on
3840 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3841 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3842 return success();
3843 }
3844
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3845 LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3846 RankedTensorType input_ty,
3847 RankedTensorType result_ty,
3848 PatternRewriter &rewriter) const {
3849 // If begin and end values are dynamic, we can only support this lowering
3850 // if strides are a known value of 1.
3851 DenseIntElementsAttr sparse_strides_attr;
3852 if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3853 return rewriter.notifyMatchFailure(
3854 op,
3855 "requires that strides are known when begin/end values are dynamic");
3856 }
3857 SmallVector<int64_t, 4> strides;
3858 int64_t stride_value;
3859 for (const APInt &stride : sparse_strides_attr) {
3860 if ((stride_value = stride.getSExtValue()) != 1) {
3861 return rewriter.notifyMatchFailure(op,
3862 "requires that strides are all 1 "
3863 "when begin/end values are dynamic");
3864 }
3865 strides.push_back(stride_value);
3866 }
3867
3868 ArrayRef<int64_t> input_shape = input_ty.getShape();
3869 int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3870
3871 // When begin/end values are dynamic, we can only support shrinking a major
3872 // axis. For instance, if there are 4 dims, we can support a
3873 // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3874 // other.
3875 bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3876 if (!shrink_axis_mask_ok)
3877 return rewriter.notifyMatchFailure(
3878 op,
3879 "requires that shrink_axis_mask, if set, refer to a major axis "
3880 "dimension (when begin/end values are dynamic)");
3881
3882 // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3883 // to the last dimension.
3884 int ellipsis_mask = op.ellipsis_mask();
3885 if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3886 return rewriter.notifyMatchFailure(
3887 op,
3888 "requires that ellipsis_mask, if set, refer to the last dimension of "
3889 "input (when begin/end values are dynamic)");
3890
3891 uint64_t begin_mask = op.begin_mask();
3892 if (begin_mask)
3893 return rewriter.notifyMatchFailure(
3894 op,
3895 "requires that begin_mask is either set to 0 or not set when "
3896 "begin/end values are dynamic");
3897 uint64_t end_mask = op.end_mask();
3898 if (end_mask)
3899 return rewriter.notifyMatchFailure(
3900 op,
3901 "requires that end_mask is either set to 0 or not set when begin/end "
3902 "values are dynamic");
3903 uint64_t new_axis_mask = op.new_axis_mask();
3904 if (new_axis_mask)
3905 return rewriter.notifyMatchFailure(
3906 op,
3907 "requires that new_axis_mask is either set to 0 or not set when "
3908 "begin/end values are dynamic");
3909
3910 // In this case where the begin and end values are dynamic, the number of
3911 // output elements has to be equal to the number of input elements that
3912 // are sliced.
3913 int output_elements = result_ty.getNumElements();
3914 int input_elements_sliced = 1;
3915
3916 // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3917 // verifier.
3918 int64_t slicing_dim_size =
3919 op.begin().getType().cast<RankedTensorType>().getShape()[0];
3920 const int input_rank = input_shape.size();
3921 for (int d = slicing_dim_size; d < input_rank; ++d) {
3922 // We only support slicing major dimensions, so minor dimensions after
3923 // slicing dimensions are all sliced with their full sizes.
3924 input_elements_sliced *= input_shape[d];
3925 }
3926 if (input_elements_sliced != output_elements) {
3927 return rewriter.notifyMatchFailure(
3928 op,
3929 "requires the number of output elements to be equal to the number of "
3930 "input elements sliced (when begin/end values are dynamic)");
3931 }
3932
3933 SmallVector<Value, 4> slice_begin_indices;
3934 // For the dimensions that are to be sliced, all have slice sizes of 1.
3935 SmallVector<int64_t, 4> slice_sizes(slicing_dim_size, 1);
3936 auto begin_element_ty =
3937 op.begin().getType().cast<ShapedType>().getElementType();
3938 // Scalar tensor type.
3939 TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3940 Location loc = op.getLoc();
3941 auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3942 for (int d = 0; d < slicing_dim_size; ++d) {
3943 auto index = rewriter.create<SliceOp>(
3944 loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3945 GetI64ElementsAttr({d + 1}, &rewriter),
3946 GetI64ElementsAttr({1}, &rewriter));
3947 // Convert index to scalar.
3948 auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3949 // If the index is negative, wrap it around with dimension size.
3950 auto index_negative =
3951 rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3952 auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3953 input_shape[d], &rewriter);
3954 auto wrapped_index =
3955 rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3956 auto final_index = rewriter.create<SelectOp>(
3957 loc, type, index_negative, wrapped_index, reshaped_index);
3958 slice_begin_indices.push_back(final_index);
3959 }
3960
3961 // For non-slice dims, get the full slice of that dimension.
3962 for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) {
3963 slice_sizes.push_back(input_shape[d]);
3964 slice_begin_indices.push_back(zero);
3965 }
3966
3967 auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3968 // This must be an xla DynamicSlice op due to the inputs that aren't
3969 // constant.
3970 auto sliced = rewriter.create<DynamicSliceOp>(
3971 loc, op.getType(), op.input(), slice_begin_indices, slice_sizes_attr);
3972
3973 // Reshape slice result so that the shape is updated depending on
3974 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3975 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3976 return success();
3977 }
3978
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3979 LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3980 PatternRewriter &rewriter) const override {
3981 // Input shape needs to be static to convert negative indices in TensorFlow
3982 // to absolute indices required by HLO.
3983 //
3984 // TODO(hinsu): Relax this constraint for ops without negative indices and
3985 // strides.
3986 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3987 if (!input_ty || !input_ty.hasStaticShape()) return failure();
3988
3989 // Output shape needs to be static to apply 'new_axis_mask' or
3990 // 'shrink_axis_mask' by reshaping tensor after slice.
3991 //
3992 // TODO(hinsu): Relax this constraint for ops without the above masks.
3993 auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3994 if (!result_ty || !result_ty.hasStaticShape()) return failure();
3995
3996 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3997 if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3998 !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3999 return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
4000 }
4001
4002 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
4003 if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
4004 return failure();
4005 }
4006 return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
4007 input_ty, rewriter);
4008 }
4009 };
4010
4011 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
4012 //
4013 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
4014 // the reverse: it propagates the graident for the sliced tensor to the original
4015 // input tensor by doing padding with zeros. The main logic is calculating the
4016 // indices and strides for padding.
4017 class ConvertStridedSliceGradOp
4018 : public OpRewritePattern<TF::StridedSliceGradOp> {
4019 public:
4020 using OpRewritePattern::OpRewritePattern;
4021
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const4022 LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
4023 PatternRewriter &rewriter) const override {
4024 // We need constant input shape to perform padding calculations later.
4025 DenseIntElementsAttr input_shape_attr;
4026 if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
4027 return failure();
4028
4029 // We also need constant begin/end indices and strides to perform padding
4030 // calculations.
4031 // Bounded shape after performing strided slice
4032 SmallVector<int64_t, 4> shape;
4033 // Bounded begin, end, and strides for strided slice
4034 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
4035 if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
4036 &strides))
4037 return failure();
4038
4039 Value grad = op.dy();
4040 Type element_type = grad.getType().cast<ShapedType>().getElementType();
4041
4042 // Perform reshape to undo any new/shrink axes done by strided slice.
4043 grad = rewriter.create<mhlo::ReshapeOp>(
4044 op.getLoc(), RankedTensorType::get(shape, element_type), grad);
4045
4046 SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
4047 SmallVector<int64_t, 4> dims_to_reverse;
4048 padding_low.reserve(shape.size());
4049 padding_high.reserve(shape.size());
4050 padding_interm.reserve(shape.size());
4051
4052 // Prepare padding parameters for each dimension.
4053 for (int i = 0, e = shape.size(); i < e; ++i) {
4054 int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
4055 if (strides[i] > 0) {
4056 padding_low.push_back(begin_indices[i]);
4057 padding_interm.push_back(strides[i] - 1);
4058
4059 // Pad the upper dimension up to the expected input shape. It's not
4060 // sufficient simply to use end_indices[i] to compute the padding in
4061 // cases where the stride does not divide evenly into the interval
4062 // between begin_indices[i] and end_indices[i].
4063 int64_t size =
4064 padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
4065 padding_high.push_back(input_dim - size);
4066 } else {
4067 dims_to_reverse.push_back(i);
4068 padding_high.push_back(input_dim - begin_indices[i] - 1);
4069 padding_interm.push_back(-strides[i] - 1);
4070
4071 // Pad the lower dimension up to the expected input shape.
4072 int64_t size =
4073 padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
4074 padding_low.push_back(input_dim - size);
4075 }
4076 }
4077
4078 if (!dims_to_reverse.empty()) {
4079 grad = rewriter.create<mhlo::ReverseOp>(
4080 op.getLoc(), grad.getType(), grad,
4081 GetI64ElementsAttr(dims_to_reverse, &rewriter));
4082 }
4083
4084 auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
4085 rewriter.replaceOpWithNewOp<mhlo::PadOp>(
4086 op, op.getType(), grad, zero,
4087 GetI64ElementsAttr(padding_low, &rewriter),
4088 GetI64ElementsAttr(padding_high, &rewriter),
4089 GetI64ElementsAttr(padding_interm, &rewriter));
4090 return success();
4091 }
4092 };
4093
4094 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
4095 /// offset applied to generate the range values. The output tensor needs to
4096 /// have a static shape.
4097 ///
4098 /// For example an op like the following:
4099 /// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
4100 /// : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
4101 ///
4102 /// Output would be:
4103 /// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
4104 /// %scaled = "mhlo.multiply"(%iota, %delta)
4105 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
4106 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
4107 /// %result = "mhlo.add"(%scaled, %offset)
4108 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
4109 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
4110 ///
4111 /// Implementation is defined in C++ due to no type interface for the iota op.
4112 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
4113 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
4114
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const4115 LogicalResult matchAndRewrite(TF::RangeOp op,
4116 PatternRewriter &rewriter) const override {
4117 auto result = op.getResult();
4118 auto result_type = result.getType();
4119 if (!result_type.cast<ShapedType>().hasStaticShape()) {
4120 return failure();
4121 }
4122
4123 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4124 rewriter.getI64IntegerAttr(0));
4125 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4126 op.getLoc(), result_type, iota, op.delta(),
4127 hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
4128 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4129 op, result_type, scaled, op.start(),
4130 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4131 return success();
4132 }
4133 };
4134
4135 // Converts RangeOp for cases with the length is a dynamic value. The shape of
4136 // the resulting tensor computed, then the start and delta is used with the
4137 // dynamic_iota value to compute the final range value.
4138 //
4139 // For example, the resulting range op value:
4140 // %range = "tf.range"(%start, %limit, %delta)
4141 //
4142 // Is converted to the following.
4143 // %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
4144 //
4145 // Implementation is defined in C++ due to the complicated type behavior.
4146 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
4147 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
4148
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const4149 LogicalResult matchAndRewrite(TF::RangeOp op,
4150 PatternRewriter &rewriter) const override {
4151 auto result = op.getResult();
4152 auto result_type = result.getType().cast<ShapedType>();
4153 if (result_type.hasStaticShape()) {
4154 return failure();
4155 }
4156
4157 Value start = op.start();
4158 Value delta = op.delta();
4159 Value limit = op.limit();
4160
4161 // To compute the length we need to use floating point calculations so that
4162 // ceil can be computed for the number of steps.
4163 auto compute_element_type =
4164 getElementTypeOrSelf(start.getType()).isa<FloatType>()
4165 ? getElementTypeOrSelf(start.getType())
4166 : rewriter.getF64Type();
4167 auto compute_type = RankedTensorType::get(
4168 limit.getType().cast<ShapedType>().getShape(), compute_element_type);
4169
4170 // Compute the length of the sequence we are going to need. This includes
4171 // some conversion to float for the operations.
4172 //
4173 // %size = ceil(abs((%limit - %start) / %delta))
4174 auto range = rewriter.create<mhlo::SubOp>(op.getLoc(), limit, start);
4175 auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
4176
4177 // Delta is not necessarily the same type as start and limit.
4178 auto abs_cast =
4179 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
4180 auto delta_cast =
4181 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
4182
4183 // Compute the total number of integer steps and convert to the HLO
4184 // dimension tensor.
4185 auto normalized =
4186 rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
4187 auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
4188 auto steps = rewriter.create<mhlo::ConvertOp>(
4189 op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
4190 auto reshape = rewriter.create<mhlo::ReshapeOp>(
4191 op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
4192
4193 // Using the resulting length compute the correct range value:
4194 //
4195 // %range = %start + %delta * iota(%size)
4196 auto out_scalar_type =
4197 RankedTensorType::get({}, getElementTypeOrSelf(result_type));
4198 auto start_out_cast =
4199 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
4200 auto delta_out_cast =
4201 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
4202
4203 auto iota = rewriter.create<DynamicIotaOp>(
4204 op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
4205 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4206 op.getLoc(), result_type, iota, delta_out_cast,
4207 hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
4208 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4209 op, result_type, scaled, start_out_cast,
4210 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
4211 return success();
4212 }
4213 };
4214
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)4215 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
4216 auto int_attr = attr.cast<DenseIntElementsAttr>();
4217 auto type = val.getType().cast<ShapedType>();
4218
4219 SmallVector<int64_t, 6> axis;
4220 axis.reserve(int_attr.getNumElements());
4221
4222 int64_t rank = type.getRank();
4223 for (auto val : int_attr.getValues<APInt>()) {
4224 axis.push_back((val.getSExtValue() + rank) % rank);
4225 }
4226
4227 return builder->getI64TensorAttr(axis);
4228 }
4229
4230 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
4231 /// and offset applied to generate the linspace values. The output tensor needs
4232 /// to have a static shape. The implementation is defined in C++ because there
4233 /// is no type inference for the iota op.
4234 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
4235 using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
4236
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const4237 LogicalResult matchAndRewrite(TF::LinSpaceOp op,
4238 PatternRewriter &rewriter) const override {
4239 auto result = op.getResult();
4240 auto result_type = result.getType().dyn_cast<ShapedType>();
4241 if (!result_type || !result_type.hasStaticShape()) {
4242 return failure();
4243 }
4244
4245 DenseIntElementsAttr num_attr;
4246 if (!matchPattern(op.num(), m_Constant(&num_attr))) {
4247 return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
4248 }
4249
4250 if (num_attr.begin() == num_attr.end()) {
4251 return rewriter.notifyMatchFailure(op, "Num must not be empty");
4252 }
4253 int64_t num = (*num_attr.begin()).getSExtValue();
4254
4255 // Calculate the scaling that needs to be applied to the iota.
4256 auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
4257 op.getLoc(), op.start().getType(), op.stop(), op.start(),
4258 hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
4259 Value step_denominator = rewriter.create<ConvertOp>(
4260 op.getLoc(), op.num(), result_type.getElementType());
4261 if (num > 1) {
4262 Value one = GetScalarConstOfType(result_type.getElementType(),
4263 op.getLoc(), 1, &rewriter);
4264 step_denominator = rewriter.create<chlo::BroadcastSubOp>(
4265 op.getLoc(), step_denominator.getType(), step_denominator, one,
4266 hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
4267 }
4268 auto step = rewriter.create<chlo::BroadcastDivOp>(
4269 op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
4270 hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
4271 step_denominator));
4272
4273 // Scale the iota and add the offset.
4274 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4275 rewriter.getI64IntegerAttr(0));
4276 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4277 op.getLoc(), result_type, iota, step,
4278 hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
4279 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4280 op, result_type, scaled, op.start(),
4281 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4282 return success();
4283 }
4284 };
4285
4286 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
4287 /// ReductionOp.
4288 /// `is_accumulation` controls whether it uses higher precision for the actual
4289 /// reduction. This is set to false for ops like max where there is no precision
4290 /// concerns.
4291 //
4292 // The Derived class should have a static method to return the initial value to
4293 // use for reduction:
4294 // static Value GetInitialValue(Type reduce_element_type, Location loc,
4295 // PatternRewriter *rewriter);
4296 // The reduce_element_type is guaranteed to be a float, int, or complex type
4297 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
4298 template <typename Derived, typename OpTy, typename ReductionOp,
4299 bool is_accumulation = true>
4300 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
4301 using OpRewritePattern<OpTy>::OpRewritePattern;
4302
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4303 LogicalResult matchAndRewrite(OpTy op,
4304 PatternRewriter &rewriter) const override {
4305 // TODO(b/141785544): Update this to not require ranked shapes.
4306 // Input shape needs to be ranked to convert negative indices in TensorFlow
4307 // to absolute indices required by HLO.
4308 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4309 if (!input_ty) return failure();
4310 ArrayRef<int64_t> input_shape = input_ty.getShape();
4311
4312 DenseIntElementsAttr dimensions;
4313 if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
4314 return failure();
4315
4316 // Build the final shape from input_shape and dimensions using a bitmap
4317 // to mark the reduced dimensions.
4318 SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
4319 SmallVector<int64_t, 4> xla_dimensions;
4320 for (const APInt &index_raw : dimensions.getValues<APInt>()) {
4321 int64_t index = index_raw.getSExtValue();
4322 int64_t rank = input_shape.size();
4323 if ((index < -rank || index >= rank)) return failure();
4324 index = (index + rank) % rank;
4325 reduced_dimensions_bitmap[index] = true;
4326 xla_dimensions.push_back(index);
4327 }
4328
4329 Location loc = op.getLoc();
4330 Type element_type = input_ty.getElementType();
4331
4332 // Only float, int, and complex types are currently supported.
4333 if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
4334 !element_type.isa<ComplexType>()) {
4335 return rewriter.notifyMatchFailure(
4336 op, "element type must be float, int, or complex type");
4337 }
4338
4339 // Convert to an accumulation type to not lose precision when doing
4340 // repeated arithmetic operations.
4341 Type reduce_element_type =
4342 is_accumulation ? GetAccumulationType(element_type) : element_type;
4343 auto casted_input =
4344 rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
4345
4346 // Each reduction op can have a different initial value.
4347 Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
4348
4349 auto reduction = rewriter.create<ReduceOp>(
4350 loc, casted_input.getResult(), init,
4351 GetI64ElementsAttr(xla_dimensions, &rewriter));
4352 BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
4353 &rewriter);
4354 Value result = reduction.getResult(0);
4355
4356 // The mean op needs to divide by the product of the reduced dimensions.
4357 if (std::is_same<OpTy, TF::MeanOp>::value) {
4358 Value in_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4359 Value divisor_count = rewriter.create<ConstantIndexOp>(loc, 1);
4360 for (size_t i = 0; i < input_shape.size(); ++i) {
4361 if (reduced_dimensions_bitmap[i]) {
4362 Value index = rewriter.create<ConstantIndexOp>(loc, i);
4363 auto dim = rewriter.create<tensor::ExtractOp>(loc, in_shape, index);
4364 divisor_count = rewriter.create<MulIOp>(loc, divisor_count, dim);
4365 }
4366 }
4367 // HLO ops are only defined on tensors, so we cast the divisor from
4368 // index -> i64 -> tensor<1xi64> -> tensor<i64> -> tensor<reduction type>
4369 auto divisor_casted = rewriter.create<IndexCastOp>(
4370 loc, rewriter.getI64Type(), divisor_count);
4371 auto divisor_tensor = rewriter.create<tensor::FromElementsOp>(
4372 loc, rewriter.getI64Type(), ValueRange{divisor_casted});
4373 auto divisor_reshaped = rewriter.create<mhlo::ReshapeOp>(
4374 loc, RankedTensorType::get({}, rewriter.getI64Type()),
4375 divisor_tensor);
4376 auto divisor = rewriter.create<ConvertOp>(
4377 loc, RankedTensorType::get({}, reduce_element_type),
4378 divisor_reshaped);
4379 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
4380 result = rewriter.create<chlo::BroadcastDivOp>(loc, result, divisor,
4381 broadcast_dims);
4382 }
4383
4384 result = rewriter.create<ConvertOp>(loc, result, element_type);
4385
4386 // Need to reshape back after the reduction if we're keeping the reduced
4387 // dimensions. Note that we do this through successive (nominally 1)
4388 // applications of the TF ExpandDims op vs a more labor intensive
4389 // reshape. Various code generation techniques benefit from the knowledge
4390 // that this is a restricted form of shape manipulation that is just adding
4391 // unit dims.
4392 if (op.keep_dims()) {
4393 for (auto dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) {
4394 if (dim_is_reduced.value()) {
4395 auto index_attr = GetI32ElementsAttr(
4396 {static_cast<int>(dim_is_reduced.index())}, &rewriter);
4397 Value index = rewriter.create<ConstantOp>(loc, index_attr);
4398 result = rewriter.create<TF::ExpandDimsOp>(loc, result, index);
4399 }
4400 }
4401 }
4402 rewriter.replaceOp(op, {result});
4403
4404 return success();
4405 }
4406 };
4407
4408 // Converts Mean op to HLO Reduce op.
4409 //
4410 // %init = constant dense<...> : tensor<T>
4411 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4412 // {dimensions = ...}
4413 // %divisor = constant dense<...> : tensor<T>
4414 // %mean = "mhlo.divide"(%sum, %divisor)
4415 class ConvertMeanOp
4416 : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
4417 public:
4418 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4419 static Value GetInitialValue(Type reduce_element_type, Location loc,
4420 PatternRewriter *rewriter) {
4421 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4422 }
4423 };
4424
4425 // Converts Sum op to HLO Reduce op.
4426 //
4427 // %init = constant dense<...> : tensor<T>
4428 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4429 // {dimensions = ...}
4430 class ConvertSumOp
4431 : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
4432 public:
4433 using GenericConvertReductionOp::GenericConvertReductionOp;
4434
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4435 static Value GetInitialValue(Type reduce_element_type, Location loc,
4436 PatternRewriter *rewriter) {
4437 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4438 }
4439 };
4440
4441 // Converts Max op to HLO Reduce op.
4442 //
4443 // %init = constant dense<...> : tensor<T>
4444 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
4445 // {dimensions = ...}
4446 class ConvertMaxOp
4447 : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
4448 /* is_accumulation= */ false> {
4449 public:
4450 using GenericConvertReductionOp::GenericConvertReductionOp;
4451
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4452 static Value GetInitialValue(Type reduce_element_type, Location loc,
4453 PatternRewriter *rewriter) {
4454 return GetScalarLimitConstOfType(reduce_element_type, loc,
4455 hlo::kInfinityLowest, rewriter);
4456 }
4457 };
4458
4459 // Converts Min op to HLO Reduce op.
4460 //
4461 // %init = constant dense<...> : tensor<T>
4462 // %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
4463 // {dimensions = ...}
4464 class ConvertMinOp
4465 : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
4466 /* is_accumulation= */ false> {
4467 public:
4468 using GenericConvertReductionOp::GenericConvertReductionOp;
4469
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4470 static Value GetInitialValue(Type reduce_element_type, Location loc,
4471 PatternRewriter *rewriter) {
4472 return GetScalarLimitConstOfType(reduce_element_type, loc,
4473 hlo::kInfinityMax, rewriter);
4474 }
4475 };
4476
4477 // Converts Prod op to HLO Reduce op.
4478 //
4479 // %init = constant dense<...> : tensor<T>
4480 // %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
4481 // {dimensions = ...}
4482 class ConvertProdOp
4483 : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
4484 public:
4485 using GenericConvertReductionOp::GenericConvertReductionOp;
4486
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4487 static Value GetInitialValue(Type reduce_element_type, Location loc,
4488 PatternRewriter *rewriter) {
4489 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4490 }
4491 };
4492
4493 // Converts All op to HLO Reduce op.
4494 //
4495 // %init = constant dense<...> : tensor<T>
4496 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
4497 // {dimensions = ...}
4498 class ConvertAllOp
4499 : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
4500 public:
4501 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4502 static Value GetInitialValue(Type reduce_element_type, Location loc,
4503 PatternRewriter *rewriter) {
4504 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4505 }
4506 };
4507
4508 // Converts Any op to HLO Reduce op.
4509 //
4510 // %init = constant dense<...> : tensor<T>
4511 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
4512 // {dimensions = ...}
4513 class ConvertAnyOp
4514 : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
4515 public:
4516 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4517 static Value GetInitialValue(Type reduce_element_type, Location loc,
4518 PatternRewriter *rewriter) {
4519 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4520 }
4521 };
4522
4523 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
4524 // a reduction on the original input and the corresponding index. The reduction
4525 // sub-computation selects the max (or min) value and the index for the value.
4526 // Derived: is the resulting derived class of this class.
4527 // OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
4528 template <typename Derived, typename OpTy>
4529 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
4530 using OpRewritePattern<OpTy>::OpRewritePattern;
4531
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4532 LogicalResult matchAndRewrite(OpTy op,
4533 PatternRewriter &rewriter) const override {
4534 RankedTensorType input_type =
4535 op.input().getType().template dyn_cast<RankedTensorType>();
4536 if (!input_type) {
4537 return failure();
4538 }
4539
4540 Type input_element_type = input_type.getElementType();
4541 // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
4542 // tf.ArgMax doesn't support complex data types, this check can be removed.
4543 if (!input_element_type.isSignlessIntOrFloat()) return failure();
4544
4545 Location loc = op.getLoc();
4546 Value init_value =
4547 Derived::GetInitialValue(input_element_type, loc, rewriter);
4548
4549 RankedTensorType output_type =
4550 op.output().getType().template dyn_cast<RankedTensorType>();
4551 if (!output_type) {
4552 return rewriter.notifyMatchFailure(op, "requires known rank");
4553 }
4554
4555 Type index_element_type = output_type.getElementType();
4556 Value index_init_value =
4557 GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
4558
4559 RankedTensorType index_type =
4560 RankedTensorType::get(input_type.getShape(), index_element_type);
4561
4562 llvm::Optional<int64_t> optional_axis =
4563 GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
4564 if (!optional_axis.hasValue())
4565 return rewriter.notifyMatchFailure(op, "required axis");
4566 int64_t axis = optional_axis.getValue();
4567
4568 IntegerAttr iota_dimension =
4569 IntegerAttr::get(rewriter.getIntegerType(64), axis);
4570 Value index_values =
4571 rewriter.create<IotaOp>(loc, index_type, iota_dimension);
4572
4573 std::vector<int64_t> dimensions = input_type.getShape();
4574 dimensions.erase(dimensions.begin() + axis);
4575 ArrayRef<int64_t> reduction_result_shape(dimensions);
4576
4577 Value operands[] = {op.input(), index_values};
4578 Value init_values[] = {init_value, index_init_value};
4579 DenseIntElementsAttr reduction_dimensions =
4580 GetI64ElementsAttr({axis}, &rewriter);
4581
4582 auto reduction = rewriter.create<ReduceOp>(
4583 loc, llvm::ArrayRef<Value>(operands),
4584 llvm::ArrayRef<Value>(init_values), reduction_dimensions);
4585 StringRef direction = Derived::GetDirection();
4586 BuildArgMinMaxReductionBody(input_element_type, index_element_type,
4587 direction, &reduction.body(), &rewriter);
4588
4589 rewriter.replaceOp(op, {reduction.getResult(1)});
4590 return success();
4591 }
4592 };
4593
4594 // Converts tensorflow ArgMax op to mhlo operations. The actual
4595 // implementation is in class ConvertArgMinMaxOp:
4596 //
4597 // %init_index = constant dense<...> : tensor<T>
4598 // %init = constant dense<...> : tensor<T>
4599 // %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4600 // %init_index) ["mhlo.arg_max"]
4601 class ConvertArgMaxOp
4602 : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
4603 public:
4604 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4605
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4606 static Value GetInitialValue(Type reduce_element_type, Location loc,
4607 PatternRewriter &rewriter) {
4608 return GetScalarLimitConstOfType(reduce_element_type, loc,
4609 hlo::kInfinityLowest, &rewriter);
4610 }
4611
GetDirection()4612 static StringRef GetDirection() { return "GE"; }
4613 };
4614
4615 // Converts tensorflow ArgMin op to mhlo operations. The actual
4616 // implementation is in class ConvertArgMinMaxOp:
4617 //
4618 // %init_index = constant dense<...> : tensor<T>
4619 // %init = constant dense<...> : tensor<T>
4620 // %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4621 // %init_index) ["mhlo.arg_min"]
4622 class ConvertArgMinOp
4623 : public ConvertArgMinMaxOp<ConvertArgMinOp, TF::ArgMinOp> {
4624 public:
4625 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4626
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4627 static Value GetInitialValue(Type reduce_element_type, Location loc,
4628 PatternRewriter &rewriter) {
4629 return GetScalarLimitConstOfType(reduce_element_type, loc,
4630 hlo::kInfinityMax, &rewriter);
4631 }
4632
GetDirection()4633 static StringRef GetDirection() { return "LE"; }
4634 };
4635
4636 // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with
4637 // assignment:
4638 //
4639 // %result = "mhlo.scatter"(%tensor, %indices, %updates)
4640 // { dimensions = ... }
4641 //
4642 template <typename Derived, typename OpTy>
4643 class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
4644 public:
4645 using OpRewritePattern<OpTy>::OpRewritePattern;
4646
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4647 LogicalResult matchAndRewrite(OpTy op,
4648 PatternRewriter &rewriter) const override {
4649 auto tensor_ty =
4650 op.tensor().getType().template dyn_cast<RankedTensorType>();
4651 auto indices_ty =
4652 op.indices().getType().template dyn_cast<RankedTensorType>();
4653 auto updates_ty =
4654 op.updates().getType().template dyn_cast<RankedTensorType>();
4655
4656 if (!tensor_ty || !indices_ty || !updates_ty) return failure();
4657 // Last dimension of the indices needs to known at compile time for
4658 // computation of the 'update_window_dims' attribute in the dimensions
4659 // struct.
4660 int64_t num_index_dims = indices_ty.getShape().back();
4661 if (ShapedType::isDynamic(num_index_dims)) return failure();
4662
4663 int64_t tensor_rank = tensor_ty.getRank();
4664 int64_t indices_rank = indices_ty.getRank();
4665 int64_t updates_rank = updates_ty.getRank();
4666
4667 int64_t window_dims = tensor_rank - num_index_dims;
4668 auto dims_attr = ScatterDimensionNumbers::get(
4669 GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
4670 &rewriter),
4671 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
4672 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
4673 rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
4674
4675 Location loc = op.getLoc();
4676 auto scatter = rewriter.create<ScatterOp>(
4677 loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
4678 Derived::BuildScatterBody(tensor_ty.getElementType(),
4679 &scatter.update_computation(), loc, rewriter);
4680
4681 rewriter.replaceOp(op, scatter.getResult());
4682 return success();
4683 }
4684 };
4685
4686 class ConvertTensorScatterUpdateOp
4687 : public ConvertTensorScatterOp<ConvertTensorScatterUpdateOp,
4688 TF::TensorScatterUpdateOp> {
4689 public:
4690 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4691
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4692 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4693 OpBuilder &builder) {
4694 OpBuilder::InsertionGuard guard(builder);
4695 Block *block = builder.createBlock(region);
4696 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4697 block->addArguments({type, type});
4698 builder.create<ReturnOp>(loc, block->getArgument(1));
4699 }
4700 };
4701
4702 class ConvertTensorScatterAddOp
4703 : public ConvertTensorScatterOp<ConvertTensorScatterAddOp,
4704 TF::TensorScatterAddOp> {
4705 public:
4706 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4707
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4708 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4709 OpBuilder &builder) {
4710 OpBuilder::InsertionGuard guard(builder);
4711 Block *block = builder.createBlock(region);
4712 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4713 block->addArguments({type, type});
4714 auto add_op = builder.create<AddOp>(loc, block->getArgument(0),
4715 block->getArgument(1));
4716 builder.create<ReturnOp>(loc, add_op.getResult());
4717 }
4718 };
4719
4720 class ConvertTensorScatterSubOp
4721 : public ConvertTensorScatterOp<ConvertTensorScatterSubOp,
4722 TF::TensorScatterSubOp> {
4723 public:
4724 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4725
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4726 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4727 OpBuilder &builder) {
4728 OpBuilder::InsertionGuard guard(builder);
4729 Block *block = builder.createBlock(region);
4730 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4731 block->addArguments({type, type});
4732 auto sub_op = builder.create<SubOp>(loc, block->getArgument(0),
4733 block->getArgument(1));
4734 builder.create<ReturnOp>(loc, sub_op.getResult());
4735 }
4736 };
4737
4738 class ConvertTensorScatterMinOp
4739 : public ConvertTensorScatterOp<ConvertTensorScatterMinOp,
4740 TF::TensorScatterMinOp> {
4741 public:
4742 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4743
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4744 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4745 OpBuilder &builder) {
4746 OpBuilder::InsertionGuard guard(builder);
4747 Block *block = builder.createBlock(region);
4748 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4749 block->addArguments({type, type});
4750 auto min_op = builder.create<MinOp>(loc, block->getArgument(0),
4751 block->getArgument(1));
4752 builder.create<ReturnOp>(loc, min_op.getResult());
4753 }
4754 };
4755
4756 class ConvertTensorScatterMaxOp
4757 : public ConvertTensorScatterOp<ConvertTensorScatterMaxOp,
4758 TF::TensorScatterMaxOp> {
4759 public:
4760 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4761
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4762 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4763 OpBuilder &builder) {
4764 OpBuilder::InsertionGuard guard(builder);
4765 Block *block = builder.createBlock(region);
4766 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4767 block->addArguments({type, type});
4768 auto max_op = builder.create<MaxOp>(loc, block->getArgument(0),
4769 block->getArgument(1));
4770 builder.create<ReturnOp>(loc, max_op.getResult());
4771 }
4772 };
4773
4774 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
4775 // For shape [S1, S2] and multiples [M1, M2],
4776 // MS1 = M1 * S1; MS2 = M2 * S2
4777 //
4778 // %broadcast = mhlo.broadcast_in_dim(%input) {
4779 // broadcast_dimensions = [0, 2]
4780 // }
4781 // %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
4782 // -> tensor<MS1xMS2xf32>
4783 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
4784 public:
4785 using OpRewritePattern::OpRewritePattern;
4786
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4787 LogicalResult matchAndRewrite(TF::TileOp op,
4788 PatternRewriter &rewriter) const override {
4789 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
4790 if (!input_ty || !input_ty.hasStaticShape()) return failure();
4791 ArrayRef<int64_t> input_shape = input_ty.getShape();
4792 Type element_type = input_ty.getElementType();
4793
4794 DenseIntElementsAttr multiples;
4795 if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
4796 multiples.getType().getRank() != 1)
4797 return failure();
4798
4799 const int64_t input_shape_size = input_shape.size();
4800 if (multiples.getNumElements() != input_shape_size) return failure();
4801
4802 SmallVector<int64_t, 8> broadcasted_shape;
4803 SmallVector<int64_t, 4> broadcast_dimensions;
4804 broadcasted_shape.reserve(input_shape.size() * 2);
4805 broadcast_dimensions.reserve(input_shape.size());
4806 for (auto multiple_and_input :
4807 llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4808 int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4809 int64_t input_size = std::get<1>(multiple_and_input);
4810
4811 if (multiple < 0) return failure();
4812
4813 // Line input up with the next dimension in broadcasted_shape
4814 // when broadcasting.
4815 int64_t broadcast_dim;
4816 int64_t output_size = input_size * multiple;
4817 if (input_size == 1 || multiple == 1) {
4818 // Special case for when normal broadcasting will just work.
4819 broadcast_dim = broadcasted_shape.size();
4820 broadcasted_shape.push_back(output_size);
4821 } else {
4822 // Tiling will happen for this dimension during the ReshapeOp below.
4823 broadcasted_shape.push_back(multiple);
4824 broadcast_dim = broadcasted_shape.size();
4825 broadcasted_shape.push_back(input_size);
4826 }
4827 broadcast_dimensions.push_back(broadcast_dim);
4828 }
4829 Location loc = op.getLoc();
4830 Type broadcasted_type =
4831 RankedTensorType::get(broadcasted_shape, element_type);
4832 Type output_type = op.getType();
4833
4834 Value result = rewriter.create<BroadcastInDimOp>(
4835 loc, broadcasted_type, op.input(),
4836 GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4837
4838 if (output_type != broadcasted_type) {
4839 result = rewriter.create<ReshapeOp>(loc, output_type, result);
4840 }
4841
4842 rewriter.replaceOp(op, {result});
4843
4844 return success();
4845 }
4846 };
4847
4848 // Converts the tf.TileOp op into mhlo.dynamic_reshape
4849 // TODO(disc): To recover static special case's performance with folding and
4850 // canonicalization.
4851 class ConvertTileOpDynamic : public OpRewritePattern<TF::TileOp> {
4852 public:
4853 using OpRewritePattern::OpRewritePattern;
4854 // clang-format off
4855 // Converts Tile op to HLO DBroadcastInDim and DReshape ops.
4856 // For shape [S1, S2] and multiples [M1, M2],
4857 // MS1 = M1 * S1; MS2 = M2 * S2
4858 //
4859 // %out_dim_size = [S1, M1, S2, M2]
4860 // %broadcast_dimensions = [1, 3];
4861 // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions);
4862 // %shape = [MS1, MS2]
4863 // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor<S1xM1xS2xM2xf32>) -> tensor<MS1xMS2xf32>
4864 // clang-format on
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4865 LogicalResult matchAndRewrite(TF::TileOp op,
4866 PatternRewriter &rewriter) const final {
4867 Location loc = op.getLoc();
4868 Value input = op.input();
4869 Value multiples = op.multiples();
4870 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
4871 if (!input_ty) return failure();
4872 // TODO(disc): Remove this constraint once fold and canonicalization
4873 // implemented.
4874 if (input_ty.hasStaticShape()) return failure();
4875
4876 Type element_type = input_ty.getElementType();
4877 int64_t input_rank = input_ty.getRank();
4878 SmallVector<Value, 4> input_shape_values;
4879 for (int64_t i = 0; i < input_rank; ++i) {
4880 auto dim_size = input_ty.getDimSize(i);
4881 if (dim_size == ShapedType::kDynamicSize) {
4882 input_shape_values.push_back(
4883 rewriter.create<tensor::DimOp>(loc, input, i));
4884 } else {
4885 input_shape_values.push_back(
4886 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(dim_size)));
4887 }
4888 }
4889
4890 auto multiples_ty = multiples.getType().dyn_cast<RankedTensorType>();
4891 int64_t multiples_rank = multiples_ty.getRank();
4892 // rank of multiples input of tf.TileOp must be 1
4893 if (multiples_rank != 1) return failure();
4894 // multiples input of tf.TileOp must be fixed shaped
4895 if ((!multiples_ty.hasStaticShape()) ||
4896 (multiples_ty.getDimSize(0) != input_rank)) {
4897 return failure();
4898 }
4899 // %out_dim_size
4900 SmallVector<Value, 4> out_dim_size;
4901 out_dim_size.reserve(input_rank * 2);
4902 for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4903 Value index =
4904 rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(dim_idx));
4905 Value multiples_size =
4906 rewriter.create<tensor::ExtractOp>(loc, multiples, ValueRange{index});
4907 Value multiples_size_casted = rewriter.create<IndexCastOp>(
4908 loc, rewriter.getIndexType(), multiples_size);
4909 out_dim_size.push_back(multiples_size_casted);
4910 out_dim_size.push_back(input_shape_values[dim_idx]);
4911 }
4912 SmallVector<int64_t, 4> broadcast_dimensions;
4913 broadcast_dimensions.reserve(input_rank);
4914 for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4915 broadcast_dimensions.push_back(1 + 2 * dim_idx);
4916 }
4917 auto broadcast_dims_attr =
4918 GetI64ElementsAttr(broadcast_dimensions, &rewriter);
4919
4920 Value out_dim_size_tensor = rewriter.create<tensor::FromElementsOp>(
4921 loc, rewriter.getIndexType(), out_dim_size);
4922 SmallVector<int64_t, 4> broadcast_shape(input_rank * 2,
4923 ShapedType::kDynamicSize);
4924 RankedTensorType broadcast_type =
4925 RankedTensorType::get(broadcast_shape, element_type);
4926 Value broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
4927 loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr);
4928
4929 // %shape = [MS1, MS2]
4930 SmallVector<Value, 4> shape_values;
4931 shape_values.reserve(input_rank);
4932 for (int64_t i = 0; i < input_rank; ++i) {
4933 Value dim_size_value = rewriter.create<mlir::MulIOp>(
4934 loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]);
4935 shape_values.push_back(dim_size_value);
4936 }
4937 Value shape = rewriter.create<tensor::FromElementsOp>(
4938 loc, rewriter.getIndexType(), shape_values);
4939 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
4940 broadcast, shape);
4941 return success();
4942 }
4943 };
4944
4945 template <typename OpTy, int num_dims>
4946 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4947 public:
4948 using OpRewritePattern<OpTy>::OpRewritePattern;
4949
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4950 LogicalResult matchAndRewrite(OpTy op,
4951 PatternRewriter &rewriter) const override {
4952 Location loc = op.getLoc();
4953
4954 Type element_type =
4955 op.orig_input().getType().template cast<TensorType>().getElementType();
4956
4957 // Compute paddings using the original input and kernel shape and strides.
4958 // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4959 // ReduceWindow op.
4960 auto input_ty =
4961 op.orig_input().getType().template dyn_cast<RankedTensorType>();
4962 if (!input_ty) return failure();
4963 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4964 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4965
4966 auto result = rewriter.create<SelectAndScatterOp>(
4967 loc, op.getType(), op.orig_input(), op.grad(),
4968 GetScalarConstOfType(element_type, loc, 0, &rewriter),
4969 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4970 paddings_attr);
4971
4972 BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4973 {
4974 OpBuilder::InsertionGuard guard(rewriter);
4975 Block *block = rewriter.createBlock(&result.select());
4976
4977 // Block arguments are scalars of the given element type.
4978 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4979 block->addArguments({type, type});
4980
4981 auto reducer = rewriter.create<CompareOp>(
4982 loc, block->getArgument(0), block->getArgument(1),
4983 StringAttr::get(rewriter.getContext(), "GE"));
4984 rewriter.create<ReturnOp>(loc, reducer.getResult());
4985 }
4986
4987 rewriter.replaceOp(op, {result});
4988
4989 return success();
4990 }
4991 };
4992
4993 using ConvertMaxPool2DGradOp =
4994 ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4995 using ConvertMaxPool3DGradOp =
4996 ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4997
4998 // Converts tf.Conv?DBackpropInputOp into:
4999 // %rev_filter = "mhlo.reverse"(%filter)
5000 // %result = "mhlo.convolution"(%out_backprop, %rev_filter)
5001 template <typename OpTy, int num_spatial_dims>
5002 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
5003 public:
5004 using OpRewritePattern<OpTy>::OpRewritePattern;
5005
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5006 LogicalResult matchAndRewrite(OpTy op,
5007 PatternRewriter &rewriter) const override {
5008 // Unpack all of the attributes.
5009 tensorflow::TensorFormat data_format;
5010 if (!FormatFromString(op.data_format().str(), &data_format))
5011 return op.emitOpError("invalid data format");
5012
5013 tensorflow::Padding padding;
5014 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5015 return failure();
5016
5017 auto out_backprop_ty =
5018 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5019 auto filter_ty =
5020 op.filter().getType().template dyn_cast<RankedTensorType>();
5021
5022 for (RankedTensorType ty : {out_backprop_ty, filter_ty})
5023 if (!ty || !ty.hasStaticShape()) return failure();
5024
5025 DenseIntElementsAttr input_shape_attr;
5026 if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
5027 input_shape_attr.getType().getRank() != 1)
5028 return failure();
5029
5030 auto input_shape = input_shape_attr.getValues<int32_t>();
5031
5032 auto dilations_attr = GetI64ElementsAttr(op.dilations());
5033 std::vector<int> dilations{
5034 dilations_attr.template getValues<int64_t>().begin(),
5035 dilations_attr.template getValues<int64_t>().end()};
5036 auto strides_attr = GetI64ElementsAttr(op.strides());
5037 std::vector<tensorflow::int32> strides{
5038 strides_attr.template getValues<int64_t>().begin(),
5039 strides_attr.template getValues<int64_t>().end()};
5040
5041 std::vector<tensorflow::int64> explicit_paddings;
5042 if (padding == tensorflow::Padding::EXPLICIT) {
5043 // EXPLICIT padding mode and the associated attribute is limited to
5044 // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
5045 // op.explicit_paddings() attribute getter.
5046 ArrayRef<Attribute> explicit_paddings_attr =
5047 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5048 explicit_paddings.reserve(explicit_paddings_attr.size());
5049 for (Attribute explicit_padding : explicit_paddings_attr)
5050 explicit_paddings.push_back(
5051 explicit_padding.cast<IntegerAttr>().getInt());
5052 }
5053
5054 constexpr int num_dims = num_spatial_dims + 2;
5055 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
5056
5057 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5058 tensorflow::ConvBackpropDimensions dims;
5059 if (!tensorflow::ConvBackpropComputeDimensionsV2(
5060 /*label=*/"", num_spatial_dims,
5061 ToTensorShape<int32_t, num_dims>(input_shape),
5062 ToTensorShape<int64_t, num_dims>(filter_shape),
5063 ToTensorShape<int64_t, num_dims>(out_backprop_ty.getShape()),
5064 dilations, strides, padding, explicit_paddings, data_format, &dims)
5065 .ok()) {
5066 return failure();
5067 }
5068
5069 // Compute ConvDimensionNumbers, dilation, and padding.
5070 SmallVector<int64_t, num_spatial_dims> spatial_dims;
5071 SmallVector<int64_t, num_spatial_dims> lhs_dilation;
5072 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5073 SmallVector<int64_t, num_spatial_dims * 2> paddings;
5074
5075 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5076 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
5077 spatial_dims.push_back(dim);
5078 const auto &spatial_dim_i = dims.spatial_dims[i];
5079 lhs_dilation.push_back(spatial_dim_i.stride);
5080 rhs_dilation.push_back(dilations[dim]);
5081 paddings.push_back(spatial_dim_i.pad_before);
5082 paddings.push_back(spatial_dim_i.pad_after);
5083 }
5084
5085 RankedTensorType paddings_ty = RankedTensorType::get(
5086 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5087 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5088
5089 auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
5090
5091 Value filter = op.filter();
5092
5093 const int feature_dim =
5094 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5095 const int64_t in_depth = *(input_shape.begin() + feature_dim);
5096 const int64_t filter_in_depth = filter_shape[num_spatial_dims];
5097 const int64_t feature_group_count = in_depth / filter_in_depth;
5098
5099 if (feature_group_count != 1) {
5100 // 1. Reshape filter from
5101 // [H, W, ..., filter_in_depth, out_depth] to
5102 // [H, W, ..., filter_in_depth, G, out_depth / G].
5103 auto new_shape = llvm::to_vector<6>(filter_shape);
5104 new_shape.back() = feature_group_count;
5105 new_shape.push_back(filter_shape.back() / feature_group_count);
5106 Type filter_element_ty = filter_ty.getElementType();
5107 auto ty = RankedTensorType::get(new_shape, filter_element_ty);
5108 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5109
5110 // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
5111 llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
5112 std::iota(perm.begin(), perm.end(), 0);
5113 std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
5114 std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
5115 ty = RankedTensorType::get(new_shape, filter_element_ty);
5116 filter = rewriter.create<TransposeOp>(
5117 op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
5118
5119 // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
5120 new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
5121 new_shape[num_spatial_dims + 1] = new_shape.back();
5122 new_shape.pop_back();
5123 ty = RankedTensorType::get(new_shape, filter_element_ty);
5124 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5125 }
5126
5127 auto kernel_spatial_dims_attr =
5128 GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
5129
5130 // Mirror the filter in the spatial dimensions.
5131 filter = rewriter.create<ReverseOp>(op.getLoc(), filter,
5132 kernel_spatial_dims_attr);
5133
5134 const int batch_dim =
5135 tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5136 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
5137 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
5138
5139 // activation gradients
5140 // = gradients (with padding and dilation) <conv> mirrored_weights
5141 Value result = rewriter.create<ConvOp>(
5142 op.getLoc(), op.getType(), op.out_backprop(), filter,
5143 /*window_strides=*/
5144 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5145 &rewriter),
5146 /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
5147 GetI64ElementsAttr(rhs_dilation, &rewriter),
5148 /*window_reversal=*/nullptr,
5149 ConvDimensionNumbers::get(
5150 /*input_batch_dimension=*/batch_dim_attr,
5151 /*input_feature_dimension=*/feature_dim_attr,
5152 /*input_spatial_dimensions=*/spatial_dims_attr,
5153 // TF filter shape is [ H, W, ..., inC, outC ]
5154 // Transpose the input and output features for computing the
5155 // gradient.
5156 /*kernel_input_feature_dimension=*/
5157 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
5158 /*kernel_output_feature_dimension=*/
5159 rewriter.getI64IntegerAttr(num_spatial_dims),
5160 /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
5161 /*output_batch_dimension=*/batch_dim_attr,
5162 /*output_feature_dimension=*/feature_dim_attr,
5163 /*output_spatial_dimensions=*/spatial_dims_attr,
5164 rewriter.getContext()),
5165 rewriter.getI64IntegerAttr(feature_group_count),
5166 /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
5167 /*precision_config=*/ArrayAttr());
5168
5169 rewriter.replaceOp(op, {result});
5170
5171 return success();
5172 }
5173 };
5174
5175 using ConvertConv2DBackpropInputOp =
5176 ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
5177 /*num_spatial_dims=*/2>;
5178 using ConvertConv3DBackpropInputOp =
5179 ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
5180 /*num_spatial_dims=*/3>;
5181
5182 // Converts tf.Conv?DBackpropFilterOp into:
5183 // %result = "mhlo.convolution"(%input, %out_backprop)
5184 template <typename OpTy, int num_spatial_dims>
5185 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
5186 public:
5187 using OpRewritePattern<OpTy>::OpRewritePattern;
5188
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5189 LogicalResult matchAndRewrite(OpTy op,
5190 PatternRewriter &rewriter) const override {
5191 // Unpack all of the attributes.
5192 tensorflow::TensorFormat data_format;
5193 if (!FormatFromString(op.data_format().str(), &data_format))
5194 return op.emitOpError("invalid data format");
5195
5196 tensorflow::Padding padding;
5197 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5198 return failure();
5199
5200 auto out_backprop_ty =
5201 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5202 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
5203
5204 for (RankedTensorType ty : {out_backprop_ty, input_ty})
5205 if (!ty || !ty.hasStaticShape()) return failure();
5206
5207 ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
5208 ArrayRef<int64_t> input_shape = input_ty.getShape();
5209
5210 DenseIntElementsAttr filter_shape_attr;
5211 if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
5212 filter_shape_attr.getType().getRank() != 1)
5213 return failure();
5214
5215 auto dilations_attr = GetI64ElementsAttr(op.dilations());
5216 std::vector<int> dilations{
5217 dilations_attr.template getValues<int64_t>().begin(),
5218 dilations_attr.template getValues<int64_t>().end()};
5219 auto strides_attr = GetI64ElementsAttr(op.strides());
5220 std::vector<tensorflow::int32> strides{
5221 strides_attr.template getValues<int64_t>().begin(),
5222 strides_attr.template getValues<int64_t>().end()};
5223
5224 std::vector<tensorflow::int64> explicit_paddings;
5225 if (padding == tensorflow::Padding::EXPLICIT) {
5226 // EXPLICIT padding mode and the associated attribute is limited to
5227 // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
5228 // op.explicit_paddings() attribute getter.
5229 ArrayRef<Attribute> explicit_paddings_attr =
5230 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5231 explicit_paddings.reserve(explicit_paddings_attr.size());
5232 for (Attribute explicit_padding : explicit_paddings_attr)
5233 explicit_paddings.push_back(
5234 explicit_padding.cast<IntegerAttr>().getInt());
5235 }
5236
5237 constexpr int num_dims = num_spatial_dims + 2;
5238 auto filter_shape = filter_shape_attr.getValues<int32_t>();
5239
5240 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5241 tensorflow::ConvBackpropDimensions dims;
5242 if (!tensorflow::ConvBackpropComputeDimensionsV2(
5243 /*label=*/"", num_spatial_dims,
5244 ToTensorShape<int64_t, num_dims>(input_shape),
5245 ToTensorShape<int32_t, num_dims>(filter_shape),
5246 ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
5247 strides, padding, explicit_paddings, data_format, &dims)
5248 .ok()) {
5249 return failure();
5250 }
5251
5252 // The activations (inputs) form the LHS of the convolution.
5253 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
5254 // For the gradient computation, we need to:
5255 // 1. In the case of group convolution, move the num_groups dimension before
5256 // the batch dimension
5257 // 2. Swap the roles of the batch and feature dimensions.
5258 const int feature_dim =
5259 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5260 const int64_t in_depth = input_shape[feature_dim];
5261 const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
5262 const int64_t batch_group_count = in_depth / filter_in_depth;
5263
5264 // Compute ConvDimensionNumbers, dilation, and padding.
5265 SmallVector<int64_t, num_spatial_dims> spatial_dims;
5266 SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
5267 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5268 SmallVector<int64_t, num_spatial_dims * 2> paddings;
5269 SmallVector<int64_t, num_spatial_dims> window_strides;
5270
5271 // The filter gradients are computed by a convolution of the input
5272 // activations and the output gradients, with some appropriate padding.
5273 // See the comment at the top of conv_grad_ops.h for details.
5274
5275 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5276 const int64_t dim =
5277 tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
5278 kernel_spatial_dims.push_back(dim);
5279 // Besides padding the input, we will also expand output_rows to
5280 // expanded_out_rows = (output_rows - 1) * stride + 1
5281 // with zeros in between:
5282 //
5283 // a . . . b . . . c . . . d . . . e
5284 //
5285 // This is done by specifying the window dilation factors in the
5286 // convolution HLO below.
5287 const auto &spatial_dim_i = dims.spatial_dims[i];
5288 rhs_dilation.push_back(spatial_dim_i.stride);
5289 window_strides.push_back(dilations[dim]);
5290
5291 // We will also need to pad the input with zeros such that after the
5292 // convolution, we get the right size for the filter.
5293 // The padded_in_rows should be such that when we convolve this with the
5294 // expanded_out_rows as a filter, we should get filter_rows back.
5295
5296 const int64_t padded_in_size =
5297 spatial_dim_i.expanded_output_size +
5298 (spatial_dim_i.filter_size - 1) * dilations[dim];
5299
5300 // However it can be smaller than input_rows: in this
5301 // case it means some of the inputs are not used.
5302 //
5303 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
5304 //
5305 // INPUT = [ A B C ]
5306 //
5307 // FILTER = [ x y ]
5308 //
5309 // and the output will only have one column: a = A * x + B * y
5310 //
5311 // and input "C" is not used at all.
5312 //
5313 // We apply negative padding in this case.
5314 const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
5315
5316 // + For the EXPLICIT padding, we pad the top/left side with the explicit
5317 // padding and pad the bottom/right side with the remaining space.
5318 // + For the VALID padding, we don't pad anything on the top/left side
5319 // and pad the bottom/right side with the remaining space.
5320 // + For the SAME padding, we pad top/left side the same as bottom/right
5321 // side.
5322 //
5323 // In addition, if the padded input size is smaller than the input size,
5324 // we need to ignore some training elements of the input. We do this by
5325 // applying negative padding on the right/bottom.
5326 const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
5327 ? explicit_paddings[2 * dim]
5328 : padding == tensorflow::Padding::SAME
5329 ? std::max<int64_t>(pad_total / 2, 0)
5330 : 0;
5331 paddings.push_back(pad_before);
5332 paddings.push_back(pad_total - pad_before);
5333 }
5334
5335 RankedTensorType paddings_ty = RankedTensorType::get(
5336 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5337 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5338 auto kernel_spatial_dims_attr =
5339 GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
5340
5341 const int batch_dim =
5342 tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5343 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
5344 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
5345
5346 Value result = rewriter.create<ConvOp>(
5347 op.getLoc(), op.getType(), op.input(), op.out_backprop(),
5348 /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
5349 /*padding=*/paddings_attr, /*lhs_dilation=*/
5350 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5351 &rewriter),
5352 GetI64ElementsAttr(rhs_dilation, &rewriter),
5353 /*window_reversal=*/nullptr,
5354 ConvDimensionNumbers::get(
5355 // Swap batch_dim and feature_dim in the activations.
5356 /*input_batch_dimension=*/feature_dim_attr,
5357 /*input_feature_dimension=*/batch_dim_attr,
5358 /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
5359 // The gradients become the RHS of the convolution.
5360 // The gradients have shape [batch, out_rows, out_cols, ...,
5361 // out_depth] where the batch becomes the input feature for the
5362 // convolution.
5363 /*kernel_input_feature_dimension=*/batch_dim_attr,
5364 /*kernel_output_feature_dimension=*/feature_dim_attr,
5365 /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
5366 /*output_batch_dimension=*/
5367 rewriter.getI64IntegerAttr(num_spatial_dims),
5368 /*output_feature_dimension=*/
5369 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
5370 /*output_spatial_dimensions=*/
5371 GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter),
5372 rewriter.getContext()),
5373 /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
5374 rewriter.getI64IntegerAttr(batch_group_count),
5375 /*precision_config=*/ArrayAttr());
5376
5377 rewriter.replaceOp(op, {result});
5378
5379 return success();
5380 }
5381 };
5382
5383 using ConvertConv2DBackpropFilterOp =
5384 ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
5385 /*num_spatial_dims=*/2>;
5386 using ConvertConv3DBackpropFilterOp =
5387 ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
5388 /*num_spatial_dims=*/3>;
5389
5390 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
5391 public:
5392 using OpRewritePattern::OpRewritePattern;
5393
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const5394 LogicalResult matchAndRewrite(TF::OneHotOp op,
5395 PatternRewriter &rewriter) const override {
5396 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
5397 if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
5398 ArrayRef<int64_t> indices_shape = indices_ty.getShape();
5399 Type element_type = indices_ty.getElementType();
5400
5401 DenseIntElementsAttr depth_attr;
5402 if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
5403 return failure();
5404 }
5405
5406 int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
5407 int64_t axis = op.axis();
5408 if (axis == -1) axis = indices_shape.size();
5409
5410 llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
5411 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
5412 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
5413
5414 llvm::SmallVector<int64_t, 4> output_dims =
5415 llvm::to_vector<4>(indices_shape);
5416 output_dims.insert(output_dims.begin() + axis, depth);
5417
5418 Location loc = op.getLoc();
5419
5420 // The iota result is the effective output shape of the computation,
5421 // and indices must be broadcast into it. At this point, this computation
5422 // would need to be reworked quite a bit to support dynamic shapes, so
5423 // just using static broadcasting.
5424 auto index_type = RankedTensorType::get(output_dims, element_type);
5425 auto iota = rewriter.create<IotaOp>(
5426 loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
5427 auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
5428 loc, index_type, op.indices(),
5429 GetI64ElementsAttr(broadcast_dims, &rewriter));
5430
5431 Value compare = rewriter.create<mhlo::CompareOp>(
5432 loc, broadcast_indices, iota,
5433 StringAttr::get(rewriter.getContext(), "EQ"));
5434 Value on_value = rewriter.create<BroadcastOp>(
5435 loc, op.getType(), op.on_value(),
5436 GetI64ElementsAttr(output_dims, &rewriter));
5437 Value off_value = rewriter.create<BroadcastOp>(
5438 loc, op.getType(), op.off_value(),
5439 GetI64ElementsAttr(output_dims, &rewriter));
5440 Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
5441 on_value, off_value);
5442
5443 rewriter.replaceOp(op, {result});
5444
5445 return success();
5446 }
5447 };
5448
5449 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
5450 // get_tuple_element ops.
5451 //
5452 // All HLO infeed ops expect a HLO token type operand and produce a tuple
5453 // containing a token. This HLO token type is used to order multiple infeed
5454 // operations within a computation. The token type can come from other
5455 // infeed/outfeed/send/recv ops or can be generated using create_token op with
5456 // no operands. Here we emit a create_token op to generate the token type
5457 // operand of infeed.
5458 //
5459 // For example the following IR:
5460 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
5461 //
5462 // would be lowered to
5463 //
5464 // %token = "mhlo.create_token"() : () -> !mhlo.token
5465 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
5466 // (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
5467 // !mhlo.token>
5468 // %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0}
5469 // %0#0 = "mhlo.get_tuple_element"(%data) {index = 0}
5470 // %0#1 = "mhlo.get_tuple_element"(%data) {index = 1}
5471 //
5472 class ConvertInfeedDequeueTupleOp
5473 : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
5474 public:
5475 using OpRewritePattern::OpRewritePattern;
5476
GetTPUInfeedLayoutFromAPI(RankedTensorType t) const5477 FailureOr<std::vector<int64_t>> GetTPUInfeedLayoutFromAPI(
5478 RankedTensorType t) const {
5479 // Call the TPU API to determine the right infeed layout. Note that
5480 // this can fail if we're not running on a TPU-enabled node.
5481 // TODO(kramm): Move this into a separate pass. See b/181724526
5482 xla::Shape old_shape = xla::TypeToShape(t);
5483 XLA_Shape old_shape_c = {};
5484 XLA_Shape new_shape_c = {};
5485 TfTpu_ExecutorApiFn *executor = tensorflow::tpu::ExecutorApiFn();
5486 if (!tensorflow::tpu::IsInitialized(executor)) {
5487 return failure();
5488 }
5489 ApiConverter::ToC(old_shape, &old_shape_c);
5490 executor->TpuTransferManager_GetInfeedLayoutFn(&old_shape_c, &new_shape_c);
5491 xla::Shape new_shape = ApiConverter::FromC(&new_shape_c);
5492 ApiConverter::Free(&old_shape_c);
5493 ApiConverter::Free(&new_shape_c);
5494
5495 xla::Layout layout = new_shape.layout();
5496 auto minor_to_major = layout.minor_to_major();
5497 return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
5498 }
5499
GetLayout(const Type & type,PatternRewriter & rewriter) const5500 FailureOr<Attribute> GetLayout(const Type &type,
5501 PatternRewriter &rewriter) const {
5502 auto i64_type = rewriter.getIntegerType(64);
5503 if (type.isa<TupleType>()) {
5504 TupleType tuple_type = type.dyn_cast<TupleType>();
5505 std::vector<mlir::Attribute> v;
5506 for (const mlir::Type &t : tuple_type.getTypes()) {
5507 auto layout = GetLayout(t, rewriter);
5508 if (failed(layout)) return failure();
5509 v.push_back(layout.getValue());
5510 }
5511 ArrayRef<Attribute> shape(v);
5512 return rewriter.getArrayAttr(shape);
5513 } else if (RankedTensorType t = type.dyn_cast<RankedTensorType>()) {
5514 if (!t.hasStaticShape()) return failure();
5515 auto layout = GetTPUInfeedLayoutFromAPI(t);
5516 std::vector<int64_t> minor_to_major;
5517 if (succeeded(layout)) {
5518 minor_to_major = layout.getValue();
5519 } else {
5520 /* If we're not running on a TPU node, we might not be able to
5521 * actually call the part of the TPU API that gives us layout.
5522 * This happens e.g. for unit tests. Below we just create a reasonable
5523 * layout. We sort by dimension size, which makes the layout agree with
5524 * the "correct" TPU layout in surprisingly many cases.
5525 * Note that the corresponding InfeedEnqueue op will be generated
5526 * through another path, and might still generate an (incompatible)
5527 * layout using the TPU API. Running legalize_tf.cc on non-TPU nodes
5528 * thus is a potential source of bugs.
5529 */
5530 minor_to_major.resize(t.getRank());
5531 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
5532 std::sort(minor_to_major.begin(), minor_to_major.end(),
5533 [=](int64_t a, int64_t b) {
5534 int da = t.getDimSize(a);
5535 int db = t.getDimSize(b);
5536 return da > db || (da == db && a > b);
5537 });
5538 }
5539 std::vector<Attribute> elements;
5540 for (int64_t i = 0; i < minor_to_major.size(); i++) {
5541 elements.push_back(
5542 rewriter.getIntegerAttr(i64_type, minor_to_major[i]));
5543 }
5544 return rewriter.getArrayAttr(elements);
5545 } else {
5546 return rewriter.getUnitAttr(); // e.g. tokens
5547 }
5548 }
5549
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const5550 LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
5551 PatternRewriter &rewriter) const override {
5552 std::vector<Type> result_types(op.outputs().size());
5553 for (auto idx_and_output : llvm::enumerate(op.outputs())) {
5554 result_types[idx_and_output.index()] = (idx_and_output.value().getType());
5555 }
5556 // Infeed takes a single token operand. Generate the token using
5557 // create_token op to pass to the infeed op.
5558 auto token = rewriter.create<CreateTokenOp>(
5559 op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
5560
5561 // Emit infeed op.
5562 // The result type of infeed is a tuple(tuple(result types), token type).
5563 auto data_tuple_type =
5564 mlir::TupleType::get(rewriter.getContext(), result_types);
5565 auto data_and_token_type = mlir::TupleType::get(
5566 rewriter.getContext(), {data_tuple_type, token.getType()});
5567
5568 auto layout = GetLayout(data_and_token_type, rewriter);
5569 if (failed(layout)) return failure();
5570
5571 auto data_and_token = rewriter.create<InfeedOp>(
5572 op.getLoc(), data_and_token_type, token,
5573 /*infeed_config=*/rewriter.getStringAttr(""),
5574 /*layout=*/layout.getValue().cast<ArrayAttr>());
5575
5576 if (op._XlaSharding().hasValue()) {
5577 // _XlaSharding attribute in TF is a serialized string of the OpSharding
5578 // proto, so convert to a text form here.
5579 ::xla::OpSharding sharding_proto;
5580 if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
5581 return failure();
5582
5583 // Token is a control signal and not a real data, so arbitrarily assign
5584 // the token to device 0.
5585 if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
5586 *sharding_proto.add_tuple_shardings() =
5587 ::xla::sharding_builder::AssignDevice(0);
5588 data_and_token->setAttr(
5589 kShardingAttr,
5590 rewriter.getStringAttr(sharding_proto.SerializeAsString()));
5591 } else {
5592 data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
5593 }
5594 }
5595
5596 // The infeed instruction produces a tuple of the infeed data and a token
5597 // type. Emit get_tuple_element to get infeed data tuple.
5598 auto data_tuple = rewriter.create<GetTupleElementOp>(
5599 op.getLoc(), data_tuple_type, data_and_token,
5600 rewriter.getI32IntegerAttr(0));
5601
5602 // Emit get_tuple_element for each result.
5603 std::vector<Value> results;
5604 for (auto idx_and_type : llvm::enumerate(result_types)) {
5605 auto tuple_element = rewriter.create<GetTupleElementOp>(
5606 op.getLoc(), idx_and_type.value(), data_tuple,
5607 rewriter.getI32IntegerAttr(idx_and_type.index()));
5608 results.push_back(tuple_element);
5609 }
5610 rewriter.replaceOp(op, ValueRange(results));
5611 return success();
5612 }
5613 };
5614
5615 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
5616 // ops.
5617 //
5618 // XLA HLO outfeed op expects a token, which we generate by emitting an
5619 // create_token op.
5620 //
5621 // For example the following IR:
5622 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5623 // ()
5624 //
5625 // would be lowered to
5626 //
5627 // %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5628 // tuple<tensor<3xi32>, tensor<4xf32>>
5629 // %token = "mhlo.create_token"() : () -> !mhlo.token
5630 // %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
5631 // (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
5632 //
5633 class ConvertOutfeedEnqueueTupleOp
5634 : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
5635 public:
5636 using OpRewritePattern::OpRewritePattern;
5637
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const5638 LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
5639 PatternRewriter &rewriter) const override {
5640 auto token_type = mhlo::TokenType::get(rewriter.getContext());
5641 auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
5642 auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
5643 rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
5644 /*outfeed_config=*/rewriter.getStringAttr(""));
5645 rewriter.eraseOp(op);
5646 return success();
5647 }
5648 };
5649
5650 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
5651 //
5652 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
5653 // the top K components' values and indices. This is translated into a few
5654 // ops in XLA HLO: first generating an integer sequence for the indices,
5655 // then sort both the original input tensor and the indices togheter, and
5656 // at last slice out the top K components.
5657 //
5658 // For example, for the following IR:
5659 //
5660 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
5661 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
5662 // (tensor<16x8xf32>, tensor<16x8xi32>)
5663 //
5664 // We will get:
5665 //
5666 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
5667 // %2 = "mhlo.sort"(%input, %1) ( {
5668 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
5669 // %arg3: tensor<i32>, %arg4: tensor<i32>):
5670 // %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
5671 // "mhlo.return"(%7) : (tensor<i1>) -> ()
5672 // }) {dimension = 1 : i64, is_stable = true} : ...
5673 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
5674 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
5675 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
5676 // start_indices dense<0> : tensor<2xi64>,
5677 // strides = dense<1> : tensor<2xi64>} :
5678 // (tensor<16x16xf32>) -> tensor<16x8xf32>
5679 // %6 = "mhlo.slice"(%4) ...
5680 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
5681 public:
5682 using OpRewritePattern::OpRewritePattern;
5683
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const5684 LogicalResult matchAndRewrite(TF::TopKV2Op op,
5685 PatternRewriter &rewriter) const override {
5686 // We can only match when the `k` operand is a constant scalar.
5687 DenseIntElementsAttr k_attr;
5688 if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
5689
5690 // The last dimension of the input tensor's shape should be known so we can
5691 // have clamped end_indices for slices.
5692 TensorType input_type = op.input().getType().cast<TensorType>();
5693 if (!input_type.hasRank()) return failure();
5694 int64_t input_rank = input_type.getRank();
5695 int64_t last_dim_index = input_rank - 1;
5696 int64_t last_dim_size = input_type.getDimSize(last_dim_index);
5697 if (last_dim_size == ShapedType::kDynamicSize) return failure();
5698
5699 // Create an Itoa op for indices.
5700 auto i32_type = rewriter.getIntegerType(32);
5701 Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
5702 Value iota_op = rewriter.create<mhlo::IotaOp>(
5703 op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
5704
5705 // Create the sort op. It takes two inputs, one for the original input, the
5706 // other for the indices.
5707 auto sort_op = rewriter.create<mhlo::SortOp>(
5708 op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
5709 /*is_stable=*/true);
5710
5711 // Use TOTALORDER comparison type instead of the default comparison if the
5712 // element type is of type float.
5713 llvm::Optional<StringRef> compare_type;
5714 if (input_type.getElementType().isa<FloatType>())
5715 compare_type.emplace("TOTALORDER");
5716 BuildSortComparisonBody({input_type.getElementType(), i32_type},
5717 /*direction=*/"GT", compare_type,
5718 &sort_op.comparator(), &rewriter);
5719
5720 // Get the sorted input and index tuple element.
5721 auto tuple_first_element = sort_op.getResult(0);
5722 auto tuple_second_element = sort_op.getResult(1);
5723
5724 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
5725 auto end_indices = llvm::to_vector<4>(input_type.getShape());
5726 end_indices.back() =
5727 std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
5728 SmallVector<int64_t, 4> strides(input_rank, 1);
5729
5730 // Get the slice for the top K elements.
5731
5732 Value values = rewriter.create<mhlo::SliceOp>(
5733 op.getLoc(), tuple_first_element,
5734 GetI64ElementsAttr(begin_indices, &rewriter),
5735 GetI64ElementsAttr(end_indices, &rewriter),
5736 GetI64ElementsAttr(strides, &rewriter));
5737
5738 Value indices = rewriter.create<mhlo::SliceOp>(
5739 op.getLoc(), tuple_second_element,
5740 GetI64ElementsAttr(begin_indices, &rewriter),
5741 GetI64ElementsAttr(end_indices, &rewriter),
5742 GetI64ElementsAttr(strides, &rewriter));
5743
5744 rewriter.replaceOp(op, {values, indices});
5745 return success();
5746 }
5747 };
5748
5749 // Converts tf.Unpack to a series of XLA HLO slice ops.
5750 //
5751 // Each slice takes one element along the dimension to unpack and takes the full
5752 // range for all other dimensions. Each slice is then reshaped to drop the
5753 // dimension to unpack (which is always of size 1).
5754 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
5755 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
5756 public:
5757 using OpRewritePattern::OpRewritePattern;
5758
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5759 LogicalResult matchAndRewrite(TF::UnpackOp op,
5760 PatternRewriter &rewriter) const override {
5761 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5762 if (!value_type) return failure();
5763
5764 int64_t value_rank = value_type.getRank();
5765 int64_t axis = op.axis();
5766 if (axis < 0) axis += value_rank;
5767
5768 // Parameters for constructing each slice.
5769 SmallVector<int64_t, 4> begin_indices(value_rank, 0);
5770 auto end_indices = llvm::to_vector<4>(value_type.getShape());
5771 SmallVector<int64_t, 4> strides(value_rank, 1);
5772
5773 // All HLO slice+squeeze results used to replace the original tf.Unpack op.
5774 SmallVector<Value, 4> results;
5775 results.reserve(op.getNumResults());
5776
5777 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
5778 begin_indices[axis] = i;
5779 end_indices[axis] = i + 1;
5780
5781 auto slice_op = rewriter.create<mhlo::SliceOp>(
5782 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
5783 GetI64ElementsAttr(end_indices, &rewriter),
5784 GetI64ElementsAttr(strides, &rewriter));
5785 // Reshape to drop the axis dimension.
5786 auto result =
5787 rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
5788 rewriter.getI64ArrayAttr(op.axis()));
5789 results.push_back(result);
5790 }
5791
5792 rewriter.replaceOp(op, results);
5793 return success();
5794 }
5795 };
5796
5797 // Converts tf.Unpack to a series of XLA HLO Slice ops.
5798 // TODO(disc): To recover static special case's performance with folding and
5799 // canonicalization.
5800 class ConvertUnpackOpDynamic : public OpRewritePattern<TF::UnpackOp> {
5801 public:
5802 using OpRewritePattern::OpRewritePattern;
5803
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5804 LogicalResult matchAndRewrite(TF::UnpackOp op,
5805 PatternRewriter &rewriter) const override {
5806 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5807 if (!value_type) return failure();
5808 // TODO(disc): Remove this constraint once fold and canonicalization
5809 // implemented.
5810 if (value_type.hasStaticShape()) return failure();
5811
5812 int64_t value_rank = value_type.getRank();
5813 int64_t axis = op.axis();
5814 if (axis < 0) axis += value_rank;
5815 Location loc = op.getLoc();
5816
5817 auto shape_scalar_type = rewriter.getIntegerType(32);
5818 // Parameters for constructing each slice.
5819 SmallVector<Value, 4> begin_indices, end_indices, strides;
5820 begin_indices.reserve(value_rank);
5821 end_indices.reserve(value_rank);
5822 strides.reserve(value_rank);
5823 // final output shape
5824 SmallVector<Value, 4> shape_values;
5825 shape_values.reserve(value_rank - 1);
5826 // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1
5827 SmallVector<int64_t, 4> slice_shape(value_rank, ShapedType::kDynamicSize);
5828 for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) {
5829 int64_t dim_size = value_type.getDimSize(dim_idx);
5830 if (dim_size == ShapedType::kDynamicSize) {
5831 Value dim_i = rewriter.create<IndexCastOp>(
5832 loc, rewriter.create<tensor::DimOp>(loc, op.getOperand(), dim_idx),
5833 shape_scalar_type);
5834 end_indices.push_back(dim_i);
5835 if (dim_idx != axis) {
5836 shape_values.push_back(dim_i);
5837 }
5838 } else {
5839 Value dim_i = rewriter.create<ConstantOp>(
5840 loc, shape_scalar_type,
5841 rewriter.getIntegerAttr(shape_scalar_type, dim_size));
5842 end_indices.push_back(dim_i);
5843 if (dim_idx != axis) {
5844 shape_values.push_back(dim_i);
5845 slice_shape[dim_idx] = dim_size;
5846 } else {
5847 slice_shape[dim_idx] = 1;
5848 }
5849 }
5850 begin_indices.push_back(rewriter.create<ConstantIntOp>(loc, 0, 32));
5851 strides.push_back(rewriter.create<ConstantIntOp>(loc, 1, 32));
5852 }
5853
5854 SmallVector<Value, 4> results;
5855 results.reserve(op.getNumResults());
5856 for (int64_t i = 0; i < op.getNumResults(); ++i) {
5857 begin_indices[axis] = rewriter.create<ConstantIntOp>(loc, i, 32);
5858 end_indices[axis] = rewriter.create<ConstantIntOp>(loc, i + 1, 32);
5859 Value slice_op = rewriter.create<RealDynamicSliceOp>(
5860 loc, RankedTensorType::get(slice_shape, value_type.getElementType()),
5861 op.value(),
5862 rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5863 begin_indices),
5864 rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5865 end_indices),
5866 rewriter.create<tensor::FromElementsOp>(loc, rewriter.getI32Type(),
5867 strides));
5868 // Reshape to drop the axis dimension.
5869 Value new_shape = rewriter.create<tensor::FromElementsOp>(
5870 loc, rewriter.getI32Type(), shape_values);
5871 Value reshape_op = rewriter.create<DynamicReshapeOp>(loc, op.getType(i),
5872 slice_op, new_shape);
5873 results.push_back(reshape_op);
5874 }
5875
5876 rewriter.replaceOp(op, results);
5877 return success();
5878 }
5879 };
5880
5881 // Converts the tf.Sign op into mhlo.sign
5882 // TODO(disc): To recover static special case's performance with folding and
5883 // canonicalization.
5884 class ConvertSignOpDynamic : public OpRewritePattern<TF::SignOp> {
5885 public:
5886 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::SignOp op,PatternRewriter & rewriter) const5887 LogicalResult matchAndRewrite(TF::SignOp op,
5888 PatternRewriter &rewriter) const override {
5889 Location loc = op.getLoc();
5890 Value x = op.x();
5891 auto x_type = x.getType().dyn_cast<RankedTensorType>();
5892 if (!x_type) return failure();
5893 // TODO(disc): Remove this constraint once fold and canonicalization
5894 // implemented.
5895 if (x_type.hasStaticShape()) return failure();
5896
5897 Value hlo_sign = rewriter.create<mhlo::SignOp>(loc, x);
5898 const StringAttr kNe = rewriter.getStringAttr(
5899 mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
5900 Value hlo_cmp = rewriter.create<mhlo::CompareOp>(loc, x, x, kNe);
5901
5902 auto zero =
5903 GetScalarConstOfType(x_type.getElementType(), loc, 0, &rewriter);
5904 Value shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), x);
5905
5906 auto broadcast_dims_attr =
5907 GetI64ElementsAttr(ArrayRef<int64_t>({}), &rewriter);
5908 Value broadcasted_zero = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
5909 loc, x_type, zero, shape_op, broadcast_dims_attr);
5910
5911 auto hlo_select = rewriter.create<mhlo::SelectOp>(
5912 loc, hlo_cmp, broadcasted_zero, hlo_sign);
5913
5914 rewriter.replaceOp(op, hlo_select.getResult());
5915 return success();
5916 }
5917 };
5918
5919 // Converts the tf.SigmoidGradOp
5920 // TODO(disc): To recover static special case's performance with folding and
5921 // canonicalization.
5922 class ConvertSigmoidGradOpDynamic : public OpRewritePattern<TF::SigmoidGradOp> {
5923 public:
5924 using OpRewritePattern::OpRewritePattern;
5925
matchAndRewrite(TF::SigmoidGradOp op,PatternRewriter & rewriter) const5926 LogicalResult matchAndRewrite(TF::SigmoidGradOp op,
5927 PatternRewriter &rewriter) const override {
5928 Location loc = op.getLoc();
5929 Value y = op.y();
5930 Value dy = op.dy();
5931 auto tp_y = y.getType().dyn_cast<RankedTensorType>();
5932 auto tp_dy = dy.getType().dyn_cast<RankedTensorType>();
5933 if (!tp_y || !tp_dy) return failure();
5934
5935 // TODO(disc): Remove this constraint once fold and canonicalization
5936 // implemented.
5937 if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure();
5938
5939 Attribute attr;
5940 Type elem_tp = tp_y.getElementType();
5941 if (elem_tp.isSignlessInteger()) {
5942 attr = rewriter.getIntegerAttr(elem_tp, 1);
5943 } else {
5944 assert(elem_tp.isa<FloatType>());
5945 attr = rewriter.getFloatAttr(elem_tp, 1);
5946 }
5947 Value one = rewriter.create<mhlo::ConstOp>(
5948 loc, DenseElementsAttr::get(RankedTensorType::get({}, elem_tp), attr));
5949
5950 auto v0 = rewriter.create<chlo::BroadcastMulOp>(
5951 loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y));
5952 auto v1 = rewriter.create<chlo::BroadcastSubOp>(
5953 loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y));
5954 auto result = rewriter.create<chlo::BroadcastMulOp>(
5955 loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1));
5956
5957 rewriter.replaceOp(op, result.getOperation()->getResults());
5958 return success();
5959 }
5960 };
5961
5962 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
5963 //
5964 // TF unsorted segment reduction op peforms the following calculation:
5965 //
5966 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is
5967 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
5968 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
5969 // Dm+1, ..., Dn]. Then
5970 // output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
5971 // <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
5972 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
5973 //
5974 // The op will be translated to XLA HLO scatter with the following parameters:
5975 // * Update window dims is [segment_id_rank, data_rank).
5976 // * Inserted window dims is {0}.
5977 // * Scatter dims to operand dims mapping is {0}.
5978 // * Index vector dim is segment_id_rank.
5979 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
5980 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
5981 using OpRewritePattern<OpTy>::OpRewritePattern;
5982
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5983 LogicalResult matchAndRewrite(OpTy op,
5984 PatternRewriter &rewriter) const override {
5985 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
5986 if (!data_type) return failure();
5987 int64_t data_rank = data_type.getRank();
5988
5989 auto segment_ids_type =
5990 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
5991 if (!segment_ids_type) return failure();
5992 int64_t segment_ids_rank = segment_ids_type.getRank();
5993
5994 DenseIntElementsAttr num_segments_attr;
5995 if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
5996 return failure();
5997
5998 // The final shape for TF unsorted segment reduction op is [num_segments] +
5999 // data_shape[segment_ids_rank:].
6000 SmallVector<int64_t, 4> output_shape;
6001 output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
6002 auto suffix = data_type.getShape().drop_front(segment_ids_rank);
6003 output_shape.append(suffix.begin(), suffix.end());
6004 auto output_type =
6005 RankedTensorType::get(output_shape, data_type.getElementType());
6006
6007 // Broadcast the initial value for reduction. This will become the
6008 // 'operand' parameter to scatter to for the final scatter op.
6009 Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
6010 op.getLoc(), &rewriter);
6011 auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
6012 op.getLoc(), output_type, init,
6013 GetI64ElementsAttr(output_shape, &rewriter));
6014
6015 // Parameters for the generated scatter op.
6016 SmallVector<int64_t, 1> inserted_window_dims(1, 0);
6017 SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
6018 int64_t index_vector_dim = segment_ids_rank;
6019
6020 // Put all parameters in a StructAttr.
6021 auto dims_attr = ScatterDimensionNumbers::get(
6022 GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
6023 GetI64ElementsAttr(inserted_window_dims, &rewriter),
6024 GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
6025 rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
6026
6027 auto scatter =
6028 rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
6029 op.segment_ids(), op.data(), dims_attr);
6030 BuildReduceBody<ReductionOp>(data_type.getElementType(),
6031 &scatter.update_computation(), &rewriter);
6032
6033 rewriter.replaceOp(op, scatter.getResult());
6034 return success();
6035 }
6036 };
6037
6038 class ConvertUnsortedSegmentMaxOp
6039 : public GenericConvertUnsortedSegmentReductionOp<
6040 ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
6041 public:
6042 using GenericConvertUnsortedSegmentReductionOp::
6043 GenericConvertUnsortedSegmentReductionOp;
6044
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6045 static Value GetInitialValue(Type reduce_element_type, Location loc,
6046 PatternRewriter *rewriter) {
6047 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
6048 rewriter);
6049 }
6050 };
6051
6052 class ConvertUnsortedSegmentMinOp
6053 : public GenericConvertUnsortedSegmentReductionOp<
6054 ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
6055 public:
6056 using GenericConvertUnsortedSegmentReductionOp::
6057 GenericConvertUnsortedSegmentReductionOp;
6058
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6059 static Value GetInitialValue(Type reduce_element_type, Location loc,
6060 PatternRewriter *rewriter) {
6061 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
6062 rewriter);
6063 }
6064 };
6065
6066 class ConvertUnsortedSegmentProdOp
6067 : public GenericConvertUnsortedSegmentReductionOp<
6068 ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
6069 public:
6070 using GenericConvertUnsortedSegmentReductionOp::
6071 GenericConvertUnsortedSegmentReductionOp;
6072
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6073 static Value GetInitialValue(Type reduce_element_type, Location loc,
6074 PatternRewriter *rewriter) {
6075 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
6076 }
6077 };
6078
6079 class ConvertUnsortedSegmentSumOp
6080 : public GenericConvertUnsortedSegmentReductionOp<
6081 ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
6082 public:
6083 using GenericConvertUnsortedSegmentReductionOp::
6084 GenericConvertUnsortedSegmentReductionOp;
6085
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)6086 static Value GetInitialValue(Type reduce_element_type, Location loc,
6087 PatternRewriter *rewriter) {
6088 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
6089 }
6090 };
6091
6092 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
6093 //
6094 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
6095 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
6096 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
6097 // translated into an HLO while op to first emulate shuffling indices using
6098 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
6099 // with the shuffled indices.
6100 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
6101 public:
6102 using OpRewritePattern::OpRewritePattern;
6103
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const6104 LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
6105 PatternRewriter &rewriter) const override {
6106 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
6107 if (!input_type) return failure();
6108
6109 int64_t input_rank = input_type.getRank();
6110 int64_t first_dim_size = input_type.getDimSize(0);
6111 if (ShapedType::isDynamic(first_dim_size)) return failure();
6112
6113 // We are shuffling along the first dimension. If its size is <= 1, then
6114 // shuffling is a no-op.
6115 if (first_dim_size <= 1) {
6116 rewriter.replaceOp(op, op.value());
6117 return success();
6118 }
6119
6120 // For vectors, shuffle values by sorting instead of the obvious
6121 // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
6122 // but not easily parallelizable. For a sufficiently parallel architecture,
6123 // it is faster to sort many times, than Fisher-Yates shuffle once.
6124 if (input_rank == 1) {
6125 // Shuffle values by assigning each value a random key and sorting the
6126 // keys. Keys can collide causing detectable patterns in the shuffled
6127 // output. Collisions translates into more ascending sub-sequences in the
6128 // shuffled output than would be expected by chance. To avoid collisions,
6129 // the number of possible key values must be sufficiently large.
6130
6131 // How are more than 2^32 keys created? In each loop iteration, the
6132 // algorithm sorts by random keys. Conceptually, the earlier iterations
6133 // are sorting on the lower-order bits of larger keys that are never
6134 // actually assembled.
6135
6136 // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
6137 // the number of possible keys and n is the number of values. If d = n^2,
6138 // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
6139 // as n goes to infinity is zero.
6140
6141 // This implementation ensures that the key-space is greater than or equal
6142 // to the cube of the number of values. The risk of collisions can be
6143 // further reduced by increasing Exponent at the expense of
6144 // performance.
6145
6146 // For Exponent = 2, the expected number of collisions per shuffle is
6147 // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
6148 // about 1/2.
6149
6150 // For Exponent = 3, the expected number of collisions per shuffle is
6151 // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
6152 // about 1/3255.
6153
6154 // For Exponent = 4, the expected number of collisions per shuffle is
6155 // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
6156 // about 1/132622.
6157 constexpr int exponent = 3;
6158 int64_t num_elements = input_type.getNumElements();
6159 uint32_t u32_max = std::numeric_limits<uint32_t>::max();
6160 int rounds =
6161 std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
6162
6163 Value current = op.value();
6164 for (int i = 0; i < rounds; ++i) {
6165 auto keys =
6166 CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
6167 /*upper_limit=*/u32_max, &rewriter);
6168 auto sorted = rewriter.create<mhlo::SortOp>(
6169 op.getLoc(), llvm::ArrayRef<Value>{keys, current});
6170 auto i32_type = rewriter.getIntegerType(32);
6171 BuildSortComparisonBody({i32_type, input_type.getElementType()},
6172 /*direction=*/"LT", llvm::None,
6173 &sorted.comparator(), &rewriter);
6174 current = sorted.getResult(1);
6175 }
6176 rewriter.replaceOp(op, current);
6177 return success();
6178 }
6179
6180 // The Fisher-Yates algorithm.
6181
6182 // Generate range(n) as the initial value for the indices to be swapped.
6183 auto indices_type =
6184 RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
6185 Value indices = rewriter.create<mhlo::IotaOp>(
6186 op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
6187
6188 // Generate random numbers to be used as swaps for the indices.
6189 Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
6190 first_dim_size, &rewriter);
6191
6192 // While loop body to perform index swaps.
6193 auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
6194 SmallVectorImpl<Value> *new_values,
6195 OpBuilder *builder) {
6196 Value swaps = old_values[0];
6197 Value indices = old_values[1];
6198
6199 auto vec1_i32_type =
6200 RankedTensorType::get({1}, builder->getIntegerType(32));
6201 auto scalar_i32_type =
6202 RankedTensorType::get({}, builder->getIntegerType(32));
6203 auto scalar_i64_type =
6204 RankedTensorType::get({}, builder->getIntegerType(64));
6205
6206 auto scalar_one =
6207 DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
6208
6209 // We need to swap the indices[i] with indices[swaps[i]]. First get
6210 // these index values.
6211 Value source_index = builder->create<mhlo::DynamicSliceOp>(
6212 loc, vec1_i32_type, indices, i, scalar_one);
6213 Value swap_index = builder->create<mhlo::ReshapeOp>(
6214 loc, scalar_i32_type,
6215 builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
6216 scalar_one));
6217 Value target_index = builder->create<mhlo::DynamicSliceOp>(
6218 loc, vec1_i32_type, indices, swap_index, scalar_one);
6219
6220 // Then perform the swap.
6221 // indices[i] <- indices[swaps[i]]
6222 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
6223 loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
6224 // indices[swaps[i]] <- indices[i]
6225 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
6226 loc, indices.getType(), indices, source_index,
6227 llvm::makeArrayRef(swap_index));
6228
6229 // Update new values.
6230 new_values->assign({swaps, indices});
6231 };
6232
6233 // Create a while op to swap indices.
6234 SmallVector<Value, 2> while_output;
6235 CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
6236 &while_output, &rewriter);
6237 Value swaped_indices = while_output[1];
6238
6239 // Gather the data using the swapped indices as the shuffled order.
6240 ArrayRef<int64_t> input_shape = input_type.getShape();
6241 SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
6242 slice_sizes[0] = 1;
6243 auto dims_attr = GatherDimensionNumbers::get(
6244 /*offset_dims=*/GetI64ElementsAttrForSeq(1, input_rank, &rewriter),
6245 /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
6246 /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
6247 /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
6248 rewriter.getContext());
6249 rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
6250 op, op.getType(), op.value(), swaped_indices, dims_attr,
6251 GetI64ElementsAttr(slice_sizes, &rewriter));
6252
6253 return success();
6254 }
6255 };
6256
6257 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
6258 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
6259 public:
6260 using OpRewritePattern::OpRewritePattern;
6261
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const6262 LogicalResult matchAndRewrite(TF::XlaShardingOp op,
6263 PatternRewriter &rewriter) const override {
6264 // TODO(b/148313088): define sharding attribute struct in MLIR intead of
6265 // using a string.
6266 if (!op._XlaSharding().hasValue()) return failure();
6267
6268 auto custom_call = rewriter.create<mhlo::CustomCallOp>(
6269 op.getLoc(), op.getType(), op.input(),
6270 /*call_target_name=*/rewriter.getStringAttr("Sharding"),
6271 /*has_side_effect=*/rewriter.getBoolAttr(false),
6272 /*backend_config=*/rewriter.getStringAttr(""),
6273 /*api_version=*/
6274 mhlo::CustomCallApiVersionAttr::get(
6275 rewriter.getContext(),
6276 mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL));
6277 custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
6278 rewriter.replaceOp(op, custom_call.getResult(0));
6279
6280 return success();
6281 }
6282 };
6283
6284 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
6285 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
6286 public:
6287 using OpRewritePattern::OpRewritePattern;
6288
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const6289 LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
6290 PatternRewriter &rewriter) const override {
6291 auto input = op.x();
6292 auto indices = op.i();
6293 auto updates = op.v();
6294
6295 // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
6296 // on the contents of `x`.
6297 auto input_type = input.getType().cast<ShapedType>();
6298 auto updates_type = updates.getType().cast<ShapedType>();
6299 auto indices_type = indices.getType().cast<ShapedType>();
6300 if (!indices_type.hasStaticShape()) return failure();
6301
6302 if (indices_type.getRank() != 1) return failure();
6303
6304 SmallVector<Type, 4> unpacked_indices_type(
6305 indices_type.getDimSize(0),
6306 RankedTensorType::get({}, indices_type.getElementType()));
6307 // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
6308 // required to have matching types. This rewrite rule creates
6309 // DynamicUpdateSlice ops where the first "start index" is always i32 and
6310 // subsequent ones are constructed based on zero_attr. Thus the type
6311 // for zero_attr needs to be i32 as well.
6312 auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
6313 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6314 op.getLoc(), unpacked_indices_type, indices, zero_attr);
6315
6316 SmallVector<int64_t, 4> split_updates_shape;
6317 split_updates_shape.append(updates_type.getShape().begin(),
6318 updates_type.getShape().end());
6319 split_updates_shape.front() = 1;
6320 SmallVector<Type, 4> split_updates_type;
6321 split_updates_type.resize(
6322 updates_type.getShape().front(),
6323 RankedTensorType::get(split_updates_shape,
6324 updates_type.getElementType()));
6325
6326 auto cst =
6327 rewriter.create<mhlo::ConstOp>(op.getLoc(), zero_attr).getResult();
6328 auto split_updates = rewriter.create<TF::SplitOp>(
6329 op.getLoc(), split_updates_type, cst, updates);
6330
6331 SmallVector<Value, 6> input_indices;
6332 input_indices.resize(input_type.getRank(), cst);
6333
6334 SmallVector<int64_t, 6> starts(updates_type.getRank(), 0);
6335 SmallVector<int64_t, 6> strides(updates_type.getRank(), 1);
6336 SmallVector<int64_t, 6> limits(updates_type.getShape().begin(),
6337 updates_type.getShape().end());
6338
6339 for (auto pair :
6340 llvm::zip(unpacked_indices.output(), split_updates.output())) {
6341 input_indices.front() = std::get<0>(pair);
6342 input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
6343 op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
6344 }
6345
6346 rewriter.replaceOp(op, input);
6347 return success();
6348 }
6349 };
6350
6351 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
6352 class ConvertXlaDynamicUpdateSliceOp
6353 : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
6354 public:
6355 using OpRewritePattern::OpRewritePattern;
6356
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const6357 LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
6358 PatternRewriter &rewriter) const override {
6359 auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
6360 if (!indices_type || !indices_type.hasStaticShape() ||
6361 indices_type.getShape().size() != 1)
6362 return failure();
6363
6364 SmallVector<Type, 4> unpacked_indices_type(
6365 indices_type.getDimSize(0),
6366 RankedTensorType::get({}, indices_type.getElementType()));
6367 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6368 op.getLoc(), unpacked_indices_type, op.indices(),
6369 IntegerAttr::get(rewriter.getIntegerType(64), 0));
6370 rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
6371 op, op.getType(), op.input(), op.update(), unpacked_indices.output());
6372 return success();
6373 }
6374 };
6375
6376 // Converts a TF XlaAllReduce op to AllReduce HLO.
6377 class ConvertXlaAllReduceOp : public OpRewritePattern<TF::XlaAllReduceOp> {
6378 using OpRewritePattern::OpRewritePattern;
6379
matchAndRewrite(TF::XlaAllReduceOp op,PatternRewriter & rewriter) const6380 LogicalResult matchAndRewrite(TF::XlaAllReduceOp op,
6381 PatternRewriter &rewriter) const override {
6382 DenseIntElementsAttr group_assignment;
6383 if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
6384 return failure();
6385 auto replica_groups =
6386 hlo::ConvertElementsAttr(group_assignment, rewriter.getIntegerType(64))
6387 .cast<DenseIntElementsAttr>();
6388 if (replica_groups.getType().getRank() != 2) return failure();
6389
6390 Location loc = op.getLoc();
6391 Type element_type = getElementTypeOrSelf(op.input().getType());
6392
6393 auto all_reduce = rewriter.create<AllReduceOp>(
6394 loc, op.getType(), op.input(), replica_groups, ChannelHandle());
6395 StringRef reduce_op = op.reduce_op();
6396 if (reduce_op == "Add") {
6397 BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
6398 &rewriter);
6399 } else if (reduce_op == "Mul") {
6400 BuildReduceBody<MulOp>(element_type, &all_reduce.computation(),
6401 &rewriter);
6402 } else if (reduce_op == "Min") {
6403 BuildReduceBody<MinOp>(element_type, &all_reduce.computation(),
6404 &rewriter);
6405 } else if (reduce_op == "Max") {
6406 BuildReduceBody<MaxOp>(element_type, &all_reduce.computation(),
6407 &rewriter);
6408 } else {
6409 // For mean, add replicas in the same group. Then divide the sum by the
6410 // number of replicas in each group below.
6411 assert(reduce_op == "Mean");
6412 BuildReduceBody<AddOp>(element_type, &all_reduce.computation(),
6413 &rewriter);
6414 }
6415 Value result = all_reduce.getResult();
6416
6417 // For mean, divide the merge result by group size.
6418 if (reduce_op == "Mean") {
6419 int64_t replica_group_size = replica_groups.getType().getDimSize(1);
6420 auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
6421 &rewriter);
6422 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
6423 result = rewriter.create<chlo::BroadcastDivOp>(
6424 loc, result, divisor.getResult(), broadcast_dims);
6425 }
6426
6427 rewriter.replaceOp(op, {result});
6428 return success();
6429 }
6430 };
6431
6432 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
6433 // semantics for static and dynamic cases.
6434 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
6435 public:
6436 using OpRewritePattern::OpRewritePattern;
6437
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const6438 LogicalResult matchAndRewrite(TF::ClipByValueOp op,
6439 PatternRewriter &rewriter) const override {
6440 Value input = op.t();
6441 Value min = op.clip_value_min();
6442 Value max = op.clip_value_max();
6443
6444 auto input_ty = input.getType().cast<ShapedType>();
6445 auto min_ty = min.getType().cast<ShapedType>();
6446 auto max_ty = max.getType().cast<ShapedType>();
6447
6448 if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
6449 return failure();
6450 }
6451
6452 auto shape = rewriter.create<TF::ShapeOp>(
6453 op.getLoc(),
6454 RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
6455 input);
6456
6457 if (min_ty != input_ty) {
6458 min =
6459 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
6460 }
6461
6462 if (max_ty != input_ty) {
6463 max =
6464 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
6465 }
6466
6467 rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
6468 return success();
6469 }
6470 };
6471
6472 // Converts ConstOp to XLA's constant operation and introduces a tensor cast if
6473 // needed.
6474 class ConvertConstOp : public OpRewritePattern<TF::ConstOp> {
6475 public:
6476 using OpRewritePattern::OpRewritePattern;
6477
matchAndRewrite(TF::ConstOp op,PatternRewriter & rewriter) const6478 LogicalResult matchAndRewrite(TF::ConstOp op,
6479 PatternRewriter &rewriter) const override {
6480 // Convert only for valid HLO tensors.
6481 auto ty = op.getType().dyn_cast<TensorType>();
6482 if (!ty || !ty.getElementType().isa<FloatType, IntegerType, ComplexType>())
6483 return failure();
6484
6485 Location loc = op.getLoc();
6486 Value result = rewriter.create<mhlo::ConstOp>(loc, op.value());
6487 if (result.getType() != op.getType())
6488 result = rewriter.create<tensor::CastOp>(loc, op.getType(), result);
6489 rewriter.replaceOp(op, result);
6490 return success();
6491 }
6492 };
6493
6494 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
6495 // setting appropriate window dimensions, with the given aggregation op as the
6496 // reduction function. The input tensor needs to have a static shape, and 'axis'
6497 // must be const. The TableGen pattern is not used for this rewrite because it
6498 // involves regions.
6499 template <typename OpT, typename AggregationOp>
6500 class ConvertCumOp : public OpRewritePattern<OpT> {
6501 using OpRewritePattern<OpT>::OpRewritePattern;
6502
matchAndRewrite(OpT op,PatternRewriter & rewriter) const6503 LogicalResult matchAndRewrite(OpT op,
6504 PatternRewriter &rewriter) const override {
6505 auto input = op.x();
6506 auto input_type = input.getType().template dyn_cast<ShapedType>();
6507 if (!input_type || !input_type.hasStaticShape()) {
6508 return failure();
6509 }
6510
6511 ArrayRef<int64_t> input_shape = input_type.getShape();
6512 int64_t rank = input_shape.size();
6513
6514 // We can only match when the axis is a constant scalar.
6515 DenseIntElementsAttr axis_attr;
6516 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
6517 return failure();
6518 }
6519
6520 // Get the dimension to apply the reduction on, and offset properly if it is
6521 // negative.
6522 int64_t axis = (*axis_attr.begin()).getSExtValue();
6523 if (axis < 0) {
6524 axis += rank;
6525 }
6526
6527 // If we're supposed to sum things up in the reverse direction, we reverse
6528 // the input and then later reverse the output.
6529 if (op.reverse()) {
6530 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6531 input = rewriter.create<ReverseOp>(
6532 op.getLoc(), op.getType(), input,
6533 GetI64ElementsAttr(dims_to_reverse, &rewriter));
6534 }
6535
6536 // Convert if we need to enlarge the element type's bitwidth to avoid
6537 // precision loss.
6538 Type input_element_type = input_type.getElementType();
6539
6540 // TODO(hinsu): Handle complex element types.
6541 if (!input_element_type.isIntOrFloat()) return failure();
6542
6543 Type sum_element_type = GetSumAccumulationType(input_element_type);
6544 input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
6545
6546 SmallVector<int64_t, 4> window_dims(rank, 1);
6547 SmallVector<int64_t, 4> window_strides(rank, 1);
6548 window_dims[axis] = input_shape[axis];
6549
6550 SmallVector<int64_t, 8> paddings(rank * 2, 0);
6551 paddings[axis * 2] =
6552 std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
6553 auto paddings_attr = DenseIntElementsAttr::get(
6554 RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
6555 paddings);
6556
6557 int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
6558 Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
6559 &rewriter);
6560
6561 auto reduce = rewriter.create<ReduceWindowOp>(
6562 op.getLoc(), input_type, input, init,
6563 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
6564 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
6565 /*base_dilations=*/DenseIntElementsAttr(),
6566 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
6567 BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
6568 Value result = reduce.getResult(0);
6569
6570 if (op.exclusive()) {
6571 // In "exclusive" operation, the output will start with the "init" (0)
6572 // values. There is no way to express that as a ReduceWindowOp, so run the
6573 // normal operation, and then use a PadOp to add the 0 "column" on the
6574 // left and cut away the last column on the right.
6575 llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
6576 llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
6577 llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
6578 low_padding[axis] = 1;
6579 high_padding[axis] = -1;
6580 result = rewriter.create<PadOp>(
6581 op.getLoc(), op.getType(), result, init,
6582 GetI64ElementsAttr(low_padding, &rewriter),
6583 GetI64ElementsAttr(high_padding, &rewriter),
6584 GetI64ElementsAttr(interior_padding, &rewriter));
6585 }
6586
6587 // Convert back if we enlarged the element type's bitwidth.
6588 result =
6589 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
6590
6591 if (op.reverse()) {
6592 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6593 result = rewriter.create<ReverseOp>(
6594 op.getLoc(), op.getType(), result,
6595 GetI64ElementsAttr(dims_to_reverse, &rewriter));
6596 }
6597
6598 rewriter.replaceOp(op, result);
6599 return success();
6600 }
6601 };
6602
6603 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
6604 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
6605
6606 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
6607 // dialect lowerings. This involves extracting the shape type, extracting and
6608 // converting each dimension to a known integer type, and repacking into a final
6609 // tensor.
6610 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
6611 public:
6612 using OpRewritePattern::OpRewritePattern;
6613
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const6614 LogicalResult matchAndRewrite(TF::ShapeOp op,
6615 PatternRewriter &rewriter) const override {
6616 Value input = op.input();
6617
6618 auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
6619 if (!result_ty) {
6620 return failure();
6621 }
6622
6623 auto index_tensor =
6624 RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
6625 auto shape_op =
6626 rewriter.create<shape::ShapeOfOp>(op.getLoc(), index_tensor, input);
6627 rewriter.replaceOpWithNewOp<IndexCastOp>(op, shape_op, result_ty);
6628 return success();
6629 }
6630 };
6631
6632 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
6633 public:
6634 using OpRewritePattern::OpRewritePattern;
6635
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const6636 LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
6637 PatternRewriter &rewriter) const override {
6638 auto input = op.input();
6639 auto input_ty = input.getType().cast<ShapedType>();
6640 auto result_ty = op.getType().cast<ShapedType>();
6641 if (!result_ty.hasRank() || !input_ty.hasRank() ||
6642 result_ty.hasStaticShape()) {
6643 return failure();
6644 }
6645
6646 DenseIntElementsAttr expand_dims_attr;
6647 if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
6648 return failure();
6649 }
6650
6651 auto shape = rewriter.create<shape::ShapeOfOp>(
6652 op.getLoc(),
6653 RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
6654 input);
6655 auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getIntValues());
6656
6657 llvm::SmallVector<Value, 4> dims;
6658 dims.resize(result_ty.getRank());
6659
6660 auto inserted_dim = expand_dims[0].getSExtValue();
6661
6662 // Handle the negative value use case.
6663 if (inserted_dim < 0) {
6664 inserted_dim += result_ty.getRank();
6665 // This means the value is completely incorrect, just return.
6666 if (inserted_dim < 0) {
6667 return failure();
6668 }
6669 }
6670
6671 dims[inserted_dim] = rewriter.create<ConstantIndexOp>(op.getLoc(), 1);
6672
6673 for (int i = 0; i < dims.size() - 1; i++) {
6674 // Add the extracted dim.
6675 Value index = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
6676 Value dim = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape, index);
6677 dims[i >= inserted_dim ? i + 1 : i] = dim;
6678 }
6679
6680 auto from_extents =
6681 rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6682 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
6683 from_extents);
6684 return success();
6685 }
6686 };
6687
6688 // Converts a TF QR op to HLO.
6689 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
6690 public:
6691 using OpRewritePattern::OpRewritePattern;
6692
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const6693 LogicalResult matchAndRewrite(TF::QrOp op,
6694 PatternRewriter &rewriter) const override {
6695 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
6696 // Loan. def qr_blocked(a, block_size):
6697 // m = a.shape[0]
6698 // n = a.shape[1]
6699 // q = np.eye(m)
6700 // for i in xrange(0, min(m, n), block_size):
6701 // k = min(block_size, min(m, n) - s)
6702 // (a, vs, taus) = qr(a[i:, i:i+k])
6703 // y = vs
6704 // w = ComputeWYRepresentation(vs, taus, m-i, k)
6705 // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
6706 // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
6707 // return (q, a)
6708 auto type = op.input().getType().dyn_cast<RankedTensorType>();
6709 if (!type || !type.hasStaticShape()) return failure();
6710 // The block size is chosen to match old bridge lowering.
6711 constexpr int64_t kBlockSize = 128;
6712 Value a = op.input();
6713 int64_t m = type.getDimSize(type.getRank() - 2);
6714 int64_t n = type.getDimSize(type.getRank() - 1);
6715 int64_t p = std::min(m, n);
6716 auto batch_dims = type.getShape().drop_back(2);
6717 auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
6718 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6719 rewriter.getI64IntegerAttr(0));
6720 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6721 rewriter.getI64IntegerAttr(1));
6722 Value compare = rewriter.create<CompareOp>(
6723 op.getLoc(), iota0, iota1,
6724 StringAttr::get(rewriter.getContext(), "EQ"));
6725 Value identity_matrix =
6726 rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
6727 auto q_shape = llvm::to_vector<4>(type.getShape());
6728 q_shape.back() = m;
6729 Value q = rewriter.create<BroadcastOp>(
6730 op.getLoc(), RankedTensorType::get(q_shape, type.getElementType()),
6731 identity_matrix, GetI64ElementsAttr(batch_dims, &rewriter));
6732 auto precision_config = rewriter.getStrArrayAttr({"HIGHEST", "HIGHEST"});
6733 for (int64_t i = 0; i < p; i += kBlockSize) {
6734 int64_t k = std::min(kBlockSize, p - i);
6735 auto a_block =
6736 SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
6737 Value r_block;
6738 Value taus;
6739 Value vs;
6740 QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
6741 a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
6742
6743 // Compute the I-WY block representation of a product of Householder
6744 // matrices.
6745 Value w =
6746 ComputeWYRepresentation(op.getLoc(), type.getElementType(),
6747 batch_dims, vs, taus, m - i, k, &rewriter);
6748 auto y = vs;
6749
6750 // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
6751 Value a_panel =
6752 SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
6753 auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
6754 batch_dims.size(), precision_config, &rewriter);
6755 a_update = BatchDot(op.getLoc(), y, false, a_update, false,
6756 batch_dims.size(), precision_config, &rewriter);
6757 a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
6758 a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
6759 &rewriter);
6760
6761 // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
6762 Value q_panel =
6763 SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
6764 Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
6765 batch_dims.size(), precision_config, &rewriter);
6766 q_update = BatchDot(op.getLoc(), q_update, false, y, true,
6767 batch_dims.size(), precision_config, &rewriter);
6768 q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
6769 q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
6770 }
6771 // full_matrices is false when only a partial result in needed. Slice to the
6772 // needed dimensions here.
6773 if (!op.full_matrices()) {
6774 q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
6775 a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
6776 }
6777 rewriter.replaceOp(op, {q, a});
6778 return success();
6779 }
6780
6781 private:
6782 // Computes a Householder reflection of the form:
6783 // H = I - tau v v.T.
6784 // such that
6785 // H . ( x1 ) = ( x1 )
6786 // ( x2 ) = ( x2 )
6787 // ( ... ) = ( ... )
6788 // ( xk ) = ( beta )
6789 // ( ... ) ( 0 )
6790 // ( ... ) ( 0 )
6791 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
6792 // only providing the relevant part of 'x' to maintain XLA's static shape
6793 // invariant. In addition, the implementation supports batching.
6794 // Pseudo-code, without batching:
6795 // alpha = x[k]
6796 // x_copy = np.copy(x)
6797 // x_copy[:k+1] = 0
6798 // xnorm = norm2(x_copy)
6799 // if xnorm == 0:
6800 // beta = alpha
6801 // tau = 0
6802 // v = np.zeros_like(x)
6803 // else:
6804 // beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
6805 // tau = (beta - alpha) / beta
6806 // v = x / (alpha - beta)
6807 // v[k] = 1
6808 // return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const6809 void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
6810 const int64_t m, OpBuilder *builder, Value *v, Value *tau,
6811 Value *beta) const {
6812 auto x_type = x.getType().cast<RankedTensorType>();
6813
6814 llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
6815 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
6816 const int64_t minor_dim = batch_dims.size();
6817
6818 Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
6819 Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
6820
6821 // alpha = x[k]
6822 Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
6823 alpha = builder->create<ReshapeOp>(
6824 loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
6825
6826 // Compute x[k+1:] (padded with zeros in elements 0..k)
6827 Value iota = builder->create<IotaOp>(
6828 loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
6829 builder->getI64IntegerAttr(0));
6830 Value gtk = builder->create<chlo::BroadcastCompareOp>(
6831 loc, iota, k, GetI64ElementsAttr({}, builder),
6832 StringAttr::get(builder->getContext(), "GT"));
6833 gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
6834 Value x_after_k = builder->create<chlo::BroadcastMulOp>(
6835 loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
6836 Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
6837 // sigma = np.dot(x[k+1:], x[k+1:])
6838 auto sigma = builder->create<ReduceOp>(
6839 loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
6840 BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
6841 // mu = np.sqrt(x[k]*x[k] + sigma)
6842 Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
6843 Value mu = builder->create<SqrtOp>(
6844 loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
6845
6846 Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
6847 loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
6848 StringAttr::get(builder->getContext(), "EQ"));
6849 Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
6850 loc, alpha, zero, GetI64ElementsAttr({}, builder),
6851 StringAttr::get(builder->getContext(), "LT"));
6852 auto batch_size_one = builder->create<BroadcastOp>(
6853 loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
6854 Value signed_mu = builder->create<chlo::BroadcastMulOp>(
6855 loc,
6856 builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
6857 batch_size_one,
6858 builder->create<NegOp>(loc, batch_size_one)),
6859 mu, GetI64ElementsAttr({}, builder));
6860 *beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
6861 alpha, signed_mu);
6862 *tau = builder->create<DivOp>(
6863 loc, builder->create<SubOp>(loc, *beta, alpha), *beta);
6864 Value zero_tau = builder->create<BroadcastOp>(
6865 loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder));
6866 *tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero,
6867 zero_tau, *tau);
6868 Value divisor = builder->create<SubOp>(loc, alpha, *beta);
6869 divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
6870 batch_size_one, divisor);
6871
6872 Value eqk = builder->create<chlo::BroadcastCompareOp>(
6873 loc, iota, k, GetI64ElementsAttr({}, builder),
6874 StringAttr::get(builder->getContext(), "EQ"));
6875 eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
6876 llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
6877 e_k_shape.push_back(m);
6878 auto e_k = builder->create<BroadcastOp>(
6879 loc, RankedTensorType::get(e_k_shape, x_type.getElementType()), eqk,
6880 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
6881 builder));
6882
6883 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
6884 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
6885 // Note that the add performs a degenerate broadcast.
6886 *v = builder->create<chlo::BroadcastAddOp>(
6887 loc, e_k,
6888 StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
6889 GetI64ElementsAttr(batch_dim_ids, builder),
6890 *builder),
6891 /*broadcast_dimensions=*/nullptr);
6892 }
6893
6894 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
6895 // Loan "Matrix Computations", 4th Edition. This is an unblocked
6896 // implementation used as an inner routine of the blocked implementation.
6897 // Algorithm is adapted slightly so the shapes inside the loop are static, at
6898 // the cost of some redundant computation. Since this is used as an inner
6899 // block kernel, accumulates the Householder transformations (vs, taus) rather
6900 // than the matrix q. Equivalent Python code, without batching: def qr(a):
6901 // m = a.shape[0]
6902 // n = a.shape[1]
6903 // vs = np.zeros([m, n])
6904 // taus = np.zeros([n])
6905 // for j in xrange(min(m, n)):
6906 // v, tau, beta = house(a[:, j], j)
6907 // # Unusually, we apply the Householder transformation to the entirety of
6908 // # a, wasting FLOPs to maintain the static shape invariant that XLA
6909 // # requires. For columns that precede j this has no effect.
6910 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6911 // np.dot(v[np.newaxis, :], a[:, :]))
6912 // # Form column j explicitly rather than relying on the precision of the
6913 // # Householder update.
6914 // a[j, j] = beta
6915 // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
6916 // vs[:, j] = v
6917 // taus[j] = tau
6918 // return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const6919 void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
6920 PatternRewriter *rewriter) const {
6921 auto a_type = a.getType().cast<RankedTensorType>();
6922 const int num_dims = a_type.getRank();
6923 assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
6924
6925 const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
6926 const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
6927
6928 const int64_t num_batch_dims = num_dims - 2;
6929 auto batch_dims = a_type.getShape().take_front(num_batch_dims);
6930 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6931 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6932
6933 auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6934 SmallVectorImpl<Value> *new_values,
6935 OpBuilder *builder) {
6936 auto a = old_values[0];
6937 auto vs = old_values[1];
6938 auto taus = old_values[2];
6939
6940 // v, beta = house(a[:, j], j)
6941 auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
6942 auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
6943 x_collapsed_shape.push_back(m);
6944 auto x_collapsed = builder->create<ReshapeOp>(
6945 loc,
6946 RankedTensorType::get(x_collapsed_shape,
6947 getElementTypeOrSelf(x.getType())),
6948 x);
6949 Value v, tau, beta;
6950 House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
6951
6952 auto shape = llvm::to_vector<4>(batch_dims);
6953 shape.append({1, m});
6954 auto v_broadcast = builder->create<ReshapeOp>(
6955 loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
6956 v);
6957 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6958 // np.dot(v[np.newaxis, :], a[:, :]))
6959 auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
6960 auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
6961 precision, builder);
6962 vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
6963 precision, builder);
6964 auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
6965 loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
6966 *builder);
6967 a = builder->create<SubOp>(loc, a, tau_x_vva);
6968
6969 // It is more precise to populate column 'k' explicitly, rather than
6970 // computing it implicitly by applying the Householder transformation.
6971 // a[k,k] = beta
6972 // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
6973 auto iota = builder->create<IotaOp>(
6974 loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
6975 builder->getI64IntegerAttr(0));
6976 Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
6977 loc, iota, j, GetI64ElementsAttr({}, builder),
6978 StringAttr::get(builder->getContext(), "LT"));
6979 predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
6980 a_type.getElementType());
6981 Value mask = builder->create<chlo::BroadcastCompareOp>(
6982 loc, iota, j, GetI64ElementsAttr({}, builder),
6983 StringAttr::get(builder->getContext(), "EQ"));
6984 mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
6985 llvm::SmallVector<int64_t, 4> broadcast_mask_shape(a_type.getRank(), 1);
6986 broadcast_mask_shape[a_type.getRank() - 2] = m;
6987 mask = builder->create<BroadcastOp>(
6988 loc,
6989 RankedTensorType::get(broadcast_mask_shape, a_type.getElementType()),
6990 mask,
6991 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
6992 builder));
6993 Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
6994 loc, x, predecessor_mask,
6995 GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
6996 Value masked_beta = StaticBinaryBroadcast<MulOp>(
6997 loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
6998 *builder);
6999 Value new_x =
7000 builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
7001 // Update a[:,j]
7002 llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
7003 std::iota(dim_ids.begin(), dim_ids.end(), 0);
7004 new_x = builder->create<BroadcastInDimOp>(
7005 loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
7006 const int64_t minor_dim = num_batch_dims;
7007 auto iota_mn = builder->create<IotaOp>(
7008 loc,
7009 RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
7010 builder->getI64IntegerAttr(minor_dim + 1));
7011 Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
7012 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
7013 StringAttr::get(builder->getContext(), "EQ"));
7014 a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
7015
7016 // vs[:, j] = v
7017 llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
7018 std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
7019 Value vs_zeros =
7020 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
7021 vs_zeros = builder->create<BroadcastOp>(
7022 loc, vs.getType(), vs_zeros,
7023 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
7024 builder));
7025 auto vs_update = builder->create<SelectOp>(
7026 loc, vs.getType(), xa_mask,
7027 StaticBinaryBroadcast<AddOp>(
7028 loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
7029 *builder),
7030 vs_zeros);
7031 vs = builder->create<AddOp>(loc, vs, vs_update);
7032
7033 // taus[j] = tau
7034 llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
7035 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
7036
7037 auto iota_shape = llvm::to_vector<4>(batch_dims);
7038 iota_shape.push_back(n);
7039 auto iota_n = builder->create<IotaOp>(
7040 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
7041 builder->getI64IntegerAttr(minor_dim));
7042 Value taus_zeros =
7043 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
7044 taus_zeros = builder->create<BroadcastOp>(
7045 loc, taus.getType(), taus_zeros,
7046 GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
7047 builder));
7048 Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
7049 loc, iota_n, j, GetI64ElementsAttr({}, builder),
7050 StringAttr::get(builder->getContext(), "EQ"));
7051 auto taus_update = builder->create<SelectOp>(
7052 loc, taus.getType(), taus_mask,
7053 StaticBinaryBroadcast<AddOp>(
7054 loc, taus_zeros, tau,
7055 GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
7056 taus_zeros);
7057 taus = builder->create<AddOp>(loc, taus, taus_update);
7058 new_values->assign({a, vs, taus});
7059 };
7060
7061 Value zero =
7062 GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
7063 *vs = rewriter->create<BroadcastOp>(
7064 loc, a_type, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
7065 auto taus_shape = llvm::to_vector<4>(batch_dims);
7066 taus_shape.push_back(n);
7067 *taus = rewriter->create<BroadcastOp>(
7068 loc, RankedTensorType::get(taus_shape, a_type.getElementType()), zero,
7069 GetI64ElementsAttr(taus_shape, rewriter));
7070
7071 SmallVector<Value, 4> while_output;
7072 CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
7073 &while_output, rewriter);
7074 *r = while_output[0];
7075 *vs = while_output[1];
7076 *taus = while_output[2];
7077 }
7078
7079 // Computes W and Y such that I-WY is equivalent to the sequence of
7080 // Householder
7081 // transformations given by vs and taus.
7082 // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
7083 // Y = np.zeros([m, n])
7084 // W = np.zeros([m, n])
7085 // Y[:, 0] = vs[:, 0]
7086 // W[:, 0] = -taus[0] * vs[:, 0]
7087 // for j in xrange(1, n):
7088 // v = vs[:, j]
7089 // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
7090 // W[:, j] = z
7091 // Y[:, j] = v
7092 // return W
7093 // There is no need to return Y since at termination of the loop it is equal
7094 // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const7095 Value ComputeWYRepresentation(Location loc, Type data_type,
7096 ArrayRef<int64_t> batch_dims, Value vs,
7097 Value taus, int64_t m, int64_t n,
7098 PatternRewriter *rewriter) const {
7099 int64_t n_index = batch_dims.size() + 1;
7100 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
7101 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
7102
7103 auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
7104 SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
7105 // w has shape [..., m, n]
7106 auto w = old_values[0];
7107 const auto vs = old_values[1];
7108 const auto taus = old_values[2];
7109
7110 // Want j values in range [1, ... n).
7111 j = builder->create<AddOp>(
7112 loc, j,
7113 GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
7114 builder));
7115 // vs has shape [..., m, 1]
7116 auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
7117 // beta has shape [..., 1]
7118 auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
7119
7120 auto iota_shape = llvm::to_vector<4>(batch_dims);
7121 iota_shape.append({m, n});
7122 auto iota_mn = builder->create<IotaOp>(
7123 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
7124 builder->getI64IntegerAttr(n_index));
7125
7126 // y has shape [..., m, n]
7127 Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
7128 0, builder);
7129 zero = builder->create<BroadcastOp>(
7130 loc, vs.getType(), zero,
7131 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
7132 builder));
7133 auto compare = builder->create<chlo::BroadcastCompareOp>(
7134 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
7135 StringAttr::get(builder->getContext(), "GE"));
7136 auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
7137
7138 // yv has shape [..., n, 1]
7139 auto precision = builder->getStrArrayAttr({"HIGHEST", "HIGHEST"});
7140 auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
7141 builder);
7142 // wyv has shape [..., m, 1]
7143 auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
7144 precision, builder);
7145
7146 // z = -beta * (v + wyv)
7147 auto neg_beta = builder->create<NegOp>(loc, beta);
7148 auto v_wyv = builder->create<AddOp>(loc, v, wyv);
7149 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
7150 beta_broadcast_dims.push_back(n_index);
7151 auto z = StaticBinaryBroadcast<MulOp>(
7152 loc, neg_beta, v_wyv,
7153 GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
7154
7155 w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
7156 new_values->assign({w, vs, taus});
7157 };
7158
7159 Value w =
7160 GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
7161 auto w_shape = llvm::to_vector<4>(batch_dims);
7162 w_shape.append({m, n});
7163 w = rewriter->create<BroadcastOp>(loc,
7164 RankedTensorType::get(w_shape, data_type),
7165 w, GetI64ElementsAttr(w_shape, rewriter));
7166 auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
7167 auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
7168 auto neg_beta = rewriter->create<NegOp>(loc, beta);
7169 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
7170 beta_broadcast_dims.push_back(n_index);
7171 auto bv = StaticBinaryBroadcast<MulOp>(
7172 loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
7173 *rewriter);
7174 w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
7175
7176 SmallVector<Value, 4> while_output;
7177 CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
7178 return while_output[0];
7179 }
7180 };
7181
7182 // TF.PrintOp to mhlo.PrintOp
7183 class ConvertPrintOp : public OpRewritePattern<TF::PrintOp> {
7184 public:
7185 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::PrintOp op,PatternRewriter & rewriter) const7186 LogicalResult matchAndRewrite(TF::PrintOp op,
7187 PatternRewriter &rewriter) const final {
7188 auto input = op.input();
7189 rewriter.replaceOpWithNewOp<mhlo::PrintOp>(op, op.getType(), input);
7190 return success();
7191 }
7192 };
7193
7194 // Emits debug information which includes the number of ops of each type which
7195 // failed to legalize.
EmitLegalizationErrors(Operation * op,const DenseSet<Operation * > & nonlegalized_ops)7196 void EmitLegalizationErrors(Operation *op,
7197 const DenseSet<Operation *> &nonlegalized_ops) {
7198 // Track the legalization failures by mapping op name to information about
7199 // that failure: the number of unlegalized occurrences of the op, and one
7200 // example operation that failed.
7201 std::map<StringRef, std::pair<int, Operation *>> op_name_to_error_info;
7202 DenseSet<Operation *> error_ops;
7203 for (Operation *nonlegalized_op : nonlegalized_ops) {
7204 // Increment count of this legalization failure.
7205 StringRef op_name = nonlegalized_op->getName().getStringRef();
7206 // If this emplace is successful, it's the first time we've encountered
7207 // this op type. Initialize count to 0 so that after increment, it is 1.
7208 auto insertion_result = op_name_to_error_info.emplace(
7209 op_name, std::make_pair(0, nonlegalized_op));
7210 ++insertion_result.first->second.first;
7211 }
7212 std::vector<std::string> error_messages;
7213 error_messages.reserve(op_name_to_error_info.size());
7214 for (const auto &op_info : op_name_to_error_info) {
7215 error_messages.push_back(
7216 llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first));
7217 }
7218 Location loc = op->getLoc();
7219 emitError(loc) << "The following operations cannot be legalized: "
7220 << llvm::join(error_messages, "; ")
7221 << ". These legalization failure(s) may be due to missing TF "
7222 "to HLO lowerings and/or unsupported attributes, etc.";
7223 // Emit more information about the missing ops. This error message
7224 // contains useful details beyond the op name (input and output shapes,
7225 // attributes, etc.).
7226 if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) {
7227 emitError(loc)
7228 << "Emitting more detail about one op that failed to legalize...";
7229 } else if (VLOG_IS_ON(1)) {
7230 emitError(loc) << "Emitting more detail about one of each type of op "
7231 "that failed to legalize...";
7232 }
7233 for (const auto &op_info : op_name_to_error_info) {
7234 op_info.second.second->emitOpError() << "is not legalizable";
7235 if (!VLOG_IS_ON(1)) break;
7236 }
7237 }
7238
7239 // Performs the lowering to XLA dialect.
runOnFunction()7240 void LegalizeTF::runOnFunction() {
7241 llvm::Optional<StringRef> tf2xla_fallback_device_type = llvm::None;
7242 if (use_tf2xla_fallback_) {
7243 tf2xla_fallback_device_type = device_type_;
7244 }
7245 if (failed(legalizeTF(getFunction(), allow_partial_conversion_,
7246 legalize_chlo_, tf2xla_fallback_device_type,
7247 prefer_tf2xla_))) {
7248 signalPassFailure();
7249 }
7250 }
7251
7252 // Patterns whose root op is in the set `include_ops` are moved from the set
7253 // `from` to the returned set. This is used to partition patterns by op so they
7254 // can be cleanly migrated from the old bridge to the MLIR bridge.
PatternsIncludeOps(OwningRewritePatternList & from,const llvm::DenseSet<mlir::TypeID> & include_ops)7255 OwningRewritePatternList PatternsIncludeOps(
7256 OwningRewritePatternList &from,
7257 const llvm::DenseSet<mlir::TypeID> &include_ops) {
7258 OwningRewritePatternList to(from.getContext());
7259 // Filter NativePatterns.
7260 for (auto &pattern : from.getNativePatterns()) {
7261 Optional<OperationName> pat_op_name = pattern->getRootKind();
7262 // If the pattern does not have a specific operation, always include it,
7263 // If the pattern is in include_ops then include it.
7264 bool include =
7265 !pat_op_name ||
7266 include_ops.count(pat_op_name->getAbstractOperation()->typeID);
7267 if (include) to.add(std::move(pattern));
7268 }
7269
7270 // Don't filter PDLPatterns.
7271 to.add(std::move(from.getPDLPatterns()));
7272
7273 return to;
7274 }
7275
7276 /// Returns ops that should use MLIR legalization only in the case of
7277 /// prefer_tf2xla. All other ops not in this list should use XlaOpKernel
7278 /// legalization only or not be legalized by the new bridge.
MlirPreferredOps()7279 const llvm::DenseSet<mlir::TypeID> &MlirPreferredOps() {
7280 // The static variable is a pointer in order to avoid destruction upon thread
7281 // termination.
7282
7283 // clang-format off
7284 static const llvm::DenseSet<mlir::TypeID>* ops =
7285 new llvm::DenseSet<mlir::TypeID>{
7286 // Ops that are legalized in the old bridge using MlirXlaOpKernel
7287 TypeID::get<TF::AbsOp>(),
7288 TypeID::get<TF::AtanOp>(),
7289 TypeID::get<TF::AvgPool3DOp>(),
7290 TypeID::get<TF::BiasAddGradOp>(),
7291 TypeID::get<TF::CeilOp>(),
7292 TypeID::get<TF::CheckNumericsOp>(),
7293 TypeID::get<TF::ComplexOp>(),
7294 TypeID::get<TF::CosOp>(),
7295 TypeID::get<TF::DiagPartOp>(),
7296 TypeID::get<TF::DivOp>(),
7297 TypeID::get<TF::EinsumOp>(),
7298 TypeID::get<TF::ExpOp>(),
7299 TypeID::get<TF::Expm1Op>(),
7300 TypeID::get<TF::FakeQuantWithMinMaxArgsOp>(),
7301 TypeID::get<TF::FloorOp>(),
7302 TypeID::get<TF::GreaterEqualOp>(),
7303 TypeID::get<TF::IFFTOp>(),
7304 TypeID::get<TF::ImagOp>(),
7305 TypeID::get<TF::IsFiniteOp>(),
7306 TypeID::get<TF::IsInfOp>(),
7307 TypeID::get<TF::IsNanOp>(),
7308 TypeID::get<TF::LessEqualOp>(),
7309 TypeID::get<TF::LgammaOp>(),
7310 TypeID::get<TF::Log1pOp>(),
7311 TypeID::get<TF::LogicalOrOp>(),
7312 TypeID::get<TF::LogSoftmaxOp>(),
7313 TypeID::get<TF::MatrixBandPartOp>(),
7314 TypeID::get<TF::MaxPool3DGradOp>(),
7315 TypeID::get<TF::PreventGradientOp>(),
7316 TypeID::get<TF::RandomShuffleOp>(),
7317 TypeID::get<TF::RealOp>(),
7318 TypeID::get<TF::ReciprocalOp>(),
7319 TypeID::get<TF::ReluOp>(),
7320 TypeID::get<TF::Relu6Op>(),
7321 TypeID::get<TF::ReluGradOp>(),
7322 TypeID::get<TF::RsqrtOp>(),
7323 TypeID::get<TF::SelectOp>(),
7324 TypeID::get<TF::SigmoidOp>(),
7325 TypeID::get<TF::SignOp>(),
7326 TypeID::get<TF::SoftmaxOp>(),
7327 TypeID::get<TF::SqrtOp>(),
7328 TypeID::get<TF::SqrtGradOp>(),
7329 TypeID::get<TF::SquaredDifferenceOp>(),
7330 TypeID::get<TF::TanhOp>(),
7331 TypeID::get<TF::TanhGradOp>(),
7332 TypeID::get<TF::XlogyOp>(),
7333 TypeID::get<TF::ZetaOp>(),
7334
7335 // Ops that have no XlaOpKernel.
7336 TypeID::get<TF::RiscAddOp>(),
7337 TypeID::get<TF::RiscDotOp>(),
7338
7339 // Const op has a simple legalization and it is much more efficient to lower
7340 // within MLIR.
7341 TypeID::get<TF::ConstOp>(),
7342
7343 // AssertOp with string types are not supported by the fallback.
7344 TypeID::get<TF::AssertOp>(),
7345
7346 // TF2XLA fallback pattern doesn't support these op as MLIR hlo builder
7347 // doesn't override the necessary builder methods. These ops have simple
7348 // lowering pattern so this should be safe.
7349 TypeID::get<TF::CrossReplicaSumOp>(),
7350 TypeID::get<TF::InfeedDequeueTupleOp>(),
7351 TypeID::get<TF::OutfeedEnqueueTupleOp>(),
7352 TypeID::get<TF::XlaShardingOp>(),
7353 };
7354 // clang-format on
7355 return *ops;
7356 }
7357
7358 } // end namespace
7359
7360 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
7361
legalizeTF(Operation * op,bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)7362 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
7363 bool legalize_chlo,
7364 llvm::Optional<StringRef> tf2xla_fallback_device_type,
7365 bool prefer_tf2xla) {
7366 MLIRContext *context = op->getContext();
7367 OwningRewritePatternList legalize_lower_patterns(context);
7368 // Note that the `OperationConverter` orders patterns lexicographically by:
7369 // 1) Ascending legalization depth (i.e., minimum number of patterns necessary
7370 // to arrive at conversion target). This requires relevant patterns to
7371 // specify the list of ops generated by it which most of patterns
7372 // implemented in C++ don't do so this comparison doesn't work in those
7373 // cases.
7374 // 2) Descending pattern benefit.
7375 // 3) Op specific patterns over patterns with MatchAnyOpTypeTag.
7376 // 4) Order of patterns in `OwningRewritePatternList`.
7377
7378 // Add TF->HLO legalization patterns.
7379 PopulateLegalizeTfPatterns(context, &legalize_lower_patterns);
7380
7381 // Add TF->TF lowering patterns.
7382 TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns);
7383
7384 if (tf2xla_fallback_device_type && prefer_tf2xla) {
7385 VLOG(1) << "TF to XLA legalization patterns are partitioned by op into "
7386 "either native MLIR legalization, or TF2XLA fallback "
7387 "legalzation, with a preference toward TF2XLA.";
7388 } else if (tf2xla_fallback_device_type) {
7389 VLOG(1) << "TF to XLA legalization patterns include all native patterns "
7390 "and TF2XLA fallback patterns.";
7391 } else {
7392 VLOG(1) << "TF to XLA legalization patterns are native patterns only.";
7393 }
7394
7395 // Set patterns to legalize_lower_patters, where in the prefer_tf2xla case
7396 // only patterns whose ops are in the set MlirPreferredOps are kept.
7397 OwningRewritePatternList patterns =
7398 (tf2xla_fallback_device_type && prefer_tf2xla)
7399 ? PatternsIncludeOps(legalize_lower_patterns, MlirPreferredOps())
7400 : std::move(legalize_lower_patterns);
7401
7402 if (tf2xla_fallback_device_type) {
7403 // Add TF->HLO legalization patterns via TF2XLA fallback.
7404 PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(),
7405 patterns, context, prefer_tf2xla);
7406 }
7407
7408 // Populate with CHLO->HLO lowerings to account for TF ops legalized to
7409 // CHLO first.
7410 if (legalize_chlo) {
7411 chlo::PopulateDecomposeChloPatterns(context, &patterns);
7412 chlo::PopulateChloBroadcastingPatterns(context, &patterns);
7413 }
7414 // ConstantLike op is convenient to create splat constants, but is
7415 // canonicalized to plain HLO constant if statically shaped. Add the
7416 // canonicalization pattern to pattern list to enable multi-hop lowering.
7417 chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context);
7418
7419 ConversionTarget target(*context);
7420 if (legalize_chlo) {
7421 target.addIllegalDialect<chlo::HloClientDialect>();
7422 } else {
7423 target.addLegalDialect<chlo::HloClientDialect>();
7424 }
7425 target.addLegalDialect<MhloDialect>();
7426 target.addLegalDialect<StandardOpsDialect>();
7427 target.addLegalDialect<tensor::TensorDialect>();
7428 target.addLegalDialect<shape::ShapeDialect>();
7429 target.addLegalOp<CallOp>();
7430
7431 if (!allow_partial_conversion) {
7432 // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
7433 target.addLegalOp<ModuleOp, FuncOp, ::mlir::ReturnOp>();
7434 DenseSet<Operation *> nonlegalized_ops;
7435 LogicalResult result = applyPartialConversion(
7436 op, target, std::move(patterns), &nonlegalized_ops);
7437 // In order to enforce that the conversion result is fully converted,
7438 // fail if there are any nonlegalized ops in the set.
7439 if (failed(result) || !nonlegalized_ops.empty()) {
7440 EmitLegalizationErrors(op, nonlegalized_ops);
7441 return failure();
7442 }
7443 return result;
7444 }
7445
7446 return applyPartialConversion(op, target, std::move(patterns));
7447 }
7448
PopulateLegalizeTfPatterns(MLIRContext * context,OwningRewritePatternList * patterns)7449 void PopulateLegalizeTfPatterns(MLIRContext *context,
7450 OwningRewritePatternList *patterns) {
7451 populateWithGenerated(*patterns);
7452 // clang-format off
7453 patterns->insert<
7454 ConvertAllOp,
7455 ConvertAnyOp,
7456 ConvertArgMaxOp,
7457 ConvertArgMinOp,
7458 ConvertBatchMatMulV2Op,
7459 ConvertBiasAddOp,
7460 ConvertBroadcastToOp,
7461 ConvertBF16FloorDivOp,
7462 ConvertClipByValueOp,
7463 ConvertConstOp,
7464 ConvertConv2DOp,
7465 ConvertConv3DOp,
7466 ConvertDepthConv2DOp,
7467 ConvertConv2DBackpropFilterOp,
7468 ConvertConv3DBackpropFilterOp,
7469 ConvertConv2DBackpropInputOp,
7470 ConvertConv3DBackpropInputOp,
7471 ConvertCumprodOp,
7472 ConvertCumsumOp,
7473 ConvertDiagPartOp,
7474 ConvertDynamicExpandDimsOp,
7475 ConvertEinsumOp,
7476 ConvertRFFTOp,
7477 ConvertIRFFTOp,
7478 ConvertFusedBatchNormGradOp,
7479 ConvertFusedBatchNormGradV2Op,
7480 ConvertFusedBatchNormGradV3Op,
7481 ConvertFusedBatchNormV2Op,
7482 ConvertFusedBatchNormV3Op,
7483 ConvertInfeedDequeueTupleOp,
7484 ConvertIdentityNOp,
7485 ConvertInplaceUpdateOp,
7486 ConvertLinSpaceOp,
7487 ConvertMaxOp,
7488 ConvertMinOp,
7489 ConvertAvgPool2DOp,
7490 ConvertAvgPool3DOp,
7491 ConvertAvgPool2DGradOp,
7492 ConvertAvgPool3DGradOp,
7493 ConvertMaxPool2DOp,
7494 ConvertMaxPool3DOp,
7495 ConvertMaxPool2DGradOp,
7496 ConvertMaxPool3DGradOp,
7497 ConvertMeanOp,
7498 ConvertOneHotOp,
7499 ConvertOutfeedEnqueueTupleOp,
7500 ConvertProdOp,
7501 ConvertQrOp,
7502 ConvertDynamicRangeOp,
7503 ConvertMatrixDiagPartV3Op,
7504 ConvertRangeOp,
7505 ConvertSelectOp,
7506 ConvertSigmoidOp,
7507 ConvertShapeOp,
7508 ConvertSplitOp,
7509 ConvertSplitVOp,
7510 ConvertStridedSliceOp,
7511 ConvertStridedSliceGradOp,
7512 ConvertSumOp,
7513 ConvertTensorScatterAddOp,
7514 ConvertTensorScatterSubOp,
7515 ConvertTensorScatterMinOp,
7516 ConvertTensorScatterMaxOp,
7517 ConvertTensorScatterUpdateOp,
7518 ConvertTileOp,
7519 ConvertTopKV2Op,
7520 ConvertUnpackOp,
7521 ConvertUnsortedSegmentMaxOp,
7522 ConvertUnsortedSegmentMinOp,
7523 ConvertUnsortedSegmentProdOp,
7524 ConvertUnsortedSegmentSumOp,
7525 ConvertRandomShuffleOp,
7526 ConvertXlaShardingOp,
7527 ConvertXlaDynamicUpdateSliceOp,
7528 ConvertXlaAllReduceOp,
7529 ConvertRollOp,
7530 ConvertLeakyReluOp,
7531 ConvertLeakyReluGradOp,
7532 ConvertPrintOp,
7533 ConvertSplitOpDynamic,
7534 ConvertSliceOpDynamic,
7535 ConvertTileOpDynamic,
7536 ConvertUnpackOpDynamic,
7537 ConvertSignOpDynamic,
7538 ConvertSigmoidGradOpDynamic,
7539 ConvertGatherV2OpDynamic,
7540 ConvertConv2DDynamic,
7541 ConvertPadOpDynamic,
7542 ConvertGatherNdOpDynamic>(context);
7543 // clang-format on
7544 }
7545
createLegalizeTFPass(bool allow_partial_conversion,bool legalize_chlo,llvm::Optional<StringRef> tf2xla_fallback_device_type,bool prefer_tf2xla)7546 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
7547 bool allow_partial_conversion, bool legalize_chlo,
7548 llvm::Optional<StringRef> tf2xla_fallback_device_type, bool prefer_tf2xla) {
7549 return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo,
7550 tf2xla_fallback_device_type,
7551 prefer_tf2xla);
7552 }
7553
7554 } // end namespace mhlo
7555 } // end namespace mlir
7556