1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 #include <string>
25 #include <tuple>
26 #include <type_traits>
27
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/ADT/iterator_range.h"
39 #include "llvm/Support/Casting.h"
40 #include "llvm/Support/FormatVariadic.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
43 #include "mlir/Dialect/Traits.h" // from @llvm-project
44 #include "mlir/IR/Attributes.h" // from @llvm-project
45 #include "mlir/IR/Builders.h" // from @llvm-project
46 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
47 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
48 #include "mlir/IR/Diagnostics.h" // from @llvm-project
49 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
50 #include "mlir/IR/Identifier.h" // from @llvm-project
51 #include "mlir/IR/Location.h" // from @llvm-project
52 #include "mlir/IR/MLIRContext.h" // from @llvm-project
53 #include "mlir/IR/Matchers.h" // from @llvm-project
54 #include "mlir/IR/OpDefinition.h" // from @llvm-project
55 #include "mlir/IR/OpImplementation.h" // from @llvm-project
56 #include "mlir/IR/PatternMatch.h" // from @llvm-project
57 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
58 #include "mlir/IR/Types.h" // from @llvm-project
59 #include "mlir/IR/Value.h" // from @llvm-project
60 #include "mlir/Parser.h" // from @llvm-project
61 #include "mlir/Support/LLVM.h" // from @llvm-project
62 #include "mlir/Support/LogicalResult.h" // from @llvm-project
63 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
65 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
66 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
70 #include "tensorflow/core/framework/kernel_shape_util.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/util/padding.h"
73 #include "tensorflow/core/util/tensor_format.h"
74
75 namespace mlir {
76 namespace TF {
77
78 namespace {
79 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc"
80 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
81 } // namespace
82
83 //===----------------------------------------------------------------------===//
84 // AddOp
85 //===----------------------------------------------------------------------===//
86
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)87 void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
88 MLIRContext *context) {
89 results.insert<AddToAddV2>(context);
90 }
91
92 //===----------------------------------------------------------------------===//
93 // AddNOp
94 //===----------------------------------------------------------------------===//
95
fold(ArrayRef<Attribute> operands)96 OpFoldResult AddNOp::fold(ArrayRef<Attribute> operands) {
97 if (operands.size() == 1) return *inputs().begin();
98 return {};
99 }
100
101 //===----------------------------------------------------------------------===//
102 // AddV2Op
103 //===----------------------------------------------------------------------===//
104
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)105 void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
106 MLIRContext *context) {
107 results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
108 }
109
fold(ArrayRef<Attribute> operands)110 OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
111 return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
112 }
113
114 //===----------------------------------------------------------------------===//
115 // AllOp
116 //===----------------------------------------------------------------------===//
117
Verify(AllOp op)118 static LogicalResult Verify(AllOp op) {
119 return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
120 op.getLoc());
121 }
122
123 //===----------------------------------------------------------------------===//
124 // AnyOp
125 //===----------------------------------------------------------------------===//
126
Verify(AnyOp op)127 static LogicalResult Verify(AnyOp op) {
128 return VerifyReductionInputAndDims(op.input(), op.reduction_indices(),
129 op.getLoc());
130 }
131
132 //===----------------------------------------------------------------------===//
133 // AssertOp
134 //===----------------------------------------------------------------------===//
135
136 namespace {
137
138 // Removes Assert with constant true predicate.
139 struct AssertWithTrue : public OpRewritePattern<AssertOp> {
140 using OpRewritePattern<AssertOp>::OpRewritePattern;
141
matchAndRewritemlir::TF::__anon0d5221e00211::AssertWithTrue142 LogicalResult matchAndRewrite(AssertOp op,
143 PatternRewriter &rewriter) const override {
144 ElementsAttr cst;
145 if (matchPattern(op.condition(), m_Constant(&cst))) {
146 if (cst.getValue<BoolAttr>({}).getValue()) {
147 rewriter.eraseOp(op);
148 return success();
149 }
150 }
151 return failure();
152 }
153 };
154 } // namespace
155
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)156 void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
157 MLIRContext *context) {
158 results.insert<AssertWithTrue>(context);
159 }
160
161 //===----------------------------------------------------------------------===//
162 // BatchMatMulV2Op & BatchMatMulOp
163 //===----------------------------------------------------------------------===//
164
165 template <typename OpT,
166 typename std::enable_if<llvm::is_one_of<
167 OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
Verify(OpT op)168 static LogicalResult Verify(OpT op) {
169 if (!HasRankAtLeast(op.x(), 2)) {
170 return op.emitOpError("requires lhs operand to have rank at least two");
171 }
172 if (!HasRankAtLeast(op.y(), 2)) {
173 return op.emitOpError("requires rhs operand to have rank at least two");
174 }
175
176 RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
177 RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
178
179 if (!x_ty || !y_ty) return success();
180
181 ArrayRef<int64_t> x_shape = x_ty.getShape();
182 ArrayRef<int64_t> y_shape = y_ty.getShape();
183
184 llvm::SmallVector<int64_t, 4> result_batch_shape;
185 llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
186 llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
187
188 // Check compatibility of batch dimensions if both input shapes are known.
189 // BatchMatMul should have exactly the same batch dimensions and
190 // BatchMatMulV2 should have broadcastable batch dimensions.
191 //
192 // The last two dimensions are non-batch dimensions that don't need to
193 // participate in batch dimension compatibility check.
194 if (std::is_same<OpT, BatchMatMulOp>()) {
195 for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
196 int64_t x_dim = std::get<0>(dim_pairs);
197 int64_t y_dim = std::get<1>(dim_pairs);
198 if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
199 x_dim != y_dim) {
200 return op.emitOpError()
201 << "found mismatching batch dimensions for lhs shape " << x_ty
202 << " and rhs shape " << y_ty;
203 }
204 }
205 } else {
206 if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
207 result_batch_shape))
208 return op.emitOpError()
209 << "found incompatible broadcast batch dimensions for lhs shape "
210 << x_ty << " and rhs shape " << y_ty;
211 }
212
213 RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
214 if (!output_ty) return success();
215
216 int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
217 if (output_ty.getRank() != expected_output_rank)
218 return op.emitOpError()
219 << "found invalid output rank, expected " << expected_output_rank
220 << " but got " << output_ty.getRank();
221
222 // Check output batch dim with potential broadcasting.
223 ArrayRef<int64_t> output_shape = output_ty.getShape();
224 for (int i = 0; i < result_batch_shape.size(); ++i) {
225 if (output_shape[i] != ShapedType::kDynamicSize &&
226 output_shape[i] != result_batch_shape[i])
227 return op.emitOpError()
228 << "has mismatching input batch dimension "
229 << result_batch_shape[i] << " and output batch dimension "
230 << output_shape[i];
231 }
232
233 // Check output shape for non-batch dimension, following documentation below.
234 // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
235 int64_t x_row_dim = x_shape[x_shape.size() - 2];
236 int64_t x_col_dim = x_shape[x_shape.size() - 1];
237 int64_t y_row_dim = y_shape[y_shape.size() - 2];
238 int64_t y_col_dim = y_shape[y_shape.size() - 1];
239 int64_t out_row_dim = output_shape[output_shape.size() - 2];
240 int64_t out_col_dim = output_shape[output_shape.size() - 1];
241
242 int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
243 int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
244
245 if (expected_out_row_dim != ShapedType::kDynamicSize &&
246 out_row_dim != ShapedType::kDynamicSize &&
247 out_row_dim != expected_out_row_dim)
248 return op.emitOpError()
249 << "found invalid output dimension on row, expected "
250 << expected_out_row_dim << " but got " << out_row_dim;
251 if (expected_out_col_dim != ShapedType::kDynamicSize &&
252 out_col_dim != ShapedType::kDynamicSize &&
253 out_col_dim != expected_out_col_dim)
254 return op.emitOpError()
255 << "found invalid output dimension on col, expected "
256 << expected_out_col_dim << " but got " << out_col_dim;
257
258 return success();
259 }
260
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)261 void BatchMatMulOp::getCanonicalizationPatterns(
262 OwningRewritePatternList &results, MLIRContext *context) {
263 results.insert<BatchMatMulToV2>(context);
264 }
265
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)266 void BatchMatMulV2Op::getCanonicalizationPatterns(
267 OwningRewritePatternList &results, MLIRContext *context) {
268 results.insert<BatchMatMulV2ToMatMul>(context);
269 }
270
271 //===----------------------------------------------------------------------===//
272 // BatchToSpaceOp
273 //===----------------------------------------------------------------------===//
274
Verify(BatchToSpaceOp op)275 static LogicalResult Verify(BatchToSpaceOp op) {
276 // Op already has a constraint that block_size >= 2.
277 int64_t block_size = op.block_size();
278
279 llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
280 auto input_type = op.input().getType().cast<TensorType>();
281 if (input_type.hasRank()) {
282 if (input_type.getRank() != 4)
283 return op.emitOpError()
284 << "requires input to be a 4D tensor, but got " << input_type;
285
286 int64_t input_batch = input_type.getDimSize(0);
287 if (input_batch != ShapedType::kDynamicSize &&
288 input_batch % (block_size * block_size) != 0) {
289 return op.emitOpError()
290 << "requires input batch (dimension 0) to be evenly divisible "
291 "by (block_size * block_size), but got input batch "
292 << input_batch << " and block_size " << block_size;
293 }
294
295 input_shape.assign(input_type.getShape().begin(),
296 input_type.getShape().end());
297 }
298
299 auto crops_type = op.crops().getType().cast<TensorType>();
300 if (crops_type.hasRank()) {
301 if (crops_type.getRank() != 2)
302 return op.emitOpError()
303 << "requires crops to be a 2D tensor, but got " << crops_type;
304
305 auto dim_of_size = [&](int64_t dim, int64_t size) {
306 if (crops_type.isDynamicDim(dim)) return true;
307 return crops_type.getDimSize(dim) == size;
308 };
309 if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
310 return op.emitOpError()
311 << "requires crops to be a tensor<2x2>, but got " << crops_type;
312 }
313
314 DenseIntElementsAttr crops_attr;
315 // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
316 // and flattened as [crop_top, crop_bottom, crop_left, crop_right]
317 llvm::SmallVector<int64_t, 4> crops_values;
318 if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
319 assert(crops_attr.getNumElements() == 4 &&
320 "tf.BatchToSpace crops must have 4 elements");
321
322 auto crops_range = crops_attr.getIntValues();
323 for (const auto &crops_value : crops_range) {
324 int64_t crops_value_int = crops_value.getSExtValue();
325 if (crops_value_int < 0)
326 return op.emitOpError()
327 << "requires all crop values to be nonnegative, but got "
328 << crops_attr;
329
330 crops_values.push_back(crops_value_int);
331 }
332 }
333
334 auto output_type = op.output().getType().cast<TensorType>();
335 if (output_type.hasRank()) {
336 if (output_type.getRank() != 4)
337 return op.emitOpError()
338 << "requires output to be a 4D tensor, but got " << output_type;
339
340 auto static_dims = [](int64_t dim_a, int64_t dim_b) {
341 return dim_a != ShapedType::kDynamicSize &&
342 dim_b != ShapedType::kDynamicSize;
343 };
344
345 auto output_shape = output_type.getShape();
346
347 // output batch = input batch / (block_size * block_size).
348 int64_t input_batch = input_shape[0];
349 int64_t output_batch = output_shape[0];
350 if (static_dims(input_batch, output_batch) &&
351 (output_batch * block_size * block_size) != input_batch)
352 return op.emitOpError()
353 << "requires output batch (dimension 0) to be equal to input "
354 "batch (dimension 0) / (block_size * block_size), but got "
355 "output batch "
356 << output_batch << ", input batch " << input_batch
357 << ", and block_size " << block_size;
358
359 auto check_spatial_dim = [&](int64_t spatial_dim_index,
360 llvm::StringRef dim_name,
361 llvm::StringRef crop_a_name,
362 llvm::StringRef crop_b_name) -> LogicalResult {
363 int64_t input_dim = input_shape[spatial_dim_index];
364 int64_t output_dim = output_shape[spatial_dim_index];
365 if (!static_dims(input_dim, output_dim)) return success();
366
367 int64_t input_dim_pad = input_dim * block_size;
368 // If crops are unknown, the maximum output spatial dim size is input
369 // spatial dim size * block_size, as crops can be minimum 0.
370 if (crops_values.empty() && output_dim > input_dim * block_size)
371 return op.emitOpError()
372 << "requires output " << dim_name << " (dimension "
373 << spatial_dim_index << ") to be less than or equal to input "
374 << dim_name << " (dimension " << spatial_dim_index
375 << ") * block_size, but got output " << dim_name << " "
376 << output_dim << ", input " << dim_name << " " << input_dim
377 << ", and block_size " << block_size;
378
379 if (!crops_values.empty()) {
380 // output spatial dim = input spatial dim * block_size - crops.
381 int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
382 int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
383 if (output_dim != input_dim_pad - crop_a - crop_b)
384 return op.emitOpError()
385 << "requires output " << dim_name << " (dimension "
386 << spatial_dim_index << ") to be equal to input " << dim_name
387 << " (dimension " << spatial_dim_index << ") * block_size - "
388 << crop_a_name << " - " << crop_b_name << ", but got output "
389 << dim_name << " " << output_dim << ", input " << dim_name
390 << " " << input_dim << ", " << crop_a_name << " " << crop_a
391 << ", " << crop_b_name << " " << crop_b << ", and block_size "
392 << block_size;
393 }
394
395 return success();
396 };
397
398 if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
399 failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
400 return failure();
401
402 int64_t input_depth = input_shape[3];
403 int64_t output_depth = output_shape[3];
404 if (static_dims(input_depth, output_depth) && output_depth != input_depth)
405 return op.emitOpError()
406 << "requires output depth (dimension 3) to be equal to input "
407 "depth (dimension 3), but got output depth "
408 << output_depth << " and input depth " << input_depth;
409 }
410
411 return success();
412 }
413
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)414 void BatchToSpaceOp::getCanonicalizationPatterns(
415 OwningRewritePatternList &results, MLIRContext *context) {
416 results.insert<BatchToSpaceToBatchToSpaceND>(context);
417 }
418
419 //===----------------------------------------------------------------------===//
420 // BatchToSpaceNDOp
421 //===----------------------------------------------------------------------===//
422
Verify(BatchToSpaceNDOp op)423 static LogicalResult Verify(BatchToSpaceNDOp op) {
424 auto block_shape_ty = op.block_shape().getType().cast<ShapedType>();
425 auto crops_ty = op.crops().getType().cast<ShapedType>();
426
427 if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) {
428 const int block_rank = block_shape_ty.getShape().front();
429 if (crops_ty.getRank() != 2 || crops_ty.getShape().front() != block_rank ||
430 crops_ty.getShape()[1] != 2) {
431 op.emitOpError() << "crops should have shape [" << block_rank
432 << ", 2] instead of " << crops_ty.getShape();
433 return failure();
434 }
435 }
436
437 return success();
438 }
439
440 //===----------------------------------------------------------------------===//
441 // BiasAddOp
442 //===----------------------------------------------------------------------===//
443
444 // Verifies that,
445 // * the value and bias operands have valid ranks or are unranked.
446 // * Channel dimension of the value operand and length of bias matches if they
447 // are not unknown.
448 //
Verify(BiasAddOp op)449 static LogicalResult Verify(BiasAddOp op) {
450 absl::string_view data_format(op.data_format().data(),
451 op.data_format().size());
452 tensorflow::TensorFormat format;
453 bool is_valid = FormatFromString(data_format, &format);
454 DCHECK(is_valid) << data_format;
455 if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
456 if (!HasRankAtLeast(op.value(), 2))
457 return op.emitOpError(
458 "requires value operand to have rank at least two with `NHWC` data "
459 "format");
460 } else {
461 // Op definition requires data_format to be either NHWC or NCHW.
462 DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
463 if (!HasRankAtLeast(op.value(), 3))
464 return op.emitOpError(
465 "requires value operand to have rank at least three with `NCHW` data "
466 "format");
467 }
468
469 if (!IsOfRankOrUnranked(op.bias(), 1))
470 return op.emitOpError("requires bias operand to have rank exactly one");
471
472 RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
473 RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
474 if (!bias_ty || !value_ty) return success();
475
476 int64_t feature_dim_idx =
477 tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format);
478 int64_t feature_dim = value_ty.getDimSize(feature_dim_idx);
479 int64_t bias_len = bias_ty.getDimSize(0);
480 if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) {
481 return op.emitOpError()
482 << "requires channel dimension and feature dimension to match; "
483 "found "
484 << feature_dim << " and " << bias_len << ", respectively";
485 }
486 return success();
487 }
488
GetContractionFusion()489 Optional<ContractionFusion> BiasAddOp::GetContractionFusion() {
490 // Only NHWC in f32 is supported for fusion.
491 if (data_format() != "NHWC" || !T().isF32()) return None;
492
493 return ContractionFusion("BiasAdd", /*additional_arguments=*/{1});
494 }
495
UpdateDataFormat(StringRef data_format)496 LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
497 return ::mlir::TF::UpdateDataFormat(data_format, this);
498 }
499
GetOptimalLayout(const RuntimeDevices & devices)500 StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) {
501 // Keep current data format if no GPUs are available or if explicit placement
502 // does not allow to use GPU for this operation.
503 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
504 return data_format();
505
506 // Prefer NHWC for GPU devices.
507 return "NHWC";
508 }
509
510 //===----------------------------------------------------------------------===//
511 // BiasAddGradOp
512 //===----------------------------------------------------------------------===//
513
514 // Verifies that,
515 // * the out_backprop operands have valid ranks or are unranked.
516 //
Verify(BiasAddGradOp op)517 static LogicalResult Verify(BiasAddGradOp op) {
518 absl::string_view data_format(op.data_format().data(),
519 op.data_format().size());
520 tensorflow::TensorFormat format;
521 bool is_valid = FormatFromString(data_format, &format);
522 DCHECK(is_valid) << data_format;
523 if (format == tensorflow::TensorFormat::FORMAT_NHWC) {
524 if (!HasRankAtLeast(op.out_backprop(), 2))
525 return op.emitOpError(
526 "requires out_backprop operand to have rank at least two with `NHWC` "
527 "data format");
528 } else {
529 // Op definition requires data_format to be either NHWC or NCHW.
530 DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW);
531 if (!HasRankAtLeast(op.out_backprop(), 3))
532 return op.emitOpError(
533 "requires out_backprop operand to have rank at least three with "
534 "`NCHW` data format");
535 }
536
537 return success();
538 }
539
540 //===----------------------------------------------------------------------===//
541 // BiasAddV1Op
542 //===----------------------------------------------------------------------===//
543
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)544 void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
545 MLIRContext *context) {
546 results.insert<BiasAddV1ToBiasAdd>(context);
547 }
548
549 //===----------------------------------------------------------------------===//
550 // BitcastOp
551 //===----------------------------------------------------------------------===//
552
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)553 void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
554 MLIRContext *context) {
555 results.insert<BitcastSameType, BitcastNested>(context);
556 }
557
558 //===----------------------------------------------------------------------===//
559 // BroadcastToOp
560 //===----------------------------------------------------------------------===//
561
Verify(BroadcastToOp op)562 static LogicalResult Verify(BroadcastToOp op) {
563 // TODO(antiagainst): check that
564 // * The 'shape' input is an 1-D int tensor.
565 // * Each dimension pair of the source and target shapes are either equal
566 // or one of them is one.
567 return success();
568 }
569
fold(ArrayRef<Attribute> operands)570 OpFoldResult BroadcastToOp::fold(ArrayRef<Attribute> operands) {
571 Value input = this->input();
572
573 // Fold broadcast if operand and result types are the same and all dimensions
574 // are statically known (no-op broadcast).
575 auto result_ty = getType().dyn_cast<ShapedType>();
576 if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) {
577 return input;
578 }
579
580 return {};
581 }
582
583 //===----------------------------------------------------------------------===//
584 // BroadcastGradientArgsOp
585 //===----------------------------------------------------------------------===//
586
587 namespace {
588 // Returns `true` if both s0 & s1 are defined via constant op, and fills
589 // s0_shape & s1_shape.
ExtractInputConstShape(BroadcastGradientArgsOp op,DenseIntElementsAttr & s0,DenseIntElementsAttr & s1,SmallVectorImpl<int64_t> & s0_shape,SmallVectorImpl<int64_t> & s1_shape)590 bool ExtractInputConstShape(BroadcastGradientArgsOp op,
591 DenseIntElementsAttr &s0, DenseIntElementsAttr &s1,
592 SmallVectorImpl<int64_t> &s0_shape,
593 SmallVectorImpl<int64_t> &s1_shape) {
594 if (!matchPattern(op.s0(), m_Constant(&s0))) return false;
595 if (!matchPattern(op.s1(), m_Constant(&s1))) return false;
596
597 for (auto s : s0.getIntValues()) s0_shape.push_back(s.getSExtValue());
598 for (auto s : s1.getIntValues()) s1_shape.push_back(s.getSExtValue());
599
600 return true;
601 }
602
603 // Calculates r0 & r1 output based on inputs and calculated broadcasted shape.
604 //
605 // For given bcasted_shape, s0_shape and s1_shape, the broadcasted dimension is
606 // calculated and push back to its corresponding result, r0 or r1. For example,
607 // for s0_shape [1,4] and s1_shape [4, 4], bcasted_shape is computed to be
608 // [4,4] - this leads to the result of r0 to be [0] as the first dimension of s0
609 // is broadcasted, and r1 to be <> as no broadcasting is happening for s1.
GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,ArrayRef<int64_t> s0_shape,ArrayRef<int64_t> s1_shape,SmallVectorImpl<int64_t> & r0,SmallVectorImpl<int64_t> & r1)610 void GetOutputShapeForBroadcastGradientArgs(ArrayRef<int64_t> bcasted_shape,
611 ArrayRef<int64_t> s0_shape,
612 ArrayRef<int64_t> s1_shape,
613 SmallVectorImpl<int64_t> &r0,
614 SmallVectorImpl<int64_t> &r1) {
615 r0.clear();
616 r1.clear();
617
618 // No broadcasting is required if both the shapes are equal.
619 if (s0_shape == s1_shape) return;
620
621 for (int i = bcasted_shape.size(); i > 0; --i) {
622 int idx = bcasted_shape.size() - i;
623 int s0_idx = i > s0_shape.size() ? -1 : s0_shape.size() - i;
624 int s1_idx = i > s1_shape.size() ? -1 : s1_shape.size() - i;
625 if (s0_idx == -1) {
626 r0.push_back(idx);
627 if (s1_shape[s1_idx] == 1) r1.push_back(idx);
628 } else if (s1_idx == -1) {
629 r1.push_back(idx);
630 if (s0_shape[s0_idx] == 1) r0.push_back(idx);
631 } else if (s0_shape[s0_idx] != s1_shape[s1_idx]) {
632 if (s0_shape[s0_idx] != bcasted_shape[idx])
633 r0.push_back(idx);
634 else
635 r1.push_back(idx);
636 } else if (s0_shape[s0_idx] == 1) {
637 // This op is used to compute the gradient dimensions requiring reduction
638 // to match the input dimensions. In case both the dimensions are one,
639 // reducing the dimension has no effect. We choose to reduce such
640 // dimensions to match the TensorFlow kernel behavior. However, note that
641 // the TF behavior in this case is inconsistent with the case with the
642 // same shapes.
643 r0.push_back(idx);
644 r1.push_back(idx);
645 }
646 }
647 }
648 } // namespace
649
650 // Verifies that,
651 // * Broadcast compatability for input shapes.
652 // * Output shape dimension matches the expected dimension size for input
653 // shapes.
Verify(BroadcastGradientArgsOp op)654 static LogicalResult Verify(BroadcastGradientArgsOp op) {
655 SmallVector<int64_t, 4> s0_shape, s1_shape;
656 DenseIntElementsAttr s0, s1;
657 if (!ExtractInputConstShape(op, s0, s1, s0_shape, s1_shape)) return success();
658
659 // If both shape is known const, try to validate shape on them as well.
660 SmallVector<int64_t, 4> bcasted_shape;
661 if (!OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape))
662 return op.emitOpError() << "requires broadcast compatible shape tensors "
663 "for 's0' and 's1', but got "
664 << s0 << " and " << s1;
665
666 SmallVector<int64_t, 4> r0, r1;
667 GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
668 r1);
669
670 // Verify that output types are of rank one and matches the computed result
671 // shape.
672 auto r0_ty = op.r0().getType().dyn_cast<RankedTensorType>();
673 auto r1_ty = op.r1().getType().dyn_cast<RankedTensorType>();
674 if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size())
675 return op.emitOpError() << "requires dimension 0 size of 'r0' to be "
676 << r0.size() << " but got " << r0_ty.getShape()[0];
677 if (r1_ty && r1_ty.hasStaticShape() && r1_ty.getDimSize(0) != r1.size())
678 return op.emitOpError() << "requires dimension 0 size of 'r1' to be "
679 << r1.size() << " but got " << r1_ty.getShape()[0];
680
681 return success();
682 }
683
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)684 LogicalResult BroadcastGradientArgsOp::fold(
685 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
686 SmallVector<int64_t, 4> s0_shape, s1_shape;
687 DenseIntElementsAttr s0, s1;
688 if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape))
689 return failure();
690
691 // Fold BroadcastGradientArgs into two constants if both of the inputs have
692 // known shape.
693 SmallVector<int64_t, 4> bcasted_shape;
694 // Verifier should already ensure the broadcast compatibility.
695 bool bcast_compatible =
696 OpTrait::util::getBroadcastedShape(s0_shape, s1_shape, bcasted_shape);
697 assert(bcast_compatible);
698 (void)bcast_compatible;
699
700 SmallVector<int64_t, 4> r0, r1;
701 GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0,
702 r1);
703
704 auto build_out_dense_element = [](SmallVectorImpl<int64_t> &shape,
705 Type input_type) {
706 Type element_type = input_type.cast<mlir::TensorType>().getElementType();
707 RankedTensorType type = RankedTensorType::get(
708 {static_cast<int64_t>(shape.size())}, element_type);
709 // Input could only be i32 or i64. For i32, downcast to int32_t array.
710 if (element_type.isInteger(32)) {
711 SmallVector<int32_t, 4> i32_shape;
712 for (auto s : shape) i32_shape.push_back(static_cast<int32_t>(s));
713 return DenseIntElementsAttr::get(type, i32_shape);
714 } else {
715 assert(element_type.isInteger(64));
716 return DenseIntElementsAttr::get(type, shape);
717 }
718 };
719
720 results.push_back(build_out_dense_element(r0, this->s0().getType()));
721 results.push_back(build_out_dense_element(r1, this->s1().getType()));
722
723 return success();
724 }
725
726 //===----------------------------------------------------------------------===//
727 // CaseOp
728 //===----------------------------------------------------------------------===//
729
730 class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
731 public:
FoldConstantCaseOp(MLIRContext * context)732 explicit FoldConstantCaseOp(MLIRContext *context)
733 : OpRewritePattern<TF::CaseOp>(context) {}
734 LogicalResult matchAndRewrite(TF::CaseOp op,
735 PatternRewriter &rewriter) const override;
736 };
737
matchAndRewrite(TF::CaseOp op,PatternRewriter & rewriter) const738 LogicalResult FoldConstantCaseOp::matchAndRewrite(
739 TF::CaseOp op, PatternRewriter &rewriter) const {
740 // Extract the constant cond value.
741 DenseIntElementsAttr branch;
742 if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
743
744 int index = *branch.getValues<int>().begin();
745 if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
746
747 auto func = op.branches()[index].cast<SymbolRefAttr>();
748 auto empty = rewriter.getStringAttr("");
749 auto call_op = rewriter.create<PartitionedCallOp>(
750 op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
751 /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
752 CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
753 rewriter.replaceOp(op, call_op.getResults());
754 return success();
755 }
756
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)757 void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
758 MLIRContext *context) {
759 results.insert<FoldConstantCaseOp, DropAttributes<CaseOp>>(context);
760 }
761
VerifyCaseOpBase(Operation * op,Value branch_index)762 static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) {
763 if (!IsOfRankOrUnranked(branch_index, 0))
764 return op->emitOpError()
765 << "expects 'branch_index' to be a scalar, but got "
766 << branch_index.getType();
767 return success();
768 }
769
VerifyCaseOrIfOpBranchFunctions(Operation * op,ArrayRef<Attribute> branches,llvm::function_ref<std::string (unsigned branch_index)> branch_name)770 static LogicalResult VerifyCaseOrIfOpBranchFunctions(
771 Operation *op, ArrayRef<Attribute> branches,
772 llvm::function_ref<std::string(unsigned branch_index)> branch_name) {
773 SmallVector<FunctionType, 2> branch_types;
774 branch_types.reserve(branches.size());
775
776 // Functions have one less operand compared to op as first operand is elided
777 // (`cond` of `tf.If` and `branch_index` of `tf.Case`).
778 TypeRangeWithDesc input{op->getOperands().drop_front().getTypes(), "input"};
779 TypeRangeWithDesc result{op->getResultTypes(), "result"};
780
781 for (auto branch : llvm::enumerate(branches)) {
782 auto branch_func = SymbolTable::lookupNearestSymbolFrom<FuncOp>(
783 op, branch.value().cast<SymbolRefAttr>());
784 if (!branch_func)
785 return op->emitOpError()
786 << "expects " << branch_name(branch.index()) << " ("
787 << branch.value() << ") to point to a defined function";
788
789 FunctionType branch_type = branch_func.getType();
790 std::string desc = branch_name(branch.index()) + " input";
791 TypeRangeWithDesc branch_input{branch_type.getInputs(), desc};
792 if (failed(VerifyTypeRangesAreCompatible(op, branch_input, input)))
793 return failure();
794
795 desc = branch_name(branch.index()) + " result";
796 TypeRangeWithDesc branch_result{branch_type.getResults(), desc};
797 if (failed(VerifyTypeRangesAreCompatible(op, branch_result, result)))
798 return failure();
799
800 branch_types.push_back(branch_type);
801 }
802
803 // If branches have incompatible input types that means that no tensor can
804 // serve as input to all the functions. Hence, the op is invalid.
805 int expected_num_inputs = op->getNumOperands() - 1;
806 for (int i = 0; i < expected_num_inputs; ++i) {
807 SmallVector<Type, 2> branch_input_i_types;
808 branch_input_i_types.reserve(branches.size());
809 llvm::transform(
810 branch_types, std::back_inserter(branch_input_i_types),
811 [i](FunctionType &branch_type) { return branch_type.getInput(i); });
812 if (!AreCastCompatible(branch_input_i_types)) {
813 std::string input_types_str;
814 llvm::raw_string_ostream os(input_types_str);
815 llvm::interleaveComma(branch_input_i_types, os);
816 return op->emitOpError()
817 << "expects all branch input type(s) (" << os.str()
818 << ") at index " << i << " to be cast compatible";
819 }
820 }
821
822 return success();
823 }
824
Verify(CaseOp op)825 static LogicalResult Verify(CaseOp op) {
826 if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
827 auto branch_name = [](unsigned index) {
828 return llvm::formatv("branch #{0}", index).str();
829 };
830 return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(),
831 branch_name);
832 }
833
834 //===----------------------------------------------------------------------===//
835 // CaseRegionOp
836 //===----------------------------------------------------------------------===//
837
Verify(CaseRegionOp op)838 static LogicalResult Verify(CaseRegionOp op) {
839 if (op.branches().empty())
840 return op.emitOpError() << "expects to have at least 1 region";
841
842 if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure();
843
844 TypeRangeWithDesc results{op.getResultTypes(), "result"};
845
846 for (auto region_and_idx : llvm::enumerate(op.branches())) {
847 std::string description =
848 llvm::formatv("branch #{0} result", region_and_idx.index()).str();
849 Operation *yield = region_and_idx.value().front().getTerminator();
850 TypeRangeWithDesc branch_results{yield->getOperandTypes(), description};
851 if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results)))
852 return failure();
853 }
854
855 return success();
856 }
857
858 namespace {
859 // Eliminate values that pass through the CaseRegionOp or IfRegionOp branches.
860 template <class CaseOrIfRegionOp>
861 class CaseOrIfRegionEliminatePassThrough
862 : public OpRewritePattern<CaseOrIfRegionOp> {
863 using OpRewritePattern<CaseOrIfRegionOp>::OpRewritePattern;
864
matchAndRewrite(CaseOrIfRegionOp op,PatternRewriter & rewriter) const865 LogicalResult matchAndRewrite(CaseOrIfRegionOp op,
866 PatternRewriter &rewriter) const override {
867 RegionRange branches = op.getRegions();
868 SmallVector<Type, 4> new_result_types;
869 // Maps pass through results to extern values.
870 llvm::SmallDenseMap<Value, Value, 4> result_to_extern_value;
871
872 for (auto result : op.getResults()) {
873 unsigned index = result.getResultNumber();
874 Region *first_branch = *branches.begin();
875 Operation *first_terminator = first_branch->front().getTerminator();
876 Value returned_val = first_terminator->getOperand(index);
877
878 // Pass through values would be defined outside the branch region. Keep
879 // the type of non pass through results to create a new op later, if
880 // required.
881 if (returned_val.getParentBlock() == &first_branch->front()) {
882 new_result_types.push_back(result.getType());
883 continue;
884 }
885 // Check if the same extern value is returned in each branch.
886 for (Region *region : branches.drop_front()) {
887 Operation *terminator = region->front().getTerminator();
888 if (terminator->getOperand(index) != returned_val) return failure();
889 }
890 result_to_extern_value[result] = returned_val;
891 }
892
893 // If no pass through values are found, no change is required.
894 if (result_to_extern_value.empty()) return failure();
895
896 // Create new case/if region op.
897 auto new_op = rewriter.create<CaseOrIfRegionOp>(
898 op.getLoc(), new_result_types, op.getOperand(), op.getAttrs(),
899 op.getNumRegions());
900
901 int next_index = 0;
902 for (auto result : op.getResults()) {
903 if (!result_to_extern_value.count(result)) {
904 result.replaceAllUsesWith(new_op.getResult(next_index++));
905 continue;
906 }
907 result.replaceAllUsesWith(result_to_extern_value[result]);
908 for (Region *branch : branches)
909 branch->front().getTerminator()->eraseOperand(next_index);
910 }
911
912 // Move region bodies to the new op.
913 for (auto region_index : llvm::seq<int>(0, branches.size()))
914 new_op.getRegion(region_index).takeBody(op.getRegion(region_index));
915
916 op.erase();
917 return success();
918 }
919 };
920 } // namespace
921
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)922 void CaseRegionOp::getCanonicalizationPatterns(
923 OwningRewritePatternList &results, MLIRContext *context) {
924 results.insert<CaseOrIfRegionEliminatePassThrough<TF::CaseRegionOp>>(context);
925 }
926
927 //===----------------------------------------------------------------------===//
928 // CastOp
929 //===----------------------------------------------------------------------===//
930
fold(ArrayRef<Attribute> operands)931 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
932 // Cast with the same type is a no-op.
933 Value operand = getOperand();
934 if (getType() == operand.getType()) return operand;
935 return {};
936 }
937
938 //===----------------------------------------------------------------------===//
939 // ConcatOp and ConcatV2Op
940 //===----------------------------------------------------------------------===//
941
942 template <typename OpT,
943 typename std::enable_if<llvm::is_one_of<
944 OpT, ConcatOp, ConcatV2Op>::value>::type * = nullptr>
Verify(OpT op)945 static LogicalResult Verify(OpT op) {
946 // TODO(hinsu): Convert variadic length attributes to derived attributes.
947 Operation::operand_range values = op.values();
948
949 int axis_idx = std::is_same<OpT, ConcatOp>() ? 0 : 1;
950 Value axis = *op.getODSOperands(axis_idx).begin();
951 if (!HasRankAtMost(axis, 1)) {
952 return op.emitOpError(
953 "requires axis to be of scalar type (or vector type for older "
954 "versions)");
955 }
956
957 return VerifyTypesCompatibility(values,
958 /*mask_one_dim=*/true, op.getOperation());
959 }
960
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)961 void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
962 MLIRContext *context) {
963 results.insert<ConvertToConcatV2>(context);
964 }
965
966 namespace {
967
968 // Hoist coefficient-wise unary operation out of the Concat op:
969 //
970 // %0 = "tf.Log1p"(%arg_0)
971 // %1 = "tf.Log1p"(%arg_1)
972 // ...
973 // %n = "tf.Log1p"(%arg_n)
974 // %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis)
975 //
976 // Rewrite it to:
977 //
978 // %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis)
979 // %1 = "tf.Log1p"(%0)
980 class HoistCwiseUnaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
981 public:
HoistCwiseUnaryOutOfConcat(MLIRContext * context)982 explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context)
983 : OpRewritePattern<TF::ConcatV2Op>(context) {}
984 LogicalResult matchAndRewrite(TF::ConcatV2Op op,
985 PatternRewriter &rewriter) const override;
986 };
987
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const988 LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite(
989 TF::ConcatV2Op op, PatternRewriter &rewriter) const {
990 auto loc = op.getLoc();
991
992 // All concat operands must be defined by ops.
993 Operation *first_arg_op = op.values().front().getDefiningOp();
994 if (first_arg_op == nullptr) return failure();
995
996 // All concat operands must be produced by the coeff-wise unary operation.
997 if (!first_arg_op->hasTrait<OpTrait::TF::CwiseUnary>()) return failure();
998
999 // All concat operands must be defined by the op of same kind.
1000 bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool {
1001 Operation *arg_op = arg.getDefiningOp();
1002 return arg_op && arg_op->getName() == first_arg_op->getName();
1003 });
1004 if (!args_same_op) return failure();
1005
1006 // Collect unary operations operands.
1007 auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value {
1008 return arg.getDefiningOp()->getOperand(0);
1009 });
1010 SmallVector<Value, 8> unary_ops_args(unary_operands);
1011
1012 // Concatenate unary ops operands.
1013 auto concat_unary_operands =
1014 rewriter.create<ConcatV2Op>(loc, op.getType(), unary_ops_args, op.axis());
1015
1016 // Replace original concat with an unary op.
1017 OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(),
1018 concat_unary_operands.getResult(),
1019 op.getResult().getType(),
1020 ArrayRef<NamedAttribute>());
1021 Operation *new_unary_op = rewriter.createOperation(new_unary_op_state);
1022
1023 rewriter.replaceOp(op, new_unary_op->getResults());
1024
1025 return success();
1026 }
1027
1028 // Hoist coefficient-wise binary operation out of the Concat op:
1029 //
1030 // %0 = tf.Mul(%lhs_0, %rhs_0)
1031 // %1 = tf.Mul(%lhs_1, %rhs_1)
1032 // ...
1033 // %n = tf.Mul(%lhs_n, %rhs_n)
1034 // %m = tf.ConcatV2(%0, %1, ..., %n, %axis)
1035 //
1036 // Rewrite it to:
1037 //
1038 // %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis)
1039 // %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis)
1040 // %2 = tf.Mul(%0, %1)
1041 //
1042 // If a minor fraction of the Concat inputs are not of the same binary op kind
1043 // (tf.Mul in the above example), we will synthesize the binary ops for those
1044 // inputs. e.g. if we instead have %1 = %lhs_1, then we would synthesize a
1045 // tf.Mul op over it and a scalar const tensor 1.0. For now this only applies to
1046 // float32 tensors.
1047 // TODO(hongm): Implement this op synthesis optimization for other dtypes if
1048 // needed.
1049 //
1050 // Because coefficient-wise binary operations support implicit broadcasting, we
1051 // should be very careful with this optimization, and do not accidentally
1052 // produce incorrect concat operations.
1053 class HoistCwiseBinaryOutOfConcat : public OpRewritePattern<TF::ConcatV2Op> {
1054 public:
HoistCwiseBinaryOutOfConcat(MLIRContext * context)1055 explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context)
1056 : OpRewritePattern<TF::ConcatV2Op>(context) {}
1057 LogicalResult matchAndRewrite(TF::ConcatV2Op op,
1058 PatternRewriter &rewriter) const override;
1059
1060 private:
1061 struct HoistParams {
1062 SmallVector<Value, 8> lhs_args;
1063 SmallVector<Value, 8> rhs_args;
1064 int64_t lhs_axis;
1065 int64_t rhs_axis;
1066 Type lhs_concat_type;
1067 Type rhs_concat_type;
1068 int scalar_operand_idx; // can be 0 or 1 for the binary op's operands.
1069 };
1070
1071 // Returns parameters of a binary op hoisting out of concatenation if all of
1072 // the operands are in one of the compatible configurations.
1073 // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul),
1074 // except from the ones in `exceptions`. In that case, we can synthesize that
1075 // binary op kind for the values in `exceptions`.
1076 Optional<HoistParams> GetHoistParams(
1077 TF::ConcatV2Op op, int64_t axis,
1078 const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const;
1079 };
1080
matchAndRewrite(TF::ConcatV2Op op,PatternRewriter & rewriter) const1081 LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
1082 TF::ConcatV2Op op, PatternRewriter &rewriter) const {
1083 auto loc = op.getLoc();
1084
1085 // Axis must be a constant scalar value.
1086 DenseIntElementsAttr axis_attr;
1087 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure();
1088 if (axis_attr.getNumElements() != 1) return failure();
1089 int64_t axis =
1090 axis_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
1091 // TODO(ezhulenev): Compute axis from rank. e.g. It might be common to concat
1092 // on the channels dim for NCHW layout as axis=-2.
1093 if (axis < 0) return failure();
1094
1095 // All concat operands must be defined by ops of the same kind (e.g. tf.Mul),
1096 // or some other ops that we might convert to using the same op kind above
1097 // (e.g. converting op A to tf.Mul(A, 1.0))
1098 // TODO(hongm): generalize the code here to support cases where the first arg
1099 // has no defining op (e.g. might be a block arg).
1100 Operation *first_arg_op = op.values().front().getDefiningOp();
1101 if (first_arg_op == nullptr) return failure();
1102
1103 // All concat operands must be produced by the coeff-wise binary operation.
1104 if (!first_arg_op->hasTrait<OpTrait::TF::CwiseBinary>()) return failure();
1105
1106 // All concat operands must be defined by the op of same kind, except for a
1107 // minor portion which we track in `exceptions`.
1108 // Map from the operands to operand indices.
1109 llvm::SmallDenseMap<Value, unsigned, 4> exceptions;
1110 unsigned operand_idx = 0;
1111 for (Value arg : op.values()) {
1112 Operation *arg_op = arg.getDefiningOp();
1113 if (arg_op && arg_op->getName() == first_arg_op->getName()) {
1114 ++operand_idx;
1115 continue;
1116 }
1117 exceptions[arg] = operand_idx++;
1118 }
1119 // Recall those inputs to the concat op that are not produced by a binary op
1120 // of the `first_arg_op` kind (e.g. tf.Mul) are stored in `exceptions`. If
1121 // there are too many exceptions, it might not be cost effective to apply the
1122 // concat hoisting optimization here.
1123 // Setting the threshold to be 50% as a simple cost model heuristic. e.g. If 1
1124 // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1
1125 // out of 3, we do.
1126 const float exception_pct_threshold = 0.5;
1127 if (static_cast<float>(op.values().size()) * exception_pct_threshold <=
1128 exceptions.size())
1129 return failure();
1130
1131 // Compute binary operands hoist parameters.
1132 auto hoist_params = GetHoistParams(op, axis, exceptions);
1133 if (!hoist_params.hasValue()) return failure();
1134
1135 // Process `exceptions`: For each value there, synthesize a binary op of the
1136 // above kind, so that the concat hoisting optimization can still apply.
1137 if (!exceptions.empty()) {
1138 int identity_val;
1139 if (isa<AddOp>(first_arg_op) || isa<SubOp>(first_arg_op))
1140 identity_val = 0;
1141 else if (isa<MulOp>(first_arg_op) || isa<DivOp>(first_arg_op) ||
1142 isa<RealDivOp>(first_arg_op))
1143 identity_val = 1;
1144 else
1145 return failure();
1146 DenseElementsAttr const_attr;
1147 auto scalar_tensor_type =
1148 first_arg_op->getOperand(hoist_params->scalar_operand_idx)
1149 .getType()
1150 .dyn_cast<ShapedType>();
1151 Type scalar_dtype = scalar_tensor_type.getElementType();
1152 if (scalar_dtype.isa<FloatType>())
1153 const_attr = DenseElementsAttr::get(scalar_tensor_type,
1154 static_cast<float>(identity_val));
1155 else
1156 return failure();
1157
1158 // All checks are passes, and we now prepare for rewrite.
1159 auto identity_const = rewriter.create<TF::ConstOp>(loc, const_attr);
1160 for (const auto &kv : exceptions) {
1161 assert(!hoist_params->lhs_args[kv.second]);
1162 assert(!hoist_params->rhs_args[kv.second]);
1163
1164 if (hoist_params->scalar_operand_idx == 1) {
1165 hoist_params->lhs_args[kv.second] = kv.first;
1166 hoist_params->rhs_args[kv.second] = identity_const;
1167 } else {
1168 assert(hoist_params->scalar_operand_idx == 0);
1169 hoist_params->lhs_args[kv.second] = identity_const;
1170 hoist_params->rhs_args[kv.second] = kv.first;
1171 }
1172 }
1173 }
1174
1175 // New lhs and rhs concatenation axis.
1176 auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
1177 auto lhs_axis = rewriter.create<TF::ConstOp>(
1178 loc, DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis));
1179 auto rhs_axis = rewriter.create<TF::ConstOp>(
1180 loc, DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis));
1181
1182 // Concatenate binary ops operands on the new axis.
1183 auto lhs_concat = rewriter.create<ConcatV2Op>(
1184 loc, hoist_params->lhs_concat_type, hoist_params->lhs_args, lhs_axis);
1185 auto rhs_concat = rewriter.create<ConcatV2Op>(
1186 loc, hoist_params->rhs_concat_type, hoist_params->rhs_args, rhs_axis);
1187
1188 // Replace original concat with a binary op.
1189 OperationState new_binary_op_state(
1190 loc, first_arg_op->getName().getStringRef(),
1191 {lhs_concat.getResult(), rhs_concat.getResult()},
1192 op.getResult().getType(), ArrayRef<NamedAttribute>());
1193 Operation *new_binary_op = rewriter.createOperation(new_binary_op_state);
1194
1195 rewriter.replaceOp(op, new_binary_op->getResults());
1196
1197 return success();
1198 }
1199
1200 Optional<HoistCwiseBinaryOutOfConcat::HoistParams>
GetHoistParams(TF::ConcatV2Op op,int64_t axis,const llvm::SmallDenseMap<Value,unsigned,4> & exceptions) const1201 HoistCwiseBinaryOutOfConcat::GetHoistParams(
1202 TF::ConcatV2Op op, int64_t axis,
1203 const llvm::SmallDenseMap<Value, unsigned, 4> &exceptions) const {
1204 assert(axis >= 0);
1205 // Collects lhs or rhs arguments of concat op operands.
1206 auto args = [&](int operand_idx) -> SmallVector<Value, 8> {
1207 auto range = llvm::map_range(op.values(), [&](Value arg) {
1208 if (exceptions.count(arg)) return Value();
1209 return arg.getDefiningOp()->getOperand(operand_idx);
1210 });
1211 return {range.begin(), range.end()};
1212 };
1213
1214 // Returns true if all binary ops operands at `operand_idx` index are tensors
1215 // of `axis + 1` rank and axis dim has size `1`.
1216 auto is_all_tensors = [&](int operand_idx, int axis) -> bool {
1217 return llvm::all_of(op.values(), [&](Value arg) -> bool {
1218 if (exceptions.count(arg)) return true;
1219 auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1220 auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1221 return ranked && ranked.getRank() == (axis + 1) &&
1222 ranked.getShape()[axis] == 1;
1223 });
1224 };
1225
1226 // Returns true if all binary ops operands at `operand_idx` index are scalars.
1227 auto is_all_scalars = [&](int operand_idx) -> bool {
1228 return llvm::all_of(op.values(), [&](Value arg) -> bool {
1229 if (exceptions.count(arg)) return true;
1230 auto operand = arg.getDefiningOp()->getOperand(operand_idx);
1231 auto ranked = operand.getType().dyn_cast<RankedTensorType>();
1232 return ranked && ranked.hasRank() && ranked.getRank() == 0;
1233 });
1234 };
1235
1236 // Concat result type must be a ranked tensor.
1237 auto ranked = op.getType().dyn_cast<RankedTensorType>();
1238 if (!ranked) return None;
1239
1240 // TODO(ezhulenev): Add support for more valid concat patterns.
1241
1242 // Tensor + Scalar: [..., 1] + [] <- scalar
1243 // ^
1244 // \- axis is the innermost dimension.
1245 //
1246 // Concatenate tensor arguments on the same axis as the original operation,
1247 // and concatenate scalars into the vector.
1248 if (is_all_tensors(0, axis) && is_all_scalars(1)) {
1249 std::array<int64_t, 1> rhs_dims{static_cast<int64_t>(op.values().size())};
1250 auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType());
1251 return HoistParams{args(0),
1252 args(1),
1253 axis,
1254 0,
1255 op.getType(),
1256 rhs_type,
1257 /*scalar_operand_idx=*/1};
1258 } else if (is_all_tensors(1, axis) && is_all_scalars(0)) {
1259 std::array<int64_t, 1> lhs_dims{static_cast<int64_t>(op.values().size())};
1260 auto lhs_type = RankedTensorType::get(lhs_dims, ranked.getElementType());
1261 return HoistParams{args(0),
1262 args(1),
1263 0,
1264 axis,
1265 lhs_type,
1266 op.getType(),
1267 /*scalar_operand_idx=*/0};
1268 }
1269 return None;
1270 }
1271
1272 } // namespace
1273
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1274 void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
1275 MLIRContext *context) {
1276 results.insert<HoistCwiseBinaryOutOfConcat, HoistCwiseUnaryOutOfConcat>(
1277 context);
1278 }
1279
1280 //===----------------------------------------------------------------------===//
1281 // CumsumOp and CumprodOp
1282 //===----------------------------------------------------------------------===//
1283
1284 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1285 OpT, CumsumOp, CumprodOp>::value>::type * = nullptr>
Verify(OpT op)1286 static LogicalResult Verify(OpT op) {
1287 if (!IsOfRankOrUnranked(op.axis(), 0))
1288 return op.emitOpError("requires scalar axis operand");
1289
1290 DenseIntElementsAttr axis_attr;
1291 if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
1292 auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>();
1293 if (input_ty) {
1294 int64_t rank = input_ty.getRank();
1295 assert(axis_attr.getNumElements() == 1 &&
1296 "scalar attribute should have exactly one element");
1297 int64_t axis = (*axis_attr.begin()).getSExtValue();
1298 if (axis < -rank || axis >= rank) {
1299 return op.emitError()
1300 << "axis operand should be within range [" << -rank << ", "
1301 << rank << "); actual value: " << axis;
1302 }
1303 }
1304 }
1305
1306 return success();
1307 }
1308
1309 //===----------------------------------------------------------------------===//
1310 // ConcatOffsetOp
1311 //===----------------------------------------------------------------------===//
1312
Verify(ConcatOffsetOp op)1313 static LogicalResult Verify(ConcatOffsetOp op) {
1314 if (op.N() < 2)
1315 return op.emitOpError() << "requires N to be at least 2, got " << op.N();
1316
1317 if (op.shape().size() != op.offset().size())
1318 return op.emitOpError()
1319 << "requires sizes of shapes and offsets to be the same, got sizes "
1320 << op.shape().size() << " and " << op.offset().size();
1321
1322 auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
1323 if (ranked_dim && ranked_dim.getRank() != 0)
1324 return op.emitOpError()
1325 << "requires concat_dim to be a scalar, got tensor of rank "
1326 << ranked_dim.getRank();
1327
1328 int64_t num_dims = -1;
1329 for (auto shape_offset_idx :
1330 llvm::enumerate(llvm::zip(op.shape(), op.offset()))) {
1331 Value shape = std::get<0>(shape_offset_idx.value());
1332 Value offset = std::get<1>(shape_offset_idx.value());
1333 const size_t idx = shape_offset_idx.index();
1334
1335 if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
1336 return op.emitOpError() << "requires operand and result " << idx
1337 << " to have compatible shapes";
1338
1339 auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
1340 if (!ranked_shape) continue;
1341
1342 if (ranked_shape.getRank() != 1)
1343 return op.emitOpError() << "requires shape tensor operand " << idx
1344 << " to be of rank 1, got tensor of rank "
1345 << ranked_shape.getRank();
1346
1347 if (!ranked_shape.hasStaticShape()) continue;
1348
1349 int64_t ranked_shape_dim = ranked_shape.getDimSize(0);
1350 if (num_dims == -1)
1351 num_dims = ranked_shape_dim;
1352 else if (ranked_shape_dim != num_dims)
1353 return op.emitOpError()
1354 << "requires shape tensor (rank 1) operand " << idx
1355 << " to be of length " << num_dims
1356 << ", got tensor (rank 1) of length " << ranked_shape_dim;
1357 }
1358
1359 return success();
1360 }
1361
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1362 LogicalResult ConcatOffsetOp::fold(ArrayRef<Attribute> operands,
1363 SmallVectorImpl<OpFoldResult> &results) {
1364 // ConcatOffset must have its first operand be concat_dim and at least two
1365 // shape tensors in variadic shapes operand.
1366 if (operands.size() < 3) return failure();
1367
1368 // Check concat_dim is a scalar.
1369 auto concat_dim_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
1370 if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0)
1371 return failure();
1372
1373 llvm::SmallVector<DenseIntElementsAttr, 4> shapes;
1374 shapes.reserve(operands.size() - 1);
1375 for (Attribute shape : llvm::drop_begin(operands, 1))
1376 if (auto shape_attr = shape.dyn_cast_or_null<DenseIntElementsAttr>())
1377 shapes.push_back(shape_attr);
1378 else
1379 return failure();
1380
1381 // Check all shapes are vectors of the same length.
1382 if (shapes.front().getType().getRank() != 1) return success();
1383 const int64_t num_dims = shapes.front().getNumElements();
1384 for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1))
1385 if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims)
1386 return failure();
1387
1388 // Check concat_dim is within [-num_dims, num_dims).
1389 int32_t concat_dim = (*concat_dim_attr.getValues<int32_t>().begin());
1390 if (concat_dim < 0) concat_dim += num_dims;
1391 if (concat_dim >= num_dims || concat_dim < 0) return failure();
1392
1393 // Check all elements besides at concat_dim match across all shape tensors.
1394 SmallVector<int32_t, 4> shape0;
1395 shape0.reserve(num_dims);
1396 for (int32_t dim : shapes.front().getValues<int32_t>()) shape0.push_back(dim);
1397
1398 for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) {
1399 for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) {
1400 if (dims_and_idx.index() == concat_dim) continue;
1401
1402 if (std::get<0>(dims_and_idx.value()) !=
1403 std::get<1>(dims_and_idx.value()).getSExtValue())
1404 return failure();
1405 }
1406 }
1407
1408 // Compute an exclusive cumulative sum of elements at concat_dim.
1409 results.reserve(shapes.size());
1410 SmallVector<int32_t, 4> cumulative_sum(num_dims, 0);
1411 RankedTensorType offset_type =
1412 RankedTensorType::get({num_dims}, IntegerType::get(getContext(), 32));
1413 for (DenseIntElementsAttr shape : shapes) {
1414 results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum));
1415 cumulative_sum[concat_dim] += shape.getValue<int32_t>(concat_dim);
1416 }
1417
1418 return success();
1419 }
1420
1421 //===----------------------------------------------------------------------===//
1422 // ConstOp
1423 //===----------------------------------------------------------------------===//
1424
fold(ArrayRef<Attribute> operands)1425 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
1426 assert(operands.empty() && "constant has no operands");
1427
1428 // Return the held attribute value.
1429 return value();
1430 }
1431
1432 // Builds a constant op with the specified attribute `value`. The result
1433 // op's type is deduced from `value`; if `value` is of scalar type,
1434 // wraps it up with a tensor type of empty shape.
1435 // TODO(jpienaar): This one differs from the autogenerated one as it takes an
1436 // attribute but always creates an ElementsAttr internally.
build(OpBuilder & builder,OperationState & result,Attribute value)1437 void ConstOp::build(OpBuilder &builder, OperationState &result,
1438 Attribute value) {
1439 ShapedType type;
1440 if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1441 return ConstOp::build(builder, result, elem_attr);
1442 } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
1443 // All TensorFlow types must be tensor types. In the build() method,
1444 // we want to provide more flexibility by allowing attributes of scalar
1445 // types. But we need to wrap it up with ElementsAttr to construct
1446 // valid TensorFlow constants.
1447 type = RankedTensorType::get(/*shape=*/{}, value.getType());
1448 return ConstOp::build(builder, result, DenseElementsAttr::get(type, value));
1449 }
1450 // TODO(jpienaar): support other TensorFlow specific types.
1451 llvm_unreachable("unsupported attribute type for building tf.Const");
1452 }
1453
build(OpBuilder & builder,OperationState & result,Type type,Attribute value)1454 void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
1455 Attribute value) {
1456 // Handle the case where the type and value are already tensors.
1457 if (type.isa<TensorType>() && value.isa<ElementsAttr>()) {
1458 result.addTypes(type);
1459 result.addAttribute("value", value);
1460 return;
1461 }
1462
1463 // Otherwise, default to the attribute builder.
1464 ConstOp::build(builder, result, value);
1465 assert(type == result.types[0] && "type mismatch in construction");
1466 }
1467
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1468 LogicalResult ConstOp::inferReturnTypes(
1469 MLIRContext *context, Optional<Location> location, ValueRange operands,
1470 DictionaryAttr attributes, RegionRange regions,
1471 SmallVectorImpl<Type> &inferredReturnTypes) {
1472 auto value = attributes.get("value");
1473 if (!value) return emitOptionalError(location, "missing attribute 'value'");
1474 if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
1475 inferredReturnTypes.assign({elem_attr.getType()});
1476 return success();
1477 }
1478 return emitOptionalError(location,
1479 "attribute 'value' failed to satisfy constraint: "
1480 "constant vector/tensor");
1481 }
1482
1483 //===----------------------------------------------------------------------===//
1484 // Conv2DOp and Conv3DOp
1485 //===----------------------------------------------------------------------===//
1486
VerifyConvOpAttributes(int num_dims,ArrayRef<Attribute> strides,ArrayRef<Attribute> dilations,llvm::Optional<mlir::Location> location)1487 static LogicalResult VerifyConvOpAttributes(
1488 int num_dims, ArrayRef<Attribute> strides, ArrayRef<Attribute> dilations,
1489 llvm::Optional<mlir::Location> location) {
1490 int64_t strides_size = strides.size();
1491 if (strides_size != num_dims)
1492 return emitOptionalError(
1493 location, "requires strides attribute length to be ", num_dims);
1494 auto is_not_positive = [](Attribute val) {
1495 return val.cast<IntegerAttr>().getValue().getSExtValue() <= 0;
1496 };
1497 if (llvm::any_of(strides, is_not_positive))
1498 return emitOptionalError(location, "requires positive strides");
1499
1500 int64_t dilations_size = dilations.size();
1501 if (dilations_size != num_dims)
1502 return emitOptionalError(
1503 location, "requires dilations attribute length to be ", num_dims);
1504 if (llvm::any_of(dilations, is_not_positive))
1505 return emitOptionalError(location, "requires positive dilations");
1506
1507 return success();
1508 }
1509
1510 // Verifies that,
1511 // * Number of input channels is divisible by the number of filter input
1512 // channels
1513 template <typename OpT, typename std::enable_if<llvm::is_one_of<
1514 OpT, Conv2DOp, Conv3DOp>::value>::type * = nullptr>
Verify(OpT op)1515 static LogicalResult Verify(OpT op) {
1516 int num_spatial_dims = std::is_same<OpT, Conv2DOp>() ? 2 : 3;
1517 int num_dims = 2 + num_spatial_dims;
1518
1519 int64_t input_channels = -1;
1520 if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
1521 absl::string_view data_format(op.data_format().data(),
1522 op.data_format().size());
1523 tensorflow::TensorFormat format;
1524 auto is_valid = FormatFromString(data_format, &format);
1525 DCHECK(is_valid) << data_format;
1526 int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format);
1527 input_channels = ty.getDimSize(idx);
1528 }
1529
1530 int64_t filter_channels = -1;
1531 if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
1532 int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
1533 num_dims, tensorflow::FORMAT_HWIO);
1534 filter_channels = ty.getDimSize(idx);
1535 }
1536
1537 if (input_channels != -1 && filter_channels != -1 &&
1538 input_channels % filter_channels != 0)
1539 return op.emitOpError()
1540 << "requires the number of input channels to be divisible by the "
1541 "number of filter input channels; found "
1542 << input_channels << " and " << filter_channels << ", respectively";
1543
1544 return success();
1545 }
1546
UpdateDataFormat(StringRef data_format)1547 LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
1548 auto perm = GetDataFormatPermutation(this->data_format(), data_format);
1549 if (perm.empty()) return failure();
1550
1551 // Update data_format attribute and result types.
1552 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1553
1554 // Update convolution attributes.
1555 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1556 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1557 (*this)->setAttr("explicit_paddings",
1558 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1559
1560 return success();
1561 }
1562
1563 // Verifies the inferred return type of the given operation.
1564 template <typename OpT,
1565 typename std::enable_if<llvm::is_one_of<
1566 OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type * = nullptr>
inferConvReturnTypes(OpT op,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes,llvm::Optional<mlir::Location> location,ArrayRef<Attribute> explicit_padding)1567 static LogicalResult inferConvReturnTypes(
1568 OpT op, llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes,
1569 llvm::Optional<mlir::Location> location,
1570 ArrayRef<Attribute> explicit_padding) {
1571 const int64_t num_spatial_dims = std::is_same<OpT, Conv2DOpAdaptor>() ? 2 : 3;
1572 const int64_t num_dims = 2 + num_spatial_dims;
1573 const Value input = op.input();
1574 const Value filter = op.filter();
1575 const TensorType input_ty = input.getType().template cast<TensorType>();
1576 const TensorType filter_ty = filter.getType().template cast<TensorType>();
1577 const StringRef paddings = op.padding().getValue();
1578
1579 ArrayRef<Attribute> strides = op.strides().getValue();
1580 StringRef data_format = op.data_format().getValue();
1581 ArrayRef<Attribute> dilations = op.dilations().getValue();
1582
1583 tensorflow::TensorFormat format;
1584 auto data_format_is_valid = FormatFromString(data_format.str(), &format);
1585 if (!data_format_is_valid) {
1586 return emitOptionalError(location, "Invalid data format provided");
1587 }
1588 tensorflow::Padding padding;
1589 auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding);
1590 if (!padding_is_valid.ok()) {
1591 return emitOptionalError(location, "Invalid padding format provided");
1592 }
1593 auto get_int = [](Attribute attr) {
1594 return attr.template cast<IntegerAttr>().getInt();
1595 };
1596
1597 // Necessary sanity checks.
1598 // Verifies that,
1599 // * Ranks of operands and result are valid
1600 // * Length of explicit_paddings attribute is valid and has non negative
1601 // elements
1602 // * strides and dilations attributes have positive elements
1603 if (!IsOfRankOrUnranked(input, num_dims) ||
1604 !IsOfRankOrUnranked(filter, num_dims))
1605 return emitOptionalError(location, "requires operands to be ", num_dims,
1606 "D tensor");
1607
1608 if (padding == tensorflow::Padding::EXPLICIT) {
1609 if (explicit_padding.size() == 0) {
1610 return emitOptionalError(location,
1611 "requires attribute 'explicit_paddings' with "
1612 "'EXPLICIT' padding mode");
1613 }
1614 if (explicit_padding.size() != num_dims * 2) {
1615 return emitOptionalError(
1616 location, "requires explicit_paddings attribute length to be ",
1617 num_dims * 2);
1618 }
1619 auto is_negative = [](Attribute val) {
1620 return val.cast<IntegerAttr>().getValue().getSExtValue() < 0;
1621 };
1622 if (llvm::any_of(explicit_padding, is_negative))
1623 return emitOptionalError(location,
1624 "requires non negative explicit paddings");
1625 }
1626
1627 if (failed(VerifyConvOpAttributes(num_dims, strides, dilations, location))) {
1628 return failure();
1629 }
1630
1631 // Output always have `num_dims` rank. All dimensions are initialized to
1632 // dynamic size and can be partially inferred.
1633 SmallVector<int64_t, 4> return_shape(num_dims, ShapedType::kDynamicSize);
1634 // Output batch and channel dimension can be obtained using utilities from
1635 // tensorflow/core/util/tensor_format.h.
1636 if (input_ty.hasRank()) {
1637 return_shape[GetTensorBatchDimIndex(num_dims, format)] =
1638 input_ty.getDimSize(GetTensorBatchDimIndex(num_dims, format));
1639 }
1640 if (filter_ty.hasRank()) {
1641 return_shape[GetTensorFeatureDimIndex(num_dims, format)] =
1642 filter_ty.getDimSize(GetFilterTensorOutputChannelsDimIndex(
1643 num_dims, tensorflow::FORMAT_HWIO));
1644 }
1645 // Spatial dimensions can be inferred only when both input and filter are
1646 // ranked because we need to get their spatial dimensions.
1647 if (input_ty.hasRank() && filter_ty.hasRank()) {
1648 // Checks the size of each of the output spatial dimensions.
1649 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1650 const int64_t dim = GetTensorSpatialDimIndex(num_dims, format, i);
1651 int64_t stride = get_int(strides[dim]);
1652 tensorflow::int64 expected_output_size;
1653 tensorflow::int64 pad_low;
1654 tensorflow::int64 pad_high;
1655 // Retrieve padding, if defined explicitly.
1656 if (padding == tensorflow::Padding::EXPLICIT) {
1657 pad_low = get_int(explicit_padding[2 * dim]);
1658 pad_high = get_int(explicit_padding[2 * dim + 1]);
1659 }
1660 // Skip if input or filter size is dynamic.
1661 if (input_ty.isDynamicDim(dim) || filter_ty.isDynamicDim(i)) continue;
1662 // Calculate the expected_output_size.
1663 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1664 input_ty.getDimSize(dim), filter_ty.getDimSize(i),
1665 get_int(dilations[dim]), stride, padding, &expected_output_size,
1666 &pad_low, &pad_high);
1667 // Return failure if expected_output_size could not be calculated.
1668 if (!status.ok()) return failure();
1669 return_shape[dim] = expected_output_size;
1670 }
1671 }
1672
1673 inferredReturnTypes.assign(
1674 {RankedTensorType::get(return_shape, input_ty.getElementType())});
1675 return success();
1676 }
1677
inferReturnTypes(mlir::MLIRContext * context,llvm::Optional<mlir::Location> location,mlir::ValueRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes)1678 LogicalResult Conv2DOp::inferReturnTypes(
1679 mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
1680 mlir::ValueRange operands, mlir::DictionaryAttr attributes,
1681 mlir::RegionRange regions,
1682 llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
1683 Conv2DOpAdaptor op(operands, attributes);
1684 ArrayRef<Attribute> explicit_padding;
1685 ArrayAttr explicit_pad =
1686 attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1687 if (!explicit_pad) {
1688 explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1689 }
1690 explicit_padding = explicit_pad.getValue();
1691
1692 return inferConvReturnTypes(op, inferredReturnTypes, location,
1693 explicit_padding);
1694 }
1695
GetOptimalLayout(const RuntimeDevices & devices)1696 StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) {
1697 // Keep current data format if no GPUs are available or if explicit placement
1698 // does not allow to use GPU for this operation.
1699 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1700 return data_format();
1701
1702 // Input must be a tensor.
1703 auto input_ty = input().getType().dyn_cast<TensorType>();
1704 if (!input_ty) return data_format();
1705
1706 // For f16 data type on devices with Tensor Cores support NHWC data format
1707 // is up to ~2x faster.
1708 const bool is_f16 = input_ty.getElementType().isF16();
1709 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1710
1711 // For f32/f16 data type decision depends on the filter size in spatial
1712 // dimensions, for other data types we keep current data format.
1713 if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16())
1714 return data_format();
1715
1716 // Keep current data format if filter rank is unknown or not equal to 4.
1717 auto filter_ty = filter().getType().dyn_cast<RankedTensorType>();
1718 if (!filter_ty || filter_ty.getRank() != 4) return data_format();
1719
1720 const int64_t d0 = filter_ty.getDimSize(0);
1721 const int64_t d1 = filter_ty.getDimSize(1);
1722
1723 auto all_ones = [](ArrayAttr arr) -> bool {
1724 return llvm::all_of(arr, [](Attribute attr) -> bool {
1725 return attr.cast<IntegerAttr>().getInt() == 1;
1726 });
1727 };
1728
1729 // Convolutions with 1x1 filter and with strides and dilations all ones, can
1730 // be computed as a GEMM in NHWC data format, and can be up to ~2x times
1731 // faster than convolution in NCHW.
1732 const bool one_by_one = d0 == 1 && d1 == 1;
1733 const bool trivial_strides = all_ones(strides());
1734 const bool trivial_dilations = all_ones(dilations());
1735
1736 // TODO(ezhulenev): This might lead to excessive transposes in the final IR,
1737 // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1.
1738 // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all
1739 // users can efficiently use NHWC data format?
1740 if (one_by_one && trivial_strides && trivial_dilations) {
1741 return "NHWC";
1742 }
1743
1744 // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because
1745 // it's the fastest option on NVIDIA GPUs with cuDNN library support.
1746 return "NCHW";
1747 }
1748
1749 //===----------------------------------------------------------------------===//
1750 // Conv2dBackpropFilterOp
1751 //===----------------------------------------------------------------------===//
1752
UpdateDataFormat(StringRef data_format)1753 LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) {
1754 StringRef src_data_format = this->data_format();
1755
1756 auto perm = GetDataFormatPermutation(src_data_format, data_format);
1757 if (perm.empty()) return failure();
1758
1759 // Update data_format attribute and result types.
1760 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1761
1762 // Update convolution attributes.
1763 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1764 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1765 (*this)->setAttr("explicit_paddings",
1766 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1767
1768 // Permute filter sizes operand.
1769 OpBuilder builder(getOperation());
1770 auto filter_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1771 getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format),
1772 StringAttr::get(getContext(), data_format));
1773 setOperand(1, filter_sizes_permuted);
1774
1775 return success();
1776 }
1777
GetOptimalLayout(const RuntimeDevices & devices)1778 StringRef Conv2DBackpropFilterOp::GetOptimalLayout(
1779 const RuntimeDevices &devices) {
1780 // Keep current data format if no GPUs are available or if explicit placement
1781 // does not allow to use GPU for this operation.
1782 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1783 return data_format();
1784
1785 // Input must be a tensor.
1786 auto input_ty = input().getType().dyn_cast<TensorType>();
1787 if (!input_ty) return data_format();
1788
1789 // For f16 data type on devices with Tensor Cores support NHWC data format
1790 // is up to ~2x faster.
1791 const bool is_f16 = input_ty.getElementType().isF16();
1792 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1793
1794 // Otherwise always use "NCHW".
1795 return "NCHW";
1796 }
1797
1798 //===----------------------------------------------------------------------===//
1799 // Conv2DBackpropInputOp
1800 //===----------------------------------------------------------------------===//
1801
Verify(Conv2DBackpropInputOp op)1802 static LogicalResult Verify(Conv2DBackpropInputOp op) {
1803 int num_spatial_dims = 2;
1804 int num_dims = 2 + num_spatial_dims;
1805
1806 if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) ||
1807 !IsOfRankOrUnranked(op.filter(), num_dims))
1808 return op.emitOpError()
1809 << "requires operands to be " << num_dims << "D tensor";
1810 if (!IsOfRankOrUnranked(op.getResult(), num_dims))
1811 return op.emitOpError()
1812 << "requires result to be " << num_dims << "D tensor";
1813
1814 llvm::Optional<mlir::Location> location = op.getLoc();
1815 ArrayRef<Attribute> strides = op.strides().getValue();
1816 ArrayRef<Attribute> dilations = op.dilations().getValue();
1817 LogicalResult verify_result =
1818 VerifyConvOpAttributes(num_dims, strides, dilations, location);
1819 if (failed(verify_result)) {
1820 return verify_result;
1821 }
1822
1823 return success();
1824 }
1825
UpdateDataFormat(StringRef data_format)1826 LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) {
1827 StringRef src_data_format = this->data_format();
1828
1829 auto perm = GetDataFormatPermutation(src_data_format, data_format);
1830 if (perm.empty()) return failure();
1831
1832 // Update data_format attribute and result types.
1833 if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
1834
1835 // Update convolution attributes.
1836 (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
1837 (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
1838 (*this)->setAttr("explicit_paddings",
1839 ShuffleArrayAttr(explicit_paddings(), perm, 2));
1840
1841 // Permute input sizes operand.
1842 OpBuilder builder(getOperation());
1843 auto input_sizes_permuted = builder.create<TF::DataFormatVecPermuteOp>(
1844 getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format),
1845 StringAttr::get(getContext(), data_format));
1846 setOperand(0, input_sizes_permuted);
1847
1848 return success();
1849 }
1850
GetOptimalLayout(const RuntimeDevices & devices)1851 StringRef Conv2DBackpropInputOp::GetOptimalLayout(
1852 const RuntimeDevices &devices) {
1853 // Keep current data format if no GPUs are available or if explicit placement
1854 // does not allow to use GPU for this operation.
1855 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
1856 return data_format();
1857
1858 // Filter must be a tensor.
1859 auto filter_ty = filter().getType().dyn_cast<TensorType>();
1860 if (!filter_ty) return data_format();
1861
1862 // For f16 data type on devices with Tensor Cores support NHWC data format
1863 // is up to ~2x faster.
1864 const bool is_f16 = filter_ty.getElementType().isF16();
1865 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
1866
1867 // Otherwise always use "NCHW".
1868 return "NCHW";
1869 }
1870
1871 //===----------------------------------------------------------------------===//
1872 // Conv3DOp
1873 //===----------------------------------------------------------------------===//
1874
inferReturnTypes(mlir::MLIRContext * context,llvm::Optional<mlir::Location> location,mlir::ValueRange operands,mlir::DictionaryAttr attributes,mlir::RegionRange regions,llvm::SmallVectorImpl<mlir::Type> & inferredReturnTypes)1875 LogicalResult Conv3DOp::inferReturnTypes(
1876 mlir::MLIRContext *context, llvm::Optional<mlir::Location> location,
1877 mlir::ValueRange operands, mlir::DictionaryAttr attributes,
1878 mlir::RegionRange regions,
1879 llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
1880 Conv3DOpAdaptor op(operands, attributes);
1881 ArrayRef<Attribute> explicit_padding;
1882 ArrayAttr explicit_pad =
1883 attributes.get("explicit_paddings").dyn_cast_or_null<::mlir::ArrayAttr>();
1884 if (!explicit_pad) {
1885 explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({});
1886 }
1887 explicit_padding = explicit_pad.getValue();
1888
1889 return inferConvReturnTypes(op, inferredReturnTypes, location,
1890 explicit_padding);
1891 }
1892
1893 //===----------------------------------------------------------------------===//
1894 // DataFormatVecPermuteOp
1895 //===----------------------------------------------------------------------===//
1896
Verify(DataFormatVecPermuteOp op)1897 static LogicalResult Verify(DataFormatVecPermuteOp op) {
1898 auto input_ty = op.x().getType().dyn_cast<RankedTensorType>();
1899 if (!input_ty) return success();
1900
1901 int rank = input_ty.getRank();
1902 if (rank != 1 && rank != 2)
1903 return op.emitOpError("requires input of rank 1 or 2");
1904
1905 if (rank == 1) {
1906 int64_t dim0 = input_ty.getDimSize(0);
1907 if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
1908 return op.emitOpError("requires 1D input of size 4 or size 2");
1909 }
1910
1911 if (rank == 2) {
1912 int64_t dim0 = input_ty.getDimSize(0);
1913 if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
1914 return op.emitOpError(
1915 "requires first dimensions of 2D input to be of size 4");
1916
1917 int64_t dim1 = input_ty.getDimSize(1);
1918 if (dim1 != ShapedType::kDynamicSize && dim1 != 2)
1919 return op.emitOpError(
1920 "requires second dimensions of 2D input to be of size 2");
1921 }
1922
1923 return success();
1924 }
1925
1926 //===----------------------------------------------------------------------===//
1927 // DivOp
1928 //===----------------------------------------------------------------------===//
1929
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1930 void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1931 MLIRContext *context) {
1932 results.insert<DivWithSqrtDivisor>(context);
1933 }
1934
fold(ArrayRef<Attribute> operands)1935 OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
1936 return IdentityArithmeticOpFolder<DivOp>(*this, operands);
1937 }
1938
1939 //===----------------------------------------------------------------------===//
1940 // DynamicStitchOp
1941 //===----------------------------------------------------------------------===//
1942
Verify(DynamicStitchOp op)1943 static LogicalResult Verify(DynamicStitchOp op) {
1944 if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1");
1945
1946 if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) {
1947 if (out_ty.getRank() == 0) {
1948 return op.emitOpError("requires non scalar output");
1949 }
1950 }
1951
1952 llvm::SmallDenseSet<int64_t, 8> index_values;
1953 bool all_indices_const = true;
1954 int32_t max_index = -1;
1955 llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
1956 for (auto it : llvm::zip(op.indices(), op.data())) {
1957 Value index = std::get<0>(it);
1958
1959 DenseIntElementsAttr index_attr;
1960 if (matchPattern(index, m_Constant(&index_attr))) {
1961 for (int32_t index : index_attr.getValues<int32_t>()) {
1962 if (index < 0)
1963 return op.emitOpError()
1964 << "requires non-negative index values; found " << index;
1965 max_index = std::max(index, max_index);
1966 index_values.insert(index);
1967 }
1968 } else {
1969 all_indices_const = false;
1970 }
1971
1972 Value data = std::get<1>(it);
1973 RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
1974 RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
1975 if (!index_ty || !data_ty) continue;
1976
1977 int64_t index_rank = index_ty.getRank();
1978 ArrayRef<int64_t> data_shape = data_ty.getShape();
1979 ArrayRef<int64_t> index_shape = index_ty.getShape();
1980 if (failed(mlir::verifyCompatibleShape(index_shape,
1981 data_shape.take_front(index_rank))))
1982 return op.emitOpError() << "requires shape of data with type " << data_ty
1983 << " to have prefix matching with shape of the "
1984 "corresponding index type "
1985 << index_ty;
1986
1987 ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank);
1988 if (!inferred_item_shape) {
1989 inferred_item_shape = llvm::to_vector<4>(item_shape);
1990 continue;
1991 }
1992
1993 if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape)))
1994 return op.emitOpError() << "has inconsistent shaped data and index "
1995 "pairs; inferred item shapes ["
1996 << llvm::makeArrayRef(*inferred_item_shape)
1997 << "] and [" << item_shape << "] don't match";
1998 for (int i = 0, e = item_shape.size(); i < e; ++i) {
1999 int64_t &inferred_dim = (*inferred_item_shape)[i];
2000 int64_t dim = item_shape[i];
2001 if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim;
2002 }
2003 }
2004
2005 // If all indices are constants, then verify that they cover all indices in
2006 // the range [0, max_index] and the output type is legal.
2007 if (all_indices_const) {
2008 for (int32_t i = 0; i <= max_index; i++) {
2009 if (!index_values.count(i))
2010 return op.emitOpError() << "missing index " << i;
2011 }
2012
2013 if (inferred_item_shape) {
2014 SmallVector<int64_t, 4> expected_shape;
2015 expected_shape.push_back(max_index + 1);
2016 expected_shape.append(inferred_item_shape->begin(),
2017 inferred_item_shape->end());
2018
2019 auto out_ty = op.getType().cast<TensorType>();
2020 auto expected_out_ty =
2021 RankedTensorType::get(expected_shape, out_ty.getElementType());
2022
2023 if (!AreCastCompatible({out_ty, expected_out_ty})) {
2024 return op.emitOpError() << "has invalid output type; should be "
2025 "compatible with inferred type "
2026 << expected_out_ty;
2027 }
2028 }
2029 }
2030
2031 return success();
2032 }
2033
2034 //===----------------------------------------------------------------------===//
2035 // EinsumOp
2036 //===----------------------------------------------------------------------===//
2037
2038 // Verifies that,
2039 // * Arity of the op is at most two.
2040 //
2041 // TODO(hinsu): Verify einsum equation attribute.
Verify(EinsumOp op)2042 static LogicalResult Verify(EinsumOp op) {
2043 if (op.N() > 2) {
2044 return op.emitOpError("supports at most two operands");
2045 }
2046 return success();
2047 }
2048
2049 //===----------------------------------------------------------------------===//
2050 // EmptyOp
2051 //===----------------------------------------------------------------------===//
2052
fold(ArrayRef<Attribute> operands)2053 OpFoldResult EmptyOp::fold(ArrayRef<Attribute> operands) {
2054 assert(operands.size() == 1 && "empty op has one operand");
2055
2056 Attribute attr = operands.front();
2057 if (!attr) return {};
2058
2059 auto int_attr = attr.cast<DenseIntElementsAttr>();
2060 SmallVector<int64_t, 6> out_shape;
2061 for (const auto val : int_attr.getValues<int32_t>()) {
2062 out_shape.push_back(val);
2063 }
2064
2065 auto type = getResult().getType().cast<ShapedType>();
2066 auto etype = type.getElementType();
2067
2068 // We can not fold if the result is not static.
2069 if (!type.hasStaticShape()) return {};
2070
2071 if (auto float_type = etype.dyn_cast<FloatType>()) {
2072 auto out_type = RankedTensorType::get(out_shape, float_type);
2073 return DenseElementsAttr::get(out_type,
2074 {APFloat(float_type.getFloatSemantics())});
2075 }
2076
2077 if (auto int_type = etype.dyn_cast<IntegerType>()) {
2078 auto out_type = RankedTensorType::get(out_shape, etype);
2079 APInt val(int_type.getWidth(), 0, int_type.getSignedness());
2080 return DenseElementsAttr::get(out_type, val);
2081 }
2082
2083 return {};
2084 }
2085
2086 //===----------------------------------------------------------------------===//
2087 // EmptyTensorListOp
2088 //===----------------------------------------------------------------------===//
2089
Verify(EmptyTensorListOp op)2090 static LogicalResult Verify(EmptyTensorListOp op) {
2091 if (!IsOfRankOrUnranked(op.element_shape(), 0) &&
2092 !IsOfRankOrUnranked(op.element_shape(), 1)) {
2093 return op.emitOpError("requires element_shape operand to be 0D/1D tensor");
2094 }
2095
2096 if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) {
2097 return op.emitOpError("requires max_num_elements operand to be 0D tensor");
2098 }
2099 return success();
2100 }
2101
2102 //===----------------------------------------------------------------------===//
2103 // EqualOp
2104 //===----------------------------------------------------------------------===//
2105
Verify(EqualOp op)2106 static LogicalResult Verify(EqualOp op) {
2107 // If we allow inputs to have incompatible type, then nothing to do.
2108 if (!op.incompatible_shape_error()) return success();
2109
2110 // Otherwise, check inputs are broadcastable.
2111 return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
2112 op.getOperation());
2113 }
2114
build(OpBuilder & builder,OperationState & result,Value x,Value y,BoolAttr incompatible_shape_error)2115 void EqualOp::build(OpBuilder &builder, OperationState &result, Value x,
2116 Value y, BoolAttr incompatible_shape_error) {
2117 auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y,
2118 incompatible_shape_error);
2119 return build(builder, result, result_type, x, y, incompatible_shape_error);
2120 }
2121
2122 //===----------------------------------------------------------------------===//
2123 // ExpandDimsOp
2124 //===----------------------------------------------------------------------===//
2125
InferExpandDimsOpType(Value input,Value dim)2126 Type InferExpandDimsOpType(Value input, Value dim) {
2127 Type element_ty = input.getType().cast<TensorType>().getElementType();
2128 auto unranked_ty = UnrankedTensorType::get(element_ty);
2129
2130 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
2131 if (!input_ty) return unranked_ty;
2132
2133 DenseIntElementsAttr dim_attr;
2134 if (!matchPattern(dim, m_Constant(&dim_attr)) ||
2135 dim_attr.getNumElements() != 1)
2136 return unranked_ty;
2137 int64_t dim_val = (*dim_attr.begin()).getSExtValue();
2138 int64_t input_rank = input_ty.getRank();
2139
2140 if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty;
2141 if (dim_val < 0) dim_val += input_rank + 1;
2142
2143 SmallVector<int64_t, 4> shape = llvm::to_vector<4>(input_ty.getShape());
2144 shape.insert(shape.begin() + dim_val, 1);
2145 return RankedTensorType::get(shape, element_ty);
2146 }
2147
build(OpBuilder & builder,OperationState & result,Value input,Value dim)2148 void ExpandDimsOp::build(OpBuilder &builder, OperationState &result,
2149 Value input, Value dim) {
2150 return build(builder, result, InferExpandDimsOpType(input, dim), input, dim);
2151 }
2152
2153 //===----------------------------------------------------------------------===//
2154 // FakeQuantWithMinMaxArgsOp
2155 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxArgsOp op)2156 static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) {
2157 // TODO(fengliuai): moving the following to an utility method.
2158 const llvm::fltSemantics &semantics = op.min().getSemantics();
2159 float rmin, rmax;
2160 if (&semantics == &APFloat::IEEEsingle()) {
2161 rmin = op.min().convertToFloat();
2162 rmax = op.max().convertToFloat();
2163 } else {
2164 rmin = op.min().convertToDouble();
2165 rmax = op.max().convertToDouble();
2166 }
2167 // Range boundaries must be valid.
2168 if (rmin >= rmax) {
2169 return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) +
2170 "," + Twine(std::to_string(rmax)) + "]");
2171 }
2172 int64_t num_bits = op.num_bits();
2173 if (num_bits < 2 || num_bits > 16) {
2174 return op.emitOpError(
2175 "requires num_bits to be between 2 and 16, inclusive");
2176 }
2177 return success();
2178 }
2179
2180 //===----------------------------------------------------------------------===//
2181 // FakeQuantWithMinMaxVarsOp
2182 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxVarsOp op)2183 static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) {
2184 auto min = GetRankedTensorTypeForOperand(op.min());
2185 if (min && !IsOfRankedFloatTensorType(min, 0))
2186 return op.emitOpError("requires min to be a 0d float tensor");
2187
2188 auto max = GetRankedTensorTypeForOperand(op.max());
2189 if (max && !IsOfRankedFloatTensorType(max, 0))
2190 return op.emitOpError("requires max to be a 0d float tensor");
2191
2192 int64_t num_bits = op.num_bits();
2193 if (num_bits < 2 || num_bits > 16) {
2194 return op.emitOpError(
2195 "requires num_bits to be between 2 and 16, inclusive");
2196 }
2197 return success();
2198 }
2199
2200 //===----------------------------------------------------------------------===//
2201 // FakeQuantWithMinMaxVarsPerChannelOp
2202 //===----------------------------------------------------------------------===//
Verify(FakeQuantWithMinMaxVarsPerChannelOp op)2203 static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
2204 auto min = GetRankedTensorTypeForOperand(op.min());
2205 if (min && !IsOfRankedFloatTensorType(min, 1))
2206 return op.emitOpError("requires min to be a 1d float tensor");
2207
2208 auto max = GetRankedTensorTypeForOperand(op.max());
2209 if (max && !IsOfRankedFloatTensorType(max, 1))
2210 return op.emitOpError("requires max to be a 1d float tensor");
2211
2212 Value inputs = op.inputs();
2213 if (!HasRankAtLeast(inputs, 1))
2214 return op.emitError("requires inputs to be at least 1d float tensor");
2215
2216 int64_t num_bits = op.num_bits();
2217 if (num_bits < 2 || num_bits > 16) {
2218 return op.emitOpError(
2219 "requires num_bits to be between 2 and 16, inclusive");
2220 }
2221
2222 auto inputs_type = inputs.getType().dyn_cast<RankedTensorType>();
2223 if (!inputs_type) return success();
2224 int depth = inputs_type.getDimSize(inputs_type.getRank() - 1);
2225 if ((min && min.getDimSize(0) != depth) ||
2226 (max && max.getDimSize(0) != depth)) {
2227 return op.emitOpError(
2228 "requires min and max to have same size as last dimension of inputs");
2229 }
2230
2231 return success();
2232 }
2233
2234 //===----------------------------------------------------------------------===//
2235 // FillOp
2236 //===----------------------------------------------------------------------===//
2237
Verify(FillOp op)2238 static LogicalResult Verify(FillOp op) {
2239 if (!IsOfRankOrUnranked(op.dims(), 1))
2240 return op.emitOpError() << "requires dims to be a 1D tensor";
2241 if (!IsOfRankOrUnranked(op.value(), 0))
2242 return op.emitOpError() << "requires value to be a scalar";
2243
2244 return success();
2245 }
2246
InferFillOpType(Value dims,Value value)2247 static ShapedType InferFillOpType(Value dims, Value value) {
2248 Type etype = value.getType().cast<ShapedType>().getElementType();
2249
2250 DenseIntElementsAttr dims_attr;
2251 if (!matchPattern(dims, m_Constant(&dims_attr))) {
2252 return UnrankedTensorType::get(etype);
2253 }
2254
2255 llvm::SmallVector<int64_t, 4> shape;
2256 shape.reserve(dims_attr.getNumElements());
2257 for (const APInt dim : dims_attr.getValues<APInt>()) {
2258 shape.push_back(dim.getSExtValue());
2259 }
2260 return RankedTensorType::get(shape, etype);
2261 }
2262
build(OpBuilder & builder,OperationState & result,Value dims,Value value)2263 void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
2264 Value value) {
2265 FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
2266 }
2267
fold(ArrayRef<Attribute> operands)2268 OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
2269 assert(operands.size() == 2 && "fill op has two operand");
2270
2271 auto type = getType().cast<ShapedType>();
2272 // DenseElementsAttr that is used in this folder only supports int and float
2273 // types.
2274 // TODO(hinsu): Handle complex types once there is a attribute kind for
2275 // complex.
2276 if (!type.getElementType().isIntOrFloat()) return {};
2277
2278 auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
2279 if (!value) return {};
2280
2281 if (type.hasStaticShape())
2282 return DenseElementsAttr::get(type, value.getValue({}));
2283
2284 auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
2285 if (!dims) return {};
2286
2287 llvm::SmallVector<int64_t, 4> shape;
2288 shape.reserve(dims.getNumElements());
2289 for (const APInt dim : dims.getValues<APInt>()) {
2290 shape.push_back(dim.getSExtValue());
2291 }
2292 type = RankedTensorType::get(shape, type.getElementType());
2293
2294 return DenseElementsAttr::get(type, value.getValue({}));
2295 }
2296
2297 //===----------------------------------------------------------------------===//
2298 // FusedBatchNormGradOp
2299 //===----------------------------------------------------------------------===//
2300
2301 // TODO(b/150954845): Add benchmarks to verify that layout preference didn't
2302 // change in the latest GPU generations.
2303
UpdateDataFormat(StringRef data_format)2304 LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) {
2305 return ::mlir::TF::UpdateDataFormat(data_format, this);
2306 }
2307
GetOptimalLayout(const RuntimeDevices & devices)2308 StringRef FusedBatchNormGradV3Op::GetOptimalLayout(
2309 const RuntimeDevices &devices) {
2310 // Keep current data format if no GPUs are available or if explicit placement
2311 // does not allow to use GPU for this operation.
2312 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2313 return data_format();
2314
2315 // For f16 data type on devices with Tensor Cores support NHWC data format
2316 // is up to ~2x faster.
2317 auto x_ty = x().getType().cast<TensorType>();
2318 const bool is_f16 = x_ty.getElementType().isF16();
2319 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2320
2321 // For all other data types prefer NCHW.
2322 return "NCHW";
2323 }
2324
2325 //===----------------------------------------------------------------------===//
2326 // FusedBatchNormOp
2327 //===----------------------------------------------------------------------===//
2328
Verify(FusedBatchNormOp op)2329 static LogicalResult Verify(FusedBatchNormOp op) {
2330 auto x = GetRankedTensorTypeForOperand(op.x());
2331 if (x && !IsOfRankedFloatTensorType(x, 4))
2332 return op.emitOpError("requires x to be a 4D float tensor");
2333
2334 auto scale = GetRankedTensorTypeForOperand(op.scale());
2335 if (scale && !IsOfRankedFloatTensorType(scale, 1))
2336 return op.emitOpError("requires scale to be a 1D float tensor");
2337
2338 auto offset = GetRankedTensorTypeForOperand(op.offset());
2339 if (offset && !IsOfRankedFloatTensorType(offset, 1))
2340 return op.emitOpError("requires offset to be a 1D float tensor");
2341
2342 auto mean = GetRankedTensorTypeForOperand(op.mean());
2343 if (mean && !IsOfRankedFloatTensorType(mean, 1))
2344 return op.emitOpError("requires mean to be a 1D float tensor");
2345
2346 auto variance = GetRankedTensorTypeForOperand(op.variance());
2347 if (variance && !IsOfRankedFloatTensorType(variance, 1))
2348 return op.emitOpError("requires variance to be a 1D float tensor");
2349
2350 // TODO(antiagainst): check attributes
2351
2352 return success();
2353 }
2354
2355 //===----------------------------------------------------------------------===//
2356 // FusedBatchNormV2Op / FusedBatchNormV3Op
2357 //===----------------------------------------------------------------------===//
2358
2359 template <class Op>
InferenceFoldOperandsPermutation(ArrayRef<int64_t> permutation,Op * op)2360 static LogicalResult InferenceFoldOperandsPermutation(
2361 ArrayRef<int64_t> permutation, Op *op) {
2362 // FusedBatchNorm in training mode is a layout sentitive operation, and should
2363 // have already assigned an optimal data format.
2364 if (op->is_training()) return failure();
2365 return ::mlir::TF::FoldOperandsPermutation(permutation, op);
2366 }
2367
2368 template <class Op>
GetOptimalLayout(const RuntimeDevices & devices,Op * op)2369 static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) {
2370 // In inference mode FusedBatchNorm is not sensitive to data layout.
2371 if (!op->is_training()) return op->data_format();
2372
2373 // Keep current data format if no GPUs are available or if explicit placement
2374 // does not allow to use GPU for this operation.
2375 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation()))
2376 return op->data_format();
2377
2378 // For f16 data type on devices with Tensor Cores support NHWC data format
2379 // is up to ~2x faster.
2380 auto x_ty = op->x().getType().template cast<TensorType>();
2381 const bool is_f16 = x_ty.getElementType().isF16();
2382 if (is_f16 && CanUseTensorCores(devices)) return "NHWC";
2383
2384 // For all other data types prefer NCHW.
2385 return "NCHW";
2386 }
2387
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2388 LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation(
2389 ArrayRef<int64_t> permutation) {
2390 return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2391 }
2392
UpdateDataFormat(StringRef data_format)2393 LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) {
2394 return ::mlir::TF::UpdateDataFormat(data_format, this);
2395 }
2396
GetOptimalLayout(const RuntimeDevices & devices)2397 StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) {
2398 return ::mlir::TF::GetOptimalLayout(devices, this);
2399 }
2400
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2401 LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
2402 ArrayRef<int64_t> permutation) {
2403 return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this);
2404 }
2405
UpdateDataFormat(StringRef data_format)2406 LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) {
2407 return ::mlir::TF::UpdateDataFormat(data_format, this);
2408 }
2409
GetOptimalLayout(const RuntimeDevices & devices)2410 StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) {
2411 return ::mlir::TF::GetOptimalLayout(devices, this);
2412 }
2413
2414 //===----------------------------------------------------------------------===//
2415 // GatherV2Op
2416 //===----------------------------------------------------------------------===//
2417
Verify(GatherV2Op op)2418 static LogicalResult Verify(GatherV2Op op) {
2419 int64_t batch_dims = op.batch_dims();
2420 if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
2421 int64_t rank = ty.getRank();
2422 if (batch_dims > rank || batch_dims < -rank)
2423 return op.emitOpError()
2424 << "batch_dims (" << batch_dims << ") must be in range [" << -rank
2425 << ", " << rank + 1 << ")";
2426 if (batch_dims < 0) batch_dims += rank;
2427 }
2428
2429 if (!HasRankAtMost(op.axis(), 1))
2430 return op.emitOpError("requires axis to have rank at most 1");
2431
2432 DenseIntElementsAttr axis_attr;
2433 if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
2434 int64_t axis = (*axis_attr.begin()).getSExtValue();
2435 if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
2436 int64_t rank = ty.getRank();
2437 if (axis >= rank || axis < -rank)
2438 return op.emitOpError() << "axis (" << axis << ") must be in range ["
2439 << -rank << ", " << rank << ")";
2440 if (axis < 0) axis += rank;
2441 }
2442
2443 if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) {
2444 return op.emitOpError() << "requires axis (" << axis
2445 << ") to be greater than or equal to batch_dims ("
2446 << batch_dims << ")";
2447 }
2448 }
2449 return success();
2450 }
2451
2452 //===----------------------------------------------------------------------===//
2453 // IfOp
2454 //===----------------------------------------------------------------------===//
2455
Verify(IfOp op)2456 static LogicalResult Verify(IfOp op) {
2457 auto branch_name = [](unsigned index) -> std::string {
2458 return index == 0 ? "'then_branch'" : "'else_branch'";
2459 };
2460 return VerifyCaseOrIfOpBranchFunctions(
2461 op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name);
2462 }
2463
2464 //===----------------------------------------------------------------------===//
2465 // IfOp canonicalization.
2466 //===----------------------------------------------------------------------===//
2467
2468 namespace {
2469 class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
2470 public:
FoldConstantIfOp(MLIRContext * context)2471 explicit FoldConstantIfOp(MLIRContext *context)
2472 : OpRewritePattern<TF::IfOp>(context) {}
2473 LogicalResult matchAndRewrite(TF::IfOp op,
2474 PatternRewriter &rewriter) const override;
2475
2476 private:
2477 template <typename T>
2478 struct CallOpType {
2479 using CallOp = T;
2480 };
2481 };
2482
matchAndRewrite(TF::IfOp op,PatternRewriter & rewriter) const2483 LogicalResult FoldConstantIfOp::matchAndRewrite(
2484 TF::IfOp op, PatternRewriter &rewriter) const {
2485 // Extract the constant cond value.
2486 DenseIntElementsAttr cond_attr;
2487 if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2488
2489 // Cond value must be a scalar.
2490 if (cond_attr.getNumElements() != 1) return failure();
2491
2492 // Select a branch function.
2493 bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2494 FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr();
2495
2496 // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp.
2497 auto rewrite = [&](auto op_type) {
2498 auto empty = rewriter.getStringAttr("");
2499 auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
2500 op.getLoc(), op.getResultTypes(), op.input(), func,
2501 /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
2502 CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
2503 rewriter.replaceOp(op, call_op.getResults());
2504 };
2505
2506 if (op.is_stateless())
2507 rewrite(CallOpType<PartitionedCallOp>{});
2508 else
2509 rewrite(CallOpType<StatefulPartitionedCallOp>{});
2510
2511 return success();
2512 }
2513 } // anonymous namespace
2514
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2515 void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2516 MLIRContext *context) {
2517 results.insert<FoldConstantIfOp, DropAttributes<IfOp>>(context);
2518 }
2519
2520 //===----------------------------------------------------------------------===//
2521 // IfRegionOp
2522 //===----------------------------------------------------------------------===//
2523
Verify(IfRegionOp op)2524 static LogicalResult Verify(IfRegionOp op) {
2525 TypeRange then_types =
2526 op.then_branch().front().getTerminator()->getOperandTypes();
2527 TypeRange else_types =
2528 op.else_branch().front().getTerminator()->getOperandTypes();
2529
2530 TypeRangeWithDesc results{op.getResultTypes(), "result"};
2531 TypeRangeWithDesc then_results{then_types, "then result"};
2532 TypeRangeWithDesc else_results{else_types, "else result"};
2533
2534 if (failed(VerifyTypeRangesAreCompatible(op, then_results, results)))
2535 return failure();
2536 if (failed(VerifyTypeRangesAreCompatible(op, else_results, results)))
2537 return failure();
2538 return success();
2539 }
2540
2541 namespace {
2542 class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
2543 public:
FoldConstantIfRegionOp(MLIRContext * context)2544 explicit FoldConstantIfRegionOp(MLIRContext *context)
2545 : OpRewritePattern<TF::IfRegionOp>(context) {}
2546 LogicalResult matchAndRewrite(TF::IfRegionOp op,
2547 PatternRewriter &rewriter) const override;
2548 };
2549
matchAndRewrite(TF::IfRegionOp op,PatternRewriter & rewriter) const2550 LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
2551 TF::IfRegionOp op, PatternRewriter &rewriter) const {
2552 // Extract the constant cond value.
2553 DenseIntElementsAttr cond_attr;
2554 if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
2555
2556 // IfRegion condition should always be a scalar. Select the region to fold to.
2557 bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
2558 Region ®ion = cond ? op.then_branch() : op.else_branch();
2559
2560 // If the IfRegion is stateless but the region being inlined itself is not
2561 // stateless, then inlining the region could cause a loss of information.
2562 // However, its probably better to fold the IfRegion instead of having the
2563 // dead branch stay.
2564
2565 // Inline the region in place of the IfRegion op, and forward the yield
2566 // inputs to the IfRegion op results. This is possible only if the yield
2567 // types match the result types.
2568 auto yield = cast<YieldOp>(region.front().getTerminator());
2569 auto updated_results = llvm::to_vector<4>(yield.getOperands());
2570
2571 // If the yield types do not match the IfRegion result types, add appropriate
2572 // casts.
2573 rewriter.setInsertionPoint(yield);
2574 for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
2575 auto &updated_result = std::get<1>(it);
2576 Type result_type = std::get<0>(it);
2577 if (result_type != updated_result.getType()) {
2578 updated_result =
2579 rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
2580 /*Truncate=*/rewriter.getBoolAttr(false));
2581 }
2582 }
2583 // Inline the region into the block containing the IfRegion.
2584 rewriter.mergeBlockBefore(®ion.front(), op);
2585 rewriter.eraseOp(yield);
2586 rewriter.replaceOp(op, updated_results);
2587 return success();
2588 }
2589 } // anonymous namespace
2590
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2591 void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2592 MLIRContext *context) {
2593 results.insert<FoldConstantIfRegionOp,
2594 CaseOrIfRegionEliminatePassThrough<TF::IfRegionOp>>(context);
2595 }
2596
2597 //===----------------------------------------------------------------------===//
2598 // InvertPermutationOp
2599 //===----------------------------------------------------------------------===//
2600
2601 // Verifies that the input is 1D.
Verify(InvertPermutationOp op)2602 static LogicalResult Verify(InvertPermutationOp op) {
2603 auto x_type = op.x().getType().cast<TensorType>();
2604 if (!x_type.hasRank()) return success();
2605 if (x_type.getShape().size() != 1)
2606 return op.emitOpError() << "requires input x to be 1-dimensional";
2607
2608 return success();
2609 }
2610
2611 //===----------------------------------------------------------------------===//
2612 // LeakyReluOp
2613 //===----------------------------------------------------------------------===//
2614
fold(ArrayRef<Attribute> operands)2615 OpFoldResult LeakyReluOp::fold(ArrayRef<Attribute> operands) {
2616 assert(operands.size() == 1 && "leaky relu has one operand");
2617
2618 // leaky_relu(x, alpha: 1) -> x
2619 if (alpha().convertToFloat() == 1.0f) return getOperand();
2620
2621 auto calculate = [&](FloatAttr arg) {
2622 APFloat val = arg.getValue();
2623 if (val.isNegative()) val = alpha() * val;
2624 return FloatAttr::get(arg.getType(), val);
2625 };
2626
2627 if (auto arg = operands[0].dyn_cast_or_null<FloatAttr>()) {
2628 return calculate(arg);
2629 } else if (auto arg = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
2630 if (auto elementAttr = arg.getSplatValue().dyn_cast<FloatAttr>())
2631 return DenseElementsAttr::get(arg.getType(), calculate(elementAttr));
2632 }
2633 return {};
2634 }
2635
GetContractionFusion()2636 Optional<ContractionFusion> LeakyReluOp::GetContractionFusion() {
2637 // Only f32 is supported for fusion.
2638 if (!T().isF32()) return None;
2639
2640 NamedAttribute alpha(Identifier::get("alpha", getContext()), alphaAttr());
2641 return ContractionFusion("LeakyRelu", /*additional_arguments=*/{},
2642 /*additional_attributes=*/{alpha});
2643 }
2644
2645 //===----------------------------------------------------------------------===//
2646 // LogOp
2647 //===----------------------------------------------------------------------===//
2648
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2649 void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2650 MLIRContext *context) {
2651 results.insert<LogOfSoftmax, LogToLog1p>(context);
2652 }
2653
2654 //===----------------------------------------------------------------------===//
2655 // LogicalNotOp
2656 //===----------------------------------------------------------------------===//
2657
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2658 void LogicalNotOp::getCanonicalizationPatterns(
2659 OwningRewritePatternList &results, MLIRContext *context) {
2660 results.insert<LogicalNotOfEqual, LogicalNotOfNotEqual, LogicalNotOfGreater,
2661 LogicalNotOfGreaterEqual, LogicalNotOfLess,
2662 LogicalNotOfLessEqual>(context);
2663 }
2664
2665 //===----------------------------------------------------------------------===//
2666 // MatrixBandPartOp
2667 //===----------------------------------------------------------------------===//
2668
Verify(MatrixBandPartOp op)2669 static LogicalResult Verify(MatrixBandPartOp op) {
2670 if (!HasRankAtLeast(op.input(), 2)) {
2671 return op.emitOpError()
2672 << "requires `input` to have rank of at least 2, but found "
2673 << op.input().getType();
2674 }
2675 if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
2676 return op.emitOpError()
2677 << "requires `num_lower` to have 0 dimensions, but found "
2678 << op.num_lower().getType();
2679 }
2680 if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
2681 return op.emitOpError()
2682 << "requires `num_upper` to have 0 dimensions, but found "
2683 << op.num_upper().getType();
2684 }
2685 return success();
2686 }
2687
2688 //===----------------------------------------------------------------------===//
2689 // MatrixSetDiagOp
2690 //===----------------------------------------------------------------------===//
2691 //
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2692 void MatrixSetDiagOp::getCanonicalizationPatterns(
2693 OwningRewritePatternList &results, MLIRContext *context) {
2694 results.insert<MatrixSetDiagToV3>(context);
2695 }
2696
2697 //===----------------------------------------------------------------------===//
2698 // MatrixSetDiagV2Op
2699 //===----------------------------------------------------------------------===//
2700
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2701 void MatrixSetDiagV2Op::getCanonicalizationPatterns(
2702 OwningRewritePatternList &results, MLIRContext *context) {
2703 results.insert<MatrixSetDiagV2ToV3>(context);
2704 }
2705
2706 //===----------------------------------------------------------------------===//
2707 // MaxOp
2708 //===----------------------------------------------------------------------===//
2709
build(OpBuilder & builder,OperationState & result,Value input,Value reduction_indices,BoolAttr keep_dims)2710 void MaxOp::build(OpBuilder &builder, OperationState &result, Value input,
2711 Value reduction_indices, BoolAttr keep_dims) {
2712 Type out_ty =
2713 InferReductionOpType(input, reduction_indices, keep_dims, &builder);
2714 build(builder, result, out_ty, input, reduction_indices, keep_dims);
2715 }
2716
2717 //===----------------------------------------------------------------------===//
2718 // MaxPoolOp
2719 //===----------------------------------------------------------------------===//
2720
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2721 LogicalResult MaxPoolOp::FoldOperandsPermutation(
2722 ArrayRef<int64_t> permutation) {
2723 return ::mlir::TF::FoldOperandsPermutation(
2724 permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
2725 }
2726
UpdateDataFormat(StringRef new_data_format)2727 LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
2728 StringRef src_data_format = data_format();
2729
2730 auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
2731 if (perm.empty()) return failure();
2732
2733 // Update data_format attribute and result types.
2734 if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
2735 return failure();
2736
2737 stridesAttr(ShuffleArrayAttr(strides(), perm));
2738 explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
2739 ksizeAttr(ShuffleArrayAttr(ksize(), perm));
2740
2741 return success();
2742 }
2743
GetOptimalLayout(const RuntimeDevices & devices)2744 StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
2745 // Keep current data format if no GPUs are available or if explicit placement
2746 // does not allow to use GPU for this operation.
2747 if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
2748 return data_format();
2749
2750 // Defaults to NCHW.
2751 return "NCHW";
2752 }
2753
2754 //===----------------------------------------------------------------------===//
2755 // MaxPoolGradOp
2756 //===----------------------------------------------------------------------===//
2757
Verify(MaxPoolGradOp op)2758 static LogicalResult Verify(MaxPoolGradOp op) {
2759 if (!IsOfRankOrUnranked(op.orig_input(), 4)) {
2760 return op.emitOpError() << "requires orig_input to be rank 4";
2761 }
2762 if (!IsOfRankOrUnranked(op.orig_output(), 4)) {
2763 return op.emitOpError() << "requires orig_output to be rank 4";
2764 }
2765 if (!IsOfRankOrUnranked(op.grad(), 4)) {
2766 return op.emitOpError() << "requires grad to be rank 4";
2767 }
2768 return success();
2769 }
2770
2771 //===----------------------------------------------------------------------===//
2772 // MeanOp
2773 //===----------------------------------------------------------------------===//
2774
FoldOperandsPermutation(ArrayRef<int64_t> permutation)2775 LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
2776 // Reduction indices must be defined by a constant operation.
2777 auto reduction_op =
2778 dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
2779 if (!reduction_op) return failure();
2780
2781 auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
2782 if (!reductions_value) return failure();
2783
2784 // Prepare new reduction indices according to operand permutation.
2785 SmallVector<int32_t, 4> shuffled_reduction;
2786 llvm::transform(reductions_value.getIntValues(),
2787 std::back_inserter(shuffled_reduction),
2788 [&](APInt idx) { return permutation[idx.getSExtValue()]; });
2789
2790 // Add constant operation with a new reduction indices.
2791 OpBuilder builder(getOperation());
2792 auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
2793 builder.getIntegerType(32));
2794 auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
2795 auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
2796
2797 // Use new reduction indices.
2798 setOperand(1, shuffled_reduction_op);
2799
2800 return success();
2801 }
2802
2803 //===----------------------------------------------------------------------===//
2804 // MulOp
2805 //===----------------------------------------------------------------------===//
2806
fold(ArrayRef<Attribute> operands)2807 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
2808 return IdentityArithmeticOpFolder<MulOp>(*this, operands);
2809 }
2810
2811 } // namespace TF
2812 } // namespace mlir
2813
2814 //===----------------------------------------------------------------------===//
2815 // TableGen'd op method definitions
2816 //===----------------------------------------------------------------------===//
2817
2818 #define GET_OP_CLASSES
2819 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc"
2820