• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
18 
19 #include <string>
20 
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "tensorflow/cc/saved_model/bundle_v2.h"
26 #include "tensorflow/cc/saved_model/loader.h"
27 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
28 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace tensorflow {
36 
37 ABSL_CONST_INIT extern const char kImportModelDefaultGraphFuncName[];
38 
39 // Given a GraphDef, returns a MLIR module containing the graph, expressed with
40 // tf_executor dialect.
41 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
42     const GraphDef& graphdef, const GraphDebugInfo& debug_info,
43     const GraphImportConfig& specs, mlir::MLIRContext* context,
44     bool add_default_attributes = true);
45 
46 // Given a Graph, returns a MLIR module containing the graph, expressed with
47 // tf_executor dialect.
48 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
49     const Graph& graph, const GraphDebugInfo& debug_info,
50     const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
51     mlir::MLIRContext* context);
52 
53 // [Experimental]
54 // Given a Function, returns a MLIR module containing the graph, expressed with
55 // tf_executor dialect.
56 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
57     const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
58     mlir::MLIRContext* context);
59 
60 // Given a SavedModel, returns a MLIR module containing the functions, expressed
61 // with tf_executor dialect.
62 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
63     SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
64     absl::Span<std::string> exported_names, bool add_default_attributes = true);
65 
66 // Given a V1 SavedModel, returns a MLIR module containing the functions,
67 // expressed with tf_executor dialect.
68 stream_executor::port::StatusOr<mlir::OwningModuleRef>
69 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
70                           absl::Span<std::string> exported_names,
71                           mlir::MLIRContext* context, MLIRImportOptions options,
72                           bool lift_variables = true);
73 
74 // Given a V1 SavedModel, returns a MLIR module containing the functions,
75 // expressed with tf_executor dialect. It does not require a session to be
76 // created and it does not perform any graph transformation. If `exported_names`
77 // is absl::nullopt, all signatures will be imported. Otherwise, only names
78 // in `exported_names` are imported.
79 //
80 // Note that the word `Lite` means it is a lighter version compared to
81 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
82 //
83 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
84 stream_executor::port::StatusOr<mlir::OwningModuleRef>
85 ConvertSavedModelV1ToMlirLite(
86     const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info,
87     absl::optional<absl::Span<const std::string>> exported_names,
88     mlir::MLIRContext* context, MLIRImportOptions options);
89 
90 // SavedModelMLIRImportInput is an adapter class for users to inject custom
91 // graph transformation logic on Tensorflow graphs before importing to MLIR. It
92 // serves as the source that provides the subgraphs requested by the savedmodel
93 // MLIR importer, and at the same time it allows the implementation of this
94 // class to transform the graph before feeding it to the importer.
95 class SavedModelMLIRImportInput {
96  public:
SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)97   SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
98                             const GraphDebugInfo& debug_info)
99       : meta_graph_def_(meta_graph_def), debug_info_(debug_info) {
100     DCHECK(meta_graph_def);
101   }
102 
103   virtual ~SavedModelMLIRImportInput();
104 
105   // The original MetaGraphDef of the savedmodel.
meta_graph_def()106   const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; }
107 
debug_info()108   const GraphDebugInfo& debug_info() const { return debug_info_; }
109 
110   // GetSubGraph() is expected to return a tensorflow::Graph that contains the
111   // node set specified in `specs`. The implementation is free to transform the
112   // graph in the original savedmodel as needed, as long as it produces the same
113   // results and effects. `name` is a unique identifier for this subgraph, so
114   // the implementation can use it for eg. debugging or caching compilation
115   // results.
116   virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph(
117       absl::string_view name, const GraphImportConfig& specs) = 0;
118 
119  private:
120   const MetaGraphDef* meta_graph_def_ = nullptr;
121   GraphDebugInfo debug_info_;
122 };
123 
124 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module
125 // containing the functions, expressed with tf_executor dialect. It does not
126 // require a session to be created. If `exported_names` is absl::nullopt, all
127 // signatures will be imported. Otherwise, only names in `exported_names` are
128 // imported.
129 
130 //
131 // Note that the word `Lite` means it is a lighter version compared to
132 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
133 //
134 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
135 stream_executor::port::StatusOr<mlir::OwningModuleRef>
136 ConvertSavedModelV1ToMlirLite(
137     SavedModelMLIRImportInput& input,
138     absl::optional<absl::Span<const std::string>> exported_names,
139     mlir::MLIRContext* context);
140 
141 // Serialize a MLIR module to a string.
142 std::string MlirModuleToString(mlir::ModuleOp module,
143                                mlir::OpPrintingFlags flags);
144 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);
145 
146 }  // namespace tensorflow
147 
148 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
149