1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass inserts corert.transfer op to make sure any argument of any op is
17 // on the same device of the op itself.
18
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/Types.h" // from @llvm-project
24 #include "mlir/Pass/PassManager.h" // from @llvm-project
25 #include "mlir/Transforms/Passes.h" // from @llvm-project
26 #include "tensorflow/core/util/device_name_utils.h"
27 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
28 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
29 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
31
32 namespace tensorflow {
33
34 namespace {
35
36 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
37
38 constexpr const char *kDeviceAttr = "device";
39 constexpr const char *kTFRTDeviceAttr = "tfrt.device";
40 // TODO(b/175480458): Do not assign default device once every op in the TF
41 // dialect has the device attribute.
42 constexpr const char *kDefaultDevice =
43 "/job:localhost/replica:0/task:0/device:CPU:0";
44
45 // This method canonicalizes the device name so that we can use string
46 // comparison to see if two devices are the same. It does the following
47 // transformations:
48 // 1) Set device ID to 0 if device ID is not already specified.
49 // 2) Change the device type to uppercase string.
CanonicalizeDeviceName(const std::string & device)50 static std::string CanonicalizeDeviceName(const std::string &device) {
51 if (device.empty()) return kDefaultDevice;
52
53 DeviceNameUtils::ParsedName parsed_name;
54 if (!device.empty() && device.at(0) == '/') {
55 DeviceNameUtils::ParseFullName(device, &parsed_name);
56 } else {
57 DeviceNameUtils::ParseFullName("/device:" + device, &parsed_name);
58 }
59
60 if (!parsed_name.has_id) {
61 parsed_name.has_id = true;
62 parsed_name.id = 0;
63 }
64
65 if (parsed_name.type == "cpu")
66 parsed_name.type = "CPU";
67 else if (parsed_name.type == "gpu")
68 parsed_name.type = "GPU";
69 else if (parsed_name.type == "tpu")
70 parsed_name.type = "TPU";
71 return DeviceNameUtils::ParsedNameToString(parsed_name);
72 }
73
74 // Return the device of the given operation.
GetDevice(Operation * op)75 static std::string GetDevice(Operation *op) {
76 std::string device = "";
77 if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
78 device = device_attr.getValue().str();
79 } else if (auto execute_op = llvm::dyn_cast<tfrt::corert::ExecuteOp>(op)) {
80 SmallVector<std::pair<StringRef, Attribute>, 4> attrs;
81 execute_op.getOpAttrs(&attrs);
82 for (std::pair<StringRef, Attribute> entry : attrs) {
83 if (entry.first == kDeviceAttr && entry.second.isa<StringAttr>()) {
84 device = entry.second.cast<StringAttr>().getValue().str();
85 break;
86 }
87 }
88 }
89
90 return CanonicalizeDeviceName(device);
91 }
92
93 // Return the device of the given value.
GetDevice(mlir::Value value,FuncOp parent_func_op)94 static std::string GetDevice(mlir::Value value, FuncOp parent_func_op) {
95 std::string device = "";
96 if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
97 if (StringAttr device_attr = parent_func_op.getArgAttrOfType<StringAttr>(
98 block_arg.getArgNumber(), kTFRTDeviceAttr)) {
99 device = device_attr.getValue().str();
100 }
101 } else {
102 device = GetDevice(value.getDefiningOp());
103 }
104
105 return CanonicalizeDeviceName(device);
106 }
107
108 struct CrossDeviceTransferPass
109 : public PassWrapper<CrossDeviceTransferPass, FunctionPass> {
110 void runOnFunction() override;
111
getArgumenttensorflow::__anoncfb66f6c0111::CrossDeviceTransferPass112 llvm::StringRef getArgument() const final {
113 return "tfrt-cross-device-transfer";
114 }
115
getDescriptiontensorflow::__anoncfb66f6c0111::CrossDeviceTransferPass116 llvm::StringRef getDescription() const final {
117 return "This pass inserts corert.transfer op to make sure any argument of "
118 "any op is on the same device of the op itself.";
119 }
120 };
121
runOnFunction()122 void CrossDeviceTransferPass::runOnFunction() {
123 FuncOp func_op = getOperation();
124 llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
125 transferred_value_by_value_and_device;
126
127 func_op.getBody().walk([&](Operation *op) {
128 if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
129 // Do not transfer the argument of corert.transfer op.
130 if (llvm::isa<tfrt::corert::TransferOp>(op)) return WalkResult::advance();
131
132 OpBuilder builder(op);
133 std::string dst_device = GetDevice(op);
134 mlir::Type tensor_type_type =
135 builder.getType<::tfrt::compiler::TensorTypeType>();
136 mlir::Type device_type = builder.getType<::tfrt::compiler::DeviceType>();
137
138 for (mlir::Value arg : op->getOperands()) {
139 // Do not transfer non-TensorHandle values.
140 if (!arg.getType().isa<tfrt::corert::TensorHandleType>()) continue;
141
142 // Do not transfer the result of corert.transfer op.
143 if (OpResult op_result = arg.dyn_cast<OpResult>()) {
144 Operation *defining_op = arg.getDefiningOp();
145 if (llvm::isa<tfrt::corert::TransferOp>(defining_op)) continue;
146 }
147
148 std::string src_device = GetDevice(arg, func_op);
149
150 if (DeviceNameUtils::LocalName(src_device) ==
151 DeviceNameUtils::LocalName(dst_device))
152 continue;
153
154 // Re-use the value already transferred to the given device.
155 llvm::StringMap<mlir::Value> &transferred_value_by_device =
156 transferred_value_by_value_and_device[arg];
157 auto iter = transferred_value_by_device.find(dst_device);
158 if (iter != transferred_value_by_device.end()) {
159 op->replaceUsesOfWith(arg, iter->second);
160 continue;
161 }
162
163 mlir::Value chain_in = func_op.getArgument(0);
164 auto get_device_op = builder.create<tfrt::compiler::GetDeviceOp>(
165 op->getLoc(), device_type, chain_in, dst_device);
166 auto get_tensor_type_op =
167 builder.create<tfrt::corert::GetDstTensorTypeOp>(
168 op->getLoc(), tensor_type_type, arg, get_device_op.getResult());
169 auto transfer_op = builder.create<tfrt::corert::TransferOp>(
170 op->getLoc(), arg.getType(), arg, get_device_op.getResult(),
171 get_tensor_type_op.getResult());
172 mlir::Value new_arg = transfer_op.getResult();
173 transferred_value_by_device[dst_device] = new_arg;
174 op->replaceUsesOfWith(arg, new_arg);
175 }
176 return WalkResult::advance();
177 });
178 }
179
180 } // namespace
181
CreateCrossDeviceTransferPass()182 std::unique_ptr<FunctionPass> CreateCrossDeviceTransferPass() {
183 return std::make_unique<CrossDeviceTransferPass>();
184 }
185
186 static PassRegistration<CrossDeviceTransferPass> pass;
187
188 } // namespace tensorflow
189