• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
17 
18 #include "absl/strings/match.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
22 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
23 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
24 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
25 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
26 #include "tensorflow/core/common_runtime/function_body.h"
27 #include "tensorflow/core/common_runtime/function_def_utils.h"
28 #include "tfrt/bef_converter/mlir_to_bef.h"  // from @tf_runtime
29 
30 namespace tensorflow {
31 
ConvertFunctionToBef(mlir::StringRef function_name,const tensorflow::FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,tfrt::ArrayRef<tfrt::string_view> devices,tensorflow::TfrtFunctionCompileOptions options,tfrt::BefBuffer * bef_buffer)32 Status ConvertFunctionToBef(mlir::StringRef function_name,
33                             const tensorflow::FunctionBody* fbody,
34                             const FunctionLibraryDefinition& flib_def,
35                             tfrt::ArrayRef<tfrt::string_view> devices,
36                             tensorflow::TfrtFunctionCompileOptions options,
37                             tfrt::BefBuffer* bef_buffer) {
38   mlir::MLIRContext context;
39   // FunctionDef -> TF Dialect
40   auto expected_module =
41       tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context);
42 
43   if (!expected_module.ok())
44     return tensorflow::errors::Internal(
45         "Failed to convert function to mlir for function ", function_name.str(),
46         ". Error: ", expected_module.status().error_message());
47 
48   auto module = expected_module.ConsumeValueOrDie();
49 
50   // Attach devices to the MLIR module.
51   if (!devices.empty()) {
52     mlir::Builder builder(module->getContext());
53     module->getOperation()->setAttr("tf.devices",
54                                     builder.getStrArrayAttr(devices));
55   }
56 
57   // TF Dialect -> BEF
58   return tensorflow::CompileTFMLIRToBEF(options, module.get(), bef_buffer);
59 }
60 
ConvertTfMlirToBef(const TfrtCompileOptions & options,mlir::ModuleOp module,tfrt::BefBuffer * bef_buffer)61 Status ConvertTfMlirToBef(const TfrtCompileOptions& options,
62                           mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) {
63   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
64 
65   if (options.target_tpu) {
66     if (VLOG_IS_ON(1)) {
67       tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module);
68     }
69 
70     auto backward_compat_result =
71         tensorflow::RunTPUBackwardCompatConversion(module);
72     if (mlir::failed(backward_compat_result)) {
73       return diag_handler.Combine(
74           tensorflow::errors::Internal("Failed to handle legacy TPU Ops"));
75     }
76 
77     if (VLOG_IS_ON(1)) {
78       tensorflow::DumpMlirOpToFile("tpu_bct_conversion_after", module);
79     }
80 
81     TF_RETURN_IF_ERROR(
82         mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
83   }
84 
85   if (VLOG_IS_ON(1)) {
86     tensorflow::DumpMlirOpToFile("tf_dialect", module);
87   }
88 
89   // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect.
90   mlir::PassManager pm(module.getContext());
91 
92   tensorflow::TfrtPipelineOptions pass_options;
93   if (!options.default_device.empty()) {
94     pass_options.default_device = options.default_device;
95   }
96   if (!options.force_data_format.empty()) {
97     pass_options.force_data_format = options.force_data_format;
98   }
99 
100   // TODO(b/187991150): Consider only decomposing read-only resource variable
101   // ops.
102   pass_options.decompose_resource_ops = true;
103   pass_options.enable_optimizer = options.enable_optimizer;
104   pass_options.enable_native_ops = options.enable_native_ops;
105   pass_options.target_tpu = options.target_tpu;
106   pass_options.hoist_invariant_ops = options.hoist_invariant_ops;
107   pass_options.func_use_fallback_tensor = true;
108   pass_options.auto_fusion_oplist = options.auto_fusion_oplist;
109   pass_options.auto_fusion_min_cluster_size =
110       options.auto_fusion_min_cluster_size;
111   pass_options.cost_threshold = options.cost_threshold;
112   pass_options.upper_cost_threshold = options.upper_cost_threshold;
113   pass_options.merge_inter_dependent_streams =
114       options.merge_inter_dependent_streams;
115   tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options);
116 
117   if (mlir::failed(pm.run(module)))
118     return diag_handler.Combine(tensorflow::errors::Internal(
119         "failed to lower TF Dialect to CoreRT dialect."));
120 
121   if (VLOG_IS_ON(1)) {
122     tensorflow::DumpMlirOpToFile("tfrt_dialect", module);
123   }
124 
125   *bef_buffer =
126       tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true);
127   if (bef_buffer->empty())
128     return diag_handler.Combine(
129         tensorflow::errors::Internal("failed to convert MLIR to BEF."));
130 
131   bef_buffer->shrink_to_fit();
132 
133   return Status::OK();
134 }
135 
136 }  // namespace tensorflow
137