• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h"
16 
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/Support/MemoryBuffer.h"
24 #include "llvm/Support/SMLoc.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
29 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Verifier.h"  // from @llvm-project
37 #include "mlir/Parser/Parser.h"  // from @llvm-project
38 #include "mlir/Pass/PassManager.h"  // from @llvm-project
39 #include "mlir/Transforms/Passes.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
49 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
50 #include "tensorflow/core/platform/path.h"
51 #include "tensorflow/core/platform/stringpiece.h"
52 #include "tensorflow/core/util/env_var.h"
53 #include "tensorflow/stream_executor/lib/statusor.h"
54 
55 namespace tensorflow {
56 namespace tfr {
57 
58 const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR";
59 
Get(mlir::MLIRContext * mlir_ctx)60 StatusOr<std::unique_ptr<TFRDecomposeContext>> TFRDecomposeContext::Get(
61     mlir::MLIRContext* mlir_ctx) {
62   Env* env = Env::Default();
63   std::string tfr_lib_dir;
64   TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
65       kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir));
66   string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir);
67   std::vector<string> files;
68   TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files));
69   if (files.empty()) {
70     return errors::Internal(absl::StrCat(
71         "Failed to find the decomposition lib from path ", composite_mlir_dir));
72   }
73   std::string tfr_raw_text;
74   for (const auto& file : files) {
75     string fullpath = io::JoinPath(composite_mlir_dir, file);
76     if (env->MatchPath(fullpath, io::JoinPath(composite_mlir_dir, "*.mlir"))) {
77       std::string text;
78       TF_RETURN_IF_ERROR(ReadFileToString(env, fullpath, &text));
79       tfr_raw_text.append(text);
80     }
81   }
82 
83   auto ctx = TFRDecomposeContext::GetFromText(tfr_raw_text, mlir_ctx);
84   if (!ctx) {
85     return errors::Internal(absl::StrCat(
86         "Failed to load the imported decomposition lib: ", tfr_raw_text));
87   }
88   return ctx;
89 }
90 
GetFromText(StringPiece tfr_raw_text,mlir::MLIRContext * mlir_ctx)91 std::unique_ptr<TFRDecomposeContext> TFRDecomposeContext::GetFromText(
92     StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) {
93   mlir_ctx->allowUnregisteredDialects(/*allow=*/true);
94   // Load dialects involved in the conversion
95   mlir::DialectRegistry registry;
96   // clang-format off
97   registry.insert<mlir::arith::ArithmeticDialect,
98                   mlir::func::FuncDialect,
99                   mlir::scf::SCFDialect,
100                   mlir::shape::ShapeDialect,
101                   mlir::TF::TensorFlowDialect,
102                   mlir::tf_device::TensorFlowDeviceDialect,
103                   mlir::tf_executor::TensorFlowExecutorDialect,
104                   mlir::TFR::TFRDialect>();
105   // clang-format on
106   mlir_ctx->appendDialectRegistry(registry);
107   mlir_ctx->loadAllAvailableDialects();
108 
109   // Load the TFR functions in a mlir::ModuleOp
110   auto memory_buffer = llvm::MemoryBuffer::getMemBuffer(
111       llvm::StringRef(tfr_raw_text.data(), tfr_raw_text.size()));
112   llvm::SourceMgr source_mgr;
113   source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc());
114   mlir::OwningOpRef<mlir::ModuleOp> module =
115       mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, mlir_ctx);
116   // The MLIRContext owns the module
117   auto module_op = module.release();
118 
119   // Create the context
120   return std::make_unique<TFRDecomposeContext>(module_op);
121 }
122 
ExpandNode(const NodeDef & node_def,StringPiece func_name)123 StatusOr<FunctionDef> TFRDecomposeContext::ExpandNode(const NodeDef& node_def,
124                                                       StringPiece func_name) {
125   const OpDef* op_def;
126   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
127   DataTypeVector input_dtys, output_dtys;
128   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys));
129   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys));
130 
131   mlir::MLIRContext* context = tfr_module_.getContext();
132   llvm::SmallVector<mlir::Type, 4> input_tys, output_tys;
133   mlir::Builder builder(context);
134   for (auto ty : input_dtys) {
135     mlir::Type elt_ty;
136     TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty));
137     mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty);
138     input_tys.push_back(mlir_ty);
139   }
140   for (auto ty : output_dtys) {
141     mlir::Type elt_ty;
142     TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty));
143     mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty);
144     output_tys.push_back(mlir_ty);
145   }
146   llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
147   for (const auto& attr : node_def.attr()) {
148     TF_ASSIGN_OR_RETURN(auto mlir_attr,
149                         ConvertAttributeValue(attr.second, &builder));
150     attrs.push_back({mlir::StringAttr::get(context, attr.first), mlir_attr});
151   }
152 
153   mlir::Location loc = mlir::UnknownLoc::get(context);
154   mlir::ModuleOp module = mlir::ModuleOp::create(loc);
155   mlir::FunctionType func_type =
156       mlir::FunctionType::get(context, input_tys, output_tys);
157   llvm::StringRef func_name_str(func_name.data(), func_name.size());
158   auto func = mlir::func::FuncOp::create(loc, func_name_str, func_type, {});
159   module.push_back(func);
160   func.addEntryBlock();
161   mlir::OpBuilder op_builder(func.getBody());
162 
163   // Create the TF op
164   const std::string tf_op_full_name = absl::StrCat("tf.", node_def.op());
165   mlir::OperationState op_state(loc, tf_op_full_name);
166   op_state.addOperands(func.getArguments());
167   op_state.addTypes(output_tys);
168   op_state.addAttributes(attrs);
169   mlir::Operation* tf_op = op_builder.create(op_state);
170   op_builder.create<mlir::func::ReturnOp>(loc, tf_op->getResults());
171 
172   // Run the decompose passes on the module
173   TF_RETURN_IF_ERROR(DecomposeGraph(module));
174 
175   // Export the result as a FunctionDef.
176   FunctionDef func_def;
177   TF_RETURN_IF_ERROR(
178       ConvertMlirFunctionToFunctionLibraryDef(func, export_confs_, &func_def));
179   module.erase();
180   return func_def;
181 }
182 
DecomposeGraph(mlir::ModuleOp user_module)183 Status TFRDecomposeContext::DecomposeGraph(mlir::ModuleOp user_module) {
184   // Call the decompose passes by using the external symbol table.
185   if (failed(pm_.run(user_module))) {
186     return errors::Internal("Failed to run the decompose passes.");
187   }
188   return OkStatus();
189 }
190 
191 // Constructor of the decompose context.
TFRDecomposeContext(mlir::ModuleOp tfr_module)192 TFRDecomposeContext::TFRDecomposeContext(mlir::ModuleOp tfr_module)
193     : tfr_module_(tfr_module), pm_(tfr_module_.getContext()) {
194   mlir::OpPassManager& func_pm = pm_.nest<mlir::func::FuncOp>();
195 
196   // Prepare the imported graph.
197   func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass());
198 
199   // Run TFR lowering, inlining and raising to tf.
200   func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_));
201   func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass(
202       tfr_module_, /*materialize_derived_attrs=*/true));
203 
204   // Prepare to be exported.
205   func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass());
206   pm_.addPass(mlir::CreateBreakUpIslandsPass());
207 }
208 
Destroy()209 void TFRDecomposeContext::Destroy() { tfr_module_.erase(); }
210 
ExpandNode(const NodeDef & node_def,StringPiece func_name)211 StatusOr<FunctionDef> ExpandNode(const NodeDef& node_def,
212                                  StringPiece func_name) {
213   mlir::MLIRContext mlir_ctx;
214   TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(&mlir_ctx));
215   return ctx->ExpandNode(node_def, func_name);
216 }
217 
DecomposeGraph(mlir::ModuleOp user_module)218 Status DecomposeGraph(mlir::ModuleOp user_module) {
219   mlir::MLIRContext* mlir_ctx = user_module.getContext();
220   TF_ASSIGN_OR_RETURN(auto ctx, TFRDecomposeContext::Get(mlir_ctx));
221   return ctx->DecomposeGraph(user_module);
222 }
223 
224 }  // namespace tfr
225 }  // namespace tensorflow
226