• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/IR/Visitors.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
46 
47 #define DEBUG_TYPE "tf-resource-device-inference"
48 
49 namespace mlir {
50 namespace TF {
51 
52 namespace {
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55 
56 // A pass that propagates device assignment of resources on a module. It
57 // performs in-function propagation, as well as cross-function propagation from
58 // callers to callees.
59 //
60 // This pass changes the module by adding "tf.device" attribute to function
61 // arguments and adding "device" attribute to TF ops.
62 struct ResourceDeviceInference
63     : public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
64   void runOnOperation() override;
65 };
66 
67 // A class that records each resource's device assignment in a function.
68 class PerFunctionResult {
69  public:
PerFunctionResult(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)70   explicit PerFunctionResult(
71       FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
72       : alias_analysis_(alias_analysis) {}
73 
74   // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const75   Optional<StringRef> DeviceForResource(Value resource) const {
76     Optional<StringRef> result;
77     if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
78     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
79       auto it = resource_id_to_device_.find(id);
80       if (it == resource_id_to_device_.end()) continue;
81       if (!result || result == it->second) {
82         result = it->getSecond();
83         continue;
84       }
85       // Got conflicting assignments
86       return llvm::None;
87     }
88     return result;
89   }
90 
91   // Records the device assignment for a resource. If the new assignment
92   // conflicts with an existing one, returns an error.
93   //
94   // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)95   LogicalResult AddResourceDevice(Value resource, StringRef device,
96                                   bool* changed = nullptr) {
97     if (alias_analysis_.IsUnknownResource(resource)) return success();
98     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
99       auto emplace_res = resource_id_to_device_.try_emplace(id, device);
100       if (emplace_res.second) {
101         if (changed) *changed = true;
102       } else if (emplace_res.first->getSecond() != device) {
103         // Existing assignment does not equal the new assignment.
104         return failure();
105       }
106     }
107     return success();
108   }
109 
110  private:
111   llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
112   const TF::ResourceAliasAnalysis::Info& alias_analysis_;
113 };
114 
115 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)116 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
117                                             Operation* error_reporting_op,
118                                             PerFunctionResult* result,
119                                             bool* changed = nullptr) {
120   auto res = result->AddResourceDevice(resource, device, changed);
121   if (failed(res)) {
122     error_reporting_op->emitError()
123         << "Conflicting device assignment for resource";
124   }
125   return res;
126 }
127 
128 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(FuncOp func,int arg_no)129 inline StringRef GetDeviceAttr(FuncOp func, int arg_no) {
130   auto device_attr =
131       func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
132   return device_attr ? device_attr.getValue() : "";
133 }
134 
135 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)136 inline StringRef GetDeviceAttr(Operation* op) {
137   auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
138   return device_attr ? device_attr.getValue() : "";
139 }
140 
141 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)142 void dump(StringRef message, Operation* op) {
143   llvm::dbgs() << message;
144   op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
145   llvm::dbgs() << "\n";
146 }
147 
148 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(FuncOp func_op,PerFunctionResult * result)149 LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
150                                                   PerFunctionResult* result) {
151   OpBuilder builder(func_op);
152   // Function arguments.
153   for (auto arg : filter_resources(func_op.getArguments())) {
154     StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
155     if (device_attr.empty()) {
156       // If device_attr does not exist, try to construct it from any recorded
157       // assignment.
158       if (auto device = result->DeviceForResource(arg)) {
159         func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
160                            builder.getStringAttr(*device));
161       }
162       continue;
163     }
164     // Record the attribute.
165     auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
166     if (failed(res)) return res;
167   }
168 
169   // To support WhileRegion, we need to propagate device attributes from
170   // WhileRegion operands to body/cond region arguments *prior* to visiting
171   // these regions. Use tensorflow::walk() instead of MLIR core walker to
172   // implement such a pre-order walk.
173   auto walk_res = tensorflow::GenericWalk(
174       func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
175         // We just need to visit operations in pre-order mode.
176         if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
177 
178         if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
179           // Record VarHandleOp's device attribute.
180           StringRef device_attr = GetDeviceAttr(op);
181           if (device_attr.empty()) return WalkResult::advance();
182           auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
183                                                    device_attr, op, result);
184           if (failed(res)) return WalkResult::interrupt();
185         } else if (auto identity = dyn_cast<IdentityOp>(op)) {
186           LLVM_DEBUG(dump("Visiting ", identity));
187           // Try to construct IdentityOp's attribute from recorded assignment.
188           if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
189           for (auto output : filter_resources(op->getResults())) {
190             LLVM_DEBUG(llvm::dbgs() << "  Processing output #"
191                                     << output.getResultNumber() << "\n");
192             if (auto device = result->DeviceForResource(output)) {
193               LLVM_DEBUG(llvm::dbgs()
194                          << " Setting device = " << *device << "\n");
195               identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
196             }
197           }
198         } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
199           // For WhileRegion, do local analysis prior to visiting the attached
200           // regions and propagate device annotations to the cond and body
201           // region arguments. The annotations are the union of annotations
202           // on the input and result. Resource alias analysis already propagates
203           // resource ID from the inputs to the results for a while, so just
204           // need to consider the results.
205           LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
206 
207           for (auto output : filter_resources(while_region.getResults())) {
208             auto device = result->DeviceForResource(output);
209             int output_index = output.getResultNumber();
210             if (!device) {
211               LLVM_DEBUG(llvm::dbgs()
212                          << "  No device for output #" << output_index << "\n");
213               continue;
214             }
215             // Transfer the annotation to both region arguments
216             for (Region* region : while_region.getRegions()) {
217               BlockArgument arg = region->getArgument(output_index);
218               LLVM_DEBUG(llvm::dbgs()
219                          << "  Propagating device = '" << *device
220                          << "' to arg #" << output_index << " of region #"
221                          << region->getRegionNumber() << "\n");
222               if (failed(AddResourceDeviceAndEmitError(arg, *device,
223                                                        while_region, result)))
224                 return WalkResult::interrupt();
225             }
226           }
227         }
228         return WalkResult::advance();
229       });
230   return failure(walk_res.wasInterrupted());
231 }
232 
runOnOperation()233 void ResourceDeviceInference::runOnOperation() {
234   auto module = getOperation();
235   const auto& resource_alias_analysis =
236       getAnalysis<TF::ResourceAliasAnalysis>();
237 
238   llvm::SmallDenseMap<FuncOp, PerFunctionResult, 4> per_function_results;
239   llvm::SetVector<FuncOp> worklist;
240   for (auto func_op : module.getOps<FuncOp>()) {
241     worklist.insert(func_op);
242     per_function_results.try_emplace(
243         func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
244   }
245   // Helper that propagates an op's recorded operand device assignments to its
246   // called function's arguments.
247   auto propagate_operands_to_callee_arguments =
248       [&](Operation* caller, Operation::operand_range caller_operands,
249           ArrayRef<FuncOp> callees, const PerFunctionResult& caller_res) {
250         for (FuncOp callee : callees) {
251           assert(callee);
252           auto& callee_res = per_function_results.find(callee)->getSecond();
253           bool callee_needs_recompute = false;
254           for (BlockArgument arg : filter_resources(callee.getArguments())) {
255             Value arg_operand = caller_operands[arg.getArgNumber()];
256             auto device = caller_res.DeviceForResource(arg_operand);
257             if (!device) continue;
258             LLVM_DEBUG(llvm::dbgs()
259                        << "Propagating '" << *device << "' to arg #"
260                        << arg.getArgNumber() << " of function @"
261                        << callee.getName() << "\n");
262             if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
263                                                      &callee_res,
264                                                      &callee_needs_recompute)))
265               return failure();
266           }
267           // If the callee recording is modified, make sure that it will be
268           // reprocessed.
269           if (callee_needs_recompute) worklist.insert(callee);
270         }
271         return success();
272       };
273 
274   while (!worklist.empty()) {
275     auto func_op = worklist.pop_back_val();
276     auto& func_res = per_function_results.find(func_op)->getSecond();
277     // In-function propagation.
278     if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
279       return signalPassFailure();
280 
281     // Propagation to callees.
282     auto walk_res = func_op.walk([&](Operation* op) {
283       if (auto while_op = dyn_cast<WhileOp>(op)) {
284         if (failed(propagate_operands_to_callee_arguments(
285                 while_op, while_op.getOperands(),
286                 {while_op.body_function(), while_op.cond_function()},
287                 func_res)))
288           return WalkResult::interrupt();
289       } else if (auto if_op = dyn_cast<IfOp>(op)) {
290         if (failed(propagate_operands_to_callee_arguments(
291                 if_op, if_op.input(),
292                 {if_op.then_function(), if_op.else_function()}, func_res)))
293           return WalkResult::interrupt();
294       } else if (auto call = dyn_cast<CallOpInterface>(op)) {
295         auto func = dyn_cast<FuncOp>(call.resolveCallable());
296         if (!func) {
297           op->emitError(
298               "Cannot propagate device attribute to callee: Unable to resolve "
299               "call");
300           return WalkResult::interrupt();
301         }
302         LLVM_DEBUG(llvm::dbgs()
303                    << "Visiting call to function @" << func.getName() << "\n");
304         if (failed(propagate_operands_to_callee_arguments(
305                 call, call.getArgOperands(), {func}, func_res)))
306           return WalkResult::interrupt();
307       }
308       return WalkResult::advance();
309     });
310     if (walk_res.wasInterrupted()) return signalPassFailure();
311   }
312 }
313 
314 PassRegistration<ResourceDeviceInference> pass(
315     "tf-resource-device-inference",
316     "Propagates the device attribute on resources from callers to callees.");
317 
318 }  // namespace
319 
CreateResourceDeviceInferencePass()320 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
321   return std::make_unique<ResourceDeviceInference>();
322 }
323 
324 }  // namespace TF
325 }  // namespace mlir
326