• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 
29 namespace mlir {
30 namespace TFTPU {
31 
32 namespace {
33 
34 // A pass that moves `tf.AssignVariableOp` into a `tf_device.parallel_execute`
35 // region if the `tf.AssignVariableOp` is the only consumer of a
36 // `tf_device.parallel_execute` result. This will allow
37 // TPUMergeVariablesWithExecute to merge resource writes without special
38 // handling for `tf_device.parallel_execute`.
39 struct TPUParallelExecuteSinkResourceWrite
40     : public PassWrapper<TPUParallelExecuteSinkResourceWrite, FunctionPass> {
41   void runOnFunction() override;
42 
getArgumentmlir::TFTPU::__anondf8a44be0111::TPUParallelExecuteSinkResourceWrite43   StringRef getArgument() const final {
44     return "tf-tpu-parallel-execute-sink-resource-write";
45   }
46 
getDescriptionmlir::TFTPU::__anondf8a44be0111::TPUParallelExecuteSinkResourceWrite47   StringRef getDescription() const final {
48     return "Moves tf.AssignVariableOp consumers of tf_device.parallel_execute "
49            "into tf_device.parallel_execute regions";
50   }
51 };
52 
53 // Finds an AssignVariableOp that can be moved into the parallel_execute region.
54 // These AssignVariableOps must be the only consumer of the respective
55 // parallel_execute result, and the resource handle producer must be from an op
56 // before or above the parallel_execute.
GetSingleUseResourceWrite(tf_device::ParallelExecuteOp parallel_execute,Value result)57 TF::AssignVariableOp GetSingleUseResourceWrite(
58     tf_device::ParallelExecuteOp parallel_execute, Value result) {
59   if (!result.hasOneUse()) return nullptr;
60 
61   OpOperand& use = *result.getUses().begin();
62   auto assign_var = dyn_cast<TF::AssignVariableOp>(use.getOwner());
63   if (!assign_var) return nullptr;
64 
65   if (use.get() != assign_var.value()) return nullptr;
66 
67   auto* resource_handle_op = assign_var.resource().getDefiningOp();
68   if (resource_handle_op == parallel_execute) return nullptr;
69 
70   if (resource_handle_op &&
71       resource_handle_op->getBlock() ==
72           parallel_execute.getOperation()->getBlock() &&
73       parallel_execute.getOperation()->isBeforeInBlock(resource_handle_op))
74     return nullptr;
75 
76   return assign_var;
77 }
78 
79 // Finds AssignVariableOps that can be moved into a parallel_execute region and
80 // moves them. Leftover parallel_execute results that were used by the
81 // such AssignVariableOp are also pruned.
SinkResourceWritesIntoParallelExecute(tf_device::ParallelExecuteOp parallel_execute)82 void SinkResourceWritesIntoParallelExecute(
83     tf_device::ParallelExecuteOp parallel_execute) {
84   bool rewrite = false;
85   const int num_regions = parallel_execute.getNumRegions();
86   llvm::SmallVector<Value, 4> results_to_remap;
87 
88   // Go through each region and find AssignVariableOps that can be moved into
89   // the parallel_execute region. Result indices by region index are collected,
90   // so they can be removed afterwards.
91   llvm::SmallVector<llvm::SmallVector<int, 4>, 4> results_to_remove_by_region;
92   results_to_remove_by_region.resize(num_regions);
93   for (int i = 0; i < num_regions; ++i) {
94     Block& block = parallel_execute.GetRegionBlockWithIndex(i);
95     auto results = parallel_execute.GetRegionOutputs(i);
96     auto& results_to_remove = results_to_remove_by_region[i];
97     results_to_remove.reserve(results.size());
98     Operation* terminator = block.getTerminator();
99     for (auto result : llvm::enumerate(results)) {
100       TF::AssignVariableOp assign_var =
101           GetSingleUseResourceWrite(parallel_execute, result.value());
102       if (!assign_var) {
103         results_to_remap.push_back(result.value());
104         continue;
105       }
106 
107       // Move AssignVariableOp and update the value to be written to the
108       // resource variable to be the non forwarded value from within the
109       // parallel_execute region.
110       assign_var.getOperation()->moveBefore(terminator);
111       assign_var.valueMutable().assign(terminator->getOperand(result.index()));
112       results_to_remove.push_back(result.index());
113     }
114 
115     rewrite |= !results_to_remove.empty();
116   }
117 
118   if (!rewrite) return;
119 
120   // Remove leftover unused results (terminator operands) from moving
121   // AssignVariabeOps into the parallel_execute region.
122   for (auto results_to_remove : llvm::enumerate(results_to_remove_by_region)) {
123     Block& block =
124         parallel_execute.GetRegionBlockWithIndex(results_to_remove.index());
125     Operation* terminator = block.getTerminator();
126     for (int index_to_remove : llvm::reverse(results_to_remove.value()))
127       terminator->eraseOperand(index_to_remove);
128   }
129 
130   // Replace old parallel_execute with new parallel_execute by moving the
131   // regions to a new parallel_execute and remapping the results.
132   llvm::SmallVector<Type, 4> new_result_types;
133   new_result_types.reserve(results_to_remap.size());
134   for (Value old_result : results_to_remap)
135     new_result_types.push_back(old_result.getType());
136 
137   OpBuilder builder(parallel_execute);
138   auto new_parallel_execute = builder.create<tf_device::ParallelExecuteOp>(
139       parallel_execute.getLoc(), num_regions, new_result_types);
140 
141   for (auto region : llvm::zip(new_parallel_execute.getRegions(),
142                                parallel_execute.getRegions()))
143     std::get<0>(region)->takeBody(*std::get<1>(region));
144 
145   for (auto result :
146        llvm::zip(results_to_remap, new_parallel_execute.getResults()))
147     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
148 
149   parallel_execute.erase();
150 }
151 
runOnFunction()152 void TPUParallelExecuteSinkResourceWrite::runOnFunction() {
153   llvm::SmallVector<tf_device::ParallelExecuteOp, 4> parallel_executes;
154   getFunction().walk([&](tf_device::ParallelExecuteOp parallel_execute) {
155     parallel_executes.push_back(parallel_execute);
156   });
157 
158   for (tf_device::ParallelExecuteOp parallel_execute : parallel_executes)
159     SinkResourceWritesIntoParallelExecute(parallel_execute);
160 }
161 
162 }  // anonymous namespace
163 
164 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUParallelExecuteSinkResourceWritePass()165 CreateTPUParallelExecuteSinkResourceWritePass() {
166   return std::make_unique<TPUParallelExecuteSinkResourceWrite>();
167 }
168 
169 static PassRegistration<TPUParallelExecuteSinkResourceWrite> pass;
170 
171 }  // namespace TFTPU
172 }  // namespace mlir
173