1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/OpImplementation.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/OperationSupport.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/IR/Visitors.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
46
47 #define DEBUG_TYPE "tf-resource-device-inference"
48
49 namespace mlir {
50 namespace TF {
51
52 namespace {
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55
56 // A pass that propagates device assignment of resources on a module. It
57 // performs in-function propagation, as well as cross-function propagation from
58 // callers to callees.
59 //
60 // This pass changes the module by adding "tf.device" attribute to function
61 // arguments and adding "device" attribute to TF ops.
62 struct ResourceDeviceInference
63 : public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
getArgumentmlir::TF::__anon2f7139870111::ResourceDeviceInference64 StringRef getArgument() const final { return "tf-resource-device-inference"; }
65
getDescriptionmlir::TF::__anon2f7139870111::ResourceDeviceInference66 StringRef getDescription() const final {
67 return "Propagates the device attribute on resources from callers to "
68 "callees.";
69 }
70
71 void runOnOperation() override;
72 };
73
74 // A class that records each resource's device assignment in a function.
75 class PerFunctionResult {
76 public:
PerFunctionResult(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)77 explicit PerFunctionResult(
78 FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
79 : alias_analysis_(alias_analysis) {}
80
81 // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const82 Optional<StringRef> DeviceForResource(Value resource) const {
83 Optional<StringRef> result;
84 if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
85 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
86 auto it = resource_id_to_device_.find(id);
87 if (it == resource_id_to_device_.end()) continue;
88 if (!result || result == it->second) {
89 result = it->getSecond();
90 continue;
91 }
92 // Got conflicting assignments
93 return llvm::None;
94 }
95 return result;
96 }
97
98 // Records the device assignment for a resource. If the new assignment
99 // conflicts with an existing one, returns an error.
100 //
101 // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)102 LogicalResult AddResourceDevice(Value resource, StringRef device,
103 bool* changed = nullptr) {
104 if (alias_analysis_.IsUnknownResource(resource)) return success();
105 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
106 auto emplace_res = resource_id_to_device_.try_emplace(id, device);
107 if (emplace_res.second) {
108 if (changed) *changed = true;
109 } else if (emplace_res.first->getSecond() != device) {
110 // Existing assignment does not equal the new assignment.
111 return failure();
112 }
113 }
114 return success();
115 }
116
117 private:
118 llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
119 const TF::ResourceAliasAnalysis::Info& alias_analysis_;
120 };
121
122 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)123 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
124 Operation* error_reporting_op,
125 PerFunctionResult* result,
126 bool* changed = nullptr) {
127 auto res = result->AddResourceDevice(resource, device, changed);
128 if (failed(res)) {
129 error_reporting_op->emitError()
130 << "Conflicting device assignment for resource";
131 }
132 return res;
133 }
134
135 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(FuncOp func,int arg_no)136 inline StringRef GetDeviceAttr(FuncOp func, int arg_no) {
137 auto device_attr =
138 func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
139 return device_attr ? device_attr.getValue() : "";
140 }
141
142 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)143 inline StringRef GetDeviceAttr(Operation* op) {
144 auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
145 return device_attr ? device_attr.getValue() : "";
146 }
147
148 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)149 void dump(StringRef message, Operation* op) {
150 llvm::dbgs() << message;
151 op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
152 llvm::dbgs() << "\n";
153 }
154
155 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(FuncOp func_op,PerFunctionResult * result)156 LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
157 PerFunctionResult* result) {
158 OpBuilder builder(func_op);
159 // Function arguments.
160 for (auto arg : filter_resources(func_op.getArguments())) {
161 StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
162 if (device_attr.empty()) {
163 // If device_attr does not exist, try to construct it from any recorded
164 // assignment.
165 if (auto device = result->DeviceForResource(arg)) {
166 func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
167 builder.getStringAttr(*device));
168 }
169 continue;
170 }
171 // Record the attribute.
172 auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
173 if (failed(res)) return res;
174 }
175
176 // To support WhileRegion, we need to propagate device attributes from
177 // WhileRegion operands to body/cond region arguments *prior* to visiting
178 // these regions. Use tensorflow::walk() instead of MLIR core walker to
179 // implement such a pre-order walk.
180 auto walk_res = tensorflow::GenericWalk(
181 func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
182 // We just need to visit operations in pre-order mode.
183 if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
184
185 if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
186 // Record VarHandleOp's device attribute.
187 StringRef device_attr = GetDeviceAttr(op);
188 if (device_attr.empty()) return WalkResult::advance();
189 auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
190 device_attr, op, result);
191 if (failed(res)) return WalkResult::interrupt();
192 } else if (auto identity = dyn_cast<IdentityOp>(op)) {
193 LLVM_DEBUG(dump("Visiting ", identity));
194 // Try to construct IdentityOp's attribute from recorded assignment.
195 if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
196 for (auto output : filter_resources(op->getResults())) {
197 LLVM_DEBUG(llvm::dbgs() << " Processing output #"
198 << output.getResultNumber() << "\n");
199 if (auto device = result->DeviceForResource(output)) {
200 LLVM_DEBUG(llvm::dbgs()
201 << " Setting device = " << *device << "\n");
202 identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
203 }
204 }
205 } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
206 // For WhileRegion, do local analysis prior to visiting the attached
207 // regions and propagate device annotations to the cond and body
208 // region arguments. The annotations are the union of annotations
209 // on the input and result. Resource alias analysis already propagates
210 // resource ID from the inputs to the results for a while, so just
211 // need to consider the results.
212 LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
213
214 for (auto output : filter_resources(while_region.getResults())) {
215 auto device = result->DeviceForResource(output);
216 int output_index = output.getResultNumber();
217 if (!device) {
218 LLVM_DEBUG(llvm::dbgs()
219 << " No device for output #" << output_index << "\n");
220 continue;
221 }
222 // Transfer the annotation to both region arguments
223 for (Region* region : while_region.getRegions()) {
224 BlockArgument arg = region->getArgument(output_index);
225 LLVM_DEBUG(llvm::dbgs()
226 << " Propagating device = '" << *device
227 << "' to arg #" << output_index << " of region #"
228 << region->getRegionNumber() << "\n");
229 if (failed(AddResourceDeviceAndEmitError(arg, *device,
230 while_region, result)))
231 return WalkResult::interrupt();
232 }
233 }
234 }
235 return WalkResult::advance();
236 });
237 return failure(walk_res.wasInterrupted());
238 }
239
runOnOperation()240 void ResourceDeviceInference::runOnOperation() {
241 auto module = getOperation();
242 const auto& resource_alias_analysis =
243 getAnalysis<TF::ResourceAliasAnalysis>();
244
245 llvm::SmallDenseMap<FuncOp, PerFunctionResult, 4> per_function_results;
246 llvm::SetVector<FuncOp> worklist;
247 for (auto func_op : module.getOps<FuncOp>()) {
248 worklist.insert(func_op);
249 per_function_results.try_emplace(
250 func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
251 }
252 // Helper that propagates an op's recorded operand device assignments to its
253 // called function's arguments.
254 auto propagate_operands_to_callee_arguments =
255 [&](Operation* caller, Operation::operand_range caller_operands,
256 ArrayRef<FuncOp> callees, const PerFunctionResult& caller_res) {
257 for (FuncOp callee : callees) {
258 assert(callee);
259 auto& callee_res = per_function_results.find(callee)->getSecond();
260 bool callee_needs_recompute = false;
261 for (BlockArgument arg : filter_resources(callee.getArguments())) {
262 Value arg_operand = caller_operands[arg.getArgNumber()];
263 auto device = caller_res.DeviceForResource(arg_operand);
264 if (!device) continue;
265 LLVM_DEBUG(llvm::dbgs()
266 << "Propagating '" << *device << "' to arg #"
267 << arg.getArgNumber() << " of function @"
268 << callee.getName() << "\n");
269 if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
270 &callee_res,
271 &callee_needs_recompute)))
272 return failure();
273 }
274 // If the callee recording is modified, make sure that it will be
275 // reprocessed.
276 if (callee_needs_recompute) worklist.insert(callee);
277 }
278 return success();
279 };
280
281 while (!worklist.empty()) {
282 auto func_op = worklist.pop_back_val();
283 auto& func_res = per_function_results.find(func_op)->getSecond();
284 // In-function propagation.
285 if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
286 return signalPassFailure();
287
288 // Propagation to callees.
289 auto walk_res = func_op.walk([&](Operation* op) {
290 if (auto while_op = dyn_cast<WhileOp>(op)) {
291 if (failed(propagate_operands_to_callee_arguments(
292 while_op, while_op.getOperands(),
293 {while_op.body_function(), while_op.cond_function()},
294 func_res)))
295 return WalkResult::interrupt();
296 } else if (auto if_op = dyn_cast<IfOp>(op)) {
297 if (failed(propagate_operands_to_callee_arguments(
298 if_op, if_op.input(),
299 {if_op.then_function(), if_op.else_function()}, func_res)))
300 return WalkResult::interrupt();
301 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
302 auto func = dyn_cast<FuncOp>(call.resolveCallable());
303 if (!func) {
304 op->emitError(
305 "Cannot propagate device attribute to callee: Unable to resolve "
306 "call");
307 return WalkResult::interrupt();
308 }
309 LLVM_DEBUG(llvm::dbgs()
310 << "Visiting call to function @" << func.getName() << "\n");
311 if (failed(propagate_operands_to_callee_arguments(
312 call, call.getArgOperands(), {func}, func_res)))
313 return WalkResult::interrupt();
314 }
315 return WalkResult::advance();
316 });
317 if (walk_res.wasInterrupted()) return signalPassFailure();
318 }
319 }
320
321 PassRegistration<ResourceDeviceInference> pass;
322
323 } // namespace
324
CreateResourceDeviceInferencePass()325 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
326 return std::make_unique<ResourceDeviceInference>();
327 }
328
329 } // namespace TF
330 } // namespace mlir
331