• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Traits.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 
32 namespace mlir {
33 namespace {
34 
35 class ConvertResultsBroadcastableShapeOp : public RewritePattern {
36  public:
ConvertResultsBroadcastableShapeOp(MLIRContext * context)37   ConvertResultsBroadcastableShapeOp(MLIRContext* context)
38       : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
39 
40   LogicalResult matchAndRewrite(Operation* op,
41                                 PatternRewriter& rewriter) const override;
42 
43  private:
44   template <typename Op>
45   LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
46 
47   LogicalResult RewriteOp(
48       Operation* op, PatternRewriter& rewriter,
49       const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
50                                SmallVectorImpl<int64_t>&)>&
51           get_broadcasted_shape) const;
52 
53   LogicalResult RewriteBatchMatMulV2Op(Operation* op,
54                                        PatternRewriter& rewriter) const;
55 };
56 
57 class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
58  public:
getArgument() const59   StringRef getArgument() const final { return "tf-broadcast-fold"; }
60 
getDescription() const61   StringRef getDescription() const final {
62     return "Fold explicit broadcasts into the following operations if they "
63            "support implicit broadcasting on their operand.";
64   }
65 
66   void runOnFunction() override;
67 };
68 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const69 LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
70     Operation* op, PatternRewriter& rewriter) const {
71   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
72     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
73 
74   // tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
75   // incompatible_shape_error is `true` (what is also checked by the verifier).
76   if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
77   if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
78   if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
79 
80   return failure();
81 }
82 
RewriteBatchMatMulV2Op(Operation * op,PatternRewriter & rewriter) const83 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
84     Operation* op, PatternRewriter& rewriter) const {
85   auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
86   if (!matmul_op) return failure();
87 
88   // Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
89   // shape of op's first/left-hand-side operand and `shape_y` is the shape of
90   // op's second/right-hand-side operand.
91   const auto get_broadcasted_shape =
92       [&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
93           SmallVectorImpl<int64_t>& result_shape) {
94         if (shape_x.size() < 2 || shape_y.size() < 2) {
95           return false;
96         }
97 
98         // Checks outer dimensions (i.e., the dimensions higher than 2D) are
99         // broadcastable. If true, then get the broadcasted shape for outer
100         // dimension.
101         if (!OpTrait::util::getBroadcastedShape(
102                 shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
103           return false;
104         }
105 
106         const int x_row =
107             matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
108         const int x_col =
109             !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
110 
111         const int y_row =
112             matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
113         const int y_col =
114             !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
115 
116         // Checks that matrix multiply can perform a valid contraction.
117         if (x_col != y_row) {
118           result_shape.clear();
119           return false;
120         }
121 
122         result_shape.push_back(x_row);
123         result_shape.push_back(y_col);
124         return true;
125       };
126 
127   return RewriteOp(op, rewriter, get_broadcasted_shape);
128 }
129 
130 template <typename Op>
RewriteEqOp(Operation * op,PatternRewriter & rewriter) const131 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
132     Operation* op, PatternRewriter& rewriter) const {
133   auto eq_op = llvm::dyn_cast_or_null<Op>(op);
134   if (eq_op && eq_op.incompatible_shape_error())
135     return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
136   return failure();
137 }
138 
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::function<bool (ArrayRef<int64_t>,ArrayRef<int64_t>,SmallVectorImpl<int64_t> &)> & get_broadcasted_shape) const139 LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
140     Operation* op, PatternRewriter& rewriter,
141     const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
142                              SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
143     const {
144   if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
145     return failure();
146 
147   // Check that the result shape is fully defined.
148   auto result_type =
149       op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
150   if (!result_type || !result_type.hasStaticShape()) return failure();
151 
152   bool changed = false;
153   for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
154     // Check that the i'th operand is a broadcast.
155     auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
156         op->getOpOperand(i).get().getDefiningOp());
157     if (!broadcast) continue;
158 
159     // Check that the operand of the broadcast has fully defined shape.
160     auto broadcast_arg_type =
161         broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
162     if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
163 
164     // Check that the other argument has fully defined shape.
165     auto argument_type = op->getOpOperand(1 - i)
166                              .get()
167                              .getType()
168                              .dyn_cast_or_null<RankedTensorType>();
169     if (!argument_type || !argument_type.hasStaticShape()) continue;
170 
171     // Get the unbroadcasted shapes in the operand order.
172     std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
173     operand_shapes[i] = broadcast_arg_type.getShape();
174     operand_shapes[1 - i] = argument_type.getShape();
175 
176     // Check that the input of the broadcast and the other operand is broadcast
177     // compatible.
178     llvm::SmallVector<int64_t, 4> broadcasted_shape;
179     if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
180                                broadcasted_shape))
181       continue;
182 
183     // Check that an implicit broadcast between the operand of the broadcast and
184     // the other argument would result in the same type as the result type.
185     if (broadcasted_shape != result_type.getShape()) continue;
186 
187     // Update the operand of the op to be the operand of the broadcast.
188     rewriter.updateRootInPlace(
189         op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
190     changed = true;
191   }
192   return success(changed);
193 }
194 
runOnFunction()195 void BroadcastFoldPass::runOnFunction() {
196   OwningRewritePatternList patterns(&getContext());
197   auto func = getFunction();
198 
199   patterns.insert<ConvertResultsBroadcastableShapeOp>(func.getContext());
200   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
201 }
202 
203 }  // namespace
204 
205 namespace TF {
CreateBroadcastFoldPass()206 std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass() {
207   return absl::make_unique<BroadcastFoldPass>();
208 }
209 }  // namespace TF
210 
211 static PassRegistration<BroadcastFoldPass> pass;
212 
213 }  // namespace mlir
214