• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 
29 namespace mlir {
30 namespace TFTPU {
31 
32 namespace {
33 
34 // A pass that moves `tf.AssignVariableOp` into a `tf_device.parallel_execute`
35 // region if the `tf.AssignVariableOp` is the only consumer of a
36 // `tf_device.parallel_execute` result. This will allow
37 // TPUMergeVariablesWithExecute to merge resource writes without special
38 // handling for `tf_device.parallel_execute`.
39 struct TPUParallelExecuteSinkResourceWrite
40     : public PassWrapper<TPUParallelExecuteSinkResourceWrite, FunctionPass> {
41   void runOnFunction() override;
42 };
43 
44 // Finds an AssignVariableOp that can be moved into the parallel_execute region.
45 // These AssignVariableOps must be the only consumer of the respective
46 // parallel_execute result, and the resource handle producer must be from an op
47 // before or above the parallel_execute.
GetSingleUseResourceWrite(tf_device::ParallelExecuteOp parallel_execute,Value result)48 TF::AssignVariableOp GetSingleUseResourceWrite(
49     tf_device::ParallelExecuteOp parallel_execute, Value result) {
50   if (!result.hasOneUse()) return nullptr;
51 
52   OpOperand& use = *result.getUses().begin();
53   auto assign_var = dyn_cast<TF::AssignVariableOp>(use.getOwner());
54   if (!assign_var) return nullptr;
55 
56   if (use.get() != assign_var.value()) return nullptr;
57 
58   auto* resource_handle_op = assign_var.resource().getDefiningOp();
59   if (resource_handle_op == parallel_execute) return nullptr;
60 
61   if (resource_handle_op &&
62       resource_handle_op->getBlock() ==
63           parallel_execute.getOperation()->getBlock() &&
64       parallel_execute.getOperation()->isBeforeInBlock(resource_handle_op))
65     return nullptr;
66 
67   return assign_var;
68 }
69 
70 // Finds AssignVariableOps that can be moved into a parallel_execute region and
71 // moves them. Leftover parallel_execute results that were used by the
72 // such AssignVariableOp are also pruned.
SinkResourceWritesIntoParallelExecute(tf_device::ParallelExecuteOp parallel_execute)73 void SinkResourceWritesIntoParallelExecute(
74     tf_device::ParallelExecuteOp parallel_execute) {
75   bool rewrite = false;
76   const int num_regions = parallel_execute.getNumRegions();
77   llvm::SmallVector<Value, 4> results_to_remap;
78 
79   // Go through each region and find AssignVariableOps that can be moved into
80   // the parallel_execute region. Result indices by region index are collected,
81   // so they can be removed afterwards.
82   llvm::SmallVector<llvm::SmallVector<int, 4>, 4> results_to_remove_by_region;
83   results_to_remove_by_region.resize(num_regions);
84   for (int i = 0; i < num_regions; ++i) {
85     Block& block = parallel_execute.GetRegionBlockWithIndex(i);
86     auto results = parallel_execute.GetRegionOutputs(i);
87     auto& results_to_remove = results_to_remove_by_region[i];
88     results_to_remove.reserve(results.size());
89     Operation* terminator = block.getTerminator();
90     for (auto result : llvm::enumerate(results)) {
91       TF::AssignVariableOp assign_var =
92           GetSingleUseResourceWrite(parallel_execute, result.value());
93       if (!assign_var) {
94         results_to_remap.push_back(result.value());
95         continue;
96       }
97 
98       // Move AssignVariableOp and update the value to be written to the
99       // resource variable to be the non forwarded value from within the
100       // parallel_execute region.
101       assign_var.getOperation()->moveBefore(terminator);
102       assign_var.valueMutable().assign(terminator->getOperand(result.index()));
103       results_to_remove.push_back(result.index());
104     }
105 
106     rewrite |= !results_to_remove.empty();
107   }
108 
109   if (!rewrite) return;
110 
111   // Remove leftover unused results (terminator operands) from moving
112   // AssignVariabeOps into the parallel_execute region.
113   for (auto results_to_remove : llvm::enumerate(results_to_remove_by_region)) {
114     Block& block =
115         parallel_execute.GetRegionBlockWithIndex(results_to_remove.index());
116     Operation* terminator = block.getTerminator();
117     for (int index_to_remove : llvm::reverse(results_to_remove.value()))
118       terminator->eraseOperand(index_to_remove);
119   }
120 
121   // Replace old parallel_execute with new parallel_execute by moving the
122   // regions to a new parallel_execute and remapping the results.
123   llvm::SmallVector<Type, 4> new_result_types;
124   new_result_types.reserve(results_to_remap.size());
125   for (Value old_result : results_to_remap)
126     new_result_types.push_back(old_result.getType());
127 
128   OpBuilder builder(parallel_execute);
129   auto new_parallel_execute = builder.create<tf_device::ParallelExecuteOp>(
130       parallel_execute.getLoc(), num_regions, new_result_types);
131 
132   for (auto region : llvm::zip(new_parallel_execute.getRegions(),
133                                parallel_execute.getRegions()))
134     std::get<0>(region)->takeBody(*std::get<1>(region));
135 
136   for (auto result :
137        llvm::zip(results_to_remap, new_parallel_execute.getResults()))
138     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
139 
140   parallel_execute.erase();
141 }
142 
runOnFunction()143 void TPUParallelExecuteSinkResourceWrite::runOnFunction() {
144   llvm::SmallVector<tf_device::ParallelExecuteOp, 4> parallel_executes;
145   getFunction().walk([&](tf_device::ParallelExecuteOp parallel_execute) {
146     parallel_executes.push_back(parallel_execute);
147   });
148 
149   for (tf_device::ParallelExecuteOp parallel_execute : parallel_executes)
150     SinkResourceWritesIntoParallelExecute(parallel_execute);
151 }
152 
153 }  // anonymous namespace
154 
155 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUParallelExecuteSinkResourceWritePass()156 CreateTPUParallelExecuteSinkResourceWritePass() {
157   return std::make_unique<TPUParallelExecuteSinkResourceWrite>();
158 }
159 
160 static PassRegistration<TPUParallelExecuteSinkResourceWrite> pass(
161     "tf-tpu-parallel-execute-sink-resource-write",
162     "Moves tf.AssignVariableOp consumers of tf_device.parallel_execute into "
163     "tf_device.parallel_execute regions");
164 
165 }  // namespace TFTPU
166 }  // namespace mlir
167