1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
21 #include "mlir/IR/Value.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
28
29 namespace mlir {
30 namespace TFTPU {
31 namespace {
32
33 // Pass that co-locates resource ops that use composite device resources
34 // (packed tensors) with the underlying physical TPU device.
35 struct TPUColocateCompositeResourceOps
36 : public PassWrapper<TPUColocateCompositeResourceOps, FunctionPass> {
getArgumentmlir::TFTPU::__anonbda32c8e0111::TPUColocateCompositeResourceOps37 StringRef getArgument() const final {
38 return "tf-tpu-colocate-composite-resource-ops";
39 }
40
getDescriptionmlir::TFTPU::__anonbda32c8e0111::TPUColocateCompositeResourceOps41 StringRef getDescription() const final {
42 return "Colocate resource with composite device assignment to TPU device.";
43 }
44
45 void runOnFunction() override;
46 };
47
48 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)49 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
50 llvm::StringRef device) {
51 builder->setInsertionPoint(op);
52 auto launch = builder->create<tf_device::LaunchOp>(
53 loc, builder->getStringAttr(device), op->getResultTypes());
54 launch.body().push_back(new Block);
55 op->replaceAllUsesWith(launch);
56
57 builder->setInsertionPointToEnd(&launch.GetBody());
58 builder->create<tf_device::ReturnOp>(loc, op->getResults());
59
60 // Move op inside cluster.
61 op->moveBefore(launch.GetBody().getTerminator());
62 }
63
GetResourceOpsUsingCompositeArgsInReplicate(tf_device::ReplicateOp replicate)64 llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
65 tf_device::ReplicateOp replicate) {
66 llvm::SmallVector<Operation*, 4> resource_users;
67 const auto add_resource_op_to_list = [&resource_users](Operation* op) {
68 if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
69
70 resource_users.emplace_back(op);
71 };
72
73 llvm::SmallVector<Operation*, 4> resource_users_to_visit;
74 for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
75 for (auto resource_user : composite_arguments.getUsers())
76 resource_users_to_visit.emplace_back(resource_user);
77 }
78
79 while (!resource_users_to_visit.empty()) {
80 llvm::SmallVector<Operation*, 4> new_resource_users;
81
82 for (auto resource_user : resource_users_to_visit) {
83 add_resource_op_to_list(resource_user);
84
85 // Account for pass-through identity ops.
86 if (auto pass_through_identity =
87 llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
88 for (auto identity_user : pass_through_identity.output().getUsers()) {
89 new_resource_users.emplace_back(identity_user);
90 }
91 }
92 }
93 resource_users_to_visit.swap(new_resource_users);
94 }
95
96 return resource_users;
97 }
98
ColocateCompositeResourceOpsInReplicate(tf_device::ReplicateOp replicate_op,OpBuilder * builder)99 void ColocateCompositeResourceOpsInReplicate(
100 tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
101 auto devices = replicate_op.devices();
102 if (!devices) return;
103 if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
104 return;
105
106 const auto composite_resource_users =
107 GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
108 for (auto resource_user : composite_resource_users) {
109 WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
110 tensorflow::GetDeviceAliasForLogicalCore(0));
111 }
112 }
113
runOnFunction()114 void TPUColocateCompositeResourceOps::runOnFunction() {
115 // Find all the executes first, since we will mutate the nodes around each
116 // execute in the same tf_device.replicate op.
117 llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
118 getFunction().walk([&](tf_device::LaunchOp op) {
119 if (op.WrapsSingleOp() &&
120 llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
121 op.GetBody().front()))
122 execute_launches.push_back(op);
123 });
124
125 OpBuilder builder(&getContext());
126 for (auto execute_launch : execute_launches) {
127 auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
128 if (!replicate) continue;
129
130 ColocateCompositeResourceOpsInReplicate(replicate, &builder);
131 }
132 }
133
134 } // namespace
135
CreateTPUColocateCompositeResourceOps()136 std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps() {
137 return std::make_unique<TPUColocateCompositeResourceOps>();
138 }
139
140 static PassRegistration<TPUColocateCompositeResourceOps> pass;
141
142 } // namespace TFTPU
143 } // namespace mlir
144