1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 18 19 #include <string> 20 21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 22 #include "mlir/IR/MLIRContext.h" // from @llvm-project 23 #include "mlir/IR/OperationSupport.h" // from @llvm-project 24 #include "mlir/Support/LLVM.h" // from @llvm-project 25 #include "tensorflow/cc/saved_model/bundle_v2.h" 26 #include "tensorflow/cc/saved_model/loader.h" 27 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" 28 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" 29 #include "tensorflow/core/framework/function.h" 30 #include "tensorflow/core/framework/graph.pb.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/protobuf/graph_debug_info.pb.h" 33 #include "tensorflow/stream_executor/lib/statusor.h" 34 35 namespace tensorflow { 36 37 // Given a GraphDef, returns a MLIR module containing the graph, expressed with 38 // tf_executor dialect. 39 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphdefToMlir( 40 const GraphDef& graphdef, const GraphDebugInfo& debug_info, 41 const GraphImportConfig& specs, mlir::MLIRContext* context, 42 bool add_default_attributes = true); 43 44 // Given a Graph, returns a MLIR module containing the graph, expressed with 45 // tf_executor dialect. 46 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir( 47 const Graph& graph, const GraphDebugInfo& debug_info, 48 const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, 49 mlir::MLIRContext* context); 50 51 // [Experimental] 52 // Given a Function, returns a MLIR module containing the graph, expressed with 53 // tf_executor dialect. 54 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir( 55 const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, 56 mlir::MLIRContext* context); 57 58 // Given a SavedModel, returns a MLIR module containing the functions, expressed 59 // with tf_executor dialect. 60 stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( 61 SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, 62 absl::Span<std::string> exported_names, bool add_default_attributes = true); 63 64 // Given a V1 SavedModel, returns a MLIR module containing the functions, 65 // expressed with tf_executor dialect. 66 stream_executor::port::StatusOr<mlir::OwningModuleRef> 67 ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, 68 absl::Span<std::string> exported_names, 69 mlir::MLIRContext* context, 70 MLIRImportOptions options); 71 72 // Given a V1 SavedModel, returns a MLIR module containing the functions, 73 // expressed with tf_executor dialect. It does not require a session to be 74 // created and it does not perform any graph transformation. 75 // 76 // Note that the word `Lite` means it is a lighter version compared to 77 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 78 // 79 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 80 stream_executor::port::StatusOr<mlir::OwningModuleRef> 81 ConvertSavedModelV1ToMlirLite(const MetaGraphDef& meta_graph_def, 82 const GraphDebugInfo& debug_info, 83 absl::Span<std::string> exported_names, 84 mlir::MLIRContext* context, 85 MLIRImportOptions options); 86 87 // SavedModelMLIRImportInput is an adapter class for users to inject custom 88 // graph transformation logic on Tensorflow graphs before importing to MLIR. It 89 // serves as the source that provides the subgraphs requested by the savedmodel 90 // MLIR importer, and at the same time it allows the implementation of this 91 // class to transform the graph before feeding it to the importer. 92 class SavedModelMLIRImportInput { 93 public: SavedModelMLIRImportInput(const MetaGraphDef * meta_graph_def,const GraphDebugInfo & debug_info)94 SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def, 95 const GraphDebugInfo& debug_info) 96 : meta_graph_def_(meta_graph_def), debug_info_(debug_info) { 97 DCHECK(meta_graph_def); 98 } 99 100 virtual ~SavedModelMLIRImportInput(); 101 102 // The original MetaGraphDef of the savedmodel. meta_graph_def()103 const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; } 104 debug_info()105 const GraphDebugInfo& debug_info() const { return debug_info_; } 106 107 // GetSubGraph() is expected to return a tensorflow::Graph that contains the 108 // node set specified in `specs`. The implementation is free to transform the 109 // graph in the original savedmodel as needed, as long as it produces the same 110 // results and effects. `name` is a unique identifier for this subgraph, so 111 // the implementation can use it for eg. debugging or caching compilation 112 // results. 113 virtual stream_executor::port::StatusOr<const Graph*> GetSubGraph( 114 absl::string_view name, const GraphImportConfig& specs) = 0; 115 116 private: 117 const MetaGraphDef* meta_graph_def_ = nullptr; 118 GraphDebugInfo debug_info_; 119 }; 120 121 // Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module 122 // containing the functions, expressed with tf_executor dialect. It does not 123 // require a session to be created. 124 // 125 // Note that the word `Lite` means it is a lighter version compared to 126 // ConvertSavedModelV1ToMlir(), and is not related to TFLite. 127 // 128 // TODO(b/179683149): Rename this class to avoid confusion with TFLite. 129 stream_executor::port::StatusOr<mlir::OwningModuleRef> 130 ConvertSavedModelV1ToMlirLite(SavedModelMLIRImportInput& input, 131 absl::Span<std::string> exported_names, 132 mlir::MLIRContext* context); 133 134 // Serialize a MLIR module to a string. 135 std::string MlirModuleToString(mlir::ModuleOp module, 136 mlir::OpPrintingFlags flags); 137 std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); 138 139 } // namespace tensorflow 140 141 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ 142