• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 //  %0 = tf_executor.island {
22 //    tf_executor.yield
23 //  }
24 //  %1:2 = tf_executor.island {
25 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 //      tf_executor.yield %2 : tensor<i1>
27 //  }
28 //  %3:2 = tf_executor.island(%0) {
29 //    %4 = "tf_device.parallel_execute"() ( {
30 //      %5 = "tf.opB"() : () -> tensor<i1>
31 //      tf_device.return %5 : tensor<i1>
32 //    }, {
33 //      %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 //      tf_device.return
35 //    }) {} : () -> (tensor<i1>)
36 //    tf_executor.yield %4 : tensor<i1>
37 //  }
38 //  tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 //  %0 = tf_executor.island {
43 //    tf_executor.yield
44 //  }
45 //  %1:2 = tf_executor.island {
46 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 //    tf_executor.yield %2 : tensor<i1>
48 //  }
49 //
50 //  // Island for the first region of above parallel_execute.
51 //  %3:2 = tf_executor.island(%0) {
52 //    %4 = "tf.opB"() : () -> tensor<i1>
53 //    tf_executor.yield %4 : tensor<i1>
54 //  }
55 //
56 //  // Island for the second region of above parallel_execute.
57 //  %5 = tf_executor.island(%0) {
58 //    %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 //    tf_executor.yield
60 //  }
61 //
62 //  tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 //  When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 //  then this pass will run following `replicate-to-island` pass and
66 //  `tf-executor-break-up-islands` pass.
67 
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h"  // from @llvm-project
71 #include "mlir/IR/Builders.h"  // from @llvm-project
72 #include "mlir/IR/Value.h"  // from @llvm-project
73 #include "mlir/Pass/Pass.h"  // from @llvm-project
74 #include "mlir/Support/LLVM.h"  // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77 
78 namespace mlir {
79 namespace TFDevice {
80 namespace {
81 
82 struct ParallelExecuteToIslandsPass
83     : public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
getArgumentmlir::TFDevice::__anon203c14300111::ParallelExecuteToIslandsPass84   StringRef getArgument() const final {
85     // This is the argument used to refer to the pass in
86     // the textual format (on the commandline for example).
87     return "tf-parallel-execute-to-islands";
88   }
getDescriptionmlir::TFDevice::__anon203c14300111::ParallelExecuteToIslandsPass89   StringRef getDescription() const final {
90     // This is a brief description of the pass.
91     return "Lowers device parallel_execute to executor islands";
92   }
93   void runOnFunction() override;
94 };
95 
96 // Convert parallel_execute op to a set of islands where each region of
97 // parallel_execute op becomes a separate island. This ensures that the regions
98 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)99 void ExpandParallelExecuteToIslands(
100     tf_executor::IslandOp island_op,
101     tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
102     llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
103   const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
104   executes.reserve(num_regions);
105 
106   for (int i : llvm::seq<int>(0, num_regions)) {
107     Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
108 
109     // Replace terminator with tf_executor.YieldOp.
110     Operation* terminator = execute_block.getTerminator();
111     builder->setInsertionPoint(terminator);
112     auto yield = builder->create<tf_executor::YieldOp>(
113         terminator->getLoc(), terminator->getOperands());
114     terminator->erase();
115 
116     // Create new island for each region.
117     builder->setInsertionPoint(island_op);
118     auto execute_island = builder->create<tf_executor::IslandOp>(
119         island_op.getLoc(), yield.getOperandTypes(),
120         island_op.control().getType(), island_op.controlInputs());
121 
122     // Move over tf_device.parallel_execute body region into newly the created
123     // island.
124     execute_island.body().takeBody(*execute_block.getParent());
125     executes.push_back(execute_island);
126   }
127 }
128 
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)129 void CreateIslandsFromParallelExecute(
130     tf_executor::IslandOp island_op,
131     tf_device::ParallelExecuteOp parallel_execute_op) {
132   OpBuilder builder(island_op);
133 
134   // Create islands for each region of the parallel_execute op.
135   llvm::SmallVector<tf_executor::IslandOp, 4> executes;
136   ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
137                                  executes);
138 
139   // Remap all results of parallel_execute op with outputs from newly created
140   // islands.
141   llvm::SmallVector<Value, 8> parallel_execute_outputs;
142   parallel_execute_outputs.reserve(
143       parallel_execute_op.getOperation()->getNumResults());
144 
145   for (auto& execute : executes)
146     parallel_execute_outputs.append(execute.outputs().begin(),
147                                     execute.outputs().end());
148 
149   for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
150     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
151 
152   // Add sink island to pin all islands as a control dependency if there is a
153   // control dependency leading from the parallel_execute originally.
154   if (!island_op.control().use_empty()) {
155     llvm::SmallVector<Value, 8> island_operands;
156     for (auto& execute : executes) island_operands.push_back(execute.control());
157 
158     builder.setInsertionPoint(island_op);
159     auto island_sink = builder.create<tf_executor::IslandOp>(
160         island_op.getLoc(), llvm::ArrayRef<Type>{},
161         island_op.control().getType(), island_operands);
162     island_sink.body().push_back(new Block);
163     builder.setInsertionPointToEnd(&island_sink.GetBody());
164     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
165                                          llvm::ArrayRef<Value>{});
166     island_op.control().replaceAllUsesWith(island_sink.control());
167   }
168 
169   // Islands with no uses should be pinned to a graph fetch so they still
170   // execute.
171   llvm::SmallVector<Value, 8> unused_execute_controls;
172   for (auto& execute : executes)
173     if (execute.use_empty())
174       unused_execute_controls.push_back(execute.control());
175 
176   if (!unused_execute_controls.empty()) {
177     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
178     tf_executor::FetchOp fetch = graph_op.GetFetch();
179     auto fetches = llvm::to_vector<8>(fetch.getOperands());
180     fetches.append(unused_execute_controls.begin(),
181                    unused_execute_controls.end());
182     builder.setInsertionPoint(fetch);
183     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
184     fetch.erase();
185   }
186 
187   island_op.erase();
188 }
189 
runOnFunction()190 void ParallelExecuteToIslandsPass::runOnFunction() {
191   // Find islands with a single `tf_device.parallel_execute` and create
192   // individual islands per execute region of the parallel_execute.
193   llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
194   getFunction().walk([&](tf_executor::GraphOp graph_op) {
195     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
196       if (!island_op.WrapsSingleOp()) continue;
197 
198       if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
199         parallel_execute_op_islands.push_back(island_op);
200     }
201   });
202 
203   for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
204     auto parallel_execute_op =
205         cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
206     CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
207   }
208 }
209 }  // anonymous namespace
210 
CreateParallelExecuteToIslandsPass()211 std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
212   return std::make_unique<ParallelExecuteToIslandsPass>();
213 }
214 
215 static PassRegistration<ParallelExecuteToIslandsPass> pass;
216 
217 }  // namespace TFDevice
218 }  // namespace mlir
219