• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/TypeRange.h"  // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/core/util/tensor_format.h"
37 
38 namespace mlir {
39 namespace TF {
40 namespace {
41 
42 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)43 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
44                                                Builder *builder) {
45   RankedTensorType ty = RankedTensorType::get(
46       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
47   return DenseIntElementsAttr::get(ty, values);
48 }
49 
50 // Returns a 1-d i64 elements attribute populated with numbers from start to
51 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)52 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
53                                                      Builder *builder) {
54   int size = end - start;
55 
56   SmallVector<int64_t, 4> vals;
57   vals.resize(size);
58   std::iota(vals.begin(), vals.end(), start);
59 
60   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
61   return DenseIntElementsAttr::get(ty, vals);
62 }
63 
64 // Return an Attr representation of the value.
GetF32Scalar(OpBuilder * builder,float value)65 static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) {
66   return DenseElementsAttr::get(
67       RankedTensorType::get({}, builder->getF32Type()),
68       FloatAttr::get(builder->getF32Type(), value));
69 }
70 
71 // Returns a TF_CastOp to F32. This function is used for CastOps that are
72 // intermediate nodes in a TableGen pattern result. In such a case, the
73 // destination type is not inferred and must be given explicitly.
74 //
75 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpF32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)76 static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x,
77                                BoolAttr truncate) {
78   auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
79   if (!x_type) llvm_unreachable("unsupported type");
80   Type type = x_type.clone(builder->getF32Type());
81   return builder->create<CastOp>(loc, type, x, truncate);
82 }
83 
ConvertToAPFloat(double val,Type type)84 static APFloat ConvertToAPFloat(double val, Type type) {
85   if (type.getIntOrFloatBitWidth() == 32) {
86     return APFloat(static_cast<float>(val));
87   }
88 
89   return APFloat(val);
90 }
91 
92 // Returns int, float, or complex DenseElementsAttr with scalar shape with the
93 // given element type and the value.
94 template <typename T>
GetScalarOfType(Type ty,T raw_value)95 static DenseElementsAttr GetScalarOfType(Type ty, T raw_value) {
96   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
97   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
98     FloatAttr attr = FloatAttr::get(float_ty, raw_value);
99     return DenseElementsAttr::get(scalar_ty, attr);
100   } else if (auto int_ty = ty.dyn_cast_or_null<IntegerType>()) {
101     IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
102     return DenseElementsAttr::get(scalar_ty, attr);
103   } else if (auto complex_ty = ty.dyn_cast_or_null<ComplexType>()) {
104     Type complex_element_ty = complex_ty.getElementType();
105     if (complex_element_ty.isF32()) {
106       return DenseElementsAttr::get(
107           scalar_ty, static_cast<std::complex<float>>(raw_value));
108     } else if (complex_element_ty.isF64()) {
109       return DenseElementsAttr::get(
110           scalar_ty, static_cast<std::complex<double>>(raw_value));
111     }
112   }
113   llvm_unreachable("unsupported type");
114 }
115 
116 // Return true if the passed quantized type is unsigned.
QuantizedTypeIsUnsigned(Type type)117 bool QuantizedTypeIsUnsigned(Type type) {
118   return TypeSwitch<Type, bool>(type)
119       .Case<mlir::TF::Qint8Type>([](Type) { return false; })
120       .Case<mlir::TF::Qint16Type>([](Type) { return false; })
121       .Case<mlir::TF::Qint32Type>([](Type) { return false; })
122       .Case<mlir::TF::Quint8Type>([](Type) { return true; })
123       .Case<mlir::TF::Quint16Type>([](Type) { return true; })
124       .Default([](Type) {
125         llvm_unreachable("QuantizedTypeIsUnsigned: not a quantized type");
126         return false;
127       });
128 }
129 
130 // Return the half_range value that is used by DequantizeOp. half_range is used
131 // to offset the quantized representation before it gets scaled. In the case
132 // of negative quantize types, this offset is half the type's range.
DequantizeHalfRange(OpBuilder * builder,Value input)133 static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) {
134   auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
135   if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType");
136   bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType());
137   float half_range = is_unsigned ? 0 : 128;
138   return GetScalarOfType(builder->getF32Type(), half_range);
139 }
140 
141 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
142 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)143 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
144                                                     StringAttr data_format,
145                                                     Builder *builder) {
146   tensorflow::TensorFormat format;
147   if (!FormatFromString(data_format.getValue().str(), &format)) return {};
148 
149   // Reduce along all dimensions except the feature dimension.
150   int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
151   llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
152   std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
153   std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
154             feature_dim + 1);
155   return GetI64ElementsAttr(dims_to_reduce, builder);
156 }
157 
158 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
159 
160 // Infers ExpandDims op output type for the given input type `ty` and dimension
161 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)162 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
163   auto ranked_ty = ty.dyn_cast<RankedTensorType>();
164 
165   // Unranked type.
166   if (!ranked_ty) return ty;
167 
168   auto shape = llvm::to_vector<4>(ranked_ty.getShape());
169   if (axis < 0) axis += ranked_ty.getRank() + 1;
170 
171   shape.insert(shape.begin() + axis, 1);
172   return RankedTensorType::get(shape, ranked_ty.getElementType());
173 }
174 
175 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
176 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)177 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
178                     ArrayRef<Value> vals) {
179   int64_t length = vals.size();
180   auto type = RankedTensorType::get({length}, dtype);
181   auto axis = rewriter.create<ConstOp>(
182       loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
183   return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
184 }
185 
186 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
187 //
188 // Note that to improve the parallelism, AddN op uses tree-based reduction.
189 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
190 //
191 //                 0     1     2     3     4
192 //                 |     |     |     |     |
193 //                 -------     -------     |
194 //                    |           |        |
195 //                    5           6        |
196 //                    |           |        |
197 //                    -------------        |
198 //                          |              |
199 //                          7              |
200 //                          |              |
201 //                          ----------------
202 //                                 |
203 //                                 8
204 //
205 // Example:
206 //
207 //   %result = "tf.AddN"(%0, %1, %2)
208 //
209 // is lowered to:
210 //
211 //   %sum0 = "tf.AddV2"(%0, %1)
212 //   %result = "tf.AddV2"(%sum0, %2)
213 //
214 // While
215 //
216 //   %result = "tf.AddN"(%0, %1, %2, %3, %4)
217 //
218 // is lowered to:
219 //
220 //   %sum0 = "tf.AddV2"(%0, %1)
221 //   %sum1 = "tf.AddV2"(%2, %3)
222 //   %sum2 = "tf.AddV2"(%sum0, %sum1)
223 //   %result = "tf.AddV2"(%sum2, %4)
224 //
225 class LowerAddNOp : public RewritePattern {
226  public:
LowerAddNOp(MLIRContext * context)227   explicit LowerAddNOp(MLIRContext *context)
228       : RewritePattern(AddNOp::getOperationName(), 1, context,
229                        {AddV2Op::getOperationName()}) {}
230 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const231   LogicalResult matchAndRewrite(Operation *op,
232                                 PatternRewriter &rewriter) const override {
233     auto addn_op = cast<AddNOp>(op);
234 
235     // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
236     // support variant type so variant types require special handling.
237     if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
238       return failure();
239     llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
240                                          addn_op.inputs().end());
241 
242     int64_t n = operands.size();
243     // Keep doing tree-based reduction when there are more than one operand.
244     while (n > 1) {
245       for (int64_t i = 0; i < n; i += 2) {
246         // Add two adjacent operands if applicable.
247         operands[i / 2] =
248             (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
249                                                    operands[i], operands[i + 1])
250                         : operands[i];
251       }
252       n = (n + 1) / 2;
253     }
254 
255     rewriter.replaceOp(addn_op, operands[0]);
256     return success();
257   }
258 };
259 
260 // Lowers DynamicStitch op with constant indices and with static input and
261 // output shapes using Reshape, UnPack and Pack op.
262 //
263 //   %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
264 //   %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
265 //   tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
266 //   %arg1)
267 //     : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
268 //     -> tensor<5x2xf32>
269 //
270 // is lowered to
271 //
272 //   %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
273 //   %inp0 = "tf.Reshape"(%arg0, %shape)
274 //     : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
275 //   %inp1 = "tf.Reshape"(%arg1, %shape)
276 //     : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
277 //   %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
278 //     : (tensor<1x2xf32>) -> tensor<2xf32>
279 //   %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
280 //     : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
281 //     tensor<2xf32>)
282 //   %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
283 //   %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
284 //     : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
285 //        tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
286 //
287 template <typename OpT>
288 class LowerDynamicStitchOp : public RewritePattern {
289  public:
LowerDynamicStitchOp(MLIRContext * context)290   explicit LowerDynamicStitchOp(MLIRContext *context)
291       : RewritePattern(
292             OpT::getOperationName(), 1, context,
293             {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
294              UnpackOp::getOperationName(), PackOp::getOperationName()}) {}
295 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const296   LogicalResult matchAndRewrite(Operation *src_op,
297                                 PatternRewriter &rewriter) const override {
298     auto op = cast<OpT>(src_op);
299 
300     // Static output type is used to compute intermediate values. Note that the
301     // output type doesn't have to be static but if input types and indices are
302     // constant, then the output type can be statically determined.
303     RankedTensorType out_ty =
304         op.getType().template dyn_cast<RankedTensorType>();
305     if (!out_ty || !out_ty.hasStaticShape()) return failure();
306 
307     // Extract out all the constant indices' attributes and verify that data
308     // types are static.
309     SmallVector<DenseIntElementsAttr, 4> indices;
310     indices.reserve(op.N());
311     for (auto it : llvm::zip(op.indices(), op.data())) {
312       Value index = std::get<0>(it);
313       Value data = std::get<1>(it);
314 
315       DenseIntElementsAttr index_attr;
316       if (!matchPattern(index, m_Constant(&index_attr))) return failure();
317       indices.push_back(index_attr);
318 
319       RankedTensorType data_ty =
320           data.getType().template dyn_cast<RankedTensorType>();
321       if (!data_ty || !data_ty.hasStaticShape()) return failure();
322     }
323 
324     // Compute type of each of the items and shape to use while reshaping inputs
325     // so that they can be unpacked to extract out individual items.
326     ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
327     auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
328 
329     SmallVector<int64_t, 4> packed_shape;
330     packed_shape.push_back(-1);
331     packed_shape.append(item_shape.begin(), item_shape.end());
332     Location loc = op.getLoc();
333     auto packed_shape_val = rewriter.create<ConstOp>(
334         loc, GetI64ElementsAttr(packed_shape, &rewriter));
335 
336     // Prepare each of the output item by unpacking data and then putting it to
337     // the specified index.
338     SmallVector<Value, 8> values(out_ty.getDimSize(0));
339     for (auto it : llvm::zip(indices, op.data())) {
340       DenseIntElementsAttr index_attr = std::get<0>(it);
341       Value data = std::get<1>(it);
342 
343       auto reshaped_data =
344           rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
345       auto num_items = reshaped_data.getType()
346                            .template cast<RankedTensorType>()
347                            .getShape()[0];
348       auto items = rewriter.create<UnpackOp>(
349           loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
350           /*axis=*/0);
351       for (auto index_item : llvm::zip(index_attr, items.getResults())) {
352         int64_t output_index = std::get<0>(index_item).getSExtValue();
353         Value item = std::get<1>(index_item);
354         values[output_index] = item;
355       }
356     }
357 
358     rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
359     return success();
360   }
361 };
362 
363 // This pass performs a manual conversion with FakeQuant, converting between
364 // floating point and quantized space. It is designed to reproduce TF's
365 // implementation, mirroring the previous XLA implementation.
366 //
367 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
368 // 2. Converting the input bounds to quantized space, rounding values.
369 // 3. Convert back into floating point space.
370 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
371  public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)372   explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
373       : RewritePattern(
374             FakeQuantWithMinMaxVarsOp::getOperationName(), 1, context,
375             {AddV2Op::getOperationName(), SubOp::getOperationName(),
376              ConstOp::getOperationName(), MulOp::getOperationName(),
377              FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
378              DivOp::getOperationName(), RoundOp::getOperationName()}) {}
379 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const380   LogicalResult matchAndRewrite(Operation *src_op,
381                                 PatternRewriter &rewriter) const override {
382     auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
383 
384     auto input = op.inputs();
385     auto input_ty = input.getType().cast<ShapedType>();
386     auto element_ty = input_ty.getElementType();
387     auto scalar_ty = RankedTensorType::get({}, element_ty);
388 
389     auto num_bits = op.num_bits();
390     auto narrow_range = op.narrow_range();
391     const double bits_min = narrow_range ? 1 : 0;
392     const double bits_max = (1 << num_bits) - 1;
393 
394     auto float_min = op.min();
395     auto float_max = op.max();
396 
397     auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
398 
399     // Compute the range when quantized.
400     auto quant_min = rewriter.create<ConstOp>(
401         op.getLoc(), DenseElementsAttr::get(
402                          scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
403 
404     auto quant_max = rewriter.create<ConstOp>(
405         op.getLoc(), DenseElementsAttr::get(
406                          scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
407 
408     auto quant_diff = rewriter.create<ConstOp>(
409         op.getLoc(),
410         DenseElementsAttr::get(
411             scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
412 
413     auto quant_to_float =
414         rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
415 
416     auto float_to_quant =
417         rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
418 
419     // During quantization, the quantized min/max values may not line up
420     // perfectly with the specified min/max. Nudge them into the right range.
421     auto min_scaled =
422         rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
423     auto min_scaled_sub =
424         rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
425 
426     auto mid_rounded =
427         rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
428 
429     auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
430         op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
431 
432     auto quant_min_sub =
433         rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
434     auto quant_max_sub =
435         rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
436 
437     auto nudged_float_min =
438         rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
439 
440     auto nudged_float_max =
441         rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
442 
443     // Now quantize the input value with the approximated min/max values.
444 
445     // Move the input value into quantized space
446     Value quantized_input = rewriter.create<ClipByValueOp>(
447         op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
448 
449     quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
450                                              quantized_input, nudged_float_min);
451 
452     quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
453                                              quantized_input, float_to_quant);
454 
455     // Round the quantized input always to the positive direction.
456     auto half_val = rewriter.create<ConstOp>(
457         op.getLoc(),
458         DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
459 
460     quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
461                                                quantized_input, half_val);
462 
463     quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
464 
465     // Convert back into floating point spae.
466     Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
467                                           quantized_input, quant_to_float);
468 
469     output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
470                                       nudged_float_min);
471 
472     rewriter.replaceOp(op, {output});
473     return success();
474   }
475 };
476 
477 // Lowers InvertPermutation op to TensorScatterUpdate op.
478 //
479 // Example:
480 //
481 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
482 //   "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
483 //
484 // is lowered to
485 //
486 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
487 //   %start = "tf.Const"() {value = dense<0> : tensor<i32>}
488 //   %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
489 //   %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
490 //   %updates = "tf.Range"(%start, %limit, %delta) :
491 //     (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
492 //   %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
493 //   %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
494 //     tensor<5x1xi32>
495 //   "tf.TensorScatterUpdate"(%x, %indices, %updates) :
496 //     (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
497 //
498 class LowerInvertPermutationOp : public RewritePattern {
499  public:
LowerInvertPermutationOp(MLIRContext * context)500   explicit LowerInvertPermutationOp(MLIRContext *context)
501       : RewritePattern(
502             InvertPermutationOp::getOperationName(), 1, context,
503             {ConstOp::getOperationName(), RangeOp::getOperationName(),
504              ReshapeOp::getOperationName(),
505              TensorScatterUpdateOp::getOperationName()}) {}
506 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const507   LogicalResult matchAndRewrite(Operation *src_op,
508                                 PatternRewriter &rewriter) const override {
509     auto op = cast<InvertPermutationOp>(src_op);
510 
511     Location loc = op.getLoc();
512     auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
513     // x input must have static shape.
514     if (!x_type || !x_type.hasStaticShape()) {
515       return failure();
516     }
517     Type int_type = x_type.getElementType();  // Could be i32 or i64.
518 
519     auto result_type = x_type;
520     auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
521     Value limit = rewriter.create<ConstOp>(
522         loc, GetScalarOfType(int_type, x_type.getShape()[0]));
523     auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
524     // Construct a sequence of numbers [0, 1, ... len(x)-1].
525     auto updates =
526         rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
527 
528     auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
529     auto shape = rewriter.create<ConstOp>(
530         loc, DenseElementsAttr::get(
531                  shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
532     auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
533 
534     rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
535                                                        indices, updates);
536     return success();
537   }
538 };
539 
540 // Approximates lgamma using Lanczos' approximation from
541 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
542 // Analysis series B. Vol. 1:
543 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
544 // t(z) = z + kLanczosGamma + 1/2
545 // A(z) = kBaseLanczosCoeff
546 //       + sigma(k = 1, n, kLanczosCoefficients[i] / (z +  k))
547 //
548 // Coefficients for the Lanczos approximation of the gamma function. The
549 // coefficients are uniquely determined by the choice of g and n
550 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
551 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
552 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
553 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
554 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
555 static constexpr double kLanczosGamma = 7;  // aka g
556 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
557 static constexpr std::array<double, 8> kLanczosCoefficients = {
558     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
559     771.3234287776530788486528258894,   -176.61502916214059906584551354,
560     12.507343278686904814458936853,     -0.13857109526572011689554707,
561     9.984369578019570859563e-6,         1.50563273514931155834e-7};
562 
563 class LowerLgammaOp : public RewritePattern {
564  public:
LowerLgammaOp(MLIRContext * context)565   explicit LowerLgammaOp(MLIRContext *context)
566       : RewritePattern(LgammaOp::getOperationName(), 1, context,
567                        {
568                            CastOp::getOperationName(),
569                            ConstOp::getOperationName(),
570                            NegOp::getOperationName(),
571                            SubOp::getOperationName(),
572                            SelectV2Op::getOperationName(),
573                            LessOp::getOperationName(),
574                            AddV2Op::getOperationName(),
575                            DivOp::getOperationName(),
576                            SubOp::getOperationName(),
577                            LogOp::getOperationName(),
578                            Log1pOp::getOperationName(),
579                            IsInfOp::getOperationName(),
580                            MulOp::getOperationName(),
581                            FloorOp::getOperationName(),
582                            AbsOp::getOperationName(),
583                            GreaterOp::getOperationName(),
584                            SinOp::getOperationName(),
585                            IsFiniteOp::getOperationName(),
586                        }) {}
587 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const588   LogicalResult matchAndRewrite(Operation *src_op,
589                                 PatternRewriter &rewriter) const override {
590     auto op = cast<LgammaOp>(src_op);
591 
592     Location loc = op.getLoc();
593     Value input = op.x();
594     TensorType original_tensor_type = op.x().getType().cast<TensorType>();
595 
596     // The approximation is not precise enough for float16. Do the computation
597     // in float32 for that case.
598     TensorType tensor_type = original_tensor_type;
599     FloatType float_type = tensor_type.getElementType().cast<FloatType>();
600     bool needs_cast = float_type.getWidth() < 32;
601     if (needs_cast) {
602       MLIRContext *context = rewriter.getContext();
603       float_type = FloatType::getF32(context);
604       if (original_tensor_type.hasRank()) {
605         tensor_type =
606             RankedTensorType::get(original_tensor_type.getShape(), float_type);
607       } else {
608         tensor_type = UnrankedTensorType::get(float_type);
609       }
610       input = rewriter.create<CastOp>(loc, tensor_type, input);
611     }
612 
613     // Helper lambda function for creating a ConstOp for a tensor filled with
614     // the given constant float value.
615     auto create_const_op = [&rewriter, loc, tensor_type,
616                             float_type](double value) {
617       return rewriter.create<ConstOp>(
618           loc, DenseElementsAttr::get(tensor_type,
619                                       FloatAttr::get(float_type, value)));
620     };
621 
622     Value one_half = create_const_op(0.5);
623     Value one = create_const_op(1.0);
624     Value infinity = create_const_op(std::numeric_limits<double>::infinity());
625     Value pi = create_const_op(M_PI);
626     Value log_pi = create_const_op(std::log(M_PI));
627     Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
628     Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
629     Value log_lanczos_gamma_plus_one_half =
630         create_const_op(std::log(kLanczosGamma + 0.5));
631     Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
632 
633     Value minus_input = rewriter.create<NegOp>(loc, input);
634     Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
635 
636     // If the input is less than 0.5 use Euler's reflection formula:
637     // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
638     Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
639     Type tensor_bool_type = need_to_reflect.getType();
640     Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
641                                           input_minus_one);
642 
643     Value x = base_lanczos_coeff;
644     for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
645       Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
646       Value index = create_const_op(static_cast<double>(i));
647       Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
648       Value z_plus_index_plus_one =
649           rewriter.create<AddV2Op>(loc, z_plus_index, one);
650       Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
651                                           z_plus_index_plus_one);
652       x = rewriter.create<AddV2Op>(loc, x, incr);
653     }
654 
655     // To improve accuracy on platforms with less-precise log implementations,
656     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
657     // the device.
658     // log(t) = log(kLanczosGamma + 0.5 + z)
659     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
660     Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
661     Value z_div_lanczos_gamma_plus_one_half =
662         rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
663     Value log1p_z_div_lanczos_gamma_plus_one_half =
664         rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
665     Value log_t =
666         rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
667                                  log1p_z_div_lanczos_gamma_plus_one_half);
668 
669     // Compute the final result (modulo reflection).  t(z) may be large, and we
670     // need to be careful not to overflow to infinity in the first term of
671     //
672     //   (z + 1/2) * log(t(z)) - t(z).
673     //
674     // Therefore we compute this as
675     //
676     //   (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
677     //
678     // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
679     Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
680     Value one_half_minus_t_div_log_t =
681         rewriter.create<SubOp>(loc, one_half, t_div_log_t);
682     Value z_plus_one_half_minus_t_div_log_t =
683         rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
684     Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
685         rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
686     Value log_x = rewriter.create<LogOp>(loc, x);
687     Value log_y_rhs = rewriter.create<AddV2Op>(
688         loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
689     Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
690 
691     // Compute the reflected value, used when x < 0.5:
692     //
693     //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
694     //
695     // (The abs is because lgamma is the log of the absolute value of the gamma
696     // function.)
697     //
698     // We have to be careful when computing the final term above. gamma(x) goes
699     // to +/-inf at every integer x < 0, and this is controlled by the
700     // sin(pi * x) term.  The slope is large, so precision is particularly
701     // important.
702     //
703     // Because abs(sin(pi * x)) has period 1, we can equivalently use
704     // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x.  This
705     // is more numerically accurate: It doesn't overflow to inf like pi * x can,
706     // and if x is an integer, it evaluates to 0 exactly, which is significant
707     // because we then take the log of this value, and log(0) is inf.
708     //
709     // We don't have a frac(x) primitive in XLA and computing it is tricky, but
710     // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
711     // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
712     //
713     // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
714     // to 1.  To remedy this, we can use the fact that sin(pi * x) in the domain
715     // [0, 1] is symmetric across the line Y=0.5.
716     Value abs_input = rewriter.create<AbsOp>(loc, input);
717     Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
718     Value abs_frac_input =
719         rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
720 
721     // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
722     // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
723     Value one_minus_abs_frac_input =
724         rewriter.create<SubOp>(loc, one, abs_frac_input);
725     Value abs_frac_input_gt_one_half =
726         rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
727     Value reduced_frac_input =
728         rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
729                                     one_minus_abs_frac_input, abs_frac_input);
730     Value pi_mul_reduced_frac_input =
731         rewriter.create<MulOp>(loc, pi, reduced_frac_input);
732     Value sin_pi_mul_reduced_frac_input =
733         rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
734     Value reflection_denom =
735         rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
736 
737     // Avoid computing -inf - inf, which is nan.  If reflection_denom is +/-inf,
738     // then it "wins" and the result is +/-inf.
739     Value is_finite =
740         rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
741     Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
742     Value log_pi_minus_reflection_denom =
743         rewriter.create<SubOp>(loc, log_pi, reflection_denom);
744     Value reflection_if_finite =
745         rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
746     Value reflection = rewriter.create<SelectV2Op>(
747         loc, is_finite, reflection_if_finite, neg_reflection_denom);
748 
749     Value result =
750         rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
751 
752     // lgamma(+/-inf) = +inf.
753     Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
754     result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
755 
756     if (needs_cast) {
757       result = rewriter.create<CastOp>(loc, original_tensor_type, result);
758     }
759 
760     rewriter.replaceOp(op, result);
761     return success();
762   }
763 };
764 
765 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
766 // ExpandDims op.
767 //
768 // Sample result with 2 inputs to pack:
769 //
770 //   %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
771 //   %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
772 //   %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
773 //   %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
774 //
775 class LowerPackOp : public RewritePattern {
776  public:
LowerPackOp(MLIRContext * context)777   explicit LowerPackOp(MLIRContext *context)
778       : RewritePattern(
779             PackOp::getOperationName(), 1, context,
780             {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
781              ExpandDimsOp::getOperationName()}) {}
782 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const783   LogicalResult matchAndRewrite(Operation *src_op,
784                                 PatternRewriter &rewriter) const override {
785     auto op = cast<PackOp>(src_op);
786 
787     Location loc = op.getLoc();
788     auto axis_value = rewriter.create<ConstOp>(
789         loc,
790         DenseElementsAttr::get(
791             RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
792     int64_t axis = op.axis();
793 
794     Type prev_input_ty, inferred_ty;
795     SmallVector<Value, 4> expanded_inputs;
796     expanded_inputs.reserve(op.N());
797     for (Value input : op.values()) {
798       // If input type is different than the previous input type, infer the
799       // output type. Otherwise, use the already inferred output type from the
800       // previous iteration.
801       Type input_ty = input.getType();
802       if (input_ty != prev_input_ty) {
803         inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
804         prev_input_ty = input_ty;
805       }
806       expanded_inputs.push_back(
807           rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
808     }
809 
810     rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
811                                             axis_value);
812     return success();
813   }
814 };
815 
816 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
817 //
818 // Before rewrite:
819 //   output = SpaceToBatchND(input, block_shape, paddings)
820 // Let:
821 //   [batch] + spatial_shape + remaining_shape = input.shape
822 //   M = spatial_shape.rank
823 // After rewrite:
824 //   padded = zero-pad input with paddings
825 //     The spatial_shape component of input.shape pads with paddings[*, 0]
826 //     before each dimension, and paddings[*, 1] after each dimension.
827 //   reshaped = reshape padded to:
828 //     [batch]
829 //     + [padded.shape[1]/block_shape[0], block_shape[0], ...,
830 //        padded.shape[M]/block_shape[M-1], block_shape[M-1]]
831 //     + remaining_shape
832 //   permuted = transpose reshaped to:
833 //     block_shape
834 //     + [batch]
835 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
836 //     + remaining_shape
837 //   result = reshape permuted to:
838 //     [batch * product(block_shape)]
839 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
840 //     + remaining_shape
841 class LowerSpaceToBatchNDOp : public RewritePattern {
842  public:
LowerSpaceToBatchNDOp(MLIRContext * context)843   explicit LowerSpaceToBatchNDOp(MLIRContext *context)
844       : RewritePattern(SpaceToBatchNDOp::getOperationName(), 1, context,
845                        {
846                            CastOp::getOperationName(),
847                            ConstOp::getOperationName(),
848                            ConcatV2Op::getOperationName(),
849                            AddV2Op::getOperationName(),
850                            PadOp::getOperationName(),
851                            SplitOp::getOperationName(),
852                            UnpackOp::getOperationName(),
853                            DivOp::getOperationName(),
854                            MulOp::getOperationName(),
855                            ReshapeOp::getOperationName(),
856                            TransposeOp::getOperationName(),
857                        }) {}
858 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const859   LogicalResult matchAndRewrite(Operation *src_op,
860                                 PatternRewriter &rewriter) const override {
861     auto op = cast<SpaceToBatchNDOp>(src_op);
862 
863     Location loc = op.getLoc();
864     auto input_type = op.input().getType().cast<TensorType>();
865     auto element_type = input_type.getElementType();
866     if (!input_type.hasStaticShape()) {
867       return failure();
868     }
869     ArrayRef<int64_t> input_shape = input_type.getShape();
870     auto block_shape_type = op.block_shape().getType().cast<TensorType>();
871     if (!block_shape_type.hasStaticShape()) {
872       return failure();
873     }
874     auto paddings_type = op.paddings().getType().cast<ShapedType>();
875     if (!paddings_type.hasRank()) {
876       return failure();
877     }
878 
879     int64_t input_rank = input_type.getRank();
880     int64_t block_rank = block_shape_type.getNumElements();
881     int64_t remaining_rank = input_rank - 1 - block_rank;
882     if (remaining_rank < 0) {
883       // TODO(b/157475606): Move this check to ::Verify
884       return failure();
885     }
886 
887     auto block_shape_i64_type = RankedTensorType::get(
888         block_shape_type.getShape(), rewriter.getIntegerType(64));
889     auto block_shape_i64 =
890         rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
891 
892     auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
893                                                    rewriter.getIntegerType(64));
894     auto paddings_i64 =
895         rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
896 
897     auto pad00 = rewriter.create<ConstOp>(
898         loc, DenseElementsAttr::get<int64_t>(
899                  RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
900                  {0, 0}));
901     SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
902     full_paddings_list.append(remaining_rank, pad00);
903     auto full_paddings_type =
904         RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
905     auto zero_i64 = rewriter.create<ConstOp>(
906         loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
907     // Extends paddings to all dimensions of input by adding 0s to non-block
908     // dimensions.
909     auto full_paddings = rewriter.create<ConcatV2Op>(
910         loc, full_paddings_type, full_paddings_list, zero_i64);
911 
912     // Compute the result type here instead of using shape inference because the
913     // full_paddings won't be available as a constant for shape inference.
914     ElementsAttr block_shape;
915     ElementsAttr paddings;
916     llvm::SmallVector<int64_t, 4> block_shape_ints;
917     auto padded_shape = llvm::to_vector<4>(input_shape);
918     if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
919         matchPattern(op.paddings(), m_Constant(&paddings))) {
920       for (uint64_t i = 0; i < block_rank; i++) {
921         int64_t paddings_sum =
922             paddings.getValue({i, 0}).cast<IntegerAttr>().getInt() +
923             paddings.getValue({i, 1}).cast<IntegerAttr>().getInt();
924         int64_t block_shape_i =
925             block_shape.getValue({i}).cast<IntegerAttr>().getInt();
926         padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
927         block_shape_ints.push_back(block_shape_i);
928       }
929     } else {
930       for (int i = 0; i < block_rank; i++) {
931         padded_shape[i + 1] = ShapedType::kDynamicSize;
932       }
933       block_shape_ints.resize(block_shape_type.getNumElements(), -1);
934     }
935 
936     auto padded_type = RankedTensorType::get(padded_shape, element_type);
937     // padded = pad(input, full_paddings)
938     auto padded =
939         rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
940 
941     auto paddings_sum_type =
942         RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
943     // paddings_sum = paddings[*,0] + paddings[*,1]
944     auto paddings_split = rewriter.create<UnpackOp>(
945         loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
946         rewriter.getI64IntegerAttr(1));
947     auto paddings_sum = rewriter.create<AddV2Op>(
948         loc, paddings_split.getResult(0), paddings_split.getResult(1));
949 
950     auto input_shape_tensor = rewriter.create<ConstOp>(
951         loc,
952         DenseElementsAttr::get(
953             RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
954             input_shape));
955 
956     // padded_shape_tensor is the shape of padded.
957     auto padded_shape_tensor =
958         rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
959 
960     auto zero_i32 = rewriter.create<ConstOp>(
961         loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
962     SmallVector<Type, 4> padded_shape_splits_types(
963         input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
964     SmallVector<Value, 4> padded_shape_splits(
965         rewriter
966             .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
967                              padded_shape_tensor)
968             .output());
969 
970     SmallVector<Type, 4> block_shape_splits_types(
971         block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
972     SmallVector<Value, 4> block_shape_splits(
973         rewriter
974             .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
975                              block_shape_i64)
976             .output());
977 
978     SmallVector<int64_t, 4> outer_shape_ints;
979     SmallVector<Value, 4> outer_shape_vals;
980     for (int64_t i = 0; i < block_rank; ++i) {
981       // TODO(b/157475606): Insert tf.Assert that the following division has
982       // remainder 0.
983       outer_shape_vals.push_back(rewriter.create<DivOp>(
984           loc, padded_shape_splits[1 + i], block_shape_splits[i]));
985 
986       auto padded_shape_i = padded_shape[1 + i];
987       auto block_shape_ints_i = block_shape_ints[i];
988 
989       // Compute the outer_shape constant values to infer the reshape.
990       if (padded_shape_i == -1 || block_shape_ints_i == -1) {
991         outer_shape_ints.push_back(-1);
992       } else {
993         outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
994       }
995     }
996 
997     SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
998     SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
999     for (int64_t i = 0; i < block_rank; ++i) {
1000       reshaped_shape_vals.push_back(outer_shape_vals[i]);
1001       reshaped_shape_vals.push_back(block_shape_splits[i]);
1002 
1003       reshaped_shape_ints.push_back(outer_shape_ints[i]);
1004       reshaped_shape_ints.push_back(block_shape_ints[i]);
1005     }
1006     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1007       reshaped_shape_vals.push_back(padded_shape_splits[i]);
1008       reshaped_shape_ints.push_back(padded_shape[i]);
1009     }
1010     auto reshaped_shape = ValuesToRank1(
1011         rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
1012 
1013     auto reshaped = rewriter.create<ReshapeOp>(
1014         loc, RankedTensorType::get(reshaped_shape_ints, element_type), padded,
1015         reshaped_shape);
1016 
1017     SmallVector<int64_t, 6> permutation_vals;
1018     for (int64_t i = 0; i < block_rank; ++i) {
1019       permutation_vals.push_back(2 + 2 * i);
1020     }
1021     permutation_vals.push_back(0);
1022     for (int64_t i = 0; i < block_rank; ++i) {
1023       permutation_vals.push_back(1 + 2 * i);
1024     }
1025     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1026       permutation_vals.push_back(block_rank + i);
1027     }
1028     auto permutation = rewriter.create<ConstOp>(
1029         loc, GetI64ElementsAttr(permutation_vals, &rewriter));
1030 
1031     auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
1032     auto output_batch = padded_shape_splits[0];
1033     for (int64_t i = 0; i < block_rank; ++i) {
1034       output_batch =
1035           rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
1036     }
1037     SmallVector<Value, 4> output_shape_vals{output_batch};
1038     for (int64_t i = 0; i < block_rank; ++i) {
1039       output_shape_vals.push_back(outer_shape_vals[i]);
1040     }
1041     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1042       output_shape_vals.push_back(padded_shape_splits[i]);
1043     }
1044     auto output_shape = ValuesToRank1(
1045         rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1046 
1047     // Sometimes the result type is more specific than what the reshape builder
1048     // can infer.
1049     auto result_type = op.getResult().getType();
1050     rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1051                                            output_shape);
1052 
1053     return success();
1054   }
1055 };
1056 
1057 class LowerBatchToSpaceND : public RewritePattern {
1058  public:
LowerBatchToSpaceND(MLIRContext * context)1059   explicit LowerBatchToSpaceND(MLIRContext *context)
1060       : RewritePattern(BatchToSpaceNDOp::getOperationName(), 1, context,
1061                        {
1062                            ConstOp::getOperationName(),
1063                            ReshapeOp::getOperationName(),
1064                            SliceOp::getOperationName(),
1065                            TransposeOp::getOperationName(),
1066                        }) {}
1067 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1068   LogicalResult matchAndRewrite(Operation *src_op,
1069                                 PatternRewriter &rewriter) const override {
1070     auto op = cast<BatchToSpaceNDOp>(src_op);
1071     auto input = op.input();
1072     auto input_ty = input.getType().cast<ShapedType>();
1073     auto element_ty = input_ty.getElementType();
1074     if (!input_ty.hasStaticShape()) {
1075       return failure();
1076     }
1077 
1078     const int input_rank = input_ty.getRank();
1079     auto input_shape = input_ty.getShape();
1080 
1081     DenseIntElementsAttr block_shape;
1082     DenseIntElementsAttr crops;
1083     if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1084         !matchPattern(op.crops(), m_Constant(&crops))) {
1085       return failure();
1086     }
1087 
1088     auto block_shape_ty = block_shape.getType();
1089     if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1090       return failure();
1091     }
1092 
1093     const int block_rank = block_shape_ty.getShape().front();
1094     auto remainder_shape = input_shape.drop_front(1 + block_rank);
1095 
1096     const int64_t batch_size = input_shape[0];
1097 
1098     // Compute the product of the block_shape values.
1099     int64_t block_num_elems = 1;
1100 
1101     for (auto val : block_shape.getIntValues()) {
1102       block_num_elems *= val.getSExtValue();
1103     }
1104 
1105     if (block_num_elems <= 0) {
1106       op.emitOpError()
1107           << "The product of the block dimensions must be positive";
1108       return failure();
1109     }
1110 
1111     // 1. Reshape `input` to `reshaped` of shape:
1112     //      [block_shape[0], ..., block_shape[M-1],
1113     //       batch / prod(block_shape),
1114     //       input_shape[1], ..., input_shape[N-1]]
1115     std::vector<int64_t> reshaped_shape;
1116     for (auto val : block_shape) {
1117       reshaped_shape.push_back(val.getSExtValue());
1118     }
1119     reshaped_shape.resize(input_rank + block_rank);
1120 
1121     reshaped_shape[block_rank] = batch_size / block_num_elems;
1122     std::copy(input_shape.begin() + 1, input_shape.end(),
1123               reshaped_shape.begin() + block_rank + 1);
1124 
1125     auto reshaped = rewriter.create<TF::ReshapeOp>(
1126         op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1127         rewriter.create<ConstOp>(op.getLoc(),
1128                                  rewriter.getI64TensorAttr(reshaped_shape)));
1129 
1130     // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1131     //      [batch / prod(block_shape),
1132     //
1133     //       input_shape[1], block_shape[0],
1134     //       ...,
1135     //       input_shape[M], block_shape[M-1],
1136     //
1137     //       input_shape[M+1], ..., input_shape[N-1]]
1138     std::vector<int64_t> permutation(reshaped_shape.size());
1139     permutation[0] = block_rank;
1140     for (int i = 0; i < block_rank; ++i) {
1141       permutation[1 + 2 * i] = block_rank + 1 + i;
1142       permutation[1 + 2 * i + 1] = i;
1143     }
1144     std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1145               1 + block_rank * 2);
1146 
1147     std::vector<int64_t> transpose_shape(permutation.size());
1148     for (auto it : llvm::enumerate(permutation)) {
1149       transpose_shape[it.index()] = reshaped_shape[it.value()];
1150     }
1151 
1152     auto permuted = rewriter.create<TF::TransposeOp>(
1153         op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1154         reshaped,
1155         rewriter.create<ConstOp>(op.getLoc(),
1156                                  rewriter.getI64TensorAttr(permutation)));
1157 
1158     // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1159     //      [batch / prod(block_shape),
1160     //
1161     //       input_shape[1] * block_shape[0],
1162     //       ...,
1163     //       input_shape[M] * block_shape[M-1],
1164     //
1165     //       input_shape[M+1],
1166     //       ...,
1167     //       input_shape[N-1]]
1168     std::vector<int64_t> reshaped_permuted_shape(input_rank);
1169     auto block_shape_values = llvm::to_vector<4>(block_shape.getIntValues());
1170     reshaped_permuted_shape[0] = batch_size / block_num_elems;
1171     for (int i = 0; i < block_rank; ++i) {
1172       reshaped_permuted_shape[1 + i] =
1173           block_shape_values[i].getSExtValue() * input_shape[1 + i];
1174     }
1175     std::copy(remainder_shape.begin(), remainder_shape.end(),
1176               reshaped_permuted_shape.begin() + 1 + block_rank);
1177 
1178     auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1179         op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1180         permuted,
1181         rewriter.create<ConstOp>(
1182             op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1183 
1184     // 4. Crop the start and end of dimensions `[1, ..., M]` of
1185     //    `reshaped_permuted` according to `crops` to produce the output of
1186     //    shape:
1187     //      [batch / prod(block_shape),
1188     //
1189     //       input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1190     //       ...,
1191     //       input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1192     //
1193     //       input_shape[M+1], ..., input_shape[N-1]]
1194     std::vector<int64_t> start_indices(input_rank, 0);
1195     std::vector<int64_t> slice_sizes = reshaped_permuted_shape;
1196     std::vector<int64_t> strides(input_rank, 1);
1197     auto crop_values = llvm::to_vector<4>(crops.getIntValues());
1198     for (int i = 0; i < block_rank; ++i) {
1199       int64_t crop_start = crop_values[i * 2].getSExtValue();
1200       int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1201 
1202       if (crop_start < 0 || crop_end < 0) {
1203         op.emitOpError() << "Crops must be non-negative";
1204         return failure();
1205       }
1206 
1207       start_indices[i + 1] = crop_start;
1208       slice_sizes[i + 1] -= crop_start + crop_end;
1209 
1210       if (slice_sizes[i + 1] < 0) {
1211         op.emitOpError() << "Cropped size must be non-negative: start: "
1212                          << crop_start << " end: " << crop_end << " size "
1213                          << reshaped_permuted_shape[1 + i];
1214       }
1215     }
1216 
1217     rewriter.replaceOpWithNewOp<TF::SliceOp>(
1218         op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1219         rewriter.create<ConstOp>(op.getLoc(),
1220                                  rewriter.getI64TensorAttr(start_indices)),
1221         rewriter.create<ConstOp>(op.getLoc(),
1222                                  rewriter.getI64TensorAttr(slice_sizes)));
1223     return success();
1224   }
1225 };
1226 
1227 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1228 // since we currently don't have an implementation that can use this
1229 // information. Adds appropriate casts where necessary to align element types
1230 // of operands and result for `MatMulOp`.
1231 class LowerSparseMatMulOp : public RewritePattern {
1232  public:
LowerSparseMatMulOp(MLIRContext * context)1233   explicit LowerSparseMatMulOp(MLIRContext *context)
1234       : RewritePattern(
1235             SparseMatMulOp::getOperationName(), 1, context,
1236             {CastOp::getOperationName(), MatMulOp::getOperationName()}) {}
1237 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1238   LogicalResult matchAndRewrite(Operation *src_op,
1239                                 PatternRewriter &rewriter) const override {
1240     auto op = cast<SparseMatMulOp>(src_op);
1241 
1242     // Result type must be f32 for applying the pattern (currently this is
1243     // required by the op anyway but this might change).
1244     if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1245       return failure();
1246     }
1247     MLIRContext *context = rewriter.getContext();
1248     llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1249     for (Value &operand : operands) {
1250       TensorType tensor_type = operand.getType().cast<TensorType>();
1251       Type element_type = tensor_type.getElementType();
1252       if (element_type.isF32()) continue;
1253       // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1254       // must be bf16 here.
1255       assert(element_type.isBF16());
1256       Type tensor_type_f32;
1257       if (tensor_type.hasRank()) {
1258         tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1259                                                 FloatType::getF32(context));
1260       } else {
1261         tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1262       }
1263       // Add cast to f32 to conform with element type of result.
1264       operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1265     }
1266     Value result = rewriter.create<MatMulOp>(
1267         op.getLoc(), op.product().getType(), operands[0], operands[1],
1268         op.transpose_a(), op.transpose_b());
1269 
1270     rewriter.replaceOp(op, {result});
1271     return success();
1272   }
1273 };
1274 
1275 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1276 // were fused together.
1277 class Lower_UnaryOpsComposition
1278     : public OpRewritePattern<_UnaryOpsCompositionOp> {
1279  public:
1280   using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1281 
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1282   LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1283                                 PatternRewriter &rewriter) const override {
1284     Value result = op.x();
1285     for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1286       std::string full_name = "tf." + op_name.str();
1287       // All ops in the sequences have the same result type as the original
1288       // result type.
1289       OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1290                            /*types=*/{op.getType()}, /*attributes=*/{});
1291       Operation *op = rewriter.createOperation(state);
1292       result = op->getResult(0);
1293     }
1294     rewriter.replaceOp(op, {result});
1295     return success();
1296   }
1297 };
1298 
1299 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1300 // the combined spatial dimensions. Generating the indices along the
1301 // width/height index could be used to gather along each of W and H dimension
1302 // of the input image array. To reduce to a single gather, these indices are
1303 // combined, so a single gather can be performed along the combined spatial
1304 // dimensions.
1305 //
1306 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1307 // containing the height and width values for the output tensor. This lowering
1308 // should work with a dynamic images array.
1309 //
1310 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1311 // corners would generate a [0, 1] lookup along both the x and y direction.
1312 // Then when combined to form the 1-D spatial index the values would be
1313 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1314 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1315 class LowerResizeNearestNeighbor : public RewritePattern {
1316  public:
LowerResizeNearestNeighbor(MLIRContext * context)1317   explicit LowerResizeNearestNeighbor(MLIRContext *context)
1318       : RewritePattern(ResizeNearestNeighborOp::getOperationName(), 1, context,
1319                        {
1320                            BroadcastToOp::getOperationName(),
1321                            ConstOp::getOperationName(),
1322                            DivOp::getOperationName(),
1323                            PackOp::getOperationName(),
1324                            RangeOp::getOperationName(),
1325                            ReshapeOp::getOperationName(),
1326                            ShapeOp::getOperationName(),
1327                            SplitOp::getOperationName(),
1328                            TransposeOp::getOperationName(),
1329                        }) {}
1330 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1331   LogicalResult matchAndRewrite(Operation *src_op,
1332                                 PatternRewriter &rewriter) const override {
1333     auto op = cast<ResizeNearestNeighborOp>(src_op);
1334     auto loc = op.getLoc();
1335     auto result_ty = op.getType().cast<ShapedType>();
1336 
1337     auto input = op.images();
1338     auto input_ty = input.getType().cast<ShapedType>();
1339     auto input_element_ty = input_ty.getElementType();
1340     auto out_size = op.size();
1341     auto out_size_ty = out_size.getType().cast<ShapedType>();
1342     auto out_size_element_ty = out_size_ty.getElementType();
1343 
1344     // Input should be rank 4.
1345     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1346       return failure();
1347     }
1348 
1349     // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1350     if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1351         out_size_ty.getShape()[0] != 2) {
1352       return failure();
1353     }
1354 
1355     // Extract the output width / height dim size.
1356     int out_height_constant = -1;
1357     int out_width_constant = -1;
1358     DenseIntElementsAttr out_size_cst;
1359     if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1360       llvm::SmallVector<int64_t, 2> cst_size;
1361       for (auto val : out_size_cst.getIntValues()) {
1362         cst_size.push_back(val.getSExtValue());
1363       }
1364 
1365       out_height_constant = cst_size[0];
1366       out_width_constant = cst_size[1];
1367 
1368       if (out_height_constant < 0 || out_width_constant < 0) return failure();
1369     }
1370 
1371     int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1372                               ? -1
1373                               : out_height_constant * out_width_constant;
1374 
1375     // Input rank should be 4. Might be able to drop this requirement entirely
1376     // as its an input requirement.
1377     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1378       return failure();
1379     }
1380 
1381     int batch_cst = input_ty.getShape()[0];
1382     int channels_cst = input_ty.getShape()[3];
1383 
1384     int in_y_cst = input_ty.getShape()[1];
1385     int in_x_cst = input_ty.getShape()[2];
1386     int in_spatial_cst =
1387         in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1388 
1389     // TODO(suderman): Add support for these optional parameters.
1390     if (op.align_corners() == true || op.half_pixel_centers() == true) {
1391       return failure();
1392     }
1393 
1394     auto one =
1395         rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1396 
1397     // Extract the image shape.
1398     Value input_shape = rewriter.create<ShapeOp>(
1399         loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1400     input_shape = rewriter.create<CastOp>(
1401         loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1402 
1403     auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1404     auto split_image_shape = rewriter.create<UnpackOp>(
1405         loc,
1406         TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1407         input_shape);
1408 
1409     // Extract the separate components from the input shape.
1410     auto batch = split_image_shape.getResult(0);
1411     auto in_y = split_image_shape.getResult(1);
1412     auto in_x = split_image_shape.getResult(2);
1413     auto channels = split_image_shape.getResult(3);
1414 
1415     auto in_count = rewriter.create<MulOp>(
1416         loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1417 
1418     // Unpack and separate the out width/height.
1419     auto split_out_size = rewriter.create<UnpackOp>(
1420         loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1421 
1422     auto out_y = split_out_size.getResult(0);
1423     auto out_x = split_out_size.getResult(1);
1424 
1425     auto out_count = rewriter.create<MulOp>(
1426         loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1427 
1428     // Generate what the final output shape will look like.
1429     auto out_shape = rewriter.create<PackOp>(
1430         loc, RankedTensorType::get({4}, out_size_element_ty),
1431         ValueRange({batch, out_y, out_x, channels}));
1432 
1433     // Compute the indices along the vertical dimension.
1434     auto in_y_f32 = rewriter.create<CastOp>(
1435         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1436     auto out_w_f32 = rewriter.create<CastOp>(
1437         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1438 
1439     Value y_scale = rewriter.create<DivOp>(
1440         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1441         out_w_f32);
1442 
1443     Value zero_f32 = rewriter.create<ConstOp>(
1444         loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1445     Value one_f32 = rewriter.create<ConstOp>(
1446         loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1447 
1448     Value y_range = rewriter.create<RangeOp>(
1449         loc,
1450         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1451         zero_f32, out_w_f32, one_f32);
1452 
1453     y_range = rewriter.create<MulOp>(
1454         loc,
1455         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1456         y_range, y_scale);
1457 
1458     y_range = rewriter.create<CastOp>(
1459         loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1460         y_range);
1461 
1462     y_range = rewriter.create<ReshapeOp>(
1463         loc,
1464         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1465         y_range,
1466         rewriter.create<PackOp>(loc,
1467                                 RankedTensorType::get({2}, out_size_element_ty),
1468                                 ValueRange({out_y, one})));
1469 
1470     Value y_indices = rewriter.create<MulOp>(
1471         loc,
1472         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1473         y_range, in_x);
1474 
1475     // Compute the indices for the nearest neighbour lookup across the width
1476     // dim.
1477     auto in_x_f32 = rewriter.create<CastOp>(
1478         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1479     auto out_h_f32 = rewriter.create<CastOp>(
1480         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1481 
1482     Value x_scale = rewriter.create<DivOp>(
1483         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1484         out_h_f32);
1485 
1486     Value x_range = rewriter.create<RangeOp>(
1487         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1488         zero_f32, out_h_f32, one_f32);
1489 
1490     x_range = rewriter.create<MulOp>(
1491         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1492         x_range, x_scale);
1493 
1494     x_range = rewriter.create<CastOp>(
1495         loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1496         x_range);
1497 
1498     Value x_indices = rewriter.create<ReshapeOp>(
1499         loc,
1500         RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1501         x_range,
1502         rewriter.create<PackOp>(loc,
1503                                 RankedTensorType::get({2}, out_size_element_ty),
1504                                 ValueRange({one, out_x})));
1505 
1506     // Generate the combined index array, reshape to be 1-D.
1507     Value indices = rewriter.create<AddV2Op>(
1508         loc,
1509         RankedTensorType::get({out_height_constant, out_width_constant},
1510                               out_size_element_ty),
1511         y_indices, x_indices);
1512 
1513     indices = rewriter.create<ReshapeOp>(
1514         loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1515         indices,
1516         rewriter.create<ReshapeOp>(
1517             loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1518             rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1519 
1520     // Group the spatial indices and gather along that combined index.
1521     Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1522         loc,
1523         RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1524                               input_element_ty),
1525         input,
1526         rewriter.create<PackOp>(loc,
1527                                 RankedTensorType::get({3}, out_size_element_ty),
1528                                 ValueRange({batch, in_count, channels})));
1529 
1530     Value gathered_values = rewriter.create<GatherV2Op>(
1531         loc,
1532         RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1533                               input_element_ty),
1534         input_collapsed_spatial, indices, /*axis=*/one);
1535 
1536     gathered_values =
1537         rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1538 
1539     rewriter.replaceOp(op, gathered_values);
1540     return success();
1541   }
1542 };
1543 
1544 struct LowerRollOp : public RewritePattern {
LowerRollOpmlir::TF::__anon13d7ee640111::LowerRollOp1545   explicit LowerRollOp(MLIRContext *context)
1546       : RewritePattern(
1547             RollOp::getOperationName(), 1, context,
1548             {ConstOp::getOperationName(), SliceOp::getOperationName(),
1549              ConcatV2Op::getOperationName()}) {}
1550 
matchAndRewritemlir::TF::__anon13d7ee640111::LowerRollOp1551   LogicalResult matchAndRewrite(Operation *op,
1552                                 PatternRewriter &rewriter) const override {
1553     auto tf_roll_op = cast<RollOp>(op);
1554 
1555     auto input_ty = tf_roll_op.input().getType().dyn_cast<RankedTensorType>();
1556     if (!input_ty || !input_ty.hasStaticShape()) {
1557       return rewriter.notifyMatchFailure(
1558           op, "require the type of input to have static shapes");
1559     }
1560 
1561     DenseIntElementsAttr shift_attr;
1562     Value shift = tf_roll_op.shift();
1563     auto shift_ranked_attr_type = shift.getType().dyn_cast<RankedTensorType>();
1564     if (!shift_ranked_attr_type ||
1565         !matchPattern(shift, m_Constant(&shift_attr))) {
1566       return failure();
1567     }
1568 
1569     DenseIntElementsAttr axis_attr;
1570     Value axis = tf_roll_op.axis();
1571     auto axis_ranked_attr_type = axis.getType().dyn_cast<RankedTensorType>();
1572     if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) {
1573       return failure();
1574     }
1575 
1576     // Combine duplicate axis and make sure they are in [0, rank(input)) range.
1577     auto input_shape = input_ty.getShape();
1578     int input_rank = input_shape.size();
1579     SmallVector<int32_t, 4> shift_map(input_rank, 0);
1580     for (int i = 0; i < axis_attr.getNumElements(); ++i) {
1581       int32_t axis_i = axis_attr.getValue<int32_t>(i);
1582       if (axis_i < 0) axis_i += input_rank;
1583       int32_t shift_i = shift_attr.getValue<int32_t>(i);
1584       shift_map[axis_i] += shift_i;
1585     }
1586 
1587     SmallVector<int32_t, 4> adjusted_axis;
1588     SmallVector<int32_t, 4> adjusted_shift;
1589     for (int i = 0; i < input_rank; ++i) {
1590       int32_t input_dims_i = input_shape[i];
1591       int32_t shift_i = shift_map[i] % input_dims_i;
1592       if (shift_i < 0) shift_i += input_dims_i;
1593       if (shift_i == 0) continue;
1594       adjusted_axis.push_back(i);
1595       adjusted_shift.push_back(shift_i);
1596     }
1597 
1598     // Convert rolling in each dimension to two Slice ops and one Concat op.
1599     auto axis_type =
1600         RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
1601     auto create_slice_op = [&](int32_t axis_i, int32_t begin_i, int32_t size_i,
1602                                Value input) {
1603       SmallVector<int64_t, 4> begin_values(input_rank, 0);
1604       begin_values[axis_i] = begin_i;
1605       auto begin_attr = DenseIntElementsAttr::get(axis_type, begin_values);
1606       auto begin =
1607           rewriter.create<ConstOp>(op->getLoc(), axis_type, begin_attr);
1608 
1609       SmallVector<int64_t, 4> output_shape;
1610       output_shape.append(input_shape.begin(), input_shape.end());
1611       output_shape[axis_i] = size_i;
1612       auto size_attr = DenseIntElementsAttr::get(axis_type, output_shape);
1613       auto size = rewriter.create<ConstOp>(op->getLoc(), axis_type, size_attr);
1614 
1615       auto slice_op_ty =
1616           RankedTensorType::get(output_shape, input_ty.getElementType());
1617       return rewriter.create<SliceOp>(op->getLoc(), slice_op_ty, input, begin,
1618                                       size);
1619     };
1620 
1621     auto result = tf_roll_op.input();
1622     auto scalar_type =
1623         mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1624     for (int i = 0; i < adjusted_axis.size(); ++i) {
1625       int32_t axis_i = adjusted_axis[i];
1626       int32_t shift_i = adjusted_shift[i];
1627       auto slice_op_1 = create_slice_op(axis_i, input_shape[axis_i] - shift_i,
1628                                         shift_i, result);
1629       auto slice_op_2 =
1630           create_slice_op(axis_i, 0, input_shape[axis_i] - shift_i, result);
1631 
1632       auto dim_attr = DenseIntElementsAttr::get(scalar_type, {axis_i});
1633       auto concat_dim =
1634           rewriter.create<ConstOp>(op->getLoc(), scalar_type, dim_attr);
1635       auto concat_op = rewriter.create<ConcatV2Op>(
1636           op->getLoc(), input_ty,
1637           ArrayRef<Value>({slice_op_1.output(), slice_op_2.output()}),
1638           concat_dim);
1639       result = concat_op.getResult();
1640     }
1641 
1642     rewriter.replaceOp(op, result);
1643     return success();
1644   }
1645 };
1646 
1647 // Decomposes Softmax and LogSoftmax to primitive TF ops, using the following
1648 // formulas:
1649 //
1650 //     softmax = div(exp(logits), sum(exp(logits)))
1651 //     log_softmax = sub(logits, log(sum(exp(logits))))
1652 //
1653 // TODO(jpienaar): Evaluate benefit of templating here.
1654 template <typename OpTy, bool use_log = true>
1655 class LowerSoftmaxOp : public OpRewritePattern<OpTy> {
1656  public:
1657   using OpRewritePattern<OpTy>::OpRewritePattern;
1658 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1659   LogicalResult matchAndRewrite(OpTy op,
1660                                 PatternRewriter &rewriter) const override {
1661     Value logits = op.logits();
1662     auto loc = op.getLoc();
1663 
1664     // Note that the TensorFlow Softmax op verifies that the input rank is
1665     // greater than or equal to one so the following sequence is valid.
1666     auto reduce_dim =
1667         rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({-1}, &rewriter));
1668 
1669     // Exponential of input values and then their sum can be very large here.
1670     // Division with large denominator is numerically unstable. To improve
1671     // numerical stability, subtract each batch with their max element so that
1672     // the maximum input value is zero. It can be shown that softmax computed
1673     // after adding or subtracting all inputs in a batch using a common value
1674     // gives mathematically equivalent result.
1675     auto max_logits =
1676         rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1677                                    /*keep_dims=*/rewriter.getBoolAttr(true));
1678     auto shifted_logits = rewriter.create<TF::SubOp>(loc, logits, max_logits);
1679 
1680     // Exponentiate the inputs.
1681     Value exp = rewriter.create<TF::ExpOp>(loc, shifted_logits);
1682 
1683     // Compute summation of the exponentials.
1684     Value sum =
1685         rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1686                                    /*keep_dims=*/rewriter.getBoolAttr(true));
1687 
1688     if (use_log) {
1689       Value log = rewriter.create<TF::LogOp>(loc, sum);
1690       rewriter.replaceOpWithNewOp<TF::SubOp>(op, shifted_logits, log);
1691     } else {
1692       rewriter.replaceOpWithNewOp<TF::DivOp>(op, exp, sum);
1693     }
1694     return success();
1695   }
1696 };
1697 
1698 }  // namespace
1699 
PopulateLoweringTFPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1700 void PopulateLoweringTFPatterns(MLIRContext *context,
1701                                 OwningRewritePatternList *patterns) {
1702   // clang-format off
1703   patterns->insert<
1704       LowerAddNOp,
1705       ConvertFakeQuantWithMinMaxVarsOp,
1706       LowerDynamicStitchOp<DynamicStitchOp>,
1707       LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1708       LowerInvertPermutationOp,
1709       LowerLgammaOp,
1710       LowerPackOp,
1711       LowerBatchToSpaceND,
1712       LowerSpaceToBatchNDOp,
1713       LowerResizeNearestNeighbor,
1714       LowerSparseMatMulOp,
1715       Lower_UnaryOpsComposition,
1716       LowerRollOp>(context);
1717   // clang-format on
1718   populateWithGenerated(*patterns);
1719 }
1720 
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1721 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1722                                          OwningRewritePatternList *patterns) {
1723   // clang-format off
1724   patterns->insert<
1725       ConvertFakeQuantWithMinMaxVarsOp,
1726       LowerAddNOp,
1727       LowerBatchToSpaceND,
1728       LowerDynamicStitchOp<DynamicStitchOp>,
1729       LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1730       LowerInvertPermutationOp,
1731       LowerPackOp,
1732       LowerResizeNearestNeighbor,
1733       LowerSoftmaxOp<TF::LogSoftmaxOp, /*use_log=*/true>,
1734       LowerSoftmaxOp<TF::SoftmaxOp, /*use_log=*/false>,
1735       LowerSpaceToBatchNDOp,
1736       LowerSparseMatMulOp,
1737       Lower_UnaryOpsComposition,
1738       LowerRollOp>(context);
1739   // clang-format on
1740 
1741   // Populate the relevant generated patterns.
1742   // clang-format off
1743   patterns->insert<
1744       LowerBiasAddGradOp,
1745       LowerDivNoNanOp,
1746       LowerEmptyOp,
1747       LowerFakeQuantWithMinMaxArgs,
1748       LowerFillOp,
1749       LowerIsNanOp,
1750       LowerL2LossOp,
1751       LowerMulNoNanOp,
1752       LowerPadOp,
1753       LowerReciprocal,
1754       LowerRintOp,
1755       LowerRoundOpOnFloatTensor,
1756       LowerRoundOpOnIntTensor,
1757       LowerRsqrtGradOp,
1758       LowerScatterNdOp,
1759       LowerSeluOp,
1760       LowerSeluGradOp,
1761       LowerSizeOp,
1762       LowerSoftmaxCrossEntropyWithLogitsOp,
1763       LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1764       LowerSqrtGradOp,
1765       LowerSquareOp,
1766       LowerSquaredDifferenceOpOnRealTensors,
1767       LowerSquaredDifferenceOpOneComplexTensors,
1768       LowerTanhGradOp,
1769       LowerXdivyOp,
1770       LowerXlog1pyOp,
1771       LowerXlogyOp>(context);
1772   // clang-format on
1773 }
1774 
PopulateLoweringQuantizedPatterns(MLIRContext * context,OwningRewritePatternList * patterns)1775 void PopulateLoweringQuantizedPatterns(MLIRContext *context,
1776                                        OwningRewritePatternList *patterns) {
1777   // clang-format off
1778   patterns->insert<
1779       LowerDequantizeOp>(context);
1780   // clang-format on
1781 }
1782 
1783 }  // namespace TF
1784 }  // namespace mlir
1785