• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass hoists replicate invariant ops, or ops that yield the same
17 // result(s) regardless of replication, out of their respective replicate.
18 
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 
32 namespace mlir {
33 namespace TFDevice {
34 
35 namespace {
36 
37 constexpr char kDeviceAttr[] = "device";
38 
39 struct ReplicateInvariantOpHoistingPass
40     : public PassWrapper<ReplicateInvariantOpHoistingPass, FunctionPass> {
41   void runOnFunction() override;
42 };
43 
44 // Make ShapeOp replicate invariant if it is possible. This currently updates or
45 // replace ShapeOps of replicated arguments, either tensors or resources.
46 //
47 // For example, the following:
48 //
49 // tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
50 //   %2 = "tf.Shape"(%ri) : (tensor<*xi32>) -> tensor<?xi32>
51 //   tf_device.return
52 // }
53 //
54 // gets converted to:
55 //
56 // tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
57 //   %2 = "tf.Shape"(%0) : (tensor<*xi32>) -> tensor<?xi32>
58 //   tf_device.return
59 // }
60 //
61 // and for resource variables:
62 //
63 // tf_device.replicate([%0, %1] as %ri: tensor<*x!tf.resource>) {n = 2 : i32} {
64 //   %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf.resource> -> tensor<*xi32>
65 //   %3 = "tf.Shape"(%2) : (tensor<*xi32>) -> tensor<?xi32>
66 //   tf_device.return
67 // }
68 //
69 // gets converted to:
70 //
71 // tf_device.replicate([%0, %1] as %ri: tensor<*x!tf.resource>) {n = 2 : i32} {
72 //   %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf.resource> -> tensor<*xi32>
73 //   %3 = "tf.VariableShape"(%0) : (tensor<*x!tf.resource>) -> tensor<?xi32>
74 //   tf_device.return
75 // }
MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op,int num_replicas,Block * replicate_block,TF::ShapeOp shape_op)76 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
77                           Block* replicate_block, TF::ShapeOp shape_op) {
78   Value input = shape_op.input();
79   // If ShapeOp operand is replicate tensor block argument, replace with the
80   // associated first replica operand.
81   if (auto block_arg = input.dyn_cast<BlockArgument>()) {
82     if (block_arg.getOwner() != replicate_block) return;
83 
84     shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument(
85         block_arg, /*replica=*/0));
86 
87     return;
88   }
89 
90   Operation* input_def = input.getDefiningOp();
91 
92   // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
93   // operand is a replicate resource block argument, replace ShapeOp with
94   // VariableShapeOp and use the associated first replica operand as its
95   // operand.
96   auto read_var_op = llvm::dyn_cast<TF::ReadVariableOp>(input_def);
97   if (!read_var_op) return;
98 
99   // TODO(lyandy): Check if resource (first replica or replicate block arg)
100   // shape has not changed in replicate prior to read. Currently after both
101   // ResourceOpLiftingPass and TPURewritePass, there should not be any updates
102   // to resources prior to their respective ReadVariableOp.
103   if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
104     if (block_arg.getOwner() != replicate_block) return;
105 
106     OpBuilder builder(shape_op);
107     auto new_shape_op = builder.create<TF::VariableShapeOp>(
108         shape_op.getLoc(), shape_op.getType(),
109         replicate_op.GetReplicaOperandForBlockArgument(block_arg,
110                                                        /*replica=*/0));
111     shape_op.replaceAllUsesWith(new_shape_op.getOperation());
112     shape_op.erase();
113   }
114 }
115 
116 // Check if op uses a device from a list of virtual devices.
UsesVirtualDevice(const Optional<DictionaryAttr> & virtual_devices,Operation * operation)117 bool UsesVirtualDevice(const Optional<DictionaryAttr>& virtual_devices,
118                        Operation* operation) {
119   if (!virtual_devices.hasValue()) return false;
120 
121   auto result = operation->walk([&](Operation* op) {
122     StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
123     if (!op_device) return WalkResult::advance();
124 
125     if (virtual_devices.getValue().get(op_device.getValue()))
126       return WalkResult::interrupt();
127     return WalkResult::advance();
128   });
129   return result.wasInterrupted();
130 }
131 
132 // Checks if op and inner op operands are all replicate invariant.
IsOpReplicateInvariant(Region * replicate_region,Operation * op)133 bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
134   auto ancestor_of_replicate = [&](Region* region) {
135     return region && region->isProperAncestor(replicate_region);
136   };
137 
138   for (Value operand : op->getOperands())
139     if (!ancestor_of_replicate(operand.getParentRegion())) return false;
140 
141   bool has_replicate_operands = false;
142   visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
143     if (!ancestor_of_replicate(operand->get().getParentRegion()))
144       has_replicate_operands = true;
145   });
146 
147   return !has_replicate_operands;
148 }
149 
150 // Hoists replicate invariant ops out of associated `tf_device.replicate` op.
151 // Ops to be hoisted are determined by if all of their operands are replicate
152 // invariant. Shape ops are rewritten to be invariant when possible, prior to
153 // hoisting ops.
HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op)154 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) {
155   const int num_replicas = replicate_op.n();
156   Block* replicate_block = &replicate_op.GetBody();
157 
158   replicate_op.walk([&](TF::ShapeOp shape_op) {
159     MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op);
160   });
161 
162   Region* replicate_region = &replicate_op.body();
163   Optional<DictionaryAttr> virtual_device_list = replicate_op.devices();
164   for (Operation& inner_op :
165        llvm::make_early_inc_range(replicate_op.GetBody())) {
166     if (llvm::isa<tf_device::ReturnOp>(inner_op)) continue;
167     // Skip hoisting if the inner op device attribute is a virtual device
168     // defined by tf_device.replicate.
169     if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue;
170 
171     if (IsOpReplicateInvariant(replicate_region, &inner_op))
172       inner_op.moveBefore(replicate_op);
173   }
174 }
175 
runOnFunction()176 void ReplicateInvariantOpHoistingPass::runOnFunction() {
177   getFunction().walk(
178       [](tf_device::ReplicateOp op) { HoistReplicateInvariantOps(op); });
179 }
180 }  // anonymous namespace
181 
182 std::unique_ptr<OperationPass<FuncOp>>
CreateReplicateInvariantOpHoistingPass()183 CreateReplicateInvariantOpHoistingPass() {
184   return std::make_unique<ReplicateInvariantOpHoistingPass>();
185 }
186 
187 static PassRegistration<ReplicateInvariantOpHoistingPass> pass(
188     "tf-replicate-invariant-op-hoisting",
189     "Hoists replicate invariant operations out of replicate");
190 
191 }  // namespace TFDevice
192 }  // namespace mlir
193