• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/IR/Visitors.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
46 
47 #define DEBUG_TYPE "tf-resource-device-inference"
48 
49 namespace mlir {
50 namespace TF {
51 
52 namespace {
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55 
56 // A pass that propagates device assignment of resources on a module. It
57 // performs in-function propagation, as well as cross-function propagation from
58 // callers to callees.
59 //
60 // This pass changes the module by adding "tf.device" attribute to function
61 // arguments and adding "device" attribute to TF ops.
62 struct ResourceDeviceInference
63     : public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
getArgumentmlir::TF::__anon2f7139870111::ResourceDeviceInference64   StringRef getArgument() const final { return "tf-resource-device-inference"; }
65 
getDescriptionmlir::TF::__anon2f7139870111::ResourceDeviceInference66   StringRef getDescription() const final {
67     return "Propagates the device attribute on resources from callers to "
68            "callees.";
69   }
70 
71   void runOnOperation() override;
72 };
73 
74 // A class that records each resource's device assignment in a function.
75 class PerFunctionResult {
76  public:
PerFunctionResult(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)77   explicit PerFunctionResult(
78       FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
79       : alias_analysis_(alias_analysis) {}
80 
81   // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const82   Optional<StringRef> DeviceForResource(Value resource) const {
83     Optional<StringRef> result;
84     if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
85     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
86       auto it = resource_id_to_device_.find(id);
87       if (it == resource_id_to_device_.end()) continue;
88       if (!result || result == it->second) {
89         result = it->getSecond();
90         continue;
91       }
92       // Got conflicting assignments
93       return llvm::None;
94     }
95     return result;
96   }
97 
98   // Records the device assignment for a resource. If the new assignment
99   // conflicts with an existing one, returns an error.
100   //
101   // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)102   LogicalResult AddResourceDevice(Value resource, StringRef device,
103                                   bool* changed = nullptr) {
104     if (alias_analysis_.IsUnknownResource(resource)) return success();
105     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
106       auto emplace_res = resource_id_to_device_.try_emplace(id, device);
107       if (emplace_res.second) {
108         if (changed) *changed = true;
109       } else if (emplace_res.first->getSecond() != device) {
110         // Existing assignment does not equal the new assignment.
111         return failure();
112       }
113     }
114     return success();
115   }
116 
117  private:
118   llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
119   const TF::ResourceAliasAnalysis::Info& alias_analysis_;
120 };
121 
122 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)123 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
124                                             Operation* error_reporting_op,
125                                             PerFunctionResult* result,
126                                             bool* changed = nullptr) {
127   auto res = result->AddResourceDevice(resource, device, changed);
128   if (failed(res)) {
129     error_reporting_op->emitError()
130         << "Conflicting device assignment for resource";
131   }
132   return res;
133 }
134 
135 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(FuncOp func,int arg_no)136 inline StringRef GetDeviceAttr(FuncOp func, int arg_no) {
137   auto device_attr =
138       func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
139   return device_attr ? device_attr.getValue() : "";
140 }
141 
142 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)143 inline StringRef GetDeviceAttr(Operation* op) {
144   auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
145   return device_attr ? device_attr.getValue() : "";
146 }
147 
148 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)149 void dump(StringRef message, Operation* op) {
150   llvm::dbgs() << message;
151   op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
152   llvm::dbgs() << "\n";
153 }
154 
155 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(FuncOp func_op,PerFunctionResult * result)156 LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
157                                                   PerFunctionResult* result) {
158   OpBuilder builder(func_op);
159   // Function arguments.
160   for (auto arg : filter_resources(func_op.getArguments())) {
161     StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
162     if (device_attr.empty()) {
163       // If device_attr does not exist, try to construct it from any recorded
164       // assignment.
165       if (auto device = result->DeviceForResource(arg)) {
166         func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
167                            builder.getStringAttr(*device));
168       }
169       continue;
170     }
171     // Record the attribute.
172     auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
173     if (failed(res)) return res;
174   }
175 
176   // To support WhileRegion, we need to propagate device attributes from
177   // WhileRegion operands to body/cond region arguments *prior* to visiting
178   // these regions. Use tensorflow::walk() instead of MLIR core walker to
179   // implement such a pre-order walk.
180   auto walk_res = tensorflow::GenericWalk(
181       func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
182         // We just need to visit operations in pre-order mode.
183         if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
184 
185         if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
186           // Record VarHandleOp's device attribute.
187           StringRef device_attr = GetDeviceAttr(op);
188           if (device_attr.empty()) return WalkResult::advance();
189           auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
190                                                    device_attr, op, result);
191           if (failed(res)) return WalkResult::interrupt();
192         } else if (auto identity = dyn_cast<IdentityOp>(op)) {
193           LLVM_DEBUG(dump("Visiting ", identity));
194           // Try to construct IdentityOp's attribute from recorded assignment.
195           if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
196           for (auto output : filter_resources(op->getResults())) {
197             LLVM_DEBUG(llvm::dbgs() << "  Processing output #"
198                                     << output.getResultNumber() << "\n");
199             if (auto device = result->DeviceForResource(output)) {
200               LLVM_DEBUG(llvm::dbgs()
201                          << " Setting device = " << *device << "\n");
202               identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
203             }
204           }
205         } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
206           // For WhileRegion, do local analysis prior to visiting the attached
207           // regions and propagate device annotations to the cond and body
208           // region arguments. The annotations are the union of annotations
209           // on the input and result. Resource alias analysis already propagates
210           // resource ID from the inputs to the results for a while, so just
211           // need to consider the results.
212           LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
213 
214           for (auto output : filter_resources(while_region.getResults())) {
215             auto device = result->DeviceForResource(output);
216             int output_index = output.getResultNumber();
217             if (!device) {
218               LLVM_DEBUG(llvm::dbgs()
219                          << "  No device for output #" << output_index << "\n");
220               continue;
221             }
222             // Transfer the annotation to both region arguments
223             for (Region* region : while_region.getRegions()) {
224               BlockArgument arg = region->getArgument(output_index);
225               LLVM_DEBUG(llvm::dbgs()
226                          << "  Propagating device = '" << *device
227                          << "' to arg #" << output_index << " of region #"
228                          << region->getRegionNumber() << "\n");
229               if (failed(AddResourceDeviceAndEmitError(arg, *device,
230                                                        while_region, result)))
231                 return WalkResult::interrupt();
232             }
233           }
234         }
235         return WalkResult::advance();
236       });
237   return failure(walk_res.wasInterrupted());
238 }
239 
runOnOperation()240 void ResourceDeviceInference::runOnOperation() {
241   auto module = getOperation();
242   const auto& resource_alias_analysis =
243       getAnalysis<TF::ResourceAliasAnalysis>();
244 
245   llvm::SmallDenseMap<FuncOp, PerFunctionResult, 4> per_function_results;
246   llvm::SetVector<FuncOp> worklist;
247   for (auto func_op : module.getOps<FuncOp>()) {
248     worklist.insert(func_op);
249     per_function_results.try_emplace(
250         func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
251   }
252   // Helper that propagates an op's recorded operand device assignments to its
253   // called function's arguments.
254   auto propagate_operands_to_callee_arguments =
255       [&](Operation* caller, Operation::operand_range caller_operands,
256           ArrayRef<FuncOp> callees, const PerFunctionResult& caller_res) {
257         for (FuncOp callee : callees) {
258           assert(callee);
259           auto& callee_res = per_function_results.find(callee)->getSecond();
260           bool callee_needs_recompute = false;
261           for (BlockArgument arg : filter_resources(callee.getArguments())) {
262             Value arg_operand = caller_operands[arg.getArgNumber()];
263             auto device = caller_res.DeviceForResource(arg_operand);
264             if (!device) continue;
265             LLVM_DEBUG(llvm::dbgs()
266                        << "Propagating '" << *device << "' to arg #"
267                        << arg.getArgNumber() << " of function @"
268                        << callee.getName() << "\n");
269             if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
270                                                      &callee_res,
271                                                      &callee_needs_recompute)))
272               return failure();
273           }
274           // If the callee recording is modified, make sure that it will be
275           // reprocessed.
276           if (callee_needs_recompute) worklist.insert(callee);
277         }
278         return success();
279       };
280 
281   while (!worklist.empty()) {
282     auto func_op = worklist.pop_back_val();
283     auto& func_res = per_function_results.find(func_op)->getSecond();
284     // In-function propagation.
285     if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
286       return signalPassFailure();
287 
288     // Propagation to callees.
289     auto walk_res = func_op.walk([&](Operation* op) {
290       if (auto while_op = dyn_cast<WhileOp>(op)) {
291         if (failed(propagate_operands_to_callee_arguments(
292                 while_op, while_op.getOperands(),
293                 {while_op.body_function(), while_op.cond_function()},
294                 func_res)))
295           return WalkResult::interrupt();
296       } else if (auto if_op = dyn_cast<IfOp>(op)) {
297         if (failed(propagate_operands_to_callee_arguments(
298                 if_op, if_op.input(),
299                 {if_op.then_function(), if_op.else_function()}, func_res)))
300           return WalkResult::interrupt();
301       } else if (auto call = dyn_cast<CallOpInterface>(op)) {
302         auto func = dyn_cast<FuncOp>(call.resolveCallable());
303         if (!func) {
304           op->emitError(
305               "Cannot propagate device attribute to callee: Unable to resolve "
306               "call");
307           return WalkResult::interrupt();
308         }
309         LLVM_DEBUG(llvm::dbgs()
310                    << "Visiting call to function @" << func.getName() << "\n");
311         if (failed(propagate_operands_to_callee_arguments(
312                 call, call.getArgOperands(), {func}, func_res)))
313           return WalkResult::interrupt();
314       }
315       return WalkResult::advance();
316     });
317     if (walk_res.wasInterrupted()) return signalPassFailure();
318   }
319 }
320 
321 PassRegistration<ResourceDeviceInference> pass;
322 
323 }  // namespace
324 
CreateResourceDeviceInferencePass()325 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
326   return std::make_unique<ResourceDeviceInference>();
327 }
328 
329 }  // namespace TF
330 }  // namespace mlir
331