1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17
18 #include "absl/strings/str_join.h"
19 #include "absl/strings/string_view.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/OperationSupport.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/core/platform/macros.h"
28
29 namespace mlir {
30 namespace quant {
31 namespace {
32
33 constexpr char kEntryFunctionAttr[] = "tf.entry_function";
34 constexpr char kExportedNameAttr[] = "tf_saved_model.exported_names";
35 constexpr char kIndexPathAttr[] = "tf_saved_model.index_path";
36
37 // The ConvertMlirToGraphdef requires the provided input module to have a main
38 // function, which might not exist in case of multi-signature graphs. In that
39 // case, this pass will create a new main function, which calls signature
40 // functions.
41 class InsertMainFunctionPass
42 : public PassWrapper<InsertMainFunctionPass, OperationPass<ModuleOp>> {
43 public:
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMainFunctionPass)
45
InsertMainFunctionPass()46 explicit InsertMainFunctionPass() {}
47
getArgument() const48 StringRef getArgument() const override { return "quant-add-main-function"; }
49
getDescription() const50 StringRef getDescription() const override {
51 return "Insert the main function to the module if it is missing.";
52 }
53
54 void runOnOperation() override;
55 };
56
57 // Checks if the module has a main function.
HasMainFunction(ModuleOp & module)58 bool HasMainFunction(ModuleOp& module) {
59 StringAttr main_func_id = StringAttr::get(module.getContext(), "main");
60 for (auto function : module.getOps<func::FuncOp>()) {
61 if (function.getName() == main_func_id) return true;
62 }
63 return false;
64 }
65
66 // Checks if a FuncOp is exported.
IsExported(func::FuncOp & op)67 bool IsExported(func::FuncOp& op) {
68 auto exported_names = op->getAttrOfType<ArrayAttr>(kExportedNameAttr);
69 return exported_names && !exported_names.empty();
70 }
71
72 // Check if a function is an entry function.
IsEntryFunction(func::FuncOp & op)73 bool IsEntryFunction(func::FuncOp& op) {
74 return op->hasAttr(kEntryFunctionAttr);
75 }
76
77 // Sets a function to be private so it can be referred internally.
SetFunctionPrivate(func::FuncOp & func)78 void SetFunctionPrivate(func::FuncOp& func) {
79 func.setVisibility(SymbolTable::Visibility::Private);
80
81 // The `tf_saved_model` attributes can only be appied to public functions.
82 for (auto& attr : func->getAttrs()) {
83 StringRef attr_name = attr.getName().getValue();
84 if (attr_name.startswith("tf_saved_model.")) {
85 func->removeAttr(attr_name);
86 }
87 }
88
89 for (int i = 0; i < func.getNumArguments(); ++i) {
90 for (auto& attr : func.getArgAttrs(i)) {
91 const StringAttr& attr_name = attr.getName();
92 if (attr_name.getValue().startswith("tf_saved_model.")) {
93 func.removeArgAttr(i, attr_name);
94 }
95 }
96 }
97 for (int i = 0; i < func.getNumResults(); ++i) {
98 for (auto& attr : func.getResultAttrs(i)) {
99 const StringAttr& attr_name = attr.getName();
100 if (attr_name.getValue().startswith("tf_saved_model.")) {
101 func.removeResultAttr(i, attr_name);
102 }
103 }
104 }
105 }
106
107 // Creates a main function which calls other exported functions.
CreateMainFunction(ModuleOp & module)108 bool CreateMainFunction(ModuleOp& module) {
109 MLIRContext* context = module.getContext();
110 OpBuilder builder(context);
111
112 // Collects argument and result types.
113 llvm::SmallVector<Location> arg_locs;
114 llvm::SmallVector<Type> arg_types, result_types;
115 std::vector<std::string> input_names, output_names;
116 for (auto function : module.getOps<func::FuncOp>()) {
117 if (function.isPrivate() || !IsExported(function)) continue;
118 arg_types.append(function.getArgumentTypes().begin(),
119 function.getArgumentTypes().end());
120 auto& return_op = function.getBody().getBlocks().front().back();
121 result_types.append(return_op.getOperandTypes().begin(),
122 return_op.getOperandTypes().end());
123 for (const auto& arg : function.getArguments()) {
124 arg_locs.push_back(arg.getLoc());
125 }
126
127 // Collects input and output node names. These names are prefixed with the
128 // signature key in SavedModel. They also contain the index suffix. Ex:
129 // "<signature key>_<name>:0", where 0 is the index.
130 if (auto tf_attrs =
131 function->getAttrOfType<DictionaryAttr>(kEntryFunctionAttr)) {
132 if (auto inputs_attr = tf_attrs.get("inputs")) {
133 std::string inputs_attr_str =
134 inputs_attr.cast<StringAttr>().getValue().str();
135 std::vector<std::string> inputs_attr_vec =
136 absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty());
137 input_names.insert(input_names.end(), inputs_attr_vec.begin(),
138 inputs_attr_vec.end());
139 }
140 if (auto outputs_attr = tf_attrs.get("outputs")) {
141 std::string outputs_attr_str =
142 outputs_attr.cast<StringAttr>().getValue().str();
143 std::vector<std::string> outputs_attr_vec =
144 absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty());
145 output_names.insert(output_names.end(), outputs_attr_vec.begin(),
146 outputs_attr_vec.end());
147 }
148 }
149 }
150
151 // Creates a new main function.
152 auto func_type = FunctionType::get(context, arg_types, result_types);
153 auto main_func =
154 builder.create<func::FuncOp>(module.getLoc(), "main", func_type);
155 builder.createBlock(&main_func.getBody(), main_func.begin(), arg_types,
156 arg_locs);
157 SmallVector<NamedAttribute> func_attrs;
158 func_attrs.push_back(
159 {StringAttr::get(context, "inputs"),
160 StringAttr::get(context, absl::StrJoin(input_names, ","))});
161 func_attrs.push_back(
162 {StringAttr::get(context, "outputs"),
163 StringAttr::get(context, absl::StrJoin(output_names, ","))});
164 auto dictAttr = DictionaryAttr::get(context, func_attrs);
165 main_func->setAttr(StringAttr::get(context, kEntryFunctionAttr), dictAttr);
166 main_func->setAttr(kExportedNameAttr, builder.getStrArrayAttr({"main"}));
167
168 if (input_names.size() != main_func.getNumArguments() ||
169 output_names.size() != main_func.getNumResults()) {
170 module.emitError()
171 << "Number of inputs and outputs in the tf.entry_function attribute "
172 "mismatched. [Input] Expected: "
173 << input_names.size() << ", got: " << main_func.getNumArguments()
174 << ". [Output] Expected: " << output_names.size()
175 << ", got: " << main_func.getNumResults();
176 return false;
177 }
178
179 int numArgs = main_func.getNumArguments();
180 for (int i = 0; i < numArgs; ++i) {
181 main_func.setArgAttr(
182 i, kIndexPathAttr,
183 mlir::ArrayAttr::get(context,
184 {mlir::StringAttr::get(context, input_names[i])}));
185 }
186
187 int numResults = main_func.getNumResults();
188 for (int i = 0; i < numResults; ++i) {
189 main_func.setResultAttr(
190 i, kIndexPathAttr,
191 mlir::ArrayAttr::get(
192 context, {mlir::StringAttr::get(context, output_names[i])}));
193 }
194
195 // Creates PartitionedCall ops to call exported functions.
196 auto guard = OpBuilder::InsertionGuard(builder);
197 int arg_idx = 0;
198 int result_idx = 0;
199 llvm::SmallVector<Value> returning_values;
200 for (auto function : module.getOps<func::FuncOp>()) {
201 if (function.isPrivate() || !IsExported(function) ||
202 !IsEntryFunction(function)) {
203 continue;
204 }
205
206 llvm::ArrayRef<BlockArgument> new_args = llvm::makeArrayRef(
207 main_func.getArguments().begin() + arg_idx, function.getNumArguments());
208 arg_idx += function.getNumArguments();
209 llvm::ArrayRef<Type> new_types = llvm::makeArrayRef(
210 result_types.begin() + result_idx, function.getNumResults());
211 result_idx += function.getNumResults();
212
213 auto call_op = builder.create<TF::PartitionedCallOp>(
214 module.getLoc(), new_types, new_args,
215 SymbolRefAttr::get(context, function.getSymName()),
216 /*config=*/builder.getStringAttr(""),
217 /*config_proto=*/builder.getStringAttr(""),
218 /*executor_type=*/builder.getStringAttr(""));
219 returning_values.append(call_op.getResults().begin(),
220 call_op.getResults().end());
221 SetFunctionPrivate(function);
222 }
223 builder.create<mlir::func::ReturnOp>(main_func.getBody().getLoc(),
224 returning_values);
225
226 // Adds the new function to symbol table.
227 SymbolTable symbol_table(module);
228 symbol_table.insert(main_func);
229 return true;
230 }
231
runOnOperation()232 void InsertMainFunctionPass::runOnOperation() {
233 ModuleOp module = getOperation();
234 if (!HasMainFunction(module)) {
235 if (!CreateMainFunction(module)) {
236 signalPassFailure();
237 }
238 }
239 }
240
241 } // namespace
242
CreateInsertMainFunctionPass()243 std::unique_ptr<OperationPass<ModuleOp>> CreateInsertMainFunctionPass() {
244 return std::make_unique<InsertMainFunctionPass>();
245 }
246
__anonf8a5ef800202null247 static PassRegistration<InsertMainFunctionPass> pass([] {
248 return CreateInsertMainFunctionPass();
249 });
250
251 } // namespace quant
252 } // namespace mlir
253