• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
17 
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
20 #include "mlir/Parser.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 
SerializeMlirModule(mlir::ModuleOp module_op)27 std::string SerializeMlirModule(mlir::ModuleOp module_op) {
28   std::string serialized_mlir_module;
29   llvm::raw_string_ostream os(serialized_mlir_module);
30   mlir::OpPrintingFlags print_flags;
31   print_flags.enableDebugInfo();
32   module_op.print(os, print_flags);
33   return std::move(os.str());
34 }
35 
DeserializeMlirModule(llvm::StringRef serialized_mlir_module,mlir::MLIRContext * mlir_context,mlir::OwningModuleRef * mlir_module)36 Status DeserializeMlirModule(llvm::StringRef serialized_mlir_module,
37                              mlir::MLIRContext* mlir_context,
38                              mlir::OwningModuleRef* mlir_module) {
39   TF_RET_CHECK(!serialized_mlir_module.empty())
40       << "unexpected empty serialized MLIR module string";
41   TF_RET_CHECK(mlir_module) << "unexpected null MLIR module pointer";
42 
43   // Make sure we catch any error reported by MLIR and forward it to the TF
44   // error reporting system.
45   mlir::StatusScopedDiagnosticHandler error_handler(mlir_context);
46 
47   // Parse the module.
48   *mlir_module = mlir::parseSourceString(serialized_mlir_module, mlir_context);
49   if (!*mlir_module)
50     return error_handler.Combine(
51         errors::InvalidArgument("could not parse MLIR module"));
52 
53   return Status::OK();
54 }
55 
56 }  // namespace tensorflow
57