• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per region of
17 // `tf_device.parallel_execute`.
18 //
19 // For example, the following:
20 //
21 //  %0 = tf_executor.island {
22 //    tf_executor.yield
23 //  }
24 //  %1:2 = tf_executor.island {
25 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
26 //      tf_executor.yield %2 : tensor<i1>
27 //  }
28 //  %3:2 = tf_executor.island(%0) {
29 //    %4 = "tf_device.parallel_execute"() ( {
30 //      %5 = "tf.opB"() : () -> tensor<i1>
31 //      tf_device.return %5 : tensor<i1>
32 //    }, {
33 //      %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
34 //      tf_device.return
35 //    }) {} : () -> (tensor<i1>)
36 //    tf_executor.yield %4 : tensor<i1>
37 //  }
38 //  tf_executor.fetch %3#0 : tensor<i1>
39 //
40 // gets lowered to:
41 //
42 //  %0 = tf_executor.island {
43 //    tf_executor.yield
44 //  }
45 //  %1:2 = tf_executor.island {
46 //    %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
47 //    tf_executor.yield %2 : tensor<i1>
48 //  }
49 //
50 //  // Island for the first region of above parallel_execute.
51 //  %3:2 = tf_executor.island(%0) {
52 //    %4 = "tf.opB"() : () -> tensor<i1>
53 //    tf_executor.yield %4 : tensor<i1>
54 //  }
55 //
56 //  // Island for the second region of above parallel_execute.
57 //  %5 = tf_executor.island(%0) {
58 //    %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
59 //    tf_executor.yield
60 //  }
61 //
62 //  tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
63 //
64 //  When tf_device.parallel_execute op is enclosed after tf_device.replicate,
65 //  then this pass will run following `replicate-to-island` pass and
66 //  `tf-executor-break-up-islands` pass.
67 
68 #include "llvm/ADT/STLExtras.h"
69 #include "llvm/ADT/SmallVector.h"
70 #include "mlir/IR/Block.h"  // from @llvm-project
71 #include "mlir/IR/Builders.h"  // from @llvm-project
72 #include "mlir/IR/Value.h"  // from @llvm-project
73 #include "mlir/Pass/Pass.h"  // from @llvm-project
74 #include "mlir/Support/LLVM.h"  // from @llvm-project
75 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
76 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
77 
78 namespace mlir {
79 namespace TFDevice {
80 namespace {
81 
82 struct ParallelExecuteToIslandsPass
83     : public PassWrapper<ParallelExecuteToIslandsPass, FunctionPass> {
84   void runOnFunction() override;
85 };
86 
87 // Convert parallel_execute op to a set of islands where each region of
88 // parallel_execute op becomes a separate island. This ensures that the regions
89 // of the parallel_execute op gets executed concurrently.
ExpandParallelExecuteToIslands(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op,OpBuilder * builder,llvm::SmallVectorImpl<tf_executor::IslandOp> & executes)90 void ExpandParallelExecuteToIslands(
91     tf_executor::IslandOp island_op,
92     tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
93     llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
94   const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
95   executes.reserve(num_regions);
96 
97   for (int i : llvm::seq<int>(0, num_regions)) {
98     Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
99 
100     // Replace terminator with tf_executor.YieldOp.
101     Operation* terminator = execute_block.getTerminator();
102     builder->setInsertionPoint(terminator);
103     auto yield = builder->create<tf_executor::YieldOp>(
104         terminator->getLoc(), terminator->getOperands());
105     terminator->erase();
106 
107     // Create new island for each region.
108     builder->setInsertionPoint(island_op);
109     auto execute_island = builder->create<tf_executor::IslandOp>(
110         island_op.getLoc(), yield.getOperandTypes(),
111         island_op.control().getType(), island_op.controlInputs());
112 
113     // Move over tf_device.parallel_execute body region into newly the created
114     // island.
115     execute_island.body().takeBody(*execute_block.getParent());
116     executes.push_back(execute_island);
117   }
118 }
119 
CreateIslandsFromParallelExecute(tf_executor::IslandOp island_op,tf_device::ParallelExecuteOp parallel_execute_op)120 void CreateIslandsFromParallelExecute(
121     tf_executor::IslandOp island_op,
122     tf_device::ParallelExecuteOp parallel_execute_op) {
123   OpBuilder builder(island_op);
124 
125   // Create islands for each region of the parallel_execute op.
126   llvm::SmallVector<tf_executor::IslandOp, 4> executes;
127   ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
128                                  executes);
129 
130   // Remap all results of parallel_execute op with outputs from newly created
131   // islands.
132   llvm::SmallVector<Value, 8> parallel_execute_outputs;
133   parallel_execute_outputs.reserve(
134       parallel_execute_op.getOperation()->getNumResults());
135 
136   for (auto& execute : executes)
137     parallel_execute_outputs.append(execute.outputs().begin(),
138                                     execute.outputs().end());
139 
140   for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
141     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
142 
143   // Add sink island to pin all islands as a control dependency if there is a
144   // control dependency leading from the parallel_execute originally.
145   if (!island_op.control().use_empty()) {
146     llvm::SmallVector<Value, 8> island_operands;
147     for (auto& execute : executes) island_operands.push_back(execute.control());
148 
149     builder.setInsertionPoint(island_op);
150     auto island_sink = builder.create<tf_executor::IslandOp>(
151         island_op.getLoc(), llvm::ArrayRef<Type>{},
152         island_op.control().getType(), island_operands);
153     island_sink.body().push_back(new Block);
154     builder.setInsertionPointToEnd(&island_sink.GetBody());
155     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
156                                          llvm::ArrayRef<Value>{});
157     island_op.control().replaceAllUsesWith(island_sink.control());
158   }
159 
160   // Islands with no uses should be pinned to a graph fetch so they still
161   // execute.
162   llvm::SmallVector<Value, 8> unused_execute_controls;
163   for (auto& execute : executes)
164     if (execute.use_empty())
165       unused_execute_controls.push_back(execute.control());
166 
167   if (!unused_execute_controls.empty()) {
168     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
169     tf_executor::FetchOp fetch = graph_op.GetFetch();
170     auto fetches = llvm::to_vector<8>(fetch.getOperands());
171     fetches.append(unused_execute_controls.begin(),
172                    unused_execute_controls.end());
173     builder.setInsertionPoint(fetch);
174     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
175     fetch.erase();
176   }
177 
178   island_op.erase();
179 }
180 
runOnFunction()181 void ParallelExecuteToIslandsPass::runOnFunction() {
182   // Find islands with a single `tf_device.parallel_execute` and create
183   // individual islands per execute region of the parallel_execute.
184   llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
185   getFunction().walk([&](tf_executor::GraphOp graph_op) {
186     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
187       if (!island_op.WrapsSingleOp()) continue;
188 
189       if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
190         parallel_execute_op_islands.push_back(island_op);
191     }
192   });
193 
194   for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
195     auto parallel_execute_op =
196         cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
197     CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
198   }
199 }
200 }  // anonymous namespace
201 
CreateParallelExecuteToIslandsPass()202 std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass() {
203   return std::make_unique<ParallelExecuteToIslandsPass>();
204 }
205 
206 static PassRegistration<ParallelExecuteToIslandsPass> pass(
207     "tf-parallel-execute-to-islands",
208     "Lowers device parallel_execute to executor islands");
209 
210 }  // namespace TFDevice
211 }  // namespace mlir
212