1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Traits.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/PatternMatch.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31
32 namespace mlir {
33 namespace {
34
35 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
36 public:
ConvertResultsBroadcastableShapeOp()37 ConvertResultsBroadcastableShapeOp()
38 : RewritePattern(1, MatchAnyOpTypeTag()) {}
39
40 LogicalResult matchAndRewrite(Operation* op,
41 PatternRewriter& rewriter) const override;
42
43 private:
44 template <typename Op>
45 LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
46
47 LogicalResult RewriteOp(
48 Operation* op, PatternRewriter& rewriter,
49 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
50 SmallVectorImpl<int64_t>&)>&
51 get_broadcasted_shape) const;
52
53 LogicalResult RewriteBatchMatMulV2Op(Operation* op,
54 PatternRewriter& rewriter) const;
55 };
56
57 class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
58 public:
59 void runOnFunction() override;
60 };
61
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const62 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
63 Operation* op, PatternRewriter& rewriter) const {
64 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
65 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
66
67 // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
68 // incompatible_shape_error is `true` (what is also checked by the verifier).
69 if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
70 if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
71 if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
72
73 return failure();
74 }
75
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const76 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
77 Operation* op, PatternRewriter& rewriter) const {
78 auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
79 if (!matmul_op) return failure();
80
81 // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
82 // shape of op's first/left-hand-side operand and `shape_y` is the shape of
83 // op's second/right-hand-side operand.
84 const auto get_broadcasted_shape =
85 [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
86 SmallVectorImpl<int64_t>& result_shape) {
87 if (shape_x.size() < 2 || shape_y.size() < 2) {
88 return false;
89 }
90
91 // Checks outer dimensions (i.e., the dimensions higher than 2D) are
92 // broadcastable. If true, then get the broadcasted shape for outer
93 // dimension.
94 if (!OpTrait::util::getBroadcastedShape(
95 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
96 return false;
97 }
98
99 const int x_row =
100 matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
101 const int x_col =
102 !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
103
104 const int y_row =
105 matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
106 const int y_col =
107 !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
108
109 // Checks that matrix multiply can perform a valid contraction.
110 if (x_col != y_row) {
111 result_shape.clear();
112 return false;
113 }
114
115 result_shape.push_back(x_row);
116 result_shape.push_back(y_col);
117 return true;
118 };
119
120 return RewriteOp(op, rewriter, get_broadcasted_shape);
121 }
122
123 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const124 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
125 Operation* op, PatternRewriter& rewriter) const {
126 auto eq_op = llvm::dyn_cast_or_null<Op>(op);
127 if (eq_op && eq_op.incompatible_shape_error())
128 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
129 return failure();
130 }
131
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const132 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
133 Operation* op, PatternRewriter& rewriter,
134 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
135 SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
136 const {
137 if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
138 return failure();
139
140 // Check that the result shape is fully defined.
141 auto result_type =
142 op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
143 if (!result_type || !result_type.hasStaticShape()) return failure();
144
145 bool changed = false;
146 for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
147 // Check that the i'th operand is a broadcast.
148 auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
149 op->getOpOperand(i).get().getDefiningOp());
150 if (!broadcast) continue;
151
152 // Check that the operand of the broadcast has fully defined shape.
153 auto broadcast_arg_type =
154 broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
155 if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
156
157 // Check that the other argument has fully defined shape.
158 auto argument_type = op->getOpOperand(1 - i)
159 .get()
160 .getType()
161 .dyn_cast_or_null<RankedTensorType>();
162 if (!argument_type || !argument_type.hasStaticShape()) continue;
163
164 // Get the unbroadcasted shapes in the operand order.
165 std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
166 operand_shapes[i] = broadcast_arg_type.getShape();
167 operand_shapes[1 - i] = argument_type.getShape();
168
169 // Check that the input of the broadcast and the other operand is broadcast
170 // compatible.
171 llvm::SmallVector<int64_t, 4> broadcasted_shape;
172 if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
173 broadcasted_shape))
174 continue;
175
176 // Check that an implicit broadcast between the operand of the broadcast and
177 // the other argument would result in the same type as the result type.
178 if (broadcasted_shape != result_type.getShape()) continue;
179
180 // Update the operand of the op to be the operand of the broadcast.
181 rewriter.updateRootInPlace(
182 op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
183 changed = true;
184 }
185 return success(changed);
186 }
187
runOnFunction()188 void BroadcastFoldPass::runOnFunction() {
189 OwningRewritePatternList patterns;
190 auto func = getFunction();
191
192 patterns.insert<ConvertResultsBroadcastableShapeOp>();
193 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
194 }
195
196 } // namespace
197
198 namespace TF {
CreateBroadcastFoldPass()199 std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
200 return absl::make_unique<BroadcastFoldPass>();
201 }
202 } // namespace TF
203
204 static PassRegistration<BroadcastFoldPass> pass(
205 "tf-broadcast-fold",
206 "Fold explicit broadcasts into the following operations if they support "
207 "implicit broadcasting on their operand.");
208
209 } // namespace mlir
210