1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 // %0 = tf_executor.island {
22 // tf_executor.yield
23 // }
24 // %1:2 = tf_executor.island {
25 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 // tf_executor.yield %2 : tensor<i1>
27 // }
28 // %3:2 = tf_executor.island(%0) {
29 // %4 = "tf_device.parallel_execute"() ( {
30 // %5 = "tf.opB"() : () -> tensor<i1>
31 // tf_device.return %5 : tensor<i1>
32 // }, {
33 // %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 // tf_device.return
35 // }) {} : () -> (tensor<i1>)
36 // tf_executor.yield %4 : tensor<i1>
37 // }
38 // tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 // %0 = tf_executor.island {
43 // tf_executor.yield
44 // }
45 // %1:2 = tf_executor.island {
46 // %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 // tf_executor.yield %2 : tensor<i1>
48 // }
49 //
50 // // Island for the first region of above parallel_execute.
51 // %3:2 = tf_executor.island(%0) {
52 // %4 = "tf.opB"() : () -> tensor<i1>
53 // tf_executor.yield %4 : tensor<i1>
54 // }
55 //
56 // // Island for the second region of above parallel_execute.
57 // %5 = tf_executor.island(%0) {
58 // %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 // tf_executor.yield
60 // }
61 //
62 // tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 // When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 // then this pass will run following `replicate-to-island` pass and
66 // `tf-executor-break-up-islands` pass.
67
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h" // from @llvm-project
71 #include "mlir/IR/Builders.h" // from @llvm-project
72 #include "mlir/IR/Value.h" // from @llvm-project
73 #include "mlir/Pass/Pass.h" // from @llvm-project
74 #include "mlir/Support/LLVM.h" // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77
78 namespace mlir {
79 namespace TFDevice {
80 namespace {
81
82 struct ParallelExecuteToIslandsPass
83 : public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
84 void runOnFunction() override;
85 };
86
87 // Convert parallel_execute op to a set of islands where each region of
88 // parallel_execute op becomes a separate island. This ensures that the regions
89 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)90 void ExpandParallelExecuteToIslands(
91 tf_executor::IslandOp island_op,
92 tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
93 llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
94 const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
95 executes.reserve(num_regions);
96
97 for (int i : llvm::seq<int>(0, num_regions)) {
98 Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
99
100 // Replace terminator with tf_executor.YieldOp.
101 Operation* terminator = execute_block.getTerminator();
102 builder->setInsertionPoint(terminator);
103 auto yield = builder->create<tf_executor::YieldOp>(
104 terminator->getLoc(), terminator->getOperands());
105 terminator->erase();
106
107 // Create new island for each region.
108 builder->setInsertionPoint(island_op);
109 auto execute_island = builder->create<tf_executor::IslandOp>(
110 island_op.getLoc(), yield.getOperandTypes(),
111 island_op.control().getType(), island_op.controlInputs());
112
113 // Move over tf_device.parallel_execute body region into newly the created
114 // island.
115 execute_island.body().takeBody(*execute_block.getParent());
116 executes.push_back(execute_island);
117 }
118 }
119
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)120 void CreateIslandsFromParallelExecute(
121 tf_executor::IslandOp island_op,
122 tf_device::ParallelExecuteOp parallel_execute_op) {
123 OpBuilder builder(island_op);
124
125 // Create islands for each region of the parallel_execute op.
126 llvm::SmallVector<tf_executor::IslandOp, 4> executes;
127 ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
128 executes);
129
130 // Remap all results of parallel_execute op with outputs from newly created
131 // islands.
132 llvm::SmallVector<Value, 8> parallel_execute_outputs;
133 parallel_execute_outputs.reserve(
134 parallel_execute_op.getOperation()->getNumResults());
135
136 for (auto& execute : executes)
137 parallel_execute_outputs.append(execute.outputs().begin(),
138 execute.outputs().end());
139
140 for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
141 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
142
143 // Add sink island to pin all islands as a control dependency if there is a
144 // control dependency leading from the parallel_execute originally.
145 if (!island_op.control().use_empty()) {
146 llvm::SmallVector<Value, 8> island_operands;
147 for (auto& execute : executes) island_operands.push_back(execute.control());
148
149 builder.setInsertionPoint(island_op);
150 auto island_sink = builder.create<tf_executor::IslandOp>(
151 island_op.getLoc(), llvm::ArrayRef<Type>{},
152 island_op.control().getType(), island_operands);
153 island_sink.body().push_back(new Block);
154 builder.setInsertionPointToEnd(&island_sink.GetBody());
155 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
156 llvm::ArrayRef<Value>{});
157 island_op.control().replaceAllUsesWith(island_sink.control());
158 }
159
160 // Islands with no uses should be pinned to a graph fetch so they still
161 // execute.
162 llvm::SmallVector<Value, 8> unused_execute_controls;
163 for (auto& execute : executes)
164 if (execute.use_empty())
165 unused_execute_controls.push_back(execute.control());
166
167 if (!unused_execute_controls.empty()) {
168 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
169 tf_executor::FetchOp fetch = graph_op.GetFetch();
170 auto fetches = llvm::to_vector<8>(fetch.getOperands());
171 fetches.append(unused_execute_controls.begin(),
172 unused_execute_controls.end());
173 builder.setInsertionPoint(fetch);
174 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
175 fetch.erase();
176 }
177
178 island_op.erase();
179 }
180
runOnFunction()181 void ParallelExecuteToIslandsPass::runOnFunction() {
182 // Find islands with a single `tf_device.parallel_execute` and create
183 // individual islands per execute region of the parallel_execute.
184 llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
185 getFunction().walk([&](tf_executor::GraphOp graph_op) {
186 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
187 if (!island_op.WrapsSingleOp()) continue;
188
189 if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
190 parallel_execute_op_islands.push_back(island_op);
191 }
192 });
193
194 for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
195 auto parallel_execute_op =
196 cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
197 CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
198 }
199 }
200 } // anonymous namespace
201
CreateParallelExecuteToIslandsPass()202 std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
203 return std::make_unique<ParallelExecuteToIslandsPass>();
204 }
205
206 static PassRegistration<ParallelExecuteToIslandsPass> pass(
207 "tf-parallel-execute-to-islands",
208 "Lowers device parallel_execute to executor islands");
209
210 } // namespace TFDevice
211 } // namespace mlir
212