1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Debug.h"
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
24
25 #define DEBUG_TYPE "tf-hoist-replicate-invariant-resource-writes"
26
27 namespace mlir {
28 namespace TF {
29
30 namespace {
31
32 struct HoistReplicateInvariantResourceWritesPass
33 : public TF::HoistReplicateInvariantResourceWritesPassBase<
34 HoistReplicateInvariantResourceWritesPass> {
35 void runOnOperation() override;
36 };
37
38 // TODO(prakalps): This is a common utility and other passes use something
39 // similar. Move to common utils.
IsResourceType(Type type)40 bool IsResourceType(Type type) {
41 return type.isa<TF::ResourceType>() ||
42 (type.isa<TensorType>() &&
43 type.cast<TensorType>().getElementType().isa<TF::ResourceType>());
44 }
45
GetAccessedResources(Operation & op)46 SmallVector<Value> GetAccessedResources(Operation& op) {
47 SmallVector<Value, 4> accessed_resources;
48 for (auto operand : op.getOperands()) {
49 if (!IsResourceType(operand.getType())) continue;
50 accessed_resources.push_back(operand);
51 }
52 return std::move(accessed_resources);
53 }
54
55 // Lifts the tail writes outside of tf_device.replicate. The written value is
56 // added to the values returned by tf_device.replicate op. Modify the assign
57 // variable ops to use the value from first replica.
MoveTailWritesAfterReplicate(tf_device::ReplicateOp replicate_op,llvm::ArrayRef<TF::AssignVariableOp> tail_assign_variable_ops)58 void MoveTailWritesAfterReplicate(
59 tf_device::ReplicateOp replicate_op,
60 llvm::ArrayRef<TF::AssignVariableOp> tail_assign_variable_ops) {
61 const auto num_replicas = replicate_op.n();
62 auto return_op = llvm::dyn_cast<tf_device::ReturnOp>(
63 replicate_op.getRegion().front().getTerminator());
64
65 // Get the new result types.
66 // TODO(prakalps): Do not add a value to returned values if it is already
67 // returned.
68 auto new_result_types = llvm::to_vector<4>(replicate_op->getResultTypes());
69 for (auto assign : tail_assign_variable_ops) {
70 return_op->insertOperands(return_op->getNumOperands(), assign.value());
71 new_result_types.insert(new_result_types.end(), num_replicas,
72 assign.value().getType());
73 }
74
75 OpBuilder builder(replicate_op);
76 // Clone this old replicate op but with new result types.
77 auto new_replicate_op = builder.create<tf_device::ReplicateOp>(
78 replicate_op->getLoc(), new_result_types, replicate_op->getOperands(),
79 replicate_op->getAttrs());
80
81 // Move region to the new op.
82 new_replicate_op.getRegion().takeBody(replicate_op.getRegion());
83
84 // Replace all old uses with new op results.
85 int old_num_results = replicate_op->getNumResults();
86 replicate_op->replaceAllUsesWith(
87 new_replicate_op->getResults().take_front(old_num_results));
88
89 // Move assign ops after replicate and use the output of first replica.
90 for (auto indexed_assign : llvm::enumerate(tail_assign_variable_ops)) {
91 auto assign_op = indexed_assign.value();
92 auto index = indexed_assign.index();
93 assign_op->moveAfter(new_replicate_op);
94 assign_op->setOperand(
95 1, new_replicate_op->getResult(old_num_results + num_replicas * index));
96 }
97 replicate_op->erase();
98 }
99
100 // Looks for AssignVariable ops from the end of the tf_device.replicate op. It
101 // returns all the last writes to replicate invariant resource variables
102 // (resource handles defined outside the tf_device.replicate op).
GetTailWritesToReplicateInvariantResourceVars(tf_device::ReplicateOp replicate_op)103 SmallVector<TF::AssignVariableOp> GetTailWritesToReplicateInvariantResourceVars(
104 tf_device::ReplicateOp replicate_op) {
105 SmallVector<TF::AssignVariableOp, 16> tail_assign_variable_ops;
106 llvm::SmallDenseSet<Value, 16> visited_resources;
107 for (auto& op :
108 llvm::reverse(replicate_op.getRegion().front().getOperations())) {
109 SmallVector<Value> op_accessed_resources = GetAccessedResources(op);
110 if (op_accessed_resources.empty()) continue;
111
112 if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(op)) {
113 Value resource_var = assign.resource();
114 if (visited_resources.contains(resource_var) ||
115 !resource_var.getParentRegion()->isProperAncestor(
116 &replicate_op.getRegion()))
117 continue;
118 tail_assign_variable_ops.push_back(assign);
119 }
120
121 for (Value resource : op_accessed_resources)
122 visited_resources.insert(resource);
123 }
124 return std::move(tail_assign_variable_ops);
125 }
126
runOnOperation()127 void HoistReplicateInvariantResourceWritesPass::runOnOperation() {
128 SmallVector<tf_device::ReplicateOp, 2> replicate_ops;
129 getOperation().walk([&](tf_device::ReplicateOp replicate_op) {
130 replicate_ops.push_back(replicate_op);
131 });
132 for (auto replicate_op : replicate_ops) {
133 SmallVector<TF::AssignVariableOp> tail_writes =
134 GetTailWritesToReplicateInvariantResourceVars(replicate_op);
135
136 if (tail_writes.empty()) continue;
137 MoveTailWritesAfterReplicate(replicate_op, tail_writes);
138 }
139 }
140
141 } // namespace
142
143 std::unique_ptr<OperationPass<func::FuncOp>>
CreateHoistReplicateInvariantResourceWritesPass()144 CreateHoistReplicateInvariantResourceWritesPass() {
145 return std::make_unique<HoistReplicateInvariantResourceWritesPass>();
146 }
147
148 } // namespace TF
149 } // namespace mlir
150