1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass hoists replicate invariant ops, or ops that yield the same
17 // result(s) regardless of replication, out of their respective replicate.
18
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31
32 namespace mlir {
33 namespace TFDevice {
34
35 namespace {
36
37 constexpr char kDeviceAttr[] = "device";
38
39 struct ReplicateInvariantOpHoistingPass
40 : public PassWrapper<ReplicateInvariantOpHoistingPass, FunctionPass> {
41 void runOnFunction() override;
42 };
43
44 // Make ShapeOp replicate invariant if it is possible. This currently updates or
45 // replace ShapeOps of replicated arguments, either tensors or resources.
46 //
47 // For example, the following:
48 //
49 // tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
50 // %2 = "tf.Shape"(%ri) : (tensor<*xi32>) -> tensor<?xi32>
51 // tf_device.return
52 // }
53 //
54 // gets converted to:
55 //
56 // tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
57 // %2 = "tf.Shape"(%0) : (tensor<*xi32>) -> tensor<?xi32>
58 // tf_device.return
59 // }
60 //
61 // and for resource variables:
62 //
63 // tf_device.replicate([%0, %1] as %ri: tensor<*x!tf.resource>) {n = 2 : i32} {
64 // %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf.resource> -> tensor<*xi32>
65 // %3 = "tf.Shape"(%2) : (tensor<*xi32>) -> tensor<?xi32>
66 // tf_device.return
67 // }
68 //
69 // gets converted to:
70 //
71 // tf_device.replicate([%0, %1] as %ri: tensor<*x!tf.resource>) {n = 2 : i32} {
72 // %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf.resource> -> tensor<*xi32>
73 // %3 = "tf.VariableShape"(%0) : (tensor<*x!tf.resource>) -> tensor<?xi32>
74 // tf_device.return
75 // }
MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op,int num_replicas,Block * replicate_block,TF::ShapeOp shape_op)76 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
77 Block* replicate_block, TF::ShapeOp shape_op) {
78 Value input = shape_op.input();
79 // If ShapeOp operand is replicate tensor block argument, replace with the
80 // associated first replica operand.
81 if (auto block_arg = input.dyn_cast<BlockArgument>()) {
82 if (block_arg.getOwner() != replicate_block) return;
83
84 shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument(
85 block_arg, /*replica=*/0));
86
87 return;
88 }
89
90 Operation* input_def = input.getDefiningOp();
91
92 // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
93 // operand is a replicate resource block argument, replace ShapeOp with
94 // VariableShapeOp and use the associated first replica operand as its
95 // operand.
96 auto read_var_op = llvm::dyn_cast<TF::ReadVariableOp>(input_def);
97 if (!read_var_op) return;
98
99 // TODO(lyandy): Check if resource (first replica or replicate block arg)
100 // shape has not changed in replicate prior to read. Currently after both
101 // ResourceOpLiftingPass and TPURewritePass, there should not be any updates
102 // to resources prior to their respective ReadVariableOp.
103 if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
104 if (block_arg.getOwner() != replicate_block) return;
105
106 OpBuilder builder(shape_op);
107 auto new_shape_op = builder.create<TF::VariableShapeOp>(
108 shape_op.getLoc(), shape_op.getType(),
109 replicate_op.GetReplicaOperandForBlockArgument(block_arg,
110 /*replica=*/0));
111 shape_op.replaceAllUsesWith(new_shape_op.getOperation());
112 shape_op.erase();
113 }
114 }
115
116 // Check if op uses a device from a list of virtual devices.
UsesVirtualDevice(const Optional<DictionaryAttr> & virtual_devices,Operation * operation)117 bool UsesVirtualDevice(const Optional<DictionaryAttr>& virtual_devices,
118 Operation* operation) {
119 if (!virtual_devices.hasValue()) return false;
120
121 auto result = operation->walk([&](Operation* op) {
122 StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
123 if (!op_device) return WalkResult::advance();
124
125 if (virtual_devices.getValue().get(op_device.getValue()))
126 return WalkResult::interrupt();
127 return WalkResult::advance();
128 });
129 return result.wasInterrupted();
130 }
131
132 // Checks if op and inner op operands are all replicate invariant.
IsOpReplicateInvariant(Region * replicate_region,Operation * op)133 bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
134 auto ancestor_of_replicate = [&](Region* region) {
135 return region && region->isProperAncestor(replicate_region);
136 };
137
138 for (Value operand : op->getOperands())
139 if (!ancestor_of_replicate(operand.getParentRegion())) return false;
140
141 bool has_replicate_operands = false;
142 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
143 if (!ancestor_of_replicate(operand->get().getParentRegion()))
144 has_replicate_operands = true;
145 });
146
147 return !has_replicate_operands;
148 }
149
150 // Hoists replicate invariant ops out of associated `tf_device.replicate` op.
151 // Ops to be hoisted are determined by if all of their operands are replicate
152 // invariant. Shape ops are rewritten to be invariant when possible, prior to
153 // hoisting ops.
HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op)154 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) {
155 const int num_replicas = replicate_op.n();
156 Block* replicate_block = &replicate_op.GetBody();
157
158 replicate_op.walk([&](TF::ShapeOp shape_op) {
159 MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op);
160 });
161
162 Region* replicate_region = &replicate_op.body();
163 Optional<DictionaryAttr> virtual_device_list = replicate_op.devices();
164 for (Operation& inner_op :
165 llvm::make_early_inc_range(replicate_op.GetBody())) {
166 if (llvm::isa<tf_device::ReturnOp>(inner_op)) continue;
167 // Skip hoisting if the inner op device attribute is a virtual device
168 // defined by tf_device.replicate.
169 if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue;
170
171 if (IsOpReplicateInvariant(replicate_region, &inner_op))
172 inner_op.moveBefore(replicate_op);
173 }
174 }
175
runOnFunction()176 void ReplicateInvariantOpHoistingPass::runOnFunction() {
177 getFunction().walk(
178 [](tf_device::ReplicateOp op) { HoistReplicateInvariantOps(op); });
179 }
180 } // anonymous namespace
181
182 std::unique_ptr<OperationPass<FuncOp>>
CreateReplicateInvariantOpHoistingPass()183 CreateReplicateInvariantOpHoistingPass() {
184 return std::make_unique<ReplicateInvariantOpHoistingPass>();
185 }
186
187 static PassRegistration<ReplicateInvariantOpHoistingPass> pass(
188 "tf-replicate-invariant-op-hoisting",
189 "Hoists replicate invariant operations out of replicate");
190
191 } // namespace TFDevice
192 } // namespace mlir
193