• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
18 
19 #include <string>
20 
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "tensorflow/cc/saved_model/bundle_v2.h"
26 #include "tensorflow/cc/saved_model/loader.h"
27 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h"
28 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 namespace tensorflow {
36 
37 // Given a GraphDef, returns a MLIR module containing the graph, expressed with
38 // tf_executor dialect.
39 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir(
40     const GraphDef& graphdef, const GraphDebugInfo& debug_info,
41     const GraphImportConfig& specs, mlir::MLIRContext* context,
42     bool add_default_attributes = true);
43 
44 // Given a Graph, returns a MLIR module containing the graph, expressed with
45 // tf_executor dialect.
46 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
47     const Graph& graph, const GraphDebugInfo& debug_info,
48     const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
49     mlir::MLIRContext* context);
50 
51 // [Experimental]
52 // Given a Function, returns a MLIR module containing the graph, expressed with
53 // tf_executor dialect.
54 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
55     const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def,
56     mlir::MLIRContext* context);
57 
58 // Given a SavedModel, returns a MLIR module containing the functions, expressed
59 // with tf_executor dialect.
60 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
61     SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
62     absl::Span<std::string> exported_names, bool add_default_attributes = true);
63 
64 // Given a V1 SavedModel, returns a MLIR module containing the functions,
65 // expressed with tf_executor dialect.
66 stream_executor::port::StatusOr<mlir::OwningModuleRef>
67 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
68                           absl::Span<std::string> exported_names,
69                           mlir::MLIRContext* context,
70                           MLIRImportOptions options);
71 
72 // Given a V1 SavedModel, returns a MLIR module containing the functions,
73 // expressed with tf_executor dialect. It does not require a session to be
74 // created and it does not perform any graph transformation.
75 //
76 // Note that the word `Lite` means it is a lighter version compared to
77 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
78 //
79 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
80 stream_executor::port::StatusOr<mlir::OwningModuleRef>
81 ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def,
82                               const GraphDebugInfo& debug_info,
83                               absl::Span<std::string> exported_names,
84                               mlir::MLIRContext* context,
85                               MLIRImportOptions options);
86 
87 // SavedModelMLIRImportInput is an adapter class for users to inject custom
88 // graph transformation logic on Tensorflow graphs before importing to MLIR. It
89 // serves as the source that provides the subgraphs requested by the savedmodel
90 // MLIR importer, and at the same time it allows the implementation of this
91 // class to transform the graph before feeding it to the importer.
92 class SavedModelMLIRImportInput {
93  public:
SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)94   SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def,
95                             const GraphDebugInfo& debug_info)
96       : meta_graph_def_(meta_graph_def), debug_info_(debug_info) {
97     DCHECK(meta_graph_def);
98   }
99 
100   virtual ~SavedModelMLIRImportInput();
101 
102   // The original MetaGraphDef of the savedmodel.
meta_graph_def()103   const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; }
104 
debug_info()105   const GraphDebugInfo& debug_info() const { return debug_info_; }
106 
107   // GetSubGraph() is expected to return a tensorflow::Graph that contains the
108   // node set specified in `specs`. The implementation is free to transform the
109   // graph in the original savedmodel as needed, as long as it produces the same
110   // results and effects. `name` is a unique identifier for this subgraph, so
111   // the implementation can use it for eg. debugging or caching compilation
112   // results.
113   virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph(
114       absl::string_view name, const GraphImportConfig& specs) = 0;
115 
116  private:
117   const MetaGraphDef* meta_graph_def_ = nullptr;
118   GraphDebugInfo debug_info_;
119 };
120 
121 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module
122 // containing the functions, expressed with tf_executor dialect. It does not
123 // require a session to be created.
124 //
125 // Note that the word `Lite` means it is a lighter version compared to
126 // ConvertSavedModelV1ToMlir(), and is not related to TFLite.
127 //
128 // TODO(b/179683149): Rename this class to avoid confusion with TFLite.
129 stream_executor::port::StatusOr<mlir::OwningModuleRef>
130 ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput& input,
131                               absl::Span<std::string> exported_names,
132                               mlir::MLIRContext* context);
133 
134 // Serialize a MLIR module to a string.
135 std::string MlirModuleToString(mlir::ModuleOp module,
136                                mlir::OpPrintingFlags flags);
137 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);
138 
139 }  // namespace tensorflow
140 
141 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_
142