• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
24 
25 #define DEBUG_TYPE "tf-hoist-replicate-invariant-resource-writes"
26 
27 namespace mlir {
28 namespace TF {
29 
30 namespace {
31 
32 struct HoistReplicateInvariantResourceWritesPass
33     : public TF::HoistReplicateInvariantResourceWritesPassBase<
34           HoistReplicateInvariantResourceWritesPass> {
35   void runOnOperation() override;
36 };
37 
38 // TODO(prakalps): This is a common utility and other passes use something
39 // similar. Move to common utils.
IsResourceType(Type type)40 bool IsResourceType(Type type) {
41   return type.isa<TF::ResourceType>() ||
42          (type.isa<TensorType>() &&
43           type.cast<TensorType>().getElementType().isa<TF::ResourceType>());
44 }
45 
GetAccessedResources(Operation & op)46 SmallVector<Value> GetAccessedResources(Operation& op) {
47   SmallVector<Value, 4> accessed_resources;
48   for (auto operand : op.getOperands()) {
49     if (!IsResourceType(operand.getType())) continue;
50     accessed_resources.push_back(operand);
51   }
52   return std::move(accessed_resources);
53 }
54 
55 // Lifts the tail writes outside of tf_device.replicate. The written value is
56 // added to the values returned by tf_device.replicate op. Modify the assign
57 // variable ops to use the value from first replica.
MoveTailWritesAfterReplicate(tf_device::ReplicateOp replicate_op,llvm::ArrayRef<TF::AssignVariableOp> tail_assign_variable_ops)58 void MoveTailWritesAfterReplicate(
59     tf_device::ReplicateOp replicate_op,
60     llvm::ArrayRef<TF::AssignVariableOp> tail_assign_variable_ops) {
61   const auto num_replicas = replicate_op.n();
62   auto return_op = llvm::dyn_cast<tf_device::ReturnOp>(
63       replicate_op.getRegion().front().getTerminator());
64 
65   // Get the new result types.
66   // TODO(prakalps): Do not add a value to returned values if it is already
67   // returned.
68   auto new_result_types = llvm::to_vector<4>(replicate_op->getResultTypes());
69   for (auto assign : tail_assign_variable_ops) {
70     return_op->insertOperands(return_op->getNumOperands(), assign.value());
71     new_result_types.insert(new_result_types.end(), num_replicas,
72                             assign.value().getType());
73   }
74 
75   OpBuilder builder(replicate_op);
76   // Clone this old replicate op but with new result types.
77   auto new_replicate_op = builder.create<tf_device::ReplicateOp>(
78       replicate_op->getLoc(), new_result_types, replicate_op->getOperands(),
79       replicate_op->getAttrs());
80 
81   // Move region to the new op.
82   new_replicate_op.getRegion().takeBody(replicate_op.getRegion());
83 
84   // Replace all old uses with new op results.
85   int old_num_results = replicate_op->getNumResults();
86   replicate_op->replaceAllUsesWith(
87       new_replicate_op->getResults().take_front(old_num_results));
88 
89   // Move assign ops after replicate and use the output of first replica.
90   for (auto indexed_assign : llvm::enumerate(tail_assign_variable_ops)) {
91     auto assign_op = indexed_assign.value();
92     auto index = indexed_assign.index();
93     assign_op->moveAfter(new_replicate_op);
94     assign_op->setOperand(
95         1, new_replicate_op->getResult(old_num_results + num_replicas * index));
96   }
97   replicate_op->erase();
98 }
99 
100 // Looks for AssignVariable ops from the end of the tf_device.replicate op. It
101 // returns all the last writes to replicate invariant resource variables
102 // (resource handles defined outside the tf_device.replicate op).
GetTailWritesToReplicateInvariantResourceVars(tf_device::ReplicateOp replicate_op)103 SmallVector<TF::AssignVariableOp> GetTailWritesToReplicateInvariantResourceVars(
104     tf_device::ReplicateOp replicate_op) {
105   SmallVector<TF::AssignVariableOp, 16> tail_assign_variable_ops;
106   llvm::SmallDenseSet<Value, 16> visited_resources;
107   for (auto& op :
108        llvm::reverse(replicate_op.getRegion().front().getOperations())) {
109     SmallVector<Value> op_accessed_resources = GetAccessedResources(op);
110     if (op_accessed_resources.empty()) continue;
111 
112     if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(op)) {
113       Value resource_var = assign.resource();
114       if (visited_resources.contains(resource_var) ||
115           !resource_var.getParentRegion()->isProperAncestor(
116               &replicate_op.getRegion()))
117         continue;
118       tail_assign_variable_ops.push_back(assign);
119     }
120 
121     for (Value resource : op_accessed_resources)
122       visited_resources.insert(resource);
123   }
124   return std::move(tail_assign_variable_ops);
125 }
126 
runOnOperation()127 void HoistReplicateInvariantResourceWritesPass::runOnOperation() {
128   SmallVector<tf_device::ReplicateOp, 2> replicate_ops;
129   getOperation().walk([&](tf_device::ReplicateOp replicate_op) {
130     replicate_ops.push_back(replicate_op);
131   });
132   for (auto replicate_op : replicate_ops) {
133     SmallVector<TF::AssignVariableOp> tail_writes =
134         GetTailWritesToReplicateInvariantResourceVars(replicate_op);
135 
136     if (tail_writes.empty()) continue;
137     MoveTailWritesAfterReplicate(replicate_op, tail_writes);
138   }
139 }
140 
141 }  // namespace
142 
143 std::unique_ptr<OperationPass<func::FuncOp>>
CreateHoistReplicateInvariantResourceWritesPass()144 CreateHoistReplicateInvariantResourceWritesPass() {
145   return std::make_unique<HoistReplicateInvariantResourceWritesPass>();
146 }
147 
148 }  // namespace TF
149 }  // namespace mlir
150