1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Traits.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/PatternMatch.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31
32 namespace mlir {
33 namespace {
34
35 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
36 public:
ConvertResultsBroadcastableShapeOp(MLIRContext * context)37 ConvertResultsBroadcastableShapeOp(MLIRContext* context)
38 : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
39
40 LogicalResult matchAndRewrite(Operation* op,
41 PatternRewriter& rewriter) const override;
42
43 private:
44 template <typename Op>
45 LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
46
47 LogicalResult RewriteOp(
48 Operation* op, PatternRewriter& rewriter,
49 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
50 SmallVectorImpl<int64_t>&)>&
51 get_broadcasted_shape) const;
52
53 LogicalResult RewriteBatchMatMulV2Op(Operation* op,
54 PatternRewriter& rewriter) const;
55 };
56
57 class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
58 public:
getArgument() const59 StringRef getArgument() const final { return "tf-broadcast-fold"; }
60
getDescription() const61 StringRef getDescription() const final {
62 return "Fold explicit broadcasts into the following operations if they "
63 "support implicit broadcasting on their operand.";
64 }
65
66 void runOnFunction() override;
67 };
68
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const69 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
70 Operation* op, PatternRewriter& rewriter) const {
71 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
72 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
73
74 // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
75 // incompatible_shape_error is `true` (what is also checked by the verifier).
76 if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
77 if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
78 if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
79
80 return failure();
81 }
82
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const83 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
84 Operation* op, PatternRewriter& rewriter) const {
85 auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
86 if (!matmul_op) return failure();
87
88 // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
89 // shape of op's first/left-hand-side operand and `shape_y` is the shape of
90 // op's second/right-hand-side operand.
91 const auto get_broadcasted_shape =
92 [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
93 SmallVectorImpl<int64_t>& result_shape) {
94 if (shape_x.size() < 2 || shape_y.size() < 2) {
95 return false;
96 }
97
98 // Checks outer dimensions (i.e., the dimensions higher than 2D) are
99 // broadcastable. If true, then get the broadcasted shape for outer
100 // dimension.
101 if (!OpTrait::util::getBroadcastedShape(
102 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
103 return false;
104 }
105
106 const int x_row =
107 matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
108 const int x_col =
109 !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
110
111 const int y_row =
112 matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
113 const int y_col =
114 !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
115
116 // Checks that matrix multiply can perform a valid contraction.
117 if (x_col != y_row) {
118 result_shape.clear();
119 return false;
120 }
121
122 result_shape.push_back(x_row);
123 result_shape.push_back(y_col);
124 return true;
125 };
126
127 return RewriteOp(op, rewriter, get_broadcasted_shape);
128 }
129
130 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const131 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
132 Operation* op, PatternRewriter& rewriter) const {
133 auto eq_op = llvm::dyn_cast_or_null<Op>(op);
134 if (eq_op && eq_op.incompatible_shape_error())
135 return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
136 return failure();
137 }
138
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const139 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
140 Operation* op, PatternRewriter& rewriter,
141 const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
142 SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
143 const {
144 if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
145 return failure();
146
147 // Check that the result shape is fully defined.
148 auto result_type =
149 op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
150 if (!result_type || !result_type.hasStaticShape()) return failure();
151
152 bool changed = false;
153 for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
154 // Check that the i'th operand is a broadcast.
155 auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
156 op->getOpOperand(i).get().getDefiningOp());
157 if (!broadcast) continue;
158
159 // Check that the operand of the broadcast has fully defined shape.
160 auto broadcast_arg_type =
161 broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
162 if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
163
164 // Check that the other argument has fully defined shape.
165 auto argument_type = op->getOpOperand(1 - i)
166 .get()
167 .getType()
168 .dyn_cast_or_null<RankedTensorType>();
169 if (!argument_type || !argument_type.hasStaticShape()) continue;
170
171 // Get the unbroadcasted shapes in the operand order.
172 std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
173 operand_shapes[i] = broadcast_arg_type.getShape();
174 operand_shapes[1 - i] = argument_type.getShape();
175
176 // Check that the input of the broadcast and the other operand is broadcast
177 // compatible.
178 llvm::SmallVector<int64_t, 4> broadcasted_shape;
179 if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
180 broadcasted_shape))
181 continue;
182
183 // Check that an implicit broadcast between the operand of the broadcast and
184 // the other argument would result in the same type as the result type.
185 if (broadcasted_shape != result_type.getShape()) continue;
186
187 // Update the operand of the op to be the operand of the broadcast.
188 rewriter.updateRootInPlace(
189 op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
190 changed = true;
191 }
192 return success(changed);
193 }
194
runOnFunction()195 void BroadcastFoldPass::runOnFunction() {
196 OwningRewritePatternList patterns(&getContext());
197 auto func = getFunction();
198
199 patterns.insert<ConvertResultsBroadcastableShapeOp>(func.getContext());
200 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
201 }
202
203 } // namespace
204
205 namespace TF {
CreateBroadcastFoldPass()206 std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
207 return absl::make_unique<BroadcastFoldPass>();
208 }
209 } // namespace TF
210
211 static PassRegistration<BroadcastFoldPass> pass;
212
213 } // namespace mlir
214