• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/IR/Value.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
28 
29 namespace mlir {
30 namespace TFTPU {
31 namespace {
32 
33 // Pass that co-locates resource ops that use composite device resources
34 // (packed tensors) with the underlying physical TPU device.
35 struct TPUColocateCompositeResourceOps
36     : public PassWrapper<TPUColocateCompositeResourceOps, FunctionPass> {
37   void runOnFunction() override;
38 };
39 
40 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)41 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
42                     llvm::StringRef device) {
43   builder->setInsertionPoint(op);
44   auto launch = builder->create<tf_device::LaunchOp>(
45       loc, builder->getStringAttr(device), op->getResultTypes());
46   launch.body().push_back(new Block);
47   op->replaceAllUsesWith(launch);
48 
49   builder->setInsertionPointToEnd(&launch.GetBody());
50   builder->create<tf_device::ReturnOp>(loc, op->getResults());
51 
52   // Move op inside cluster.
53   op->moveBefore(launch.GetBody().getTerminator());
54 }
55 
GetResourceOpsUsingCompositeArgsInReplicate(tf_device::ReplicateOp replicate)56 llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
57     tf_device::ReplicateOp replicate) {
58   llvm::SmallVector<Operation*, 4> resource_users;
59   const auto add_resource_op_to_list = [&resource_users](Operation* op) {
60     if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
61 
62     resource_users.emplace_back(op);
63   };
64 
65   llvm::SmallVector<Operation*, 4> resource_users_to_visit;
66   for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
67     for (auto resource_user : composite_arguments.getUsers())
68       resource_users_to_visit.emplace_back(resource_user);
69   }
70 
71   while (!resource_users_to_visit.empty()) {
72     llvm::SmallVector<Operation*, 4> new_resource_users;
73 
74     for (auto resource_user : resource_users_to_visit) {
75       add_resource_op_to_list(resource_user);
76 
77       // Account for pass-through identity ops.
78       if (auto pass_through_identity =
79               llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
80         for (auto identity_user : pass_through_identity.output().getUsers()) {
81           new_resource_users.emplace_back(identity_user);
82         }
83       }
84     }
85     resource_users_to_visit.swap(new_resource_users);
86   }
87 
88   return resource_users;
89 }
90 
ColocateCompositeResourceOpsInReplicate(tf_device::ReplicateOp replicate_op,OpBuilder * builder)91 void ColocateCompositeResourceOpsInReplicate(
92     tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
93   auto devices = replicate_op.devices();
94   if (!devices) return;
95   if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
96     return;
97 
98   const auto composite_resource_users =
99       GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
100   for (auto resource_user : composite_resource_users) {
101     WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
102                    tensorflow::GetDeviceAliasForLogicalCore(0));
103   }
104 }
105 
runOnFunction()106 void TPUColocateCompositeResourceOps::runOnFunction() {
107   // Find all the executes first, since we will mutate the nodes around each
108   // execute in the same tf_device.replicate op.
109   llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
110   getFunction().walk([&](tf_device::LaunchOp op) {
111     if (op.WrapsSingleOp() &&
112         llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
113             op.GetBody().front()))
114       execute_launches.push_back(op);
115   });
116 
117   OpBuilder builder(&getContext());
118   for (auto execute_launch : execute_launches) {
119     auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
120     if (!replicate) continue;
121 
122     ColocateCompositeResourceOpsInReplicate(replicate, &builder);
123   }
124 }
125 
126 }  // namespace
127 
CreateTPUColocateCompositeResourceOps()128 std::unique_ptr<OperationPass<FuncOp>> CreateTPUColocateCompositeResourceOps() {
129   return std::make_unique<TPUColocateCompositeResourceOps>();
130 }
131 
132 static PassRegistration<TPUColocateCompositeResourceOps> pass(
133     "tf-tpu-colocate-composite-resource-ops",
134     "Colocate resource with composite device assignment to TPU device.");
135 
136 }  // namespace TFTPU
137 }  // namespace mlir
138