• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/IR/Value.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
28 
29 namespace mlir {
30 namespace TFTPU {
31 namespace {
32 
33 // Pass that co-locates resource ops that use composite device resources
34 // (packed tensors) with the underlying physical TPU device.
35 struct TPUColocateCompositeResourceOps
36     : public PassWrapper<TPUColocateCompositeResourceOps, FunctionPass> {
getArgumentmlir::TFTPU::__anonbda32c8e0111::TPUColocateCompositeResourceOps37   StringRef getArgument() const final {
38     return "tf-tpu-colocate-composite-resource-ops";
39   }
40 
getDescriptionmlir::TFTPU::__anonbda32c8e0111::TPUColocateCompositeResourceOps41   StringRef getDescription() const final {
42     return "Colocate resource with composite device assignment to TPU device.";
43   }
44 
45   void runOnFunction() override;
46 };
47 
48 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)49 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
50                     llvm::StringRef device) {
51   builder->setInsertionPoint(op);
52   auto launch = builder->create<tf_device::LaunchOp>(
53       loc, builder->getStringAttr(device), op->getResultTypes());
54   launch.body().push_back(new Block);
55   op->replaceAllUsesWith(launch);
56 
57   builder->setInsertionPointToEnd(&launch.GetBody());
58   builder->create<tf_device::ReturnOp>(loc, op->getResults());
59 
60   // Move op inside cluster.
61   op->moveBefore(launch.GetBody().getTerminator());
62 }
63 
GetResourceOpsUsingCompositeArgsInReplicate(tf_device::ReplicateOp replicate)64 llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
65     tf_device::ReplicateOp replicate) {
66   llvm::SmallVector<Operation*, 4> resource_users;
67   const auto add_resource_op_to_list = [&resource_users](Operation* op) {
68     if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
69 
70     resource_users.emplace_back(op);
71   };
72 
73   llvm::SmallVector<Operation*, 4> resource_users_to_visit;
74   for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
75     for (auto resource_user : composite_arguments.getUsers())
76       resource_users_to_visit.emplace_back(resource_user);
77   }
78 
79   while (!resource_users_to_visit.empty()) {
80     llvm::SmallVector<Operation*, 4> new_resource_users;
81 
82     for (auto resource_user : resource_users_to_visit) {
83       add_resource_op_to_list(resource_user);
84 
85       // Account for pass-through identity ops.
86       if (auto pass_through_identity =
87               llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
88         for (auto identity_user : pass_through_identity.output().getUsers()) {
89           new_resource_users.emplace_back(identity_user);
90         }
91       }
92     }
93     resource_users_to_visit.swap(new_resource_users);
94   }
95 
96   return resource_users;
97 }
98 
ColocateCompositeResourceOpsInReplicate(tf_device::ReplicateOp replicate_op,OpBuilder * builder)99 void ColocateCompositeResourceOpsInReplicate(
100     tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
101   auto devices = replicate_op.devices();
102   if (!devices) return;
103   if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
104     return;
105 
106   const auto composite_resource_users =
107       GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
108   for (auto resource_user : composite_resource_users) {
109     WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
110                    tensorflow::GetDeviceAliasForLogicalCore(0));
111   }
112 }
113 
runOnFunction()114 void TPUColocateCompositeResourceOps::runOnFunction() {
115   // Find all the executes first, since we will mutate the nodes around each
116   // execute in the same tf_device.replicate op.
117   llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
118   getFunction().walk([&](tf_device::LaunchOp op) {
119     if (op.WrapsSingleOp() &&
120         llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
121             op.GetBody().front()))
122       execute_launches.push_back(op);
123   });
124 
125   OpBuilder builder(&getContext());
126   for (auto execute_launch : execute_launches) {
127     auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
128     if (!replicate) continue;
129 
130     ColocateCompositeResourceOpsInReplicate(replicate, &builder);
131   }
132 }
133 
134 }  // namespace
135 
CreateTPUColocateCompositeResourceOps()136 std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps() {
137   return std::make_unique<TPUColocateCompositeResourceOps>();
138 }
139 
140 static PassRegistration<TPUColocateCompositeResourceOps> pass;
141 
142 }  // namespace TFTPU
143 }  // namespace mlir
144