1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass inserts tf_device.send and tf_device.receive ops to make sure any
17 // argument of any op is on the same host of the op itself.
18
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/Types.h" // from @llvm-project
24 #include "mlir/Pass/PassManager.h" // from @llvm-project
25 #include "mlir/Transforms/Passes.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28
29 namespace mlir {
30 namespace TF {
31
32 namespace {
33
34 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
35
36 constexpr const char *kOpDeviceAttr = "device";
37 constexpr const char *kArgDeviceAttr = "tf.device";
38 // TODO(b/175480458): Do not assign default host once every op in the TF
39 // dialect has the device attribute.
40 constexpr const char *kDefaultHost = "/job:localhost/replica:0/task:0";
41 constexpr const char *kCPUDevice = "/device:CPU:0";
42
43 // Return the job/replica/task from the device name as the host address. If no
44 // job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
45 // default host address.
GetHost(const std::string & device)46 std::string GetHost(const std::string &device) {
47 DeviceNameUtils::ParsedName parsed_name;
48 DeviceNameUtils::ParseFullName(device, &parsed_name);
49 parsed_name.has_id = false;
50 parsed_name.has_type = false;
51
52 auto host = DeviceNameUtils::ParsedNameToString(parsed_name);
53 if (host.empty()) return kDefaultHost;
54
55 return host;
56 }
57
58 struct CrossHostTransferPass
59 : public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
60 void runOnOperation() override;
61
62 private:
63 // The key_count represents the total number of send/recv pairs generated
64 // before this method call. And the key_count should be incremented based
65 // on the send/recv pairs newly generated by this method call.
66 void runOnFunction(FuncOp func_op, int &key_count);
67 };
68
runOnOperation()69 void CrossHostTransferPass::runOnOperation() {
70 ModuleOp module = getOperation();
71 int key_count = 0;
72
73 module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
74 }
75
runOnFunction(FuncOp func_op,int & key_count)76 void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
77 // This map is used to avoid transferring the same value to the same host
78 // multiple times.
79 llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
80 transferred_value_by_value_and_host;
81
82 func_op.getBody().walk([&](Operation *op) {
83 if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
84
85 OpBuilder builder(op);
86 // Get the host address of the op.
87 std::string op_device = "";
88 if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
89 op_device = device_attr.getValue().str();
90 }
91 std::string dst_host = GetHost(op_device);
92
93 for (mlir::Value arg : op->getOperands()) {
94 // Get the host address of the argument.
95 std::string arg_device = "";
96 if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
97 // Do not send this argument if it is not a function's argument. This
98 // can happen when the argument is a while loop's argument.
99 if (block_arg.getParentRegion() != &func_op.getRegion()) continue;
100
101 if (StringAttr device_attr = func_op.getArgAttrOfType<StringAttr>(
102 block_arg.getArgNumber(), kArgDeviceAttr)) {
103 arg_device = device_attr.getValue().str();
104 }
105 } else {
106 Operation *defining_op = arg.getDefiningOp();
107 if (StringAttr device_attr =
108 defining_op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
109 arg_device = device_attr.getValue().str();
110 }
111 }
112 std::string src_host = GetHost(arg_device);
113
114 if (src_host == dst_host) continue;
115
116 // Re-use the transferred argument if the argument has already been
117 // transferred to the given host.
118 llvm::StringMap<mlir::Value> &transferred_value_by_host =
119 transferred_value_by_value_and_host[arg];
120 auto iter = transferred_value_by_host.find(dst_host);
121 if (iter != transferred_value_by_host.end()) {
122 op->replaceUsesOfWith(arg, iter->second);
123 continue;
124 }
125
126 // Create tf_device.send and tf_device.receive ops to send the argument to
127 // the same host of the operation.
128 std::string key = "key-" + std::to_string(key_count);
129 key_count++;
130
131 auto send_op =
132 builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
133 send_op->setAttr(kOpDeviceAttr,
134 builder.getStringAttr(src_host + kCPUDevice));
135
136 auto receive_op = builder.create<tf_device::ReceiveOp>(
137 op->getLoc(), arg.getType(), key, src_host);
138 receive_op->setAttr(kOpDeviceAttr,
139 builder.getStringAttr(dst_host + kCPUDevice));
140
141 transferred_value_by_host[dst_host] = receive_op.getResult();
142 op->replaceUsesOfWith(arg, receive_op.getResult());
143 }
144 return WalkResult::advance();
145 });
146 }
147
148 } // namespace
149
CreateCrossHostTransferPass()150 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
151 return std::make_unique<CrossHostTransferPass>();
152 }
153
154 static PassRegistration<CrossHostTransferPass> pass(
155 "tf-cross-host-transfer",
156 "This pass inserts tf_device.send and tf_device.receive ops to make sure "
157 "any argument of any op is on the same host of the op itself.");
158
159 } // namespace TF
160 } // namespace mlir
161