1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
17
18 #include <utility>
19
20 #include "absl/strings/match.h"
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
24 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
28 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
29 #include "tensorflow/core/common_runtime/function_body.h"
30 #include "tensorflow/core/common_runtime/function_def_utils.h"
31 #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime
32
33 namespace tensorflow {
34
ConvertFunctionToBef(mlir::StringRef function_name,const tensorflow::FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,tfrt::ArrayRef<tfrt::string_view> devices,const tensorflow::TfrtFunctionCompileOptions & options,tfrt::BefBuffer * bef_buffer)35 Status ConvertFunctionToBef(
36 mlir::StringRef function_name, const tensorflow::FunctionBody* fbody,
37 const FunctionLibraryDefinition& flib_def,
38 tfrt::ArrayRef<tfrt::string_view> devices,
39 const tensorflow::TfrtFunctionCompileOptions& options,
40 tfrt::BefBuffer* bef_buffer) {
41 mlir::MLIRContext context;
42 // FunctionDef -> TF Dialect
43 auto expected_module =
44 tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context);
45
46 if (!expected_module.ok())
47 return tensorflow::errors::Internal(
48 "Failed to convert function to mlir for function ", function_name.str(),
49 ". Error: ", expected_module.status().error_message());
50
51 auto module = std::move(expected_module).value();
52
53 // Attach devices to the MLIR module.
54 if (!devices.empty()) {
55 mlir::Builder builder(module->getContext());
56 module->getOperation()->setAttr("tf.devices",
57 builder.getStrArrayAttr(devices));
58 }
59
60 // TF Dialect -> BEF
61 return tensorflow::CompileTFMLIRToBEF(options, module.get(), bef_buffer);
62 }
63
ConvertTfMlirToBef(const TfrtCompileOptions & options,mlir::ModuleOp module,tfrt::BefBuffer * bef_buffer)64 Status ConvertTfMlirToBef(const TfrtCompileOptions& options,
65 mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) {
66 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
67
68 if (options.tpu_target == TfrtTpuInfraTarget::kTpurt) {
69 VLOG(1) << "Running MLIR TPU bridge for tpurt";
70 if (VLOG_IS_ON(1)) {
71 tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module);
72 }
73
74 TfrtTpuCompileOptions tpu_compile_options;
75 tpu_compile_options.move_resource_gather_to_host =
76 options.tpu_move_resource_gather_to_host;
77 tpu_compile_options.gather_table_width_threshold_bytes =
78 options.tpu_gather_table_width_threshold_bytes;
79
80 auto backward_compat_result =
81 tensorflow::RunTPUBackwardCompatConversion(module, tpu_compile_options);
82 if (mlir::failed(backward_compat_result)) {
83 return diag_handler.Combine(
84 tensorflow::errors::Internal("Failed to handle legacy TPU Ops"));
85 }
86
87 if (VLOG_IS_ON(1)) {
88 tensorflow::DumpMlirOpToFile("tpu_bct_conversion_after", module);
89 }
90
91 TF_RETURN_IF_ERROR(
92 mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
93 } else if (options.tpu_target == TfrtTpuInfraTarget::kTfFallback) {
94 auto tpu_partitioned_call_fallback_compat_result =
95 tensorflow::RunTPUPartitionedCallFallbackCompatConversion(module);
96 if (mlir::failed(tpu_partitioned_call_fallback_compat_result)) {
97 return diag_handler.Combine(tensorflow::errors::Internal(
98 "Failed to process TPUPartitionedCallOp for fallback execution"));
99 }
100 }
101
102 if (VLOG_IS_ON(1)) {
103 tensorflow::DumpMlirOpToFile("tf_dialect", module);
104 }
105
106 // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect.
107 mlir::PassManager pm(module.getContext());
108
109 tensorflow::TfrtPipelineOptions pass_options;
110 if (!options.default_device.empty()) {
111 pass_options.default_device = options.default_device;
112 }
113 if (!options.force_data_format.empty()) {
114 pass_options.force_data_format = options.force_data_format;
115 }
116
117 // TODO(b/187991150): Consider only decomposing read-only resource variable
118 // ops.
119 pass_options.decompose_resource_ops = options.decompose_resource_ops;
120 pass_options.enable_optimizer = options.enable_optimizer;
121 pass_options.enable_native_ops = options.enable_native_ops;
122 pass_options.target_tpurt =
123 (options.tpu_target == TfrtTpuInfraTarget::kTpurt);
124 pass_options.tpu_fuse_ops = options.tpu_fuse_ops;
125 pass_options.use_tpu_host_allocator_for_inputs =
126 options.use_tpu_host_allocator_for_inputs;
127 pass_options.hoist_invariant_ops = options.hoist_invariant_ops;
128 pass_options.func_use_fallback_tensor = true;
129 pass_options.enable_while_parallel_iterations =
130 options.enable_while_parallel_iterations;
131 pass_options.auto_fusion_oplist = options.auto_fusion_oplist;
132 pass_options.auto_fusion_min_cluster_size =
133 options.auto_fusion_min_cluster_size;
134 pass_options.cost_threshold = options.cost_threshold;
135 pass_options.upper_cost_threshold = options.upper_cost_threshold;
136 pass_options.merge_inter_dependent_streams =
137 options.merge_inter_dependent_streams;
138 tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options);
139
140 if (mlir::failed(pm.run(module)))
141 return diag_handler.Combine(tensorflow::errors::Internal(
142 "failed to lower TF Dialect to CoreRT dialect."));
143
144 if (VLOG_IS_ON(1)) {
145 tensorflow::DumpMlirOpToFile("tfrt_dialect", module);
146 }
147
148 *bef_buffer =
149 tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true);
150 if (bef_buffer->empty())
151 return diag_handler.Combine(
152 tensorflow::errors::Internal("failed to convert MLIR to BEF."));
153
154 bef_buffer->shrink_to_fit();
155
156 return OkStatus();
157 }
158
159 } // namespace tensorflow
160