1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 18 19 #include <string> 20 21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 22 #include "mlir/IR/MLIRContext.h" // from @llvm-project 23 #include "mlir/IR/OperationSupport.h" // from @llvm-project 24 #include "mlir/Support/LLVM.h" // from @llvm-project 25 #include "tensorflow/cc/saved_model/bundle_v2.h" 26 #include "tensorflow/cc/saved_model/loader.h" 27 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" 28 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" 29 #include "tensorflow/core/framework/function.h" 30 #include "tensorflow/core/framework/graph.pb.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/protobuf/graph_debug_info.pb.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 namespace tensorflow { 36 37 ABSL_CONST_INIT extern const char kImportModelDefaultGraphFuncName[]; 38 39 // Given a GraphDef, returns a MLIR module containing the graph, expressed with 40 // tf_executor dialect. 41 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir( 42 const GraphDef& graphdef, const GraphDebugInfo& debug_info, 43 const GraphImportConfig& specs, mlir::MLIRContext* context, 44 bool add_default_attributes = true); 45 46 // Given a Graph, returns a MLIR module containing the graph, expressed with 47 // tf_executor dialect. 48 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( 49 const Graph& graph, const GraphDebugInfo& debug_info, 50 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, 51 mlir::MLIRContext* context); 52 53 // [Experimental] 54 // Given a Function, returns a MLIR module containing the graph, expressed with 55 // tf_executor dialect. 56 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir( 57 const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, 58 mlir::MLIRContext* context); 59 60 // Given a SavedModel, returns a MLIR module containing the functions, expressed 61 // with tf_executor dialect. 62 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( 63 SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, 64 absl::Span<std::string> exported_names, bool add_default_attributes = true); 65 66 // Given a V1 SavedModel, returns a MLIR module containing the functions, 67 // expressed with tf_executor dialect. 68 stream_executor::port::StatusOr<mlir::OwningModuleRef> 69 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, 70 absl::Span<std::string> exported_names, 71 mlir::MLIRContext* context, MLIRImportOptions options, 72 bool lift_variables = true); 73 74 // Given a V1 SavedModel, returns a MLIR module containing the functions, 75 // expressed with tf_executor dialect. It does not require a session to be 76 // created and it does not perform any graph transformation. If `exported_names` 77 // is absl::nullopt, all signatures will be imported. Otherwise, only names 78 // in `exported_names` are imported. 79 // 80 // Note that the word `Lite` means it is a lighter version compared to 81 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 82 // 83 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 84 stream_executor::port::StatusOr<mlir::OwningModuleRef> 85 ConvertSavedModelV1ToMlirLite( 86 const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, 87 absl::optional<absl::Span<const std::string>> exported_names, 88 mlir::MLIRContext* context, MLIRImportOptions options); 89 90 // SavedModelMLIRImportInput is an adapter class for users to inject custom 91 // graph transformation logic on Tensorflow graphs before importing to MLIR. It 92 // serves as the source that provides the subgraphs requested by the savedmodel 93 // MLIR importer, and at the same time it allows the implementation of this 94 // class to transform the graph before feeding it to the importer. 95 class SavedModelMLIRImportInput { 96 public: SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)97 SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def, 98 const GraphDebugInfo& debug_info) 99 : meta_graph_def_(meta_graph_def), debug_info_(debug_info) { 100 DCHECK(meta_graph_def); 101 } 102 103 virtual ~SavedModelMLIRImportInput(); 104 105 // The original MetaGraphDef of the savedmodel. meta_graph_def()106 const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; } 107 debug_info()108 const GraphDebugInfo& debug_info() const { return debug_info_; } 109 110 // GetSubGraph() is expected to return a tensorflow::Graph that contains the 111 // node set specified in `specs`. The implementation is free to transform the 112 // graph in the original savedmodel as needed, as long as it produces the same 113 // results and effects. `name` is a unique identifier for this subgraph, so 114 // the implementation can use it for eg. debugging or caching compilation 115 // results. 116 virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph( 117 absl::string_view name, const GraphImportConfig& specs) = 0; 118 119 private: 120 const MetaGraphDef* meta_graph_def_ = nullptr; 121 GraphDebugInfo debug_info_; 122 }; 123 124 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module 125 // containing the functions, expressed with tf_executor dialect. It does not 126 // require a session to be created. If `exported_names` is absl::nullopt, all 127 // signatures will be imported. Otherwise, only names in `exported_names` are 128 // imported. 129 130 // 131 // Note that the word `Lite` means it is a lighter version compared to 132 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 133 // 134 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 135 stream_executor::port::StatusOr<mlir::OwningModuleRef> 136 ConvertSavedModelV1ToMlirLite( 137 SavedModelMLIRImportInput& input, 138 absl::optional<absl::Span<const std::string>> exported_names, 139 mlir::MLIRContext* context); 140 141 // Serialize a MLIR module to a string. 142 std::string MlirModuleToString(mlir::ModuleOp module, 143 mlir::OpPrintingFlags flags); 144 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); 145 146 } // namespace tensorflow 147 148 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 149