• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <iostream>
16 
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
31 
32 namespace mlir {
33 namespace TF {
34 namespace {
35 
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
37 
38 // Returns a TF Constant tensor with the passed in values.
GetI64ConstantTensor(PatternRewriter & rewriter,ArrayRef<int64_t> values,Location location)39 TF::ConstOp GetI64ConstantTensor(PatternRewriter &rewriter,
40                                  ArrayRef<int64_t> values, Location location) {
41   auto cst_attr = rewriter.getI64TensorAttr(values);
42   return rewriter.create<TF::ConstOp>(location, cst_attr.getType(), cst_attr);
43 }
44 
45 // Rewrites broadcast->reshape to a reshape->broadcast that reduces
46 // the rank of the input and output of the broadcast.
47 class SimplifyBroadcastReshape : public OpRewritePattern<BroadcastToOp> {
48   using OpRewritePattern<BroadcastToOp>::OpRewritePattern;
49 
matchAndRewrite(BroadcastToOp op,PatternRewriter & rewriter) const50   LogicalResult matchAndRewrite(BroadcastToOp op,
51                                 PatternRewriter &rewriter) const override {
52     // Only rewrite if the Broadcast has only one consumer.
53     if (!op.output().hasOneUse()) return failure();
54 
55     Operation *user = *op.output().getUsers().begin();
56 
57     auto reshape_op = llvm::dyn_cast_or_null<ReshapeOp>(user);
58     if (!reshape_op) return failure();
59 
60     auto reshape_type = reshape_op.output().getType().cast<ShapedType>();
61 
62     if (!reshape_type.hasStaticShape()) return failure();
63     ArrayRef<int64_t> reshape_shape = reshape_type.getShape();
64 
65     auto input_type = op.input().getType().cast<ShapedType>();
66     auto output_type = op.output().getType().cast<ShapedType>();
67 
68     if (!input_type.hasRank() || !output_type.hasRank()) return failure();
69 
70     // The pattern attempts to reduce the rank of the input to BroadcastTo.
71     // Thus, we fail to match if the consuming reshape rank is larger.
72     ArrayRef<int64_t> input_shape = input_type.getShape();
73     if (reshape_shape.size() > input_shape.size()) return failure();
74 
75     // Extend the input shape with leading 1s to match the broadcast shape.
76     ArrayRef<int64_t> broadcast_shape = output_type.getShape();
77     SmallVector<int64_t, 4> input_shape_extended;
78     input_shape_extended.append(broadcast_shape.size() - input_shape.size(), 1);
79     input_shape_extended.append(input_shape.begin(), input_shape.end());
80 
81     // Collect non-unit dims and corresponding dim in the input shape.
82     SmallVector<int64_t, 4> input_carryover_dims;
83     SmallVector<int64_t, 4> non_unit_dims;
84 
85     for (int i = 0; i < input_shape_extended.size(); i++) {
86       int64_t dim = broadcast_shape[i];
87       if (dim != 1) {
88         non_unit_dims.push_back(dim);
89         input_carryover_dims.push_back(input_shape_extended[i]);
90       }
91     }
92 
93     // If the reshape rank is less than the number of non-unit dimensions
94     // of the broadcast, then the reshape collapses non-unit dimensions.
95     // TODO(rahulsp) : Handle this case with more careful checks.
96     if (reshape_shape.size() < non_unit_dims.size()) return failure();
97 
98     SmallVector<int64_t, 4> old_reshape_non_unit_dims;
99     SmallVector<int64_t, 4> new_reshape_dims;
100     int new_reshape_dim_idx = 0;
101     for (int64_t dim : reshape_shape) {
102       int new_reshape_dim = 1;
103       if (dim != 1) {
104         old_reshape_non_unit_dims.push_back(dim);
105         if (new_reshape_dim_idx < input_carryover_dims.size()) {
106           new_reshape_dim = input_carryover_dims[new_reshape_dim_idx];
107           new_reshape_dim_idx++;
108         }
109       }
110       new_reshape_dims.push_back(new_reshape_dim);
111     }
112 
113     if (non_unit_dims != old_reshape_non_unit_dims) return failure();
114 
115     if (failed(VerifyShapeOfReshapeOp(new_reshape_dims))) return failure();
116 
117     Type el_ty = getElementTypeOrSelf(op.getType());
118     TF::ConstOp new_reshape_shape = GetI64ConstantTensor(
119         rewriter, ArrayRef<int64_t>(new_reshape_dims), op.getLoc());
120     auto new_reshape_type = RankedTensorType::get(new_reshape_dims, el_ty);
121     ReshapeOp new_reshape =
122         rewriter.create<ReshapeOp>(new_reshape_shape.getLoc(), new_reshape_type,
123                                    op.input(), new_reshape_shape);
124     TF::ConstOp new_broadcast_shape =
125         GetI64ConstantTensor(rewriter, reshape_shape, op.getLoc());
126     rewriter.replaceOpWithNewOp<BroadcastToOp>(
127         reshape_op, reshape_op.output().getType(), new_reshape,
128         new_broadcast_shape);
129     return success();
130   }
131 };
132 
133 // Canonicalize operations in functions.
134 struct TensorFlowOptimizePass
135     : public TensorFlowOptimizePassBase<TensorFlowOptimizePass> {
initializemlir::TF::__anon92670cf30111::TensorFlowOptimizePass136   LogicalResult initialize(MLIRContext *context) override {
137     OwningRewritePatternList pattern_list(context);
138     populateWithGenerated(pattern_list);
139     pattern_list.insert<SimplifyBroadcastReshape>(context);
140     patterns = std::move(pattern_list);
141     return success();
142   }
143 
runOnFunctionmlir::TF::__anon92670cf30111::TensorFlowOptimizePass144   void runOnFunction() override {
145     auto func = getFunction();
146     if (failed(applyPatternsAndFoldGreedily(func, patterns)))
147       signalPassFailure();
148   }
149 
150   FrozenRewritePatternSet patterns;
151 };
152 
153 }  // namespace
154 
CreateTFStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)155 void CreateTFStandardPipeline(OpPassManager &pm,
156                               const StandardPipelineOptions &options) {
157   OpPassManager &func_pm = pm.nest<FuncOp>();
158 
159   // First operates on the executor dialect:
160   // - remove dead islands.
161   // - fuse islands as much as possible.
162   // - materialize the eventual "pass-through" ops by inlining their content.
163   func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
164   func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
165   func_pm.addPass(CreateMaterializePassthroughOpPass());
166   if (options.form_clusters)
167     func_pm.addPass(TFDevice::CreateClusterFormationPass());
168 
169   // Hopefully there is a single island left, or there wasn't any to begin with.
170   // We now run the optimizer which operates mostly inside islands.
171   func_pm.addPass(createCanonicalizerPass());
172   pm.addPass(CreateTFShapeInferencePass());
173   if (options.enable_inliner) {
174     pm.addPass(createInlinerPass());
175   }
176   pm.addPass(createSymbolDCEPass());
177   pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
178   pm.addNestedPass<FuncOp>(createCSEPass());
179 }
180 
CreateTFOptimizePass()181 std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass() {
182   return std::make_unique<TensorFlowOptimizePass>();
183 }
184 
185 // Registers a pipeline builder function for the default canonicalize/optimizer.
186 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
187     "tf-standard-pipeline",
188     "Run all the passes involved in transforming/optimizing the graph after "
189     "importing into MLIR, without any target specialization.",
190     CreateTFStandardPipeline);
191 
192 }  // namespace TF
193 }  // namespace mlir
194