• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h"
17 
18 #include <utility>
19 
20 #include "absl/strings/match.h"
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
24 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
27 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
28 #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
29 #include "tensorflow/core/common_runtime/function_body.h"
30 #include "tensorflow/core/common_runtime/function_def_utils.h"
31 #include "tfrt/bef_converter/mlir_to_bef.h"  // from @tf_runtime
32 
33 namespace tensorflow {
34 
ConvertFunctionToBef(mlir::StringRef function_name,const tensorflow::FunctionBody * fbody,const FunctionLibraryDefinition & flib_def,tfrt::ArrayRef<tfrt::string_view> devices,const tensorflow::TfrtFunctionCompileOptions & options,tfrt::BefBuffer * bef_buffer)35 Status ConvertFunctionToBef(
36     mlir::StringRef function_name, const tensorflow::FunctionBody* fbody,
37     const FunctionLibraryDefinition& flib_def,
38     tfrt::ArrayRef<tfrt::string_view> devices,
39     const tensorflow::TfrtFunctionCompileOptions& options,
40     tfrt::BefBuffer* bef_buffer) {
41   mlir::MLIRContext context;
42   // FunctionDef -> TF Dialect
43   auto expected_module =
44       tensorflow::ConvertFunctionToMlir(fbody, flib_def, &context);
45 
46   if (!expected_module.ok())
47     return tensorflow::errors::Internal(
48         "Failed to convert function to mlir for function ", function_name.str(),
49         ". Error: ", expected_module.status().error_message());
50 
51   auto module = std::move(expected_module).value();
52 
53   // Attach devices to the MLIR module.
54   if (!devices.empty()) {
55     mlir::Builder builder(module->getContext());
56     module->getOperation()->setAttr("tf.devices",
57                                     builder.getStrArrayAttr(devices));
58   }
59 
60   // TF Dialect -> BEF
61   return tensorflow::CompileTFMLIRToBEF(options, module.get(), bef_buffer);
62 }
63 
ConvertTfMlirToBef(const TfrtCompileOptions & options,mlir::ModuleOp module,tfrt::BefBuffer * bef_buffer)64 Status ConvertTfMlirToBef(const TfrtCompileOptions& options,
65                           mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) {
66   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
67 
68   if (options.tpu_target == TfrtTpuInfraTarget::kTpurt) {
69     VLOG(1) << "Running MLIR TPU bridge for tpurt";
70     if (VLOG_IS_ON(1)) {
71       tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module);
72     }
73 
74     TfrtTpuCompileOptions tpu_compile_options;
75     tpu_compile_options.move_resource_gather_to_host =
76         options.tpu_move_resource_gather_to_host;
77     tpu_compile_options.gather_table_width_threshold_bytes =
78         options.tpu_gather_table_width_threshold_bytes;
79 
80     auto backward_compat_result =
81         tensorflow::RunTPUBackwardCompatConversion(module, tpu_compile_options);
82     if (mlir::failed(backward_compat_result)) {
83       return diag_handler.Combine(
84           tensorflow::errors::Internal("Failed to handle legacy TPU Ops"));
85     }
86 
87     if (VLOG_IS_ON(1)) {
88       tensorflow::DumpMlirOpToFile("tpu_bct_conversion_after", module);
89     }
90 
91     TF_RETURN_IF_ERROR(
92         mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
93   } else if (options.tpu_target == TfrtTpuInfraTarget::kTfFallback) {
94     auto tpu_partitioned_call_fallback_compat_result =
95         tensorflow::RunTPUPartitionedCallFallbackCompatConversion(module);
96     if (mlir::failed(tpu_partitioned_call_fallback_compat_result)) {
97       return diag_handler.Combine(tensorflow::errors::Internal(
98           "Failed to process TPUPartitionedCallOp for fallback execution"));
99     }
100   }
101 
102   if (VLOG_IS_ON(1)) {
103     tensorflow::DumpMlirOpToFile("tf_dialect", module);
104   }
105 
106   // Lower MLIR TF Dialect to MLIR TFRT CoreRT dialect.
107   mlir::PassManager pm(module.getContext());
108 
109   tensorflow::TfrtPipelineOptions pass_options;
110   if (!options.default_device.empty()) {
111     pass_options.default_device = options.default_device;
112   }
113   if (!options.force_data_format.empty()) {
114     pass_options.force_data_format = options.force_data_format;
115   }
116 
117   // TODO(b/187991150): Consider only decomposing read-only resource variable
118   // ops.
119   pass_options.decompose_resource_ops = options.decompose_resource_ops;
120   pass_options.enable_optimizer = options.enable_optimizer;
121   pass_options.enable_native_ops = options.enable_native_ops;
122   pass_options.target_tpurt =
123       (options.tpu_target == TfrtTpuInfraTarget::kTpurt);
124   pass_options.tpu_fuse_ops = options.tpu_fuse_ops;
125   pass_options.use_tpu_host_allocator_for_inputs =
126       options.use_tpu_host_allocator_for_inputs;
127   pass_options.hoist_invariant_ops = options.hoist_invariant_ops;
128   pass_options.func_use_fallback_tensor = true;
129   pass_options.enable_while_parallel_iterations =
130       options.enable_while_parallel_iterations;
131   pass_options.auto_fusion_oplist = options.auto_fusion_oplist;
132   pass_options.auto_fusion_min_cluster_size =
133       options.auto_fusion_min_cluster_size;
134   pass_options.cost_threshold = options.cost_threshold;
135   pass_options.upper_cost_threshold = options.upper_cost_threshold;
136   pass_options.merge_inter_dependent_streams =
137       options.merge_inter_dependent_streams;
138   tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options);
139 
140   if (mlir::failed(pm.run(module)))
141     return diag_handler.Combine(tensorflow::errors::Internal(
142         "failed to lower TF Dialect to CoreRT dialect."));
143 
144   if (VLOG_IS_ON(1)) {
145     tensorflow::DumpMlirOpToFile("tfrt_dialect", module);
146   }
147 
148   *bef_buffer =
149       tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/true);
150   if (bef_buffer->empty())
151     return diag_handler.Combine(
152         tensorflow::errors::Internal("failed to convert MLIR to BEF."));
153 
154   bef_buffer->shrink_to_fit();
155 
156   return OkStatus();
157 }
158 
159 }  // namespace tensorflow
160