1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Diagnostics.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/TypeRange.h" // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/core/util/tensor_format.h"
37
38 namespace mlir {
39 namespace TF {
40 namespace {
41
42 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)43 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
44 Builder *builder) {
45 RankedTensorType ty = RankedTensorType::get(
46 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
47 return DenseIntElementsAttr::get(ty, values);
48 }
49
50 // Returns a 1-d i64 elements attribute populated with numbers from start to
51 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)52 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
53 Builder *builder) {
54 int size = end - start;
55
56 SmallVector<int64_t, 4> vals;
57 vals.resize(size);
58 std::iota(vals.begin(), vals.end(), start);
59
60 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
61 return DenseIntElementsAttr::get(ty, vals);
62 }
63
64 // Return an Attr representation of the value.
GetF32Scalar(OpBuilder * builder,float value)65 static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) {
66 return DenseElementsAttr::get(
67 RankedTensorType::get({}, builder->getF32Type()),
68 FloatAttr::get(builder->getF32Type(), value));
69 }
70
71 // Returns a TF_CastOp to F32. This function is used for CastOps that are
72 // intermediate nodes in a TableGen pattern result. In such a case, the
73 // destination type is not inferred and must be given explicitly.
74 //
75 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpF32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)76 static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x,
77 BoolAttr truncate) {
78 auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
79 if (!x_type) llvm_unreachable("unsupported type");
80 Type type = x_type.clone(builder->getF32Type());
81 return builder->create<CastOp>(loc, type, x, truncate);
82 }
83
ConvertToAPFloat(double val,Type type)84 static APFloat ConvertToAPFloat(double val, Type type) {
85 if (type.getIntOrFloatBitWidth() == 32) {
86 return APFloat(static_cast<float>(val));
87 }
88
89 return APFloat(val);
90 }
91
92 // Returns int, float, or complex DenseElementsAttr with scalar shape with the
93 // given element type and the value.
94 template <typename T>
GetScalarOfType(Type ty,T raw_value)95 static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
96 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
97 if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
98 FloatAttr attr = FloatAttr::get(float_ty, raw_value);
99 return DenseElementsAttr::get(scalar_ty, attr);
100 } else if (auto int_ty = ty.dyn_cast_or_null<IntegerType>()) {
101 IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
102 return DenseElementsAttr::get(scalar_ty, attr);
103 } else if (auto complex_ty = ty.dyn_cast_or_null<ComplexType>()) {
104 Type complex_element_ty = complex_ty.getElementType();
105 if (complex_element_ty.isF32()) {
106 return DenseElementsAttr::get(
107 scalar_ty, static_cast<std::complex<float>>(raw_value));
108 } else if (complex_element_ty.isF64()) {
109 return DenseElementsAttr::get(
110 scalar_ty, static_cast<std::complex<double>>(raw_value));
111 }
112 }
113 llvm_unreachable("unsupported type");
114 }
115
116 // Return true if the passed quantized type is unsigned.
QuantizedTypeIsUnsigned(Type type)117 bool QuantizedTypeIsUnsigned(Type type) {
118 return TypeSwitch<Type, bool>(type)
119 .Case<mlir::TF::Qint8Type>([](Type) { return false; })
120 .Case<mlir::TF::Qint16Type>([](Type) { return false; })
121 .Case<mlir::TF::Qint32Type>([](Type) { return false; })
122 .Case<mlir::TF::Quint8Type>([](Type) { return true; })
123 .Case<mlir::TF::Quint16Type>([](Type) { return true; })
124 .Default([](Type) {
125 llvm_unreachable("QuantizedTypeIsUnsigned: not a quantized type");
126 return false;
127 });
128 }
129
130 // Return the half_range value that is used by DequantizeOp. half_range is used
131 // to offset the quantized representation before it gets scaled. In the case
132 // of negative quantize types, this offset is half the type's range.
DequantizeHalfRange(OpBuilder * builder,Value input)133 static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) {
134 auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
135 if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType");
136 bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType());
137 float half_range = is_unsigned ? 0 : 128;
138 return GetScalarOfType(builder->getF32Type(), half_range);
139 }
140
141 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
142 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)143 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
144 StringAttr data_format,
145 Builder *builder) {
146 tensorflow::TensorFormat format;
147 if (!FormatFromString(data_format.getValue().str(), &format)) return {};
148
149 // Reduce along all dimensions except the feature dimension.
150 int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
151 llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
152 std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
153 std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
154 feature_dim + 1);
155 return GetI64ElementsAttr(dims_to_reduce, builder);
156 }
157
158 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
159
160 // Infers ExpandDims op output type for the given input type `ty` and dimension
161 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)162 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
163 auto ranked_ty = ty.dyn_cast<RankedTensorType>();
164
165 // Unranked type.
166 if (!ranked_ty) return ty;
167
168 auto shape = llvm::to_vector<4>(ranked_ty.getShape());
169 if (axis < 0) axis += ranked_ty.getRank() + 1;
170
171 shape.insert(shape.begin() + axis, 1);
172 return RankedTensorType::get(shape, ranked_ty.getElementType());
173 }
174
175 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
176 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)177 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
178 ArrayRef<Value> vals) {
179 int64_t length = vals.size();
180 auto type = RankedTensorType::get({length}, dtype);
181 auto axis = rewriter.create<ConstOp>(
182 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
183 return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
184 }
185
186 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
187 //
188 // Note that to improve the parallelism, AddN op uses tree-based reduction.
189 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
190 //
191 // 0 1 2 3 4
192 // | | | | |
193 // ------- ------- |
194 // | | |
195 // 5 6 |
196 // | | |
197 // ------------- |
198 // | |
199 // 7 |
200 // | |
201 // ----------------
202 // |
203 // 8
204 //
205 // Example:
206 //
207 // %result = "tf.AddN"(%0, %1, %2)
208 //
209 // is lowered to:
210 //
211 // %sum0 = "tf.AddV2"(%0, %1)
212 // %result = "tf.AddV2"(%sum0, %2)
213 //
214 // While
215 //
216 // %result = "tf.AddN"(%0, %1, %2, %3, %4)
217 //
218 // is lowered to:
219 //
220 // %sum0 = "tf.AddV2"(%0, %1)
221 // %sum1 = "tf.AddV2"(%2, %3)
222 // %sum2 = "tf.AddV2"(%sum0, %sum1)
223 // %result = "tf.AddV2"(%sum2, %4)
224 //
225 class LowerAddNOp : public RewritePattern {
226 public:
LowerAddNOp(MLIRContext * context)227 explicit LowerAddNOp(MLIRContext *context)
228 : RewritePattern(AddNOp::getOperationName(), 1, context,
229 {AddV2Op::getOperationName()}) {}
230
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const231 LogicalResult matchAndRewrite(Operation *op,
232 PatternRewriter &rewriter) const override {
233 auto addn_op = cast<AddNOp>(op);
234
235 // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
236 // support variant type so variant types require special handling.
237 if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
238 return failure();
239 llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
240 addn_op.inputs().end());
241
242 int64_t n = operands.size();
243 // Keep doing tree-based reduction when there are more than one operand.
244 while (n > 1) {
245 for (int64_t i = 0; i < n; i += 2) {
246 // Add two adjacent operands if applicable.
247 operands[i / 2] =
248 (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
249 operands[i], operands[i + 1])
250 : operands[i];
251 }
252 n = (n + 1) / 2;
253 }
254
255 rewriter.replaceOp(addn_op, operands[0]);
256 return success();
257 }
258 };
259
260 // Lowers DynamicStitch op with constant indices and with static input and
261 // output shapes using Reshape, UnPack and Pack op.
262 //
263 // %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
264 // %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
265 // tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
266 // %arg1)
267 // : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
268 // -> tensor<5x2xf32>
269 //
270 // is lowered to
271 //
272 // %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
273 // %inp0 = "tf.Reshape"(%arg0, %shape)
274 // : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
275 // %inp1 = "tf.Reshape"(%arg1, %shape)
276 // : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
277 // %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
278 // : (tensor<1x2xf32>) -> tensor<2xf32>
279 // %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
280 // : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
281 // tensor<2xf32>)
282 // %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
283 // %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
284 // : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
285 // tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
286 //
287 template <typename OpT>
288 class LowerDynamicStitchOp : public RewritePattern {
289 public:
LowerDynamicStitchOp(MLIRContext * context)290 explicit LowerDynamicStitchOp(MLIRContext *context)
291 : RewritePattern(
292 OpT::getOperationName(), 1, context,
293 {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
294 UnpackOp::getOperationName(), PackOp::getOperationName()}) {}
295
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const296 LogicalResult matchAndRewrite(Operation *src_op,
297 PatternRewriter &rewriter) const override {
298 auto op = cast<OpT>(src_op);
299
300 // Static output type is used to compute intermediate values. Note that the
301 // output type doesn't have to be static but if input types and indices are
302 // constant, then the output type can be statically determined.
303 RankedTensorType out_ty =
304 op.getType().template dyn_cast<RankedTensorType>();
305 if (!out_ty || !out_ty.hasStaticShape()) return failure();
306
307 // Extract out all the constant indices' attributes and verify that data
308 // types are static.
309 SmallVector<DenseIntElementsAttr, 4> indices;
310 indices.reserve(op.N());
311 for (auto it : llvm::zip(op.indices(), op.data())) {
312 Value index = std::get<0>(it);
313 Value data = std::get<1>(it);
314
315 DenseIntElementsAttr index_attr;
316 if (!matchPattern(index, m_Constant(&index_attr))) return failure();
317 indices.push_back(index_attr);
318
319 RankedTensorType data_ty =
320 data.getType().template dyn_cast<RankedTensorType>();
321 if (!data_ty || !data_ty.hasStaticShape()) return failure();
322 }
323
324 // Compute type of each of the items and shape to use while reshaping inputs
325 // so that they can be unpacked to extract out individual items.
326 ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
327 auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
328
329 SmallVector<int64_t, 4> packed_shape;
330 packed_shape.push_back(-1);
331 packed_shape.append(item_shape.begin(), item_shape.end());
332 Location loc = op.getLoc();
333 auto packed_shape_val = rewriter.create<ConstOp>(
334 loc, GetI64ElementsAttr(packed_shape, &rewriter));
335
336 // Prepare each of the output item by unpacking data and then putting it to
337 // the specified index.
338 SmallVector<Value, 8> values(out_ty.getDimSize(0));
339 for (auto it : llvm::zip(indices, op.data())) {
340 DenseIntElementsAttr index_attr = std::get<0>(it);
341 Value data = std::get<1>(it);
342
343 auto reshaped_data =
344 rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
345 auto num_items = reshaped_data.getType()
346 .template cast<RankedTensorType>()
347 .getShape()[0];
348 auto items = rewriter.create<UnpackOp>(
349 loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
350 /*axis=*/0);
351 for (auto index_item : llvm::zip(index_attr, items.getResults())) {
352 int64_t output_index = std::get<0>(index_item).getSExtValue();
353 Value item = std::get<1>(index_item);
354 values[output_index] = item;
355 }
356 }
357
358 rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
359 return success();
360 }
361 };
362
363 // This pass performs a manual conversion with FakeQuant, converting between
364 // floating point and quantized space. It is designed to reproduce TF's
365 // implementation, mirroring the previous XLA implementation.
366 //
367 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
368 // 2. Converting the input bounds to quantized space, rounding values.
369 // 3. Convert back into floating point space.
370 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
371 public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)372 explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
373 : RewritePattern(
374 FakeQuantWithMinMaxVarsOp::getOperationName(), 1, context,
375 {AddV2Op::getOperationName(), SubOp::getOperationName(),
376 ConstOp::getOperationName(), MulOp::getOperationName(),
377 FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
378 DivOp::getOperationName(), RoundOp::getOperationName()}) {}
379
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const380 LogicalResult matchAndRewrite(Operation *src_op,
381 PatternRewriter &rewriter) const override {
382 auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
383
384 auto input = op.inputs();
385 auto input_ty = input.getType().cast<ShapedType>();
386 auto element_ty = input_ty.getElementType();
387 auto scalar_ty = RankedTensorType::get({}, element_ty);
388
389 auto num_bits = op.num_bits();
390 auto narrow_range = op.narrow_range();
391 const double bits_min = narrow_range ? 1 : 0;
392 const double bits_max = (1 << num_bits) - 1;
393
394 auto float_min = op.min();
395 auto float_max = op.max();
396
397 auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
398
399 // Compute the range when quantized.
400 auto quant_min = rewriter.create<ConstOp>(
401 op.getLoc(), DenseElementsAttr::get(
402 scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
403
404 auto quant_max = rewriter.create<ConstOp>(
405 op.getLoc(), DenseElementsAttr::get(
406 scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
407
408 auto quant_diff = rewriter.create<ConstOp>(
409 op.getLoc(),
410 DenseElementsAttr::get(
411 scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
412
413 auto quant_to_float =
414 rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
415
416 auto float_to_quant =
417 rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
418
419 // During quantization, the quantized min/max values may not line up
420 // perfectly with the specified min/max. Nudge them into the right range.
421 auto min_scaled =
422 rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
423 auto min_scaled_sub =
424 rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
425
426 auto mid_rounded =
427 rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
428
429 auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
430 op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
431
432 auto quant_min_sub =
433 rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
434 auto quant_max_sub =
435 rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
436
437 auto nudged_float_min =
438 rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
439
440 auto nudged_float_max =
441 rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
442
443 // Now quantize the input value with the approximated min/max values.
444
445 // Move the input value into quantized space
446 Value quantized_input = rewriter.create<ClipByValueOp>(
447 op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
448
449 quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
450 quantized_input, nudged_float_min);
451
452 quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
453 quantized_input, float_to_quant);
454
455 // Round the quantized input always to the positive direction.
456 auto half_val = rewriter.create<ConstOp>(
457 op.getLoc(),
458 DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
459
460 quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
461 quantized_input, half_val);
462
463 quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
464
465 // Convert back into floating point spae.
466 Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
467 quantized_input, quant_to_float);
468
469 output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
470 nudged_float_min);
471
472 rewriter.replaceOp(op, {output});
473 return success();
474 }
475 };
476
477 // Lowers InvertPermutation op to TensorScatterUpdate op.
478 //
479 // Example:
480 //
481 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
482 // "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
483 //
484 // is lowered to
485 //
486 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
487 // %start = "tf.Const"() {value = dense<0> : tensor<i32>}
488 // %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
489 // %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
490 // %updates = "tf.Range"(%start, %limit, %delta) :
491 // (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
492 // %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
493 // %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
494 // tensor<5x1xi32>
495 // "tf.TensorScatterUpdate"(%x, %indices, %updates) :
496 // (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
497 //
498 class LowerInvertPermutationOp : public RewritePattern {
499 public:
LowerInvertPermutationOp(MLIRContext * context)500 explicit LowerInvertPermutationOp(MLIRContext *context)
501 : RewritePattern(
502 InvertPermutationOp::getOperationName(), 1, context,
503 {ConstOp::getOperationName(), RangeOp::getOperationName(),
504 ReshapeOp::getOperationName(),
505 TensorScatterUpdateOp::getOperationName()}) {}
506
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const507 LogicalResult matchAndRewrite(Operation *src_op,
508 PatternRewriter &rewriter) const override {
509 auto op = cast<InvertPermutationOp>(src_op);
510
511 Location loc = op.getLoc();
512 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
513 // x input must have static shape.
514 if (!x_type || !x_type.hasStaticShape()) {
515 return failure();
516 }
517 Type int_type = x_type.getElementType(); // Could be i32 or i64.
518
519 auto result_type = x_type;
520 auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
521 Value limit = rewriter.create<ConstOp>(
522 loc, GetScalarOfType(int_type, x_type.getShape()[0]));
523 auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
524 // Construct a sequence of numbers [0, 1, ... len(x)-1].
525 auto updates =
526 rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
527
528 auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
529 auto shape = rewriter.create<ConstOp>(
530 loc, DenseElementsAttr::get(
531 shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
532 auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
533
534 rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
535 indices, updates);
536 return success();
537 }
538 };
539
540 // Approximates lgamma using Lanczos' approximation from
541 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
542 // Analysis series B. Vol. 1:
543 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
544 // t(z) = z + kLanczosGamma + 1/2
545 // A(z) = kBaseLanczosCoeff
546 // + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
547 //
548 // Coefficients for the Lanczos approximation of the gamma function. The
549 // coefficients are uniquely determined by the choice of g and n
550 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
551 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
552 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
553 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
554 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
555 static constexpr double kLanczosGamma = 7; // aka g
556 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
557 static constexpr std::array<double, 8> kLanczosCoefficients = {
558 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
559 771.3234287776530788486528258894, -176.61502916214059906584551354,
560 12.507343278686904814458936853, -0.13857109526572011689554707,
561 9.984369578019570859563e-6, 1.50563273514931155834e-7};
562
563 class LowerLgammaOp : public RewritePattern {
564 public:
LowerLgammaOp(MLIRContext * context)565 explicit LowerLgammaOp(MLIRContext *context)
566 : RewritePattern(LgammaOp::getOperationName(), 1, context,
567 {
568 CastOp::getOperationName(),
569 ConstOp::getOperationName(),
570 NegOp::getOperationName(),
571 SubOp::getOperationName(),
572 SelectV2Op::getOperationName(),
573 LessOp::getOperationName(),
574 AddV2Op::getOperationName(),
575 DivOp::getOperationName(),
576 SubOp::getOperationName(),
577 LogOp::getOperationName(),
578 Log1pOp::getOperationName(),
579 IsInfOp::getOperationName(),
580 MulOp::getOperationName(),
581 FloorOp::getOperationName(),
582 AbsOp::getOperationName(),
583 GreaterOp::getOperationName(),
584 SinOp::getOperationName(),
585 IsFiniteOp::getOperationName(),
586 }) {}
587
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const588 LogicalResult matchAndRewrite(Operation *src_op,
589 PatternRewriter &rewriter) const override {
590 auto op = cast<LgammaOp>(src_op);
591
592 Location loc = op.getLoc();
593 Value input = op.x();
594 TensorType original_tensor_type = op.x().getType().cast<TensorType>();
595
596 // The approximation is not precise enough for float16. Do the computation
597 // in float32 for that case.
598 TensorType tensor_type = original_tensor_type;
599 FloatType float_type = tensor_type.getElementType().cast<FloatType>();
600 bool needs_cast = float_type.getWidth() < 32;
601 if (needs_cast) {
602 MLIRContext *context = rewriter.getContext();
603 float_type = FloatType::getF32(context);
604 if (original_tensor_type.hasRank()) {
605 tensor_type =
606 RankedTensorType::get(original_tensor_type.getShape(), float_type);
607 } else {
608 tensor_type = UnrankedTensorType::get(float_type);
609 }
610 input = rewriter.create<CastOp>(loc, tensor_type, input);
611 }
612
613 // Helper lambda function for creating a ConstOp for a tensor filled with
614 // the given constant float value.
615 auto create_const_op = [&rewriter, loc, tensor_type,
616 float_type](double value) {
617 return rewriter.create<ConstOp>(
618 loc, DenseElementsAttr::get(tensor_type,
619 FloatAttr::get(float_type, value)));
620 };
621
622 Value one_half = create_const_op(0.5);
623 Value one = create_const_op(1.0);
624 Value infinity = create_const_op(std::numeric_limits<double>::infinity());
625 Value pi = create_const_op(M_PI);
626 Value log_pi = create_const_op(std::log(M_PI));
627 Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
628 Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
629 Value log_lanczos_gamma_plus_one_half =
630 create_const_op(std::log(kLanczosGamma + 0.5));
631 Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
632
633 Value minus_input = rewriter.create<NegOp>(loc, input);
634 Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
635
636 // If the input is less than 0.5 use Euler's reflection formula:
637 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
638 Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
639 Type tensor_bool_type = need_to_reflect.getType();
640 Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
641 input_minus_one);
642
643 Value x = base_lanczos_coeff;
644 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
645 Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
646 Value index = create_const_op(static_cast<double>(i));
647 Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
648 Value z_plus_index_plus_one =
649 rewriter.create<AddV2Op>(loc, z_plus_index, one);
650 Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
651 z_plus_index_plus_one);
652 x = rewriter.create<AddV2Op>(loc, x, incr);
653 }
654
655 // To improve accuracy on platforms with less-precise log implementations,
656 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
657 // the device.
658 // log(t) = log(kLanczosGamma + 0.5 + z)
659 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
660 Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
661 Value z_div_lanczos_gamma_plus_one_half =
662 rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
663 Value log1p_z_div_lanczos_gamma_plus_one_half =
664 rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
665 Value log_t =
666 rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
667 log1p_z_div_lanczos_gamma_plus_one_half);
668
669 // Compute the final result (modulo reflection). t(z) may be large, and we
670 // need to be careful not to overflow to infinity in the first term of
671 //
672 // (z + 1/2) * log(t(z)) - t(z).
673 //
674 // Therefore we compute this as
675 //
676 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
677 //
678 // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
679 Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
680 Value one_half_minus_t_div_log_t =
681 rewriter.create<SubOp>(loc, one_half, t_div_log_t);
682 Value z_plus_one_half_minus_t_div_log_t =
683 rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
684 Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
685 rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
686 Value log_x = rewriter.create<LogOp>(loc, x);
687 Value log_y_rhs = rewriter.create<AddV2Op>(
688 loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
689 Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
690
691 // Compute the reflected value, used when x < 0.5:
692 //
693 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
694 //
695 // (The abs is because lgamma is the log of the absolute value of the gamma
696 // function.)
697 //
698 // We have to be careful when computing the final term above. gamma(x) goes
699 // to +/-inf at every integer x < 0, and this is controlled by the
700 // sin(pi * x) term. The slope is large, so precision is particularly
701 // important.
702 //
703 // Because abs(sin(pi * x)) has period 1, we can equivalently use
704 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
705 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
706 // and if x is an integer, it evaluates to 0 exactly, which is significant
707 // because we then take the log of this value, and log(0) is inf.
708 //
709 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
710 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
711 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
712 //
713 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
714 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
715 // [0, 1] is symmetric across the line Y=0.5.
716 Value abs_input = rewriter.create<AbsOp>(loc, input);
717 Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
718 Value abs_frac_input =
719 rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
720
721 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
722 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
723 Value one_minus_abs_frac_input =
724 rewriter.create<SubOp>(loc, one, abs_frac_input);
725 Value abs_frac_input_gt_one_half =
726 rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
727 Value reduced_frac_input =
728 rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
729 one_minus_abs_frac_input, abs_frac_input);
730 Value pi_mul_reduced_frac_input =
731 rewriter.create<MulOp>(loc, pi, reduced_frac_input);
732 Value sin_pi_mul_reduced_frac_input =
733 rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
734 Value reflection_denom =
735 rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
736
737 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
738 // then it "wins" and the result is +/-inf.
739 Value is_finite =
740 rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
741 Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
742 Value log_pi_minus_reflection_denom =
743 rewriter.create<SubOp>(loc, log_pi, reflection_denom);
744 Value reflection_if_finite =
745 rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
746 Value reflection = rewriter.create<SelectV2Op>(
747 loc, is_finite, reflection_if_finite, neg_reflection_denom);
748
749 Value result =
750 rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
751
752 // lgamma(+/-inf) = +inf.
753 Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
754 result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
755
756 if (needs_cast) {
757 result = rewriter.create<CastOp>(loc, original_tensor_type, result);
758 }
759
760 rewriter.replaceOp(op, result);
761 return success();
762 }
763 };
764
765 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
766 // ExpandDims op.
767 //
768 // Sample result with 2 inputs to pack:
769 //
770 // %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
771 // %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
772 // %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
773 // %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
774 //
775 class LowerPackOp : public RewritePattern {
776 public:
LowerPackOp(MLIRContext * context)777 explicit LowerPackOp(MLIRContext *context)
778 : RewritePattern(
779 PackOp::getOperationName(), 1, context,
780 {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
781 ExpandDimsOp::getOperationName()}) {}
782
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const783 LogicalResult matchAndRewrite(Operation *src_op,
784 PatternRewriter &rewriter) const override {
785 auto op = cast<PackOp>(src_op);
786
787 Location loc = op.getLoc();
788 auto axis_value = rewriter.create<ConstOp>(
789 loc,
790 DenseElementsAttr::get(
791 RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
792 int64_t axis = op.axis();
793
794 Type prev_input_ty, inferred_ty;
795 SmallVector<Value, 4> expanded_inputs;
796 expanded_inputs.reserve(op.N());
797 for (Value input : op.values()) {
798 // If input type is different than the previous input type, infer the
799 // output type. Otherwise, use the already inferred output type from the
800 // previous iteration.
801 Type input_ty = input.getType();
802 if (input_ty != prev_input_ty) {
803 inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
804 prev_input_ty = input_ty;
805 }
806 expanded_inputs.push_back(
807 rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
808 }
809
810 rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
811 axis_value);
812 return success();
813 }
814 };
815
816 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
817 //
818 // Before rewrite:
819 // output = SpaceToBatchND(input, block_shape, paddings)
820 // Let:
821 // [batch] + spatial_shape + remaining_shape = input.shape
822 // M = spatial_shape.rank
823 // After rewrite:
824 // padded = zero-pad input with paddings
825 // The spatial_shape component of input.shape pads with paddings[*, 0]
826 // before each dimension, and paddings[*, 1] after each dimension.
827 // reshaped = reshape padded to:
828 // [batch]
829 // + [padded.shape[1]/block_shape[0], block_shape[0], ...,
830 // padded.shape[M]/block_shape[M-1], block_shape[M-1]]
831 // + remaining_shape
832 // permuted = transpose reshaped to:
833 // block_shape
834 // + [batch]
835 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
836 // + remaining_shape
837 // result = reshape permuted to:
838 // [batch * product(block_shape)]
839 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
840 // + remaining_shape
841 class LowerSpaceToBatchNDOp : public RewritePattern {
842 public:
LowerSpaceToBatchNDOp(MLIRContext * context)843 explicit LowerSpaceToBatchNDOp(MLIRContext *context)
844 : RewritePattern(SpaceToBatchNDOp::getOperationName(), 1, context,
845 {
846 CastOp::getOperationName(),
847 ConstOp::getOperationName(),
848 ConcatV2Op::getOperationName(),
849 AddV2Op::getOperationName(),
850 PadOp::getOperationName(),
851 SplitOp::getOperationName(),
852 UnpackOp::getOperationName(),
853 DivOp::getOperationName(),
854 MulOp::getOperationName(),
855 ReshapeOp::getOperationName(),
856 TransposeOp::getOperationName(),
857 }) {}
858
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const859 LogicalResult matchAndRewrite(Operation *src_op,
860 PatternRewriter &rewriter) const override {
861 auto op = cast<SpaceToBatchNDOp>(src_op);
862
863 Location loc = op.getLoc();
864 auto input_type = op.input().getType().cast<TensorType>();
865 auto element_type = input_type.getElementType();
866 if (!input_type.hasStaticShape()) {
867 return failure();
868 }
869 ArrayRef<int64_t> input_shape = input_type.getShape();
870 auto block_shape_type = op.block_shape().getType().cast<TensorType>();
871 if (!block_shape_type.hasStaticShape()) {
872 return failure();
873 }
874 auto paddings_type = op.paddings().getType().cast<ShapedType>();
875 if (!paddings_type.hasRank()) {
876 return failure();
877 }
878
879 int64_t input_rank = input_type.getRank();
880 int64_t block_rank = block_shape_type.getNumElements();
881 int64_t remaining_rank = input_rank - 1 - block_rank;
882 if (remaining_rank < 0) {
883 // TODO(b/157475606): Move this check to ::Verify
884 return failure();
885 }
886
887 auto block_shape_i64_type = RankedTensorType::get(
888 block_shape_type.getShape(), rewriter.getIntegerType(64));
889 auto block_shape_i64 =
890 rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
891
892 auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
893 rewriter.getIntegerType(64));
894 auto paddings_i64 =
895 rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
896
897 auto pad00 = rewriter.create<ConstOp>(
898 loc, DenseElementsAttr::get<int64_t>(
899 RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
900 {0, 0}));
901 SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
902 full_paddings_list.append(remaining_rank, pad00);
903 auto full_paddings_type =
904 RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
905 auto zero_i64 = rewriter.create<ConstOp>(
906 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
907 // Extends paddings to all dimensions of input by adding 0s to non-block
908 // dimensions.
909 auto full_paddings = rewriter.create<ConcatV2Op>(
910 loc, full_paddings_type, full_paddings_list, zero_i64);
911
912 // Compute the result type here instead of using shape inference because the
913 // full_paddings won't be available as a constant for shape inference.
914 ElementsAttr block_shape;
915 ElementsAttr paddings;
916 llvm::SmallVector<int64_t, 4> block_shape_ints;
917 auto padded_shape = llvm::to_vector<4>(input_shape);
918 if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
919 matchPattern(op.paddings(), m_Constant(&paddings))) {
920 for (uint64_t i = 0; i < block_rank; i++) {
921 int64_t paddings_sum =
922 paddings.getValue({i, 0}).cast<IntegerAttr>().getInt() +
923 paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
924 int64_t block_shape_i =
925 block_shape.getValue({i}).cast<IntegerAttr>().getInt();
926 padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
927 block_shape_ints.push_back(block_shape_i);
928 }
929 } else {
930 for (int i = 0; i < block_rank; i++) {
931 padded_shape[i + 1] = ShapedType::kDynamicSize;
932 }
933 block_shape_ints.resize(block_shape_type.getNumElements(), -1);
934 }
935
936 auto padded_type = RankedTensorType::get(padded_shape, element_type);
937 // padded = pad(input, full_paddings)
938 auto padded =
939 rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
940
941 auto paddings_sum_type =
942 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
943 // paddings_sum = paddings[*,0] + paddings[*,1]
944 auto paddings_split = rewriter.create<UnpackOp>(
945 loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
946 rewriter.getI64IntegerAttr(1));
947 auto paddings_sum = rewriter.create<AddV2Op>(
948 loc, paddings_split.getResult(0), paddings_split.getResult(1));
949
950 auto input_shape_tensor = rewriter.create<ConstOp>(
951 loc,
952 DenseElementsAttr::get(
953 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
954 input_shape));
955
956 // padded_shape_tensor is the shape of padded.
957 auto padded_shape_tensor =
958 rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
959
960 auto zero_i32 = rewriter.create<ConstOp>(
961 loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
962 SmallVector<Type, 4> padded_shape_splits_types(
963 input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
964 SmallVector<Value, 4> padded_shape_splits(
965 rewriter
966 .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
967 padded_shape_tensor)
968 .output());
969
970 SmallVector<Type, 4> block_shape_splits_types(
971 block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
972 SmallVector<Value, 4> block_shape_splits(
973 rewriter
974 .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
975 block_shape_i64)
976 .output());
977
978 SmallVector<int64_t, 4> outer_shape_ints;
979 SmallVector<Value, 4> outer_shape_vals;
980 for (int64_t i = 0; i < block_rank; ++i) {
981 // TODO(b/157475606): Insert tf.Assert that the following division has
982 // remainder 0.
983 outer_shape_vals.push_back(rewriter.create<DivOp>(
984 loc, padded_shape_splits[1 + i], block_shape_splits[i]));
985
986 auto padded_shape_i = padded_shape[1 + i];
987 auto block_shape_ints_i = block_shape_ints[i];
988
989 // Compute the outer_shape constant values to infer the reshape.
990 if (padded_shape_i == -1 || block_shape_ints_i == -1) {
991 outer_shape_ints.push_back(-1);
992 } else {
993 outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
994 }
995 }
996
997 SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
998 SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
999 for (int64_t i = 0; i < block_rank; ++i) {
1000 reshaped_shape_vals.push_back(outer_shape_vals[i]);
1001 reshaped_shape_vals.push_back(block_shape_splits[i]);
1002
1003 reshaped_shape_ints.push_back(outer_shape_ints[i]);
1004 reshaped_shape_ints.push_back(block_shape_ints[i]);
1005 }
1006 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1007 reshaped_shape_vals.push_back(padded_shape_splits[i]);
1008 reshaped_shape_ints.push_back(padded_shape[i]);
1009 }
1010 auto reshaped_shape = ValuesToRank1(
1011 rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
1012
1013 auto reshaped = rewriter.create<ReshapeOp>(
1014 loc, RankedTensorType::get(reshaped_shape_ints, element_type), padded,
1015 reshaped_shape);
1016
1017 SmallVector<int64_t, 6> permutation_vals;
1018 for (int64_t i = 0; i < block_rank; ++i) {
1019 permutation_vals.push_back(2 + 2 * i);
1020 }
1021 permutation_vals.push_back(0);
1022 for (int64_t i = 0; i < block_rank; ++i) {
1023 permutation_vals.push_back(1 + 2 * i);
1024 }
1025 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1026 permutation_vals.push_back(block_rank + i);
1027 }
1028 auto permutation = rewriter.create<ConstOp>(
1029 loc, GetI64ElementsAttr(permutation_vals, &rewriter));
1030
1031 auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
1032 auto output_batch = padded_shape_splits[0];
1033 for (int64_t i = 0; i < block_rank; ++i) {
1034 output_batch =
1035 rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
1036 }
1037 SmallVector<Value, 4> output_shape_vals{output_batch};
1038 for (int64_t i = 0; i < block_rank; ++i) {
1039 output_shape_vals.push_back(outer_shape_vals[i]);
1040 }
1041 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1042 output_shape_vals.push_back(padded_shape_splits[i]);
1043 }
1044 auto output_shape = ValuesToRank1(
1045 rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1046
1047 // Sometimes the result type is more specific than what the reshape builder
1048 // can infer.
1049 auto result_type = op.getResult().getType();
1050 rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1051 output_shape);
1052
1053 return success();
1054 }
1055 };
1056
1057 class LowerBatchToSpaceND : public RewritePattern {
1058 public:
LowerBatchToSpaceND(MLIRContext * context)1059 explicit LowerBatchToSpaceND(MLIRContext *context)
1060 : RewritePattern(BatchToSpaceNDOp::getOperationName(), 1, context,
1061 {
1062 ConstOp::getOperationName(),
1063 ReshapeOp::getOperationName(),
1064 SliceOp::getOperationName(),
1065 TransposeOp::getOperationName(),
1066 }) {}
1067
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1068 LogicalResult matchAndRewrite(Operation *src_op,
1069 PatternRewriter &rewriter) const override {
1070 auto op = cast<BatchToSpaceNDOp>(src_op);
1071 auto input = op.input();
1072 auto input_ty = input.getType().cast<ShapedType>();
1073 auto element_ty = input_ty.getElementType();
1074 if (!input_ty.hasStaticShape()) {
1075 return failure();
1076 }
1077
1078 const int input_rank = input_ty.getRank();
1079 auto input_shape = input_ty.getShape();
1080
1081 DenseIntElementsAttr block_shape;
1082 DenseIntElementsAttr crops;
1083 if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1084 !matchPattern(op.crops(), m_Constant(&crops))) {
1085 return failure();
1086 }
1087
1088 auto block_shape_ty = block_shape.getType();
1089 if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1090 return failure();
1091 }
1092
1093 const int block_rank = block_shape_ty.getShape().front();
1094 auto remainder_shape = input_shape.drop_front(1 + block_rank);
1095
1096 const int64_t batch_size = input_shape[0];
1097
1098 // Compute the product of the block_shape values.
1099 int64_t block_num_elems = 1;
1100
1101 for (auto val : block_shape.getIntValues()) {
1102 block_num_elems *= val.getSExtValue();
1103 }
1104
1105 if (block_num_elems <= 0) {
1106 op.emitOpError()
1107 << "The product of the block dimensions must be positive";
1108 return failure();
1109 }
1110
1111 // 1. Reshape `input` to `reshaped` of shape:
1112 // [block_shape[0], ..., block_shape[M-1],
1113 // batch / prod(block_shape),
1114 // input_shape[1], ..., input_shape[N-1]]
1115 std::vector<int64_t> reshaped_shape;
1116 for (auto val : block_shape) {
1117 reshaped_shape.push_back(val.getSExtValue());
1118 }
1119 reshaped_shape.resize(input_rank + block_rank);
1120
1121 reshaped_shape[block_rank] = batch_size / block_num_elems;
1122 std::copy(input_shape.begin() + 1, input_shape.end(),
1123 reshaped_shape.begin() + block_rank + 1);
1124
1125 auto reshaped = rewriter.create<TF::ReshapeOp>(
1126 op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1127 rewriter.create<ConstOp>(op.getLoc(),
1128 rewriter.getI64TensorAttr(reshaped_shape)));
1129
1130 // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1131 // [batch / prod(block_shape),
1132 //
1133 // input_shape[1], block_shape[0],
1134 // ...,
1135 // input_shape[M], block_shape[M-1],
1136 //
1137 // input_shape[M+1], ..., input_shape[N-1]]
1138 std::vector<int64_t> permutation(reshaped_shape.size());
1139 permutation[0] = block_rank;
1140 for (int i = 0; i < block_rank; ++i) {
1141 permutation[1 + 2 * i] = block_rank + 1 + i;
1142 permutation[1 + 2 * i + 1] = i;
1143 }
1144 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1145 1 + block_rank * 2);
1146
1147 std::vector<int64_t> transpose_shape(permutation.size());
1148 for (auto it : llvm::enumerate(permutation)) {
1149 transpose_shape[it.index()] = reshaped_shape[it.value()];
1150 }
1151
1152 auto permuted = rewriter.create<TF::TransposeOp>(
1153 op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1154 reshaped,
1155 rewriter.create<ConstOp>(op.getLoc(),
1156 rewriter.getI64TensorAttr(permutation)));
1157
1158 // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1159 // [batch / prod(block_shape),
1160 //
1161 // input_shape[1] * block_shape[0],
1162 // ...,
1163 // input_shape[M] * block_shape[M-1],
1164 //
1165 // input_shape[M+1],
1166 // ...,
1167 // input_shape[N-1]]
1168 std::vector<int64_t> reshaped_permuted_shape(input_rank);
1169 auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
1170 reshaped_permuted_shape[0] = batch_size / block_num_elems;
1171 for (int i = 0; i < block_rank; ++i) {
1172 reshaped_permuted_shape[1 + i] =
1173 block_shape_values[i].getSExtValue() * input_shape[1 + i];
1174 }
1175 std::copy(remainder_shape.begin(), remainder_shape.end(),
1176 reshaped_permuted_shape.begin() + 1 + block_rank);
1177
1178 auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1179 op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1180 permuted,
1181 rewriter.create<ConstOp>(
1182 op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1183
1184 // 4. Crop the start and end of dimensions `[1, ..., M]` of
1185 // `reshaped_permuted` according to `crops` to produce the output of
1186 // shape:
1187 // [batch / prod(block_shape),
1188 //
1189 // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1190 // ...,
1191 // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1192 //
1193 // input_shape[M+1], ..., input_shape[N-1]]
1194 std::vector<int64_t> start_indices(input_rank, 0);
1195 std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
1196 std::vector<int64_t> strides(input_rank, 1);
1197 auto crop_values = llvm::to_vector<4>(crops.getIntValues());
1198 for (int i = 0; i < block_rank; ++i) {
1199 int64_t crop_start = crop_values[i * 2].getSExtValue();
1200 int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1201
1202 if (crop_start < 0 || crop_end < 0) {
1203 op.emitOpError() << "Crops must be non-negative";
1204 return failure();
1205 }
1206
1207 start_indices[i + 1] = crop_start;
1208 slice_sizes[i + 1] -= crop_start + crop_end;
1209
1210 if (slice_sizes[i + 1] < 0) {
1211 op.emitOpError() << "Cropped size must be non-negative: start: "
1212 << crop_start << " end: " << crop_end << " size "
1213 << reshaped_permuted_shape[1 + i];
1214 }
1215 }
1216
1217 rewriter.replaceOpWithNewOp<TF::SliceOp>(
1218 op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1219 rewriter.create<ConstOp>(op.getLoc(),
1220 rewriter.getI64TensorAttr(start_indices)),
1221 rewriter.create<ConstOp>(op.getLoc(),
1222 rewriter.getI64TensorAttr(slice_sizes)));
1223 return success();
1224 }
1225 };
1226
1227 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1228 // since we currently don't have an implementation that can use this
1229 // information. Adds appropriate casts where necessary to align element types
1230 // of operands and result for `MatMulOp`.
1231 class LowerSparseMatMulOp : public RewritePattern {
1232 public:
LowerSparseMatMulOp(MLIRContext * context)1233 explicit LowerSparseMatMulOp(MLIRContext *context)
1234 : RewritePattern(
1235 SparseMatMulOp::getOperationName(), 1, context,
1236 {CastOp::getOperationName(), MatMulOp::getOperationName()}) {}
1237
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1238 LogicalResult matchAndRewrite(Operation *src_op,
1239 PatternRewriter &rewriter) const override {
1240 auto op = cast<SparseMatMulOp>(src_op);
1241
1242 // Result type must be f32 for applying the pattern (currently this is
1243 // required by the op anyway but this might change).
1244 if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1245 return failure();
1246 }
1247 MLIRContext *context = rewriter.getContext();
1248 llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1249 for (Value &operand : operands) {
1250 TensorType tensor_type = operand.getType().cast<TensorType>();
1251 Type element_type = tensor_type.getElementType();
1252 if (element_type.isF32()) continue;
1253 // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1254 // must be bf16 here.
1255 assert(element_type.isBF16());
1256 Type tensor_type_f32;
1257 if (tensor_type.hasRank()) {
1258 tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1259 FloatType::getF32(context));
1260 } else {
1261 tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1262 }
1263 // Add cast to f32 to conform with element type of result.
1264 operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1265 }
1266 Value result = rewriter.create<MatMulOp>(
1267 op.getLoc(), op.product().getType(), operands[0], operands[1],
1268 op.transpose_a(), op.transpose_b());
1269
1270 rewriter.replaceOp(op, {result});
1271 return success();
1272 }
1273 };
1274
1275 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1276 // were fused together.
1277 class Lower_UnaryOpsComposition
1278 : public OpRewritePattern<_UnaryOpsCompositionOp> {
1279 public:
1280 using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1281
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1282 LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1283 PatternRewriter &rewriter) const override {
1284 Value result = op.x();
1285 for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1286 std::string full_name = "tf." + op_name.str();
1287 // All ops in the sequences have the same result type as the original
1288 // result type.
1289 OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1290 /*types=*/{op.getType()}, /*attributes=*/{});
1291 Operation *op = rewriter.createOperation(state);
1292 result = op->getResult(0);
1293 }
1294 rewriter.replaceOp(op, {result});
1295 return success();
1296 }
1297 };
1298
1299 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1300 // the combined spatial dimensions. Generating the indices along the
1301 // width/height index could be used to gather along each of W and H dimension
1302 // of the input image array. To reduce to a single gather, these indices are
1303 // combined, so a single gather can be performed along the combined spatial
1304 // dimensions.
1305 //
1306 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1307 // containing the height and width values for the output tensor. This lowering
1308 // should work with a dynamic images array.
1309 //
1310 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1311 // corners would generate a [0, 1] lookup along both the x and y direction.
1312 // Then when combined to form the 1-D spatial index the values would be
1313 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1314 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1315 class LowerResizeNearestNeighbor : public RewritePattern {
1316 public:
LowerResizeNearestNeighbor(MLIRContext * context)1317 explicit LowerResizeNearestNeighbor(MLIRContext *context)
1318 : RewritePattern(ResizeNearestNeighborOp::getOperationName(), 1, context,
1319 {
1320 BroadcastToOp::getOperationName(),
1321 ConstOp::getOperationName(),
1322 DivOp::getOperationName(),
1323 PackOp::getOperationName(),
1324 RangeOp::getOperationName(),
1325 ReshapeOp::getOperationName(),
1326 ShapeOp::getOperationName(),
1327 SplitOp::getOperationName(),
1328 TransposeOp::getOperationName(),
1329 }) {}
1330
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1331 LogicalResult matchAndRewrite(Operation *src_op,
1332 PatternRewriter &rewriter) const override {
1333 auto op = cast<ResizeNearestNeighborOp>(src_op);
1334 auto loc = op.getLoc();
1335 auto result_ty = op.getType().cast<ShapedType>();
1336
1337 auto input = op.images();
1338 auto input_ty = input.getType().cast<ShapedType>();
1339 auto input_element_ty = input_ty.getElementType();
1340 auto out_size = op.size();
1341 auto out_size_ty = out_size.getType().cast<ShapedType>();
1342 auto out_size_element_ty = out_size_ty.getElementType();
1343
1344 // Input should be rank 4.
1345 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1346 return failure();
1347 }
1348
1349 // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1350 if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1351 out_size_ty.getShape()[0] != 2) {
1352 return failure();
1353 }
1354
1355 // Extract the output width / height dim size.
1356 int out_height_constant = -1;
1357 int out_width_constant = -1;
1358 DenseIntElementsAttr out_size_cst;
1359 if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1360 llvm::SmallVector<int64_t, 2> cst_size;
1361 for (auto val : out_size_cst.getIntValues()) {
1362 cst_size.push_back(val.getSExtValue());
1363 }
1364
1365 out_height_constant = cst_size[0];
1366 out_width_constant = cst_size[1];
1367
1368 if (out_height_constant < 0 || out_width_constant < 0) return failure();
1369 }
1370
1371 int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1372 ? -1
1373 : out_height_constant * out_width_constant;
1374
1375 // Input rank should be 4. Might be able to drop this requirement entirely
1376 // as its an input requirement.
1377 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1378 return failure();
1379 }
1380
1381 int batch_cst = input_ty.getShape()[0];
1382 int channels_cst = input_ty.getShape()[3];
1383
1384 int in_y_cst = input_ty.getShape()[1];
1385 int in_x_cst = input_ty.getShape()[2];
1386 int in_spatial_cst =
1387 in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1388
1389 // TODO(suderman): Add support for these optional parameters.
1390 if (op.align_corners() == true || op.half_pixel_centers() == true) {
1391 return failure();
1392 }
1393
1394 auto one =
1395 rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1396
1397 // Extract the image shape.
1398 Value input_shape = rewriter.create<ShapeOp>(
1399 loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1400 input_shape = rewriter.create<CastOp>(
1401 loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1402
1403 auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1404 auto split_image_shape = rewriter.create<UnpackOp>(
1405 loc,
1406 TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1407 input_shape);
1408
1409 // Extract the separate components from the input shape.
1410 auto batch = split_image_shape.getResult(0);
1411 auto in_y = split_image_shape.getResult(1);
1412 auto in_x = split_image_shape.getResult(2);
1413 auto channels = split_image_shape.getResult(3);
1414
1415 auto in_count = rewriter.create<MulOp>(
1416 loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1417
1418 // Unpack and separate the out width/height.
1419 auto split_out_size = rewriter.create<UnpackOp>(
1420 loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1421
1422 auto out_y = split_out_size.getResult(0);
1423 auto out_x = split_out_size.getResult(1);
1424
1425 auto out_count = rewriter.create<MulOp>(
1426 loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1427
1428 // Generate what the final output shape will look like.
1429 auto out_shape = rewriter.create<PackOp>(
1430 loc, RankedTensorType::get({4}, out_size_element_ty),
1431 ValueRange({batch, out_y, out_x, channels}));
1432
1433 // Compute the indices along the vertical dimension.
1434 auto in_y_f32 = rewriter.create<CastOp>(
1435 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1436 auto out_w_f32 = rewriter.create<CastOp>(
1437 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1438
1439 Value y_scale = rewriter.create<DivOp>(
1440 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1441 out_w_f32);
1442
1443 Value zero_f32 = rewriter.create<ConstOp>(
1444 loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1445 Value one_f32 = rewriter.create<ConstOp>(
1446 loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1447
1448 Value y_range = rewriter.create<RangeOp>(
1449 loc,
1450 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1451 zero_f32, out_w_f32, one_f32);
1452
1453 y_range = rewriter.create<MulOp>(
1454 loc,
1455 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1456 y_range, y_scale);
1457
1458 y_range = rewriter.create<CastOp>(
1459 loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1460 y_range);
1461
1462 y_range = rewriter.create<ReshapeOp>(
1463 loc,
1464 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1465 y_range,
1466 rewriter.create<PackOp>(loc,
1467 RankedTensorType::get({2}, out_size_element_ty),
1468 ValueRange({out_y, one})));
1469
1470 Value y_indices = rewriter.create<MulOp>(
1471 loc,
1472 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1473 y_range, in_x);
1474
1475 // Compute the indices for the nearest neighbour lookup across the width
1476 // dim.
1477 auto in_x_f32 = rewriter.create<CastOp>(
1478 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1479 auto out_h_f32 = rewriter.create<CastOp>(
1480 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1481
1482 Value x_scale = rewriter.create<DivOp>(
1483 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1484 out_h_f32);
1485
1486 Value x_range = rewriter.create<RangeOp>(
1487 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1488 zero_f32, out_h_f32, one_f32);
1489
1490 x_range = rewriter.create<MulOp>(
1491 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1492 x_range, x_scale);
1493
1494 x_range = rewriter.create<CastOp>(
1495 loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1496 x_range);
1497
1498 Value x_indices = rewriter.create<ReshapeOp>(
1499 loc,
1500 RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1501 x_range,
1502 rewriter.create<PackOp>(loc,
1503 RankedTensorType::get({2}, out_size_element_ty),
1504 ValueRange({one, out_x})));
1505
1506 // Generate the combined index array, reshape to be 1-D.
1507 Value indices = rewriter.create<AddV2Op>(
1508 loc,
1509 RankedTensorType::get({out_height_constant, out_width_constant},
1510 out_size_element_ty),
1511 y_indices, x_indices);
1512
1513 indices = rewriter.create<ReshapeOp>(
1514 loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1515 indices,
1516 rewriter.create<ReshapeOp>(
1517 loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1518 rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1519
1520 // Group the spatial indices and gather along that combined index.
1521 Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1522 loc,
1523 RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1524 input_element_ty),
1525 input,
1526 rewriter.create<PackOp>(loc,
1527 RankedTensorType::get({3}, out_size_element_ty),
1528 ValueRange({batch, in_count, channels})));
1529
1530 Value gathered_values = rewriter.create<GatherV2Op>(
1531 loc,
1532 RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1533 input_element_ty),
1534 input_collapsed_spatial, indices, /*axis=*/one);
1535
1536 gathered_values =
1537 rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1538
1539 rewriter.replaceOp(op, gathered_values);
1540 return success();
1541 }
1542 };
1543
1544 struct LowerRollOp : public RewritePattern {
LowerRollOpmlir::TF::__anon13d7ee640111::LowerRollOp1545 explicit LowerRollOp(MLIRContext *context)
1546 : RewritePattern(
1547 RollOp::getOperationName(), 1, context,
1548 {ConstOp::getOperationName(), SliceOp::getOperationName(),
1549 ConcatV2Op::getOperationName()}) {}
1550
matchAndRewritemlir::TF::__anon13d7ee640111::LowerRollOp1551 LogicalResult matchAndRewrite(Operation *op,
1552 PatternRewriter &rewriter) const override {
1553 auto tf_roll_op = cast<RollOp>(op);
1554
1555 auto input_ty = tf_roll_op.input().getType().dyn_cast<RankedTensorType>();
1556 if (!input_ty || !input_ty.hasStaticShape()) {
1557 return rewriter.notifyMatchFailure(
1558 op, "require the type of input to have static shapes");
1559 }
1560
1561 DenseIntElementsAttr shift_attr;
1562 Value shift = tf_roll_op.shift();
1563 auto shift_ranked_attr_type = shift.getType().dyn_cast<RankedTensorType>();
1564 if (!shift_ranked_attr_type ||
1565 !matchPattern(shift, m_Constant(&shift_attr))) {
1566 return failure();
1567 }
1568
1569 DenseIntElementsAttr axis_attr;
1570 Value axis = tf_roll_op.axis();
1571 auto axis_ranked_attr_type = axis.getType().dyn_cast<RankedTensorType>();
1572 if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) {
1573 return failure();
1574 }
1575
1576 // Combine duplicate axis and make sure they are in [0, rank(input)) range.
1577 auto input_shape = input_ty.getShape();
1578 int input_rank = input_shape.size();
1579 SmallVector<int32_t, 4> shift_map(input_rank, 0);
1580 for (int i = 0; i < axis_attr.getNumElements(); ++i) {
1581 int32_t axis_i = axis_attr.getValue<int32_t>(i);
1582 if (axis_i < 0) axis_i += input_rank;
1583 int32_t shift_i = shift_attr.getValue<int32_t>(i);
1584 shift_map[axis_i] += shift_i;
1585 }
1586
1587 SmallVector<int32_t, 4> adjusted_axis;
1588 SmallVector<int32_t, 4> adjusted_shift;
1589 for (int i = 0; i < input_rank; ++i) {
1590 int32_t input_dims_i = input_shape[i];
1591 int32_t shift_i = shift_map[i] % input_dims_i;
1592 if (shift_i < 0) shift_i += input_dims_i;
1593 if (shift_i == 0) continue;
1594 adjusted_axis.push_back(i);
1595 adjusted_shift.push_back(shift_i);
1596 }
1597
1598 // Convert rolling in each dimension to two Slice ops and one Concat op.
1599 auto axis_type =
1600 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
1601 auto create_slice_op = [&](int32_t axis_i, int32_t begin_i, int32_t size_i,
1602 Value input) {
1603 SmallVector<int64_t, 4> begin_values(input_rank, 0);
1604 begin_values[axis_i] = begin_i;
1605 auto begin_attr = DenseIntElementsAttr::get(axis_type, begin_values);
1606 auto begin =
1607 rewriter.create<ConstOp>(op->getLoc(), axis_type, begin_attr);
1608
1609 SmallVector<int64_t, 4> output_shape;
1610 output_shape.append(input_shape.begin(), input_shape.end());
1611 output_shape[axis_i] = size_i;
1612 auto size_attr = DenseIntElementsAttr::get(axis_type, output_shape);
1613 auto size = rewriter.create<ConstOp>(op->getLoc(), axis_type, size_attr);
1614
1615 auto slice_op_ty =
1616 RankedTensorType::get(output_shape, input_ty.getElementType());
1617 return rewriter.create<SliceOp>(op->getLoc(), slice_op_ty, input, begin,
1618 size);
1619 };
1620
1621 auto result = tf_roll_op.input();
1622 auto scalar_type =
1623 mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1624 for (int i = 0; i < adjusted_axis.size(); ++i) {
1625 int32_t axis_i = adjusted_axis[i];
1626 int32_t shift_i = adjusted_shift[i];
1627 auto slice_op_1 = create_slice_op(axis_i, input_shape[axis_i] - shift_i,
1628 shift_i, result);
1629 auto slice_op_2 =
1630 create_slice_op(axis_i, 0, input_shape[axis_i] - shift_i, result);
1631
1632 auto dim_attr = DenseIntElementsAttr::get(scalar_type, {axis_i});
1633 auto concat_dim =
1634 rewriter.create<ConstOp>(op->getLoc(), scalar_type, dim_attr);
1635 auto concat_op = rewriter.create<ConcatV2Op>(
1636 op->getLoc(), input_ty,
1637 ArrayRef<Value>({slice_op_1.output(), slice_op_2.output()}),
1638 concat_dim);
1639 result = concat_op.getResult();
1640 }
1641
1642 rewriter.replaceOp(op, result);
1643 return success();
1644 }
1645 };
1646
1647 // Decomposes Softmax and LogSoftmax to primitive TF ops, using the following
1648 // formulas:
1649 //
1650 // softmax = div(exp(logits), sum(exp(logits)))
1651 // log_softmax = sub(logits, log(sum(exp(logits))))
1652 //
1653 // TODO(jpienaar): Evaluate benefit of templating here.
1654 template <typename OpTy, bool use_log = true>
1655 class LowerSoftmaxOp : public OpRewritePattern<OpTy> {
1656 public:
1657 using OpRewritePattern<OpTy>::OpRewritePattern;
1658
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1659 LogicalResult matchAndRewrite(OpTy op,
1660 PatternRewriter &rewriter) const override {
1661 Value logits = op.logits();
1662 auto loc = op.getLoc();
1663
1664 // Note that the TensorFlow Softmax op verifies that the input rank is
1665 // greater than or equal to one so the following sequence is valid.
1666 auto reduce_dim =
1667 rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({-1}, &rewriter));
1668
1669 // Exponential of input values and then their sum can be very large here.
1670 // Division with large denominator is numerically unstable. To improve
1671 // numerical stability, subtract each batch with their max element so that
1672 // the maximum input value is zero. It can be shown that softmax computed
1673 // after adding or subtracting all inputs in a batch using a common value
1674 // gives mathematically equivalent result.
1675 auto max_logits =
1676 rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1677 /*keep_dims=*/rewriter.getBoolAttr(true));
1678 auto shifted_logits = rewriter.create<TF::SubOp>(loc, logits, max_logits);
1679
1680 // Exponentiate the inputs.
1681 Value exp = rewriter.create<TF::ExpOp>(loc, shifted_logits);
1682
1683 // Compute summation of the exponentials.
1684 Value sum =
1685 rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1686 /*keep_dims=*/rewriter.getBoolAttr(true));
1687
1688 if (use_log) {
1689 Value log = rewriter.create<TF::LogOp>(loc, sum);
1690 rewriter.replaceOpWithNewOp<TF::SubOp>(op, shifted_logits, log);
1691 } else {
1692 rewriter.replaceOpWithNewOp<TF::DivOp>(op, exp, sum);
1693 }
1694 return success();
1695 }
1696 };
1697
1698 } // namespace
1699
PopulateLoweringTFPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1700 void PopulateLoweringTFPatterns(MLIRContext *context,
1701 OwningRewritePatternList *patterns) {
1702 // clang-format off
1703 patterns->insert<
1704 LowerAddNOp,
1705 ConvertFakeQuantWithMinMaxVarsOp,
1706 LowerDynamicStitchOp<DynamicStitchOp>,
1707 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1708 LowerInvertPermutationOp,
1709 LowerLgammaOp,
1710 LowerPackOp,
1711 LowerBatchToSpaceND,
1712 LowerSpaceToBatchNDOp,
1713 LowerResizeNearestNeighbor,
1714 LowerSparseMatMulOp,
1715 Lower_UnaryOpsComposition,
1716 LowerRollOp>(context);
1717 // clang-format on
1718 populateWithGenerated(*patterns);
1719 }
1720
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1721 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1722 OwningRewritePatternList *patterns) {
1723 // clang-format off
1724 patterns->insert<
1725 ConvertFakeQuantWithMinMaxVarsOp,
1726 LowerAddNOp,
1727 LowerBatchToSpaceND,
1728 LowerDynamicStitchOp<DynamicStitchOp>,
1729 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1730 LowerInvertPermutationOp,
1731 LowerPackOp,
1732 LowerResizeNearestNeighbor,
1733 LowerSoftmaxOp<TF::LogSoftmaxOp, /*use_log=*/true>,
1734 LowerSoftmaxOp<TF::SoftmaxOp, /*use_log=*/false>,
1735 LowerSpaceToBatchNDOp,
1736 LowerSparseMatMulOp,
1737 Lower_UnaryOpsComposition,
1738 LowerRollOp>(context);
1739 // clang-format on
1740
1741 // Populate the relevant generated patterns.
1742 // clang-format off
1743 patterns->insert<
1744 LowerBiasAddGradOp,
1745 LowerDivNoNanOp,
1746 LowerEmptyOp,
1747 LowerFakeQuantWithMinMaxArgs,
1748 LowerFillOp,
1749 LowerIsNanOp,
1750 LowerL2LossOp,
1751 LowerMulNoNanOp,
1752 LowerPadOp,
1753 LowerReciprocal,
1754 LowerRintOp,
1755 LowerRoundOpOnFloatTensor,
1756 LowerRoundOpOnIntTensor,
1757 LowerRsqrtGradOp,
1758 LowerScatterNdOp,
1759 LowerSeluOp,
1760 LowerSeluGradOp,
1761 LowerSizeOp,
1762 LowerSoftmaxCrossEntropyWithLogitsOp,
1763 LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1764 LowerSqrtGradOp,
1765 LowerSquareOp,
1766 LowerSquaredDifferenceOpOnRealTensors,
1767 LowerSquaredDifferenceOpOneComplexTensors,
1768 LowerTanhGradOp,
1769 LowerXdivyOp,
1770 LowerXlog1pyOp,
1771 LowerXlogyOp>(context);
1772 // clang-format on
1773 }
1774
PopulateLoweringQuantizedPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1775 void PopulateLoweringQuantizedPatterns(MLIRContext *context,
1776 OwningRewritePatternList *patterns) {
1777 // clang-format off
1778 patterns->insert<
1779 LowerDequantizeOp>(context);
1780 // clang-format on
1781 }
1782
1783 } // namespace TF
1784 } // namespace mlir
1785