• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts each tfrt_dist.remote_execute_func op into a combination
17 // of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The
18 // function to be executed in the remote host will be serialized as a string
19 // attribute of the tfrt_dist.register_tfrt_function op.
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
28 #include "mlir/IR/Types.h"  // from @llvm-project
29 #include "mlir/IR/Visitors.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Pass/PassManager.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Transforms/Passes.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
41 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
42 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
43 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
44 #include "tfrt/test_kernels/opdefs/test_kernels.h"  // from @tf_runtime
45 
46 namespace tensorflow {
47 
48 namespace {
49 
50 constexpr const char* kHost = "host";
51 constexpr const char* kTFRTDevice = "tfrt.device";
52 
53 struct DistRemoteRunEncapsulatePass
54     : public PassWrapper<DistRemoteRunEncapsulatePass,
55                          OperationPass<ModuleOp>> {
getArgumenttensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass56   llvm::StringRef getArgument() const final {
57     return "tfrt-dist-remote-run-encapsulate";
58   }
getDescriptiontensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass59   llvm::StringRef getDescription() const final {
60     return "This pass looks for a remote_run_func and serialize the callee to "
61            "a string attribute attached to a remote_register operation, "
62            "followed by a remote_execute invocation.";
63   }
64   void runOnOperation() override;
65 
getDependentDialectstensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass66   void getDependentDialects(DialectRegistry& registry) const override {
67     registry.insert<tfrt::dist::DistributedDialect>();
68   }
69 };
70 
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)71 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
72                                           std::string* serialized_func_module) {
73   ModuleOp module = entry_func->getParentOfType<ModuleOp>();
74   SymbolTable entry_module_table(module);
75   SmallVector<FuncOp, 4> referenced({entry_func});
76 
77   // Create a new module to hold func and all referenced functions.
78   OwningModuleRef module_for_func =
79       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
80   SymbolTable symbol_table(module_for_func.get());
81 
82   while (!referenced.empty()) {
83     FuncOp func = referenced.pop_back_val();
84 
85     // Skip functions that have already been cloned into new module.
86     if (symbol_table.lookup<FuncOp>(func.getName())) continue;
87 
88     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
89     // all found FuncOps to new_module to make sure new_module is
90     // self-contained.
91     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
92     assert(uses && "expected to be able to collect symbol uses");
93     for (SymbolTable::SymbolUse use : *uses) {
94       FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
95           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
96 
97       // Skip Symbols that do not map to a function.
98       if (!referenced_func) continue;
99 
100       referenced.emplace_back(referenced_func);
101     }
102 
103     FuncOp clone = func.clone();
104     if (clone.getName() == entry_func.getName()) {
105       clone.setPublic();
106     } else {
107       clone.setPrivate();
108     }
109     symbol_table.insert(clone);
110   }
111 
112   *serialized_func_module =
113       tensorflow::SerializeMlirModule(module_for_func.get());
114   return success();
115 }
116 
runOnOperation()117 void DistRemoteRunEncapsulatePass::runOnOperation() {
118   mlir::TF::RuntimeDevices devices;
119   ModuleOp module = getOperation();
120   SymbolTable symtab(module);
121   Type chain_type = tfrt::compiler::ChainType::get(&getContext());
122   Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext());
123   Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext());
124   module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) {
125     FlatSymbolRefAttr callee_sym = remote_exec_op.calleeAttr();
126     FuncOp callee = symtab.lookup<FuncOp>(callee_sym.getValue());
127     if (!callee) {
128       remote_exec_op.emitOpError("callee function ")
129           << callee_sym.getValue() << " is not found";
130       signalPassFailure();
131       return WalkResult::interrupt();
132     }
133     std::string txt_module;
134     if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) {
135       remote_exec_op.emitOpError("failed to serialize the callee function ")
136           << callee.getName();
137       signalPassFailure();
138       return WalkResult::interrupt();
139     }
140     Location loc = remote_exec_op.getLoc();
141     StringAttr callee_name =
142         StringAttr::get(&getContext(), callee_sym.getValue());
143     OpBuilder builder(remote_exec_op);
144     auto register_op = builder.create<tfrt::dist::RegisterTFRTFunctionOp>(
145         loc, chain_type, remote_exec_op.in_op_chain(), remote_exec_op.context(),
146         remote_exec_op.remote_task(),
147         StringAttr::get(&getContext(), txt_module), callee_name);
148 
149     // Build the device assignment for the results
150     // TODO(tfrt-devs): Define properly MLIR types and operations
151     SmallVector<Attribute, 8> result_devices;
152     for (const auto& result : llvm::enumerate(remote_exec_op.results())) {
153       StringAttr device =
154           callee.getResultAttrOfType<StringAttr>(result.index(), kTFRTDevice);
155       if (!device) {
156         // The result might not have the device attribute if it is added by
157         // the tf-to-tfrt pass. Use the first CPU on the remote host as the
158         // device of this result.
159         DeviceNameUtils::ParsedName parsed_name;
160         if (StringAttr host_attr = callee->getAttrOfType<StringAttr>(kHost)) {
161           auto host = host_attr.getValue();
162           DeviceNameUtils::ParseFullName({host.data(), host.size()},
163                                          &parsed_name);
164         }
165         parsed_name.has_type = true;
166         parsed_name.type = "CPU";
167         parsed_name.has_id = true;
168         parsed_name.id = 0;
169         device = StringAttr::get(
170             &getContext(), DeviceNameUtils::ParsedNameToString(parsed_name));
171       }
172       result_devices.push_back(std::move(device));
173     }
174     // IDEA(donglin): Update the create_remote_execute_spec kernel to use Device
175     // object instead of Device string.
176     Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext());
177     auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices);
178     auto remote_spec = builder.create<tfrt::dist::CreateRemoteExecuteSpecOp>(
179         loc, remote_spec_ty, remote_exec_op.context(), result_devices_attr);
180     // If original argument is already tfrt_dist.remote_object_id, use it
181     // directly. If it is TensorHandle, insert an op to extract the
182     // tfrt_dist.remote_object_id from it. Otherwise, emit an error.
183     SmallVector<Value, 4> arguments;
184     for (Value value : remote_exec_op.callee_args()) {
185       if (value.getType().isa<tfrt::dist::RemoteObjectIdType>()) {
186         arguments.push_back(value);
187       } else if (value.getType().isa<tfrt::corert::TensorHandleType>()) {
188         auto new_op = builder.create<tfrt::dist::GetRemoteObjectIdFromTHOp>(
189             loc, remote_object_id_ty, value);
190         arguments.push_back(new_op.result());
191       } else {
192         remote_exec_op.emitOpError(
193             "callee argument type should be either "
194             "TensorHandle or RemoteObjectId");
195         signalPassFailure();
196         return WalkResult::interrupt();
197       }
198     }
199     // Result types are 1 chain, followed by `num_th_results + 1`
200     // tfrt_dist.remote_object_id results, followed by `num_th_results`
201     // corert.tensorhandle results.
202     int32_t num_th_results = remote_exec_op.results().size() - 1;
203     SmallVector<Type, 8> result_types;
204     result_types.push_back(chain_type);
205     for (int count : llvm::seq<int>(0, num_th_results + 1)) {
206       (void)count;
207       result_types.push_back(remote_object_id_ty);
208     }
209     for (int count : llvm::seq<int>(0, num_th_results)) {
210       (void)count;
211       result_types.push_back(tensor_handle_ty);
212     }
213     auto new_remote_exec_th_op = builder.create<tfrt::dist::RemoteExecuteTHOp>(
214         loc, result_types, register_op.out_op_chain(), remote_exec_op.context(),
215         remote_exec_op.remote_task(), remote_spec, num_th_results,
216         callee_name.getValue(), std::move(arguments));
217     // The part of the new results to replace the original results are 2 chains,
218     // followed `num_th_results` corert.tesnorhandle results from the callee
219     // function.
220     SmallVector<Value, 4> new_results;
221     new_results.push_back(new_remote_exec_th_op.getResult(0));
222     new_results.push_back(new_remote_exec_th_op.getResult(1));
223     for (int i : llvm::seq<int>(0, num_th_results)) {
224       new_results.push_back(
225           new_remote_exec_th_op.getResult(i + 2 + num_th_results));
226     }
227     remote_exec_op.replaceAllUsesWith(new_results);
228     remote_exec_op.erase();
229 
230     return WalkResult::advance();
231   });
232 }
233 
234 }  // namespace
235 
CreateDistRemoteRunEncapsulatePass()236 std::unique_ptr<OperationPass<ModuleOp>> CreateDistRemoteRunEncapsulatePass() {
237   return std::make_unique<DistRemoteRunEncapsulatePass>();
238 }
239 
240 static PassRegistration<DistRemoteRunEncapsulatePass> pass;
241 
242 }  // namespace tensorflow
243