• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/IR/Visitors.h"  // from @llvm-project
40 #include "mlir/Pass/Pass.h"  // from @llvm-project
41 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
42 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47 
48 #define DEBUG_TYPE "tf-resource-device-inference"
49 
50 namespace mlir {
51 namespace TF {
52 
53 namespace {
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kFuncDeviceAttr[] = "tf.device";
56 
57 // A pass that propagates device assignment of resources on a module. It
58 // performs in-function propagation, as well as cross-function propagation from
59 // callers to callees.
60 //
61 // This pass changes the module by adding "tf.device" attribute to function
62 // arguments and adding "device" attribute to TF ops.
63 struct ResourceDeviceInference
64     : public ResourceDeviceInferencePassBase<ResourceDeviceInference> {
65   void runOnOperation() override;
66 };
67 
68 // A class that records each resource's device assignment in a function.
69 class PerFunctionResult {
70  public:
PerFunctionResult(func::FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)71   explicit PerFunctionResult(
72       func::FuncOp func_op,
73       const TF::ResourceAliasAnalysis::Info& alias_analysis)
74       : alias_analysis_(alias_analysis) {}
75 
76   // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const77   Optional<StringRef> DeviceForResource(Value resource) const {
78     Optional<StringRef> result;
79     if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
80     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
81       auto it = resource_id_to_device_.find(id);
82       if (it == resource_id_to_device_.end()) continue;
83       if (!result || result == it->second) {
84         result = it->getSecond();
85         continue;
86       }
87       // Got conflicting assignments
88       return llvm::None;
89     }
90     return result;
91   }
92 
93   // Records the device assignment for a resource. If the new assignment
94   // conflicts with an existing one, returns an error.
95   //
96   // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)97   LogicalResult AddResourceDevice(Value resource, StringRef device,
98                                   bool* changed = nullptr) {
99     if (alias_analysis_.IsUnknownResource(resource)) return success();
100     for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
101       auto emplace_res = resource_id_to_device_.try_emplace(id, device);
102       if (emplace_res.second) {
103         if (changed) *changed = true;
104       } else if (emplace_res.first->getSecond() != device) {
105         // Existing assignment does not equal the new assignment.
106         return failure();
107       }
108     }
109     return success();
110   }
111 
112  private:
113   llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
114   const TF::ResourceAliasAnalysis::Info& alias_analysis_;
115 };
116 
117 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)118 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
119                                             Operation* error_reporting_op,
120                                             PerFunctionResult* result,
121                                             bool* changed = nullptr) {
122   auto res = result->AddResourceDevice(resource, device, changed);
123   if (failed(res)) {
124     error_reporting_op->emitError()
125         << "Conflicting device assignment for resource";
126   }
127   return res;
128 }
129 
130 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(func::FuncOp func,int arg_no)131 inline StringRef GetDeviceAttr(func::FuncOp func, int arg_no) {
132   auto device_attr =
133       func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
134   return device_attr ? device_attr.getValue() : "";
135 }
136 
137 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)138 inline StringRef GetDeviceAttr(Operation* op) {
139   auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
140   return device_attr ? device_attr.getValue() : "";
141 }
142 
143 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)144 void dump(StringRef message, Operation* op) {
145   llvm::dbgs() << message;
146   op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
147   llvm::dbgs() << "\n";
148 }
149 
150 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(func::FuncOp func_op,PerFunctionResult * result)151 LogicalResult ComputeResourceDevicesInComputation(func::FuncOp func_op,
152                                                   PerFunctionResult* result) {
153   OpBuilder builder(func_op);
154   // Function arguments.
155   for (auto arg : filter_resources(func_op.getArguments())) {
156     StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
157     if (device_attr.empty()) {
158       // If device_attr does not exist, try to construct it from any recorded
159       // assignment.
160       if (auto device = result->DeviceForResource(arg)) {
161         func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
162                            builder.getStringAttr(*device));
163       }
164       continue;
165     }
166     // Record the attribute.
167     auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
168     if (failed(res)) return res;
169   }
170 
171   // To support WhileRegion, we need to propagate device attributes from
172   // WhileRegion operands to body/cond region arguments *prior* to visiting
173   // these regions, so use a pre-order walk.
174   WalkResult walk_res = func_op.walk<WalkOrder::PreOrder>([&](Operation* op) {
175     if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
176       // Record VarHandleOp's device attribute.
177       StringRef device_attr = GetDeviceAttr(op);
178       if (device_attr.empty()) return WalkResult::advance();
179       auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
180                                                device_attr, op, result);
181       if (failed(res)) return WalkResult::interrupt();
182     } else if (auto identity = dyn_cast<IdentityOp>(op)) {
183       LLVM_DEBUG(dump("Visiting ", identity));
184       // Try to construct IdentityOp's attribute from recorded assignment.
185       if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
186       for (auto output : filter_resources(op->getResults())) {
187         LLVM_DEBUG(llvm::dbgs() << "  Processing output #"
188                                 << output.getResultNumber() << "\n");
189         if (auto device = result->DeviceForResource(output)) {
190           LLVM_DEBUG(llvm::dbgs() << " Setting device = " << *device << "\n");
191           identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
192         }
193       }
194     } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
195       // For WhileRegion, do local analysis prior to visiting the attached
196       // regions and propagate device annotations to the cond and body
197       // region arguments. The annotations are the union of annotations
198       // on the input and result. Resource alias analysis already propagates
199       // resource ID from the inputs to the results for a while, so just
200       // need to consider the results.
201       LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
202 
203       for (auto output : filter_resources(while_region.getResults())) {
204         auto device = result->DeviceForResource(output);
205         int output_index = output.getResultNumber();
206         if (!device) {
207           LLVM_DEBUG(llvm::dbgs()
208                      << "  No device for output #" << output_index << "\n");
209           continue;
210         }
211         // Transfer the annotation to both region arguments
212         for (Region* region : while_region.getRegions()) {
213           BlockArgument arg = region->getArgument(output_index);
214           LLVM_DEBUG(llvm::dbgs()
215                      << "  Propagating device = '" << *device << "' to arg #"
216                      << output_index << " of region #"
217                      << region->getRegionNumber() << "\n");
218           if (failed(AddResourceDeviceAndEmitError(arg, *device, while_region,
219                                                    result)))
220             return WalkResult::interrupt();
221         }
222       }
223     }
224     return WalkResult::advance();
225   });
226   return failure(walk_res.wasInterrupted());
227 }
228 
runOnOperation()229 void ResourceDeviceInference::runOnOperation() {
230   auto module = getOperation();
231   const auto& resource_alias_analysis =
232       getAnalysis<TF::ResourceAliasAnalysis>();
233 
234   llvm::SmallDenseMap<func::FuncOp, PerFunctionResult, 4> per_function_results;
235   llvm::SetVector<func::FuncOp> worklist;
236   for (auto func_op : module.getOps<func::FuncOp>()) {
237     worklist.insert(func_op);
238     per_function_results.try_emplace(
239         func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
240   }
241   // Helper that propagates an op's recorded operand device assignments to its
242   // called function's arguments.
243   auto propagate_operands_to_callee_arguments =
244       [&](Operation* caller, Operation::operand_range caller_operands,
245           ArrayRef<func::FuncOp> callees, const PerFunctionResult& caller_res) {
246         for (func::FuncOp callee : callees) {
247           assert(callee);
248           auto& callee_res = per_function_results.find(callee)->getSecond();
249           bool callee_needs_recompute = false;
250           for (BlockArgument arg : filter_resources(callee.getArguments())) {
251             Value arg_operand = caller_operands[arg.getArgNumber()];
252             auto device = caller_res.DeviceForResource(arg_operand);
253             if (!device) continue;
254             LLVM_DEBUG(llvm::dbgs()
255                        << "Propagating '" << *device << "' to arg #"
256                        << arg.getArgNumber() << " of function @"
257                        << callee.getName() << "\n");
258             if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
259                                                      &callee_res,
260                                                      &callee_needs_recompute)))
261               return failure();
262           }
263           // If the callee recording is modified, make sure that it will be
264           // reprocessed.
265           if (callee_needs_recompute) worklist.insert(callee);
266         }
267         return success();
268       };
269 
270   while (!worklist.empty()) {
271     auto func_op = worklist.pop_back_val();
272     auto& func_res = per_function_results.find(func_op)->getSecond();
273     // In-function propagation.
274     if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
275       return signalPassFailure();
276 
277     // Propagation to callees.
278     auto walk_res = func_op.walk([&](Operation* op) {
279       if (auto while_op = dyn_cast<WhileOp>(op)) {
280         if (failed(propagate_operands_to_callee_arguments(
281                 while_op, while_op.getOperands(),
282                 {while_op.body_function(), while_op.cond_function()},
283                 func_res)))
284           return WalkResult::interrupt();
285       } else if (auto if_op = dyn_cast<IfOp>(op)) {
286         if (failed(propagate_operands_to_callee_arguments(
287                 if_op, if_op.input(),
288                 {if_op.then_function(), if_op.else_function()}, func_res)))
289           return WalkResult::interrupt();
290       } else if (auto call = dyn_cast<CallOpInterface>(op)) {
291         auto func = dyn_cast<func::FuncOp>(call.resolveCallable());
292         if (!func) {
293           op->emitError(
294               "Cannot propagate device attribute to callee: Unable to resolve "
295               "call");
296           return WalkResult::interrupt();
297         }
298         LLVM_DEBUG(llvm::dbgs()
299                    << "Visiting call to function @" << func.getName() << "\n");
300         if (failed(propagate_operands_to_callee_arguments(
301                 call, call.getArgOperands(), {func}, func_res)))
302           return WalkResult::interrupt();
303       }
304       return WalkResult::advance();
305     });
306     if (walk_res.wasInterrupted()) return signalPassFailure();
307   }
308 }
309 
310 }  // namespace
311 
CreateResourceDeviceInferencePass()312 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
313   return std::make_unique<ResourceDeviceInference>();
314 }
315 
316 }  // namespace TF
317 }  // namespace mlir
318