• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 
18 #include "absl/strings/str_join.h"
19 #include "absl/strings/string_view.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/core/platform/macros.h"
28 
29 namespace mlir {
30 namespace quant {
31 namespace {
32 
33 constexpr char kEntryFunctionAttr[] = "tf.entry_function";
34 constexpr char kExportedNameAttr[] = "tf_saved_model.exported_names";
35 constexpr char kIndexPathAttr[] = "tf_saved_model.index_path";
36 
37 // The ConvertMlirToGraphdef requires the provided input module to have a main
38 // function, which might not exist in case of multi-signature graphs. In that
39 // case, this pass will create a new main function, which calls signature
40 // functions.
41 class InsertMainFunctionPass
42     : public PassWrapper<InsertMainFunctionPass, OperationPass<ModuleOp>> {
43  public:
44   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMainFunctionPass)
45 
InsertMainFunctionPass()46   explicit InsertMainFunctionPass() {}
47 
getArgument() const48   StringRef getArgument() const override { return "quant-add-main-function"; }
49 
getDescription() const50   StringRef getDescription() const override {
51     return "Insert the main function to the module if it is missing.";
52   }
53 
54   void runOnOperation() override;
55 };
56 
57 // Checks if the module has a main function.
HasMainFunction(ModuleOp & module)58 bool HasMainFunction(ModuleOp& module) {
59   StringAttr main_func_id = StringAttr::get(module.getContext(), "main");
60   for (auto function : module.getOps<func::FuncOp>()) {
61     if (function.getName() == main_func_id) return true;
62   }
63   return false;
64 }
65 
66 // Checks if a FuncOp is exported.
IsExported(func::FuncOp & op)67 bool IsExported(func::FuncOp& op) {
68   auto exported_names = op->getAttrOfType<ArrayAttr>(kExportedNameAttr);
69   return exported_names && !exported_names.empty();
70 }
71 
72 // Check if a function is an entry function.
IsEntryFunction(func::FuncOp & op)73 bool IsEntryFunction(func::FuncOp& op) {
74   return op->hasAttr(kEntryFunctionAttr);
75 }
76 
77 // Sets a function to be private so it can be referred internally.
SetFunctionPrivate(func::FuncOp & func)78 void SetFunctionPrivate(func::FuncOp& func) {
79   func.setVisibility(SymbolTable::Visibility::Private);
80 
81   // The `tf_saved_model` attributes can only be appied to public functions.
82   for (auto& attr : func->getAttrs()) {
83     StringRef attr_name = attr.getName().getValue();
84     if (attr_name.startswith("tf_saved_model.")) {
85       func->removeAttr(attr_name);
86     }
87   }
88 
89   for (int i = 0; i < func.getNumArguments(); ++i) {
90     for (auto& attr : func.getArgAttrs(i)) {
91       const StringAttr& attr_name = attr.getName();
92       if (attr_name.getValue().startswith("tf_saved_model.")) {
93         func.removeArgAttr(i, attr_name);
94       }
95     }
96   }
97   for (int i = 0; i < func.getNumResults(); ++i) {
98     for (auto& attr : func.getResultAttrs(i)) {
99       const StringAttr& attr_name = attr.getName();
100       if (attr_name.getValue().startswith("tf_saved_model.")) {
101         func.removeResultAttr(i, attr_name);
102       }
103     }
104   }
105 }
106 
107 // Creates a main function which calls other exported functions.
CreateMainFunction(ModuleOp & module)108 bool CreateMainFunction(ModuleOp& module) {
109   MLIRContext* context = module.getContext();
110   OpBuilder builder(context);
111 
112   // Collects argument and result types.
113   llvm::SmallVector<Location> arg_locs;
114   llvm::SmallVector<Type> arg_types, result_types;
115   std::vector<std::string> input_names, output_names;
116   for (auto function : module.getOps<func::FuncOp>()) {
117     if (function.isPrivate() || !IsExported(function)) continue;
118     arg_types.append(function.getArgumentTypes().begin(),
119                      function.getArgumentTypes().end());
120     auto& return_op = function.getBody().getBlocks().front().back();
121     result_types.append(return_op.getOperandTypes().begin(),
122                         return_op.getOperandTypes().end());
123     for (const auto& arg : function.getArguments()) {
124       arg_locs.push_back(arg.getLoc());
125     }
126 
127     // Collects input and output node names. These names are prefixed with the
128     // signature key in SavedModel. They also contain the index suffix. Ex:
129     // "<signature key>_<name>:0", where 0 is the index.
130     if (auto tf_attrs =
131             function->getAttrOfType<DictionaryAttr>(kEntryFunctionAttr)) {
132       if (auto inputs_attr = tf_attrs.get("inputs")) {
133         std::string inputs_attr_str =
134             inputs_attr.cast<StringAttr>().getValue().str();
135         std::vector<std::string> inputs_attr_vec =
136             absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty());
137         input_names.insert(input_names.end(), inputs_attr_vec.begin(),
138                            inputs_attr_vec.end());
139       }
140       if (auto outputs_attr = tf_attrs.get("outputs")) {
141         std::string outputs_attr_str =
142             outputs_attr.cast<StringAttr>().getValue().str();
143         std::vector<std::string> outputs_attr_vec =
144             absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty());
145         output_names.insert(output_names.end(), outputs_attr_vec.begin(),
146                             outputs_attr_vec.end());
147       }
148     }
149   }
150 
151   // Creates a new main function.
152   auto func_type = FunctionType::get(context, arg_types, result_types);
153   auto main_func =
154       builder.create<func::FuncOp>(module.getLoc(), "main", func_type);
155   builder.createBlock(&main_func.getBody(), main_func.begin(), arg_types,
156                       arg_locs);
157   SmallVector<NamedAttribute> func_attrs;
158   func_attrs.push_back(
159       {StringAttr::get(context, "inputs"),
160        StringAttr::get(context, absl::StrJoin(input_names, ","))});
161   func_attrs.push_back(
162       {StringAttr::get(context, "outputs"),
163        StringAttr::get(context, absl::StrJoin(output_names, ","))});
164   auto dictAttr = DictionaryAttr::get(context, func_attrs);
165   main_func->setAttr(StringAttr::get(context, kEntryFunctionAttr), dictAttr);
166   main_func->setAttr(kExportedNameAttr, builder.getStrArrayAttr({"main"}));
167 
168   if (input_names.size() != main_func.getNumArguments() ||
169       output_names.size() != main_func.getNumResults()) {
170     module.emitError()
171         << "Number of inputs and outputs in the tf.entry_function attribute "
172            "mismatched. [Input] Expected: "
173         << input_names.size() << ", got: " << main_func.getNumArguments()
174         << ". [Output] Expected: " << output_names.size()
175         << ", got: " << main_func.getNumResults();
176     return false;
177   }
178 
179   int numArgs = main_func.getNumArguments();
180   for (int i = 0; i < numArgs; ++i) {
181     main_func.setArgAttr(
182         i, kIndexPathAttr,
183         mlir::ArrayAttr::get(context,
184                              {mlir::StringAttr::get(context, input_names[i])}));
185   }
186 
187   int numResults = main_func.getNumResults();
188   for (int i = 0; i < numResults; ++i) {
189     main_func.setResultAttr(
190         i, kIndexPathAttr,
191         mlir::ArrayAttr::get(
192             context, {mlir::StringAttr::get(context, output_names[i])}));
193   }
194 
195   // Creates PartitionedCall ops to call exported functions.
196   auto guard = OpBuilder::InsertionGuard(builder);
197   int arg_idx = 0;
198   int result_idx = 0;
199   llvm::SmallVector<Value> returning_values;
200   for (auto function : module.getOps<func::FuncOp>()) {
201     if (function.isPrivate() || !IsExported(function) ||
202         !IsEntryFunction(function)) {
203       continue;
204     }
205 
206     llvm::ArrayRef<BlockArgument> new_args = llvm::makeArrayRef(
207         main_func.getArguments().begin() + arg_idx, function.getNumArguments());
208     arg_idx += function.getNumArguments();
209     llvm::ArrayRef<Type> new_types = llvm::makeArrayRef(
210         result_types.begin() + result_idx, function.getNumResults());
211     result_idx += function.getNumResults();
212 
213     auto call_op = builder.create<TF::PartitionedCallOp>(
214         module.getLoc(), new_types, new_args,
215         SymbolRefAttr::get(context, function.getSymName()),
216         /*config=*/builder.getStringAttr(""),
217         /*config_proto=*/builder.getStringAttr(""),
218         /*executor_type=*/builder.getStringAttr(""));
219     returning_values.append(call_op.getResults().begin(),
220                             call_op.getResults().end());
221     SetFunctionPrivate(function);
222   }
223   builder.create<mlir::func::ReturnOp>(main_func.getBody().getLoc(),
224                                        returning_values);
225 
226   // Adds the new function to symbol table.
227   SymbolTable symbol_table(module);
228   symbol_table.insert(main_func);
229   return true;
230 }
231 
runOnOperation()232 void InsertMainFunctionPass::runOnOperation() {
233   ModuleOp module = getOperation();
234   if (!HasMainFunction(module)) {
235     if (!CreateMainFunction(module)) {
236       signalPassFailure();
237     }
238   }
239 }
240 
241 }  // namespace
242 
CreateInsertMainFunctionPass()243 std::unique_ptr<OperationPass<ModuleOp>> CreateInsertMainFunctionPass() {
244   return std::make_unique<InsertMainFunctionPass>();
245 }
246 
__anonf8a5ef800202null247 static PassRegistration<InsertMainFunctionPass> pass([] {
248   return CreateInsertMainFunctionPass();
249 });
250 
251 }  // namespace quant
252 }  // namespace mlir
253