• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <iostream>
16 
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h"
30 
31 namespace mlir {
32 namespace TF {
33 namespace {
34 
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_optimize.inc"
36 
37 // Returns a TF Constant tensor with the passed in values.
GetI64ConstantTensor(PatternRewriter & rewriter,ArrayRef<int64_t> values,Location location)38 TF::ConstOp GetI64ConstantTensor(PatternRewriter &rewriter,
39                                  ArrayRef<int64_t> values, Location location) {
40   auto cst_attr = rewriter.getI64TensorAttr(values);
41   return rewriter.create<TF::ConstOp>(location, cst_attr.getType(), cst_attr);
42 }
43 
44 // Rewrites broadcast->reshape to a reshape->broadcast that reduces
45 // the rank of the input and output of the broadcast.
46 class SimplifyBroadcastReshape : public OpRewritePattern<BroadcastToOp> {
47   using OpRewritePattern<BroadcastToOp>::OpRewritePattern;
48 
matchAndRewrite(BroadcastToOp op,PatternRewriter & rewriter) const49   LogicalResult matchAndRewrite(BroadcastToOp op,
50                                 PatternRewriter &rewriter) const override {
51     // Only rewrite if the Broadcast has only one consumer.
52     if (!op.output().hasOneUse()) return failure();
53 
54     Operation *user = *op.output().getUsers().begin();
55 
56     auto reshape_op = llvm::dyn_cast_or_null<ReshapeOp>(user);
57     if (!reshape_op) return failure();
58 
59     auto reshape_type = reshape_op.output().getType().cast<ShapedType>();
60 
61     if (!reshape_type.hasStaticShape()) return failure();
62     ArrayRef<int64_t> reshape_shape = reshape_type.getShape();
63 
64     auto input_type = op.input().getType().cast<ShapedType>();
65     auto output_type = op.output().getType().cast<ShapedType>();
66 
67     if (!input_type.hasRank() || !output_type.hasRank()) return failure();
68 
69     // The pattern attempts to reduce the rank of the input to BroadcastTo.
70     // Thus, we fail to match if the consuming reshape rank is larger.
71     ArrayRef<int64_t> input_shape = input_type.getShape();
72     if (reshape_shape.size() > input_shape.size()) return failure();
73 
74     // Extend the input shape with leading 1s to match the broadcast shape.
75     ArrayRef<int64_t> broadcast_shape = output_type.getShape();
76     SmallVector<int64_t, 4> input_shape_extended;
77     input_shape_extended.append(broadcast_shape.size() - input_shape.size(), 1);
78     input_shape_extended.append(input_shape.begin(), input_shape.end());
79 
80     // Collect non-unit dims and corresponding dim in the input shape.
81     SmallVector<int64_t, 4> input_carryover_dims;
82     SmallVector<int64_t, 4> non_unit_dims;
83 
84     for (int i = 0; i < input_shape_extended.size(); i++) {
85       int64_t dim = broadcast_shape[i];
86       if (dim != 1) {
87         non_unit_dims.push_back(dim);
88         input_carryover_dims.push_back(input_shape_extended[i]);
89       }
90     }
91 
92     // If the reshape rank is less than the number of non-unit dimensions
93     // of the broadcast, then the reshape collapses non-unit dimensions.
94     // TODO(rahulsp) : Handle this case with more careful checks.
95     if (reshape_shape.size() < non_unit_dims.size()) return failure();
96 
97     SmallVector<int64_t, 4> old_reshape_non_unit_dims;
98     SmallVector<int64_t, 4> new_reshape_dims;
99     int new_reshape_dim_idx = 0;
100     for (int64_t dim : reshape_shape) {
101       int new_reshape_dim = 1;
102       if (dim != 1) {
103         old_reshape_non_unit_dims.push_back(dim);
104         if (new_reshape_dim_idx < input_carryover_dims.size()) {
105           new_reshape_dim = input_carryover_dims[new_reshape_dim_idx];
106           new_reshape_dim_idx++;
107         }
108       }
109       new_reshape_dims.push_back(new_reshape_dim);
110     }
111 
112     if (non_unit_dims != old_reshape_non_unit_dims) return failure();
113 
114     if (failed(VerifyShapeOfReshapeOp(new_reshape_dims))) return failure();
115 
116     Type el_ty = getElementTypeOrSelf(op.getType());
117     TF::ConstOp new_reshape_shape = GetI64ConstantTensor(
118         rewriter, ArrayRef<int64_t>(new_reshape_dims), op.getLoc());
119     auto new_reshape_type = RankedTensorType::get(new_reshape_dims, el_ty);
120     ReshapeOp new_reshape =
121         rewriter.create<ReshapeOp>(new_reshape_shape.getLoc(), new_reshape_type,
122                                    op.input(), new_reshape_shape);
123     TF::ConstOp new_broadcast_shape =
124         GetI64ConstantTensor(rewriter, reshape_shape, op.getLoc());
125     rewriter.replaceOpWithNewOp<BroadcastToOp>(
126         reshape_op, reshape_op.output().getType(), new_reshape,
127         new_broadcast_shape);
128     return success();
129   }
130 };
131 
132 // Canonicalize operations in functions.
133 struct TFOptimizePass : public PassWrapper<TFOptimizePass, FunctionPass> {
runOnFunctionmlir::TF::__anonee9ba9ee0111::TFOptimizePass134   void runOnFunction() override {
135     OwningRewritePatternList patterns;
136     auto func = getFunction();
137     populateWithGenerated(&getContext(), patterns);
138     patterns.insert<SimplifyBroadcastReshape>(&getContext());
139     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
140   }
141 };
142 
143 }  // namespace
144 
145 // NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
CreateTFStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)146 void CreateTFStandardPipeline(OpPassManager &pm,
147                               const StandardPipelineOptions &options) {
148   OpPassManager &func_pm = pm.nest<FuncOp>();
149 
150   // First operates on the executor dialect:
151   // - remove dead islands.
152   // - fuse islands as much as possible.
153   // - materialize the eventual "pass-through" ops by inlining their content.
154   func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass());
155   func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass());
156   func_pm.addPass(CreateMaterializePassthroughOpPass());
157   if (options.form_clusters)
158     func_pm.addPass(TFDevice::CreateClusterFormationPass());
159 
160   // Hopefully there is a single island left, or there wasn't any to begin with.
161   // We now run the optimizer which operates mostly inside islands.
162   func_pm.addPass(createCanonicalizerPass());
163   pm.addPass(CreateTFShapeInferencePass());
164   if (options.enable_inliner) {
165     pm.addPass(createInlinerPass());
166   }
167   pm.addPass(createSymbolDCEPass());
168   pm.addNestedPass<FuncOp>(CreateTFOptimizePass());
169   pm.addNestedPass<FuncOp>(createCSEPass());
170 }
171 
CreateTFOptimizePass()172 std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass() {
173   return std::make_unique<TFOptimizePass>();
174 }
175 
176 static PassRegistration<TFOptimizePass> pass("tf-optimize", "Optimizes TF.");
177 
178 // Registers a pipeline builder function for the default canonicalize/optimizer.
179 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
180     "tf-standard-pipeline",
181     "Run all the passes involved in transforming/optimizing the graph after "
182     "importing into MLIR, without any target specialization.",
183     CreateTFStandardPipeline);
184 
185 }  // namespace TF
186 }  // namespace mlir
187