• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
20 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
21 #include "mlir/Pass/Pass.h"  // from @llvm-project
22 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 
26 namespace mlir {
27 namespace TF {
28 namespace {
29 
30 // -------------------------------------------------------------------------- //
31 // Fuse ContractionFusableInterface operations into contraction operation.
32 // -------------------------------------------------------------------------- //
33 
34 template <typename BaseOp, typename FusedOp>
35 class FuseIntoContractionOp : public RewritePattern {
36  public:
FuseIntoContractionOp()37   FuseIntoContractionOp()
38       : RewritePattern(PatternBenefit(1), MatchAnyOpTypeTag()) {}
39 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const40   LogicalResult matchAndRewrite(Operation *op,
41                                 PatternRewriter &rewriter) const override {
42     auto fusable = dyn_cast<ContractionFusableInterface>(op);
43     if (!fusable) return failure();
44 
45     auto failed = [&](Twine message) -> LogicalResult {
46       return rewriter.notifyMatchFailure(op, message);
47     };
48 
49     // Check if the operation can be fused.
50     Optional<ContractionFusion> fusion = fusable.GetContractionFusion();
51     if (!fusion.hasValue()) {
52       return failed("returned empty contraction fusion specification");
53     }
54 
55     // Check if preceeding operation is a BaseOp or FusedOp that we can use for
56     // fusion.
57     Operation *fuse_into = nullptr;
58     Value operand = op->getOperand(0);
59 
60     if (BaseOp base_op = operand.getDefiningOp<BaseOp>()) {
61       fuse_into = base_op.getOperation();
62     } else if (FusedOp fused_op = operand.getDefiningOp<FusedOp>()) {
63       fuse_into = fused_op.getOperation();
64     } else {
65       return failed("input to the fusable op must be a " +
66                     BaseOp::getOperationName() + " or a " +
67                     FusedOp::getOperationName());
68     }
69 
70     // Operand result must have one use, because we do not want to compute
71     // tensor contraction twice.
72     if (!fuse_into->getResult(0).hasOneUse()) {
73       return failed("fused into op result must have one use");
74     }
75 
76     MLIRContext *ctx = op->getContext();
77 
78     // Build a fused MatMul operation from a base MatMul and a fusion.
79     SmallVector<Location, 3> locations = {fuse_into->getLoc(), op->getLoc()};
80     Location loc = rewriter.getFusedLoc(locations);
81 
82     // Fusion can't change the type of a fused operation.
83     Type result_ty = fuse_into->getResult(0).getType();
84 
85     // Copy all operands from a base op and add additional fusion arguments.
86     SmallVector<Value, 3> operands(fuse_into->getOperands());
87     for (int idx : fusion->additional_arguments) {
88       operands.push_back(op->getOperand(idx));
89     }
90 
91     // Copy attributes from a base op that we fuse into (e.g. copy all
92     // MatMul or Conv attributes to the fused operation).
93     SmallVector<NamedAttribute, 4> attrs(fuse_into->getAttrs().begin(),
94                                          fuse_into->getAttrs().end());
95 
96     // Add fusion specific additional attributes.
97     for (auto attr : fusion->additional_attributes) {
98       attrs.push_back(attr);
99     }
100 
101     // Add a fused output kernel name to the list of fusions.
102     Identifier fusion_id = Identifier::get("fusion", ctx);
103     StringAttr fusion_name = StringAttr::get(ctx, fusion->output_kernel);
104 
105     auto is_fusion = [&](const NamedAttribute &attr) -> bool {
106       return attr.first == fusion_id;
107     };
108 
109     if (isa<BaseOp>(fuse_into)) {
110       NamedAttribute fusion_attr(fusion_id, ArrayAttr::get(ctx, {fusion_name}));
111       attrs.push_back(fusion_attr);
112 
113     } else {
114       ArrayAttr arr =
115           llvm::find_if(attrs, is_fusion)->second.template cast<ArrayAttr>();
116       llvm::erase_if(attrs, is_fusion);
117 
118       auto rng = arr.getAsRange<Attribute>();
119       SmallVector<Attribute, 4> updated(rng.begin(), rng.end());
120       updated.push_back(fusion_name);
121 
122       attrs.push_back(NamedAttribute(fusion_id, ArrayAttr::get(ctx, updated)));
123     }
124 
125     // Update all uses of a fusable op with a new fused operation.
126     Value fused = rewriter.create<FusedOp>(loc, result_ty, operands, attrs);
127     rewriter.replaceOp(op, {fused});
128 
129     return failure();
130   }
131 };
132 
133 // -------------------------------------------------------------------------- //
134 
135 using FuseIntoMatMulOp = FuseIntoContractionOp<MatMulOp, _JitFusedMatMulOp>;
136 
137 struct ContractionFusionPass
138     : public PassWrapper<ContractionFusionPass, FunctionPass> {
139   void runOnFunction() override;
140 };
141 
runOnFunction()142 void ContractionFusionPass::runOnFunction() {
143   FuncOp func = getFunction();
144 
145   OwningRewritePatternList patterns;
146   patterns.insert<FuseIntoMatMulOp>();
147   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
148 }
149 
150 }  // namespace
151 
CreateContractionFusionPass()152 std::unique_ptr<OperationPass<FuncOp>> CreateContractionFusionPass() {
153   return std::make_unique<ContractionFusionPass>();
154 }
155 
156 static PassRegistration<ContractionFusionPass> pass(
157     "tf-contraction-fusion",
158     "Fuses operations implementing ContractionFusionInterface into the "
159     "contraction operations");
160 
161 }  // namespace TF
162 }  // namespace mlir
163