1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Block.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace mlir {
36 namespace TFDevice {
37
38 namespace {
39
40 struct ClusterFormationPass
41 : public PassWrapper<ClusterFormationPass, FunctionPass> {
getDependentDialectsmlir::TFDevice::__anonde985a2d0111::ClusterFormationPass42 void getDependentDialects(DialectRegistry& registry) const override {
43 registry.insert<tf_device::TensorFlowDeviceDialect>();
44 }
45
46 void runOnFunction() override;
47 };
48
49 // Cluster structure captures all the operations that are assigned to same
50 // device and can form a legal strict cluster.
51 // Ops must follow same ordering in their parent block. We rely on this
52 // assumption to perform analysis.
53 struct Cluster {
54 llvm::SmallVector<Operation*, 4> ops;
55 StringRef device;
56 };
57
GetDevice(Operation * op)58 StringRef GetDevice(Operation* op) {
59 auto device_attr = op->getAttrOfType<StringAttr>("device");
60 return device_attr ? device_attr.getValue() : "";
61 }
62
63 // An op can be merged into cluster if all of its operands are one of the
64 // following:
65 // 1) A block argument
66 // 2) A value produced by other islands
67 // 1) Defined before the cluster
68 // 2) Defined by an operation in the cluster
69 // TODO(ycao): This is not optimal as it doesn't consider the situation of
70 // defining_op's operands all meet the requirements above. In that case, the
71 // defining_op can be moved and to_merge op would be legal to absorb.
72 // TODO(ycao): Take op side-effects into consideration since they can not be
73 // re-ordered but forming clusters of non-continuous ops is effectively
74 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)75 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
76 return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
77 // Block arguments.
78 if (operand.isa<BlockArgument>()) return true;
79
80 Operation* defining_op = operand.getDefiningOp();
81
82 // Operand produced by other islands.
83 if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
84
85 // Defining op is before the cluster.
86 if (defining_op->isBeforeInBlock(c.ops.front())) return true;
87
88 // Defining op is between first and last operation in cluster. Note that
89 // cluster may contain operations that are non-continuous in their original
90 // block, thus we also need to check defining_op is also assigned to
91 // cluster's device to be sure. This is a faster check than linearly
92 // searching through all ops in cluster.
93 if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
94 GetDevice(defining_op) == c.device)
95 return true;
96
97 // Other cases, operand is generated after or outside the cluster, this
98 // means it is illegal to merge operation.
99 return false;
100 });
101 }
102
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)103 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
104 tf_device::LaunchOp launch_op) {
105 Region* launch_op_region = &launch_op.body();
106 for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
107 Value from = std::get<0>(p);
108 // TODO(jingpu): move this to RegionUtils.h in MLIR core.
109 for (auto& use : llvm::make_early_inc_range(from.getUses())) {
110 if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
111 continue;
112 use.set(std::get<1>(p));
113 }
114 }
115 }
116
117 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)118 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
119 live_outs->clear();
120
121 for (Operation& op : region->front()) {
122 for (Value v : op.getResults()) {
123 // A value is live-out if any of its users are not inside value producer's
124 // region.
125 bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
126 return !region->isAncestor(user->getParentRegion());
127 });
128
129 if (is_live_out) live_outs->emplace_back(v);
130 }
131 }
132 }
133
134 // Build a `tf_device.launch` op with a region that contains all the operations
135 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)136 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
137 // Set insertion point to right after all operations in cluster.
138 builder->setInsertionPoint(c.ops.back()->getNextNode());
139
140 // Create a stand-alone region to hold all instructions in the cluster.
141 Region region;
142 region.push_back(new Block);
143
144 // Move all operations in cluster to newly created region, stripping their
145 // "device" attribute since launch op already carries device information.
146 Block* block = ®ion.front();
147 for (Operation* op : c.ops) {
148 op->moveBefore(block, block->end());
149 op->removeAttr(builder->getIdentifier("device"));
150 }
151
152 // Get all escaped live-out values of region, they are used later to determine
153 // return values and types of launch op.
154 llvm::SmallVector<Value, 4> live_outs;
155 GetLiveOuts(®ion, &live_outs);
156
157 // Build a `tf_device.return` op at end of region, with all live-out values
158 // as operand.
159 OpBuilder return_builder(builder->getContext());
160 return_builder.setInsertionPointToEnd(block);
161 return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
162 live_outs);
163
164 llvm::SmallVector<Type, 4> live_out_types;
165 live_out_types.reserve(live_outs.size());
166 for (Value v : live_outs) {
167 live_out_types.emplace_back(v.getType());
168 }
169
170 tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
171 builder->getUnknownLoc(), builder->getStringAttr(c.device),
172 live_out_types);
173
174 // Attach the region to launch_op.
175 launch_op.body().takeBody(region);
176
177 // Replace any external uses of live-out values with return values of launch
178 // op. So live-out values no longer escape the region.
179 ReplaceLiveOutExternalUses(live_outs, launch_op);
180 }
181
BuildClusters(Block * block,OpBuilder builder)182 void BuildClusters(Block* block, OpBuilder builder) {
183 // Iteratively find clusters of different devices within an island.
184 // Whenever we see an operation that is assigned to an accelerator device
185 // (ie. device != ""), we try to merge it into the last cluster of same
186 // device. If that is infeasible (say because of violating def-before-use),
187 // create a new cluster with that operation and move on.
188 llvm::MapVector<StringRef, Cluster> nearest_clusters;
189 for (Operation& op : llvm::make_early_inc_range(*block)) {
190 auto device = GetDevice(&op);
191 if (device == "") continue;
192
193 // If no cluster of same device has been formed yet, create a new cluster
194 // with op alone.
195 auto it = nearest_clusters.find(device);
196 if (it == nearest_clusters.end()) {
197 nearest_clusters[device] = Cluster{{&op}, device};
198 continue;
199 }
200
201 // Check if it is legal to merge op into nearest cluster of same device.
202 // If positive, update cluster and move on to next operation.
203 Cluster& nearest_cluster = it->second;
204 if (CanMergeIntoCluster(nearest_cluster, &op)) {
205 nearest_cluster.ops.emplace_back(&op);
206 continue;
207 }
208
209 // If nearest cluster of same device can not absorb `op`, then that
210 // cluster needs to be finalized by building a `tf_device.launch` op with
211 // a region that contains all operations in clusters.
212 BuildLaunchForCluster(nearest_cluster, &builder);
213
214 // Create a new cluster to hold op alone and update nearest_clusters.
215 nearest_clusters[device] = Cluster{{&op}, device};
216 }
217
218 // At the end, there might be left-over found clusters that need to be
219 // built.
220 for (auto& device_cluster : nearest_clusters)
221 BuildLaunchForCluster(device_cluster.second, &builder);
222 }
223
runOnFunction()224 void ClusterFormationPass::runOnFunction() {
225 OpBuilder builder(getFunction().getContext());
226
227 // Operates on individual blocks independently of if they are directly in the
228 // function body or if they are nested in individual `tf_executor.island`.
229 for (Block& block : getFunction().getBody()) BuildClusters(&block, builder);
230 getFunction().walk([&](tf_executor::IslandOp island) {
231 BuildClusters(&island.GetBody(), builder);
232 });
233 }
234
235 } // namespace
236
CreateClusterFormationPass()237 std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
238 return std::make_unique<ClusterFormationPass>();
239 }
240
241 static PassRegistration<ClusterFormationPass> pass(
242 "tf-device-cluster-formation",
243 "Form clusters from instructions assigned to same device");
244
245 } // namespace TFDevice
246 } // namespace mlir
247