1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17
18 #include <cstddef>
19 #include <cstdint>
20 #include <iterator>
21 #include <numeric>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
28 #include "mlir/Dialect/Traits.h" // TF:llvm-project
29 #include "mlir/IR/Attributes.h" // TF:llvm-project
30 #include "mlir/IR/Diagnostics.h" // TF:llvm-project
31 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
32 #include "mlir/IR/Matchers.h" // TF:llvm-project
33 #include "mlir/IR/Module.h" // TF:llvm-project
34 #include "mlir/IR/Operation.h" // TF:llvm-project
35 #include "mlir/IR/PatternMatch.h" // TF:llvm-project
36 #include "mlir/IR/StandardTypes.h" // TF:llvm-project
37 #include "mlir/IR/TypeUtilities.h" // TF:llvm-project
38 #include "mlir/IR/Types.h" // TF:llvm-project
39 #include "mlir/Pass/Pass.h" // TF:llvm-project
40 #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
43 #include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
44 #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
45 #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h"
46 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
47 #include "tensorflow/compiler/xla/client/padding.h"
48 #include "tensorflow/core/framework/common_shape_fns.h"
49 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
50 #include "tensorflow/core/util/padding.h"
51 #include "tensorflow/core/util/tensor_format.h"
52
53 namespace mlir {
54 namespace xla_hlo {
55 namespace {
56
57 class LegalizeTF : public FunctionPass<LegalizeTF> {
58 public:
59 LegalizeTF() = default;
LegalizeTF(const LegalizeTF &)60 LegalizeTF(const LegalizeTF &) {}
LegalizeTF(bool allow_partial_conversion)61 explicit LegalizeTF(bool allow_partial_conversion) {
62 allow_partial_conversion_ = allow_partial_conversion;
63 }
64
65 /// Performs the lowering to XLA dialect.
66 void runOnFunction() override;
67
68 private:
69 Option<bool> allow_partial_conversion_{
70 *this, "allow-partial-conversion",
71 llvm::cl::desc("Allow operations that can't be legalized."),
72 llvm::cl::init(false)};
73 };
74
75 /// Returns if the given TF data format string is the default format.
IsDefaultDataFormat(StringRef format)76 static bool IsDefaultDataFormat(StringRef format) { return format == "NHWC"; }
77
78 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(StringAttr format,RankedTensorType inputType)79 static size_t GetFeatureDimension(StringAttr format,
80 RankedTensorType inputType) {
81 return IsDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1;
82 }
83
84 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)85 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
86 auto array_attr = attr.cast<ArrayAttr>();
87 values->reserve(array_attr.getValue().size());
88 for (Attribute val : array_attr.getValue())
89 values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
90 }
91
92 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)93 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
94 Builder *builder) {
95 RankedTensorType ty = RankedTensorType::get(
96 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
97 return DenseIntElementsAttr::get(ty, values);
98 }
99
100 // Converts an ArrayAttr to a 1D 64-bit dense elements attribute.
GetI64ElementsAttr(ArrayAttr attr)101 static DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) {
102 RankedTensorType ty =
103 RankedTensorType::get(static_cast<int64_t>(attr.size()),
104 IntegerType::get(64, attr.getContext()));
105 return DenseIntElementsAttr::get(ty, attr.getValue());
106 }
107
108 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)109 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
110 Builder *builder) {
111 RankedTensorType ty = RankedTensorType::get(
112 {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
113 return DenseIntElementsAttr::get(ty, values);
114 }
115
116 // Returns the corresponding type that should be used for performing sum
117 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)118 Type GetSumAccumulationType(Type input_type) {
119 MLIRContext *ctx = input_type.getContext();
120 if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
121 if (input_type.isInteger(8) || input_type.isInteger(16))
122 return IntegerType::get(32, ctx);
123 return input_type;
124 }
125
126 // Returns axis in HLO format from TF elements attr with exactly one element
127 // containing axis in the TensorFlow format. TensorFlow format supports negative
128 // indexing unlike HLO.
GetHLOAxisFromTFAxis(ElementsAttr attr,int64_t rank,Builder * b)129 static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank,
130 Builder *b) {
131 SmallVector<uint64_t, 1> index(attr.getType().getRank(), 0);
132 int64_t axis = attr.getValue<IntegerAttr>(index).getInt();
133 if (axis < 0) {
134 axis += rank;
135 }
136 return b->getI64IntegerAttr(axis);
137 }
138
GetHLOAxisFromTFAxis(IntegerAttr attr,int64_t rank,Builder * b)139 static IntegerAttr GetHLOAxisFromTFAxis(IntegerAttr attr, int64_t rank,
140 Builder *b) {
141 int64_t axis = attr.getInt();
142 if (axis < 0) {
143 axis += rank;
144 }
145 return b->getI64IntegerAttr(axis);
146 }
147
148 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
149 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
150 // be negative, in which case, the corresponding HLO axis is
151 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)152 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
153 int64_t rank) {
154 DenseIntElementsAttr attrs;
155 if (!matchPattern(value, m_Constant(&attrs)) ||
156 attrs.getType().getRank() != 0) {
157 return llvm::None;
158 }
159 int64_t axis = attrs.getValue<IntegerAttr>({}).getInt();
160 return axis < 0 ? axis + rank : axis;
161 }
162
163 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
164 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)165 static ConvertOp CastValueToI64(Location loc, Value value,
166 PatternRewriter *rewriter) {
167 return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
168 }
169
170 // Returns size of dimension at the specified index, if ranked tensor.
171 // Otherwise, returns -1.
172 //
173 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)174 int64_t GetDimSize(Type ty, int64_t index) {
175 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
176 if (!ranked_ty) return -1;
177
178 return ranked_ty.getDimSize(index);
179 }
180
181 template <typename T>
ToTensorShape(llvm::ArrayRef<T> sizes)182 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
183 return tensorflow::TensorShape(
184 llvm::SmallVector<tensorflow::int64, 4>(sizes.begin(), sizes.end()));
185 }
186
187 // Returns minimal value for the given int or float element type.
GetMinValueForType(Type ty,Location loc,PatternRewriter * rewriter)188 static ConstOp GetMinValueForType(Type ty, Location loc,
189 PatternRewriter *rewriter) {
190 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
191
192 DenseElementsAttr attr;
193 if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
194 APFloat neg_inf =
195 APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true);
196 attr = DenseElementsAttr::get(scalar_ty, neg_inf);
197 } else {
198 auto int_ty = ty.cast<IntegerType>();
199 APInt min_val = APInt::getSignedMinValue(int_ty.getWidth());
200 attr = DenseElementsAttr::get(scalar_ty, min_val);
201 }
202 return rewriter->create<ConstOp>(loc, attr);
203 }
204
205 // Returns maximal value for the given int or float element type.
GetMaxValueForType(Type ty,Location loc,PatternRewriter * rewriter)206 static ConstOp GetMaxValueForType(Type ty, Location loc,
207 PatternRewriter *rewriter) {
208 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
209
210 DenseElementsAttr attr;
211 if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
212 APFloat pos_inf =
213 APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false);
214 attr = DenseElementsAttr::get(scalar_ty, pos_inf);
215 } else {
216 auto int_ty = ty.cast<IntegerType>();
217 APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth());
218 attr = DenseElementsAttr::get(scalar_ty, max_val);
219 }
220 return rewriter->create<ConstOp>(loc, attr);
221 }
222
223 // Returns int or float scalar DenseElementsAttr attribute with the given
224 // element type and the value.
GetScalarConstOfType(Type ty,Location loc,int64_t raw_value,PatternRewriter * rewriter)225 static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
226 PatternRewriter *rewriter) {
227 return rewriter->create<ConstOp>(loc, xla::GetScalarOfType(ty, raw_value));
228 }
229
230 // Builds body for reduce op by using the using the template binary op as the
231 // reducer op.
232 template <typename Op>
BuildReduceBody(Type element_type,Region * body,OpBuilder * builder)233 static void BuildReduceBody(Type element_type, Region *body,
234 OpBuilder *builder) {
235 OpBuilder::InsertionGuard guard(*builder);
236 Block *block = builder->createBlock(body);
237
238 // Block arguments are scalars of the given element type.
239 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
240 block->addArguments({type, type});
241
242 Location loc = body->getLoc();
243 auto reducer =
244 builder->create<Op>(loc, block->getArgument(0), block->getArgument(1),
245 /*broadcast_dimensions=*/nullptr);
246 builder->create<ReturnOp>(loc, reducer.getResult());
247 }
248
249 // Builds region taking two arguments and returning second argument as the
250 // result. Corresponds to the function f(x, y) = y.
251 // Used in Scatter op's computation to update specific elements.
BuildBinaryAssignmentRegion(Type element_type,Region * region,OpBuilder * builder)252 static void BuildBinaryAssignmentRegion(Type element_type, Region *region,
253 OpBuilder *builder) {}
254
255 // Builds a set of operations for applying reduction on the input value. A
256 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)257 static Value ApplyReduction(Location loc, Value input,
258 DenseIntElementsAttr reduce_dims,
259 OpBuilder *builder) {
260 auto reduce_dims_op = builder->create<ConstOp>(loc, reduce_dims);
261 return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
262 builder->getBoolAttr(false));
263 }
264
265 // Creates a xla_hlo.rng_uniform op with `builder` to generate `num_elements`
266 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)267 static xla_hlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
268 int lower_limit,
269 int upper_limit,
270 OpBuilder *builder) {
271 auto i32_type = builder->getIntegerType(32);
272 auto key_type = RankedTensorType::get({num_elements}, i32_type);
273 auto shape_tensor = builder->create<xla_hlo::ConstOp>(
274 loc, GetI64ElementsAttr({num_elements}, builder));
275
276 auto lower = builder->create<xla_hlo::ConstOp>(
277 loc, builder->getI32IntegerAttr(lower_limit));
278 auto upper = builder->create<xla_hlo::ConstOp>(
279 loc, builder->getI32IntegerAttr(upper_limit));
280
281 return builder->create<xla_hlo::RngUniformOp>(loc, key_type, lower, upper,
282 shape_tensor);
283 }
284
285 using WhileBodyFnType = llvm::function_ref<void(
286 Location loc, Value iteration, ArrayRef<Value> old_values,
287 SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
288
289 // Creates a xla_hlo.while op with `builder` to loop `num_interations` times,
290 // each time calling the given `body_fn` on a set of values to generate a new
291 // set of values. Returns the final set of values via `final_values`. The
292 // initial set of values is passed in via `init_values`.
293 //
294 // This effectively does:
295 //
296 // ```c++
297 // SmallVector<Values, 4> old_values = init_values;
298 // SmallVector<Values, 4> new_values;
299 // for (int i = 0; i < num_iterations; ++i) {
300 // body_fn(old_values, &new_values, ...);
301 // old_values = new_values;
302 // }
303 // ```
304 //
305 // Under the hood an induction variable is prepended to values to control the
306 // number of iterations, but that is transparent to `body_fn`, which does not
307 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)308 static void CreateWhile32(Location loc, int num_iterations,
309 WhileBodyFnType body_fn, ArrayRef<Value> init_values,
310 SmallVectorImpl<Value> *final_values,
311 OpBuilder *builder) {
312 int value_count = init_values.size() + 1;
313
314 // Prepend a loop induction variable to the initial values.
315 SmallVector<Value, 2> init_values_with_loop_iv;
316 init_values_with_loop_iv.reserve(value_count);
317 // The initial value for the loop induction variable is 0.
318 init_values_with_loop_iv.push_back(
319 builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
320 init_values_with_loop_iv.append(init_values.begin(), init_values.end());
321
322 // Prepare the initial tuple for the while op.
323 auto init_tuple =
324 builder->create<xla_hlo::TupleOp>(loc, init_values_with_loop_iv);
325 auto tuple_type = init_tuple.getType();
326
327 // Create the while op.
328 auto while_op = builder->create<xla_hlo::WhileOp>(loc, init_tuple);
329
330 {
331 OpBuilder::InsertionGuard guard(*builder);
332
333 // Build up the only block in the condition region. It should take one
334 // argument of the loop's tuple type.
335 Region &condition = while_op.cond();
336 Block *block = builder->createBlock(&condition);
337 BlockArgument arg = block->addArgument(tuple_type);
338
339 // Get the loop induction variable and compare it against the upper limit.
340 auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
341 auto upper_limit = builder->create<xla_hlo::ConstOp>(
342 loc, builder->getI32IntegerAttr(num_iterations));
343 StringAttr compare_direction = StringAttr::get("LT", builder->getContext());
344 Value compare = builder->create<xla_hlo::CompareOp>(
345 loc, loop_iv, upper_limit,
346 /*broadcast_dimensions=*/nullptr, compare_direction);
347
348 builder->create<xla_hlo::ReturnOp>(loc, compare);
349 }
350
351 {
352 OpBuilder::InsertionGuard guard(*builder);
353
354 // Build up the only block in the body region. It should take one
355 // argument of the loop's tuple type.
356 Region &body = while_op.body();
357 Block *block = builder->createBlock(&body);
358 BlockArgument arg = block->addArgument(tuple_type);
359
360 SmallVector<Value, 4> old_values; // From the previous iteration
361 SmallVector<Value, 4> new_values; // Generated by this iteration
362 old_values.reserve(value_count);
363 new_values.reserve(value_count);
364
365 // Unpack the tuple value from the last iteration.
366 for (int i = 0; i < value_count; ++i)
367 old_values.push_back(builder->create<GetTupleElementOp>(loc, arg, i));
368
369 // Feed all values excluding the loop induction variable to body_fn.
370 body_fn(loc, old_values[0], llvm::makeArrayRef(old_values).drop_front(),
371 &new_values, builder);
372
373 // Increment the loop induction variable by one.
374 auto one =
375 builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
376 auto no_broadcast_dims = GetI64ElementsAttr({}, builder);
377 auto plus_one = builder->create<xla_hlo::AddOp>(loc, old_values[0], one,
378 no_broadcast_dims);
379 // Prepend with the updated loop induction variable.
380 new_values.insert(new_values.begin(), plus_one);
381
382 Value updated_tuple = builder->create<xla_hlo::TupleOp>(loc, new_values);
383
384 builder->create<xla_hlo::ReturnOp>(loc, updated_tuple);
385 }
386
387 final_values->reserve(init_values.size());
388 for (int i = 0, e = init_values.size(); i < e; ++i)
389 final_values->push_back(
390 builder->create<GetTupleElementOp>(loc, while_op, i + 1));
391 }
392
393 //===----------------------------------------------------------------------===//
394 // BatchNorm op utilities.
395 //===----------------------------------------------------------------------===//
396
getFeatureDimensionAttr(Builder & b,StringAttr format,Value input)397 static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
398 Value input) {
399 return b.getI64IntegerAttr(
400 GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
401 }
402
403 //===----------------------------------------------------------------------===//
404 // Bias op utilities.
405 //===----------------------------------------------------------------------===//
406
407 // Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd.
408 // Requires input to have ranked tensor.
getBiasFeatureDimension(Builder & b,StringAttr format,Value input)409 static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
410 StringAttr format,
411 Value input) {
412 auto inputType = input.getType().cast<RankedTensorType>();
413 size_t featureDim = GetFeatureDimension(format, inputType);
414 RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
415 return DenseIntElementsAttr::get(type, featureDim);
416 }
417
418 //===----------------------------------------------------------------------===//
419 // MatMul op utilities.
420 //===----------------------------------------------------------------------===//
421
422 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
423 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)424 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
425 if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
426 return GetI64ElementsAttr({0, 1}, b);
427 }
428
429 //===----------------------------------------------------------------------===//
430 // Pad op utilities.
431 //===----------------------------------------------------------------------===//
432
433 // Slices input attribute of rank two and returns the specified column.
434 //
435 // Always returns 64 bit integer attribute regardless of bitwidth of the input
436 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)437 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
438 ElementsAttr input, int column) {
439 auto int_attr = input.cast<DenseIntElementsAttr>();
440 auto shaped_type = int_attr.getType();
441 auto shape = shaped_type.getShape();
442
443 if (shape.size() != 2) return DenseIntElementsAttr();
444
445 llvm::SmallVector<int64_t, 4> values;
446 values.reserve(shaped_type.getNumElements() / shape[1]);
447
448 for (auto it : llvm::enumerate(int_attr.getIntValues())) {
449 if (it.index() % shape[1] == column) {
450 values.push_back(it.value().getSExtValue());
451 }
452 }
453
454 auto element_type = IntegerType::get(64, input.getContext());
455 return DenseIntElementsAttr::get(
456 RankedTensorType::get({shape[0]}, element_type), values);
457 }
458
459 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
460 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)461 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
462 auto length = tf_padding.getType().getShape()[0];
463 auto element_type = IntegerType::get(64, tf_padding.getContext());
464 return DenseIntElementsAttr::get<int64_t>(
465 RankedTensorType::get({length}, element_type), 0);
466 }
467
468 //===----------------------------------------------------------------------===//
469 // Binary op utilities.
470 //===----------------------------------------------------------------------===//
471
472 // Returns whether the two values are guaranteed to be broadcastable to the
473 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
474 // must be broadcasted with a size 1 tensor or another dynamic dimension.
475 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)476 static bool AreBroadcastCompatible(Value x, Value y) {
477 auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
478 auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
479 if (!x_rankless || !y_rankless) {
480 return false;
481 }
482
483 // Check that the shapes can be broadcasted.
484 auto shape_x = x_rankless.getShape();
485 auto shape_y = y_rankless.getShape();
486
487 int rank_diff = shape_x.size() - shape_y.size();
488 int offset_x = rank_diff > 0 ? rank_diff : 0;
489 int offset_y = rank_diff < 0 ? -rank_diff : 0;
490 for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
491 int index_x = i + offset_x;
492 int index_y = i + offset_y;
493 if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
494 (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
495 return false;
496 }
497 }
498
499 return true;
500 }
501
502 // Return a new TensorType the same rank and dimensions as the input with an
503 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)504 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
505 Type element_type) {
506 RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
507 if (ranked_type) {
508 return RankedTensorType::get(ranked_type.getShape(), element_type);
509 }
510
511 return UnrankedTensorType::get(element_type);
512 }
513
514 //===----------------------------------------------------------------------===//
515 // Softmax op utilities.
516 //===----------------------------------------------------------------------===//
517
518 // Returns a 1-d i64 elements attribute populated with numbers from start to
519 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)520 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
521 Builder *builder) {
522 int size = end - start;
523
524 SmallVector<int64_t, 4> vals;
525 vals.resize(size);
526 std::iota(vals.begin(), vals.end(), start);
527
528 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
529 return DenseIntElementsAttr::get(ty, vals);
530 }
531
532 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)533 static Type GetAccumulationType(Type ty) {
534 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
535 // repeated floating point additions.
536 return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
537 }
538
539 //===----------------------------------------------------------------------===//
540 // ArgMax/ArgMin op utilities.
541 //===----------------------------------------------------------------------===//
542
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,StringRef direction,Region * body,OpBuilder * builder)543 static void BuildArgMinMaxReductionBody(Type input_element_type,
544 Type index_element_type,
545 StringRef direction, Region *body,
546 OpBuilder *builder) {
547 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
548
549 Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
550 Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
551 Block *block = builder->createBlock(body);
552 block->addArguments({input_type, index_type, input_type, index_type});
553
554 Location loc = body->getLoc();
555 StringAttr compare_direction =
556 StringAttr::get(direction, builder->getContext());
557 Value compare = builder->create<CompareOp>(
558 loc, block->getArgument(0), block->getArgument(2),
559 /*broadcast_dimensions=*/nullptr, compare_direction);
560
561 Value selected_input = builder->create<SelectOp>(
562 loc, input_type, compare, block->getArgument(0), block->getArgument(2));
563 Value selected_index = builder->create<SelectOp>(
564 loc, index_type, compare, block->getArgument(1), block->getArgument(3));
565
566 Value return_values[] = {selected_input, selected_index};
567 builder->create<ReturnOp>(loc, return_values);
568 }
569
570 //===----------------------------------------------------------------------===//
571 // Slice op utilities.
572 //===----------------------------------------------------------------------===//
573
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)574 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
575 DenseIntElementsAttr slice_sizes) {
576 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
577 int64_t input_rank = input_ty.getRank();
578 ArrayRef<int64_t> input_shape = input_ty.getShape();
579 DenseIntElementsAttr constant_start_indices;
580 if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
581 for (int64_t i = 0; i < input_rank; ++i) {
582 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
583 int64_t input_size = input_shape[i];
584 if (slice_size < 0 || (input_size != -1 && slice_size > input_size)) {
585 return false;
586 }
587 }
588 return true;
589 }
590
591 for (int64_t i = 0; i < input_rank; ++i) {
592 int64_t input_size = input_shape[i];
593 int64_t start_index =
594 constant_start_indices.getValue<IntegerAttr>(i).getInt();
595 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
596 if (start_index < 0) return false;
597 // A slice_size of -1 means "all elements from start_index to the end".
598 // We can't support this semantics for dynamic shapes.
599 if (slice_size == -1) {
600 if (input_size == -1) return false;
601 slice_size = input_size - start_index;
602 }
603 if (input_size != -1 && start_index + slice_size > input_size) {
604 return false;
605 }
606 }
607
608 return true;
609 }
610
611 // TF slice size can be -1, which represents all elements from start_index to
612 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
613 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)614 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
615 Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
616 Builder *builder) {
617 DenseIntElementsAttr constant_start_indices;
618 if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
619 return xla::ConvertElementsAttr(slice_sizes, builder->getIntegerType(64))
620 .cast<DenseIntElementsAttr>();
621 }
622
623 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
624 int64_t input_rank = input_ty.getRank();
625 ArrayRef<int64_t> input_shape = input_ty.getShape();
626 SmallVector<int64_t, 4> normalized_sizes;
627
628 for (int64_t i = 0; i < input_rank; ++i) {
629 int64_t input_size = input_shape[i];
630 int64_t start_index =
631 constant_start_indices.getValue<IntegerAttr>(i).getInt();
632 int64_t slice_size = slice_sizes.getValue<IntegerAttr>(i).getInt();
633 normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
634 : slice_size);
635 }
636
637 return GetI64ElementsAttr(normalized_sizes, builder);
638 }
639
640 //===----------------------------------------------------------------------===//
641 // Sort op utilities.
642 //===----------------------------------------------------------------------===//
643
644 // Builds the region `body` for xla_hlo.sort's comparator: for each type in
645 // `element_types`, create two block arguments, one for lhs and one for rhs, and
646 // generates xla_hlo.compare op to compare them with the given `direction`.
647 //
648 // Note that this right now only does comparision on the first pair of block
649 // arguments.
BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,StringRef direction,Region * body,OpBuilder * builder)650 static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
651 StringRef direction, Region *body,
652 OpBuilder *builder) {
653 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
654
655 Block *block = builder->createBlock(body);
656 // Add two arguments for each element type.
657 for (Type element_type : element_types) {
658 TensorType tensor_type = RankedTensorType::get({}, element_type);
659 block->addArguments({tensor_type, tensor_type});
660 }
661
662 Location loc = body->getLoc();
663 StringAttr compare_direction =
664 StringAttr::get(direction, builder->getContext());
665 Value compare = builder->create<xla_hlo::CompareOp>(
666 loc, block->getArgument(0), block->getArgument(1),
667 /*broadcast_dimensions=*/nullptr, compare_direction);
668
669 builder->create<xla_hlo::ReturnOp>(loc, compare);
670 }
671
672 //===----------------------------------------------------------------------===//
673 // Op converters.
674 //===----------------------------------------------------------------------===//
675
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dim_indices,tensorflow::TensorFormat format,Builder * builder)676 NamedAttribute GetConvDimensionNumbersAttr(
677 ArrayRef<int64_t> spatial_dim_indices, tensorflow::TensorFormat format,
678 Builder *builder) {
679 int64_t num_spatial_dims = spatial_dim_indices.size();
680 int64_t num_dims = num_spatial_dims + 2;
681
682 IntegerAttr batch_dim =
683 builder->getI64IntegerAttr(GetTensorBatchDimIndex(num_dims, format));
684 IntegerAttr feature_dim =
685 builder->getI64IntegerAttr(GetTensorFeatureDimIndex(num_dims, format));
686 DenseIntElementsAttr spatial_dims =
687 GetI64ElementsAttr(spatial_dim_indices, builder);
688
689 // Filters data_format is always HWIO so input channels dimension is after
690 // all spatial dimensions.
691 IntegerAttr kernel_input_feature_dim =
692 builder->getI64IntegerAttr(num_spatial_dims);
693 IntegerAttr kernel_output_feature_dim =
694 builder->getI64IntegerAttr(num_spatial_dims + 1);
695 DenseIntElementsAttr kernel_spatial_dimensions =
696 GetI64ElementsAttrForSeq(0, num_spatial_dims, builder);
697
698 return builder->getNamedAttr(
699 "dimension_numbers",
700 ConvDimensionNumbers::get(
701 batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim,
702 kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim,
703 feature_dim, spatial_dims, builder->getContext()));
704 }
705
706 // Converts the TensorFlow conv op in template to the generic HLO conv op by
707 // converting TensorFlow op attributes to HLO op attributes.
708 //
709 // Sample result for Conv2D:
710 //
711 // %conv = "xla_hlo.conv"(%input, %filter) {
712 // strides = [1, 2],
713 // paddings = [[1, 0], [1, 1]],
714 // ...
715 // }
716 //
717 // This pattern is not defined using declarative rewrite rules as computation of
718 // the paddings attribute anyway requires multiple source op attributes and
719 // result op attributes. Defining it as declarative rewrite rule will introduce
720 // some duplication in the C++ helper methods.
721 template <typename OpT, int num_spatial_dims>
722 class ConvertConv : public OpRewritePattern<OpT> {
723 public:
724 using OpRewritePattern<OpT>::OpRewritePattern;
725
matchAndRewrite(OpT op,PatternRewriter & rewriter) const726 PatternMatchResult matchAndRewrite(OpT op,
727 PatternRewriter &rewriter) const override {
728 tensorflow::TensorFormat format;
729 std::string data_format = op.data_format().str();
730 if (!FormatFromString(data_format, &format)) return Pattern::matchFailure();
731
732 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
733 auto filter_ty =
734 op.filter().getType().template dyn_cast<RankedTensorType>();
735 auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
736
737 // Input, filter and the result needs to have static shape for calculation
738 // of HLO paddings and feature group count attributes.
739 for (RankedTensorType ty : {input_ty, filter_ty, result_ty}) {
740 if (!ty || !ty.hasStaticShape()) return Pattern::matchFailure();
741 }
742
743 int num_dims = num_spatial_dims + 2;
744 tensorflow::Padding padding;
745 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
746 return Pattern::matchFailure();
747
748 auto get_int = [](Attribute attr) {
749 return attr.template cast<IntegerAttr>().getInt();
750 };
751
752 SmallVector<int64_t, 4> spatial_dim_indices;
753 SmallVector<int64_t, 4> rhs_dilations;
754 SmallVector<int64_t, 4> window_strides;
755 SmallVector<int64_t, 8> paddings;
756
757 ArrayRef<Attribute> dilations = op.dilations().getValue();
758 ArrayRef<Attribute> strides = op.strides().getValue();
759 ArrayRef<Attribute> explicit_paddings;
760 if (padding == tensorflow::Padding::EXPLICIT) {
761 // EXPLICIT padding mode and the associated attribute is limited to
762 // Conv2D. So, fetch attribute by identifier instead of the
763 // op.explicit_paddings() attribute getter.
764 explicit_paddings =
765 op.template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
766 }
767
768 for (int i = 0; i < num_spatial_dims; ++i) {
769 int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
770 spatial_dim_indices.push_back(dim);
771
772 int64_t stride = get_int(strides[dim]);
773 int64_t dilation = get_int(dilations[dim]);
774 window_strides.push_back(stride);
775 rhs_dilations.push_back(dilation);
776
777 int64_t pad_low, pad_high;
778 if (padding == tensorflow::Padding::EXPLICIT) {
779 pad_low = get_int(explicit_paddings[2 * dim]);
780 pad_high = get_int(explicit_paddings[2 * dim + 1]);
781 } else {
782 tensorflow::int64 output_size;
783 tensorflow::int64 pad_low_int64;
784 tensorflow::int64 pad_high_int64;
785 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
786 input_ty.getDimSize(dim), filter_ty.getDimSize(i), dilation, stride,
787 padding, &output_size, &pad_low_int64, &pad_high_int64);
788 if (!status.ok()) return Pattern::matchFailure();
789 pad_low = pad_low_int64;
790 pad_high = pad_high_int64;
791 }
792 paddings.push_back(pad_low);
793 paddings.push_back(pad_high);
794 }
795
796 auto rhs_dilations_attr = rewriter.getNamedAttr(
797 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
798
799 auto window_strides_attr = rewriter.getNamedAttr(
800 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
801
802 auto dimension_numbers_attr =
803 GetConvDimensionNumbersAttr(spatial_dim_indices, format, &rewriter);
804
805 int64_t input_channels =
806 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, format));
807 // Filters data_format is always HWIO so input channels dimension is after
808 // all spatial dimensions.
809 int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
810 // TensorFlow convolution op verifies that the number of input channels is
811 // divisible by the number of filter channels.
812 int64_t feature_group_count = input_channels / filter_channels;
813 auto feature_group_count_attr = rewriter.getNamedAttr(
814 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
815
816 auto batch_group_count_attr = rewriter.getNamedAttr(
817 "batch_group_count", rewriter.getI64IntegerAttr(1));
818
819 RankedTensorType paddings_ty = RankedTensorType::get(
820 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
821 auto paddings_attr = rewriter.getNamedAttr(
822 "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
823
824 SmallVector<Value, 2> operands(op.getOperands());
825 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
826 dimension_numbers_attr, feature_group_count_attr,
827 batch_group_count_attr, paddings_attr};
828 rewriter.replaceOpWithNewOp<ConvOp>(op, op.getType(), operands,
829 llvm::makeArrayRef(attrs));
830 return Pattern::matchSuccess();
831 }
832 };
833
834 using ConvertConv2D = ConvertConv<TF::Conv2DOp, /*num_spatial_dims=*/2>;
835
836 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
837 // division can result in strange behavior.
838 //
839 // floordiv = cast(floordiv(cast(left), cast(right))))
840 //
841 // %left_cast = cast(%left)
842 // %right_cast = cast(%right)
843 // %div = div(%left, %left)
844 // %floored = floor(%div)
845 // %floored_cast = cast(%floored)
846 //
847 // Required to manually specify the intermediate types.
848 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
849 public:
850 using OpRewritePattern::OpRewritePattern;
851
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const852 PatternMatchResult matchAndRewrite(TF::FloorDivOp op,
853 PatternRewriter &rewriter) const override {
854 auto l = op.x();
855 auto r = op.y();
856 auto element_type = getElementTypeOrSelf(l.getType());
857 if (!element_type.isBF16()) return matchFailure();
858
859 auto out_type = op.z().getType().cast<TensorType>();
860
861 l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
862 r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
863
864 auto intermediate = rewriter.create<TF::FloorDivOp>(
865 op.getLoc(),
866 ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
867 r);
868
869 auto floor_op =
870 rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
871 rewriter.replaceOp(op, floor_op.getResult());
872 return Pattern::matchSuccess();
873 }
874 };
875
876 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
877 // depending on arity of the op.
878 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
879 public:
880 using OpRewritePattern::OpRewritePattern;
881
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const882 PatternMatchResult matchAndRewrite(TF::EinsumOp op,
883 PatternRewriter &rewriter) const override {
884 StringAttr equation = op.getAttrOfType<StringAttr>("equation");
885 if (op.N() == 1) {
886 rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
887 op, op.getType(), *op.inputs().begin(), equation);
888 } else if (op.N() == 2) {
889 ValueRange inputs = op.inputs();
890 rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
891 inputs[1], equation);
892 } else {
893 // TensorFlow EinsumOp verifies that the number of operands are at most
894 // two.
895 return Pattern::matchFailure();
896 }
897 return Pattern::matchSuccess();
898 }
899 };
900
901 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
902 // BatchNormGradOp for training and a sequence of binary ops for inference.
903 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
904 template <typename FusedBatchNormGradOpT>
905 class ConvertFusedBatchNormGradBase
906 : public OpRewritePattern<FusedBatchNormGradOpT> {
907 public:
908 using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
909
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const910 PatternMatchResult matchAndRewrite(FusedBatchNormGradOpT op,
911 PatternRewriter &rewriter) const override {
912 Location loc = op.getLoc();
913 Value grad = op.y_backprop();
914 Value act = op.x();
915 Value scale = op.scale();
916 Value mean = op.reserve_space_1();
917 Value var = op.reserve_space_2();
918
919 // TODO(b/141785544): Update this to not require static shapes.
920 // activation shape needs to be static to convert negative indices in
921 // TensorFlow to absolute indices required by HLO.
922 RankedTensorType act_type =
923 act.getType().template dyn_cast<RankedTensorType>();
924 if (!act_type) return Pattern::matchFailure();
925 Type act_ele_type = act_type.getElementType();
926 // To support mixed precision, the statistics type, which maybe more
927 // precise than the input types, are used for this op.
928 Type kernel_type =
929 scale.getType().template cast<TensorType>().getElementType();
930 grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
931 act = rewriter.create<ConvertOp>(loc, act, kernel_type);
932
933 auto feature_dim_attr =
934 getFeatureDimensionAttr(rewriter, op.data_formatAttr(), act);
935 auto feature_dim = feature_dim_attr.getValue().getSExtValue();
936
937 // Gets the result values.
938 Value x_backprop, scale_backprop, offset_backprop;
939 if (op.is_training()) { // training
940 // TODO(b/145536565): handle GPU logic seperately.
941 // Infers the output type with the converted `act`.
942 Type feature_type = RankedTensorType::get(
943 {GetDimSize(act_type, feature_dim)}, kernel_type);
944 Type result_type = TupleType::get(
945 {act.getType(), feature_type, feature_type}, rewriter.getContext());
946
947 auto training_op = rewriter.create<BatchNormGradOp>(
948 loc, result_type, act, scale, mean, var, grad, op.epsilon(),
949 feature_dim_attr.getValue());
950
951 x_backprop =
952 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 0);
953
954 scale_backprop =
955 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 1);
956
957 offset_backprop =
958 rewriter.create<GetTupleElementOp>(loc, training_op.getResult(), 2);
959 } else { // inference
960 SmallVector<int64_t, 4> non_feature_dims;
961 for (int64_t i = 0; i < act_type.getRank(); ++i) {
962 if (i == feature_dim) continue;
963 non_feature_dims.push_back(i);
964 }
965 auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
966 auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter);
967 auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
968
969 // scratch1 = rsqrt(var + epsilon)
970 RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
971 auto epsilon = rewriter.create<ConstOp>(
972 loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
973 auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(),
974 no_broadcast_dims);
975 Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
976
977 // scratch2 = sum(y_backprop * (x - mean))
978 auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims);
979 auto weighted_grad =
980 rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims);
981 Value scratch2 =
982 ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
983
984 // x_backprop = y_backprop * (scale * scratch1)
985 auto scaled_grad =
986 rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims);
987 x_backprop =
988 rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims);
989
990 // scale_backprop = scratch2 * scratch1
991 scale_backprop =
992 rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims);
993
994 // offset_backprop = sum(y_backprop)
995 offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
996 }
997
998 x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
999 // It doesn't matter what values we provide for the last 2 results.
1000 rewriter.replaceOp(op,
1001 {/*x_backprop=*/x_backprop,
1002 /*scale_backprop=*/scale_backprop,
1003 /*offset_backprop=*/offset_backprop, op.x(), op.x()});
1004 return Pattern::matchSuccess();
1005 }
1006 };
1007
1008 using ConvertFusedBatchNormGradOp =
1009 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
1010 using ConvertFusedBatchNormGradV2Op =
1011 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
1012 using ConvertFusedBatchNormGradV3Op =
1013 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
1014
1015 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
1016 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
1017 // parameter.
1018 class ConvertFusedBatchNormV3Op
1019 : public OpRewritePattern<TF::FusedBatchNormV3Op> {
1020 public:
1021 using OpRewritePattern::OpRewritePattern;
1022
matchAndRewrite(TF::FusedBatchNormV3Op op,PatternRewriter & rewriter) const1023 PatternMatchResult matchAndRewrite(TF::FusedBatchNormV3Op op,
1024 PatternRewriter &rewriter) const override {
1025 auto feature_dim =
1026 getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
1027
1028 auto input_type_tensor = op.x().getType().dyn_cast<TensorType>();
1029 auto input_element_type = input_type_tensor.getElementType();
1030
1031 auto scale_type_tensor = op.scale().getType().dyn_cast<TensorType>();
1032 auto scale_element_type = scale_type_tensor.getElementType();
1033 // In the training case, dimensions of input tensors must be static.
1034 if (op.is_training() && ((!input_type_tensor.hasStaticShape()) ||
1035 (!scale_type_tensor.hasStaticShape()))) {
1036 return matchFailure();
1037 }
1038
1039 // TODO(b/69928690): Support mixed precision in the XLA batch
1040 // normalization operators. As a workaround, create a new x with the same
1041 // element type as scale (which may be more precise than the input type).
1042 Value bn_train_input = rewriter.create<xla_hlo::ConvertOp>(
1043 op.getLoc(), op.x(), scale_element_type);
1044 TensorType bn_train_input_type_tensor =
1045 bn_train_input.getType().cast<TensorType>();
1046
1047 if (op.is_training()) {
1048 // Training case.
1049 auto operand_shape = bn_train_input_type_tensor.getShape();
1050 // The mean and variance are each 1 dimensional arrays the size of the
1051 // feature dimension, with the same element type as the operand (x).
1052 // This shape must be constructed manually because the mean and variance
1053 // inputs are empty in the training case.
1054 Type mean_var_type = RankedTensorType::get(
1055 {operand_shape[feature_dim.getInt()]}, scale_element_type);
1056 // Op result type is a tuple of 3 values: output with same shape as input;
1057 // batch_mean, and batch_var.
1058 SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
1059 mean_var_type, mean_var_type};
1060 Type result_type = TupleType::get(operand_types, rewriter.getContext());
1061
1062 auto bn_train_op = rewriter.create<xla_hlo::BatchNormTrainingOp>(
1063 op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
1064 op.epsilon(), feature_dim.getValue());
1065 // HLO op outputs a tuple of tensors. Extract those results.
1066 auto bn_train_op_result = bn_train_op.getResult();
1067 Value y_out = rewriter.create<xla_hlo::GetTupleElementOp>(
1068 op.getLoc(), bn_train_op_result, 0);
1069 Value batch_mean = rewriter.create<xla_hlo::GetTupleElementOp>(
1070 op.getLoc(), bn_train_op_result, 1);
1071 Value batch_variance = rewriter.create<xla_hlo::GetTupleElementOp>(
1072 op.getLoc(), bn_train_op_result, 2);
1073
1074 // Apply Bessel's correction on the variance.
1075 int total_input_size = bn_train_input_type_tensor.getNumElements();
1076 int total_scale_size = scale_type_tensor.getNumElements();
1077 int sample_size = total_input_size / total_scale_size;
1078 int sample_size_minus_one = std::max(1, sample_size - 1);
1079 double factor = static_cast<double>(sample_size) /
1080 static_cast<double>(sample_size_minus_one);
1081 auto factor_const_op = rewriter.create<xla_hlo::ConstOp>(
1082 op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
1083
1084 auto corrected_variance = rewriter.create<xla_hlo::MulOp>(
1085 op.getLoc(), batch_variance.getType(), batch_variance,
1086 factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1087
1088 // Convert back to input type to stay aligned with expected output type
1089 // for TF op.
1090 y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), y_out,
1091 input_element_type);
1092
1093 // TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are
1094 // currently marked as "reserved spaces 1 and 2". They are used to
1095 // pass the per-batch mean and variance to the gradiant. Here we
1096 // maintain the same behavior by setting them to the mean and variance
1097 // calculated by BatchNormTraining. Output 5 is unused; it doesn't
1098 // matter what we pass there.
1099 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
1100 /*batch_variance=*/corrected_variance,
1101 /*reserve_space_1=*/batch_mean,
1102 /*reserve_space_2=*/corrected_variance,
1103 /*reserve_space_3=*/op.x()});
1104 } else { // Inference case.
1105 auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
1106 op.getLoc(),
1107 /*result_type=*/bn_train_input_type_tensor, bn_train_input,
1108 op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
1109 feature_dim.getValue());
1110
1111 // Convert back to input type to stay aligned with expected output type
1112 // for TF op.
1113 auto y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), bn_train_op,
1114 input_element_type);
1115
1116 // The mean, variance, and reserved space outputs of the batch norm op are
1117 // not used for inference. It doesn't matter what values we provide for
1118 // the last 5 results.
1119 rewriter.replaceOp(
1120 op, {/*y=*/y_out, /*batch_mean=*/op.x(),
1121 /*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(),
1122 /*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()});
1123 }
1124 return Pattern::matchSuccess();
1125 }
1126 };
1127
1128 // Returns padding attribute for ReduceWindow op with given params.
1129 //
1130 // Requires padding to be either 'SAME' or 'VALID' and the number of input
1131 // dimensions to be equal to the size of window dimensions and window strides.
GetReduceWindowPadding(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)1132 static DenseIntElementsAttr GetReduceWindowPadding(
1133 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
1134 ArrayAttr window_strides, StringRef padding, Builder *builder) {
1135 if (padding == "VALID") return {};
1136 DCHECK_EQ(padding.str(), "SAME");
1137
1138 llvm::SmallVector<tensorflow::int64, 4> input_shape, window_shape, strides;
1139 input_shape.reserve(input_dims.size());
1140 window_shape.reserve(window_shape.size());
1141 strides.reserve(window_strides.size());
1142
1143 for (const auto &dim : input_dims) input_shape.push_back(dim);
1144 for (Attribute attr : window_dims)
1145 window_shape.push_back(attr.cast<IntegerAttr>().getInt());
1146 for (Attribute attr : window_strides)
1147 strides.push_back(attr.cast<IntegerAttr>().getInt());
1148
1149 std::vector<std::pair<tensorflow::int64, tensorflow::int64>> paddings =
1150 ::xla::MakePadding(input_shape, window_shape, strides,
1151 ::xla::Padding::kSame);
1152 int64_t rank = paddings.size();
1153 llvm::SmallVector<int64_t, 8> flatten_paddings(rank * 2);
1154 for (int i = 0; i < rank; i++) {
1155 flatten_paddings[2 * i] = paddings[i].first;
1156 flatten_paddings[2 * i + 1] = paddings[i].second;
1157 }
1158 return DenseIntElementsAttr::get(
1159 RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
1160 flatten_paddings);
1161 }
1162
1163 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
1164 // dimensions with add as the reduction function. The reduction result is
1165 // then divided by the number of elements in the window.
1166 class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
1167 public:
1168 using OpRewritePattern::OpRewritePattern;
1169
matchAndRewrite(TF::AvgPoolOp op,PatternRewriter & rewriter) const1170 PatternMatchResult matchAndRewrite(TF::AvgPoolOp op,
1171 PatternRewriter &rewriter) const override {
1172 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1173 if (!input_type) return matchFailure();
1174
1175 // TODO(b/147217034): support other data formats.
1176 if (!IsDefaultDataFormat(op.data_format())) return matchFailure();
1177 // TODO(b/147217034): support "SAME" padding.
1178 if (op.padding() != "VALID") return matchFailure();
1179
1180 // We will do accumulation first; use a larger bitwidth if suitable.
1181 Type input_element_type = input_type.getElementType();
1182 Type sum_element_type = GetSumAccumulationType(input_element_type);
1183 Type result_type;
1184
1185 // The result type for reduction and division with the proper element type.
1186 if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>())
1187 result_type =
1188 RankedTensorType::get(ranked_type.getShape(), sum_element_type);
1189 else
1190 result_type = UnrankedTensorType::get(sum_element_type);
1191
1192 Value input_value = op.value();
1193
1194 // Convert if we need enlarge the element type's bitwidth.
1195 if (input_element_type != sum_element_type)
1196 input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
1197 sum_element_type);
1198
1199 // Create the tf.ReduceWindow op.
1200 Value init =
1201 GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
1202 DenseIntElementsAttr paddings_attr =
1203 GetReduceWindowPadding(input_type.getShape(), op.ksize(), op.strides(),
1204 op.padding(), &rewriter);
1205 auto reduce = rewriter.create<ReduceWindowOp>(
1206 op.getLoc(), result_type, input_value, init,
1207 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
1208 /*base_dilations=*/DenseIntElementsAttr(),
1209 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
1210 BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
1211
1212 // Count the number of elements in the window. The following calculation
1213 // is only valid for no paddings.
1214 SmallVector<int64_t, 4> ksize;
1215 GetI64ArrayAttrValues(op.ksize(), &ksize);
1216 int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1,
1217 std::multiplies<int64_t>());
1218
1219 // Divide by the number of elements in the window.
1220 Value divisor =
1221 GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
1222 auto batch_dims =
1223 GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter);
1224 Value result = rewriter.create<DivOp>(op.getLoc(), result_type, reduce,
1225 divisor, batch_dims);
1226
1227 // Convert back if we enlarged the element type's bitwidth.
1228 if (input_element_type != sum_element_type)
1229 result =
1230 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
1231
1232 rewriter.replaceOp(op, result);
1233 return matchSuccess();
1234 }
1235 };
1236
1237 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
1238 // dimensions with max as the reduction function.
1239 //
1240 // Sample result for VALID padding mode:
1241 //
1242 // %init = constant dense<...> : tensor<i32>
1243 // %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"]
1244 // {window_dimensions = ..., window_strides = ... }
1245 //
1246 class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
1247 public:
1248 using OpRewritePattern::OpRewritePattern;
1249
matchAndRewrite(TF::MaxPoolOp op,PatternRewriter & rewriter) const1250 PatternMatchResult matchAndRewrite(TF::MaxPoolOp op,
1251 PatternRewriter &rewriter) const override {
1252 Type element_type =
1253 op.input().getType().cast<TensorType>().getElementType();
1254 if (!element_type.isIntOrFloat()) return matchFailure();
1255 Location loc = op.getLoc();
1256 ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
1257
1258 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1259 if (!input_ty) return matchFailure();
1260 DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
1261 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
1262 auto reduce = rewriter.create<ReduceWindowOp>(
1263 loc, op.getType(), op.input(), init.getResult(),
1264 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
1265 /*base_dilations=*/DenseIntElementsAttr(),
1266 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
1267 BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
1268
1269 rewriter.replaceOp(op, reduce.getResult());
1270 return matchSuccess();
1271 }
1272 };
1273
1274 // Converts SelectV2 to HLO Select op and necessary BroadcastInDim ops on
1275 // operands.
1276 //
1277 // For example, the following source IR:
1278 //
1279 // %select = "tf.SelectV2"(%condition, %t, %e) :
1280 // (tensor<1xi1>, tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
1281 //
1282 // will be converted into:
1283 //
1284 // %pred = "xla_hlo.broadcast_in_dim"(%cond)
1285 // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
1286 // (tensor<1xi1>) -> tensor<2xi1>
1287 // %on_false = "xla_hlo.broadcast_in_dim"(%e)
1288 // {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
1289 // (tensor<1xi32>) -> tensor<2xi32>
1290 // %select = "xla_hlo.select"(%pred, %t, %on_false) :
1291 // (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1292 class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
1293 public:
1294 using OpRewritePattern::OpRewritePattern;
1295
matchAndRewrite(TF::SelectV2Op op,PatternRewriter & rewriter) const1296 PatternMatchResult matchAndRewrite(TF::SelectV2Op op,
1297 PatternRewriter &rewriter) const override {
1298 llvm::SmallVector<int64_t, 4> broadcast_then_else_shape;
1299 auto ranked_then_type = op.t().getType().dyn_cast<RankedTensorType>();
1300 auto ranked_else_type = op.e().getType().dyn_cast<RankedTensorType>();
1301 auto ranked_cond_type =
1302 op.condition().getType().dyn_cast<RankedTensorType>();
1303 if (!ranked_then_type || !ranked_then_type.hasStaticShape() ||
1304 !ranked_else_type || !ranked_else_type.hasStaticShape() ||
1305 !ranked_cond_type || !ranked_cond_type.hasStaticShape())
1306 return matchFailure();
1307
1308 if (!OpTrait::util::getBroadcastedShape(ranked_then_type.getShape(),
1309 ranked_else_type.getShape(),
1310 broadcast_then_else_shape))
1311 return matchFailure();
1312
1313 llvm::SmallVector<int64_t, 4> broadcast_shape;
1314 if (!OpTrait::util::getBroadcastedShape(broadcast_then_else_shape,
1315 ranked_cond_type.getShape(),
1316 broadcast_shape))
1317 return matchFailure();
1318
1319 auto broadcast_or_self = [&](Value value) {
1320 RankedTensorType type = value.getType().cast<RankedTensorType>();
1321 auto output_type =
1322 RankedTensorType::get(broadcast_shape, type.getElementType());
1323 if (output_type == type) return value;
1324
1325 int64_t rank = type.getRank();
1326 SmallVector<int64_t, 4> broadcast_dimensions(rank);
1327 std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
1328 broadcast_shape.size() - rank);
1329
1330 return rewriter
1331 .create<BroadcastInDimOp>(
1332 op.getLoc(), output_type, value,
1333 GetI64ElementsAttr(broadcast_dimensions, &rewriter))
1334 .getResult();
1335 };
1336
1337 // HLO SelectOp supports broadcasting for predicate/condition if
1338 // predicate/condition is a scalar.
1339 Value pred = ranked_cond_type.getRank() == 0
1340 ? op.condition()
1341 : broadcast_or_self(op.condition());
1342 Value on_true = broadcast_or_self(op.t());
1343 Value on_false = broadcast_or_self(op.e());
1344
1345 rewriter.replaceOpWithNewOp<SelectOp>(op, on_true.getType(), pred, on_true,
1346 on_false);
1347
1348 return matchSuccess();
1349 };
1350 };
1351
1352 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
1353 //
1354 // sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
1355 //
1356 // Sample result with 2-d f16 inputs with B batches of with N elements each.
1357 //
1358 // // Create an array of 0.5 the shape of the input array.
1359 // %half = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
1360 // %half_array = "xla_hlo.broadcast"(half)
1361 // {broadcast_sizes = dense<2> : tensor<1xi64>}
1362 // : (tensor<f32>) -> tensor<2xf32>
1363 //
1364 // // Compute Tanh of half the logits of the values.
1365 // %halved_logits = xla_hlo.mul %logits, %half_array : tensor<2xf32>
1366 // %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
1367 //
1368 // // Have the result of Tanh and add 0.5.
1369 // %halved_tanh = xla_hlo.mul %tanh, %half : tensor<2xf32>
1370 // %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32>
1371 //
1372 class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
1373 public:
1374 using OpRewritePattern::OpRewritePattern;
1375
matchAndRewrite(TF::SigmoidOp op,PatternRewriter & rewriter) const1376 PatternMatchResult matchAndRewrite(TF::SigmoidOp op,
1377 PatternRewriter &rewriter) const override {
1378 auto operand = op.getOperand();
1379
1380 auto scalar_one = rewriter.create<ConstOp>(
1381 op.getLoc(),
1382 rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
1383
1384 auto shaped_type = operand.getType().cast<ShapedType>();
1385 auto constant_ones = rewriter.create<BroadcastOp>(
1386 op.getLoc(), shaped_type, scalar_one,
1387 DenseIntElementsAttr::get(
1388 RankedTensorType::get({shaped_type.getRank()},
1389 rewriter.getIntegerType(64)),
1390 shaped_type.getShape()));
1391
1392 auto scaled_input = rewriter.create<MulOp>(
1393 op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
1394 auto tanh_op =
1395 rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
1396 auto mul_op =
1397 rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
1398 /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1399 auto add_op =
1400 rewriter.create<AddOp>(op.getLoc(), mul_op, constant_ones,
1401 /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1402
1403 rewriter.replaceOp(op, add_op.getResult());
1404 return matchSuccess();
1405 }
1406 };
1407
1408 // Converts Softmax and LogSoftmax to HLO ops, computing softmax with the
1409 // following formulas:
1410 //
1411 // softmax = div(exp(logits), sum(exp(logits)))
1412
1413 // log_softmax = sub(logits, log(sum(exp(logits))))
1414 //
1415 // Sample result with 2-d f16 inputs with B batches of with N elements each.
1416 //
1417 // %reduce_dim = tf.Const dense<[1]> : tensor<1xi64>
1418 //
1419 // // Subtract each element by their batches' max to improve numerical
1420 // // stability.
1421 // %max = "tf.Max"(%input, %reduce_dim)
1422 // : (tensor<BxNxf16>, tensor<1xi64>) -> tensor<Bxf16>
1423 // %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0}
1424 // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
1425 //
1426 // %exp = "xla_hlo.exp"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
1427 // %sum = "tf.Sum"(%exp, %reduce_dim)
1428 // : (tensor<BxNxf32>, tensor<1xi64>) -> tensor<Bxf32>
1429 //
1430 // // Softmax computation:
1431 // %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0}
1432 // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
1433 template <typename OpTy, bool use_log = true>
1434 class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
1435 public:
1436 using OpRewritePattern<OpTy>::OpRewritePattern;
1437
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1438 PatternMatchResult matchAndRewrite(OpTy op,
1439 PatternRewriter &rewriter) const override {
1440 Value logits = op.logits();
1441
1442 // Softmax converter requires ranked type because the XLA reduce ops used
1443 // while lowering requires dimensions attribute to reduce along.
1444 RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
1445 if (!type) return Pattern::matchFailure();
1446
1447 auto loc = op.getLoc();
1448 int rank = type.getRank();
1449
1450 // Note that the TensorFlow Softmax op verifies that the input rank is
1451 // greater than or equal to one so both of the following sequences are
1452 // valid.
1453 auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter);
1454 auto reduce_dim = rewriter.create<TF::ConstOp>(
1455 loc, GetI64ElementsAttr({rank - 1}, &rewriter));
1456
1457 // Exponential of input values and then their sum can be very large here.
1458 // Division with large denominator is numerically unstable. To improve
1459 // numerical stability, subtract each batch with their max element so that
1460 // the maximum input value is zero. It can be shown that softmax computed
1461 // after adding or subtracting all inputs in a batch using a common value
1462 // gives mathematically equivalent result.
1463 auto max_logits =
1464 rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1465 /*keep_dims=*/rewriter.getBoolAttr(false));
1466 auto shifted_logits =
1467 rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims);
1468
1469 // Exponentiate the inputs.
1470 Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
1471
1472 // Compute summation of the exponentials.
1473 auto exp_sum =
1474 rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1475 /*keep_dims=*/rewriter.getBoolAttr(false));
1476 Value sum = exp_sum.getResult();
1477
1478 if (use_log) {
1479 Value log = rewriter.create<LogOp>(loc, sum);
1480 rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims);
1481 } else {
1482 rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims);
1483 }
1484 return Pattern::matchSuccess();
1485 }
1486 };
1487
1488 // Converts Size to HLO ops, computing the size of a ranked input tensor.
1489 // TODO(b/145253252): Update this to not require ranked input tensor shapes.
1490 //
1491 // The main logic of this pattern is to calculate the size by multiplying every
1492 // dimension of the input tensor's shape together.
1493 //
1494 // For example, the following source IR:
1495 //
1496 // %size = "tf.Size"(%input) : (tensor<2x?x8xf32>) -> tensor<i32>
1497 //
1498 // will be converted into:
1499 //
1500 // %const = xla_hlo.constant dense<1> : tensor<i32>
1501 // %dim_0 = "xla_hlo.get_dimension_size"(%input) {dimension = 0 : i32} :
1502 // (tensor<2x?x8xf32>) -> tensor<i32>
1503 // %prod_0 = xla_hlo.mul %const, %dim_0 : tensor<i32>
1504 // %dim_1 = "xla_hlo.get_dimension_size"(%input) {dimension = 1 : i32} :
1505 // (tensor<2x?x8xf32>) -> tensor<i32>
1506 // %prod_1 = xla_hlo.mul %prod_0, %dim_1 : tensor<i32>
1507 // %dim_2 = "xla_hlo.get_dimension_size"(%input) {dimension = 2 : i32} :
1508 // (tensor<2x?x8xf32>) -> tensor<i32>
1509 // %size = xla_hlo.mul %prod_1, %dim_2 : tensor<i32>
1510 class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
1511 public:
1512 using OpRewritePattern::OpRewritePattern;
1513
matchAndRewrite(TF::SizeOp op,PatternRewriter & rewriter) const1514 PatternMatchResult matchAndRewrite(TF::SizeOp op,
1515 PatternRewriter &rewriter) const override {
1516 Value input = op.input();
1517 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
1518 if (!input_ty) return Pattern::matchFailure();
1519
1520 const int64_t rank = input_ty.getRank();
1521 auto result_type = op.getResult().getType();
1522 Operation *size =
1523 GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
1524 op.getLoc(), 1, &rewriter);
1525 for (int64_t i = 0; i < rank; ++i) {
1526 auto dim = rewriter.create<GetDimensionSizeOp>(
1527 op.getLoc(), result_type, input,
1528 rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
1529 size = rewriter.create<MulOp>(
1530 op.getLoc(), size->getResult(0), dim.getResult(),
1531 /*DenseIntElementsAttr=*/DenseIntElementsAttr());
1532 }
1533 rewriter.replaceOp(op, size->getResult(0));
1534
1535 return Pattern::matchSuccess();
1536 }
1537 };
1538
1539 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
1540 // split has fully static shape and the dimension to split is a constant.
1541 //
1542 // The main logic of this pattern is to calculate the index start and end range
1543 // for each slice. And this happens only on the dimension to be split; for all
1544 // other dimensions, all resultant slices' index start and end range covers the
1545 // input tensor's full range. Strides for all resultant slices are all one.
1546 //
1547 // For example, the following source IR:
1548 //
1549 // %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
1550 // %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
1551 // (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
1552 //
1553 // will be converted into:
1554 //
1555 // %0 = "xla_hlo.slice"(%input) {
1556 // limit_indices = dense<[4, 2]> : tensor<2xi64>,
1557 // start_indices = dense<0> : tensor<2xi64>,
1558 // strides = dense<1> : tensor<2xi64>} :
1559 // (tensor<4x6xf32>) -> tensor<4x2xf32>
1560 // %1 = "xla_hlo.slice"(%input) {
1561 // limit_indices = dense<4> : tensor<2xi64>,
1562 // start_indices = dense<[0, 2]> : tensor<2xi64>,
1563 // strides = dense<1> : tensor<2xi64>} :
1564 // (tensor<4x6xf32>) -> tensor<4x2xf32>
1565 // %2 = "xla_hlo.slice"(%input) {
1566 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
1567 // start_indices = dense<[0, 4]> : tensor<2xi64>,
1568 // strides = dense<1> : tensor<2xi64>} :
1569 // (tensor<4x6xf32>) -> tensor<4x2xf32>
1570 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
1571 // applicable.
1572 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
1573 public:
1574 using OpRewritePattern::OpRewritePattern;
1575
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const1576 PatternMatchResult matchAndRewrite(TF::SplitOp op,
1577 PatternRewriter &rewriter) const override {
1578 // We can only split along static dimensions.
1579 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1580 if (!input_type) return matchFailure();
1581
1582 // We can only match when the split dimension is a constant scalar.
1583 DenseIntElementsAttr split_dim_attr;
1584 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
1585 return matchFailure();
1586
1587 // Get the dimension we are splitting at. Offset properly if it's negative.
1588 int64_t input_rank = input_type.getRank();
1589 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
1590 if (dim_index < 0) dim_index += input_rank;
1591
1592 // Calculate the dimension size for each slice along the split dimension.
1593 int64_t input_dim_size = input_type.getDimSize(dim_index);
1594 // If we are splitting along the dynamic dimension then we cannot compute
1595 // the static dimension length.
1596 if (TensorType::isDynamic(input_dim_size)) return matchFailure();
1597
1598 int64_t num_splits = op.getNumResults();
1599 int64_t slice_size = input_dim_size / num_splits;
1600
1601 // Get each slice's type.
1602 auto slice_shape = llvm::to_vector<4>(input_type.getShape());
1603 slice_shape[dim_index] = slice_size;
1604 Type slice_type =
1605 RankedTensorType::get(slice_shape, input_type.getElementType());
1606
1607 // Parameters for constructing each slice.
1608 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
1609 auto end_indices = llvm::to_vector<4>(input_type.getShape());
1610 SmallVector<int64_t, 4> strides(input_rank, 1);
1611
1612 // All HLO slice results used to replace the original tf.Split op.
1613 SmallVector<Value, 4> slices;
1614 slices.reserve(num_splits);
1615
1616 for (int i = 0; i < num_splits; ++i) {
1617 begin_indices[dim_index] = i * slice_size;
1618 end_indices[dim_index] = (i + 1) * slice_size;
1619 slices.push_back(
1620 rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
1621 GetI64ElementsAttr(begin_indices, &rewriter),
1622 GetI64ElementsAttr(end_indices, &rewriter),
1623 GetI64ElementsAttr(strides, &rewriter)));
1624 }
1625
1626 rewriter.replaceOp(op, slices);
1627 return matchSuccess();
1628 }
1629 };
1630
1631 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
1632 // be split has fully static shape and the dimension to split and split sizes
1633 // are constants.
1634 //
1635 // This is similar to the conversion for tf.Split op other than that the size of
1636 // each chunk on the dimension to split is explicitly given as an op operand
1637 // and they are not necessarily the same.
1638 //
1639 // For example, given the following IR:
1640 //
1641 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
1642 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
1643 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
1644 // (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
1645 // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
1646 //
1647 // We will generate slices following slices:
1648 // %0 = "xla_hlo.slice"(%input) {
1649 // limit_indices = dense<[4, 1]> : tensor<2xi64>,
1650 // start_indices = dense<0> : tensor<2xi64>,
1651 // strides = dense<1> : tensor<2xi64>} :
1652 // (tensor<4x6xf32>) -> tensor<4x1xf32>
1653 // %1 = "xla_hlo.slice"(%input) {
1654 // limit_indices = dense<[4, 3]> : tensor<2xi64>,
1655 // start_indices = dense<[0, 1]> : tensor<2xi64>,
1656 // strides = dense<1> : tensor<2xi64>} :
1657 // (tensor<4x6xf32>) -> tensor<4x2xf32>
1658 // %2 = "xla_hlo.slice"(%input) {
1659 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
1660 // start_indices = dense<[0, 3]> : tensor<2xi64>,
1661 // strides = dense<1> : tensor<2xi64>} :
1662 // (tensor<4x6xf32>) -> tensor<4x3xf32>
1663 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
1664 public:
1665 using OpRewritePattern::OpRewritePattern;
1666
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const1667 PatternMatchResult matchAndRewrite(TF::SplitVOp op,
1668 PatternRewriter &rewriter) const override {
1669 // We can only split along static dimensions.
1670 // TODO(b/145731001): enhance to support dynamic-shaped inputs.
1671 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
1672 if (!input_type) return matchFailure();
1673
1674 // We can only match when the split dimension is a constant scalar.
1675 DenseIntElementsAttr split_dim_attr;
1676 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
1677 return matchFailure();
1678
1679 // We can only match when the split sizes is a constant int vector.
1680 DenseIntElementsAttr split_sizes_attr;
1681 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
1682 return matchFailure();
1683
1684 // Get each chunck's size along the dimension to split. It may contain
1685 // dynamic sizes and we need to update it if so.
1686 SmallVector<int64_t, 4> split_sizes;
1687 int64_t total_dim_size = 0; // Total dimension size assigned to splits
1688 llvm::Optional<int> dynamic_dim_index;
1689 split_sizes.reserve(
1690 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
1691 for (auto dim : llvm::enumerate(split_sizes_attr)) {
1692 int64_t dim_val = dim.value().getSExtValue();
1693 split_sizes.push_back(dim_val);
1694 if (dim_val == ShapedType::kDynamicSize) {
1695 // We cannot have more than one dynamic dimension.
1696 assert(!dynamic_dim_index && "invalid split sizes");
1697 dynamic_dim_index = dim.index();
1698 } else {
1699 total_dim_size += dim_val;
1700 }
1701 }
1702
1703 // Get the dimension we are splitting at. Offset properly if it's negative.
1704 int64_t input_rank = input_type.getRank();
1705 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
1706 if (dim_index < 0) dim_index += input_rank;
1707
1708 int64_t input_dim_size = input_type.getDimSize(dim_index);
1709 if (TensorType::isDynamic(input_dim_size)) return matchFailure();
1710
1711 assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
1712 (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
1713 "invalid split sizes");
1714
1715 // Update the dynamic dimension with calculated concrete size.
1716 if (dynamic_dim_index)
1717 split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
1718
1719 // Parameters for constructing each slice.
1720 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
1721 auto end_indices = llvm::to_vector<4>(input_type.getShape());
1722 SmallVector<int64_t, 4> strides(input_rank, 1);
1723
1724 // All HLO slice results used to replace the original tf.Split op.
1725 SmallVector<Value, 4> slices;
1726 slices.reserve(op.getNumResults());
1727
1728 for (int i = 0; i < op.getNumResults(); ++i) {
1729 end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
1730 slices.push_back(rewriter.create<xla_hlo::SliceOp>(
1731 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
1732 GetI64ElementsAttr(end_indices, &rewriter),
1733 GetI64ElementsAttr(strides, &rewriter)));
1734 // Prepare the begin indice for the next slice.
1735 begin_indices[dim_index] = end_indices[dim_index];
1736 }
1737
1738 rewriter.replaceOp(op, slices);
1739 return matchSuccess();
1740 }
1741 };
1742
1743 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
1744 // negative strides and Reshape op to update the output shape. Indices and
1745 // strides operands are converted to attributes with non-negative indexing.
1746 //
1747 // For example with an op like following,
1748 // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
1749 // : tensor<AxBxf32> -> tensor<Pxf32>
1750 //
1751 // Output would be:
1752 // %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
1753 // %sliced = "xla_hlo.Slice" (%input)
1754 // {start_indices = ..., limit_indices = ..., strides = ...}
1755 // %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
1756 //
1757 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
1758 public:
1759 using OpRewritePattern::OpRewritePattern;
1760
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const1761 PatternMatchResult matchAndRewrite(TF::StridedSliceOp op,
1762 PatternRewriter &rewriter) const override {
1763 // Input shape needs to be static to convert negative indices in TensorFlow
1764 // to absolute indices required by HLO.
1765 //
1766 // TODO(hinsu): Relax this constraint for ops without negative indices and
1767 // strides.
1768 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1769 if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
1770 ArrayRef<int64_t> input_shape = input_ty.getShape();
1771
1772 // Output shape needs to be static to apply 'new_axis_mask' or
1773 // 'shrink_axis_mask' by reshaping tensor after slice.
1774 //
1775 // TODO(hinsu): Relax this constraint for ops without the above masks.
1776 auto result_ty = op.getType().dyn_cast<RankedTensorType>();
1777 if (!result_ty || !result_ty.hasStaticShape()) return matchFailure();
1778
1779 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
1780 if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides))
1781 return matchFailure();
1782
1783 SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
1784 dims_to_reverse;
1785 int64_t input_rank = input_ty.getRank();
1786 hlo_begin_indices.reserve(input_rank);
1787 hlo_end_indices.reserve(input_rank);
1788 hlo_strides.reserve(input_rank);
1789
1790 int64_t indices_elements = begin_indices.size();
1791 if (input_rank < indices_elements) return matchFailure();
1792
1793 // Convert from TensorFlow negative or out of range indices and strides
1794 // values to legal HLO Slice attributes.
1795 for (int i = 0, e = indices_elements; i != e; i++) {
1796 int64_t begin = begin_indices[i];
1797 int64_t end = end_indices[i];
1798 int64_t stride = strides[i];
1799
1800 if (stride < 0) {
1801 // Negative stride means that the output values are computed starting
1802 // from end until begin. Mark the dimension for reversal before slice
1803 // and compute indices for the reversed input.
1804 dims_to_reverse.push_back(i);
1805 begin = (input_shape[i] - 1) - begin;
1806 end = (input_shape[i] - 1) - end;
1807 stride = -stride;
1808 }
1809
1810 // Unlike TensorFlow, HLO requires begin and end values to be within
1811 // range.
1812 begin = std::max(int64_t(0), begin);
1813 end = std::max(begin, end);
1814 end = std::min(end, input_shape[i]);
1815
1816 hlo_begin_indices.push_back(begin);
1817 hlo_end_indices.push_back(end);
1818 hlo_strides.push_back(stride);
1819 }
1820
1821 Location loc = op.getLoc();
1822 Value input = op.input();
1823 if (!dims_to_reverse.empty())
1824 input = rewriter.create<ReverseOp>(
1825 loc, input_ty, op.input(),
1826 GetI64ElementsAttr(dims_to_reverse, &rewriter));
1827 auto sliced = rewriter.create<SliceOp>(
1828 loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
1829 GetI64ElementsAttr(hlo_end_indices, &rewriter),
1830 GetI64ElementsAttr(hlo_strides, &rewriter));
1831
1832 // Reshape slice result so that the shape is updated depending on
1833 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
1834 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
1835 return matchSuccess();
1836 }
1837 };
1838
1839 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
1840 //
1841 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
1842 // the reverse: it propagates the graident for the sliced tensor to the original
1843 // input tensor by doing padding with zeros. The main logic is calculating the
1844 // indices and strides for padding.
1845 class ConvertStridedSliceGradOp
1846 : public OpRewritePattern<TF::StridedSliceGradOp> {
1847 public:
1848 using OpRewritePattern::OpRewritePattern;
1849
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const1850 PatternMatchResult matchAndRewrite(TF::StridedSliceGradOp op,
1851 PatternRewriter &rewriter) const override {
1852 // We need constant input shape to perform padding calculations later.
1853 DenseIntElementsAttr input_shape_attr;
1854 if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
1855 return matchFailure();
1856
1857 // We also need constant begin/end indices and strides to perform padding
1858 // calculations.
1859 // Bounded shape after performing strided slice
1860 SmallVector<int64_t, 4> shape;
1861 // Bounded begin, end, and strides for strided slice
1862 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
1863 if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
1864 &strides))
1865 return matchFailure();
1866
1867 Value grad = op.dy();
1868 Type element_type = grad.getType().cast<ShapedType>().getElementType();
1869
1870 // Perform reshape to undo any new/shrink axies done by strided slice.
1871 grad = rewriter.create<xla_hlo::ReshapeOp>(
1872 op.getLoc(), RankedTensorType::get(shape, element_type), grad);
1873
1874 SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
1875 SmallVector<int64_t, 4> dims_to_reverse;
1876 padding_low.reserve(shape.size());
1877 padding_high.reserve(shape.size());
1878 padding_interm.reserve(shape.size());
1879
1880 // Prepare padding parameters for each dimension.
1881 for (int i = 0, e = shape.size(); i < e; ++i) {
1882 int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
1883 if (strides[i] > 0) {
1884 padding_low.push_back(begin_indices[i]);
1885 padding_interm.push_back(strides[i] - 1);
1886
1887 // Pad the upper dimension up to the expected input shape. It's not
1888 // sufficient simply to use end_indices[i] to compute the padding in
1889 // cases where the stride does not divide evenly into the interval
1890 // between begin_indices[i] and end_indices[i].
1891 int64_t size =
1892 padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
1893 padding_high.push_back(input_dim - size);
1894 } else {
1895 dims_to_reverse.push_back(i);
1896 padding_high.push_back(input_dim - begin_indices[i] - 1);
1897 padding_interm.push_back(-strides[i] - 1);
1898
1899 // Pad the lower dimension up to the expected input shape.
1900 int64_t size =
1901 padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
1902 padding_low.push_back(input_dim - size);
1903 }
1904 }
1905
1906 if (!dims_to_reverse.empty()) {
1907 grad = rewriter.create<xla_hlo::ReverseOp>(
1908 op.getLoc(), grad.getType(), grad,
1909 GetI64ElementsAttr(dims_to_reverse, &rewriter));
1910 }
1911
1912 auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
1913 rewriter.replaceOpWithNewOp<xla_hlo::PadOp>(
1914 op, op.getType(), grad, zero,
1915 GetI64ElementsAttr(padding_low, &rewriter),
1916 GetI64ElementsAttr(padding_high, &rewriter),
1917 GetI64ElementsAttr(padding_interm, &rewriter));
1918 return matchSuccess();
1919 }
1920 };
1921
1922 /// Converts the RangeOp tensorflow op to a xla_hlo.iota op with a scaling and
1923 /// offset applied to generate the range values. The output tensor needs to
1924 /// have a static shape.
1925 ///
1926 /// For example an op like the following:
1927 /// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
1928 /// : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
1929 ///
1930 /// Output would be:
1931 /// %iota = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
1932 /// %scaled = "xla_hlo.mul"(%iota, %delta)
1933 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
1934 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
1935 /// %result = "xla_hlo.add"(%scaled, %offset)
1936 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
1937 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
1938 ///
1939 /// Implementation is defined in C++ due to no type interface for the iota op.
1940 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
1941 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
1942
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const1943 PatternMatchResult matchAndRewrite(TF::RangeOp op,
1944 PatternRewriter &rewriter) const override {
1945 auto result = op.getResult();
1946 auto result_type = result.getType();
1947 if (!result_type.cast<ShapedType>().hasStaticShape()) {
1948 return matchFailure();
1949 }
1950
1951 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
1952 rewriter.getI64IntegerAttr(0));
1953 auto scaled = rewriter.create<MulOp>(
1954 op.getLoc(), result_type, iota, op.delta(),
1955 xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
1956 rewriter.replaceOpWithNewOp<AddOp>(
1957 op, result_type, scaled, op.start(),
1958 xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
1959 return matchSuccess();
1960 }
1961 };
1962
1963 /// Converts a generic OpTy tensorflow op to a xla_hlo.reduce op over
1964 /// ReductionOp.
1965 /// `is_accumulation` controls whether it uses higher precision for the actual
1966 /// reduction. This is set to false for ops like max where there is no precision
1967 /// concerns.
1968 template <typename Derived, typename OpTy, typename ReductionOp,
1969 bool is_accumulation = true>
1970 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
1971 using OpRewritePattern<OpTy>::OpRewritePattern;
1972
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1973 PatternMatchResult matchAndRewrite(OpTy op,
1974 PatternRewriter &rewriter) const override {
1975 // TODO(b/141785544): Update this to not require static shapes.
1976 // Input shape needs to be static to convert negative indices in TensorFlow
1977 // to absolute indices required by HLO.
1978 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1979 if (!input_ty) return this->matchFailure();
1980 ArrayRef<int64_t> input_shape = input_ty.getShape();
1981
1982 DenseIntElementsAttr dimensions;
1983 if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
1984 return this->matchFailure();
1985
1986 // Build the final shape from input_shape and dimensions using a bitmap
1987 // to mark the reduced dimensions.
1988 SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
1989 SmallVector<int64_t, 4> xla_dimensions;
1990 for (APInt index_raw : dimensions.getValues<APInt>()) {
1991 int64_t index = index_raw.getSExtValue();
1992 int64_t rank = input_shape.size();
1993 if ((index < -rank || index >= rank)) return this->matchFailure();
1994 index = (index + rank) % rank;
1995 reduced_dimensions_bitmap[index] = true;
1996 xla_dimensions.push_back(index);
1997 }
1998
1999 Location loc = op.getLoc();
2000 Type element_type = input_ty.getElementType();
2001 // Convert to an accumulation type to not lose precision when doing
2002 // repeated arithmetic operations.
2003 Type reduce_element_type =
2004 is_accumulation ? GetAccumulationType(element_type) : element_type;
2005 auto casted_input =
2006 rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
2007
2008 // Each reduction op can have a different initial value.
2009 Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
2010
2011 auto reduction = rewriter.create<ReduceOp>(
2012 loc, casted_input.getResult(), init,
2013 GetI64ElementsAttr(xla_dimensions, &rewriter));
2014 BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
2015 &rewriter);
2016 Value result = reduction.getResult(0);
2017
2018 // The mean op needs to divide by the product of the reduced dimensions.
2019 if (std::is_same<OpTy, TF::MeanOp>::value) {
2020 int64_t divisor_count = 1;
2021 for (size_t i = 0; i < input_shape.size(); ++i) {
2022 if (reduced_dimensions_bitmap[i]) {
2023 if (TensorType::isDynamic(input_shape[i])) {
2024 return this->matchFailure();
2025 }
2026 divisor_count *= input_shape[i];
2027 }
2028 }
2029 auto divisor = GetScalarConstOfType(reduce_element_type, loc,
2030 divisor_count, &rewriter);
2031 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2032 result = rewriter.create<DivOp>(loc, result, divisor.getResult(),
2033 broadcast_dims);
2034 }
2035
2036 result = rewriter.create<ConvertOp>(loc, result, element_type);
2037
2038 // Need to reshape back after the reduction if we're keeping the reduced
2039 // dimensions.
2040 if (op.keep_dims()) {
2041 result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2042 }
2043 rewriter.replaceOp(op, {result});
2044
2045 return this->matchSuccess();
2046 }
2047 };
2048
2049 // Converts Mean op to HLO Reduce op.
2050 //
2051 // %init = constant dense<...> : tensor<T>
2052 // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
2053 // {dimensions = ...}
2054 // %divisor = constant dense<...> : tensor<T>
2055 // %mean = "xla_hlo.div"(%sum, %divisor)
2056 class ConvertMeanOp
2057 : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
2058 public:
2059 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2060 static Value GetInitialValue(Type reduce_element_type, Location loc,
2061 PatternRewriter *rewriter) {
2062 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2063 }
2064 };
2065
2066 // Converts Sum op to HLO Reduce op.
2067 //
2068 // %init = constant dense<...> : tensor<T>
2069 // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
2070 // {dimensions = ...}
2071 class ConvertSumOp
2072 : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
2073 public:
2074 using GenericConvertReductionOp::GenericConvertReductionOp;
2075
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2076 static Value GetInitialValue(Type reduce_element_type, Location loc,
2077 PatternRewriter *rewriter) {
2078 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2079 }
2080 };
2081
2082 // Converts Max op to HLO Reduce op.
2083 //
2084 // %init = constant dense<...> : tensor<T>
2085 // %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"]
2086 // {dimensions = ...}
2087 class ConvertMaxOp
2088 : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
2089 /* is_accumulation= */ false> {
2090 public:
2091 using GenericConvertReductionOp::GenericConvertReductionOp;
2092
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2093 static Value GetInitialValue(Type reduce_element_type, Location loc,
2094 PatternRewriter *rewriter) {
2095 return GetMinValueForType(reduce_element_type, loc, rewriter);
2096 }
2097 };
2098
2099 // Converts Min op to HLO Reduce op.
2100 //
2101 // %init = constant dense<...> : tensor<T>
2102 // %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.min"]
2103 // {dimensions = ...}
2104 class ConvertMinOp
2105 : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
2106 /* is_accumulation= */ false> {
2107 public:
2108 using GenericConvertReductionOp::GenericConvertReductionOp;
2109
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2110 static Value GetInitialValue(Type reduce_element_type, Location loc,
2111 PatternRewriter *rewriter) {
2112 return GetMaxValueForType(reduce_element_type, loc, rewriter);
2113 }
2114 };
2115
2116 // Converts Prod op to HLO Reduce op.
2117 //
2118 // %init = constant dense<...> : tensor<T>
2119 // %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.mul"]
2120 // {dimensions = ...}
2121 class ConvertProdOp
2122 : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
2123 public:
2124 using GenericConvertReductionOp::GenericConvertReductionOp;
2125
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2126 static Value GetInitialValue(Type reduce_element_type, Location loc,
2127 PatternRewriter *rewriter) {
2128 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
2129 }
2130 };
2131
2132 // Converts All op to HLO Reduce op.
2133 //
2134 // %init = constant dense<...> : tensor<T>
2135 // %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"]
2136 // {dimensions = ...}
2137 class ConvertAllOp
2138 : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
2139 public:
2140 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2141 static Value GetInitialValue(Type reduce_element_type, Location loc,
2142 PatternRewriter *rewriter) {
2143 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
2144 }
2145 };
2146
2147 // Converts Any op to HLO Reduce op.
2148 //
2149 // %init = constant dense<...> : tensor<T>
2150 // %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"]
2151 // {dimensions = ...}
2152 class ConvertAnyOp
2153 : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
2154 public:
2155 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)2156 static Value GetInitialValue(Type reduce_element_type, Location loc,
2157 PatternRewriter *rewriter) {
2158 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
2159 }
2160 };
2161
2162 // Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform
2163 // a reduction on the original input and the corresponding index. The reduction
2164 // sub-computation selects the max (or min) value and the index for the value.
2165 // Derived: is the resulting derived class of this class.
2166 // OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
2167 template <typename Derived, typename OpTy>
2168 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
2169 using OpRewritePattern<OpTy>::OpRewritePattern;
2170
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2171 PatternMatchResult matchAndRewrite(OpTy op,
2172 PatternRewriter &rewriter) const override {
2173 RankedTensorType input_type =
2174 op.input().getType().template dyn_cast<RankedTensorType>();
2175 if (!input_type) {
2176 return this->matchFailure();
2177 }
2178
2179 Type input_element_type = input_type.getElementType();
2180 // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
2181 // tf.ArgMax doesn't support complex data types, this check can be removed.
2182 if (!input_element_type.isIntOrFloat()) return this->matchFailure();
2183
2184 Location loc = op.getLoc();
2185 Value init_value =
2186 Derived::GetInitialValue(input_element_type, loc, rewriter);
2187
2188 RankedTensorType output_type =
2189 op.output().getType().template dyn_cast<RankedTensorType>();
2190 if (!output_type) {
2191 return this->matchFailure();
2192 }
2193
2194 Type index_element_type = output_type.getElementType();
2195 Value index_init_value =
2196 GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
2197
2198 RankedTensorType index_type =
2199 RankedTensorType::get(input_type.getShape(), index_element_type);
2200
2201 llvm::Optional<int64_t> optional_axis =
2202 GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
2203 if (!optional_axis.hasValue()) {
2204 return this->matchFailure();
2205 }
2206 int64_t axis = optional_axis.getValue();
2207
2208 IntegerAttr iota_dimension =
2209 IntegerAttr::get(rewriter.getIntegerType(64), axis);
2210 Value index_values =
2211 rewriter.create<IotaOp>(loc, index_type, iota_dimension);
2212
2213 std::vector<int64_t> dimensions = input_type.getShape();
2214 dimensions.erase(dimensions.begin() + axis);
2215 ArrayRef<int64_t> reduction_result_shape(dimensions);
2216
2217 Value operands[] = {op.input(), index_values};
2218 Value init_values[] = {init_value, index_init_value};
2219 DenseIntElementsAttr reduction_dimensions =
2220 GetI64ElementsAttr({axis}, &rewriter);
2221
2222 auto reduction = rewriter.create<ReduceOp>(
2223 loc, llvm::ArrayRef<Value>(operands),
2224 llvm::ArrayRef<Value>(init_values), reduction_dimensions);
2225 StringRef direction = Derived::GetDirection();
2226 BuildArgMinMaxReductionBody(input_element_type, index_element_type,
2227 direction, &reduction.body(), &rewriter);
2228
2229 rewriter.replaceOp(op, {reduction.getResult(1)});
2230 return this->matchSuccess();
2231 }
2232 };
2233
2234 // Converts tensorflow ArgMax op to xla_hlo operations. The actual
2235 // implementation is in class ConvertArgMinMaxOp:
2236 //
2237 // %init_index = constant dense<...> : tensor<T>
2238 // %init = constant dense<...> : tensor<T>
2239 // %reduce = "xla_hlo.reduce"(%selected_input, %select_index, %init,
2240 // %init_index) ["xla_hlo.arg_max"]
2241 class ConvertArgMaxOp
2242 : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
2243 public:
2244 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
2245
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)2246 static Value GetInitialValue(Type reduce_element_type, Location loc,
2247 PatternRewriter &rewriter) {
2248 return GetMinValueForType(reduce_element_type, loc, &rewriter);
2249 }
2250
GetDirection()2251 static StringRef GetDirection() { return "GT"; }
2252 };
2253
2254 // Converts TF TensorScatterUpdate op into Scatter Op with assignment:
2255 //
2256 // %result = "xla_hlo.scatter"(%tensor, %indices, %updates)
2257 // { dimensions = ... }
2258 //
2259 class ConvertTensorScatterUpdateOp
2260 : public OpRewritePattern<TF::TensorScatterUpdateOp> {
2261 public:
2262 using OpRewritePattern::OpRewritePattern;
2263
matchAndRewrite(TF::TensorScatterUpdateOp op,PatternRewriter & rewriter) const2264 PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
2265 PatternRewriter &rewriter) const override {
2266 auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
2267 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2268 auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
2269
2270 if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
2271 // Last dimension of the indices needs to known at compile time for
2272 // computation of the 'update_window_dims' attribute in the dimensions
2273 // struct.
2274 int64_t num_index_dims = indices_ty.getShape().back();
2275 if (ShapedType::isDynamic(num_index_dims)) return matchFailure();
2276
2277 int64_t tensor_rank = tensor_ty.getRank();
2278 int64_t indices_rank = indices_ty.getRank();
2279 int64_t updates_rank = updates_ty.getRank();
2280
2281 int64_t window_dims = tensor_rank - num_index_dims;
2282 auto dims_attr = ScatterDimensionNumbers::get(
2283 GetI64ElementsAttrForSeq(updates_rank - window_dims, updates_rank,
2284 &rewriter),
2285 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
2286 GetI64ElementsAttrForSeq(0, num_index_dims, &rewriter),
2287 rewriter.getI64IntegerAttr(indices_rank - 1), rewriter.getContext());
2288
2289 Location loc = op.getLoc();
2290 auto scatter = rewriter.create<ScatterOp>(
2291 loc, op.getType(), op.tensor(), op.indices(), op.updates(), dims_attr);
2292
2293 // Build region to assign the new value.
2294 [&](Region *region) {
2295 OpBuilder::InsertionGuard guard(rewriter);
2296 Block *block = rewriter.createBlock(region);
2297
2298 // Block arguments are scalars of the given element type.
2299 Type type =
2300 RankedTensorType::get(/*shape=*/{}, tensor_ty.getElementType());
2301 block->addArguments({type, type});
2302 rewriter.create<ReturnOp>(loc, block->getArgument(1));
2303 }(&scatter.update_computation());
2304
2305 rewriter.replaceOp(op, scatter.getResult());
2306 return matchSuccess();
2307 }
2308 };
2309
2310 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
2311 // For shape [S1, S2] and multiples [M1, M2],
2312 // MS1 = M1 * S1; MS2 = M2 * S2
2313 //
2314 // %broadcast = xla_hlo.broadcast_in_dim(%input) {
2315 // broadcast_dimensions = [0, 2]
2316 // }
2317 // %result = "xla_hlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
2318 // -> tensor<MS1xMS2xf32>
2319 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
2320 public:
2321 using OpRewritePattern::OpRewritePattern;
2322
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const2323 PatternMatchResult matchAndRewrite(TF::TileOp op,
2324 PatternRewriter &rewriter) const override {
2325 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
2326 if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
2327 ArrayRef<int64_t> input_shape = input_ty.getShape();
2328 Type element_type = input_ty.getElementType();
2329
2330 DenseIntElementsAttr multiples;
2331 if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
2332 multiples.getType().getRank() != 1)
2333 return matchFailure();
2334
2335 if (multiples.getNumElements() != input_shape.size()) return matchFailure();
2336
2337 SmallVector<int64_t, 8> broadcasted_shape;
2338 SmallVector<int64_t, 4> broadcast_dimensions;
2339 broadcasted_shape.reserve(input_shape.size() * 2);
2340 broadcast_dimensions.reserve(input_shape.size());
2341 for (auto multiple_and_input :
2342 llvm::zip(multiples.getValues<APInt>(), input_shape)) {
2343 int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
2344 int64_t input_size = std::get<1>(multiple_and_input);
2345
2346 if (multiple < 0) return matchFailure();
2347
2348 // Line input up with the next dimension in broadcasted_shape
2349 // when broadcasting.
2350 broadcast_dimensions.push_back(broadcasted_shape.size());
2351 int64_t output_size = input_size * multiple;
2352 if (input_size == 1 || multiple == 1) {
2353 // Special case for when normal broadcasting will just work.
2354 broadcasted_shape.push_back(output_size);
2355 } else {
2356 // Tiling will happen for this dimension during the ReshapeOp below.
2357 broadcasted_shape.push_back(input_size);
2358 broadcasted_shape.push_back(multiple);
2359 }
2360 }
2361 Location loc = op.getLoc();
2362 Type broadcasted_type =
2363 RankedTensorType::get(broadcasted_shape, element_type);
2364 Type output_type = op.getType();
2365
2366 Value result = rewriter.create<BroadcastInDimOp>(
2367 loc, broadcasted_type, op.input(),
2368 GetI64ElementsAttr(broadcast_dimensions, &rewriter));
2369
2370 if (output_type != broadcasted_type) {
2371 result = rewriter.create<ReshapeOp>(loc, output_type, result);
2372 }
2373
2374 rewriter.replaceOp(op, {result});
2375
2376 return matchSuccess();
2377 }
2378 };
2379
2380 class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
2381 public:
2382 using OpRewritePattern::OpRewritePattern;
2383
matchAndRewrite(TF::MaxPoolGradOp op,PatternRewriter & rewriter) const2384 PatternMatchResult matchAndRewrite(TF::MaxPoolGradOp op,
2385 PatternRewriter &rewriter) const override {
2386 Location loc = op.getLoc();
2387
2388 Type element_type =
2389 op.orig_input().getType().cast<TensorType>().getElementType();
2390
2391 // Compute paddings using the original input and kernel shape and strides.
2392 // Here, ReduceWindow op as used as the MaxPool op is lowered to the
2393 // ReduceWindow op.
2394 auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
2395 if (!input_ty) return matchFailure();
2396 DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
2397 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2398
2399 auto result = rewriter.create<SelectAndScatterOp>(
2400 loc, op.getType(), op.orig_input(), op.grad(),
2401 GetScalarConstOfType(element_type, loc, 0, &rewriter),
2402 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2403 paddings_attr);
2404
2405 BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
2406 {
2407 OpBuilder::InsertionGuard guard(rewriter);
2408 Block *block = rewriter.createBlock(&result.select());
2409
2410 // Block arguments are scalars of the given element type.
2411 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
2412 block->addArguments({type, type});
2413
2414 auto reducer = rewriter.create<CompareOp>(
2415 loc, block->getArgument(0), block->getArgument(1),
2416 /*broadcast_dimensions=*/nullptr,
2417 StringAttr::get("GE", rewriter.getContext()));
2418 rewriter.create<ReturnOp>(loc, reducer.getResult());
2419 }
2420
2421 rewriter.replaceOp(op, {result});
2422
2423 return matchSuccess();
2424 }
2425 };
2426
2427 // Converts hlo.Conv2DBackpropInputOp into:
2428 // %rev_filter = "xla_hlo.reverse"(%filter)
2429 // %result = "xla_hlo.conv"(%out_backprop, %rev_filter)
2430 class ConvertConv2DBackpropInputOp
2431 : public OpRewritePattern<TF::Conv2DBackpropInputOp> {
2432 public:
2433 using OpRewritePattern::OpRewritePattern;
2434
matchAndRewrite(TF::Conv2DBackpropInputOp op,PatternRewriter & rewriter) const2435 PatternMatchResult matchAndRewrite(TF::Conv2DBackpropInputOp op,
2436 PatternRewriter &rewriter) const override {
2437 // Unpack all of the attributes.
2438 tensorflow::TensorFormat data_format;
2439 if (!FormatFromString(op.data_format().str(), &data_format)) {
2440 return matchFailure();
2441 }
2442 tensorflow::Padding padding;
2443 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2444 return Pattern::matchFailure();
2445
2446 auto out_backprop_ty =
2447 op.out_backprop().getType().dyn_cast<RankedTensorType>();
2448 if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
2449 return matchFailure();
2450 ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
2451 auto filter_ty = op.filter().getType().dyn_cast<RankedTensorType>();
2452 if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
2453 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
2454 int num_spatial_dims = 2;
2455 Location loc = op.getLoc();
2456
2457 int num_dims = num_spatial_dims + 2;
2458 int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
2459 int feature_dim =
2460 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
2461
2462 DenseIntElementsAttr input_shape_attr;
2463 if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
2464 input_shape_attr.getType().getRank() != 1) {
2465 return matchFailure();
2466 }
2467 auto input_shape =
2468 llvm::to_vector<4>(input_shape_attr.getValues<int32_t>());
2469 if (input_shape.size() != num_dims) return matchFailure();
2470
2471 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
2472 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
2473
2474 auto strides_attr = GetI64ElementsAttr(op.strides());
2475 std::vector<tensorflow::int32> strides{
2476 strides_attr.getValues<int64_t>().begin(),
2477 strides_attr.getValues<int64_t>().end()};
2478 auto dilations_attr = GetI64ElementsAttr(op.dilations());
2479 std::vector<int> dilations{dilations_attr.getValues<int64_t>().begin(),
2480 dilations_attr.getValues<int64_t>().end()};
2481 auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings());
2482 std::vector<tensorflow::int64> explicit_paddings{
2483 explicit_paddings_attr.getValues<int64_t>().begin(),
2484 explicit_paddings_attr.getValues<int64_t>().end()};
2485
2486 int64_t in_depth = input_shape[feature_dim];
2487 int64_t filter_in_depth = filter_shape[num_spatial_dims];
2488 int64_t feature_group_count = in_depth / filter_in_depth;
2489
2490 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
2491 tensorflow::ConvBackpropDimensions dims;
2492 if (!tensorflow::ConvBackpropComputeDimensionsV2(
2493 "", num_spatial_dims, ToTensorShape<int>(input_shape),
2494 ToTensorShape<int64_t>(filter_shape),
2495 ToTensorShape<int64_t>(out_backprop_shape), dilations, strides,
2496 padding, explicit_paddings, data_format, &dims)
2497 .ok()) {
2498 return matchFailure();
2499 }
2500
2501 // Compute ConvDimensionNumbers, dilation, and padding.
2502 SmallVector<int64_t, 4> kernel_spatial_dims(num_spatial_dims);
2503 SmallVector<int64_t, 4> conv_paddings(num_spatial_dims * 2);
2504 SmallVector<int64_t, 4> lhs_dilation(num_spatial_dims);
2505 SmallVector<int64_t, 4> rhs_dilation(num_spatial_dims);
2506 SmallVector<int64_t, 4> ones(num_spatial_dims, 1);
2507 SmallVector<int64_t, 4> spatial_dims(num_spatial_dims);
2508 for (int i = 0; i < num_spatial_dims; ++i) {
2509 int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
2510 spatial_dims[i] = dim;
2511 kernel_spatial_dims[i] = i;
2512
2513 conv_paddings[i * 2] = dims.spatial_dims[i].pad_before;
2514 conv_paddings[i * 2 + 1] = dims.spatial_dims[i].pad_after;
2515 lhs_dilation[i] = dims.spatial_dims[i].stride;
2516 rhs_dilation[i] = dilations[dim];
2517 }
2518 RankedTensorType paddings_ty = RankedTensorType::get(
2519 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
2520 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_paddings);
2521 auto spatial_dims_attr = GetI64ElementsAttr(spatial_dims, &rewriter);
2522
2523 Value filter = op.filter();
2524
2525 if (feature_group_count != 1) {
2526 /*
2527 // TODO(parkers): Convert this code to mlir.
2528 filter = TransposeFilterForGroupConvolutionBackpropInput(
2529 filter, filter_shape, feature_group_count, attrs.num_spatial_dims);
2530 */
2531 return matchFailure();
2532 }
2533
2534 // Mirror the filter in the spatial dimensions.
2535 filter = rewriter.create<ReverseOp>(
2536 loc, filter, GetI64ElementsAttr(kernel_spatial_dims, &rewriter));
2537
2538 // activation gradients
2539 // = gradients (with padding and dilation) <conv> mirrored_weights
2540 Value result = rewriter.create<ConvOp>(
2541 loc, op.getType(), op.out_backprop(), filter,
2542 /*window_strides=*/GetI64ElementsAttr(ones, &rewriter),
2543 /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
2544 GetI64ElementsAttr(rhs_dilation, &rewriter),
2545 ConvDimensionNumbers::get(
2546 /*input_batch_dimension=*/batch_dim_attr,
2547 /*input_feature_dimension=*/feature_dim_attr,
2548 /*input_spatial_dimensions=*/spatial_dims_attr,
2549 // TF filter shape is [ H, W, ..., inC, outC ]
2550 // Transpose the input and output features for computing the
2551 // gradient.
2552 /*kernel_input_feature_dimension=*/
2553 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
2554 /*kernel_output_feature_dimension=*/
2555 rewriter.getI64IntegerAttr(num_spatial_dims),
2556 /*kernel_spatial_dimensions=*/
2557 GetI64ElementsAttr(kernel_spatial_dims, &rewriter),
2558 /*output_batch_dimension=*/batch_dim_attr,
2559 /*output_feature_dimension=*/feature_dim_attr,
2560 /*output_spatial_dimensions=*/spatial_dims_attr,
2561 rewriter.getContext()),
2562 rewriter.getI64IntegerAttr(feature_group_count),
2563 /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
2564 /*precision_config=*/ArrayAttr());
2565
2566 rewriter.replaceOp(op, {result});
2567
2568 return matchSuccess();
2569 }
2570 };
2571
2572 // Converts tf.Conv2DBackpropFilterOp into:
2573 // %result = "xla_hlo.conv"(%input, %out_backprop)
2574 class ConvertConv2DBackpropFilterOp
2575 : public OpRewritePattern<TF::Conv2DBackpropFilterOp> {
2576 public:
2577 using OpRewritePattern::OpRewritePattern;
2578
matchAndRewrite(TF::Conv2DBackpropFilterOp op,PatternRewriter & rewriter) const2579 PatternMatchResult matchAndRewrite(TF::Conv2DBackpropFilterOp op,
2580 PatternRewriter &rewriter) const override {
2581 // Unpack all of the attributes.
2582 tensorflow::TensorFormat data_format;
2583 if (!FormatFromString(op.data_format().str(), &data_format)) {
2584 return matchFailure();
2585 }
2586 tensorflow::Padding padding;
2587 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2588 return Pattern::matchFailure();
2589
2590 auto out_backprop_ty =
2591 op.out_backprop().getType().dyn_cast<RankedTensorType>();
2592 if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
2593 return matchFailure();
2594 ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
2595 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
2596 if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
2597 ArrayRef<int64_t> input_shape = input_ty.getShape();
2598
2599 DenseIntElementsAttr filter_shape_attr;
2600 if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
2601 filter_shape_attr.getType().getRank() != 1) {
2602 return matchFailure();
2603 }
2604
2605 auto strides_attr = GetI64ElementsAttr(op.strides());
2606 std::vector<tensorflow::int32> strides{
2607 strides_attr.getValues<int64_t>().begin(),
2608 strides_attr.getValues<int64_t>().end()};
2609 auto dilations_attr = GetI64ElementsAttr(op.dilations());
2610 SmallVector<int, 4> dilations{dilations_attr.getValues<int64_t>().begin(),
2611 dilations_attr.getValues<int64_t>().end()};
2612 auto explicit_paddings_attr = GetI64ElementsAttr(op.explicit_paddings());
2613 SmallVector<tensorflow::int64, 4> explicit_paddings{
2614 explicit_paddings_attr.getValues<int64_t>().begin(),
2615 explicit_paddings_attr.getValues<int64_t>().end()};
2616
2617 int num_spatial_dims = 2;
2618 int num_dims = num_spatial_dims + 2;
2619 int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
2620 int feature_dim =
2621 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
2622
2623 auto filter_shape =
2624 llvm::to_vector<4>(filter_shape_attr.getValues<int32_t>());
2625 if (filter_shape.size() != num_dims) return matchFailure();
2626
2627 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
2628 tensorflow::ConvBackpropDimensions dims;
2629 if (!tensorflow::ConvBackpropComputeDimensionsV2(
2630 "", num_spatial_dims, ToTensorShape<int64_t>(input_shape),
2631 ToTensorShape<int>(filter_shape),
2632 ToTensorShape<int64_t>(out_backprop_shape), dilations, strides,
2633 padding, explicit_paddings, data_format, &dims)
2634 .ok()) {
2635 return matchFailure();
2636 }
2637
2638 // The activations (inputs) form the LHS of the convolution.
2639 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
2640 // For the gradient computation, we need to:
2641 // 1. In the case of group convolution, move the num_groups dimension before
2642 // the batch dimension
2643 // 2. Swap the roles of the batch and feature dimensions.
2644 int64_t in_depth = input_shape[feature_dim];
2645 int64_t filter_in_depth = filter_shape[num_spatial_dims];
2646 int64_t feature_group_count = in_depth / filter_in_depth;
2647 if (feature_group_count != 1) {
2648 /*
2649 // TODO(parkers): translate this code to mlir.
2650 activations = TransposeInputForGroupConvolutionBackpropFilter(
2651 activations, input_shape, feature_group_count, batch_dim,
2652 feature_dim);
2653 */
2654 return matchFailure();
2655 }
2656
2657 // Compute ConvDimensionNumbers, dilation, and padding.
2658 SmallVector<int64_t, 8> conv_padding(num_spatial_dims * 2);
2659 SmallVector<int64_t, 4> rhs_dilation(num_spatial_dims);
2660 SmallVector<int64_t, 4> window_strides(num_spatial_dims);
2661 SmallVector<int64_t, 4> lhs_dilation(num_spatial_dims, 1);
2662 SmallVector<int64_t, 4> spatial_dims(num_spatial_dims);
2663 SmallVector<int64_t, 4> kernel_spatial_dims(num_spatial_dims);
2664
2665 // The filter gradients are computed by a convolution of the input
2666 // activations and the output gradients, with some appropriate padding.
2667 // See the comment at the top of conv_grad_ops.h for details.
2668
2669 for (int64_t i = 0; i < num_spatial_dims; ++i) {
2670 int64_t dim =
2671 tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2672 kernel_spatial_dims[i] = dim;
2673 // Besides padding the input, we will also expand output_rows to
2674 // expanded_out_rows = (output_rows - 1) * stride + 1
2675 // with zeros in between:
2676 //
2677 // a . . . b . . . c . . . d . . . e
2678 //
2679 // This is done by specifying the window dilation factors in the
2680 // convolution HLO below.
2681 rhs_dilation[i] = dims.spatial_dims[i].stride;
2682 window_strides[i] = dilations[dim];
2683
2684 // We will also need to pad the input with zeros such that after the
2685 // convolution, we get the right size for the filter.
2686 // The padded_in_rows should be such that when we convolve this with the
2687 // expanded_out_rows as a filter, we should get filter_rows back.
2688
2689 const int64_t padded_in_size =
2690 dims.spatial_dims[i].expanded_output_size +
2691 (dims.spatial_dims[i].filter_size - 1) * dilations[dim];
2692
2693 // However it can be smaller than input_rows: in this
2694 // case it means some of the inputs are not used.
2695 //
2696 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
2697 //
2698 // INPUT = [ A B C ]
2699 //
2700 // FILTER = [ x y ]
2701 //
2702 // and the output will only have one column: a = A * x + B * y
2703 //
2704 // and input "C" is not used at all.
2705 //
2706 // We apply negative padding in this case.
2707 const int64_t pad_total =
2708 padded_in_size - dims.spatial_dims[i].input_size;
2709
2710 // + For the EXPLICIT padding, we pad the top/left side with the explicit
2711 // padding and pad the bottom/right side with the remaining space.
2712 // + For the VALID padding, we don't pad anything on the top/left side
2713 // and pad the bottom/right side with the remaining space.
2714 // + For the SAME padding, we pad top/left side the same as bottom/right
2715 // side.
2716 //
2717 // In addition, if the padded input size is smaller than the input size,
2718 // we need to ignore some training elements of the input. We do this by
2719 // applying negative padding on the right/bottom.
2720 const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
2721 ? explicit_paddings[2 * dim]
2722 : padding == tensorflow::Padding::SAME
2723 ? std::max<int64_t>(pad_total / 2, 0)
2724 : 0;
2725 conv_padding[i * 2] = pad_before;
2726 conv_padding[i * 2 + 1] = pad_total - pad_before;
2727 }
2728
2729 RankedTensorType paddings_ty = RankedTensorType::get(
2730 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
2731 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, conv_padding);
2732 auto out_spatial_dims_attr =
2733 GetI64ElementsAttrForSeq(0, num_spatial_dims, &rewriter);
2734 auto kernel_spatial_dims_attr =
2735 GetI64ElementsAttr(kernel_spatial_dims, &rewriter);
2736
2737 auto batch_dim_attr = rewriter.getI64IntegerAttr(batch_dim);
2738 auto feature_dim_attr = rewriter.getI64IntegerAttr(feature_dim);
2739
2740 Location loc = op.getLoc();
2741 Value result = rewriter.create<ConvOp>(
2742 loc, op.getType(), op.input(), op.out_backprop(),
2743 /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
2744 /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
2745 GetI64ElementsAttr(rhs_dilation, &rewriter),
2746 ConvDimensionNumbers::get(
2747 // Swap batch_dim and feature_dim in the activations.
2748 /*input_batch_dimension=*/feature_dim_attr,
2749 /*input_feature_dimension=*/batch_dim_attr,
2750 /*input_spatial_dimensions=*/kernel_spatial_dims_attr,
2751 // The gradients become the RHS of the convolution.
2752 // The gradients have shape [batch, out_rows, out_cols, ...,
2753 // out_depth] where the batch becomes the input feature for the
2754 // convolution.
2755 /*kernel_input_feature_dimension=*/batch_dim_attr,
2756 /*kernel_output_feature_dimension=*/feature_dim_attr,
2757 /*kernel_spatial_dimensions=*/kernel_spatial_dims_attr,
2758 /*output_batch_dimension=*/
2759 rewriter.getI64IntegerAttr(num_spatial_dims),
2760 /*output_feature_dimension=*/
2761 rewriter.getI64IntegerAttr(num_spatial_dims + 1),
2762 /*output_spatial_dimensions=*/out_spatial_dims_attr,
2763 rewriter.getContext()),
2764 rewriter.getI64IntegerAttr(feature_group_count),
2765 /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
2766 /*precision_config=*/ArrayAttr());
2767
2768 rewriter.replaceOp(op, {result});
2769
2770 return matchSuccess();
2771 }
2772 };
2773
2774 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
2775 public:
2776 using OpRewritePattern::OpRewritePattern;
2777
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const2778 PatternMatchResult matchAndRewrite(TF::OneHotOp op,
2779 PatternRewriter &rewriter) const override {
2780 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
2781 if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure();
2782 ArrayRef<int64_t> indices_shape = indices_ty.getShape();
2783 Type element_type = indices_ty.getElementType();
2784
2785 DenseIntElementsAttr depth_attr;
2786 if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
2787 return matchFailure();
2788 }
2789
2790 int64_t depth = depth_attr.getValue<APInt>({}).getSExtValue();
2791 int64_t axis = op.axis().getSExtValue();
2792 if (axis == -1) axis = indices_shape.size();
2793
2794 llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
2795 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
2796 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
2797
2798 llvm::SmallVector<int64_t, 4> output_dims =
2799 llvm::to_vector<4>(indices_shape);
2800 output_dims.insert(output_dims.begin() + axis, depth);
2801
2802 Location loc = op.getLoc();
2803 auto index_type = RankedTensorType::get(output_dims, element_type);
2804 Value compare = rewriter.create<CompareOp>(
2805 loc, op.indices(),
2806 rewriter.create<IotaOp>(
2807 loc, index_type,
2808 IntegerAttr::get(rewriter.getIntegerType(64), axis)),
2809 GetI64ElementsAttr(broadcast_dims, &rewriter),
2810 StringAttr::get("EQ", rewriter.getContext()));
2811 Value on_value = rewriter.create<BroadcastOp>(
2812 loc, op.getType(), op.on_value(),
2813 GetI64ElementsAttr(output_dims, &rewriter));
2814 Value off_value = rewriter.create<BroadcastOp>(
2815 loc, op.getType(), op.off_value(),
2816 GetI64ElementsAttr(output_dims, &rewriter));
2817 Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
2818 on_value, off_value);
2819
2820 rewriter.replaceOp(op, {result});
2821
2822 return matchSuccess();
2823 }
2824 };
2825
2826 // Converts InfeedEnqueueTuple to XLA HLO after_all, infeed and
2827 // get_tuple_element ops.
2828 //
2829 // All HLO infeed ops expect a HLO token type operand and produce a tuple
2830 // containing a token. This HLO token type is used to order multiple infeed
2831 // operations within a computation. The token type can come from other
2832 // infeed/outfeed/send/recv ops or can be generated using an after_all op with
2833 // no operands. Here we emit an after_all op to generate the token type operand
2834 // of infeed.
2835 //
2836 // For example the following IR:
2837 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
2838 //
2839 // would be lowered to
2840 //
2841 // %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
2842 // %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} :
2843 // (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
2844 // !xla_hlo.token>
2845 // %data = "xla_hlo.get_tuple_element"(%data_and_token) {index = 0}
2846 // %0#0 = "xla_hlo.get_tuple_element"(%data) {index = 0}
2847 // %0#1 = "xla_hlo.get_tuple_element"(%data) {index = 1}
2848 //
2849 class ConvertInfeedDequeueTupleOp
2850 : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
2851 public:
2852 using OpRewritePattern::OpRewritePattern;
2853
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const2854 PatternMatchResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
2855 PatternRewriter &rewriter) const override {
2856 std::vector<Type> result_types(op.outputs().size());
2857 for (auto idx_and_output : llvm::enumerate(op.outputs())) {
2858 result_types[idx_and_output.index()] = (idx_and_output.value().getType());
2859 }
2860 // Infeed takes a single token operand. Generate the token using after_all
2861 // op to pass to the infeed op.
2862 auto afterall = rewriter.create<AfterAllOp>(
2863 op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()),
2864 ValueRange());
2865
2866 // Emit infeed op.
2867 // The result type of infeed is a tuple(tuple(result types), token type).
2868 auto data_tuple_type =
2869 mlir::TupleType::get(result_types, rewriter.getContext());
2870 auto data_and_token_type = mlir::TupleType::get(
2871 {data_tuple_type, afterall.getType()}, rewriter.getContext());
2872
2873 auto data_and_token =
2874 rewriter.create<InfeedOp>(op.getLoc(), data_and_token_type, afterall,
2875 /*infeed_config=*/rewriter.getStringAttr(""));
2876
2877 // The infeed instruction produces a tuple of the infeed data and a token
2878 // type. Emit get_tuple_element to get infeed data tuple.
2879 auto data_tuple = rewriter.create<GetTupleElementOp>(
2880 op.getLoc(), data_tuple_type, data_and_token,
2881 rewriter.getI32IntegerAttr(0));
2882
2883 // Emit get_tuple_element for each result.
2884 std::vector<Value> results;
2885 for (auto idx_and_type : llvm::enumerate(result_types)) {
2886 auto tuple_element = rewriter.create<GetTupleElementOp>(
2887 op.getLoc(), idx_and_type.value(), data_tuple,
2888 rewriter.getI32IntegerAttr(idx_and_type.index()));
2889 results.push_back(tuple_element);
2890 }
2891 rewriter.replaceOp(op, ValueRange(results));
2892 return matchSuccess();
2893 }
2894 };
2895
2896 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops.
2897 //
2898 // XLA HLO outfeed op expects a token, which we generate by emitting an
2899 // after_all op.
2900 //
2901 // For example the following IR:
2902 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
2903 // ()
2904 //
2905 // would be lowered to
2906 //
2907 // %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
2908 // tuple<tensor<3xi32>, tensor<4xf32>>
2909 // %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
2910 // %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
2911 // (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
2912 //
2913 class ConvertOutfeedEnqueueTupleOp
2914 : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
2915 public:
2916 using OpRewritePattern::OpRewritePattern;
2917
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const2918 PatternMatchResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
2919 PatternRewriter &rewriter) const override {
2920 auto token_type = xla_hlo::TokenType::get(rewriter.getContext());
2921 auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
2922 auto afterall =
2923 rewriter.create<AfterAllOp>(op.getLoc(), token_type, ValueRange());
2924 rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, afterall,
2925 /*outfeed_config=*/rewriter.getStringAttr(""));
2926 rewriter.eraseOp(op);
2927 return matchSuccess();
2928 }
2929 };
2930
2931 // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
2932 //
2933 // tf.TopKV2 sorts along last dimension of the input tensor and then returns
2934 // the top K components' values and indices. This is translated into a few
2935 // ops in XLA HLO: first generating an integer sequence for the indices,
2936 // then sort both the original input tensor and the indices togheter, and
2937 // at last slice out the top K components.
2938 //
2939 // For example, for the following IR:
2940 //
2941 // %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
2942 // %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) ->
2943 // (tensor<16x8xf32>, tensor<16x8xi32>)
2944 //
2945 // We will get:
2946 //
2947 // %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
2948 // %2 = "xla_hlo.sort"(%input, %1) ( {
2949 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
2950 // %arg3: tensor<i32>, %arg4: tensor<i32>):
2951 // %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
2952 // "xla_hlo.return"(%7) : (tensor<i1>) -> ()
2953 // }) {dimension = 1 : i64, is_stable = true} : ...
2954 // %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
2955 // %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
2956 // %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
2957 // start_indices dense<0> : tensor<2xi64>,
2958 // strides = dense<1> : tensor<2xi64>} :
2959 // (tensor<16x16xf32>) -> tensor<16x8xf32>
2960 // %6 = "xla_hlo.slice"(%4) ...
2961 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
2962 public:
2963 using OpRewritePattern::OpRewritePattern;
2964
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const2965 PatternMatchResult matchAndRewrite(TF::TopKV2Op op,
2966 PatternRewriter &rewriter) const override {
2967 // We can only match when the `k` operand is a constant scalar.
2968 DenseIntElementsAttr k_attr;
2969 if (!matchPattern(op.k(), m_Constant(&k_attr))) return matchFailure();
2970
2971 // The last dimension of the input tensor's shape should be known so we can
2972 // have clamped end_indices for slices.
2973 TensorType input_type = op.input().getType().cast<TensorType>();
2974 if (!input_type.hasRank()) return matchFailure();
2975 int64_t input_rank = input_type.getRank();
2976 int64_t last_dim_index = input_rank - 1;
2977 int64_t last_dim_size = input_type.getDimSize(last_dim_index);
2978 if (last_dim_size == ShapedType::kDynamicSize) return matchFailure();
2979
2980 // Create an Itoa op for indices.
2981 auto i32_type = rewriter.getIntegerType(32);
2982 Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
2983 Value iota_op = rewriter.create<xla_hlo::IotaOp>(
2984 op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
2985
2986 // Create the sort op. It takes two inputs, one for the original input, the
2987 // other for the indices.
2988 auto sort_op = rewriter.create<xla_hlo::SortOp>(
2989 op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
2990 /*is_stable=*/true);
2991 BuildSortComparisonBody({input_type.getElementType(), i32_type},
2992 /*direction=*/"GT", &sort_op.comparator(),
2993 &rewriter);
2994
2995 // Get the sorted input and index tuple element.
2996 auto tuple_first_element =
2997 rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
2998 auto tuple_second_element =
2999 rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
3000
3001 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3002 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3003 end_indices.back() =
3004 std::min((*k_attr.begin()).getSExtValue(), last_dim_size);
3005 SmallVector<int64_t, 4> strides(input_rank, 1);
3006
3007 // Get the slice for the top K elements.
3008
3009 Value values = rewriter.create<xla_hlo::SliceOp>(
3010 op.getLoc(), tuple_first_element,
3011 GetI64ElementsAttr(begin_indices, &rewriter),
3012 GetI64ElementsAttr(end_indices, &rewriter),
3013 GetI64ElementsAttr(strides, &rewriter));
3014
3015 Value indices = rewriter.create<xla_hlo::SliceOp>(
3016 op.getLoc(), tuple_second_element,
3017 GetI64ElementsAttr(begin_indices, &rewriter),
3018 GetI64ElementsAttr(end_indices, &rewriter),
3019 GetI64ElementsAttr(strides, &rewriter));
3020
3021 rewriter.replaceOp(op, {values, indices});
3022 return matchSuccess();
3023 }
3024 };
3025
3026 // Converts tf.Unpack to a series of XLA HLO slice ops.
3027 //
3028 // Each slice takes one element along the dimension to unpack and takes the full
3029 // range for all other dimensions. Each slice is then reshaped to drop the
3030 // dimension to unpack (which is always of size 1).
3031 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
3032 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
3033 public:
3034 using OpRewritePattern::OpRewritePattern;
3035
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const3036 PatternMatchResult matchAndRewrite(TF::UnpackOp op,
3037 PatternRewriter &rewriter) const override {
3038 auto value_type = op.value().getType().cast<RankedTensorType>();
3039 if (!value_type) return matchFailure();
3040
3041 int64_t value_rank = value_type.getRank();
3042 int64_t axis = op.axis().getSExtValue();
3043 if (axis < 0) axis += value_rank;
3044
3045 // Parameters for constructing each slice.
3046 SmallVector<int64_t, 4> begin_indices(value_rank, 0);
3047 auto end_indices = llvm::to_vector<4>(value_type.getShape());
3048 SmallVector<int64_t, 4> strides(value_rank, 1);
3049
3050 // All HLO slice+reshape results used to replace the original tf.Unpack op.
3051 SmallVector<Value, 4> results;
3052 results.reserve(op.getNumResults());
3053
3054 for (int i = 0; i < op.getNumResults(); ++i) {
3055 begin_indices[axis] = i;
3056 end_indices[axis] = i + 1;
3057
3058 auto slice_op = rewriter.create<xla_hlo::SliceOp>(
3059 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3060 GetI64ElementsAttr(end_indices, &rewriter),
3061 GetI64ElementsAttr(strides, &rewriter));
3062 // Reshape to drop the axis dimension.
3063 auto reshape_op = rewriter.create<xla_hlo::ReshapeOp>(
3064 op.getLoc(), op.getType(i), slice_op);
3065 results.push_back(reshape_op);
3066 }
3067
3068 rewriter.replaceOp(op, results);
3069 return matchSuccess();
3070 }
3071 };
3072
3073 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
3074 //
3075 // TF unsorted segment reduction op peforms the following calculation:
3076 //
3077 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is
3078 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
3079 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
3080 // Dm+1, ..., Dn]. Then
3081 // output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
3082 // <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
3083 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
3084 //
3085 // The op will be translated to XLA HLO scatter with the following parameters:
3086 // * Update window dims is [segment_id_rank, data_rank).
3087 // * Inserted window dims is {0}.
3088 // * Scatter dims to operand dims mapping is {0}.
3089 // * Index vector dim is segment_id_rank.
3090 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
3091 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
3092 using OpRewritePattern<OpTy>::OpRewritePattern;
3093
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const3094 PatternMatchResult matchAndRewrite(OpTy op,
3095 PatternRewriter &rewriter) const override {
3096 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
3097 if (!data_type) return this->matchFailure();
3098 int64_t data_rank = data_type.getRank();
3099
3100 auto segment_ids_type =
3101 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
3102 if (!segment_ids_type) return this->matchFailure();
3103 int64_t segment_ids_rank = segment_ids_type.getRank();
3104
3105 DenseIntElementsAttr num_segments_attr;
3106 if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
3107 return this->matchFailure();
3108
3109 // The final shape for TF unsorted segment reduction op is [num_segments] +
3110 // data_shape[segment_ids_rank:].
3111 SmallVector<int64_t, 4> output_shape;
3112 output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
3113 auto suffix = data_type.getShape().drop_front(segment_ids_rank);
3114 output_shape.append(suffix.begin(), suffix.end());
3115 auto output_type =
3116 RankedTensorType::get(output_shape, data_type.getElementType());
3117
3118 // Broadccast the initial value for reduction. This will become the
3119 // 'operand' parameter to scatter to for the final scatter op.
3120 Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
3121 op.getLoc(), &rewriter);
3122 auto broadcasted_init = rewriter.create<xla_hlo::BroadcastOp>(
3123 op.getLoc(), output_type, init,
3124 GetI64ElementsAttr(output_shape, &rewriter));
3125
3126 // Parameters for the generated scatter op.
3127 SmallVector<int64_t, 1> inserted_window_dims(1, 0);
3128 SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
3129 int64_t index_vector_dim = segment_ids_rank;
3130
3131 // Put all parameters in a StructAttr.
3132 auto dims_attr = ScatterDimensionNumbers::get(
3133 GetI64ElementsAttrForSeq(segment_ids_rank, data_rank, &rewriter),
3134 GetI64ElementsAttr(inserted_window_dims, &rewriter),
3135 GetI64ElementsAttr(scatter_dims_to_operand_dims, &rewriter),
3136 rewriter.getI64IntegerAttr(index_vector_dim), rewriter.getContext());
3137
3138 auto scatter =
3139 rewriter.create<ScatterOp>(op.getLoc(), op.getType(), broadcasted_init,
3140 op.segment_ids(), op.data(), dims_attr);
3141 BuildReduceBody<ReductionOp>(data_type.getElementType(),
3142 &scatter.update_computation(), &rewriter);
3143
3144 rewriter.replaceOp(op, scatter.getResult());
3145 return this->matchSuccess();
3146 }
3147 };
3148
3149 class ConvertUnsortedSegmentMaxOp
3150 : public GenericConvertUnsortedSegmentReductionOp<
3151 ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
3152 public:
3153 using GenericConvertUnsortedSegmentReductionOp::
3154 GenericConvertUnsortedSegmentReductionOp;
3155
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3156 static Value GetInitialValue(Type reduce_element_type, Location loc,
3157 PatternRewriter *rewriter) {
3158 return GetMinValueForType(reduce_element_type, loc, rewriter);
3159 }
3160 };
3161
3162 class ConvertUnsortedSegmentMinOp
3163 : public GenericConvertUnsortedSegmentReductionOp<
3164 ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
3165 public:
3166 using GenericConvertUnsortedSegmentReductionOp::
3167 GenericConvertUnsortedSegmentReductionOp;
3168
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3169 static Value GetInitialValue(Type reduce_element_type, Location loc,
3170 PatternRewriter *rewriter) {
3171 return GetMaxValueForType(reduce_element_type, loc, rewriter);
3172 }
3173 };
3174
3175 class ConvertUnsortedSegmentProdOp
3176 : public GenericConvertUnsortedSegmentReductionOp<
3177 ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
3178 public:
3179 using GenericConvertUnsortedSegmentReductionOp::
3180 GenericConvertUnsortedSegmentReductionOp;
3181
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3182 static Value GetInitialValue(Type reduce_element_type, Location loc,
3183 PatternRewriter *rewriter) {
3184 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
3185 }
3186 };
3187
3188 class ConvertUnsortedSegmentSumOp
3189 : public GenericConvertUnsortedSegmentReductionOp<
3190 ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
3191 public:
3192 using GenericConvertUnsortedSegmentReductionOp::
3193 GenericConvertUnsortedSegmentReductionOp;
3194
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)3195 static Value GetInitialValue(Type reduce_element_type, Location loc,
3196 PatternRewriter *rewriter) {
3197 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
3198 }
3199 };
3200
3201 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
3202 //
3203 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
3204 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
3205 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
3206 // translated into an HLO while op to first emulate shuffling indices using
3207 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
3208 // with the shuffled indices.
3209 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
3210 public:
3211 using OpRewritePattern::OpRewritePattern;
3212
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const3213 PatternMatchResult matchAndRewrite(TF::RandomShuffleOp op,
3214 PatternRewriter &rewriter) const override {
3215 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3216 if (!input_type) return matchFailure();
3217
3218 int64_t input_rank = input_type.getRank();
3219 int64_t first_dim_size = input_type.getDimSize(0);
3220 if (ShapedType::isDynamic(first_dim_size)) return matchFailure();
3221
3222 // We are shuffling along the first dimension. If its size is <= 1, then
3223 // shuffling is a no-op.
3224 if (first_dim_size <= 1) {
3225 rewriter.replaceOp(op, op.value());
3226 return matchSuccess();
3227 }
3228
3229 // For vectors, shuffle values by sorting instead of the obvious
3230 // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
3231 // but not easily parallelizable. For a sufficiently parallel architecture,
3232 // it is faster to sort many times, than Fisher-Yates shuffle once.
3233 if (input_rank == 1) {
3234 // Shuffle values by assigning each value a random key and sorting the
3235 // keys. Keys can collide causing detectable patterns in the shuffled
3236 // output. Collisions translates into more ascending sub-sequences in the
3237 // shuffled output than would be expected by chance. To avoid collisions,
3238 // the number of possible key values must be sufficiently large.
3239
3240 // How are more than 2^32 keys created? In each loop iteration, the
3241 // algorithm sorts by random keys. Conceptually, the earlier iterations
3242 // are sorting on the lower-order bits of larger keys that are never
3243 // actually assembled.
3244
3245 // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
3246 // the number of possible keys and n is the number of values. If d = n^2,
3247 // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
3248 // as n goes to infinity is zero.
3249
3250 // This implementation ensures that the key-space is greater than or equal
3251 // to the cube of the number of values. The risk of collisions can be
3252 // further reduced by increasing Exponent at the expense of
3253 // performance.
3254
3255 // For Exponent = 2, the expected number of collisions per shuffle is
3256 // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
3257 // about 1/2.
3258
3259 // For Exponent = 3, the expected number of collisions per shuffle is
3260 // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
3261 // about 1/3255.
3262
3263 // For Exponent = 4, the expected number of collisions per shuffle is
3264 // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
3265 // about 1/132622.
3266 constexpr int exponent = 3;
3267 int64_t num_elements = input_type.getNumElements();
3268 uint32_t u32_max = std::numeric_limits<uint32_t>::max();
3269 int rounds =
3270 std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
3271
3272 Value current = op.value();
3273 for (int i = 0; i < rounds; ++i) {
3274 auto keys =
3275 CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
3276 /*upper_limit=*/u32_max, &rewriter);
3277 auto sorted = rewriter.create<xla_hlo::SortOp>(
3278 op.getLoc(), llvm::ArrayRef<Value>{keys, current});
3279 auto i32_type = rewriter.getIntegerType(32);
3280 BuildSortComparisonBody({i32_type, input_type.getElementType()},
3281 /*direction=*/"LT", &sorted.comparator(),
3282 &rewriter);
3283 current = rewriter.create<GetTupleElementOp>(op.getLoc(),
3284 sorted.getResult(), 1);
3285 }
3286 rewriter.replaceOp(op, current);
3287 return matchSuccess();
3288 }
3289
3290 // The Fisher-Yates algorithm.
3291
3292 // Generate range(n) as the initial value for the indices to be swapped.
3293 auto indices_type =
3294 RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
3295 Value indices = rewriter.create<xla_hlo::IotaOp>(
3296 op.getLoc(), indices_type, rewriter.getI64IntegerAttr(first_dim_size));
3297
3298 // Generate random numbers to be used as swaps for the indices.
3299 Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
3300 first_dim_size, &rewriter);
3301
3302 // While loop body to perform index swaps.
3303 auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
3304 SmallVectorImpl<Value> *new_values,
3305 OpBuilder *builder) {
3306 Value swaps = old_values[0];
3307 Value indices = old_values[1];
3308
3309 auto vec1_i32_type =
3310 RankedTensorType::get({1}, builder->getIntegerType(32));
3311 auto scalar_i32_type =
3312 RankedTensorType::get({}, builder->getIntegerType(32));
3313 auto scalar_i64_type =
3314 RankedTensorType::get({}, builder->getIntegerType(64));
3315
3316 auto scalar_one =
3317 DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
3318
3319 // We need to swap the indices[i] with indices[swaps[i]]. First get
3320 // these index values.
3321 Value source_index = builder->create<xla_hlo::DynamicSliceOp>(
3322 loc, vec1_i32_type, indices, i, scalar_one);
3323 Value swap_index = builder->create<xla_hlo::ReshapeOp>(
3324 loc, scalar_i32_type,
3325 builder->create<xla_hlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
3326 scalar_one));
3327 Value target_index = builder->create<xla_hlo::DynamicSliceOp>(
3328 loc, vec1_i32_type, indices, swap_index, scalar_one);
3329
3330 // Then perform the swap.
3331 // indices[i] <- indices[swaps[i]]
3332 indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
3333 loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
3334 // indices[swaps[i]] <- indices[i]
3335 indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
3336 loc, indices.getType(), indices, source_index,
3337 llvm::makeArrayRef(swap_index));
3338
3339 // Update new values.
3340 new_values->assign({swaps, indices});
3341 };
3342
3343 // Create a while op to swap indices.
3344 SmallVector<Value, 2> while_output;
3345 CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
3346 &while_output, &rewriter);
3347 Value swaped_indices = while_output[1];
3348
3349 // Gather the data using the swapped indices as the shuffled order.
3350 ArrayRef<int64_t> input_shape = input_type.getShape();
3351 SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
3352 slice_sizes[0] = 1;
3353 auto dims_attr = GatherDimensionNumbers::get(
3354 /*offset_dims=*/GetI64ElementsAttrForSeq(1, first_dim_size, &rewriter),
3355 /*collapsed_slice_dims=*/GetI64ElementsAttr({0}, &rewriter),
3356 /*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
3357 /*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
3358 rewriter.getContext());
3359 rewriter.replaceOpWithNewOp<xla_hlo::GatherOp>(
3360 op, op.getType(), op.value(), swaped_indices, dims_attr,
3361 GetI64ElementsAttr(slice_sizes, &rewriter));
3362
3363 return matchSuccess();
3364 }
3365 };
3366
3367 // Converts tf.VariableShape op to a XLA HLO constant representing the variable
3368 // shape.
3369 class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
3370 public:
3371 using OpRewritePattern::OpRewritePattern;
3372
matchAndRewrite(TF::VariableShapeOp op,PatternRewriter & rewriter) const3373 PatternMatchResult matchAndRewrite(TF::VariableShapeOp op,
3374 PatternRewriter &rewriter) const override {
3375 // The input type should be a tensor<!tf.resource<resource-type>>. We need
3376 // to get the inner resource type.
3377 auto input_type = op.input().getType().cast<TensorType>();
3378 auto subtypes =
3379 input_type.getElementType().cast<TF::ResourceType>().getSubtypes();
3380 // It can be missing; then we cannot convert.
3381 if (subtypes.empty()) return matchFailure();
3382
3383 auto resource_type = subtypes[0].cast<TensorType>();
3384 if (!resource_type.hasStaticShape()) return matchFailure();
3385
3386 auto resource_shape = resource_type.getShape();
3387 Attribute const_attr;
3388
3389 // We need to match the original op result's element type.
3390 auto element_type = op.getType().cast<TensorType>().getElementType();
3391 unsigned bitwidth = element_type.cast<IntegerType>().getWidth();
3392 if (bitwidth == 32) {
3393 SmallVector<int32_t, 4> shape(resource_shape.begin(),
3394 resource_shape.end());
3395 const_attr = GetI32ElementsAttr(shape, &rewriter);
3396 } else {
3397 assert(bitwidth == 64);
3398 const_attr = GetI64ElementsAttr(resource_shape, &rewriter);
3399 }
3400
3401 rewriter.replaceOpWithNewOp<xla_hlo::ConstOp>(op, const_attr);
3402 return matchSuccess();
3403 }
3404 };
3405
3406 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
3407
legalizeTF(Operation * op,bool allow_partial_conversion)3408 LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
3409 MLIRContext *context = op->getContext();
3410
3411 // Add lowering patterns to the list.
3412 OwningRewritePatternList patterns;
3413 populateWithGenerated(context, &patterns);
3414
3415 // Add patterns that lower some of the high level TensorFlow ops to lower
3416 // level TensorFlow ops. So, we don't have to target all the TensorFlow ops
3417 // here for lowering to HLO.
3418 TF::PopulateLoweringTFPatterns(context, &patterns);
3419 patterns.insert<
3420 ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBF16FloorDivOp,
3421 ConvertConv2D, ConvertConv2DBackpropFilterOp,
3422 ConvertConv2DBackpropInputOp, ConvertEinsumOp,
3423 ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
3424 ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op,
3425 ConvertInfeedDequeueTupleOp, ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp,
3426 ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
3427 ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertRangeOp,
3428 ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp,
3429 ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
3430 ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
3431 ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
3432 ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
3433 ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
3434 ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
3435 ConvertRandomShuffleOp, ConvertVariableShapeOp>(op->getContext());
3436
3437 ConversionTarget target(*context);
3438 target.addLegalDialect<XlaHloDialect>();
3439
3440 if (!allow_partial_conversion) {
3441 // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
3442 target.addLegalOp<CallOp, ModuleOp, FuncOp, ModuleTerminatorOp,
3443 ::mlir::ReturnOp>();
3444 return applyFullConversion(op, target, patterns);
3445 }
3446
3447 return applyPartialConversion(op, target, patterns);
3448 }
3449
3450 /// Performs the lowering to XLA dialect.
runOnFunction()3451 void LegalizeTF::runOnFunction() {
3452 if (failed(legalizeTF(getFunction(), allow_partial_conversion_)))
3453 signalPassFailure();
3454 }
3455
3456 static PassRegistration<LegalizeTF> pass(
3457 "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect");
3458
3459 } // end namespace
3460
createLegalizeTFPass(bool allow_partial_conversion)3461 std::unique_ptr<OpPassBase<FuncOp>> createLegalizeTFPass(
3462 bool allow_partial_conversion) {
3463 return std::make_unique<LegalizeTF>(allow_partial_conversion);
3464 }
3465
3466 } // end namespace xla_hlo
3467 } // end namespace mlir
3468