1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/IR/Value.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
28
29 namespace mlir {
30 namespace TFTPU {
31 namespace {
32
33 // Pass that co-locates resource ops that use composite device resources
34 // (packed tensors) with the underlying physical TPU device.
35 struct TPUColocateCompositeResourceOps
36 : public PassWrapper<TPUColocateCompositeResourceOps, FunctionPass> {
37 void runOnFunction() override;
38 };
39
40 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)41 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
42 llvm::StringRef device) {
43 builder->setInsertionPoint(op);
44 auto launch = builder->create<tf_device::LaunchOp>(
45 loc, builder->getStringAttr(device), op->getResultTypes());
46 launch.body().push_back(new Block);
47 op->replaceAllUsesWith(launch);
48
49 builder->setInsertionPointToEnd(&launch.GetBody());
50 builder->create<tf_device::ReturnOp>(loc, op->getResults());
51
52 // Move op inside cluster.
53 op->moveBefore(launch.GetBody().getTerminator());
54 }
55
GetResourceOpsUsingCompositeArgsInReplicate(tf_device::ReplicateOp replicate)56 llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
57 tf_device::ReplicateOp replicate) {
58 llvm::SmallVector<Operation*, 4> resource_users;
59 const auto add_resource_op_to_list = [&resource_users](Operation* op) {
60 if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
61
62 resource_users.emplace_back(op);
63 };
64
65 llvm::SmallVector<Operation*, 4> resource_users_to_visit;
66 for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
67 for (auto resource_user : composite_arguments.getUsers())
68 resource_users_to_visit.emplace_back(resource_user);
69 }
70
71 while (!resource_users_to_visit.empty()) {
72 llvm::SmallVector<Operation*, 4> new_resource_users;
73
74 for (auto resource_user : resource_users_to_visit) {
75 add_resource_op_to_list(resource_user);
76
77 // Account for pass-through identity ops.
78 if (auto pass_through_identity =
79 llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
80 for (auto identity_user : pass_through_identity.output().getUsers()) {
81 new_resource_users.emplace_back(identity_user);
82 }
83 }
84 }
85 resource_users_to_visit.swap(new_resource_users);
86 }
87
88 return resource_users;
89 }
90
ColocateCompositeResourceOpsInReplicate(tf_device::ReplicateOp replicate_op,OpBuilder * builder)91 void ColocateCompositeResourceOpsInReplicate(
92 tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
93 auto devices = replicate_op.devices();
94 if (!devices) return;
95 if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
96 return;
97
98 const auto composite_resource_users =
99 GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
100 for (auto resource_user : composite_resource_users) {
101 WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
102 tensorflow::GetDeviceAliasForLogicalCore(0));
103 }
104 }
105
runOnFunction()106 void TPUColocateCompositeResourceOps::runOnFunction() {
107 // Find all the executes first, since we will mutate the nodes around each
108 // execute in the same tf_device.replicate op.
109 llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
110 getFunction().walk([&](tf_device::LaunchOp op) {
111 if (op.WrapsSingleOp() &&
112 llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
113 op.GetBody().front()))
114 execute_launches.push_back(op);
115 });
116
117 OpBuilder builder(&getContext());
118 for (auto execute_launch : execute_launches) {
119 auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
120 if (!replicate) continue;
121
122 ColocateCompositeResourceOpsInReplicate(replicate, &builder);
123 }
124 }
125
126 } // namespace
127
CreateTPUColocateCompositeResourceOps()128 std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps() {
129 return std::make_unique<TPUColocateCompositeResourceOps>();
130 }
131
132 static PassRegistration<TPUColocateCompositeResourceOps> pass(
133 "tf-tpu-colocate-composite-resource-ops",
134 "Colocate resource with composite device assignment to TPU device.");
135
136 } // namespace TFTPU
137 } // namespace mlir
138