• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass inserts corert.transfer op to make sure any argument of any op is
17 // on the same device of the op itself.
18 
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/IR/Attributes.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/Types.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/core/util/device_name_utils.h"
27 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
28 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
29 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
31 
32 namespace tensorflow {
33 
34 namespace {
35 
36 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
37 
38 constexpr const char *kDeviceAttr = "device";
39 constexpr const char *kTFRTDeviceAttr = "tfrt.device";
40 // TODO(b/175480458): Do not assign default device once every op in the TF
41 // dialect has the device attribute.
42 constexpr const char *kDefaultDevice =
43     "/job:localhost/replica:0/task:0/device:CPU:0";
44 
45 // This method canonicalizes the device name so that we can use string
46 // comparison to see if two devices are the same. It does the following
47 // transformations:
48 // 1) Set device ID to 0 if device ID is not already specified.
49 // 2) Change the device type to uppercase string.
CanonicalizeDeviceName(const std::string & device)50 static std::string CanonicalizeDeviceName(const std::string &device) {
51   if (device.empty()) return kDefaultDevice;
52 
53   DeviceNameUtils::ParsedName parsed_name;
54   if (!device.empty() && device.at(0) == '/') {
55     DeviceNameUtils::ParseFullName(device, &parsed_name);
56   } else {
57     DeviceNameUtils::ParseFullName("/device:" + device, &parsed_name);
58   }
59 
60   if (!parsed_name.has_id) {
61     parsed_name.has_id = true;
62     parsed_name.id = 0;
63   }
64 
65   if (parsed_name.type == "cpu")
66     parsed_name.type = "CPU";
67   else if (parsed_name.type == "gpu")
68     parsed_name.type = "GPU";
69   else if (parsed_name.type == "tpu")
70     parsed_name.type = "TPU";
71   return DeviceNameUtils::ParsedNameToString(parsed_name);
72 }
73 
74 // Return the device of the given operation.
GetDevice(Operation * op)75 static std::string GetDevice(Operation *op) {
76   std::string device = "";
77   if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
78     device = device_attr.getValue().str();
79   } else if (auto execute_op = llvm::dyn_cast<tfrt::corert::ExecuteOp>(op)) {
80     SmallVector<std::pair<StringRef, Attribute>, 4> attrs;
81     execute_op.getOpAttrs(&attrs);
82     for (std::pair<StringRef, Attribute> entry : attrs) {
83       if (entry.first == kDeviceAttr && entry.second.isa<StringAttr>()) {
84         device = entry.second.cast<StringAttr>().getValue().str();
85         break;
86       }
87     }
88   }
89 
90   return CanonicalizeDeviceName(device);
91 }
92 
93 // Return the device of the given value.
GetDevice(mlir::Value value,FuncOp parent_func_op)94 static std::string GetDevice(mlir::Value value, FuncOp parent_func_op) {
95   std::string device = "";
96   if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
97     if (StringAttr device_attr = parent_func_op.getArgAttrOfType<StringAttr>(
98             block_arg.getArgNumber(), kTFRTDeviceAttr)) {
99       device = device_attr.getValue().str();
100     }
101   } else {
102     device = GetDevice(value.getDefiningOp());
103   }
104 
105   return CanonicalizeDeviceName(device);
106 }
107 
108 struct CrossDeviceTransferPass
109     : public PassWrapper<CrossDeviceTransferPass, FunctionPass> {
110   void runOnFunction() override;
111 
getArgumenttensorflow::__anoncfb66f6c0111::CrossDeviceTransferPass112   llvm::StringRef getArgument() const final {
113     return "tfrt-cross-device-transfer";
114   }
115 
getDescriptiontensorflow::__anoncfb66f6c0111::CrossDeviceTransferPass116   llvm::StringRef getDescription() const final {
117     return "This pass inserts corert.transfer op to make sure any argument of "
118            "any op is on the same device of the op itself.";
119   }
120 };
121 
runOnFunction()122 void CrossDeviceTransferPass::runOnFunction() {
123   FuncOp func_op = getOperation();
124   llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
125       transferred_value_by_value_and_device;
126 
127   func_op.getBody().walk([&](Operation *op) {
128     if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
129     // Do not transfer the argument of corert.transfer op.
130     if (llvm::isa<tfrt::corert::TransferOp>(op)) return WalkResult::advance();
131 
132     OpBuilder builder(op);
133     std::string dst_device = GetDevice(op);
134     mlir::Type tensor_type_type =
135         builder.getType<::tfrt::compiler::TensorTypeType>();
136     mlir::Type device_type = builder.getType<::tfrt::compiler::DeviceType>();
137 
138     for (mlir::Value arg : op->getOperands()) {
139       // Do not transfer non-TensorHandle values.
140       if (!arg.getType().isa<tfrt::corert::TensorHandleType>()) continue;
141 
142       // Do not transfer the result of corert.transfer op.
143       if (OpResult op_result = arg.dyn_cast<OpResult>()) {
144         Operation *defining_op = arg.getDefiningOp();
145         if (llvm::isa<tfrt::corert::TransferOp>(defining_op)) continue;
146       }
147 
148       std::string src_device = GetDevice(arg, func_op);
149 
150       if (DeviceNameUtils::LocalName(src_device) ==
151           DeviceNameUtils::LocalName(dst_device))
152         continue;
153 
154       // Re-use the value already transferred to the given device.
155       llvm::StringMap<mlir::Value> &transferred_value_by_device =
156           transferred_value_by_value_and_device[arg];
157       auto iter = transferred_value_by_device.find(dst_device);
158       if (iter != transferred_value_by_device.end()) {
159         op->replaceUsesOfWith(arg, iter->second);
160         continue;
161       }
162 
163       mlir::Value chain_in = func_op.getArgument(0);
164       auto get_device_op = builder.create<tfrt::compiler::GetDeviceOp>(
165           op->getLoc(), device_type, chain_in, dst_device);
166       auto get_tensor_type_op =
167           builder.create<tfrt::corert::GetDstTensorTypeOp>(
168               op->getLoc(), tensor_type_type, arg, get_device_op.getResult());
169       auto transfer_op = builder.create<tfrt::corert::TransferOp>(
170           op->getLoc(), arg.getType(), arg, get_device_op.getResult(),
171           get_tensor_type_op.getResult());
172       mlir::Value new_arg = transfer_op.getResult();
173       transferred_value_by_device[dst_device] = new_arg;
174       op->replaceUsesOfWith(arg, new_arg);
175     }
176     return WalkResult::advance();
177   });
178 }
179 
180 }  // namespace
181 
CreateCrossDeviceTransferPass()182 std::unique_ptr<FunctionPass> CreateCrossDeviceTransferPass() {
183   return std::make_unique<CrossDeviceTransferPass>();
184 }
185 
186 static PassRegistration<CrossDeviceTransferPass> pass;
187 
188 }  // namespace tensorflow
189