1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
17
18 #include "absl/strings/match.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/MLIRContext.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
22 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
23 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
24 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
25 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
26 #include "tensorflow/core/common_runtime/function_body.h"
27 #include "tensorflow/core/common_runtime/function_def_utils.h"
28 #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime
29
30 namespace tensorflow {
31
ConvertFunctionToBef(mlir::StringRef function_name,const tensorflow::FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,tfrt::ArrayRef<tfrt::string_view> devices,tensorflow::TfrtFunctionCompileOptions options,tfrt::BefBuffer * bef_buffer)32 Status ConvertFunctionToBef(mlir::StringRef function_name,
33 const tensorflow::FunctionBody* fbody,
34 const FunctionLibraryDefinition& flib_def,
35 tfrt::ArrayRef<tfrt::string_view> devices,
36 tensorflow::TfrtFunctionCompileOptions options,
37 tfrt::BefBuffer* bef_buffer) {
38 mlir::MLIRContext context;
39 // FunctionDef -> TF Dialect
40 auto expected_module =
41 tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context);
42
43 if (!expected_module.ok())
44 return tensorflow::errors::Internal(
45 "Failed to convert function to mlir for function ", function_name.str(),
46 ". Error: ", expected_module.status().error_message());
47
48 auto module = expected_module.ConsumeValueOrDie();
49
50 // Attach devices to the MLIR module.
51 if (!devices.empty()) {
52 mlir::Builder builder(module->getContext());
53 module->getOperation()->setAttr("tf.devices",
54 builder.getStrArrayAttr(devices));
55 }
56
57 // TF Dialect -> BEF
58 return tensorflow::CompileTFMLIRToBEF(options, module.get(), bef_buffer);
59 }
60
ConvertTfMlirToBef(const TfrtCompileOptions & options,mlir::ModuleOp module,tfrt::BefBuffer * bef_buffer)61 Status ConvertTfMlirToBef(const TfrtCompileOptions& options,
62 mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) {
63 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
64
65 if (options.target_tpu) {
66 if (VLOG_IS_ON(1)) {
67 tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module);
68 }
69
70 auto backward_compat_result =
71 tensorflow::RunTPUBackwardCompatConversion(module);
72 if (mlir::failed(backward_compat_result)) {
73 return diag_handler.Combine(
74 tensorflow::errors::Internal("Failed to handle legacy TPU Ops"));
75 }
76
77 if (VLOG_IS_ON(1)) {
78 tensorflow::DumpMlirOpToFile("tpu_bct_conversion_after", module);
79 }
80
81 TF_RETURN_IF_ERROR(
82 mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
83 }
84
85 if (VLOG_IS_ON(1)) {
86 tensorflow::DumpMlirOpToFile("tf_dialect", module);
87 }
88
89 // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect.
90 mlir::PassManager pm(module.getContext());
91
92 tensorflow::TfrtPipelineOptions pass_options;
93 if (!options.default_device.empty()) {
94 pass_options.default_device = options.default_device;
95 }
96 if (!options.force_data_format.empty()) {
97 pass_options.force_data_format = options.force_data_format;
98 }
99
100 // TODO(b/187991150): Consider only decomposing read-only resource variable
101 // ops.
102 pass_options.decompose_resource_ops = true;
103 pass_options.enable_optimizer = options.enable_optimizer;
104 pass_options.enable_native_ops = options.enable_native_ops;
105 pass_options.target_tpu = options.target_tpu;
106 pass_options.hoist_invariant_ops = options.hoist_invariant_ops;
107 pass_options.func_use_fallback_tensor = true;
108 pass_options.auto_fusion_oplist = options.auto_fusion_oplist;
109 pass_options.auto_fusion_min_cluster_size =
110 options.auto_fusion_min_cluster_size;
111 pass_options.cost_threshold = options.cost_threshold;
112 pass_options.upper_cost_threshold = options.upper_cost_threshold;
113 pass_options.merge_inter_dependent_streams =
114 options.merge_inter_dependent_streams;
115 tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options);
116
117 if (mlir::failed(pm.run(module)))
118 return diag_handler.Combine(tensorflow::errors::Internal(
119 "failed to lower TF Dialect to CoreRT dialect."));
120
121 if (VLOG_IS_ON(1)) {
122 tensorflow::DumpMlirOpToFile("tfrt_dialect", module);
123 }
124
125 *bef_buffer =
126 tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true);
127 if (bef_buffer->empty())
128 return diag_handler.Combine(
129 tensorflow::errors::Internal("failed to convert MLIR to BEF."));
130
131 bef_buffer->shrink_to_fit();
132
133 return Status::OK();
134 }
135
136 } // namespace tensorflow
137