• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass clusters the TensorFlow ops by host. The program generated by this
17 // pass will have one function per host where all operations in the same
18 // function are placed on the same host. Each result of the per-host function
19 // will have a "tf.device" attribute which specifies the device assignment of
20 // the result.
21 //
22 // The pass currently assumes that there is no circular dependency among the
23 // per-host functions. For example, if there exists an operation placed on
24 // host_A that consumes the result of an operation placed on host_B, then there
25 // does not exist any operation placed on host_B that conumes any result of any
26 // operation placed on host_A.
27 
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 
38 namespace mlir {
39 namespace TF {
40 namespace {
41 
42 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
43 using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
44 
45 constexpr const char *kHostAttr = "host";
46 constexpr const char *kDeviceAttr = "device";
47 constexpr const char *kTFDeviceAttr = "tf.device";
48 // TODO(donglin): Handle the case where the address of localhost is different
49 // from /job:localhost/replica:0/task:0.
50 constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
51 constexpr const char *kErrorMessage =
52     "The operation that uses the operand is on a different host than the "
53     "operation that defines the op. This pass does not support cross-host data "
54     "transfer yet";
55 
56 // The host address is identified by the job/replicate/task in the device name.
GetHost(llvm::StringRef device)57 std::string GetHost(llvm::StringRef device) {
58   ParsedName parsed_name;
59   DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
60   std::string result = DeviceNameUtils::ParsedNameToString(
61       DeviceNameUtils::AddressSpace(parsed_name));
62   return result.empty() ? kLocalhost : result;
63 }
64 
GetHost(Operation * op)65 std::string GetHost(Operation *op) {
66   std::string device = "";
67   if (StringAttr attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
68     device = attr.getValue().str();
69   }
70   return GetHost(device);
71 }
72 
73 // The device is considered to be on the localhost iff one of the following is
74 // true:
75 // 1) None of the job/replica/task is specified in the device name.
76 // 2) The job/replica/task in the device name are explicitly specified as
77 //    /job:localhost/replica:0/task:0.
IsOnLocalHost(llvm::StringRef device)78 bool IsOnLocalHost(llvm::StringRef device) {
79   std::string host = GetHost(device);
80   return host == kLocalhost;
81 }
82 
83 // This structure contains the metadata of the per-host function. All operations
84 // in this function should be on the same host.
85 struct FunctionMetadata {
86   // The original function name before partition.
87   llvm::StringRef original_name;
88   // The insertion point of partition functions.
89   Block::iterator insertion_point;
90   // The partitioned function name.
91   llvm::StringRef partition_name;
92   // The input values of the function.
93   llvm::SmallVector<Value, 4> inputs;
94   // The result values of the function.
95   llvm::SmallVector<Value, 4> results;
96   // The devices of the input values. It should have the same size as inputs.
97   llvm::SmallVector<std::string, 4> input_devices;
98   // The devices of the result values. It should have the same size as results.
99   llvm::SmallVector<std::string, 4> result_devices;
100   // The operations to be included in the body of the function.
101   llvm::SmallVector<Operation *, 4> ops;
102 
103   FuncOp partition_op;
104 };
105 
106 // Returns a map that maps the host address to the metadata of the function
107 // for that remote host. The metadata of the function specifies the input
108 // values, result values, result devices and the operations to be included in
109 // the function body.
GetFunctionMetadatas(FuncOp func_op)110 llvm::Optional<llvm::StringMap<FunctionMetadata>> GetFunctionMetadatas(
111     FuncOp func_op) {
112   llvm::StringMap<FunctionMetadata> metadatas;
113   WalkResult result = func_op.getBody().walk([&](Operation *op) {
114     std::string op_host = GetHost(op);
115     FunctionMetadata &func_metadata = metadatas[op_host];
116     func_metadata.original_name = func_op.getName();
117     func_metadata.insertion_point = ++Block::iterator(func_op);
118     func_metadata.ops.push_back(op);
119 
120     for (Value value : op->getOperands()) {
121       std::string value_device = "";
122 
123       // If the value is defined as an argument of the func_op, adds it to
124       // the argument list of the function that uses this op.
125       if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
126         if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
127                 block_arg.getArgNumber(), kTFDeviceAttr)) {
128           value_device = attr.getValue().str();
129         }
130 
131         if (GetHost(value_device) != op_host) {
132           op->emitOpError() << kErrorMessage;
133           return WalkResult::interrupt();
134         }
135 
136         if (llvm::find(func_metadata.inputs, value) ==
137             func_metadata.inputs.end()) {
138           func_metadata.inputs.push_back(value);
139           func_metadata.input_devices.push_back(value_device);
140         }
141         continue;
142       }
143 
144       Operation *defining_op = value.getDefiningOp();
145       std::string defining_op_host = GetHost(defining_op);
146       FunctionMetadata &defining_func_metadata = metadatas[defining_op_host];
147 
148       if (StringAttr attr =
149               defining_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
150         value_device = attr.getValue().str();
151       }
152 
153       // If the value is used as an operand of the terminator op, adds it to
154       // the result list of function that defines this op.
155       if (op->hasTrait<OpTrait::IsTerminator>()) {
156         if (llvm::find(defining_func_metadata.results, value) ==
157             defining_func_metadata.results.end()) {
158           defining_func_metadata.results.push_back(value);
159           defining_func_metadata.result_devices.push_back(value_device);
160         }
161         continue;
162       }
163 
164       if (defining_op_host != op_host) {
165         op->emitOpError() << kErrorMessage;
166         return WalkResult::interrupt();
167       }
168     }
169     return WalkResult::advance();
170   });
171 
172   if (result.wasInterrupted()) return llvm::None;
173 
174   return metadatas;
175 }
176 
177 // Creates functions in the given module using the given FunctionMetadatas.
CreateFunctions(ModuleOp module_op,llvm::StringMap<FunctionMetadata> & metadatas)178 void CreateFunctions(ModuleOp module_op,
179                      llvm::StringMap<FunctionMetadata> &metadatas) {
180   MLIRContext *context = module_op.getContext();
181   SymbolTable symbol_table(module_op);
182   for (auto &iter : metadatas) {
183     llvm::StringRef host = iter.first();
184     FunctionMetadata &metadata = iter.second;
185 
186     // Do not create any new function for the operations on the localhost.
187     if (IsOnLocalHost(host)) continue;
188 
189     llvm::SmallVector<mlir::Type, 4> input_types;
190     llvm::SmallVector<mlir::Type, 4> result_types;
191     for (Value input : metadata.inputs) {
192       input_types.push_back(input.getType());
193     }
194     for (Value result : metadata.results) {
195       result_types.push_back(result.getType());
196     }
197 
198     // Replaces ':' and '/' with '_' in the host name and uses the resulting
199     // string as the function name.
200     std::string func_name =
201         absl::StrCat(iter.second.original_name.str(), ":", host.str());
202     std::replace(func_name.begin(), func_name.end(), ':', '_');
203     std::replace(func_name.begin(), func_name.end(), '/', '_');
204 
205     FunctionType func_type =
206         FunctionType::get(context, input_types, result_types);
207     Location loc = metadata.ops.front()->getLoc();
208     FuncOp func_op = FuncOp::create(loc, func_name, func_type);
209     // Sets the device attribute for every input and every result of the
210     // function.
211     for (int i : llvm::seq<int>(0, metadata.input_devices.size())) {
212       func_op.setArgAttr(i, kTFDeviceAttr,
213                          StringAttr::get(context, metadata.input_devices[i]));
214     }
215     for (int i : llvm::seq<int>(0, metadata.result_devices.size())) {
216       func_op.setResultAttr(
217           i, kTFDeviceAttr,
218           StringAttr::get(context, metadata.result_devices[i]));
219     }
220 
221     func_op->setAttr(kHostAttr, StringAttr::get(context, host));
222     func_op.setPublic();
223     Block *block = func_op.addEntryBlock();
224 
225     // Clones and moves the operations into the function's body. And the cloned
226     // operation should use the arguments of the newly created func_op as
227     // appropriate.
228     OpBuilder builder(block, block->end());
229     BlockAndValueMapping mapping;
230     for (int i : llvm::seq<int>(0, metadata.inputs.size())) {
231       Value original_value = metadata.inputs[i];
232       Value new_value = func_op.getArgument(i);
233       mapping.map(original_value, new_value);
234     }
235     for (Operation *op : metadata.ops) {
236       builder.clone(*op, mapping);
237     }
238     // Creates the ReturnOp so that the per-host function returns the
239     // correct values of the cloned operations.
240     llvm::SmallVector<Value, 4> results_after_mapping;
241     for (Value result : metadata.results) {
242       results_after_mapping.push_back(mapping.lookupOrDefault(result));
243     }
244     builder.create<ReturnOp>(loc, results_after_mapping);
245     symbol_table.insert(func_op, metadata.insertion_point++);
246     // Record the actual name. The symbol table might rename the FuncOp if there
247     // is name collision.
248     metadata.partition_name = func_op.getName();
249   }
250 }
251 
252 // Creates a tf_device.remote_run call for every remote function. And replaces
253 // usages of the results of the original operations with the results of the
254 // tf_device.remote_run calls.
CreateRemoteRunCalls(MLIRContext * context,const llvm::StringMap<FunctionMetadata> & metadatas)255 void CreateRemoteRunCalls(MLIRContext *context,
256                           const llvm::StringMap<FunctionMetadata> &metadatas) {
257   BlockAndValueMapping mapping;
258   for (auto &iter : metadatas) {
259     llvm::StringRef host = iter.first();
260     const FunctionMetadata &metadata = iter.second;
261 
262     // Do not create tf_device.remote_run call for the operations already placed
263     // on the localhost.
264     if (IsOnLocalHost(host)) continue;
265 
266     // Creates the tf_device.remote_run operation.
267     OpBuilder builder(metadata.ops.back());
268     llvm::SmallVector<Type, 4> result_types;
269     for (Value result : metadata.results) {
270       result_types.push_back(result.getType());
271     }
272     Location loc = metadata.ops.front()->getLoc();
273     llvm::SmallVector<Value, 4> inputs_after_mapping;
274     for (Value input : metadata.inputs) {
275       inputs_after_mapping.push_back(mapping.lookupOrDefault(input));
276     }
277 
278     tf_device::RemoteRunOp remote_run_op =
279         builder.create<tf_device::RemoteRunOp>(loc, result_types, host,
280                                                metadata.partition_name,
281                                                inputs_after_mapping);
282     // Clones the tf_device.remote_run operation to replace its callee args with
283     // the results of the other tf_device.remote_run operations using the
284     // `mapping` as appropriate.
285     Operation *cloned_remote_run_op =
286         builder.clone(*remote_run_op.getOperation(), mapping);
287     remote_run_op.erase();
288 
289     // Replaces usages of the results of the original operations with the
290     // results of the tf_device.remote_run operations.
291     for (int i : llvm::seq<int>(0, metadata.results.size())) {
292       Value original_value = metadata.results[i];
293       Value new_value = cloned_remote_run_op->getResult(i);
294       original_value.replaceAllUsesWith(new_value);
295       mapping.map(original_value, new_value);
296     }
297   }
298 }
299 
300 class ClusterTFOpsByHostPass
301     : public PassWrapper<ClusterTFOpsByHostPass, OperationPass<ModuleOp>> {
getArgument() const302   StringRef getArgument() const final { return "cluster-tf-ops-by-host"; }
303 
getDescription() const304   StringRef getDescription() const final {
305     return "Cluster the TensorFlow ops by host so that each function only "
306            "contains ops placed on the same host";
307   }
308 
runOnOperation()309   void runOnOperation() override {
310     MLIRContext *context = &getContext();
311     ModuleOp module_op = getOperation();
312     SmallVector<FuncOp, 4> original_func;
313     for (auto func_op : module_op.getOps<FuncOp>()) {
314       original_func.push_back(func_op);
315     }
316     for (auto func_op : original_func) {
317       llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
318           GetFunctionMetadatas(func_op);
319       if (!metadatas) {
320         signalPassFailure();
321         return;
322       }
323 
324       CreateFunctions(module_op, *metadatas);
325       CreateRemoteRunCalls(context, *metadatas);
326 
327       // Erases the original operations which have been cloned in the remote
328       // functions.
329       for (auto &iter : *metadatas) {
330         llvm::StringRef host = iter.first();
331         FunctionMetadata &metadata = iter.second;
332         // Do not erase operations placed on the localhost.
333         if (IsOnLocalHost(host)) continue;
334 
335         for (int i = metadata.ops.size() - 1; i >= 0; i--) {
336           metadata.ops[i]->erase();
337         }
338       }
339     }
340   }
341 };
342 
343 }  // namespace
344 
CreateClusterTFOpsByHostPass()345 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass() {
346   return std::make_unique<ClusterTFOpsByHostPass>();
347 }
348 
349 static PassRegistration<ClusterTFOpsByHostPass> pass;
350 
351 }  // namespace TF
352 }  // namespace mlir
353