• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass inserts tf_device.send and tf_device.receive ops to make sure any
17 // argument of any op is on the same host of the op itself.
18 
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/Types.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28 
29 namespace mlir {
30 namespace TF {
31 
32 namespace {
33 
34 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
35 
36 constexpr const char *kOpDeviceAttr = "device";
37 constexpr const char *kArgDeviceAttr = "tf.device";
38 // TODO(b/175480458): Do not assign default host once every op in the TF
39 // dialect has the device attribute.
40 constexpr const char *kDefaultHost = "/job:localhost/replica:0/task:0";
41 constexpr const char *kCPUDevice = "/device:CPU:0";
42 
43 // Return the job/replica/task from the device name as the host address. If no
44 // job/replica/task is specified, return /job:localhost/replica:0/task:0 as the
45 // default host address.
GetHost(const std::string & device)46 std::string GetHost(const std::string &device) {
47   DeviceNameUtils::ParsedName parsed_name;
48   DeviceNameUtils::ParseFullName(device, &parsed_name);
49   parsed_name.has_id = false;
50   parsed_name.has_type = false;
51 
52   auto host = DeviceNameUtils::ParsedNameToString(parsed_name);
53   if (host.empty()) return kDefaultHost;
54 
55   return host;
56 }
57 
58 struct CrossHostTransferPass
59     : public PassWrapper<CrossHostTransferPass, OperationPass<ModuleOp>> {
60   void runOnOperation() override;
61 
62  private:
63   // The key_count represents the total number of send/recv pairs generated
64   // before this method call. And the key_count should be incremented based
65   // on the send/recv pairs newly generated by this method call.
66   void runOnFunction(FuncOp func_op, int &key_count);
67 };
68 
runOnOperation()69 void CrossHostTransferPass::runOnOperation() {
70   ModuleOp module = getOperation();
71   int key_count = 0;
72 
73   module.walk([&](FuncOp func_op) { runOnFunction(func_op, key_count); });
74 }
75 
runOnFunction(FuncOp func_op,int & key_count)76 void CrossHostTransferPass::runOnFunction(FuncOp func_op, int &key_count) {
77   // This map is used to avoid transferring the same value to the same host
78   // multiple times.
79   llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
80       transferred_value_by_value_and_host;
81 
82   func_op.getBody().walk([&](Operation *op) {
83     if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
84 
85     OpBuilder builder(op);
86     // Get the host address of the op.
87     std::string op_device = "";
88     if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
89       op_device = device_attr.getValue().str();
90     }
91     std::string dst_host = GetHost(op_device);
92 
93     for (mlir::Value arg : op->getOperands()) {
94       // Get the host address of the argument.
95       std::string arg_device = "";
96       if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
97         // Do not send this argument if it is not a function's argument. This
98         // can happen when the argument is a while loop's argument.
99         if (block_arg.getParentRegion() != &func_op.getRegion()) continue;
100 
101         if (StringAttr device_attr = func_op.getArgAttrOfType<StringAttr>(
102                 block_arg.getArgNumber(), kArgDeviceAttr)) {
103           arg_device = device_attr.getValue().str();
104         }
105       } else {
106         Operation *defining_op = arg.getDefiningOp();
107         if (StringAttr device_attr =
108                 defining_op->getAttrOfType<StringAttr>(kOpDeviceAttr)) {
109           arg_device = device_attr.getValue().str();
110         }
111       }
112       std::string src_host = GetHost(arg_device);
113 
114       if (src_host == dst_host) continue;
115 
116       // Re-use the transferred argument if the argument has already been
117       // transferred to the given host.
118       llvm::StringMap<mlir::Value> &transferred_value_by_host =
119           transferred_value_by_value_and_host[arg];
120       auto iter = transferred_value_by_host.find(dst_host);
121       if (iter != transferred_value_by_host.end()) {
122         op->replaceUsesOfWith(arg, iter->second);
123         continue;
124       }
125 
126       // Create tf_device.send and tf_device.receive ops to send the argument to
127       // the same host of the operation.
128       std::string key = "key-" + std::to_string(key_count);
129       key_count++;
130 
131       auto send_op =
132           builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
133       send_op->setAttr(kOpDeviceAttr,
134                        builder.getStringAttr(src_host + kCPUDevice));
135 
136       auto receive_op = builder.create<tf_device::ReceiveOp>(
137           op->getLoc(), arg.getType(), key, src_host);
138       receive_op->setAttr(kOpDeviceAttr,
139                           builder.getStringAttr(dst_host + kCPUDevice));
140 
141       transferred_value_by_host[dst_host] = receive_op.getResult();
142       op->replaceUsesOfWith(arg, receive_op.getResult());
143     }
144     return WalkResult::advance();
145   });
146 }
147 
148 }  // namespace
149 
CreateCrossHostTransferPass()150 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateCrossHostTransferPass() {
151   return std::make_unique<CrossHostTransferPass>();
152 }
153 
154 static PassRegistration<CrossHostTransferPass> pass(
155     "tf-cross-host-transfer",
156     "This pass inserts tf_device.send and tf_device.receive ops to make sure "
157     "any argument of any op is on the same host of the op itself.");
158 
159 }  // namespace TF
160 }  // namespace mlir
161