1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Block.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace mlir {
36 namespace TFDevice {
37
38 namespace {
39
40 struct ClusterFormationPass
41 : public PassWrapper<ClusterFormationPass, FunctionPass> {
getDependentDialectsmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass42 void getDependentDialects(DialectRegistry& registry) const override {
43 registry.insert<tf_device::TensorFlowDeviceDialect>();
44 }
45
46 void runOnFunction() override;
getArgumentmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass47 StringRef getArgument() const final {
48 // This is the argument used to refer to the pass in
49 // the textual format (on the commandline for example).
50 return "tf-device-cluster-formation";
51 }
getDescriptionmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass52 StringRef getDescription() const final {
53 // This is a brief description of the pass.
54 return "Form clusters from instructions assigned to same device";
55 }
56 };
57
58 // Cluster structure captures all the operations that are assigned to same
59 // device and can form a legal strict cluster.
60 // Ops must follow same ordering in their parent block. We rely on this
61 // assumption to perform analysis.
62 struct Cluster {
63 llvm::SmallVector<Operation*, 4> ops;
64 StringRef device;
65 };
66
GetDevice(Operation * op)67 StringRef GetDevice(Operation* op) {
68 auto device_attr = op->getAttrOfType<StringAttr>("device");
69 return device_attr ? device_attr.getValue() : "";
70 }
71
72 // An op can be merged into cluster if all of its operands are one of the
73 // following:
74 // 1) A block argument
75 // 2) A value produced by other islands
76 // 1) Defined before the cluster
77 // 2) Defined by an operation in the cluster
78 // TODO(ycao): This is not optimal as it doesn't consider the situation of
79 // defining_op's operands all meet the requirements above. In that case, the
80 // defining_op can be moved and to_merge op would be legal to absorb.
81 // TODO(ycao): Take op side-effects into consideration since they can not be
82 // re-ordered but forming clusters of non-continuous ops is effectively
83 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)84 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
85 return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
86 // Block arguments.
87 if (operand.isa<BlockArgument>()) return true;
88
89 Operation* defining_op = operand.getDefiningOp();
90
91 // Operand produced by other islands.
92 if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
93
94 // Defining op is before the cluster.
95 if (defining_op->isBeforeInBlock(c.ops.front())) return true;
96
97 // Defining op is between first and last operation in cluster. Note that
98 // cluster may contain operations that are non-continuous in their original
99 // block, thus we also need to check defining_op is also assigned to
100 // cluster's device to be sure. This is a faster check than linearly
101 // searching through all ops in cluster.
102 if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
103 GetDevice(defining_op) == c.device)
104 return true;
105
106 // Other cases, operand is generated after or outside the cluster, this
107 // means it is illegal to merge operation.
108 return false;
109 });
110 }
111
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)112 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
113 tf_device::LaunchOp launch_op) {
114 Region* launch_op_region = &launch_op.body();
115 for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
116 Value from = std::get<0>(p);
117 // TODO(jingpu): move this to RegionUtils.h in MLIR core.
118 for (auto& use : llvm::make_early_inc_range(from.getUses())) {
119 if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
120 continue;
121 use.set(std::get<1>(p));
122 }
123 }
124 }
125
126 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)127 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
128 live_outs->clear();
129
130 for (Operation& op : region->front()) {
131 for (Value v : op.getResults()) {
132 // A value is live-out if any of its users are not inside value producer's
133 // region.
134 bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
135 return !region->isAncestor(user->getParentRegion());
136 });
137
138 if (is_live_out) live_outs->emplace_back(v);
139 }
140 }
141 }
142
143 // Build a `tf_device.launch` op with a region that contains all the operations
144 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)145 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
146 // Set insertion point to right after all operations in cluster.
147 builder->setInsertionPoint(c.ops.back()->getNextNode());
148
149 // Create a stand-alone region to hold all instructions in the cluster.
150 Region region;
151 region.push_back(new Block);
152
153 // Move all operations in cluster to newly created region, stripping their
154 // "device" attribute since launch op already carries device information.
155 Block* block = ®ion.front();
156 for (Operation* op : c.ops) {
157 op->moveBefore(block, block->end());
158 op->removeAttr(builder->getIdentifier("device"));
159 }
160
161 // Get all escaped live-out values of region, they are used later to determine
162 // return values and types of launch op.
163 llvm::SmallVector<Value, 4> live_outs;
164 GetLiveOuts(®ion, &live_outs);
165
166 // Build a `tf_device.return` op at end of region, with all live-out values
167 // as operand.
168 OpBuilder return_builder(builder->getContext());
169 return_builder.setInsertionPointToEnd(block);
170 return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
171 live_outs);
172
173 llvm::SmallVector<Type, 4> live_out_types;
174 live_out_types.reserve(live_outs.size());
175 for (Value v : live_outs) {
176 live_out_types.emplace_back(v.getType());
177 }
178
179 tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
180 builder->getUnknownLoc(), builder->getStringAttr(c.device),
181 live_out_types);
182
183 // Attach the region to launch_op.
184 launch_op.body().takeBody(region);
185
186 // Replace any external uses of live-out values with return values of launch
187 // op. So live-out values no longer escape the region.
188 ReplaceLiveOutExternalUses(live_outs, launch_op);
189 }
190
BuildClusters(Block * block,OpBuilder builder)191 void BuildClusters(Block* block, OpBuilder builder) {
192 // Iteratively find clusters of different devices within an island.
193 // Whenever we see an operation that is assigned to an accelerator device
194 // (ie. device != ""), we try to merge it into the last cluster of same
195 // device. If that is infeasible (say because of violating def-before-use),
196 // create a new cluster with that operation and move on.
197 llvm::MapVector<StringRef, Cluster> nearest_clusters;
198 for (Operation& op : llvm::make_early_inc_range(*block)) {
199 auto device = GetDevice(&op);
200 if (device.empty()) continue;
201
202 // If no cluster of same device has been formed yet, create a new cluster
203 // with op alone.
204 auto it = nearest_clusters.find(device);
205 if (it == nearest_clusters.end()) {
206 nearest_clusters[device] = Cluster{{&op}, device};
207 continue;
208 }
209
210 // Check if it is legal to merge op into nearest cluster of same device.
211 // If positive, update cluster and move on to next operation.
212 Cluster& nearest_cluster = it->second;
213 if (CanMergeIntoCluster(nearest_cluster, &op)) {
214 nearest_cluster.ops.emplace_back(&op);
215 continue;
216 }
217
218 // If nearest cluster of same device can not absorb `op`, then that
219 // cluster needs to be finalized by building a `tf_device.launch` op with
220 // a region that contains all operations in clusters.
221 BuildLaunchForCluster(nearest_cluster, &builder);
222
223 // Create a new cluster to hold op alone and update nearest_clusters.
224 nearest_clusters[device] = Cluster{{&op}, device};
225 }
226
227 // At the end, there might be left-over found clusters that need to be
228 // built.
229 for (auto& device_cluster : nearest_clusters)
230 BuildLaunchForCluster(device_cluster.second, &builder);
231 }
232
runOnFunction()233 void ClusterFormationPass::runOnFunction() {
234 OpBuilder builder(getFunction().getContext());
235
236 // Operates on individual blocks independently of if they are directly in the
237 // function body or if they are nested in individual `tf_executor.island`.
238 for (Block& block : getFunction().getBody()) BuildClusters(&block, builder);
239 getFunction().walk([&](tf_executor::IslandOp island) {
240 BuildClusters(&island.GetBody(), builder);
241 });
242 }
243
244 } // namespace
245
CreateClusterFormationPass()246 std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
247 return std::make_unique<ClusterFormationPass>();
248 }
249
250 static PassRegistration<ClusterFormationPass> pass;
251
252 } // namespace TFDevice
253 } // namespace mlir
254