• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19 
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Block.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace mlir {
36 namespace TFDevice {
37 
38 namespace {
39 
40 struct ClusterFormationPass
41     : public PassWrapper<ClusterFormationPass, FunctionPass> {
getDependentDialectsmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass42   void getDependentDialects(DialectRegistry& registry) const override {
43     registry.insert<tf_device::TensorFlowDeviceDialect>();
44   }
45 
46   void runOnFunction() override;
getArgumentmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass47   StringRef getArgument() const final {
48     // This is the argument used to refer to the pass in
49     // the textual format (on the commandline for example).
50     return "tf-device-cluster-formation";
51   }
getDescriptionmlir::TFDevice::__anon2117f2d20111::ClusterFormationPass52   StringRef getDescription() const final {
53     // This is a brief description of the pass.
54     return "Form clusters from instructions assigned to same device";
55   }
56 };
57 
58 // Cluster structure captures all the operations that are assigned to same
59 // device and can form a legal strict cluster.
60 // Ops must follow same ordering in their parent block. We rely on this
61 // assumption to perform analysis.
62 struct Cluster {
63   llvm::SmallVector<Operation*, 4> ops;
64   StringRef device;
65 };
66 
GetDevice(Operation * op)67 StringRef GetDevice(Operation* op) {
68   auto device_attr = op->getAttrOfType<StringAttr>("device");
69   return device_attr ? device_attr.getValue() : "";
70 }
71 
72 // An op can be merged into cluster if all of its operands are one of the
73 // following:
74 //  1) A block argument
75 //  2) A value produced by other islands
76 //  1) Defined before the cluster
77 //  2) Defined by an operation in the cluster
78 // TODO(ycao): This is not optimal as it doesn't consider the situation of
79 // defining_op's operands all meet the requirements above. In that case, the
80 // defining_op can be moved and to_merge op would be legal to absorb.
81 // TODO(ycao): Take op side-effects into consideration since they can not be
82 // re-ordered but forming clusters of non-continuous ops is effectively
83 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)84 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
85   return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
86     // Block arguments.
87     if (operand.isa<BlockArgument>()) return true;
88 
89     Operation* defining_op = operand.getDefiningOp();
90 
91     // Operand produced by other islands.
92     if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
93 
94     // Defining op is before the cluster.
95     if (defining_op->isBeforeInBlock(c.ops.front())) return true;
96 
97     // Defining op is between first and last operation in cluster. Note that
98     // cluster may contain operations that are non-continuous in their original
99     // block, thus we also need to check defining_op is also assigned to
100     // cluster's device to be sure. This is a faster check than linearly
101     // searching through all ops in cluster.
102     if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
103         GetDevice(defining_op) == c.device)
104       return true;
105 
106     // Other cases, operand is generated after or outside the cluster, this
107     // means it is illegal to merge operation.
108     return false;
109   });
110 }
111 
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)112 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
113                                 tf_device::LaunchOp launch_op) {
114   Region* launch_op_region = &launch_op.body();
115   for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
116     Value from = std::get<0>(p);
117     // TODO(jingpu): move this to RegionUtils.h in MLIR core.
118     for (auto& use : llvm::make_early_inc_range(from.getUses())) {
119       if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
120         continue;
121       use.set(std::get<1>(p));
122     }
123   }
124 }
125 
126 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)127 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
128   live_outs->clear();
129 
130   for (Operation& op : region->front()) {
131     for (Value v : op.getResults()) {
132       // A value is live-out if any of its users are not inside value producer's
133       // region.
134       bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
135         return !region->isAncestor(user->getParentRegion());
136       });
137 
138       if (is_live_out) live_outs->emplace_back(v);
139     }
140   }
141 }
142 
143 // Build a `tf_device.launch` op with a region that contains all the operations
144 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)145 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
146   // Set insertion point to right after all operations in cluster.
147   builder->setInsertionPoint(c.ops.back()->getNextNode());
148 
149   // Create a stand-alone region to hold all instructions in the cluster.
150   Region region;
151   region.push_back(new Block);
152 
153   // Move all operations in cluster to newly created region, stripping their
154   // "device" attribute since launch op already carries device information.
155   Block* block = &region.front();
156   for (Operation* op : c.ops) {
157     op->moveBefore(block, block->end());
158     op->removeAttr(builder->getIdentifier("device"));
159   }
160 
161   // Get all escaped live-out values of region, they are used later to determine
162   // return values and types of launch op.
163   llvm::SmallVector<Value, 4> live_outs;
164   GetLiveOuts(&region, &live_outs);
165 
166   // Build a `tf_device.return` op at end of region, with all live-out values
167   // as operand.
168   OpBuilder return_builder(builder->getContext());
169   return_builder.setInsertionPointToEnd(block);
170   return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
171                                              live_outs);
172 
173   llvm::SmallVector<Type, 4> live_out_types;
174   live_out_types.reserve(live_outs.size());
175   for (Value v : live_outs) {
176     live_out_types.emplace_back(v.getType());
177   }
178 
179   tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
180       builder->getUnknownLoc(), builder->getStringAttr(c.device),
181       live_out_types);
182 
183   // Attach the region to launch_op.
184   launch_op.body().takeBody(region);
185 
186   // Replace any external uses of live-out values with return values of launch
187   // op. So live-out values no longer escape the region.
188   ReplaceLiveOutExternalUses(live_outs, launch_op);
189 }
190 
BuildClusters(Block * block,OpBuilder builder)191 void BuildClusters(Block* block, OpBuilder builder) {
192   // Iteratively find clusters of different devices within an island.
193   // Whenever we see an operation that is assigned to an accelerator device
194   // (ie. device != ""), we try to merge it into the last cluster of same
195   // device. If that is infeasible (say because of violating def-before-use),
196   // create a new cluster with that operation and move on.
197   llvm::MapVector<StringRef, Cluster> nearest_clusters;
198   for (Operation& op : llvm::make_early_inc_range(*block)) {
199     auto device = GetDevice(&op);
200     if (device.empty()) continue;
201 
202     // If no cluster of same device has been formed yet, create a new cluster
203     // with op alone.
204     auto it = nearest_clusters.find(device);
205     if (it == nearest_clusters.end()) {
206       nearest_clusters[device] = Cluster{{&op}, device};
207       continue;
208     }
209 
210     // Check if it is legal to merge op into nearest cluster of same device.
211     // If positive, update cluster and move on to next operation.
212     Cluster& nearest_cluster = it->second;
213     if (CanMergeIntoCluster(nearest_cluster, &op)) {
214       nearest_cluster.ops.emplace_back(&op);
215       continue;
216     }
217 
218     // If nearest cluster of same device can not absorb `op`, then that
219     // cluster needs to be finalized by building a `tf_device.launch` op with
220     // a region that contains all operations in clusters.
221     BuildLaunchForCluster(nearest_cluster, &builder);
222 
223     // Create a new cluster to hold op alone and update nearest_clusters.
224     nearest_clusters[device] = Cluster{{&op}, device};
225   }
226 
227   // At the end, there might be left-over found clusters that need to be
228   // built.
229   for (auto& device_cluster : nearest_clusters)
230     BuildLaunchForCluster(device_cluster.second, &builder);
231 }
232 
runOnFunction()233 void ClusterFormationPass::runOnFunction() {
234   OpBuilder builder(getFunction().getContext());
235 
236   // Operates on individual blocks independently of if they are directly in the
237   // function body or if they are nested in individual `tf_executor.island`.
238   for (Block& block : getFunction().getBody()) BuildClusters(&block, builder);
239   getFunction().walk([&](tf_executor::IslandOp island) {
240     BuildClusters(&island.GetBody(), builder);
241   });
242 }
243 
244 }  // namespace
245 
CreateClusterFormationPass()246 std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
247   return std::make_unique<ClusterFormationPass>();
248 }
249 
250 static PassRegistration<ClusterFormationPass> pass;
251 
252 }  // namespace TFDevice
253 }  // namespace mlir
254