1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts each tfrt_dist.remote_execute_func op into a combination
17 // of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The
18 // function to be executed in the remote host will be serialized as a string
19 // attribute of the tfrt_dist.register_tfrt_function op.
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/SymbolTable.h" // from @llvm-project
28 #include "mlir/IR/Types.h" // from @llvm-project
29 #include "mlir/IR/Visitors.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Pass/PassManager.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Transforms/Passes.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
41 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
42 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
43 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime
44 #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime
45
46 namespace tensorflow {
47
48 namespace {
49
50 constexpr const char* kHost = "host";
51 constexpr const char* kTFRTDevice = "tfrt.device";
52
53 struct DistRemoteRunEncapsulatePass
54 : public PassWrapper<DistRemoteRunEncapsulatePass,
55 OperationPass<ModuleOp>> {
getArgumenttensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass56 llvm::StringRef getArgument() const final {
57 return "tfrt-dist-remote-run-encapsulate";
58 }
getDescriptiontensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass59 llvm::StringRef getDescription() const final {
60 return "This pass looks for a remote_run_func and serialize the callee to "
61 "a string attribute attached to a remote_register operation, "
62 "followed by a remote_execute invocation.";
63 }
64 void runOnOperation() override;
65
getDependentDialectstensorflow::__anon847aae630111::DistRemoteRunEncapsulatePass66 void getDependentDialects(DialectRegistry& registry) const override {
67 registry.insert<tfrt::dist::DistributedDialect>();
68 }
69 };
70
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)71 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
72 std::string* serialized_func_module) {
73 ModuleOp module = entry_func->getParentOfType<ModuleOp>();
74 SymbolTable entry_module_table(module);
75 SmallVector<FuncOp, 4> referenced({entry_func});
76
77 // Create a new module to hold func and all referenced functions.
78 OwningModuleRef module_for_func =
79 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
80 SymbolTable symbol_table(module_for_func.get());
81
82 while (!referenced.empty()) {
83 FuncOp func = referenced.pop_back_val();
84
85 // Skip functions that have already been cloned into new module.
86 if (symbol_table.lookup<FuncOp>(func.getName())) continue;
87
88 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
89 // all found FuncOps to new_module to make sure new_module is
90 // self-contained.
91 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
92 assert(uses && "expected to be able to collect symbol uses");
93 for (SymbolTable::SymbolUse use : *uses) {
94 FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
95 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
96
97 // Skip Symbols that do not map to a function.
98 if (!referenced_func) continue;
99
100 referenced.emplace_back(referenced_func);
101 }
102
103 FuncOp clone = func.clone();
104 if (clone.getName() == entry_func.getName()) {
105 clone.setPublic();
106 } else {
107 clone.setPrivate();
108 }
109 symbol_table.insert(clone);
110 }
111
112 *serialized_func_module =
113 tensorflow::SerializeMlirModule(module_for_func.get());
114 return success();
115 }
116
runOnOperation()117 void DistRemoteRunEncapsulatePass::runOnOperation() {
118 mlir::TF::RuntimeDevices devices;
119 ModuleOp module = getOperation();
120 SymbolTable symtab(module);
121 Type chain_type = tfrt::compiler::ChainType::get(&getContext());
122 Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext());
123 Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext());
124 module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) {
125 FlatSymbolRefAttr callee_sym = remote_exec_op.calleeAttr();
126 FuncOp callee = symtab.lookup<FuncOp>(callee_sym.getValue());
127 if (!callee) {
128 remote_exec_op.emitOpError("callee function ")
129 << callee_sym.getValue() << " is not found";
130 signalPassFailure();
131 return WalkResult::interrupt();
132 }
133 std::string txt_module;
134 if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) {
135 remote_exec_op.emitOpError("failed to serialize the callee function ")
136 << callee.getName();
137 signalPassFailure();
138 return WalkResult::interrupt();
139 }
140 Location loc = remote_exec_op.getLoc();
141 StringAttr callee_name =
142 StringAttr::get(&getContext(), callee_sym.getValue());
143 OpBuilder builder(remote_exec_op);
144 auto register_op = builder.create<tfrt::dist::RegisterTFRTFunctionOp>(
145 loc, chain_type, remote_exec_op.in_op_chain(), remote_exec_op.context(),
146 remote_exec_op.remote_task(),
147 StringAttr::get(&getContext(), txt_module), callee_name);
148
149 // Build the device assignment for the results
150 // TODO(tfrt-devs): Define properly MLIR types and operations
151 SmallVector<Attribute, 8> result_devices;
152 for (const auto& result : llvm::enumerate(remote_exec_op.results())) {
153 StringAttr device =
154 callee.getResultAttrOfType<StringAttr>(result.index(), kTFRTDevice);
155 if (!device) {
156 // The result might not have the device attribute if it is added by
157 // the tf-to-tfrt pass. Use the first CPU on the remote host as the
158 // device of this result.
159 DeviceNameUtils::ParsedName parsed_name;
160 if (StringAttr host_attr = callee->getAttrOfType<StringAttr>(kHost)) {
161 auto host = host_attr.getValue();
162 DeviceNameUtils::ParseFullName({host.data(), host.size()},
163 &parsed_name);
164 }
165 parsed_name.has_type = true;
166 parsed_name.type = "CPU";
167 parsed_name.has_id = true;
168 parsed_name.id = 0;
169 device = StringAttr::get(
170 &getContext(), DeviceNameUtils::ParsedNameToString(parsed_name));
171 }
172 result_devices.push_back(std::move(device));
173 }
174 // IDEA(donglin): Update the create_remote_execute_spec kernel to use Device
175 // object instead of Device string.
176 Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext());
177 auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices);
178 auto remote_spec = builder.create<tfrt::dist::CreateRemoteExecuteSpecOp>(
179 loc, remote_spec_ty, remote_exec_op.context(), result_devices_attr);
180 // If original argument is already tfrt_dist.remote_object_id, use it
181 // directly. If it is TensorHandle, insert an op to extract the
182 // tfrt_dist.remote_object_id from it. Otherwise, emit an error.
183 SmallVector<Value, 4> arguments;
184 for (Value value : remote_exec_op.callee_args()) {
185 if (value.getType().isa<tfrt::dist::RemoteObjectIdType>()) {
186 arguments.push_back(value);
187 } else if (value.getType().isa<tfrt::corert::TensorHandleType>()) {
188 auto new_op = builder.create<tfrt::dist::GetRemoteObjectIdFromTHOp>(
189 loc, remote_object_id_ty, value);
190 arguments.push_back(new_op.result());
191 } else {
192 remote_exec_op.emitOpError(
193 "callee argument type should be either "
194 "TensorHandle or RemoteObjectId");
195 signalPassFailure();
196 return WalkResult::interrupt();
197 }
198 }
199 // Result types are 1 chain, followed by `num_th_results + 1`
200 // tfrt_dist.remote_object_id results, followed by `num_th_results`
201 // corert.tensorhandle results.
202 int32_t num_th_results = remote_exec_op.results().size() - 1;
203 SmallVector<Type, 8> result_types;
204 result_types.push_back(chain_type);
205 for (int count : llvm::seq<int>(0, num_th_results + 1)) {
206 (void)count;
207 result_types.push_back(remote_object_id_ty);
208 }
209 for (int count : llvm::seq<int>(0, num_th_results)) {
210 (void)count;
211 result_types.push_back(tensor_handle_ty);
212 }
213 auto new_remote_exec_th_op = builder.create<tfrt::dist::RemoteExecuteTHOp>(
214 loc, result_types, register_op.out_op_chain(), remote_exec_op.context(),
215 remote_exec_op.remote_task(), remote_spec, num_th_results,
216 callee_name.getValue(), std::move(arguments));
217 // The part of the new results to replace the original results are 2 chains,
218 // followed `num_th_results` corert.tesnorhandle results from the callee
219 // function.
220 SmallVector<Value, 4> new_results;
221 new_results.push_back(new_remote_exec_th_op.getResult(0));
222 new_results.push_back(new_remote_exec_th_op.getResult(1));
223 for (int i : llvm::seq<int>(0, num_th_results)) {
224 new_results.push_back(
225 new_remote_exec_th_op.getResult(i + 2 + num_th_results));
226 }
227 remote_exec_op.replaceAllUsesWith(new_results);
228 remote_exec_op.erase();
229
230 return WalkResult::advance();
231 });
232 }
233
234 } // namespace
235
CreateDistRemoteRunEncapsulatePass()236 std::unique_ptr<OperationPass<ModuleOp>> CreateDistRemoteRunEncapsulatePass() {
237 return std::make_unique<DistRemoteRunEncapsulatePass>();
238 }
239
240 static PassRegistration<DistRemoteRunEncapsulatePass> pass;
241
242 } // namespace tensorflow
243