1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 // %0 = tf_executor.island {
22 // tf_executor.yield
23 // }
24 // %1:2 = tf_executor.island {
25 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 // tf_executor.yield %2 : tensor<i1>
27 // }
28 // %3:2 = tf_executor.island(%0) {
29 // %4 = "tf_device.parallel_execute"() ( {
30 // %5 = "tf.opB"() : () -> tensor<i1>
31 // tf_device.return %5 : tensor<i1>
32 // }, {
33 // %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 // tf_device.return
35 // }) {} : () -> (tensor<i1>)
36 // tf_executor.yield %4 : tensor<i1>
37 // }
38 // tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 // %0 = tf_executor.island {
43 // tf_executor.yield
44 // }
45 // %1:2 = tf_executor.island {
46 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 // tf_executor.yield %2 : tensor<i1>
48 // }
49 //
50 // // Island for the first region of above parallel_execute.
51 // %3:2 = tf_executor.island(%0) {
52 // %4 = "tf.opB"() : () -> tensor<i1>
53 // tf_executor.yield %4 : tensor<i1>
54 // }
55 //
56 // // Island for the second region of above parallel_execute.
57 // %5 = tf_executor.island(%0) {
58 // %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 // tf_executor.yield
60 // }
61 //
62 // tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 // When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 // then this pass will run following `replicate-to-island` pass and
66 // `tf-executor-break-up-islands` pass.
67
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h" // from @llvm-project
71 #include "mlir/IR/Builders.h" // from @llvm-project
72 #include "mlir/IR/Value.h" // from @llvm-project
73 #include "mlir/Pass/Pass.h" // from @llvm-project
74 #include "mlir/Support/LLVM.h" // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77
78 namespace mlir {
79 namespace TFDevice {
80 namespace {
81
82 struct ParallelExecuteToIslandsPass
83 : public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
getArgumentmlir::TFDevice::__anon203c14300111::ParallelExecuteToIslandsPass84 StringRef getArgument() const final {
85 // This is the argument used to refer to the pass in
86 // the textual format (on the commandline for example).
87 return "tf-parallel-execute-to-islands";
88 }
getDescriptionmlir::TFDevice::__anon203c14300111::ParallelExecuteToIslandsPass89 StringRef getDescription() const final {
90 // This is a brief description of the pass.
91 return "Lowers device parallel_execute to executor islands";
92 }
93 void runOnFunction() override;
94 };
95
96 // Convert parallel_execute op to a set of islands where each region of
97 // parallel_execute op becomes a separate island. This ensures that the regions
98 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)99 void ExpandParallelExecuteToIslands(
100 tf_executor::IslandOp island_op,
101 tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
102 llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
103 const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
104 executes.reserve(num_regions);
105
106 for (int i : llvm::seq<int>(0, num_regions)) {
107 Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
108
109 // Replace terminator with tf_executor.YieldOp.
110 Operation* terminator = execute_block.getTerminator();
111 builder->setInsertionPoint(terminator);
112 auto yield = builder->create<tf_executor::YieldOp>(
113 terminator->getLoc(), terminator->getOperands());
114 terminator->erase();
115
116 // Create new island for each region.
117 builder->setInsertionPoint(island_op);
118 auto execute_island = builder->create<tf_executor::IslandOp>(
119 island_op.getLoc(), yield.getOperandTypes(),
120 island_op.control().getType(), island_op.controlInputs());
121
122 // Move over tf_device.parallel_execute body region into newly the created
123 // island.
124 execute_island.body().takeBody(*execute_block.getParent());
125 executes.push_back(execute_island);
126 }
127 }
128
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)129 void CreateIslandsFromParallelExecute(
130 tf_executor::IslandOp island_op,
131 tf_device::ParallelExecuteOp parallel_execute_op) {
132 OpBuilder builder(island_op);
133
134 // Create islands for each region of the parallel_execute op.
135 llvm::SmallVector<tf_executor::IslandOp, 4> executes;
136 ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
137 executes);
138
139 // Remap all results of parallel_execute op with outputs from newly created
140 // islands.
141 llvm::SmallVector<Value, 8> parallel_execute_outputs;
142 parallel_execute_outputs.reserve(
143 parallel_execute_op.getOperation()->getNumResults());
144
145 for (auto& execute : executes)
146 parallel_execute_outputs.append(execute.outputs().begin(),
147 execute.outputs().end());
148
149 for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
150 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
151
152 // Add sink island to pin all islands as a control dependency if there is a
153 // control dependency leading from the parallel_execute originally.
154 if (!island_op.control().use_empty()) {
155 llvm::SmallVector<Value, 8> island_operands;
156 for (auto& execute : executes) island_operands.push_back(execute.control());
157
158 builder.setInsertionPoint(island_op);
159 auto island_sink = builder.create<tf_executor::IslandOp>(
160 island_op.getLoc(), llvm::ArrayRef<Type>{},
161 island_op.control().getType(), island_operands);
162 island_sink.body().push_back(new Block);
163 builder.setInsertionPointToEnd(&island_sink.GetBody());
164 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
165 llvm::ArrayRef<Value>{});
166 island_op.control().replaceAllUsesWith(island_sink.control());
167 }
168
169 // Islands with no uses should be pinned to a graph fetch so they still
170 // execute.
171 llvm::SmallVector<Value, 8> unused_execute_controls;
172 for (auto& execute : executes)
173 if (execute.use_empty())
174 unused_execute_controls.push_back(execute.control());
175
176 if (!unused_execute_controls.empty()) {
177 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
178 tf_executor::FetchOp fetch = graph_op.GetFetch();
179 auto fetches = llvm::to_vector<8>(fetch.getOperands());
180 fetches.append(unused_execute_controls.begin(),
181 unused_execute_controls.end());
182 builder.setInsertionPoint(fetch);
183 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
184 fetch.erase();
185 }
186
187 island_op.erase();
188 }
189
runOnFunction()190 void ParallelExecuteToIslandsPass::runOnFunction() {
191 // Find islands with a single `tf_device.parallel_execute` and create
192 // individual islands per execute region of the parallel_execute.
193 llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
194 getFunction().walk([&](tf_executor::GraphOp graph_op) {
195 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
196 if (!island_op.WrapsSingleOp()) continue;
197
198 if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
199 parallel_execute_op_islands.push_back(island_op);
200 }
201 });
202
203 for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
204 auto parallel_execute_op =
205 cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
206 CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
207 }
208 }
209 } // anonymous namespace
210
CreateParallelExecuteToIslandsPass()211 std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
212 return std::make_unique<ParallelExecuteToIslandsPass>();
213 }
214
215 static PassRegistration<ParallelExecuteToIslandsPass> pass;
216
217 } // namespace TFDevice
218 } // namespace mlir
219