1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <queue>
17 #include <string>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/Identifier.h" // from @llvm-project
26 #include "mlir/IR/Location.h" // from @llvm-project
27 #include "mlir/IR/MLIRContext.h" // from @llvm-project
28 #include "mlir/IR/SymbolTable.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
31
32 // The cmd line flag to specify the allowlist of functions. Rest are trimmed
33 // after this pass is run.
34 // NOLINTNEXTLINE
35 static llvm::cl::list<std::string> trim_funcs_allowlist(
36 "tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"),
37 llvm::cl::desc("comma separated list of allowlisted functions. The first "
38 "function specified will be used as main."),
39 llvm::cl::CommaSeparated);
40
41 namespace mlir {
42 namespace TFL {
43 namespace {
44
45 // The pass to trim functions before we legalize to TFL
46 // dialect using the specified allowlist.
47 class TrimFunctionsPass
48 : public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
49 public:
TrimFunctionsPass()50 explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {}
TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)51 explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)
52 : trim_funcs_allowlist_(trim_funcs_allowlist) {}
53
54 private:
55 void runOnOperation() override;
56 bool TrimModule();
57 void Verify();
58
59 llvm::ArrayRef<std::string> trim_funcs_allowlist_;
60 };
61
runOnOperation()62 void TrimFunctionsPass::runOnOperation() {
63 // trim the functions in the module using the trim_funcs_allowlist_
64 // by removing functions not in the allowlist.
65 if (TrimModule()) {
66 // verify the updated module is still valid, if not signal the
67 // pass as failed.
68 Verify();
69 }
70 }
71
TrimModule()72 bool TrimFunctionsPass::TrimModule() {
73 // if no trim_funcs_allowlist_ is specified, this pass is a no-op.
74 if (trim_funcs_allowlist_.empty()) return false;
75
76 llvm::SmallVector<FuncOp, 4> funcs_to_trim;
77 for (auto func : getOperation().getOps<FuncOp>()) {
78 if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) {
79 // If no main is specified in the allowlist, use the 1st func
80 // in trim_funcs_allowlist as the main.
81 // TODO(ashwinm): Currently tflite flatbuffer export assumes there is
82 // always a main. This is strictly not required for TFlite. We need to
83 // remove that restriction once we have support to attribute the main
84 // tensorflow function in MLIR TF import using an entry_point attr.
85 if (!llvm::is_contained(trim_funcs_allowlist_, "main") &&
86 func.getName() == trim_funcs_allowlist_[0]) {
87 func.setName("main");
88 }
89 } else {
90 funcs_to_trim.push_back(func);
91 }
92 }
93
94 // remove all unexported functions from the module.
95 for (auto func : funcs_to_trim) {
96 func.erase();
97 }
98 return true;
99 }
100
101 // validate that all reachable functions from the remaining functions are
102 // also in the allowlist.
Verify()103 void TrimFunctionsPass::Verify() {
104 // TODO(ashwinm): Instead, we should make sure that references to all
105 // SymbolRefAttrs of all ops are present.
106 SymbolTable symbol_table = SymbolTable(getOperation());
107 llvm::SetVector<FuncOp> reachable_funcs;
108 for (auto func : getOperation().getOps<FuncOp>()) {
109 auto walk_result = func.walk([&](CallOp op) -> WalkResult {
110 if (!symbol_table.lookup<FuncOp>(op.getCallee()))
111 return getOperation().emitError()
112 << func.getName() << " is not in the funcs allowlist";
113 return WalkResult::advance();
114 });
115 if (walk_result.wasInterrupted()) return signalPassFailure();
116 }
117 }
118
119 } // namespace
120
121 // Creates an instance of the TensorFlow Lite dialect TrimFunctions
122 /// pass.
CreateTrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)123 std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
124 llvm::ArrayRef<std::string> trim_funcs_allowlist) {
125 return std::make_unique<TrimFunctionsPass>(trim_funcs_allowlist);
126 }
127
128 static PassRegistration<TrimFunctionsPass> pass(
129 "tfl-trim-funcs-tf",
130 "Trim functions to restrict them to a specified allowlist prior to "
131 "legalization to TensorFlow lite dialect");
132
133 } // namespace TFL
134 } // namespace mlir
135