1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/function/function.h"
17
18 #include "absl/strings/match.h"
19 #include "absl/strings/str_split.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/Pass/PassManager.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
27 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
31 #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime
32 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
33 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
34 #include "tfrt/host_context/host_context.h" // from @tf_runtime
35 #include "tfrt/tensor/dense_host_tensor_view.h" // from @tf_runtime
36
37 namespace tensorflow {
38
CompileTFMLIRToBEF(const TfrtFunctionCompileOptions & options,mlir::ModuleOp module,tfrt::BefBuffer * bef_buffer)39 Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options,
40 mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) {
41 mlir::OpPrintingFlags print_flags;
42 print_flags.elideLargeElementsAttrs();
43
44 if (VLOG_IS_ON(1)) {
45 VLOG(1) << "Input TF Executor dialect:";
46 DumpMlirOpToFile("tf_to_tfrt_tf_executor_dialect", module);
47 }
48
49 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
50
51 // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect.
52 mlir::PassManager pm(module.getContext());
53 tensorflow::applyTensorflowAndCLOptions(pm);
54
55 tensorflow::TfrtPipelineOptions pass_options;
56 if (!options.default_device.empty()) {
57 pass_options.default_device = options.default_device;
58 }
59 if (!options.force_data_format.empty()) {
60 pass_options.force_data_format = options.force_data_format;
61 }
62 // TODO(tfrt-devs): Current MaxPoolingOp only supports NHWC on device type
63 // CPU. Enable this layout optimization after we introduce TFRT native ops
64 // for training.
65 if (absl::StrContains(pass_options.default_device, "CPU")) {
66 pass_options.skip_fold_transpose_in_ops = true;
67 }
68 pass_options.enable_optimizer = options.enable_optimizer;
69 // Use TFRT TPU OpKernel for training.
70 pass_options.target_tpurt = false;
71 pass_options.tpu_use_core_selector = options.tpu_use_core_selector;
72 pass_options.tpu_use_bundled_transfer = options.tpu_use_bundled_transfer;
73 pass_options.tpu_lower_to_fallback = options.tpu_lower_to_fallback;
74 pass_options.tpu_fuse_ops = options.tpu_fuse_ops;
75 pass_options.tpu_transfer_result_to_host =
76 options.tpu_transfer_result_to_host;
77 pass_options.enable_native_ops = options.enable_native_ops;
78 tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options);
79
80 if (mlir::failed(pm.run(module)))
81 return diag_handler.Combine(tensorflow::errors::Internal(
82 "failed to lower TF Dialect to CoreRT dialect."));
83
84 if (VLOG_IS_ON(1)) {
85 VLOG(1) << "TFRT dialect: ";
86 DumpMlirOpToFile("tf_to_tfrt_tfrt_dialect", module);
87 }
88
89 *bef_buffer =
90 tfrt::ConvertMLIRToBEF(module, /* disable_optional_sections = */ true);
91 if (bef_buffer->empty())
92 return diag_handler.Combine(
93 tensorflow::errors::Internal("failed to convert MLIR to BEF."));
94
95 return OkStatus();
96 }
97
98 } // namespace tensorflow
99