1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass clusters the TensorFlow ops by host. The program generated by this
17 // pass will have one function per host where all operations in the same
18 // function are placed on the same host. Each result of the per-host function
19 // will have a "tf.device" attribute which specifies the device assignment of
20 // the result.
21 //
22 // The pass currently assumes that there is no circular dependency among the
23 // per-host functions. For example, if there exists an operation placed on
24 // host_A that consumes the result of an operation placed on host_B, then there
25 // does not exist any operation placed on host_B that conumes any result of any
26 // operation placed on host_A.
27
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37
38 namespace mlir {
39 namespace TF {
40 namespace {
41
42 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
43 using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
44
45 constexpr const char *kHostAttr = "host";
46 constexpr const char *kDeviceAttr = "device";
47 constexpr const char *kTFDeviceAttr = "tf.device";
48 // TODO(donglin): Handle the case where the address of localhost is different
49 // from /job:localhost/replica:0/task:0.
50 constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
51 constexpr const char *kErrorMessage =
52 "The operation that uses the operand is on a different host than the "
53 "operation that defines the op. This pass does not support cross-host data "
54 "transfer yet";
55
56 // The host address is identified by the job/replicate/task in the device name.
GetHost(llvm::StringRef device)57 std::string GetHost(llvm::StringRef device) {
58 ParsedName parsed_name;
59 DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
60 std::string result = DeviceNameUtils::ParsedNameToString(
61 DeviceNameUtils::AddressSpace(parsed_name));
62 return result.empty() ? kLocalhost : result;
63 }
64
GetHost(Operation * op)65 std::string GetHost(Operation *op) {
66 std::string device = "";
67 if (StringAttr attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
68 device = attr.getValue().str();
69 }
70 return GetHost(device);
71 }
72
73 // The device is considered to be on the localhost iff one of the following is
74 // true:
75 // 1) None of the job/replica/task is specified in the device name.
76 // 2) The job/replica/task in the device name are explicitly specified as
77 // /job:localhost/replica:0/task:0.
IsOnLocalHost(llvm::StringRef device)78 bool IsOnLocalHost(llvm::StringRef device) {
79 std::string host = GetHost(device);
80 return host == kLocalhost;
81 }
82
83 // This structure contains the metadata of the per-host function. All operations
84 // in this function should be on the same host.
85 struct FunctionMetadata {
86 // The original function name before partition.
87 llvm::StringRef original_name;
88 // The insertion point of partition functions.
89 Block::iterator insertion_point;
90 // The partitioned function name.
91 llvm::StringRef partition_name;
92 // The input values of the function.
93 llvm::SmallVector<Value, 4> inputs;
94 // The result values of the function.
95 llvm::SmallVector<Value, 4> results;
96 // The devices of the input values. It should have the same size as inputs.
97 llvm::SmallVector<std::string, 4> input_devices;
98 // The devices of the result values. It should have the same size as results.
99 llvm::SmallVector<std::string, 4> result_devices;
100 // The operations to be included in the body of the function.
101 llvm::SmallVector<Operation *, 4> ops;
102
103 FuncOp partition_op;
104 };
105
106 // Returns a map that maps the host address to the metadata of the function
107 // for that remote host. The metadata of the function specifies the input
108 // values, result values, result devices and the operations to be included in
109 // the function body.
GetFunctionMetadatas(FuncOp func_op)110 llvm::Optional<llvm::StringMap<FunctionMetadata>> GetFunctionMetadatas(
111 FuncOp func_op) {
112 llvm::StringMap<FunctionMetadata> metadatas;
113 WalkResult result = func_op.getBody().walk([&](Operation *op) {
114 std::string op_host = GetHost(op);
115 FunctionMetadata &func_metadata = metadatas[op_host];
116 func_metadata.original_name = func_op.getName();
117 func_metadata.insertion_point = ++Block::iterator(func_op);
118 func_metadata.ops.push_back(op);
119
120 for (Value value : op->getOperands()) {
121 std::string value_device = "";
122
123 // If the value is defined as an argument of the func_op, adds it to
124 // the argument list of the function that uses this op.
125 if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
126 if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
127 block_arg.getArgNumber(), kTFDeviceAttr)) {
128 value_device = attr.getValue().str();
129 }
130
131 if (GetHost(value_device) != op_host) {
132 op->emitOpError() << kErrorMessage;
133 return WalkResult::interrupt();
134 }
135
136 if (llvm::find(func_metadata.inputs, value) ==
137 func_metadata.inputs.end()) {
138 func_metadata.inputs.push_back(value);
139 func_metadata.input_devices.push_back(value_device);
140 }
141 continue;
142 }
143
144 Operation *defining_op = value.getDefiningOp();
145 std::string defining_op_host = GetHost(defining_op);
146 FunctionMetadata &defining_func_metadata = metadatas[defining_op_host];
147
148 if (StringAttr attr =
149 defining_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
150 value_device = attr.getValue().str();
151 }
152
153 // If the value is used as an operand of the terminator op, adds it to
154 // the result list of function that defines this op.
155 if (op->hasTrait<OpTrait::IsTerminator>()) {
156 if (llvm::find(defining_func_metadata.results, value) ==
157 defining_func_metadata.results.end()) {
158 defining_func_metadata.results.push_back(value);
159 defining_func_metadata.result_devices.push_back(value_device);
160 }
161 continue;
162 }
163
164 if (defining_op_host != op_host) {
165 op->emitOpError() << kErrorMessage;
166 return WalkResult::interrupt();
167 }
168 }
169 return WalkResult::advance();
170 });
171
172 if (result.wasInterrupted()) return llvm::None;
173
174 return metadatas;
175 }
176
177 // Creates functions in the given module using the given FunctionMetadatas.
CreateFunctions(ModuleOp module_op,llvm::StringMap<FunctionMetadata> & metadatas)178 void CreateFunctions(ModuleOp module_op,
179 llvm::StringMap<FunctionMetadata> &metadatas) {
180 MLIRContext *context = module_op.getContext();
181 SymbolTable symbol_table(module_op);
182 for (auto &iter : metadatas) {
183 llvm::StringRef host = iter.first();
184 FunctionMetadata &metadata = iter.second;
185
186 // Do not create any new function for the operations on the localhost.
187 if (IsOnLocalHost(host)) continue;
188
189 llvm::SmallVector<mlir::Type, 4> input_types;
190 llvm::SmallVector<mlir::Type, 4> result_types;
191 for (Value input : metadata.inputs) {
192 input_types.push_back(input.getType());
193 }
194 for (Value result : metadata.results) {
195 result_types.push_back(result.getType());
196 }
197
198 // Replaces ':' and '/' with '_' in the host name and uses the resulting
199 // string as the function name.
200 std::string func_name =
201 absl::StrCat(iter.second.original_name.str(), ":", host.str());
202 std::replace(func_name.begin(), func_name.end(), ':', '_');
203 std::replace(func_name.begin(), func_name.end(), '/', '_');
204
205 FunctionType func_type =
206 FunctionType::get(context, input_types, result_types);
207 Location loc = metadata.ops.front()->getLoc();
208 FuncOp func_op = FuncOp::create(loc, func_name, func_type);
209 // Sets the device attribute for every input and every result of the
210 // function.
211 for (int i : llvm::seq<int>(0, metadata.input_devices.size())) {
212 func_op.setArgAttr(i, kTFDeviceAttr,
213 StringAttr::get(context, metadata.input_devices[i]));
214 }
215 for (int i : llvm::seq<int>(0, metadata.result_devices.size())) {
216 func_op.setResultAttr(
217 i, kTFDeviceAttr,
218 StringAttr::get(context, metadata.result_devices[i]));
219 }
220
221 func_op->setAttr(kHostAttr, StringAttr::get(context, host));
222 func_op.setPublic();
223 Block *block = func_op.addEntryBlock();
224
225 // Clones and moves the operations into the function's body. And the cloned
226 // operation should use the arguments of the newly created func_op as
227 // appropriate.
228 OpBuilder builder(block, block->end());
229 BlockAndValueMapping mapping;
230 for (int i : llvm::seq<int>(0, metadata.inputs.size())) {
231 Value original_value = metadata.inputs[i];
232 Value new_value = func_op.getArgument(i);
233 mapping.map(original_value, new_value);
234 }
235 for (Operation *op : metadata.ops) {
236 builder.clone(*op, mapping);
237 }
238 // Creates the ReturnOp so that the per-host function returns the
239 // correct values of the cloned operations.
240 llvm::SmallVector<Value, 4> results_after_mapping;
241 for (Value result : metadata.results) {
242 results_after_mapping.push_back(mapping.lookupOrDefault(result));
243 }
244 builder.create<ReturnOp>(loc, results_after_mapping);
245 symbol_table.insert(func_op, metadata.insertion_point++);
246 // Record the actual name. The symbol table might rename the FuncOp if there
247 // is name collision.
248 metadata.partition_name = func_op.getName();
249 }
250 }
251
252 // Creates a tf_device.remote_run call for every remote function. And replaces
253 // usages of the results of the original operations with the results of the
254 // tf_device.remote_run calls.
CreateRemoteRunCalls(MLIRContext * context,const llvm::StringMap<FunctionMetadata> & metadatas)255 void CreateRemoteRunCalls(MLIRContext *context,
256 const llvm::StringMap<FunctionMetadata> &metadatas) {
257 BlockAndValueMapping mapping;
258 for (auto &iter : metadatas) {
259 llvm::StringRef host = iter.first();
260 const FunctionMetadata &metadata = iter.second;
261
262 // Do not create tf_device.remote_run call for the operations already placed
263 // on the localhost.
264 if (IsOnLocalHost(host)) continue;
265
266 // Creates the tf_device.remote_run operation.
267 OpBuilder builder(metadata.ops.back());
268 llvm::SmallVector<Type, 4> result_types;
269 for (Value result : metadata.results) {
270 result_types.push_back(result.getType());
271 }
272 Location loc = metadata.ops.front()->getLoc();
273 llvm::SmallVector<Value, 4> inputs_after_mapping;
274 for (Value input : metadata.inputs) {
275 inputs_after_mapping.push_back(mapping.lookupOrDefault(input));
276 }
277
278 tf_device::RemoteRunOp remote_run_op =
279 builder.create<tf_device::RemoteRunOp>(loc, result_types, host,
280 metadata.partition_name,
281 inputs_after_mapping);
282 // Clones the tf_device.remote_run operation to replace its callee args with
283 // the results of the other tf_device.remote_run operations using the
284 // `mapping` as appropriate.
285 Operation *cloned_remote_run_op =
286 builder.clone(*remote_run_op.getOperation(), mapping);
287 remote_run_op.erase();
288
289 // Replaces usages of the results of the original operations with the
290 // results of the tf_device.remote_run operations.
291 for (int i : llvm::seq<int>(0, metadata.results.size())) {
292 Value original_value = metadata.results[i];
293 Value new_value = cloned_remote_run_op->getResult(i);
294 original_value.replaceAllUsesWith(new_value);
295 mapping.map(original_value, new_value);
296 }
297 }
298 }
299
300 class ClusterTFOpsByHostPass
301 : public PassWrapper<ClusterTFOpsByHostPass, OperationPass<ModuleOp>> {
getArgument() const302 StringRef getArgument() const final { return "cluster-tf-ops-by-host"; }
303
getDescription() const304 StringRef getDescription() const final {
305 return "Cluster the TensorFlow ops by host so that each function only "
306 "contains ops placed on the same host";
307 }
308
runOnOperation()309 void runOnOperation() override {
310 MLIRContext *context = &getContext();
311 ModuleOp module_op = getOperation();
312 SmallVector<FuncOp, 4> original_func;
313 for (auto func_op : module_op.getOps<FuncOp>()) {
314 original_func.push_back(func_op);
315 }
316 for (auto func_op : original_func) {
317 llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
318 GetFunctionMetadatas(func_op);
319 if (!metadatas) {
320 signalPassFailure();
321 return;
322 }
323
324 CreateFunctions(module_op, *metadatas);
325 CreateRemoteRunCalls(context, *metadatas);
326
327 // Erases the original operations which have been cloned in the remote
328 // functions.
329 for (auto &iter : *metadatas) {
330 llvm::StringRef host = iter.first();
331 FunctionMetadata &metadata = iter.second;
332 // Do not erase operations placed on the localhost.
333 if (IsOnLocalHost(host)) continue;
334
335 for (int i = metadata.ops.size() - 1; i >= 0; i--) {
336 metadata.ops[i]->erase();
337 }
338 }
339 }
340 }
341 };
342
343 } // namespace
344
CreateClusterTFOpsByHostPass()345 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass() {
346 return std::make_unique<ClusterTFOpsByHostPass>();
347 }
348
349 static PassRegistration<ClusterTFOpsByHostPass> pass;
350
351 } // namespace TF
352 } // namespace mlir
353