1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/OpImplementation.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/OperationSupport.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/IR/Visitors.h" // from @llvm-project
40 #include "mlir/Pass/Pass.h" // from @llvm-project
41 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
42 #include "mlir/Support/LogicalResult.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47
48 #define DEBUG_TYPE "tf-resource-device-inference"
49
50 namespace mlir {
51 namespace TF {
52
53 namespace {
54 constexpr char kDeviceAttr[] = "device";
55 constexpr char kFuncDeviceAttr[] = "tf.device";
56
57 // A pass that propagates device assignment of resources on a module. It
58 // performs in-function propagation, as well as cross-function propagation from
59 // callers to callees.
60 //
61 // This pass changes the module by adding "tf.device" attribute to function
62 // arguments and adding "device" attribute to TF ops.
63 struct ResourceDeviceInference
64 : public ResourceDeviceInferencePassBase<ResourceDeviceInference> {
65 void runOnOperation() override;
66 };
67
68 // A class that records each resource's device assignment in a function.
69 class PerFunctionResult {
70 public:
PerFunctionResult(func::FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)71 explicit PerFunctionResult(
72 func::FuncOp func_op,
73 const TF::ResourceAliasAnalysis::Info& alias_analysis)
74 : alias_analysis_(alias_analysis) {}
75
76 // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const77 Optional<StringRef> DeviceForResource(Value resource) const {
78 Optional<StringRef> result;
79 if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
80 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
81 auto it = resource_id_to_device_.find(id);
82 if (it == resource_id_to_device_.end()) continue;
83 if (!result || result == it->second) {
84 result = it->getSecond();
85 continue;
86 }
87 // Got conflicting assignments
88 return llvm::None;
89 }
90 return result;
91 }
92
93 // Records the device assignment for a resource. If the new assignment
94 // conflicts with an existing one, returns an error.
95 //
96 // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)97 LogicalResult AddResourceDevice(Value resource, StringRef device,
98 bool* changed = nullptr) {
99 if (alias_analysis_.IsUnknownResource(resource)) return success();
100 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
101 auto emplace_res = resource_id_to_device_.try_emplace(id, device);
102 if (emplace_res.second) {
103 if (changed) *changed = true;
104 } else if (emplace_res.first->getSecond() != device) {
105 // Existing assignment does not equal the new assignment.
106 return failure();
107 }
108 }
109 return success();
110 }
111
112 private:
113 llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
114 const TF::ResourceAliasAnalysis::Info& alias_analysis_;
115 };
116
117 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)118 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
119 Operation* error_reporting_op,
120 PerFunctionResult* result,
121 bool* changed = nullptr) {
122 auto res = result->AddResourceDevice(resource, device, changed);
123 if (failed(res)) {
124 error_reporting_op->emitError()
125 << "Conflicting device assignment for resource";
126 }
127 return res;
128 }
129
130 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(func::FuncOp func,int arg_no)131 inline StringRef GetDeviceAttr(func::FuncOp func, int arg_no) {
132 auto device_attr =
133 func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
134 return device_attr ? device_attr.getValue() : "";
135 }
136
137 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)138 inline StringRef GetDeviceAttr(Operation* op) {
139 auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
140 return device_attr ? device_attr.getValue() : "";
141 }
142
143 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)144 void dump(StringRef message, Operation* op) {
145 llvm::dbgs() << message;
146 op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
147 llvm::dbgs() << "\n";
148 }
149
150 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(func::FuncOp func_op,PerFunctionResult * result)151 LogicalResult ComputeResourceDevicesInComputation(func::FuncOp func_op,
152 PerFunctionResult* result) {
153 OpBuilder builder(func_op);
154 // Function arguments.
155 for (auto arg : filter_resources(func_op.getArguments())) {
156 StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
157 if (device_attr.empty()) {
158 // If device_attr does not exist, try to construct it from any recorded
159 // assignment.
160 if (auto device = result->DeviceForResource(arg)) {
161 func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
162 builder.getStringAttr(*device));
163 }
164 continue;
165 }
166 // Record the attribute.
167 auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
168 if (failed(res)) return res;
169 }
170
171 // To support WhileRegion, we need to propagate device attributes from
172 // WhileRegion operands to body/cond region arguments *prior* to visiting
173 // these regions, so use a pre-order walk.
174 WalkResult walk_res = func_op.walk<WalkOrder::PreOrder>([&](Operation* op) {
175 if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
176 // Record VarHandleOp's device attribute.
177 StringRef device_attr = GetDeviceAttr(op);
178 if (device_attr.empty()) return WalkResult::advance();
179 auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
180 device_attr, op, result);
181 if (failed(res)) return WalkResult::interrupt();
182 } else if (auto identity = dyn_cast<IdentityOp>(op)) {
183 LLVM_DEBUG(dump("Visiting ", identity));
184 // Try to construct IdentityOp's attribute from recorded assignment.
185 if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
186 for (auto output : filter_resources(op->getResults())) {
187 LLVM_DEBUG(llvm::dbgs() << " Processing output #"
188 << output.getResultNumber() << "\n");
189 if (auto device = result->DeviceForResource(output)) {
190 LLVM_DEBUG(llvm::dbgs() << " Setting device = " << *device << "\n");
191 identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
192 }
193 }
194 } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
195 // For WhileRegion, do local analysis prior to visiting the attached
196 // regions and propagate device annotations to the cond and body
197 // region arguments. The annotations are the union of annotations
198 // on the input and result. Resource alias analysis already propagates
199 // resource ID from the inputs to the results for a while, so just
200 // need to consider the results.
201 LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
202
203 for (auto output : filter_resources(while_region.getResults())) {
204 auto device = result->DeviceForResource(output);
205 int output_index = output.getResultNumber();
206 if (!device) {
207 LLVM_DEBUG(llvm::dbgs()
208 << " No device for output #" << output_index << "\n");
209 continue;
210 }
211 // Transfer the annotation to both region arguments
212 for (Region* region : while_region.getRegions()) {
213 BlockArgument arg = region->getArgument(output_index);
214 LLVM_DEBUG(llvm::dbgs()
215 << " Propagating device = '" << *device << "' to arg #"
216 << output_index << " of region #"
217 << region->getRegionNumber() << "\n");
218 if (failed(AddResourceDeviceAndEmitError(arg, *device, while_region,
219 result)))
220 return WalkResult::interrupt();
221 }
222 }
223 }
224 return WalkResult::advance();
225 });
226 return failure(walk_res.wasInterrupted());
227 }
228
runOnOperation()229 void ResourceDeviceInference::runOnOperation() {
230 auto module = getOperation();
231 const auto& resource_alias_analysis =
232 getAnalysis<TF::ResourceAliasAnalysis>();
233
234 llvm::SmallDenseMap<func::FuncOp, PerFunctionResult, 4> per_function_results;
235 llvm::SetVector<func::FuncOp> worklist;
236 for (auto func_op : module.getOps<func::FuncOp>()) {
237 worklist.insert(func_op);
238 per_function_results.try_emplace(
239 func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
240 }
241 // Helper that propagates an op's recorded operand device assignments to its
242 // called function's arguments.
243 auto propagate_operands_to_callee_arguments =
244 [&](Operation* caller, Operation::operand_range caller_operands,
245 ArrayRef<func::FuncOp> callees, const PerFunctionResult& caller_res) {
246 for (func::FuncOp callee : callees) {
247 assert(callee);
248 auto& callee_res = per_function_results.find(callee)->getSecond();
249 bool callee_needs_recompute = false;
250 for (BlockArgument arg : filter_resources(callee.getArguments())) {
251 Value arg_operand = caller_operands[arg.getArgNumber()];
252 auto device = caller_res.DeviceForResource(arg_operand);
253 if (!device) continue;
254 LLVM_DEBUG(llvm::dbgs()
255 << "Propagating '" << *device << "' to arg #"
256 << arg.getArgNumber() << " of function @"
257 << callee.getName() << "\n");
258 if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
259 &callee_res,
260 &callee_needs_recompute)))
261 return failure();
262 }
263 // If the callee recording is modified, make sure that it will be
264 // reprocessed.
265 if (callee_needs_recompute) worklist.insert(callee);
266 }
267 return success();
268 };
269
270 while (!worklist.empty()) {
271 auto func_op = worklist.pop_back_val();
272 auto& func_res = per_function_results.find(func_op)->getSecond();
273 // In-function propagation.
274 if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
275 return signalPassFailure();
276
277 // Propagation to callees.
278 auto walk_res = func_op.walk([&](Operation* op) {
279 if (auto while_op = dyn_cast<WhileOp>(op)) {
280 if (failed(propagate_operands_to_callee_arguments(
281 while_op, while_op.getOperands(),
282 {while_op.body_function(), while_op.cond_function()},
283 func_res)))
284 return WalkResult::interrupt();
285 } else if (auto if_op = dyn_cast<IfOp>(op)) {
286 if (failed(propagate_operands_to_callee_arguments(
287 if_op, if_op.input(),
288 {if_op.then_function(), if_op.else_function()}, func_res)))
289 return WalkResult::interrupt();
290 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
291 auto func = dyn_cast<func::FuncOp>(call.resolveCallable());
292 if (!func) {
293 op->emitError(
294 "Cannot propagate device attribute to callee: Unable to resolve "
295 "call");
296 return WalkResult::interrupt();
297 }
298 LLVM_DEBUG(llvm::dbgs()
299 << "Visiting call to function @" << func.getName() << "\n");
300 if (failed(propagate_operands_to_callee_arguments(
301 call, call.getArgOperands(), {func}, func_res)))
302 return WalkResult::interrupt();
303 }
304 return WalkResult::advance();
305 });
306 if (walk_res.wasInterrupted()) return signalPassFailure();
307 }
308 }
309
310 } // namespace
311
CreateResourceDeviceInferencePass()312 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
313 return std::make_unique<ResourceDeviceInference>();
314 }
315
316 } // namespace TF
317 } // namespace mlir
318