• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 #include <string>
25 #include <tuple>
26 #include <type_traits>
27 
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/ADT/iterator_range.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
43 #include "mlir/Dialect/Traits.h"  // from @llvm-project
44 #include "mlir/IR/Attributes.h"  // from @llvm-project
45 #include "mlir/IR/Builders.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
48 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
49 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
50 #include "mlir/IR/Identifier.h"  // from @llvm-project
51 #include "mlir/IR/Location.h"  // from @llvm-project
52 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
53 #include "mlir/IR/Matchers.h"  // from @llvm-project
54 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
55 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
56 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
58 #include "mlir/IR/Types.h"  // from @llvm-project
59 #include "mlir/IR/Value.h"  // from @llvm-project
60 #include "mlir/Parser.h"  // from @llvm-project
61 #include "mlir/Support/LLVM.h"  // from @llvm-project
62 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
70 #include "tensorflow/core/framework/kernel_shape_util.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/util/padding.h"
73 #include "tensorflow/core/util/tensor_format.h"
74 
75 namespace mlir {
76 namespace TF {
77 
78 namespace {
79 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
81 }  // namespace
82 
83 //===----------------------------------------------------------------------===//
84 // AddOp
85 //===----------------------------------------------------------------------===//
86 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)87 void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
88                                         MLIRContext *context) {
89   results.insert<AddToAddV2>(context);
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // AddNOp
94 //===----------------------------------------------------------------------===//
95 
fold(ArrayRef<Attribute> operands)96 OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
97   if (operands.size() == 1) return *inputs().begin();
98   return {};
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // AddV2Op
103 //===----------------------------------------------------------------------===//
104 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)105 void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
106                                           MLIRContext *context) {
107   results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
108 }
109 
fold(ArrayRef<Attribute> operands)110 OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
111   return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // AllOp
116 //===----------------------------------------------------------------------===//
117 
Verify(AllOp op)118 static LogicalResult Verify(AllOp op) {
119   return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
120                                      op.getLoc());
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // AnyOp
125 //===----------------------------------------------------------------------===//
126 
Verify(AnyOp op)127 static LogicalResult Verify(AnyOp op) {
128   return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
129                                      op.getLoc());
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // AssertOp
134 //===----------------------------------------------------------------------===//
135 
136 namespace {
137 
138 // Removes Assert with constant true predicate.
139 struct AssertWithTrue : public OpRewritePattern<AssertOp> {
140   using OpRewritePattern<AssertOp>::OpRewritePattern;
141 
matchAndRewritemlir::TF::__anon0d5221e00211::AssertWithTrue142   LogicalResult matchAndRewrite(AssertOp op,
143                                 PatternRewriter &rewriter) const override {
144     ElementsAttr cst;
145     if (matchPattern(op.condition(), m_Constant(&cst))) {
146       if (cst.getValue<BoolAttr>({}).getValue()) {
147         rewriter.eraseOp(op);
148         return success();
149       }
150     }
151     return failure();
152   }
153 };
154 }  // namespace
155 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)156 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
157                                            MLIRContext *context) {
158   results.insert<AssertWithTrue>(context);
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // BatchMatMulV2Op & BatchMatMulOp
163 //===----------------------------------------------------------------------===//
164 
165 template <typename OpT,
166           typename std::enable_if<llvm::is_one_of<
167               OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
Verify(OpT op)168 static LogicalResult Verify(OpT op) {
169   if (!HasRankAtLeast(op.x(), 2)) {
170     return op.emitOpError("requires lhs operand to have rank at least two");
171   }
172   if (!HasRankAtLeast(op.y(), 2)) {
173     return op.emitOpError("requires rhs operand to have rank at least two");
174   }
175 
176   RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
177   RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
178 
179   if (!x_ty || !y_ty) return success();
180 
181   ArrayRef<int64_t> x_shape = x_ty.getShape();
182   ArrayRef<int64_t> y_shape = y_ty.getShape();
183 
184   llvm::SmallVector<int64_t, 4> result_batch_shape;
185   llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
186   llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
187 
188   // Check compatibility of batch dimensions if both input shapes are known.
189   // BatchMatMul should have exactly the same batch dimensions and
190   // BatchMatMulV2 should have broadcastable batch dimensions.
191   //
192   // The last two dimensions are non-batch dimensions that don't need to
193   // participate in batch dimension compatibility check.
194   if (std::is_same<OpT, BatchMatMulOp>()) {
195     for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
196       int64_t x_dim = std::get<0>(dim_pairs);
197       int64_t y_dim = std::get<1>(dim_pairs);
198       if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
199           x_dim != y_dim) {
200         return op.emitOpError()
201                << "found mismatching batch dimensions for lhs shape " << x_ty
202                << " and rhs shape " << y_ty;
203       }
204     }
205   } else {
206     if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
207                                             result_batch_shape))
208       return op.emitOpError()
209              << "found incompatible broadcast batch dimensions for lhs shape "
210              << x_ty << " and rhs shape " << y_ty;
211   }
212 
213   RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
214   if (!output_ty) return success();
215 
216   int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
217   if (output_ty.getRank() != expected_output_rank)
218     return op.emitOpError()
219            << "found invalid output rank, expected " << expected_output_rank
220            << " but got " << output_ty.getRank();
221 
222   // Check output batch dim with potential broadcasting.
223   ArrayRef<int64_t> output_shape = output_ty.getShape();
224   for (int i = 0; i < result_batch_shape.size(); ++i) {
225     if (output_shape[i] != ShapedType::kDynamicSize &&
226         output_shape[i] != result_batch_shape[i])
227       return op.emitOpError()
228              << "has mismatching input batch dimension "
229              << result_batch_shape[i] << " and output batch dimension "
230              << output_shape[i];
231   }
232 
233   // Check output shape for non-batch dimension, following documentation below.
234   // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
235   int64_t x_row_dim = x_shape[x_shape.size() - 2];
236   int64_t x_col_dim = x_shape[x_shape.size() - 1];
237   int64_t y_row_dim = y_shape[y_shape.size() - 2];
238   int64_t y_col_dim = y_shape[y_shape.size() - 1];
239   int64_t out_row_dim = output_shape[output_shape.size() - 2];
240   int64_t out_col_dim = output_shape[output_shape.size() - 1];
241 
242   int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
243   int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
244 
245   if (expected_out_row_dim != ShapedType::kDynamicSize &&
246       out_row_dim != ShapedType::kDynamicSize &&
247       out_row_dim != expected_out_row_dim)
248     return op.emitOpError()
249            << "found invalid output dimension on row, expected "
250            << expected_out_row_dim << " but got " << out_row_dim;
251   if (expected_out_col_dim != ShapedType::kDynamicSize &&
252       out_col_dim != ShapedType::kDynamicSize &&
253       out_col_dim != expected_out_col_dim)
254     return op.emitOpError()
255            << "found invalid output dimension on col, expected "
256            << expected_out_col_dim << " but got " << out_col_dim;
257 
258   return success();
259 }
260 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)261 void BatchMatMulOp::getCanonicalizationPatterns(
262     OwningRewritePatternList &results, MLIRContext *context) {
263   results.insert<BatchMatMulToV2>(context);
264 }
265 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)266 void BatchMatMulV2Op::getCanonicalizationPatterns(
267     OwningRewritePatternList &results, MLIRContext *context) {
268   results.insert<BatchMatMulV2ToMatMul>(context);
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // BatchToSpaceOp
273 //===----------------------------------------------------------------------===//
274 
Verify(BatchToSpaceOp op)275 static LogicalResult Verify(BatchToSpaceOp op) {
276   // Op already has a constraint that block_size >= 2.
277   int64_t block_size = op.block_size();
278 
279   llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
280   auto input_type = op.input().getType().cast<TensorType>();
281   if (input_type.hasRank()) {
282     if (input_type.getRank() != 4)
283       return op.emitOpError()
284              << "requires input to be a 4D tensor, but got " << input_type;
285 
286     int64_t input_batch = input_type.getDimSize(0);
287     if (input_batch != ShapedType::kDynamicSize &&
288         input_batch % (block_size * block_size) != 0) {
289       return op.emitOpError()
290              << "requires input batch (dimension 0) to be evenly divisible "
291                 "by (block_size * block_size), but got input batch "
292              << input_batch << " and block_size " << block_size;
293     }
294 
295     input_shape.assign(input_type.getShape().begin(),
296                        input_type.getShape().end());
297   }
298 
299   auto crops_type = op.crops().getType().cast<TensorType>();
300   if (crops_type.hasRank()) {
301     if (crops_type.getRank() != 2)
302       return op.emitOpError()
303              << "requires crops to be a 2D tensor, but got " << crops_type;
304 
305     auto dim_of_size = [&](int64_t dim, int64_t size) {
306       if (crops_type.isDynamicDim(dim)) return true;
307       return crops_type.getDimSize(dim) == size;
308     };
309     if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
310       return op.emitOpError()
311              << "requires crops to be a tensor<2x2>, but got " << crops_type;
312   }
313 
314   DenseIntElementsAttr crops_attr;
315   // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
316   // and flattened as [crop_top, crop_bottom, crop_left, crop_right]
317   llvm::SmallVector<int64_t, 4> crops_values;
318   if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
319     assert(crops_attr.getNumElements() == 4 &&
320            "tf.BatchToSpace crops must have 4 elements");
321 
322     auto crops_range = crops_attr.getIntValues();
323     for (const auto &crops_value : crops_range) {
324       int64_t crops_value_int = crops_value.getSExtValue();
325       if (crops_value_int < 0)
326         return op.emitOpError()
327                << "requires all crop values to be nonnegative, but got "
328                << crops_attr;
329 
330       crops_values.push_back(crops_value_int);
331     }
332   }
333 
334   auto output_type = op.output().getType().cast<TensorType>();
335   if (output_type.hasRank()) {
336     if (output_type.getRank() != 4)
337       return op.emitOpError()
338              << "requires output to be a 4D tensor, but got " << output_type;
339 
340     auto static_dims = [](int64_t dim_a, int64_t dim_b) {
341       return dim_a != ShapedType::kDynamicSize &&
342              dim_b != ShapedType::kDynamicSize;
343     };
344 
345     auto output_shape = output_type.getShape();
346 
347     // output batch = input batch / (block_size * block_size).
348     int64_t input_batch = input_shape[0];
349     int64_t output_batch = output_shape[0];
350     if (static_dims(input_batch, output_batch) &&
351         (output_batch * block_size * block_size) != input_batch)
352       return op.emitOpError()
353              << "requires output batch (dimension 0) to be equal to input "
354                 "batch (dimension 0) / (block_size * block_size), but got "
355                 "output batch "
356              << output_batch << ", input batch " << input_batch
357              << ", and block_size " << block_size;
358 
359     auto check_spatial_dim = [&](int64_t spatial_dim_index,
360                                  llvm::StringRef dim_name,
361                                  llvm::StringRef crop_a_name,
362                                  llvm::StringRef crop_b_name) -> LogicalResult {
363       int64_t input_dim = input_shape[spatial_dim_index];
364       int64_t output_dim = output_shape[spatial_dim_index];
365       if (!static_dims(input_dim, output_dim)) return success();
366 
367       int64_t input_dim_pad = input_dim * block_size;
368       // If crops are unknown, the maximum output spatial dim size is input
369       // spatial dim size * block_size, as crops can be minimum 0.
370       if (crops_values.empty() && output_dim > input_dim * block_size)
371         return op.emitOpError()
372                << "requires output " << dim_name << " (dimension "
373                << spatial_dim_index << ") to be less than or equal to input "
374                << dim_name << " (dimension " << spatial_dim_index
375                << ") * block_size, but got output " << dim_name << " "
376                << output_dim << ", input " << dim_name << " " << input_dim
377                << ", and block_size " << block_size;
378 
379       if (!crops_values.empty()) {
380         // output spatial dim = input spatial dim * block_size - crops.
381         int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
382         int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
383         if (output_dim != input_dim_pad - crop_a - crop_b)
384           return op.emitOpError()
385                  << "requires output " << dim_name << " (dimension "
386                  << spatial_dim_index << ") to be equal to input " << dim_name
387                  << " (dimension " << spatial_dim_index << ") * block_size - "
388                  << crop_a_name << " - " << crop_b_name << ", but got output "
389                  << dim_name << " " << output_dim << ", input " << dim_name
390                  << " " << input_dim << ", " << crop_a_name << " " << crop_a
391                  << ", " << crop_b_name << " " << crop_b << ", and block_size "
392                  << block_size;
393       }
394 
395       return success();
396     };
397 
398     if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
399         failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
400       return failure();
401 
402     int64_t input_depth = input_shape[3];
403     int64_t output_depth = output_shape[3];
404     if (static_dims(input_depth, output_depth) && output_depth != input_depth)
405       return op.emitOpError()
406              << "requires output depth (dimension 3) to be equal to input "
407                 "depth (dimension 3), but got output depth "
408              << output_depth << " and input depth " << input_depth;
409   }
410 
411   return success();
412 }
413 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)414 void BatchToSpaceOp::getCanonicalizationPatterns(
415     OwningRewritePatternList &results, MLIRContext *context) {
416   results.insert<BatchToSpaceToBatchToSpaceND>(context);
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // BatchToSpaceNDOp
421 //===----------------------------------------------------------------------===//
422 
Verify(BatchToSpaceNDOp op)423 static LogicalResult Verify(BatchToSpaceNDOp op) {
424   auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
425   auto crops_ty = op.crops().getType().cast<ShapedType>();
426 
427   if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
428     const int block_rank = block_shape_ty.getShape().front();
429     if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
430         crops_ty.getShape()[1] != 2) {
431       op.emitOpError() << "crops should have shape [" << block_rank
432                        << ", 2] instead of " << crops_ty.getShape();
433       return failure();
434     }
435   }
436 
437   return success();
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // BiasAddOp
442 //===----------------------------------------------------------------------===//
443 
444 // Verifies that,
445 // * the value and bias operands have valid ranks or are unranked.
446 // * Channel dimension of the value operand and length of bias matches if they
447 //   are not unknown.
448 //
Verify(BiasAddOp op)449 static LogicalResult Verify(BiasAddOp op) {
450   absl::string_view data_format(op.data_format().data(),
451                                 op.data_format().size());
452   tensorflow::TensorFormat format;
453   bool is_valid = FormatFromString(data_format, &format);
454   DCHECK(is_valid) << data_format;
455   if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
456     if (!HasRankAtLeast(op.value(), 2))
457       return op.emitOpError(
458           "requires value operand to have rank at least two with `NHWC` data "
459           "format");
460   } else {
461     // Op definition requires data_format to be either NHWC or NCHW.
462     DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
463     if (!HasRankAtLeast(op.value(), 3))
464       return op.emitOpError(
465           "requires value operand to have rank at least three with `NCHW` data "
466           "format");
467   }
468 
469   if (!IsOfRankOrUnranked(op.bias(), 1))
470     return op.emitOpError("requires bias operand to have rank exactly one");
471 
472   RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
473   RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
474   if (!bias_ty || !value_ty) return success();
475 
476   int64_t feature_dim_idx =
477       tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
478   int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
479   int64_t bias_len = bias_ty.getDimSize(0);
480   if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
481     return op.emitOpError()
482            << "requires channel dimension and feature dimension to match; "
483               "found "
484            << feature_dim << " and " << bias_len << ", respectively";
485   }
486   return success();
487 }
488 
GetContractionFusion()489 Optional<ContractionFusion> BiasAddOp::GetContractionFusion() {
490   // Only NHWC in f32 is supported for fusion.
491   if (data_format() != "NHWC" || !T().isF32()) return None;
492 
493   return ContractionFusion("BiasAdd", /*additional_arguments=*/{1});
494 }
495 
UpdateDataFormat(StringRef data_format)496 LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
497   return ::mlir::TF::UpdateDataFormat(data_format, this);
498 }
499 
GetOptimalLayout(const RuntimeDevices & devices)500 StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) {
501   // Keep current data format if no GPUs are available or if explicit placement
502   // does not allow to use GPU for this operation.
503   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
504     return data_format();
505 
506   // Prefer NHWC for GPU devices.
507   return "NHWC";
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // BiasAddGradOp
512 //===----------------------------------------------------------------------===//
513 
514 // Verifies that,
515 // * the out_backprop operands have valid ranks or are unranked.
516 //
Verify(BiasAddGradOp op)517 static LogicalResult Verify(BiasAddGradOp op) {
518   absl::string_view data_format(op.data_format().data(),
519                                 op.data_format().size());
520   tensorflow::TensorFormat format;
521   bool is_valid = FormatFromString(data_format, &format);
522   DCHECK(is_valid) << data_format;
523   if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
524     if (!HasRankAtLeast(op.out_backprop(), 2))
525       return op.emitOpError(
526           "requires out_backprop operand to have rank at least two with `NHWC` "
527           "data format");
528   } else {
529     // Op definition requires data_format to be either NHWC or NCHW.
530     DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
531     if (!HasRankAtLeast(op.out_backprop(), 3))
532       return op.emitOpError(
533           "requires out_backprop operand to have rank at least three with "
534           "`NCHW` data format");
535   }
536 
537   return success();
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // BiasAddV1Op
542 //===----------------------------------------------------------------------===//
543 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)544 void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
545                                               MLIRContext *context) {
546   results.insert<BiasAddV1ToBiasAdd>(context);
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // BitcastOp
551 //===----------------------------------------------------------------------===//
552 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)553 void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
554                                             MLIRContext *context) {
555   results.insert<BitcastSameType, BitcastNested>(context);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // BroadcastToOp
560 //===----------------------------------------------------------------------===//
561 
Verify(BroadcastToOp op)562 static LogicalResult Verify(BroadcastToOp op) {
563   // TODO(antiagainst): check that
564   // * The 'shape' input is an 1-D int tensor.
565   // * Each dimension pair of the source and target shapes are either equal
566   //   or one of them is one.
567   return success();
568 }
569 
fold(ArrayRef<Attribute> operands)570 OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
571   Value input = this->input();
572 
573   // Fold broadcast if operand and result types are the same and all dimensions
574   // are statically known (no-op broadcast).
575   auto result_ty = getType().dyn_cast<ShapedType>();
576   if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) {
577     return input;
578   }
579 
580   return {};
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // BroadcastGradientArgsOp
585 //===----------------------------------------------------------------------===//
586 
587 namespace {
588 // Returns `true` if both s0 & s1 are defined via constant op, and fills
589 // s0_shape & s1_shape.
ExtractInputConstShape(BroadcastGradientArgsOp op,DenseIntElementsAttr & s0,DenseIntElementsAttr & s1,SmallVectorImpl<int64_t> & s0_shape,SmallVectorImpl<int64_t> & s1_shape)590 bool ExtractInputConstShape(BroadcastGradientArgsOp op,
591                             DenseIntElementsAttr &s0, DenseIntElementsAttr &s1,
592                             SmallVectorImpl<int64_t> &s0_shape,
593                             SmallVectorImpl<int64_t> &s1_shape) {
594   if (!matchPattern(op.s0(), m_Constant(&s0))) return false;
595   if (!matchPattern(op.s1(), m_Constant(&s1))) return false;
596 
597   for (auto s : s0.getIntValues()) s0_shape.push_back(s.getSExtValue());
598   for (auto s : s1.getIntValues()) s1_shape.push_back(s.getSExtValue());
599 
600   return true;
601 }
602 
603 // Calculates r0 & r1 output based on inputs and calculated broadcasted shape.
604 //
605 // For given bcasted_shape, s0_shape and s1_shape, the broadcasted dimension is
606 // calculated and push back to its corresponding result, r0 or r1. For example,
607 // for s0_shape [1,4] and s1_shape [4, 4], bcasted_shape is computed to be
608 // [4,4] - this leads to the result of r0 to be [0] as the first dimension of s0
609 // is broadcasted, and r1 to be <> as no broadcasting is happening for s1.
GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,ArrayRef<int64_t> s0_shape,ArrayRef<int64_t> s1_shape,SmallVectorImpl<int64_t> & r0,SmallVectorImpl<int64_t> & r1)610 void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
611                                             ArrayRef<int64_t> s0_shape,
612                                             ArrayRef<int64_t> s1_shape,
613                                             SmallVectorImpl<int64_t> &r0,
614                                             SmallVectorImpl<int64_t> &r1) {
615   r0.clear();
616   r1.clear();
617 
618   // No broadcasting is required if both the shapes are equal.
619   if (s0_shape == s1_shape) return;
620 
621   for (int i = bcasted_shape.size(); i > 0; --i) {
622     int idx = bcasted_shape.size() - i;
623     int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i;
624     int s1_idx = i > s1_shape.size() ? -1 : s1_shape.size() - i;
625     if (s0_idx == -1) {
626       r0.push_back(idx);
627       if (s1_shape[s1_idx] == 1) r1.push_back(idx);
628     } else if (s1_idx == -1) {
629       r1.push_back(idx);
630       if (s0_shape[s0_idx] == 1) r0.push_back(idx);
631     } else if (s0_shape[s0_idx] != s1_shape[s1_idx]) {
632       if (s0_shape[s0_idx] != bcasted_shape[idx])
633         r0.push_back(idx);
634       else
635         r1.push_back(idx);
636     } else if (s0_shape[s0_idx] == 1) {
637       // This op is used to compute the gradient dimensions requiring reduction
638       // to match the input dimensions. In case both the dimensions are one,
639       // reducing the dimension has no effect. We choose to reduce such
640       // dimensions to match the TensorFlow kernel behavior. However, note that
641       // the TF behavior in this case is inconsistent with the case with the
642       // same shapes.
643       r0.push_back(idx);
644       r1.push_back(idx);
645     }
646   }
647 }
648 }  // namespace
649 
650 // Verifies that,
651 // * Broadcast compatability for input shapes.
652 // * Output shape dimension matches the expected dimension size for input
653 // shapes.
Verify(BroadcastGradientArgsOp op)654 static LogicalResult Verify(BroadcastGradientArgsOp op) {
655   SmallVector<int64_t, 4> s0_shape, s1_shape;
656   DenseIntElementsAttr s0, s1;
657   if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success();
658 
659   // If both shape is known const, try to validate shape on them as well.
660   SmallVector<int64_t, 4> bcasted_shape;
661   if (!OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape))
662     return op.emitOpError() << "requires broadcast compatible shape tensors "
663                                "for 's0' and 's1', but got "
664                             << s0 << " and " << s1;
665 
666   SmallVector<int64_t, 4> r0, r1;
667   GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
668                                          r1);
669 
670   // Verify that output types are of rank one and matches the computed result
671   // shape.
672   auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>();
673   auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>();
674   if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size())
675     return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
676                             << r0.size() << " but got " << r0_ty.getShape()[0];
677   if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size())
678     return op.emitOpError() << "requires dimension 0 size of 'r1' to be "
679                             << r1.size() << " but got " << r1_ty.getShape()[0];
680 
681   return success();
682 }
683 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)684 LogicalResult BroadcastGradientArgsOp::fold(
685     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
686   SmallVector<int64_t, 4> s0_shape, s1_shape;
687   DenseIntElementsAttr s0, s1;
688   if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape))
689     return failure();
690 
691   // Fold BroadcastGradientArgs into two constants if both of the inputs have
692   // known shape.
693   SmallVector<int64_t, 4> bcasted_shape;
694   // Verifier should already ensure the broadcast compatibility.
695   bool bcast_compatible =
696       OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape);
697   assert(bcast_compatible);
698   (void)bcast_compatible;
699 
700   SmallVector<int64_t, 4> r0, r1;
701   GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
702                                          r1);
703 
704   auto build_out_dense_element = [](SmallVectorImpl<int64_t> &shape,
705                                     Type input_type) {
706     Type element_type = input_type.cast<mlir::TensorType>().getElementType();
707     RankedTensorType type = RankedTensorType::get(
708         {static_cast<int64_t>(shape.size())}, element_type);
709     // Input could only be i32 or i64. For i32, downcast to int32_t array.
710     if (element_type.isInteger(32)) {
711       SmallVector<int32_t, 4> i32_shape;
712       for (auto s : shape) i32_shape.push_back(static_cast<int32_t>(s));
713       return DenseIntElementsAttr::get(type, i32_shape);
714     } else {
715       assert(element_type.isInteger(64));
716       return DenseIntElementsAttr::get(type, shape);
717     }
718   };
719 
720   results.push_back(build_out_dense_element(r0, this->s0().getType()));
721   results.push_back(build_out_dense_element(r1, this->s1().getType()));
722 
723   return success();
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // CaseOp
728 //===----------------------------------------------------------------------===//
729 
730 class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
731  public:
FoldConstantCaseOp(MLIRContext * context)732   explicit FoldConstantCaseOp(MLIRContext *context)
733       : OpRewritePattern<TF::CaseOp>(context) {}
734   LogicalResult matchAndRewrite(TF::CaseOp op,
735                                 PatternRewriter &rewriter) const override;
736 };
737 
matchAndRewrite(TF::CaseOp op,PatternRewriter & rewriter) const738 LogicalResult FoldConstantCaseOp::matchAndRewrite(
739     TF::CaseOp op, PatternRewriter &rewriter) const {
740   // Extract the constant cond value.
741   DenseIntElementsAttr branch;
742   if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
743 
744   int index = *branch.getValues<int>().begin();
745   if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
746 
747   auto func = op.branches()[index].cast<SymbolRefAttr>();
748   auto empty = rewriter.getStringAttr("");
749   auto call_op = rewriter.create<PartitionedCallOp>(
750       op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
751       /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
752   CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
753   rewriter.replaceOp(op, call_op.getResults());
754   return success();
755 }
756 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)757 void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
758                                          MLIRContext *context) {
759   results.insert<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
760 }
761 
VerifyCaseOpBase(Operation * op,Value branch_index)762 static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) {
763   if (!IsOfRankOrUnranked(branch_index, 0))
764     return op->emitOpError()
765            << "expects 'branch_index' to be a scalar, but got "
766            << branch_index.getType();
767   return success();
768 }
769 
VerifyCaseOrIfOpBranchFunctions(Operation * op,ArrayRef<Attribute> branches,llvm::function_ref<std::string (unsigned branch_index)> branch_name)770 static LogicalResult VerifyCaseOrIfOpBranchFunctions(
771     Operation *op, ArrayRef<Attribute> branches,
772     llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
773   SmallVector<FunctionType, 2> branch_types;
774   branch_types.reserve(branches.size());
775 
776   // Functions have one less operand compared to op as first operand is elided
777   // (`cond` of `tf.If` and `branch_index` of `tf.Case`).
778   TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"};
779   TypeRangeWithDesc result{op->getResultTypes(), "result"};
780 
781   for (auto branch : llvm::enumerate(branches)) {
782     auto branch_func = SymbolTable::lookupNearestSymbolFrom<FuncOp>(
783         op, branch.value().cast<SymbolRefAttr>());
784     if (!branch_func)
785       return op->emitOpError()
786              << "expects " << branch_name(branch.index()) << " ("
787              << branch.value() << ") to point to a defined function";
788 
789     FunctionType branch_type = branch_func.getType();
790     std::string desc = branch_name(branch.index()) + " input";
791     TypeRangeWithDesc branch_input{branch_type.getInputs(), desc};
792     if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input)))
793       return failure();
794 
795     desc = branch_name(branch.index()) + " result";
796     TypeRangeWithDesc branch_result{branch_type.getResults(), desc};
797     if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result)))
798       return failure();
799 
800     branch_types.push_back(branch_type);
801   }
802 
803   // If branches have incompatible input types that means that no tensor can
804   // serve as input to all the functions. Hence, the op is invalid.
805   int expected_num_inputs = op->getNumOperands() - 1;
806   for (int i = 0; i < expected_num_inputs; ++i) {
807     SmallVector<Type, 2> branch_input_i_types;
808     branch_input_i_types.reserve(branches.size());
809     llvm::transform(
810         branch_types, std::back_inserter(branch_input_i_types),
811         [i](FunctionType &branch_type) { return branch_type.getInput(i); });
812     if (!AreCastCompatible(branch_input_i_types)) {
813       std::string input_types_str;
814       llvm::raw_string_ostream os(input_types_str);
815       llvm::interleaveComma(branch_input_i_types, os);
816       return op->emitOpError()
817              << "expects all branch input type(s) (" << os.str()
818              << ") at index " << i << " to be cast compatible";
819     }
820   }
821 
822   return success();
823 }
824 
Verify(CaseOp op)825 static LogicalResult Verify(CaseOp op) {
826   if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
827   auto branch_name = [](unsigned index) {
828     return llvm::formatv("branch #{0}", index).str();
829   };
830   return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(),
831                                          branch_name);
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // CaseRegionOp
836 //===----------------------------------------------------------------------===//
837 
Verify(CaseRegionOp op)838 static LogicalResult Verify(CaseRegionOp op) {
839   if (op.branches().empty())
840     return op.emitOpError() << "expects to have at least 1 region";
841 
842   if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
843 
844   TypeRangeWithDesc results{op.getResultTypes(), "result"};
845 
846   for (auto region_and_idx : llvm::enumerate(op.branches())) {
847     std::string description =
848         llvm::formatv("branch #{0} result", region_and_idx.index()).str();
849     Operation *yield = region_and_idx.value().front().getTerminator();
850     TypeRangeWithDesc branch_results{yield->getOperandTypes(), description};
851     if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results)))
852       return failure();
853   }
854 
855   return success();
856 }
857 
858 namespace {
859 // Eliminate values that pass through the CaseRegionOp or IfRegionOp branches.
860 template <class CaseOrIfRegionOp>
861 class CaseOrIfRegionEliminatePassThrough
862     : public OpRewritePattern<CaseOrIfRegionOp> {
863   using OpRewritePattern<CaseOrIfRegionOp>::OpRewritePattern;
864 
matchAndRewrite(CaseOrIfRegionOp op,PatternRewriter & rewriter) const865   LogicalResult matchAndRewrite(CaseOrIfRegionOp op,
866                                 PatternRewriter &rewriter) const override {
867     RegionRange branches = op.getRegions();
868     SmallVector<Type, 4> new_result_types;
869     // Maps pass through results to extern values.
870     llvm::SmallDenseMap<Value, Value, 4> result_to_extern_value;
871 
872     for (auto result : op.getResults()) {
873       unsigned index = result.getResultNumber();
874       Region *first_branch = *branches.begin();
875       Operation *first_terminator = first_branch->front().getTerminator();
876       Value returned_val = first_terminator->getOperand(index);
877 
878       // Pass through values would be defined outside the branch region. Keep
879       // the type of non pass through results to create a new op later, if
880       // required.
881       if (returned_val.getParentBlock() == &first_branch->front()) {
882         new_result_types.push_back(result.getType());
883         continue;
884       }
885       // Check if the same extern value is returned in each branch.
886       for (Region *region : branches.drop_front()) {
887         Operation *terminator = region->front().getTerminator();
888         if (terminator->getOperand(index) != returned_val) return failure();
889       }
890       result_to_extern_value[result] = returned_val;
891     }
892 
893     // If no pass through values are found, no change is required.
894     if (result_to_extern_value.empty()) return failure();
895 
896     // Create new case/if region op.
897     auto new_op = rewriter.create<CaseOrIfRegionOp>(
898         op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(),
899         op.getNumRegions());
900 
901     int next_index = 0;
902     for (auto result : op.getResults()) {
903       if (!result_to_extern_value.count(result)) {
904         result.replaceAllUsesWith(new_op.getResult(next_index++));
905         continue;
906       }
907       result.replaceAllUsesWith(result_to_extern_value[result]);
908       for (Region *branch : branches)
909         branch->front().getTerminator()->eraseOperand(next_index);
910     }
911 
912     // Move region bodies to the new op.
913     for (auto region_index : llvm::seq<int>(0, branches.size()))
914       new_op.getRegion(region_index).takeBody(op.getRegion(region_index));
915 
916     op.erase();
917     return success();
918   }
919 };
920 }  // namespace
921 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)922 void CaseRegionOp::getCanonicalizationPatterns(
923     OwningRewritePatternList &results, MLIRContext *context) {
924   results.insert<CaseOrIfRegionEliminatePassThrough<TF::CaseRegionOp>>(context);
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // CastOp
929 //===----------------------------------------------------------------------===//
930 
fold(ArrayRef<Attribute> operands)931 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
932   // Cast with the same type is a no-op.
933   Value operand = getOperand();
934   if (getType() == operand.getType()) return operand;
935   return {};
936 }
937 
938 //===----------------------------------------------------------------------===//
939 // ConcatOp and ConcatV2Op
940 //===----------------------------------------------------------------------===//
941 
942 template <typename OpT,
943           typename std::enable_if<llvm::is_one_of<
944               OpT, ConcatOp, ConcatV2Op>::value>::type * = nullptr>
Verify(OpT op)945 static LogicalResult Verify(OpT op) {
946   // TODO(hinsu): Convert variadic length attributes to derived attributes.
947   Operation::operand_range values = op.values();
948 
949   int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
950   Value axis = *op.getODSOperands(axis_idx).begin();
951   if (!HasRankAtMost(axis, 1)) {
952     return op.emitOpError(
953         "requires axis to be of scalar type (or vector type for older "
954         "versions)");
955   }
956 
957   return VerifyTypesCompatibility(values,
958                                   /*mask_one_dim=*/true, op.getOperation());
959 }
960 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)961 void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
962                                            MLIRContext *context) {
963   results.insert<ConvertToConcatV2>(context);
964 }
965 
966 namespace {
967 
968 // Hoist coefficient-wise unary operation out of the Concat op:
969 //
970 //   %0 = "tf.Log1p"(%arg_0)
971 //   %1 = "tf.Log1p"(%arg_1)
972 //   ...
973 //   %n = "tf.Log1p"(%arg_n)
974 //   %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis)
975 //
976 // Rewrite it to:
977 //
978 //   %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis)
979 //   %1 = "tf.Log1p"(%0)
980 class HoistCwiseUnaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
981  public:
HoistCwiseUnaryOutOfConcat(MLIRContext * context)982   explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context)
983       : OpRewritePattern<TF::ConcatV2Op>(context) {}
984   LogicalResult matchAndRewrite(TF::ConcatV2Op op,
985                                 PatternRewriter &rewriter) const override;
986 };
987 
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const988 LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
989     TF::ConcatV2Op op, PatternRewriter &rewriter) const {
990   auto loc = op.getLoc();
991 
992   // All concat operands must be defined by ops.
993   Operation *first_arg_op = op.values().front().getDefiningOp();
994   if (first_arg_op == nullptr) return failure();
995 
996   // All concat operands must be produced by the coeff-wise unary operation.
997   if (!first_arg_op->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
998 
999   // All concat operands must be defined by the op of same kind.
1000   bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
1001     Operation *arg_op = arg.getDefiningOp();
1002     return arg_op && arg_op->getName() == first_arg_op->getName();
1003   });
1004   if (!args_same_op) return failure();
1005 
1006   // Collect unary operations operands.
1007   auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value {
1008     return arg.getDefiningOp()->getOperand(0);
1009   });
1010   SmallVector<Value, 8> unary_ops_args(unary_operands);
1011 
1012   // Concatenate unary ops operands.
1013   auto concat_unary_operands =
1014       rewriter.create<ConcatV2Op>(loc, op.getType(), unary_ops_args, op.axis());
1015 
1016   // Replace original concat with an unary op.
1017   OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(),
1018                                     concat_unary_operands.getResult(),
1019                                     op.getResult().getType(),
1020                                     ArrayRef<NamedAttribute>());
1021   Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
1022 
1023   rewriter.replaceOp(op, new_unary_op->getResults());
1024 
1025   return success();
1026 }
1027 
1028 // Hoist coefficient-wise binary operation out of the Concat op:
1029 //
1030 //   %0 = tf.Mul(%lhs_0, %rhs_0)
1031 //   %1 = tf.Mul(%lhs_1, %rhs_1)
1032 //   ...
1033 //   %n = tf.Mul(%lhs_n, %rhs_n)
1034 //   %m = tf.ConcatV2(%0, %1, ..., %n, %axis)
1035 //
1036 // Rewrite it to:
1037 //
1038 //   %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis)
1039 //   %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
1040 //   %2 = tf.Mul(%0, %1)
1041 //
1042 // If a minor fraction of the Concat inputs are not of the same binary op kind
1043 // (tf.Mul in the above example), we will synthesize the binary ops for those
1044 // inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
1045 // tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
1046 // float32 tensors.
1047 // TODO(hongm): Implement this op synthesis optimization for other dtypes if
1048 // needed.
1049 //
1050 // Because coefficient-wise binary operations support implicit broadcasting, we
1051 // should be very careful with this optimization, and do not accidentally
1052 // produce incorrect concat operations.
1053 class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1054  public:
HoistCwiseBinaryOutOfConcat(MLIRContext * context)1055   explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context)
1056       : OpRewritePattern<TF::ConcatV2Op>(context) {}
1057   LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1058                                 PatternRewriter &rewriter) const override;
1059 
1060  private:
1061   struct HoistParams {
1062     SmallVector<Value, 8> lhs_args;
1063     SmallVector<Value, 8> rhs_args;
1064     int64_t lhs_axis;
1065     int64_t rhs_axis;
1066     Type lhs_concat_type;
1067     Type rhs_concat_type;
1068     int scalar_operand_idx;  // can be 0 or 1 for the binary op's operands.
1069   };
1070 
1071   // Returns parameters of a binary op hoisting out of concatenation if all of
1072   // the operands are in one of the compatible configurations.
1073   // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
1074   // except from the ones in `exceptions`. In that case, we can synthesize that
1075   // binary op kind for the values in `exceptions`.
1076   Optional<HoistParams> GetHoistParams(
1077       TF::ConcatV2Op op, int64_t axis,
1078       const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
1079 };
1080 
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1081 LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
1082     TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1083   auto loc = op.getLoc();
1084 
1085   // Axis must be a constant scalar value.
1086   DenseIntElementsAttr axis_attr;
1087   if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1088   if (axis_attr.getNumElements() != 1) return failure();
1089   int64_t axis =
1090       axis_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
1091   // TODO(ezhulenev): Compute axis from rank. e.g. It might be common to concat
1092   // on the channels dim for NCHW layout as axis=-2.
1093   if (axis < 0) return failure();
1094 
1095   // All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
1096   // or some other ops that we might convert to using the same op kind above
1097   // (e.g. converting op A to tf.Mul(A, 1.0))
1098   // TODO(hongm): generalize the code here to support cases where the first arg
1099   // has no defining op (e.g. might be a block arg).
1100   Operation *first_arg_op = op.values().front().getDefiningOp();
1101   if (first_arg_op == nullptr) return failure();
1102 
1103   // All concat operands must be produced by the coeff-wise binary operation.
1104   if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
1105 
1106   // All concat operands must be defined by the op of same kind, except for a
1107   // minor portion which we track in `exceptions`.
1108   // Map from the operands to operand indices.
1109   llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
1110   unsigned operand_idx = 0;
1111   for (Value arg : op.values()) {
1112     Operation *arg_op = arg.getDefiningOp();
1113     if (arg_op && arg_op->getName() == first_arg_op->getName()) {
1114       ++operand_idx;
1115       continue;
1116     }
1117     exceptions[arg] = operand_idx++;
1118   }
1119   // Recall those inputs to the concat op that are not produced by a binary op
1120   // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
1121   // there are too many exceptions, it might not be cost effective to apply the
1122   // concat hoisting optimization here.
1123   // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
1124   // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
1125   // out of 3, we do.
1126   const float exception_pct_threshold = 0.5;
1127   if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
1128       exceptions.size())
1129     return failure();
1130 
1131   // Compute binary operands hoist parameters.
1132   auto hoist_params = GetHoistParams(op, axis, exceptions);
1133   if (!hoist_params.hasValue()) return failure();
1134 
1135   // Process `exceptions`: For each value there, synthesize a binary op of the
1136   // above kind, so that the concat hoisting optimization can still apply.
1137   if (!exceptions.empty()) {
1138     int identity_val;
1139     if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
1140       identity_val = 0;
1141     else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
1142              isa<RealDivOp>(first_arg_op))
1143       identity_val = 1;
1144     else
1145       return failure();
1146     DenseElementsAttr const_attr;
1147     auto scalar_tensor_type =
1148         first_arg_op->getOperand(hoist_params->scalar_operand_idx)
1149             .getType()
1150             .dyn_cast<ShapedType>();
1151     Type scalar_dtype = scalar_tensor_type.getElementType();
1152     if (scalar_dtype.isa<FloatType>())
1153       const_attr = DenseElementsAttr::get(scalar_tensor_type,
1154                                           static_cast<float>(identity_val));
1155     else
1156       return failure();
1157 
1158     // All checks are passes, and we now prepare for rewrite.
1159     auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
1160     for (const auto &kv : exceptions) {
1161       assert(!hoist_params->lhs_args[kv.second]);
1162       assert(!hoist_params->rhs_args[kv.second]);
1163 
1164       if (hoist_params->scalar_operand_idx == 1) {
1165         hoist_params->lhs_args[kv.second] = kv.first;
1166         hoist_params->rhs_args[kv.second] = identity_const;
1167       } else {
1168         assert(hoist_params->scalar_operand_idx == 0);
1169         hoist_params->lhs_args[kv.second] = identity_const;
1170         hoist_params->rhs_args[kv.second] = kv.first;
1171       }
1172     }
1173   }
1174 
1175   // New lhs and rhs concatenation axis.
1176   auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
1177   auto lhs_axis = rewriter.create<TF::ConstOp>(
1178       loc, DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis));
1179   auto rhs_axis = rewriter.create<TF::ConstOp>(
1180       loc, DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis));
1181 
1182   // Concatenate binary ops operands on the new axis.
1183   auto lhs_concat = rewriter.create<ConcatV2Op>(
1184       loc, hoist_params->lhs_concat_type, hoist_params->lhs_args, lhs_axis);
1185   auto rhs_concat = rewriter.create<ConcatV2Op>(
1186       loc, hoist_params->rhs_concat_type, hoist_params->rhs_args, rhs_axis);
1187 
1188   // Replace original concat with a binary op.
1189   OperationState new_binary_op_state(
1190       loc, first_arg_op->getName().getStringRef(),
1191       {lhs_concat.getResult(), rhs_concat.getResult()},
1192       op.getResult().getType(), ArrayRef<NamedAttribute>());
1193   Operation *new_binary_op = rewriter.createOperation(new_binary_op_state);
1194 
1195   rewriter.replaceOp(op, new_binary_op->getResults());
1196 
1197   return success();
1198 }
1199 
1200 Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
GetHoistParams(TF::ConcatV2Op op,int64_t axis,const llvm::SmallDenseMap<Value,unsigned,4> & exceptions) const1201 HoistCwiseBinaryOutOfConcat::GetHoistParams(
1202     TF::ConcatV2Op op, int64_t axis,
1203     const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
1204   assert(axis >= 0);
1205   // Collects lhs or rhs arguments of concat op operands.
1206   auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
1207     auto range = llvm::map_range(op.values(), [&](Value arg) {
1208       if (exceptions.count(arg)) return Value();
1209       return arg.getDefiningOp()->getOperand(operand_idx);
1210     });
1211     return {range.begin(), range.end()};
1212   };
1213 
1214   // Returns true if all binary ops operands at `operand_idx` index are tensors
1215   // of `axis + 1` rank and axis dim has size `1`.
1216   auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
1217     return llvm::all_of(op.values(), [&](Value arg) -> bool {
1218       if (exceptions.count(arg)) return true;
1219       auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1220       auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1221       return ranked && ranked.getRank() == (axis + 1) &&
1222              ranked.getShape()[axis] == 1;
1223     });
1224   };
1225 
1226   // Returns true if all binary ops operands at `operand_idx` index are scalars.
1227   auto is_all_scalars = [&](int operand_idx) -> bool {
1228     return llvm::all_of(op.values(), [&](Value arg) -> bool {
1229       if (exceptions.count(arg)) return true;
1230       auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1231       auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1232       return ranked && ranked.hasRank() && ranked.getRank() == 0;
1233     });
1234   };
1235 
1236   // Concat result type must be a ranked tensor.
1237   auto ranked = op.getType().dyn_cast<RankedTensorType>();
1238   if (!ranked) return None;
1239 
1240   // TODO(ezhulenev): Add support for more valid concat patterns.
1241 
1242   // Tensor + Scalar: [..., 1] + []  <- scalar
1243   //                        ^
1244   //                        \- axis is the innermost dimension.
1245   //
1246   // Concatenate tensor arguments on the same axis as the original operation,
1247   // and concatenate scalars into the vector.
1248   if (is_all_tensors(0, axis) && is_all_scalars(1)) {
1249     std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
1250     auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
1251     return HoistParams{args(0),
1252                        args(1),
1253                        axis,
1254                        0,
1255                        op.getType(),
1256                        rhs_type,
1257                        /*scalar_operand_idx=*/1};
1258   } else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
1259     std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
1260     auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
1261     return HoistParams{args(0),
1262                        args(1),
1263                        0,
1264                        axis,
1265                        lhs_type,
1266                        op.getType(),
1267                        /*scalar_operand_idx=*/0};
1268   }
1269   return None;
1270 }
1271 
1272 }  // namespace
1273 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1274 void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
1275                                              MLIRContext *context) {
1276   results.insert<HoistCwiseBinaryOutOfConcat, HoistCwiseUnaryOutOfConcat>(
1277       context);
1278 }
1279 
1280 //===----------------------------------------------------------------------===//
1281 // CumsumOp and CumprodOp
1282 //===----------------------------------------------------------------------===//
1283 
1284 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1285                             OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
Verify(OpT op)1286 static LogicalResult Verify(OpT op) {
1287   if (!IsOfRankOrUnranked(op.axis(), 0))
1288     return op.emitOpError("requires scalar axis operand");
1289 
1290   DenseIntElementsAttr axis_attr;
1291   if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
1292     auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
1293     if (input_ty) {
1294       int64_t rank = input_ty.getRank();
1295       assert(axis_attr.getNumElements() == 1 &&
1296              "scalar attribute should have exactly one element");
1297       int64_t axis = (*axis_attr.begin()).getSExtValue();
1298       if (axis < -rank || axis >= rank) {
1299         return op.emitError()
1300                << "axis operand should be within range [" << -rank << ", "
1301                << rank << "); actual value: " << axis;
1302       }
1303     }
1304   }
1305 
1306   return success();
1307 }
1308 
1309 //===----------------------------------------------------------------------===//
1310 // ConcatOffsetOp
1311 //===----------------------------------------------------------------------===//
1312 
Verify(ConcatOffsetOp op)1313 static LogicalResult Verify(ConcatOffsetOp op) {
1314   if (op.N() < 2)
1315     return op.emitOpError() << "requires N to be at least 2, got " << op.N();
1316 
1317   if (op.shape().size() != op.offset().size())
1318     return op.emitOpError()
1319            << "requires sizes of shapes and offsets to be the same, got sizes "
1320            << op.shape().size() << " and " << op.offset().size();
1321 
1322   auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
1323   if (ranked_dim && ranked_dim.getRank() != 0)
1324     return op.emitOpError()
1325            << "requires concat_dim to be a scalar, got tensor of rank "
1326            << ranked_dim.getRank();
1327 
1328   int64_t num_dims = -1;
1329   for (auto shape_offset_idx :
1330        llvm::enumerate(llvm::zip(op.shape(), op.offset()))) {
1331     Value shape = std::get<0>(shape_offset_idx.value());
1332     Value offset = std::get<1>(shape_offset_idx.value());
1333     const size_t idx = shape_offset_idx.index();
1334 
1335     if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
1336       return op.emitOpError() << "requires operand and result " << idx
1337                               << " to have compatible shapes";
1338 
1339     auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
1340     if (!ranked_shape) continue;
1341 
1342     if (ranked_shape.getRank() != 1)
1343       return op.emitOpError() << "requires shape tensor operand " << idx
1344                               << " to be of rank 1, got tensor of rank "
1345                               << ranked_shape.getRank();
1346 
1347     if (!ranked_shape.hasStaticShape()) continue;
1348 
1349     int64_t ranked_shape_dim = ranked_shape.getDimSize(0);
1350     if (num_dims == -1)
1351       num_dims = ranked_shape_dim;
1352     else if (ranked_shape_dim != num_dims)
1353       return op.emitOpError()
1354              << "requires shape tensor (rank 1) operand " << idx
1355              << " to be of length " << num_dims
1356              << ", got tensor (rank 1) of length " << ranked_shape_dim;
1357   }
1358 
1359   return success();
1360 }
1361 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1362 LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
1363                                    SmallVectorImpl<OpFoldResult> &results) {
1364   // ConcatOffset must have its first operand be concat_dim and at least two
1365   // shape tensors in variadic shapes operand.
1366   if (operands.size() < 3) return failure();
1367 
1368   // Check concat_dim is a scalar.
1369   auto concat_dim_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1370   if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0)
1371     return failure();
1372 
1373   llvm::SmallVector<DenseIntElementsAttr, 4> shapes;
1374   shapes.reserve(operands.size() - 1);
1375   for (Attribute shape : llvm::drop_begin(operands, 1))
1376     if (auto shape_attr = shape.dyn_cast_or_null<DenseIntElementsAttr>())
1377       shapes.push_back(shape_attr);
1378     else
1379       return failure();
1380 
1381   // Check all shapes are vectors of the same length.
1382   if (shapes.front().getType().getRank() != 1) return success();
1383   const int64_t num_dims = shapes.front().getNumElements();
1384   for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1))
1385     if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims)
1386       return failure();
1387 
1388   // Check concat_dim is within [-num_dims, num_dims).
1389   int32_t concat_dim = (*concat_dim_attr.getValues<int32_t>().begin());
1390   if (concat_dim < 0) concat_dim += num_dims;
1391   if (concat_dim >= num_dims || concat_dim < 0) return failure();
1392 
1393   // Check all elements besides at concat_dim match across all shape tensors.
1394   SmallVector<int32_t, 4> shape0;
1395   shape0.reserve(num_dims);
1396   for (int32_t dim : shapes.front().getValues<int32_t>()) shape0.push_back(dim);
1397 
1398   for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) {
1399     for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) {
1400       if (dims_and_idx.index() == concat_dim) continue;
1401 
1402       if (std::get<0>(dims_and_idx.value()) !=
1403           std::get<1>(dims_and_idx.value()).getSExtValue())
1404         return failure();
1405     }
1406   }
1407 
1408   // Compute an exclusive cumulative sum of elements at concat_dim.
1409   results.reserve(shapes.size());
1410   SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
1411   RankedTensorType offset_type =
1412       RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
1413   for (DenseIntElementsAttr shape : shapes) {
1414     results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
1415     cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);
1416   }
1417 
1418   return success();
1419 }
1420 
1421 //===----------------------------------------------------------------------===//
1422 // ConstOp
1423 //===----------------------------------------------------------------------===//
1424 
fold(ArrayRef<Attribute> operands)1425 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
1426   assert(operands.empty() && "constant has no operands");
1427 
1428   // Return the held attribute value.
1429   return value();
1430 }
1431 
1432 // Builds a constant op with the specified attribute `value`. The result
1433 // op's type is deduced from `value`; if `value` is of scalar type,
1434 // wraps it up with a tensor type of empty shape.
1435 // TODO(jpienaar): This one differs from the autogenerated one as it takes an
1436 // attribute but always creates an ElementsAttr internally.
build(OpBuilder & builder,OperationState & result,Attribute value)1437 void ConstOp::build(OpBuilder &builder, OperationState &result,
1438                     Attribute value) {
1439   ShapedType type;
1440   if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1441     return ConstOp::build(builder, result, elem_attr);
1442   } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
1443     // All TensorFlow types must be tensor types. In the build() method,
1444     // we want to provide more flexibility by allowing attributes of scalar
1445     // types. But we need to wrap it up with ElementsAttr to construct
1446     // valid TensorFlow constants.
1447     type = RankedTensorType::get(/*shape=*/{}, value.getType());
1448     return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
1449   }
1450   // TODO(jpienaar): support other TensorFlow specific types.
1451   llvm_unreachable("unsupported attribute type for building tf.Const");
1452 }
1453 
build(OpBuilder & builder,OperationState & result,Type type,Attribute value)1454 void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
1455                     Attribute value) {
1456   // Handle the case where the type and value are already tensors.
1457   if (type.isa<TensorType>() && value.isa<ElementsAttr>()) {
1458     result.addTypes(type);
1459     result.addAttribute("value", value);
1460     return;
1461   }
1462 
1463   // Otherwise, default to the attribute builder.
1464   ConstOp::build(builder, result, value);
1465   assert(type == result.types[0] && "type mismatch in construction");
1466 }
1467 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1468 LogicalResult ConstOp::inferReturnTypes(
1469     MLIRContext *context, Optional<Location> location, ValueRange operands,
1470     DictionaryAttr attributes, RegionRange regions,
1471     SmallVectorImpl<Type> &inferredReturnTypes) {
1472   auto value = attributes.get("value");
1473   if (!value) return emitOptionalError(location, "missing attribute 'value'");
1474   if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1475     inferredReturnTypes.assign({elem_attr.getType()});
1476     return success();
1477   }
1478   return emitOptionalError(location,
1479                            "attribute 'value' failed to satisfy constraint: "
1480                            "constant vector/tensor");
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // Conv2DOp and Conv3DOp
1485 //===----------------------------------------------------------------------===//
1486 
VerifyConvOpAttributes(int num_dims,ArrayRef<Attribute> strides,ArrayRef<Attribute> dilations,llvm::Optional<mlir::Location> location)1487 static LogicalResult VerifyConvOpAttributes(
1488     int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations,
1489     llvm::Optional<mlir::Location> location) {
1490   int64_t strides_size = strides.size();
1491   if (strides_size != num_dims)
1492     return emitOptionalError(
1493         location, "requires strides attribute length to be ", num_dims);
1494   auto is_not_positive = [](Attribute val) {
1495     return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0;
1496   };
1497   if (llvm::any_of(strides, is_not_positive))
1498     return emitOptionalError(location, "requires positive strides");
1499 
1500   int64_t dilations_size = dilations.size();
1501   if (dilations_size != num_dims)
1502     return emitOptionalError(
1503         location, "requires dilations attribute length to be ", num_dims);
1504   if (llvm::any_of(dilations, is_not_positive))
1505     return emitOptionalError(location, "requires positive dilations");
1506 
1507   return success();
1508 }
1509 
1510 // Verifies that,
1511 // * Number of input channels is divisible by the number of filter input
1512 //   channels
1513 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1514                             OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr>
Verify(OpT op)1515 static LogicalResult Verify(OpT op) {
1516   int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3;
1517   int num_dims = 2 + num_spatial_dims;
1518 
1519   int64_t input_channels = -1;
1520   if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
1521     absl::string_view data_format(op.data_format().data(),
1522                                   op.data_format().size());
1523     tensorflow::TensorFormat format;
1524     auto is_valid = FormatFromString(data_format, &format);
1525     DCHECK(is_valid) << data_format;
1526     int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format);
1527     input_channels = ty.getDimSize(idx);
1528   }
1529 
1530   int64_t filter_channels = -1;
1531   if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
1532     int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
1533         num_dims, tensorflow::FORMAT_HWIO);
1534     filter_channels = ty.getDimSize(idx);
1535   }
1536 
1537   if (input_channels != -1 && filter_channels != -1 &&
1538       input_channels % filter_channels != 0)
1539     return op.emitOpError()
1540            << "requires the number of input channels to be divisible by the "
1541               "number of filter input channels; found "
1542            << input_channels << " and " << filter_channels << ", respectively";
1543 
1544   return success();
1545 }
1546 
UpdateDataFormat(StringRef data_format)1547 LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
1548   auto perm = GetDataFormatPermutation(this->data_format(), data_format);
1549   if (perm.empty()) return failure();
1550 
1551   // Update data_format attribute and result types.
1552   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1553 
1554   // Update convolution attributes.
1555   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1556   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1557   (*this)->setAttr("explicit_paddings",
1558                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1559 
1560   return success();
1561 }
1562 
1563 // Verifies the inferred return type of the given operation.
1564 template <typename OpT,
1565           typename std::enable_if<llvm::is_one_of<
1566               OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr>
inferConvReturnTypes(OpT op,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes,llvm::Optional<mlir::Location> location,ArrayRef<Attribute> explicit_padding)1567 static LogicalResult inferConvReturnTypes(
1568     OpT op, llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes,
1569     llvm::Optional<mlir::Location> location,
1570     ArrayRef<Attribute> explicit_padding) {
1571   const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3;
1572   const int64_t num_dims = 2 + num_spatial_dims;
1573   const Value input = op.input();
1574   const Value filter = op.filter();
1575   const TensorType input_ty = input.getType().template cast<TensorType>();
1576   const TensorType filter_ty = filter.getType().template cast<TensorType>();
1577   const StringRef paddings = op.padding().getValue();
1578 
1579   ArrayRef<Attribute> strides = op.strides().getValue();
1580   StringRef data_format = op.data_format().getValue();
1581   ArrayRef<Attribute> dilations = op.dilations().getValue();
1582 
1583   tensorflow::TensorFormat format;
1584   auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1585   if (!data_format_is_valid) {
1586     return emitOptionalError(location, "Invalid data format provided");
1587   }
1588   tensorflow::Padding padding;
1589   auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1590   if (!padding_is_valid.ok()) {
1591     return emitOptionalError(location, "Invalid padding format provided");
1592   }
1593   auto get_int = [](Attribute attr) {
1594     return attr.template cast<IntegerAttr>().getInt();
1595   };
1596 
1597   // Necessary sanity checks.
1598   // Verifies that,
1599   // * Ranks of operands and result are valid
1600   // * Length of explicit_paddings attribute is valid and has non negative
1601   //   elements
1602   // * strides and dilations attributes have positive elements
1603   if (!IsOfRankOrUnranked(input, num_dims) ||
1604       !IsOfRankOrUnranked(filter, num_dims))
1605     return emitOptionalError(location, "requires operands to be ", num_dims,
1606                              "D tensor");
1607 
1608   if (padding == tensorflow::Padding::EXPLICIT) {
1609     if (explicit_padding.size() == 0) {
1610       return emitOptionalError(location,
1611                                "requires attribute 'explicit_paddings' with "
1612                                "'EXPLICIT' padding mode");
1613     }
1614     if (explicit_padding.size() != num_dims * 2) {
1615       return emitOptionalError(
1616           location, "requires explicit_paddings attribute length to be ",
1617           num_dims * 2);
1618     }
1619     auto is_negative = [](Attribute val) {
1620       return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
1621     };
1622     if (llvm::any_of(explicit_padding, is_negative))
1623       return emitOptionalError(location,
1624                                "requires non negative explicit paddings");
1625   }
1626 
1627   if (failed(VerifyConvOpAttributes(num_dims, strides, dilations, location))) {
1628     return failure();
1629   }
1630 
1631   // Output always have `num_dims` rank. All dimensions are initialized to
1632   // dynamic size and can be partially inferred.
1633   SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
1634   // Output batch and channel dimension can be obtained using utilities from
1635   // tensorflow/core/util/tensor_format.h.
1636   if (input_ty.hasRank()) {
1637     return_shape[GetTensorBatchDimIndex(num_dims, format)] =
1638         input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
1639   }
1640   if (filter_ty.hasRank()) {
1641     return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
1642         filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
1643             num_dims, tensorflow::FORMAT_HWIO));
1644   }
1645   // Spatial dimensions can be inferred only when both input and filter are
1646   // ranked because we need to get their spatial dimensions.
1647   if (input_ty.hasRank() && filter_ty.hasRank()) {
1648     // Checks the size of each of the output spatial dimensions.
1649     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1650       const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
1651       int64_t stride = get_int(strides[dim]);
1652       tensorflow::int64 expected_output_size;
1653       tensorflow::int64 pad_low;
1654       tensorflow::int64 pad_high;
1655       // Retrieve padding, if defined explicitly.
1656       if (padding == tensorflow::Padding::EXPLICIT) {
1657         pad_low = get_int(explicit_padding[2 * dim]);
1658         pad_high = get_int(explicit_padding[2 * dim + 1]);
1659       }
1660       // Skip if input or filter size is dynamic.
1661       if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
1662       // Calculate the expected_output_size.
1663       tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1664           input_ty.getDimSize(dim), filter_ty.getDimSize(i),
1665           get_int(dilations[dim]), stride, padding, &expected_output_size,
1666           &pad_low, &pad_high);
1667       // Return failure if expected_output_size could not be calculated.
1668       if (!status.ok()) return failure();
1669       return_shape[dim] = expected_output_size;
1670     }
1671   }
1672 
1673   inferredReturnTypes.assign(
1674       {RankedTensorType::get(return_shape, input_ty.getElementType())});
1675   return success();
1676 }
1677 
inferReturnTypes(mlir::MLIRContext * context,llvm::Optional<mlir::Location> location,mlir::ValueRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes)1678 LogicalResult Conv2DOp::inferReturnTypes(
1679     mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
1680     mlir::ValueRange operands, mlir::DictionaryAttr attributes,
1681     mlir::RegionRange regions,
1682     llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
1683   Conv2DOpAdaptor op(operands, attributes);
1684   ArrayRef<Attribute> explicit_padding;
1685   ArrayAttr explicit_pad =
1686       attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1687   if (!explicit_pad) {
1688     explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1689   }
1690   explicit_padding = explicit_pad.getValue();
1691 
1692   return inferConvReturnTypes(op, inferredReturnTypes, location,
1693                               explicit_padding);
1694 }
1695 
GetOptimalLayout(const RuntimeDevices & devices)1696 StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
1697   // Keep current data format if no GPUs are available or if explicit placement
1698   // does not allow to use GPU for this operation.
1699   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1700     return data_format();
1701 
1702   // Input must be a tensor.
1703   auto input_ty = input().getType().dyn_cast<TensorType>();
1704   if (!input_ty) return data_format();
1705 
1706   // For f16 data type on devices with Tensor Cores support NHWC data format
1707   // is up to ~2x faster.
1708   const bool is_f16 = input_ty.getElementType().isF16();
1709   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1710 
1711   // For f32/f16 data type decision depends on the filter size in spatial
1712   // dimensions, for other data types we keep current data format.
1713   if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
1714     return data_format();
1715 
1716   // Keep current data format if filter rank is unknown or not equal to 4.
1717   auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
1718   if (!filter_ty || filter_ty.getRank() != 4) return data_format();
1719 
1720   const int64_t d0 = filter_ty.getDimSize(0);
1721   const int64_t d1 = filter_ty.getDimSize(1);
1722 
1723   auto all_ones = [](ArrayAttr arr) -> bool {
1724     return llvm::all_of(arr, [](Attribute attr) -> bool {
1725       return attr.cast<IntegerAttr>().getInt() == 1;
1726     });
1727   };
1728 
1729   // Convolutions with 1x1 filter and with strides and dilations all ones, can
1730   // be computed as a GEMM in NHWC data format, and can be up to ~2x times
1731   // faster than convolution in NCHW.
1732   const bool one_by_one = d0 == 1 && d1 == 1;
1733   const bool trivial_strides = all_ones(strides());
1734   const bool trivial_dilations = all_ones(dilations());
1735 
1736   // TODO(ezhulenev): This might lead to excessive transposes in the final IR,
1737   // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
1738   // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
1739   // users can efficiently use NHWC data format?
1740   if (one_by_one && trivial_strides && trivial_dilations) {
1741     return "NHWC";
1742   }
1743 
1744   // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
1745   // it's the fastest option on NVIDIA GPUs with cuDNN library support.
1746   return "NCHW";
1747 }
1748 
1749 //===----------------------------------------------------------------------===//
1750 // Conv2dBackpropFilterOp
1751 //===----------------------------------------------------------------------===//
1752 
UpdateDataFormat(StringRef data_format)1753 LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) {
1754   StringRef src_data_format = this->data_format();
1755 
1756   auto perm = GetDataFormatPermutation(src_data_format, data_format);
1757   if (perm.empty()) return failure();
1758 
1759   // Update data_format attribute and result types.
1760   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1761 
1762   // Update convolution attributes.
1763   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1764   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1765   (*this)->setAttr("explicit_paddings",
1766                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1767 
1768   // Permute filter sizes operand.
1769   OpBuilder builder(getOperation());
1770   auto filter_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1771       getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format),
1772       StringAttr::get(getContext(), data_format));
1773   setOperand(1, filter_sizes_permuted);
1774 
1775   return success();
1776 }
1777 
GetOptimalLayout(const RuntimeDevices & devices)1778 StringRef Conv2DBackpropFilterOp::GetOptimalLayout(
1779     const RuntimeDevices &devices) {
1780   // Keep current data format if no GPUs are available or if explicit placement
1781   // does not allow to use GPU for this operation.
1782   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1783     return data_format();
1784 
1785   // Input must be a tensor.
1786   auto input_ty = input().getType().dyn_cast<TensorType>();
1787   if (!input_ty) return data_format();
1788 
1789   // For f16 data type on devices with Tensor Cores support NHWC data format
1790   // is up to ~2x faster.
1791   const bool is_f16 = input_ty.getElementType().isF16();
1792   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1793 
1794   // Otherwise always use "NCHW".
1795   return "NCHW";
1796 }
1797 
1798 //===----------------------------------------------------------------------===//
1799 // Conv2DBackpropInputOp
1800 //===----------------------------------------------------------------------===//
1801 
Verify(Conv2DBackpropInputOp op)1802 static LogicalResult Verify(Conv2DBackpropInputOp op) {
1803   int num_spatial_dims = 2;
1804   int num_dims = 2 + num_spatial_dims;
1805 
1806   if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) ||
1807       !IsOfRankOrUnranked(op.filter(), num_dims))
1808     return op.emitOpError()
1809            << "requires operands to be " << num_dims << "D tensor";
1810   if (!IsOfRankOrUnranked(op.getResult(), num_dims))
1811     return op.emitOpError()
1812            << "requires result to be " << num_dims << "D tensor";
1813 
1814   llvm::Optional<mlir::Location> location = op.getLoc();
1815   ArrayRef<Attribute> strides = op.strides().getValue();
1816   ArrayRef<Attribute> dilations = op.dilations().getValue();
1817   LogicalResult verify_result =
1818       VerifyConvOpAttributes(num_dims, strides, dilations, location);
1819   if (failed(verify_result)) {
1820     return verify_result;
1821   }
1822 
1823   return success();
1824 }
1825 
UpdateDataFormat(StringRef data_format)1826 LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) {
1827   StringRef src_data_format = this->data_format();
1828 
1829   auto perm = GetDataFormatPermutation(src_data_format, data_format);
1830   if (perm.empty()) return failure();
1831 
1832   // Update data_format attribute and result types.
1833   if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1834 
1835   // Update convolution attributes.
1836   (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1837   (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1838   (*this)->setAttr("explicit_paddings",
1839                    ShuffleArrayAttr(explicit_paddings(), perm, 2));
1840 
1841   // Permute input sizes operand.
1842   OpBuilder builder(getOperation());
1843   auto input_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1844       getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format),
1845       StringAttr::get(getContext(), data_format));
1846   setOperand(0, input_sizes_permuted);
1847 
1848   return success();
1849 }
1850 
GetOptimalLayout(const RuntimeDevices & devices)1851 StringRef Conv2DBackpropInputOp::GetOptimalLayout(
1852     const RuntimeDevices &devices) {
1853   // Keep current data format if no GPUs are available or if explicit placement
1854   // does not allow to use GPU for this operation.
1855   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1856     return data_format();
1857 
1858   // Filter must be a tensor.
1859   auto filter_ty = filter().getType().dyn_cast<TensorType>();
1860   if (!filter_ty) return data_format();
1861 
1862   // For f16 data type on devices with Tensor Cores support NHWC data format
1863   // is up to ~2x faster.
1864   const bool is_f16 = filter_ty.getElementType().isF16();
1865   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1866 
1867   // Otherwise always use "NCHW".
1868   return "NCHW";
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // Conv3DOp
1873 //===----------------------------------------------------------------------===//
1874 
inferReturnTypes(mlir::MLIRContext * context,llvm::Optional<mlir::Location> location,mlir::ValueRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes)1875 LogicalResult Conv3DOp::inferReturnTypes(
1876     mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
1877     mlir::ValueRange operands, mlir::DictionaryAttr attributes,
1878     mlir::RegionRange regions,
1879     llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
1880   Conv3DOpAdaptor op(operands, attributes);
1881   ArrayRef<Attribute> explicit_padding;
1882   ArrayAttr explicit_pad =
1883       attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1884   if (!explicit_pad) {
1885     explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1886   }
1887   explicit_padding = explicit_pad.getValue();
1888 
1889   return inferConvReturnTypes(op, inferredReturnTypes, location,
1890                               explicit_padding);
1891 }
1892 
1893 //===----------------------------------------------------------------------===//
1894 // DataFormatVecPermuteOp
1895 //===----------------------------------------------------------------------===//
1896 
Verify(DataFormatVecPermuteOp op)1897 static LogicalResult Verify(DataFormatVecPermuteOp op) {
1898   auto input_ty = op.x().getType().dyn_cast<RankedTensorType>();
1899   if (!input_ty) return success();
1900 
1901   int rank = input_ty.getRank();
1902   if (rank != 1 && rank != 2)
1903     return op.emitOpError("requires input of rank 1 or 2");
1904 
1905   if (rank == 1) {
1906     int64_t dim0 = input_ty.getDimSize(0);
1907     if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
1908       return op.emitOpError("requires 1D input of size 4 or size 2");
1909   }
1910 
1911   if (rank == 2) {
1912     int64_t dim0 = input_ty.getDimSize(0);
1913     if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
1914       return op.emitOpError(
1915           "requires first dimensions of 2D input to be of size 4");
1916 
1917     int64_t dim1 = input_ty.getDimSize(1);
1918     if (dim1 != ShapedType::kDynamicSize && dim1 != 2)
1919       return op.emitOpError(
1920           "requires second dimensions of 2D input to be of size 2");
1921   }
1922 
1923   return success();
1924 }
1925 
1926 //===----------------------------------------------------------------------===//
1927 // DivOp
1928 //===----------------------------------------------------------------------===//
1929 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1930 void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1931                                         MLIRContext *context) {
1932   results.insert<DivWithSqrtDivisor>(context);
1933 }
1934 
fold(ArrayRef<Attribute> operands)1935 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1936   return IdentityArithmeticOpFolder<DivOp>(*this, operands);
1937 }
1938 
1939 //===----------------------------------------------------------------------===//
1940 // DynamicStitchOp
1941 //===----------------------------------------------------------------------===//
1942 
Verify(DynamicStitchOp op)1943 static LogicalResult Verify(DynamicStitchOp op) {
1944   if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1");
1945 
1946   if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) {
1947     if (out_ty.getRank() == 0) {
1948       return op.emitOpError("requires non scalar output");
1949     }
1950   }
1951 
1952   llvm::SmallDenseSet<int64_t, 8> index_values;
1953   bool all_indices_const = true;
1954   int32_t max_index = -1;
1955   llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
1956   for (auto it : llvm::zip(op.indices(), op.data())) {
1957     Value index = std::get<0>(it);
1958 
1959     DenseIntElementsAttr index_attr;
1960     if (matchPattern(index, m_Constant(&index_attr))) {
1961       for (int32_t index : index_attr.getValues<int32_t>()) {
1962         if (index < 0)
1963           return op.emitOpError()
1964                  << "requires non-negative index values; found " << index;
1965         max_index = std::max(index, max_index);
1966         index_values.insert(index);
1967       }
1968     } else {
1969       all_indices_const = false;
1970     }
1971 
1972     Value data = std::get<1>(it);
1973     RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
1974     RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
1975     if (!index_ty || !data_ty) continue;
1976 
1977     int64_t index_rank = index_ty.getRank();
1978     ArrayRef<int64_t> data_shape = data_ty.getShape();
1979     ArrayRef<int64_t> index_shape = index_ty.getShape();
1980     if (failed(mlir::verifyCompatibleShape(index_shape,
1981                                            data_shape.take_front(index_rank))))
1982       return op.emitOpError() << "requires shape of data with type " << data_ty
1983                               << " to have prefix matching with shape of the "
1984                                  "corresponding index type "
1985                               << index_ty;
1986 
1987     ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank);
1988     if (!inferred_item_shape) {
1989       inferred_item_shape = llvm::to_vector<4>(item_shape);
1990       continue;
1991     }
1992 
1993     if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape)))
1994       return op.emitOpError() << "has inconsistent shaped data and index "
1995                                  "pairs; inferred item shapes ["
1996                               << llvm::makeArrayRef(*inferred_item_shape)
1997                               << "] and [" << item_shape << "] don't match";
1998     for (int i = 0, e = item_shape.size(); i < e; ++i) {
1999       int64_t &inferred_dim = (*inferred_item_shape)[i];
2000       int64_t dim = item_shape[i];
2001       if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim;
2002     }
2003   }
2004 
2005   // If all indices are constants, then verify that they cover all indices in
2006   // the range [0, max_index] and the output type is legal.
2007   if (all_indices_const) {
2008     for (int32_t i = 0; i <= max_index; i++) {
2009       if (!index_values.count(i))
2010         return op.emitOpError() << "missing index " << i;
2011     }
2012 
2013     if (inferred_item_shape) {
2014       SmallVector<int64_t, 4> expected_shape;
2015       expected_shape.push_back(max_index + 1);
2016       expected_shape.append(inferred_item_shape->begin(),
2017                             inferred_item_shape->end());
2018 
2019       auto out_ty = op.getType().cast<TensorType>();
2020       auto expected_out_ty =
2021           RankedTensorType::get(expected_shape, out_ty.getElementType());
2022 
2023       if (!AreCastCompatible({out_ty, expected_out_ty})) {
2024         return op.emitOpError() << "has invalid output type; should be "
2025                                    "compatible with inferred type "
2026                                 << expected_out_ty;
2027       }
2028     }
2029   }
2030 
2031   return success();
2032 }
2033 
2034 //===----------------------------------------------------------------------===//
2035 // EinsumOp
2036 //===----------------------------------------------------------------------===//
2037 
2038 // Verifies that,
2039 // * Arity of the op is at most two.
2040 //
2041 // TODO(hinsu): Verify einsum equation attribute.
Verify(EinsumOp op)2042 static LogicalResult Verify(EinsumOp op) {
2043   if (op.N() > 2) {
2044     return op.emitOpError("supports at most two operands");
2045   }
2046   return success();
2047 }
2048 
2049 //===----------------------------------------------------------------------===//
2050 // EmptyOp
2051 //===----------------------------------------------------------------------===//
2052 
fold(ArrayRef<Attribute> operands)2053 OpFoldResult EmptyOp::fold(ArrayRef<Attribute> operands) {
2054   assert(operands.size() == 1 && "empty op has one operand");
2055 
2056   Attribute attr = operands.front();
2057   if (!attr) return {};
2058 
2059   auto int_attr = attr.cast<DenseIntElementsAttr>();
2060   SmallVector<int64_t, 6> out_shape;
2061   for (const auto val : int_attr.getValues<int32_t>()) {
2062     out_shape.push_back(val);
2063   }
2064 
2065   auto type = getResult().getType().cast<ShapedType>();
2066   auto etype = type.getElementType();
2067 
2068   // We can not fold if the result is not static.
2069   if (!type.hasStaticShape()) return {};
2070 
2071   if (auto float_type = etype.dyn_cast<FloatType>()) {
2072     auto out_type = RankedTensorType::get(out_shape, float_type);
2073     return DenseElementsAttr::get(out_type,
2074                                   {APFloat(float_type.getFloatSemantics())});
2075   }
2076 
2077   if (auto int_type = etype.dyn_cast<IntegerType>()) {
2078     auto out_type = RankedTensorType::get(out_shape, etype);
2079     APInt val(int_type.getWidth(), 0, int_type.getSignedness());
2080     return DenseElementsAttr::get(out_type, val);
2081   }
2082 
2083   return {};
2084 }
2085 
2086 //===----------------------------------------------------------------------===//
2087 // EmptyTensorListOp
2088 //===----------------------------------------------------------------------===//
2089 
Verify(EmptyTensorListOp op)2090 static LogicalResult Verify(EmptyTensorListOp op) {
2091   if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2092       !IsOfRankOrUnranked(op.element_shape(), 1)) {
2093     return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2094   }
2095 
2096   if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
2097     return op.emitOpError("requires max_num_elements operand to be 0D tensor");
2098   }
2099   return success();
2100 }
2101 
2102 //===----------------------------------------------------------------------===//
2103 // EqualOp
2104 //===----------------------------------------------------------------------===//
2105 
Verify(EqualOp op)2106 static LogicalResult Verify(EqualOp op) {
2107   // If we allow inputs to have incompatible type, then nothing to do.
2108   if (!op.incompatible_shape_error()) return success();
2109 
2110   // Otherwise, check inputs are broadcastable.
2111   return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
2112       op.getOperation());
2113 }
2114 
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)2115 void EqualOp::build(OpBuilder &builder, OperationState &result, Value x,
2116                     Value y, BoolAttr incompatible_shape_error) {
2117   auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
2118                                           incompatible_shape_error);
2119   return build(builder, result, result_type, x, y, incompatible_shape_error);
2120 }
2121 
2122 //===----------------------------------------------------------------------===//
2123 // ExpandDimsOp
2124 //===----------------------------------------------------------------------===//
2125 
InferExpandDimsOpType(Value input,Value dim)2126 Type InferExpandDimsOpType(Value input, Value dim) {
2127   Type element_ty = input.getType().cast<TensorType>().getElementType();
2128   auto unranked_ty = UnrankedTensorType::get(element_ty);
2129 
2130   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
2131   if (!input_ty) return unranked_ty;
2132 
2133   DenseIntElementsAttr dim_attr;
2134   if (!matchPattern(dim, m_Constant(&dim_attr)) ||
2135       dim_attr.getNumElements() != 1)
2136     return unranked_ty;
2137   int64_t dim_val = (*dim_attr.begin()).getSExtValue();
2138   int64_t input_rank = input_ty.getRank();
2139 
2140   if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty;
2141   if (dim_val < 0) dim_val += input_rank + 1;
2142 
2143   SmallVector<int64_t, 4> shape = llvm::to_vector<4>(input_ty.getShape());
2144   shape.insert(shape.begin() + dim_val, 1);
2145   return RankedTensorType::get(shape, element_ty);
2146 }
2147 
build(OpBuilder & builder,OperationState & result,Value input,Value dim)2148 void ExpandDimsOp::build(OpBuilder &builder, OperationState &result,
2149                          Value input, Value dim) {
2150   return build(builder, result, InferExpandDimsOpType(input, dim), input, dim);
2151 }
2152 
2153 //===----------------------------------------------------------------------===//
2154 // FakeQuantWithMinMaxArgsOp
2155 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxArgsOp op)2156 static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
2157   // TODO(fengliuai): moving the following to an utility method.
2158   const llvm::fltSemantics &semantics = op.min().getSemantics();
2159   float rmin, rmax;
2160   if (&semantics == &APFloat::IEEEsingle()) {
2161     rmin = op.min().convertToFloat();
2162     rmax = op.max().convertToFloat();
2163   } else {
2164     rmin = op.min().convertToDouble();
2165     rmax = op.max().convertToDouble();
2166   }
2167   // Range boundaries must be valid.
2168   if (rmin >= rmax) {
2169     return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
2170                           "," + Twine(std::to_string(rmax)) + "]");
2171   }
2172   int64_t num_bits = op.num_bits();
2173   if (num_bits < 2 || num_bits > 16) {
2174     return op.emitOpError(
2175         "requires num_bits to be between 2 and 16, inclusive");
2176   }
2177   return success();
2178 }
2179 
2180 //===----------------------------------------------------------------------===//
2181 // FakeQuantWithMinMaxVarsOp
2182 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxVarsOp op)2183 static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
2184   auto min = GetRankedTensorTypeForOperand(op.min());
2185   if (min && !IsOfRankedFloatTensorType(min, 0))
2186     return op.emitOpError("requires min to be a 0d float tensor");
2187 
2188   auto max = GetRankedTensorTypeForOperand(op.max());
2189   if (max && !IsOfRankedFloatTensorType(max, 0))
2190     return op.emitOpError("requires max to be a 0d float tensor");
2191 
2192   int64_t num_bits = op.num_bits();
2193   if (num_bits < 2 || num_bits > 16) {
2194     return op.emitOpError(
2195         "requires num_bits to be between 2 and 16, inclusive");
2196   }
2197   return success();
2198 }
2199 
2200 //===----------------------------------------------------------------------===//
2201 // FakeQuantWithMinMaxVarsPerChannelOp
2202 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxVarsPerChannelOp op)2203 static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
2204   auto min = GetRankedTensorTypeForOperand(op.min());
2205   if (min && !IsOfRankedFloatTensorType(min, 1))
2206     return op.emitOpError("requires min to be a 1d float tensor");
2207 
2208   auto max = GetRankedTensorTypeForOperand(op.max());
2209   if (max && !IsOfRankedFloatTensorType(max, 1))
2210     return op.emitOpError("requires max to be a 1d float tensor");
2211 
2212   Value inputs = op.inputs();
2213   if (!HasRankAtLeast(inputs, 1))
2214     return op.emitError("requires inputs to be at least 1d float tensor");
2215 
2216   int64_t num_bits = op.num_bits();
2217   if (num_bits < 2 || num_bits > 16) {
2218     return op.emitOpError(
2219         "requires num_bits to be between 2 and 16, inclusive");
2220   }
2221 
2222   auto inputs_type = inputs.getType().dyn_cast<RankedTensorType>();
2223   if (!inputs_type) return success();
2224   int depth = inputs_type.getDimSize(inputs_type.getRank() - 1);
2225   if ((min && min.getDimSize(0) != depth) ||
2226       (max && max.getDimSize(0) != depth)) {
2227     return op.emitOpError(
2228         "requires min and max to have same size as last dimension of inputs");
2229   }
2230 
2231   return success();
2232 }
2233 
2234 //===----------------------------------------------------------------------===//
2235 // FillOp
2236 //===----------------------------------------------------------------------===//
2237 
Verify(FillOp op)2238 static LogicalResult Verify(FillOp op) {
2239   if (!IsOfRankOrUnranked(op.dims(), 1))
2240     return op.emitOpError() << "requires dims to be a 1D tensor";
2241   if (!IsOfRankOrUnranked(op.value(), 0))
2242     return op.emitOpError() << "requires value to be a scalar";
2243 
2244   return success();
2245 }
2246 
InferFillOpType(Value dims,Value value)2247 static ShapedType InferFillOpType(Value dims, Value value) {
2248   Type etype = value.getType().cast<ShapedType>().getElementType();
2249 
2250   DenseIntElementsAttr dims_attr;
2251   if (!matchPattern(dims, m_Constant(&dims_attr))) {
2252     return UnrankedTensorType::get(etype);
2253   }
2254 
2255   llvm::SmallVector<int64_t, 4> shape;
2256   shape.reserve(dims_attr.getNumElements());
2257   for (const APInt dim : dims_attr.getValues<APInt>()) {
2258     shape.push_back(dim.getSExtValue());
2259   }
2260   return RankedTensorType::get(shape, etype);
2261 }
2262 
build(OpBuilder & builder,OperationState & result,Value dims,Value value)2263 void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
2264                    Value value) {
2265   FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
2266 }
2267 
fold(ArrayRef<Attribute> operands)2268 OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
2269   assert(operands.size() == 2 && "fill op has two operand");
2270 
2271   auto type = getType().cast<ShapedType>();
2272   // DenseElementsAttr that is used in this folder only supports int and float
2273   // types.
2274   // TODO(hinsu): Handle complex types once there is a attribute kind for
2275   // complex.
2276   if (!type.getElementType().isIntOrFloat()) return {};
2277 
2278   auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
2279   if (!value) return {};
2280 
2281   if (type.hasStaticShape())
2282     return DenseElementsAttr::get(type, value.getValue({}));
2283 
2284   auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2285   if (!dims) return {};
2286 
2287   llvm::SmallVector<int64_t, 4> shape;
2288   shape.reserve(dims.getNumElements());
2289   for (const APInt dim : dims.getValues<APInt>()) {
2290     shape.push_back(dim.getSExtValue());
2291   }
2292   type = RankedTensorType::get(shape, type.getElementType());
2293 
2294   return DenseElementsAttr::get(type, value.getValue({}));
2295 }
2296 
2297 //===----------------------------------------------------------------------===//
2298 // FusedBatchNormGradOp
2299 //===----------------------------------------------------------------------===//
2300 
2301 // TODO(b/150954845): Add benchmarks to verify that layout preference didn't
2302 // change in the latest GPU generations.
2303 
UpdateDataFormat(StringRef data_format)2304 LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
2305   return ::mlir::TF::UpdateDataFormat(data_format, this);
2306 }
2307 
GetOptimalLayout(const RuntimeDevices & devices)2308 StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
2309     const RuntimeDevices &devices) {
2310   // Keep current data format if no GPUs are available or if explicit placement
2311   // does not allow to use GPU for this operation.
2312   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2313     return data_format();
2314 
2315   // For f16 data type on devices with Tensor Cores support NHWC data format
2316   // is up to ~2x faster.
2317   auto x_ty = x().getType().cast<TensorType>();
2318   const bool is_f16 = x_ty.getElementType().isF16();
2319   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2320 
2321   // For all other data types prefer NCHW.
2322   return "NCHW";
2323 }
2324 
2325 //===----------------------------------------------------------------------===//
2326 // FusedBatchNormOp
2327 //===----------------------------------------------------------------------===//
2328 
Verify(FusedBatchNormOp op)2329 static LogicalResult Verify(FusedBatchNormOp op) {
2330   auto x = GetRankedTensorTypeForOperand(op.x());
2331   if (x && !IsOfRankedFloatTensorType(x, 4))
2332     return op.emitOpError("requires x to be a 4D float tensor");
2333 
2334   auto scale = GetRankedTensorTypeForOperand(op.scale());
2335   if (scale && !IsOfRankedFloatTensorType(scale, 1))
2336     return op.emitOpError("requires scale to be a 1D float tensor");
2337 
2338   auto offset = GetRankedTensorTypeForOperand(op.offset());
2339   if (offset && !IsOfRankedFloatTensorType(offset, 1))
2340     return op.emitOpError("requires offset to be a 1D float tensor");
2341 
2342   auto mean = GetRankedTensorTypeForOperand(op.mean());
2343   if (mean && !IsOfRankedFloatTensorType(mean, 1))
2344     return op.emitOpError("requires mean to be a 1D float tensor");
2345 
2346   auto variance = GetRankedTensorTypeForOperand(op.variance());
2347   if (variance && !IsOfRankedFloatTensorType(variance, 1))
2348     return op.emitOpError("requires variance to be a 1D float tensor");
2349 
2350   // TODO(antiagainst): check attributes
2351 
2352   return success();
2353 }
2354 
2355 //===----------------------------------------------------------------------===//
2356 // FusedBatchNormV2Op / FusedBatchNormV3Op
2357 //===----------------------------------------------------------------------===//
2358 
2359 template <class Op>
InferenceFoldOperandsPermutation(ArrayRef<int64_t> permutation,Op * op)2360 static LogicalResult InferenceFoldOperandsPermutation(
2361     ArrayRef<int64_t> permutation, Op *op) {
2362   // FusedBatchNorm in training mode is a layout sentitive operation, and should
2363   // have already assigned an optimal data format.
2364   if (op->is_training()) return failure();
2365   return ::mlir::TF::FoldOperandsPermutation(permutation, op);
2366 }
2367 
2368 template <class Op>
GetOptimalLayout(const RuntimeDevices & devices,Op * op)2369 static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
2370   // In inference mode FusedBatchNorm is not sensitive to data layout.
2371   if (!op->is_training()) return op->data_format();
2372 
2373   // Keep current data format if no GPUs are available or if explicit placement
2374   // does not allow to use GPU for this operation.
2375   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
2376     return op->data_format();
2377 
2378   // For f16 data type on devices with Tensor Cores support NHWC data format
2379   // is up to ~2x faster.
2380   auto x_ty = op->x().getType().template cast<TensorType>();
2381   const bool is_f16 = x_ty.getElementType().isF16();
2382   if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2383 
2384   // For all other data types prefer NCHW.
2385   return "NCHW";
2386 }
2387 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2388 LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
2389     ArrayRef<int64_t> permutation) {
2390   return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2391 }
2392 
UpdateDataFormat(StringRef data_format)2393 LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
2394   return ::mlir::TF::UpdateDataFormat(data_format, this);
2395 }
2396 
GetOptimalLayout(const RuntimeDevices & devices)2397 StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
2398   return ::mlir::TF::GetOptimalLayout(devices, this);
2399 }
2400 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2401 LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
2402     ArrayRef<int64_t> permutation) {
2403   return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2404 }
2405 
UpdateDataFormat(StringRef data_format)2406 LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
2407   return ::mlir::TF::UpdateDataFormat(data_format, this);
2408 }
2409 
GetOptimalLayout(const RuntimeDevices & devices)2410 StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
2411   return ::mlir::TF::GetOptimalLayout(devices, this);
2412 }
2413 
2414 //===----------------------------------------------------------------------===//
2415 // GatherV2Op
2416 //===----------------------------------------------------------------------===//
2417 
Verify(GatherV2Op op)2418 static LogicalResult Verify(GatherV2Op op) {
2419   int64_t batch_dims = op.batch_dims();
2420   if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
2421     int64_t rank = ty.getRank();
2422     if (batch_dims > rank || batch_dims < -rank)
2423       return op.emitOpError()
2424              << "batch_dims (" << batch_dims << ") must be in range [" << -rank
2425              << ", " << rank + 1 << ")";
2426     if (batch_dims < 0) batch_dims += rank;
2427   }
2428 
2429   if (!HasRankAtMost(op.axis(), 1))
2430     return op.emitOpError("requires axis to have rank at most 1");
2431 
2432   DenseIntElementsAttr axis_attr;
2433   if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
2434     int64_t axis = (*axis_attr.begin()).getSExtValue();
2435     if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
2436       int64_t rank = ty.getRank();
2437       if (axis >= rank || axis < -rank)
2438         return op.emitOpError() << "axis (" << axis << ") must be in range ["
2439                                 << -rank << ", " << rank << ")";
2440       if (axis < 0) axis += rank;
2441     }
2442 
2443     if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) {
2444       return op.emitOpError() << "requires axis (" << axis
2445                               << ") to be greater than or equal to batch_dims ("
2446                               << batch_dims << ")";
2447     }
2448   }
2449   return success();
2450 }
2451 
2452 //===----------------------------------------------------------------------===//
2453 // IfOp
2454 //===----------------------------------------------------------------------===//
2455 
Verify(IfOp op)2456 static LogicalResult Verify(IfOp op) {
2457   auto branch_name = [](unsigned index) -> std::string {
2458     return index == 0 ? "'then_branch'" : "'else_branch'";
2459   };
2460   return VerifyCaseOrIfOpBranchFunctions(
2461       op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name);
2462 }
2463 
2464 //===----------------------------------------------------------------------===//
2465 // IfOp canonicalization.
2466 //===----------------------------------------------------------------------===//
2467 
2468 namespace {
2469 class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
2470  public:
FoldConstantIfOp(MLIRContext * context)2471   explicit FoldConstantIfOp(MLIRContext *context)
2472       : OpRewritePattern<TF::IfOp>(context) {}
2473   LogicalResult matchAndRewrite(TF::IfOp op,
2474                                 PatternRewriter &rewriter) const override;
2475 
2476  private:
2477   template <typename T>
2478   struct CallOpType {
2479     using CallOp = T;
2480   };
2481 };
2482 
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const2483 LogicalResult FoldConstantIfOp::matchAndRewrite(
2484     TF::IfOp op, PatternRewriter &rewriter) const {
2485   // Extract the constant cond value.
2486   DenseIntElementsAttr cond_attr;
2487   if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2488 
2489   // Cond value must be a scalar.
2490   if (cond_attr.getNumElements() != 1) return failure();
2491 
2492   // Select a branch function.
2493   bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2494   FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
2495 
2496   // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
2497   auto rewrite = [&](auto op_type) {
2498     auto empty = rewriter.getStringAttr("");
2499     auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
2500         op.getLoc(), op.getResultTypes(), op.input(), func,
2501         /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
2502     CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
2503     rewriter.replaceOp(op, call_op.getResults());
2504   };
2505 
2506   if (op.is_stateless())
2507     rewrite(CallOpType<PartitionedCallOp>{});
2508   else
2509     rewrite(CallOpType<StatefulPartitionedCallOp>{});
2510 
2511   return success();
2512 }
2513 }  // anonymous namespace
2514 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2515 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2516                                        MLIRContext *context) {
2517   results.insert<FoldConstantIfOp, DropAttributes<IfOp>>(context);
2518 }
2519 
2520 //===----------------------------------------------------------------------===//
2521 // IfRegionOp
2522 //===----------------------------------------------------------------------===//
2523 
Verify(IfRegionOp op)2524 static LogicalResult Verify(IfRegionOp op) {
2525   TypeRange then_types =
2526       op.then_branch().front().getTerminator()->getOperandTypes();
2527   TypeRange else_types =
2528       op.else_branch().front().getTerminator()->getOperandTypes();
2529 
2530   TypeRangeWithDesc results{op.getResultTypes(), "result"};
2531   TypeRangeWithDesc then_results{then_types, "then result"};
2532   TypeRangeWithDesc else_results{else_types, "else result"};
2533 
2534   if (failed(VerifyTypeRangesAreCompatible(op, then_results, results)))
2535     return failure();
2536   if (failed(VerifyTypeRangesAreCompatible(op, else_results, results)))
2537     return failure();
2538   return success();
2539 }
2540 
2541 namespace {
2542 class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
2543  public:
FoldConstantIfRegionOp(MLIRContext * context)2544   explicit FoldConstantIfRegionOp(MLIRContext *context)
2545       : OpRewritePattern<TF::IfRegionOp>(context) {}
2546   LogicalResult matchAndRewrite(TF::IfRegionOp op,
2547                                 PatternRewriter &rewriter) const override;
2548 };
2549 
matchAndRewrite(TF::IfRegionOp op,PatternRewriter & rewriter) const2550 LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
2551     TF::IfRegionOp op, PatternRewriter &rewriter) const {
2552   // Extract the constant cond value.
2553   DenseIntElementsAttr cond_attr;
2554   if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2555 
2556   // IfRegion condition should always be a scalar. Select the region to fold to.
2557   bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2558   Region &region = cond ? op.then_branch() : op.else_branch();
2559 
2560   // If the IfRegion is stateless but the region being inlined itself is not
2561   // stateless, then inlining the region could cause a loss of information.
2562   // However, its probably better to fold the IfRegion instead of having the
2563   // dead branch stay.
2564 
2565   // Inline the region in place of the IfRegion op, and forward the yield
2566   // inputs to the IfRegion op results. This is possible only if the yield
2567   // types match the result types.
2568   auto yield = cast<YieldOp>(region.front().getTerminator());
2569   auto updated_results = llvm::to_vector<4>(yield.getOperands());
2570 
2571   // If the yield types do not match the IfRegion result types, add appropriate
2572   // casts.
2573   rewriter.setInsertionPoint(yield);
2574   for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
2575     auto &updated_result = std::get<1>(it);
2576     Type result_type = std::get<0>(it);
2577     if (result_type != updated_result.getType()) {
2578       updated_result =
2579           rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
2580                                       /*Truncate=*/rewriter.getBoolAttr(false));
2581     }
2582   }
2583   // Inline the region into the block containing the IfRegion.
2584   rewriter.mergeBlockBefore(&region.front(), op);
2585   rewriter.eraseOp(yield);
2586   rewriter.replaceOp(op, updated_results);
2587   return success();
2588 }
2589 }  // anonymous namespace
2590 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2591 void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2592                                              MLIRContext *context) {
2593   results.insert<FoldConstantIfRegionOp,
2594                  CaseOrIfRegionEliminatePassThrough<TF::IfRegionOp>>(context);
2595 }
2596 
2597 //===----------------------------------------------------------------------===//
2598 // InvertPermutationOp
2599 //===----------------------------------------------------------------------===//
2600 
2601 // Verifies that the input is 1D.
Verify(InvertPermutationOp op)2602 static LogicalResult Verify(InvertPermutationOp op) {
2603   auto x_type = op.x().getType().cast<TensorType>();
2604   if (!x_type.hasRank()) return success();
2605   if (x_type.getShape().size() != 1)
2606     return op.emitOpError() << "requires input x to be 1-dimensional";
2607 
2608   return success();
2609 }
2610 
2611 //===----------------------------------------------------------------------===//
2612 // LeakyReluOp
2613 //===----------------------------------------------------------------------===//
2614 
fold(ArrayRef<Attribute> operands)2615 OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
2616   assert(operands.size() == 1 && "leaky relu has one operand");
2617 
2618   // leaky_relu(x, alpha: 1) -> x
2619   if (alpha().convertToFloat() == 1.0f) return getOperand();
2620 
2621   auto calculate = [&](FloatAttr arg) {
2622     APFloat val = arg.getValue();
2623     if (val.isNegative()) val = alpha() * val;
2624     return FloatAttr::get(arg.getType(), val);
2625   };
2626 
2627   if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
2628     return calculate(arg);
2629   } else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
2630     if (auto elementAttr = arg.getSplatValue().dyn_cast<FloatAttr>())
2631       return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
2632   }
2633   return {};
2634 }
2635 
GetContractionFusion()2636 Optional<ContractionFusion> LeakyReluOp::GetContractionFusion() {
2637   // Only f32 is supported for fusion.
2638   if (!T().isF32()) return None;
2639 
2640   NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr());
2641   return ContractionFusion("LeakyRelu", /*additional_arguments=*/{},
2642                            /*additional_attributes=*/{alpha});
2643 }
2644 
2645 //===----------------------------------------------------------------------===//
2646 // LogOp
2647 //===----------------------------------------------------------------------===//
2648 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2649 void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2650                                         MLIRContext *context) {
2651   results.insert<LogOfSoftmax, LogToLog1p>(context);
2652 }
2653 
2654 //===----------------------------------------------------------------------===//
2655 // LogicalNotOp
2656 //===----------------------------------------------------------------------===//
2657 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2658 void LogicalNotOp::getCanonicalizationPatterns(
2659     OwningRewritePatternList &results, MLIRContext *context) {
2660   results.insert<LogicalNotOfEqual, LogicalNotOfNotEqual, LogicalNotOfGreater,
2661                  LogicalNotOfGreaterEqual, LogicalNotOfLess,
2662                  LogicalNotOfLessEqual>(context);
2663 }
2664 
2665 //===----------------------------------------------------------------------===//
2666 // MatrixBandPartOp
2667 //===----------------------------------------------------------------------===//
2668 
Verify(MatrixBandPartOp op)2669 static LogicalResult Verify(MatrixBandPartOp op) {
2670   if (!HasRankAtLeast(op.input(), 2)) {
2671     return op.emitOpError()
2672            << "requires `input` to have rank of at least 2, but found "
2673            << op.input().getType();
2674   }
2675   if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
2676     return op.emitOpError()
2677            << "requires `num_lower` to have 0 dimensions, but found "
2678            << op.num_lower().getType();
2679   }
2680   if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
2681     return op.emitOpError()
2682            << "requires `num_upper` to have 0 dimensions, but found "
2683            << op.num_upper().getType();
2684   }
2685   return success();
2686 }
2687 
2688 //===----------------------------------------------------------------------===//
2689 // MatrixSetDiagOp
2690 //===----------------------------------------------------------------------===//
2691 //
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2692 void MatrixSetDiagOp::getCanonicalizationPatterns(
2693     OwningRewritePatternList &results, MLIRContext *context) {
2694   results.insert<MatrixSetDiagToV3>(context);
2695 }
2696 
2697 //===----------------------------------------------------------------------===//
2698 // MatrixSetDiagV2Op
2699 //===----------------------------------------------------------------------===//
2700 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2701 void MatrixSetDiagV2Op::getCanonicalizationPatterns(
2702     OwningRewritePatternList &results, MLIRContext *context) {
2703   results.insert<MatrixSetDiagV2ToV3>(context);
2704 }
2705 
2706 //===----------------------------------------------------------------------===//
2707 // MaxOp
2708 //===----------------------------------------------------------------------===//
2709 
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)2710 void MaxOp::build(OpBuilder &builder, OperationState &result, Value input,
2711                   Value reduction_indices, BoolAttr keep_dims) {
2712   Type out_ty =
2713       InferReductionOpType(input, reduction_indices, keep_dims, &builder);
2714   build(builder, result, out_ty, input, reduction_indices, keep_dims);
2715 }
2716 
2717 //===----------------------------------------------------------------------===//
2718 // MaxPoolOp
2719 //===----------------------------------------------------------------------===//
2720 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2721 LogicalResult MaxPoolOp::FoldOperandsPermutation(
2722     ArrayRef<int64_t> permutation) {
2723   return ::mlir::TF::FoldOperandsPermutation(
2724       permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
2725 }
2726 
UpdateDataFormat(StringRef new_data_format)2727 LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
2728   StringRef src_data_format = data_format();
2729 
2730   auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
2731   if (perm.empty()) return failure();
2732 
2733   // Update data_format attribute and result types.
2734   if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
2735     return failure();
2736 
2737   stridesAttr(ShuffleArrayAttr(strides(), perm));
2738   explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
2739   ksizeAttr(ShuffleArrayAttr(ksize(), perm));
2740 
2741   return success();
2742 }
2743 
GetOptimalLayout(const RuntimeDevices & devices)2744 StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
2745   // Keep current data format if no GPUs are available or if explicit placement
2746   // does not allow to use GPU for this operation.
2747   if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2748     return data_format();
2749 
2750   // Defaults to NCHW.
2751   return "NCHW";
2752 }
2753 
2754 //===----------------------------------------------------------------------===//
2755 // MaxPoolGradOp
2756 //===----------------------------------------------------------------------===//
2757 
Verify(MaxPoolGradOp op)2758 static LogicalResult Verify(MaxPoolGradOp op) {
2759   if (!IsOfRankOrUnranked(op.orig_input(), 4)) {
2760     return op.emitOpError() << "requires orig_input to be rank 4";
2761   }
2762   if (!IsOfRankOrUnranked(op.orig_output(), 4)) {
2763     return op.emitOpError() << "requires orig_output to be rank 4";
2764   }
2765   if (!IsOfRankOrUnranked(op.grad(), 4)) {
2766     return op.emitOpError() << "requires grad to be rank 4";
2767   }
2768   return success();
2769 }
2770 
2771 //===----------------------------------------------------------------------===//
2772 // MeanOp
2773 //===----------------------------------------------------------------------===//
2774 
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2775 LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
2776   // Reduction indices must be defined by a constant operation.
2777   auto reduction_op =
2778       dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
2779   if (!reduction_op) return failure();
2780 
2781   auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
2782   if (!reductions_value) return failure();
2783 
2784   // Prepare new reduction indices according to operand permutation.
2785   SmallVector<int32_t, 4> shuffled_reduction;
2786   llvm::transform(reductions_value.getIntValues(),
2787                   std::back_inserter(shuffled_reduction),
2788                   [&](APInt idx) { return permutation[idx.getSExtValue()]; });
2789 
2790   // Add constant operation with a new reduction indices.
2791   OpBuilder builder(getOperation());
2792   auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
2793                                           builder.getIntegerType(32));
2794   auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
2795   auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
2796 
2797   // Use new reduction indices.
2798   setOperand(1, shuffled_reduction_op);
2799 
2800   return success();
2801 }
2802 
2803 //===----------------------------------------------------------------------===//
2804 // MulOp
2805 //===----------------------------------------------------------------------===//
2806 
fold(ArrayRef<Attribute> operands)2807 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
2808   return IdentityArithmeticOpFolder<MulOp>(*this, operands);
2809 }
2810 
2811 }  // namespace TF
2812 }  // namespace mlir
2813 
2814 //===----------------------------------------------------------------------===//
2815 // TableGen'd op method definitions
2816 //===----------------------------------------------------------------------===//
2817 
2818 #define GET_OP_CLASSES
2819 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc"
2820