• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Traits.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 
32 namespace mlir {
33 namespace {
34 
35 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
36  public:
ConvertResultsBroadcastableShapeOp()37   ConvertResultsBroadcastableShapeOp()
38       : RewritePattern(1, MatchAnyOpTypeTag()) {}
39 
40   LogicalResult matchAndRewrite(Operation* op,
41                                 PatternRewriter& rewriter) const override;
42 
43  private:
44   template <typename Op>
45   LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
46 
47   LogicalResult RewriteOp(
48       Operation* op, PatternRewriter& rewriter,
49       const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
50                                SmallVectorImpl<int64_t>&)>&
51           get_broadcasted_shape) const;
52 
53   LogicalResult RewriteBatchMatMulV2Op(Operation* op,
54                                        PatternRewriter& rewriter) const;
55 };
56 
57 class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
58  public:
59   void runOnFunction() override;
60 };
61 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const62 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
63     Operation* op, PatternRewriter& rewriter) const {
64   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
65     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
66 
67   // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
68   // incompatible_shape_error is `true` (what is also checked by the verifier).
69   if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
70   if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
71   if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
72 
73   return failure();
74 }
75 
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const76 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
77     Operation* op, PatternRewriter& rewriter) const {
78   auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
79   if (!matmul_op) return failure();
80 
81   // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
82   // shape of op's first/left-hand-side operand and `shape_y` is the shape of
83   // op's second/right-hand-side operand.
84   const auto get_broadcasted_shape =
85       [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
86           SmallVectorImpl<int64_t>& result_shape) {
87         if (shape_x.size() < 2 || shape_y.size() < 2) {
88           return false;
89         }
90 
91         // Checks outer dimensions (i.e., the dimensions higher than 2D) are
92         // broadcastable. If true, then get the broadcasted shape for outer
93         // dimension.
94         if (!OpTrait::util::getBroadcastedShape(
95                 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
96           return false;
97         }
98 
99         const int x_row =
100             matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
101         const int x_col =
102             !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
103 
104         const int y_row =
105             matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
106         const int y_col =
107             !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
108 
109         // Checks that matrix multiply can perform a valid contraction.
110         if (x_col != y_row) {
111           result_shape.clear();
112           return false;
113         }
114 
115         result_shape.push_back(x_row);
116         result_shape.push_back(y_col);
117         return true;
118       };
119 
120   return RewriteOp(op, rewriter, get_broadcasted_shape);
121 }
122 
123 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const124 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
125     Operation* op, PatternRewriter& rewriter) const {
126   auto eq_op = llvm::dyn_cast_or_null<Op>(op);
127   if (eq_op && eq_op.incompatible_shape_error())
128     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
129   return failure();
130 }
131 
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const132 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
133     Operation* op, PatternRewriter& rewriter,
134     const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
135                              SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
136     const {
137   if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
138     return failure();
139 
140   // Check that the result shape is fully defined.
141   auto result_type =
142       op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
143   if (!result_type || !result_type.hasStaticShape()) return failure();
144 
145   bool changed = false;
146   for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
147     // Check that the i'th operand is a broadcast.
148     auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
149         op->getOpOperand(i).get().getDefiningOp());
150     if (!broadcast) continue;
151 
152     // Check that the operand of the broadcast has fully defined shape.
153     auto broadcast_arg_type =
154         broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
155     if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
156 
157     // Check that the other argument has fully defined shape.
158     auto argument_type = op->getOpOperand(1 - i)
159                              .get()
160                              .getType()
161                              .dyn_cast_or_null<RankedTensorType>();
162     if (!argument_type || !argument_type.hasStaticShape()) continue;
163 
164     // Get the unbroadcasted shapes in the operand order.
165     std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
166     operand_shapes[i] = broadcast_arg_type.getShape();
167     operand_shapes[1 - i] = argument_type.getShape();
168 
169     // Check that the input of the broadcast and the other operand is broadcast
170     // compatible.
171     llvm::SmallVector<int64_t, 4> broadcasted_shape;
172     if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
173                                broadcasted_shape))
174       continue;
175 
176     // Check that an implicit broadcast between the operand of the broadcast and
177     // the other argument would result in the same type as the result type.
178     if (broadcasted_shape != result_type.getShape()) continue;
179 
180     // Update the operand of the op to be the operand of the broadcast.
181     rewriter.updateRootInPlace(
182         op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
183     changed = true;
184   }
185   return success(changed);
186 }
187 
runOnFunction()188 void BroadcastFoldPass::runOnFunction() {
189   OwningRewritePatternList patterns;
190   auto func = getFunction();
191 
192   patterns.insert<ConvertResultsBroadcastableShapeOp>();
193   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
194 }
195 
196 }  // namespace
197 
198 namespace TF {
CreateBroadcastFoldPass()199 std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
200   return absl::make_unique<BroadcastFoldPass>();
201 }
202 }  // namespace TF
203 
204 static PassRegistration<BroadcastFoldPass> pass(
205     "tf-broadcast-fold",
206     "Fold explicit broadcasts into the following operations if they support "
207     "implicit broadcasting on their operand.");
208 
209 }  // namespace mlir
210