1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
16
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/MemoryBuffer.h"
24 #include "llvm/Support/SMLoc.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
29 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Verifier.h" // from @llvm-project
37 #include "mlir/Parser/Parser.h" // from @llvm-project
38 #include "mlir/Pass/PassManager.h" // from @llvm-project
39 #include "mlir/Transforms/Passes.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
49 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
50 #include "tensorflow/core/platform/path.h"
51 #include "tensorflow/core/platform/stringpiece.h"
52 #include "tensorflow/core/util/env_var.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54
55 namespace tensorflow {
56 namespace tfr {
57
58 const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
59
Get(mlir::MLIRContext * mlir_ctx)60 StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
61 mlir::MLIRContext* mlir_ctx) {
62 Env* env = Env::Default();
63 std::string tfr_lib_dir;
64 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
65 kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir));
66 string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
67 std::vector<string> files;
68 TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
69 if (files.empty()) {
70 return errors::Internal(absl::StrCat(
71 "Failed to find the decomposition lib from path ", composite_mlir_dir));
72 }
73 std::string tfr_raw_text;
74 for (const auto& file : files) {
75 string fullpath = io::JoinPath(composite_mlir_dir, file);
76 if (env->MatchPath(fullpath, io::JoinPath(composite_mlir_dir, "*.mlir"))) {
77 std::string text;
78 TF_RETURN_IF_ERROR(ReadFileToString(env, fullpath, &text));
79 tfr_raw_text.append(text);
80 }
81 }
82
83 auto ctx = TFRDecomposeContext::GetFromText(tfr_raw_text, mlir_ctx);
84 if (!ctx) {
85 return errors::Internal(absl::StrCat(
86 "Failed to load the imported decomposition lib: ", tfr_raw_text));
87 }
88 return ctx;
89 }
90
GetFromText(StringPiece tfr_raw_text,mlir::MLIRContext * mlir_ctx)91 std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
92 StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
93 mlir_ctx->allowUnregisteredDialects(/*allow=*/true);
94 // Load dialects involved in the conversion
95 mlir::DialectRegistry registry;
96 // clang-format off
97 registry.insert<mlir::arith::ArithmeticDialect,
98 mlir::func::FuncDialect,
99 mlir::scf::SCFDialect,
100 mlir::shape::ShapeDialect,
101 mlir::TF::TensorFlowDialect,
102 mlir::tf_device::TensorFlowDeviceDialect,
103 mlir::tf_executor::TensorFlowExecutorDialect,
104 mlir::TFR::TFRDialect>();
105 // clang-format on
106 mlir_ctx->appendDialectRegistry(registry);
107 mlir_ctx->loadAllAvailableDialects();
108
109 // Load the TFR functions in a mlir::ModuleOp
110 auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(
111 llvm::StringRef(tfr_raw_text.data(), tfr_raw_text.size()));
112 llvm::SourceMgr source_mgr;
113 source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc());
114 mlir::OwningOpRef<mlir::ModuleOp> module =
115 mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, mlir_ctx);
116 // The MLIRContext owns the module
117 auto module_op = module.release();
118
119 // Create the context
120 return std::make_unique<TFRDecomposeContext>(module_op);
121 }
122
ExpandNode(const NodeDef & node_def,StringPiece func_name)123 StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
124 StringPiece func_name) {
125 const OpDef* op_def;
126 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
127 DataTypeVector input_dtys, output_dtys;
128 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys));
129 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys));
130
131 mlir::MLIRContext* context = tfr_module_.getContext();
132 llvm::SmallVector<mlir::Type, 4> input_tys, output_tys;
133 mlir::Builder builder(context);
134 for (auto ty : input_dtys) {
135 mlir::Type elt_ty;
136 TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty));
137 mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty);
138 input_tys.push_back(mlir_ty);
139 }
140 for (auto ty : output_dtys) {
141 mlir::Type elt_ty;
142 TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty));
143 mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty);
144 output_tys.push_back(mlir_ty);
145 }
146 llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
147 for (const auto& attr : node_def.attr()) {
148 TF_ASSIGN_OR_RETURN(auto mlir_attr,
149 ConvertAttributeValue(attr.second, &builder));
150 attrs.push_back({mlir::StringAttr::get(context, attr.first), mlir_attr});
151 }
152
153 mlir::Location loc = mlir::UnknownLoc::get(context);
154 mlir::ModuleOp module = mlir::ModuleOp::create(loc);
155 mlir::FunctionType func_type =
156 mlir::FunctionType::get(context, input_tys, output_tys);
157 llvm::StringRef func_name_str(func_name.data(), func_name.size());
158 auto func = mlir::func::FuncOp::create(loc, func_name_str, func_type, {});
159 module.push_back(func);
160 func.addEntryBlock();
161 mlir::OpBuilder op_builder(func.getBody());
162
163 // Create the TF op
164 const std::string tf_op_full_name = absl::StrCat("tf.", node_def.op());
165 mlir::OperationState op_state(loc, tf_op_full_name);
166 op_state.addOperands(func.getArguments());
167 op_state.addTypes(output_tys);
168 op_state.addAttributes(attrs);
169 mlir::Operation* tf_op = op_builder.create(op_state);
170 op_builder.create<mlir::func::ReturnOp>(loc, tf_op->getResults());
171
172 // Run the decompose passes on the module
173 TF_RETURN_IF_ERROR(DecomposeGraph(module));
174
175 // Export the result as a FunctionDef.
176 FunctionDef func_def;
177 TF_RETURN_IF_ERROR(
178 ConvertMlirFunctionToFunctionLibraryDef(func, export_confs_, &func_def));
179 module.erase();
180 return func_def;
181 }
182
DecomposeGraph(mlir::ModuleOp user_module)183 Status TFRDecomposeContext::DecomposeGraph(mlir::ModuleOp user_module) {
184 // Call the decompose passes by using the external symbol table.
185 if (failed(pm_.run(user_module))) {
186 return errors::Internal("Failed to run the decompose passes.");
187 }
188 return OkStatus();
189 }
190
191 // Constructor of the decompose context.
TFRDecomposeContext(mlir::ModuleOp tfr_module)192 TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module)
193 : tfr_module_(tfr_module), pm_(tfr_module_.getContext()) {
194 mlir::OpPassManager& func_pm = pm_.nest<mlir::func::FuncOp>();
195
196 // Prepare the imported graph.
197 func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass());
198
199 // Run TFR lowering, inlining and raising to tf.
200 func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_));
201 func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass(
202 tfr_module_, /*materialize_derived_attrs=*/true));
203
204 // Prepare to be exported.
205 func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass());
206 pm_.addPass(mlir::CreateBreakUpIslandsPass());
207 }
208
Destroy()209 void TFRDecomposeContext::Destroy() { tfr_module_.erase(); }
210
ExpandNode(const NodeDef & node_def,StringPiece func_name)211 StatusOr<FunctionDef> ExpandNode(const NodeDef& node_def,
212 StringPiece func_name) {
213 mlir::MLIRContext mlir_ctx;
214 TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx));
215 return ctx->ExpandNode(node_def, func_name);
216 }
217
DecomposeGraph(mlir::ModuleOp user_module)218 Status DecomposeGraph(mlir::ModuleOp user_module) {
219 mlir::MLIRContext* mlir_ctx = user_module.getContext();
220 TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(mlir_ctx));
221 return ctx->DecomposeGraph(user_module);
222 }
223
224 } // namespace tfr
225 } // namespace tensorflow
226