1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <tuple>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/iterator_range.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/Debug.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/OpImplementation.h" // from @llvm-project
34 #include "mlir/IR/Operation.h" // from @llvm-project
35 #include "mlir/IR/OperationSupport.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/IR/Visitors.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h"
46
47 #define DEBUG_TYPE "tf-resource-device-inference"
48
49 namespace mlir {
50 namespace TF {
51
52 namespace {
53 constexpr char kDeviceAttr[] = "device";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55
56 // A pass that propagates device assignment of resources on a module. It
57 // performs in-function propagation, as well as cross-function propagation from
58 // callers to callees.
59 //
60 // This pass changes the module by adding "tf.device" attribute to function
61 // arguments and adding "device" attribute to TF ops.
62 struct ResourceDeviceInference
63 : public PassWrapper<ResourceDeviceInference, OperationPass<ModuleOp>> {
64 void runOnOperation() override;
65 };
66
67 // A class that records each resource's device assignment in a function.
68 class PerFunctionResult {
69 public:
PerFunctionResult(FuncOp func_op,const TF::ResourceAliasAnalysis::Info & alias_analysis)70 explicit PerFunctionResult(
71 FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis)
72 : alias_analysis_(alias_analysis) {}
73
74 // Returns the recorded device assignment for a resource, if any.
DeviceForResource(Value resource) const75 Optional<StringRef> DeviceForResource(Value resource) const {
76 Optional<StringRef> result;
77 if (alias_analysis_.IsUnknownResource(resource)) return llvm::None;
78 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
79 auto it = resource_id_to_device_.find(id);
80 if (it == resource_id_to_device_.end()) continue;
81 if (!result || result == it->second) {
82 result = it->getSecond();
83 continue;
84 }
85 // Got conflicting assignments
86 return llvm::None;
87 }
88 return result;
89 }
90
91 // Records the device assignment for a resource. If the new assignment
92 // conflicts with an existing one, returns an error.
93 //
94 // If `changed` is provided, assign *changed to true if anything is modified.
AddResourceDevice(Value resource,StringRef device,bool * changed=nullptr)95 LogicalResult AddResourceDevice(Value resource, StringRef device,
96 bool* changed = nullptr) {
97 if (alias_analysis_.IsUnknownResource(resource)) return success();
98 for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) {
99 auto emplace_res = resource_id_to_device_.try_emplace(id, device);
100 if (emplace_res.second) {
101 if (changed) *changed = true;
102 } else if (emplace_res.first->getSecond() != device) {
103 // Existing assignment does not equal the new assignment.
104 return failure();
105 }
106 }
107 return success();
108 }
109
110 private:
111 llvm::SmallDenseMap<int64_t, StringRef, 8> resource_id_to_device_;
112 const TF::ResourceAliasAnalysis::Info& alias_analysis_;
113 };
114
115 // Tries to record device assignment for a resource.
AddResourceDeviceAndEmitError(Value resource,StringRef device,Operation * error_reporting_op,PerFunctionResult * result,bool * changed=nullptr)116 LogicalResult AddResourceDeviceAndEmitError(Value resource, StringRef device,
117 Operation* error_reporting_op,
118 PerFunctionResult* result,
119 bool* changed = nullptr) {
120 auto res = result->AddResourceDevice(resource, device, changed);
121 if (failed(res)) {
122 error_reporting_op->emitError()
123 << "Conflicting device assignment for resource";
124 }
125 return res;
126 }
127
128 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(FuncOp func,int arg_no)129 inline StringRef GetDeviceAttr(FuncOp func, int arg_no) {
130 auto device_attr =
131 func.getArgAttrOfType<mlir::StringAttr>(arg_no, kFuncDeviceAttr);
132 return device_attr ? device_attr.getValue() : "";
133 }
134
135 // Extracts and canonicalizes the device attribute.
GetDeviceAttr(Operation * op)136 inline StringRef GetDeviceAttr(Operation* op) {
137 auto device_attr = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
138 return device_attr ? device_attr.getValue() : "";
139 }
140
141 // Print operation with debug info (to get line number info for debugging)
dump(StringRef message,Operation * op)142 void dump(StringRef message, Operation* op) {
143 llvm::dbgs() << message;
144 op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true));
145 llvm::dbgs() << "\n";
146 }
147
148 // Propagates device assignment inside a function.
ComputeResourceDevicesInComputation(FuncOp func_op,PerFunctionResult * result)149 LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
150 PerFunctionResult* result) {
151 OpBuilder builder(func_op);
152 // Function arguments.
153 for (auto arg : filter_resources(func_op.getArguments())) {
154 StringRef device_attr = GetDeviceAttr(func_op, arg.getArgNumber());
155 if (device_attr.empty()) {
156 // If device_attr does not exist, try to construct it from any recorded
157 // assignment.
158 if (auto device = result->DeviceForResource(arg)) {
159 func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
160 builder.getStringAttr(*device));
161 }
162 continue;
163 }
164 // Record the attribute.
165 auto res = AddResourceDeviceAndEmitError(arg, device_attr, func_op, result);
166 if (failed(res)) return res;
167 }
168
169 // To support WhileRegion, we need to propagate device attributes from
170 // WhileRegion operands to body/cond region arguments *prior* to visiting
171 // these regions. Use tensorflow::walk() instead of MLIR core walker to
172 // implement such a pre-order walk.
173 auto walk_res = tensorflow::GenericWalk(
174 func_op, [&](Operation* op, const tensorflow::WalkStage& stage) {
175 // We just need to visit operations in pre-order mode.
176 if (!stage.IsBeforeAllRegions()) return WalkResult::advance();
177
178 if (auto var_handle = dyn_cast<VarHandleOp>(op)) {
179 // Record VarHandleOp's device attribute.
180 StringRef device_attr = GetDeviceAttr(op);
181 if (device_attr.empty()) return WalkResult::advance();
182 auto res = AddResourceDeviceAndEmitError(var_handle.resource(),
183 device_attr, op, result);
184 if (failed(res)) return WalkResult::interrupt();
185 } else if (auto identity = dyn_cast<IdentityOp>(op)) {
186 LLVM_DEBUG(dump("Visiting ", identity));
187 // Try to construct IdentityOp's attribute from recorded assignment.
188 if (!GetDeviceAttr(op).empty()) return WalkResult::advance();
189 for (auto output : filter_resources(op->getResults())) {
190 LLVM_DEBUG(llvm::dbgs() << " Processing output #"
191 << output.getResultNumber() << "\n");
192 if (auto device = result->DeviceForResource(output)) {
193 LLVM_DEBUG(llvm::dbgs()
194 << " Setting device = " << *device << "\n");
195 identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
196 }
197 }
198 } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
199 // For WhileRegion, do local analysis prior to visiting the attached
200 // regions and propagate device annotations to the cond and body
201 // region arguments. The annotations are the union of annotations
202 // on the input and result. Resource alias analysis already propagates
203 // resource ID from the inputs to the results for a while, so just
204 // need to consider the results.
205 LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n");
206
207 for (auto output : filter_resources(while_region.getResults())) {
208 auto device = result->DeviceForResource(output);
209 int output_index = output.getResultNumber();
210 if (!device) {
211 LLVM_DEBUG(llvm::dbgs()
212 << " No device for output #" << output_index << "\n");
213 continue;
214 }
215 // Transfer the annotation to both region arguments
216 for (Region* region : while_region.getRegions()) {
217 BlockArgument arg = region->getArgument(output_index);
218 LLVM_DEBUG(llvm::dbgs()
219 << " Propagating device = '" << *device
220 << "' to arg #" << output_index << " of region #"
221 << region->getRegionNumber() << "\n");
222 if (failed(AddResourceDeviceAndEmitError(arg, *device,
223 while_region, result)))
224 return WalkResult::interrupt();
225 }
226 }
227 }
228 return WalkResult::advance();
229 });
230 return failure(walk_res.wasInterrupted());
231 }
232
runOnOperation()233 void ResourceDeviceInference::runOnOperation() {
234 auto module = getOperation();
235 const auto& resource_alias_analysis =
236 getAnalysis<TF::ResourceAliasAnalysis>();
237
238 llvm::SmallDenseMap<FuncOp, PerFunctionResult, 4> per_function_results;
239 llvm::SetVector<FuncOp> worklist;
240 for (auto func_op : module.getOps<FuncOp>()) {
241 worklist.insert(func_op);
242 per_function_results.try_emplace(
243 func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op));
244 }
245 // Helper that propagates an op's recorded operand device assignments to its
246 // called function's arguments.
247 auto propagate_operands_to_callee_arguments =
248 [&](Operation* caller, Operation::operand_range caller_operands,
249 ArrayRef<FuncOp> callees, const PerFunctionResult& caller_res) {
250 for (FuncOp callee : callees) {
251 assert(callee);
252 auto& callee_res = per_function_results.find(callee)->getSecond();
253 bool callee_needs_recompute = false;
254 for (BlockArgument arg : filter_resources(callee.getArguments())) {
255 Value arg_operand = caller_operands[arg.getArgNumber()];
256 auto device = caller_res.DeviceForResource(arg_operand);
257 if (!device) continue;
258 LLVM_DEBUG(llvm::dbgs()
259 << "Propagating '" << *device << "' to arg #"
260 << arg.getArgNumber() << " of function @"
261 << callee.getName() << "\n");
262 if (failed(AddResourceDeviceAndEmitError(arg, *device, caller,
263 &callee_res,
264 &callee_needs_recompute)))
265 return failure();
266 }
267 // If the callee recording is modified, make sure that it will be
268 // reprocessed.
269 if (callee_needs_recompute) worklist.insert(callee);
270 }
271 return success();
272 };
273
274 while (!worklist.empty()) {
275 auto func_op = worklist.pop_back_val();
276 auto& func_res = per_function_results.find(func_op)->getSecond();
277 // In-function propagation.
278 if (failed(ComputeResourceDevicesInComputation(func_op, &func_res)))
279 return signalPassFailure();
280
281 // Propagation to callees.
282 auto walk_res = func_op.walk([&](Operation* op) {
283 if (auto while_op = dyn_cast<WhileOp>(op)) {
284 if (failed(propagate_operands_to_callee_arguments(
285 while_op, while_op.getOperands(),
286 {while_op.body_function(), while_op.cond_function()},
287 func_res)))
288 return WalkResult::interrupt();
289 } else if (auto if_op = dyn_cast<IfOp>(op)) {
290 if (failed(propagate_operands_to_callee_arguments(
291 if_op, if_op.input(),
292 {if_op.then_function(), if_op.else_function()}, func_res)))
293 return WalkResult::interrupt();
294 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
295 auto func = dyn_cast<FuncOp>(call.resolveCallable());
296 if (!func) {
297 op->emitError(
298 "Cannot propagate device attribute to callee: Unable to resolve "
299 "call");
300 return WalkResult::interrupt();
301 }
302 LLVM_DEBUG(llvm::dbgs()
303 << "Visiting call to function @" << func.getName() << "\n");
304 if (failed(propagate_operands_to_callee_arguments(
305 call, call.getArgOperands(), {func}, func_res)))
306 return WalkResult::interrupt();
307 }
308 return WalkResult::advance();
309 });
310 if (walk_res.wasInterrupted()) return signalPassFailure();
311 }
312 }
313
314 PassRegistration<ResourceDeviceInference> pass(
315 "tf-resource-device-inference",
316 "Propagates the device attribute on resources from callers to callees.");
317
318 } // namespace
319
CreateResourceDeviceInferencePass()320 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceDeviceInferencePass() {
321 return std::make_unique<ResourceDeviceInference>();
322 }
323
324 } // namespace TF
325 } // namespace mlir
326