• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19 
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Block.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace mlir {
36 namespace TFDevice {
37 
38 namespace {
39 
40 struct ClusterFormationPass
41     : public PassWrapper<ClusterFormationPass, FunctionPass> {
getDependentDialectsmlir::TFDevice::__anonde985a2d0111::ClusterFormationPass42   void getDependentDialects(DialectRegistry& registry) const override {
43     registry.insert<tf_device::TensorFlowDeviceDialect>();
44   }
45 
46   void runOnFunction() override;
47 };
48 
49 // Cluster structure captures all the operations that are assigned to same
50 // device and can form a legal strict cluster.
51 // Ops must follow same ordering in their parent block. We rely on this
52 // assumption to perform analysis.
53 struct Cluster {
54   llvm::SmallVector<Operation*, 4> ops;
55   StringRef device;
56 };
57 
GetDevice(Operation * op)58 StringRef GetDevice(Operation* op) {
59   auto device_attr = op->getAttrOfType<StringAttr>("device");
60   return device_attr ? device_attr.getValue() : "";
61 }
62 
63 // An op can be merged into cluster if all of its operands are one of the
64 // following:
65 //  1) A block argument
66 //  2) A value produced by other islands
67 //  1) Defined before the cluster
68 //  2) Defined by an operation in the cluster
69 // TODO(ycao): This is not optimal as it doesn't consider the situation of
70 // defining_op's operands all meet the requirements above. In that case, the
71 // defining_op can be moved and to_merge op would be legal to absorb.
72 // TODO(ycao): Take op side-effects into consideration since they can not be
73 // re-ordered but forming clusters of non-continuous ops is effectively
74 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)75 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
76   return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
77     // Block arguments.
78     if (operand.isa<BlockArgument>()) return true;
79 
80     Operation* defining_op = operand.getDefiningOp();
81 
82     // Operand produced by other islands.
83     if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
84 
85     // Defining op is before the cluster.
86     if (defining_op->isBeforeInBlock(c.ops.front())) return true;
87 
88     // Defining op is between first and last operation in cluster. Note that
89     // cluster may contain operations that are non-continuous in their original
90     // block, thus we also need to check defining_op is also assigned to
91     // cluster's device to be sure. This is a faster check than linearly
92     // searching through all ops in cluster.
93     if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
94         GetDevice(defining_op) == c.device)
95       return true;
96 
97     // Other cases, operand is generated after or outside the cluster, this
98     // means it is illegal to merge operation.
99     return false;
100   });
101 }
102 
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)103 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
104                                 tf_device::LaunchOp launch_op) {
105   Region* launch_op_region = &launch_op.body();
106   for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
107     Value from = std::get<0>(p);
108     // TODO(jingpu): move this to RegionUtils.h in MLIR core.
109     for (auto& use : llvm::make_early_inc_range(from.getUses())) {
110       if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
111         continue;
112       use.set(std::get<1>(p));
113     }
114   }
115 }
116 
117 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)118 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
119   live_outs->clear();
120 
121   for (Operation& op : region->front()) {
122     for (Value v : op.getResults()) {
123       // A value is live-out if any of its users are not inside value producer's
124       // region.
125       bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
126         return !region->isAncestor(user->getParentRegion());
127       });
128 
129       if (is_live_out) live_outs->emplace_back(v);
130     }
131   }
132 }
133 
134 // Build a `tf_device.launch` op with a region that contains all the operations
135 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)136 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
137   // Set insertion point to right after all operations in cluster.
138   builder->setInsertionPoint(c.ops.back()->getNextNode());
139 
140   // Create a stand-alone region to hold all instructions in the cluster.
141   Region region;
142   region.push_back(new Block);
143 
144   // Move all operations in cluster to newly created region, stripping their
145   // "device" attribute since launch op already carries device information.
146   Block* block = &region.front();
147   for (Operation* op : c.ops) {
148     op->moveBefore(block, block->end());
149     op->removeAttr(builder->getIdentifier("device"));
150   }
151 
152   // Get all escaped live-out values of region, they are used later to determine
153   // return values and types of launch op.
154   llvm::SmallVector<Value, 4> live_outs;
155   GetLiveOuts(&region, &live_outs);
156 
157   // Build a `tf_device.return` op at end of region, with all live-out values
158   // as operand.
159   OpBuilder return_builder(builder->getContext());
160   return_builder.setInsertionPointToEnd(block);
161   return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
162                                              live_outs);
163 
164   llvm::SmallVector<Type, 4> live_out_types;
165   live_out_types.reserve(live_outs.size());
166   for (Value v : live_outs) {
167     live_out_types.emplace_back(v.getType());
168   }
169 
170   tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
171       builder->getUnknownLoc(), builder->getStringAttr(c.device),
172       live_out_types);
173 
174   // Attach the region to launch_op.
175   launch_op.body().takeBody(region);
176 
177   // Replace any external uses of live-out values with return values of launch
178   // op. So live-out values no longer escape the region.
179   ReplaceLiveOutExternalUses(live_outs, launch_op);
180 }
181 
BuildClusters(Block * block,OpBuilder builder)182 void BuildClusters(Block* block, OpBuilder builder) {
183   // Iteratively find clusters of different devices within an island.
184   // Whenever we see an operation that is assigned to an accelerator device
185   // (ie. device != ""), we try to merge it into the last cluster of same
186   // device. If that is infeasible (say because of violating def-before-use),
187   // create a new cluster with that operation and move on.
188   llvm::MapVector<StringRef, Cluster> nearest_clusters;
189   for (Operation& op : llvm::make_early_inc_range(*block)) {
190     auto device = GetDevice(&op);
191     if (device == "") continue;
192 
193     // If no cluster of same device has been formed yet, create a new cluster
194     // with op alone.
195     auto it = nearest_clusters.find(device);
196     if (it == nearest_clusters.end()) {
197       nearest_clusters[device] = Cluster{{&op}, device};
198       continue;
199     }
200 
201     // Check if it is legal to merge op into nearest cluster of same device.
202     // If positive, update cluster and move on to next operation.
203     Cluster& nearest_cluster = it->second;
204     if (CanMergeIntoCluster(nearest_cluster, &op)) {
205       nearest_cluster.ops.emplace_back(&op);
206       continue;
207     }
208 
209     // If nearest cluster of same device can not absorb `op`, then that
210     // cluster needs to be finalized by building a `tf_device.launch` op with
211     // a region that contains all operations in clusters.
212     BuildLaunchForCluster(nearest_cluster, &builder);
213 
214     // Create a new cluster to hold op alone and update nearest_clusters.
215     nearest_clusters[device] = Cluster{{&op}, device};
216   }
217 
218   // At the end, there might be left-over found clusters that need to be
219   // built.
220   for (auto& device_cluster : nearest_clusters)
221     BuildLaunchForCluster(device_cluster.second, &builder);
222 }
223 
runOnFunction()224 void ClusterFormationPass::runOnFunction() {
225   OpBuilder builder(getFunction().getContext());
226 
227   // Operates on individual blocks independently of if they are directly in the
228   // function body or if they are nested in individual `tf_executor.island`.
229   for (Block& block : getFunction().getBody()) BuildClusters(&block, builder);
230   getFunction().walk([&](tf_executor::IslandOp island) {
231     BuildClusters(&island.GetBody(), builder);
232   });
233 }
234 
235 }  // namespace
236 
CreateClusterFormationPass()237 std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass() {
238   return std::make_unique<ClusterFormationPass>();
239 }
240 
241 static PassRegistration<ClusterFormationPass> pass(
242     "tf-device-cluster-formation",
243     "Form clusters from instructions assigned to same device");
244 
245 }  // namespace TFDevice
246 }  // namespace mlir
247