1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass lifts resource variable operations outside of device computation.
17
18 #include <cstddef>
19 #include <cstdint>
20
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
34 #include "mlir/IR/Builders.h" // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
37 #include "mlir/IR/Diagnostics.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/Region.h" // from @llvm-project
40 #include "mlir/IR/SymbolTable.h" // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
42 #include "mlir/IR/Types.h" // from @llvm-project
43 #include "mlir/IR/Value.h" // from @llvm-project
44 #include "mlir/IR/Verifier.h" // from @llvm-project
45 #include "mlir/IR/Visitors.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LLVM.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61
62 namespace mlir {
63
64 namespace {
65
66 constexpr char kDeviceAttr[] = "device";
67
68 // Lift resource operations out of device computation.
69 struct ResourceOpLiftingPass
70 : public TFDevice::ResourceOpLiftingPassBase<ResourceOpLiftingPass> {
71 void runOnOperation() override;
72 };
73
IsResource(Value value)74 bool IsResource(Value value) {
75 return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
76 }
77
78 // Get the type of the data contained in a resource. Returns null if there is
79 // no single type in the resource.
GetResourceSubtype(Value value)80 Type GetResourceSubtype(Value value) {
81 auto resource_type =
82 getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
83 auto subtypes = resource_type.getSubtypes();
84 if (subtypes.size() == 1) return subtypes[0];
85 return nullptr;
86 }
87
88 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
89 // TODO(b/171039585): Replace this with proper analysis of
90 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)91 void SetAllVarIsInitializedToTrue(Block* block) {
92 auto builder = OpBuilder::atBlockBegin(block);
93 TF::ConstOp const_true = nullptr;
94 for (auto op :
95 llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
96 builder.setInsertionPoint(op);
97 if (!const_true)
98 const_true = builder.create<TF::ConstOp>(
99 op.getLoc(),
100 DenseIntElementsAttr::get(
101 RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
102
103 op.is_initialized().replaceAllUsesWith(const_true);
104 op.erase();
105 }
106 }
107
108 // Performs store-load forwarding. This effectively removes
109 // 1) Any resource loads after a store to that same resource is done
110 // 2) Any resource stores except the last one.
111 // TODO(ycao): Store-load forwarding implemented here is only correct when
112 // computation is purely sequential (no concurrency). Need to support concurrent
113 // computation as well.
ForwardStoreToLoad(Block * block)114 void ForwardStoreToLoad(Block* block) {
115 // resource_handle_to_last_store_op keeps track of the most recent (last)
116 // store to each resource. Non-existent entry indicates that a resource has
117 // not been stored to yet.
118 llvm::SmallDenseMap<Value, TF::AssignVariableOp>
119 resource_handle_to_last_store_op;
120
121 // Only iterate through ops directly in the block as we can't handle ops
122 // nested deeper in regions.
123 for (Operation& op : llvm::make_early_inc_range(*block)) {
124 if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
125 Value resource = read_variable_op.resource();
126 auto last_store = resource_handle_to_last_store_op[resource];
127 if (!last_store) continue;
128
129 // Use stored value in last_store to replace all uses of current resource
130 // load's result, then erase this resource load. Add an intermediate
131 // CastOp if the shape of types doesn't exactly match.
132 Type read_type = read_variable_op.value().getType();
133 if (read_type != last_store.value().getType()) {
134 OpBuilder builder(last_store);
135 builder.setInsertionPointAfter(last_store);
136 auto cast = builder.create<TF::CastOp>(
137 last_store.getLoc(), read_type, last_store.value(),
138 /*Truncate=*/builder.getBoolAttr(false));
139 read_variable_op.value().replaceAllUsesWith(cast);
140 } else {
141 read_variable_op.value().replaceAllUsesWith(last_store.value());
142 }
143
144 read_variable_op.erase();
145 continue;
146 }
147
148 if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
149 Value resource = assign_variable_op.resource();
150 auto last_store = resource_handle_to_last_store_op[resource];
151 // Previous store ops to same resource can be erased.
152 if (last_store) last_store.erase();
153
154 resource_handle_to_last_store_op[resource] = assign_variable_op;
155 }
156 }
157 }
158
159 //===----------------------------------------------------------------------===//
160 // RegionResourceHoister
161 //===----------------------------------------------------------------------===//
162
163 // Helper class to hoist resource ops out of regions attached to an op.
164 class RegionResourceHoister {
165 public:
RegionResourceHoister(Operation * op)166 explicit RegionResourceHoister(Operation* op) : op_(op) {}
167
168 // Analyzes attached regions to record resources read and written.
169 LogicalResult Analyze();
170
171 // Returns all resources accessed by the regions attached the op.
GetResources()172 auto& GetResources() { return resources_; }
173
174 // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const175 bool Contains(Value resource) const {
176 return resources_.find(resource) != resources_.end();
177 }
178
179 // Drops the given resource from lifting.
DropResource(Value resource)180 void DropResource(Value resource) {
181 resources_.erase(resource);
182 written_resources_.remove(resource);
183 }
184
185 // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)186 void ReplaceResourceLoads(bool read_only) {
187 llvm::for_each(op_->getRegions(), [&](Region& region) {
188 ReplaceResourceLoads(region, read_only);
189 });
190 }
191
192 static LogicalResult ReplaceOpWithNewOp(Operation* op);
193
194 private:
195 // Returns if any resources need lifting.
NeedsLifting() const196 bool NeedsLifting() const { return !resources_.empty(); }
197
198 // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const199 int GetLiftedNumResults() const { return num_new_results_; }
200
201 // Generates hoisted reads for resources that need them before the op.
202 void GenerateHoistedReads();
203
204 // Replaces all resource loads in the given region with hoisted loads. If
205 // `read_only` is true, limit this replacement to read only resources.
206 void ReplaceResourceLoads(Region& region, bool read_only);
207
208 // Appends final values writte to resources to the region returns for the
209 // given set of regions.
210 void AppendResourceStoreValueToReturn(RegionRange regions);
211
212 // Performs the final replacement of the op.
213 void ReplaceOpWithNewOp();
214
215 // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const216 bool IsWritten(Value resource) const {
217 return written_resources_.contains(resource);
218 }
219
220 static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
221 static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
222
223 Operation* op_;
224
225 // Per resource information about accesses to that resource.
226 struct ResourceInfo {
227 // Is this resource read in any of the regions?
228 bool is_read;
229 // Is this resource written in any of the regions?
230 bool is_written;
231 // Is this resource written in all of the regions?
232 bool is_written_all;
233 // The hoisted read used to replace region reads.
234 Value hoisted_read;
235 // the type of the data held by the resource.
236 Type data_type;
237 // For written resources, the result # of the lifted op which will hold the
238 // value of the resource. This result will be used to generates writes to
239 // the resource after the lifted op.
240 int result_index;
241 // Attributes on the read operation.
242 DictionaryAttr read_attrs;
243 // Attributes on the write operation.
244 DictionaryAttr write_attrs;
245
ResourceInfomlir::__anonfff803140111::RegionResourceHoister::ResourceInfo246 ResourceInfo()
247 : is_read(false),
248 is_written(false),
249 is_written_all(false),
250 hoisted_read(nullptr),
251 data_type(nullptr),
252 result_index(-1) {}
253
IsResultIndexAssignedmlir::__anonfff803140111::RegionResourceHoister::ResourceInfo254 bool IsResultIndexAssigned() { return result_index != -1; }
255
256 // Refine the resource type using the given type `type`.
RefineTypemlir::__anonfff803140111::RegionResourceHoister::ResourceInfo257 void RefineType(Type type) {
258 if (!data_type) {
259 data_type = type;
260 } else {
261 data_type = TF::GetCastCompatibleType(data_type, type,
262 /*may_ignore_ref_type_a=*/false);
263 assert(data_type != nullptr && "Resource used with incompatible types");
264 }
265 }
266 };
267 llvm::MapVector<Value, ResourceInfo> resources_;
268 llvm::SetVector<Value> written_resources_;
269 // number of new results after lifting.
270 int num_new_results_;
271 };
272
273 // Analyzes resources that are read or written within attached regions.
Analyze()274 LogicalResult RegionResourceHoister::Analyze() {
275 // Hoisting of child regions might have created opportunity for store-load
276 // forwarding.
277 for (Region& region : op_->getRegions()) {
278 ForwardStoreToLoad(®ion.front());
279 }
280
281 llvm::SetVector<Value> all_resources;
282 bool is_func = false;
283 // For functions, the resources to analyze are the function arguments.
284 // Otherwise, its the region captures.
285 if (FuncOp func = dyn_cast<FuncOp>(op_)) {
286 is_func = true;
287 Region& body = func.getBody();
288 for (BlockArgument arg : body.getArguments()) {
289 if (IsResource(arg)) all_resources.insert(arg);
290 }
291 } else {
292 getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
293 all_resources.remove_if([](Value value) { return !IsResource(value); });
294 }
295
296 num_new_results_ = op_->getNumResults();
297
298 for (auto resource : all_resources) {
299 ResourceInfo info;
300 info.data_type = GetResourceSubtype(resource);
301 llvm::BitVector written_regions(op_->getNumRegions());
302 bool unsupported_use = false;
303 for (OpOperand& use : resource.getUses()) {
304 Operation* user = use.getOwner();
305 // If the user is not in one of the regions, we are not interested in it.
306 // Since all the sub-regions within this region (i.e., regions attached to
307 // op's in this region) have themselves gone through lifting, all resource
308 // users are expected to be operations in this region and not embedded
309 // within other sub-regions attached to op's in this region. So the check
310 // for whether a user is in one of the regions attached to this op is
311 // straightforward.
312 if (user->getParentRegion()->getParentOp() != op_) continue;
313
314 // For functions, if the resource is used as a return operand, use that
315 // as its result index.
316 if (is_func && isa<ReturnOp>(user)) {
317 assert(!info.IsResultIndexAssigned() &&
318 "Expect resource argument to returned no more than once");
319 info.result_index = use.getOperandNumber();
320 continue;
321 }
322
323 auto read = dyn_cast<TF::ReadVariableOp>(user);
324 auto write = dyn_cast<TF::AssignVariableOp>(user);
325 if (!read && !write) {
326 unsupported_use = true;
327 break;
328 }
329
330 if (read && !info.is_read) {
331 info.is_read = true;
332 info.RefineType(read.value().getType());
333 info.read_attrs = user->getAttrDictionary();
334 }
335
336 if (write) {
337 info.is_written = true;
338 info.RefineType(write.value().getType());
339 info.write_attrs = user->getAttrDictionary();
340 written_regions.set(user->getParentRegion()->getRegionNumber());
341 }
342 }
343
344 // If the resource is used in an op that we do not understand, skip
345 // lifting for that resource.
346 if (unsupported_use) continue;
347
348 info.is_written_all = written_regions.count() == op_->getNumRegions();
349
350 // If the resource is written in some but not all regions, we would need
351 // a read for the value before these regions. Note that this is applicable
352 // only to multi-region ops:
353 // If/Case: If not all regions write to the resource, post hoisting the read
354 // value need to be routed through all paths that don't write.
355 // While: since while condition cannot write, any resource written in the
356 // while body will need to be read as well in case the while body is never
357 // executed.
358 // Both cases are handled by the condition below.
359 if (info.is_written && !info.is_written_all) info.is_read = true;
360
361 // Allocate a result index for written resources that don't have one.
362 if (info.is_written) {
363 written_resources_.insert(resource);
364 if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
365 }
366
367 resources_.insert({resource, info});
368 }
369 return success();
370 }
371
372 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()373 void RegionResourceHoister::GenerateHoistedReads() {
374 OpBuilder builder(op_);
375 DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
376 for (auto& resource_it : GetResources()) {
377 Value resource = resource_it.first;
378 auto& info = resource_it.second;
379
380 if (info.is_read) {
381 Operation* read = builder.create<TF::ReadVariableOp>(
382 op_->getLoc(), info.data_type, resource);
383 read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
384 read->removeAttr(kDeviceAttr);
385 info.hoisted_read = read->getResult(0);
386 }
387 }
388 }
389
390 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)391 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
392 bool read_only) {
393 assert(llvm::hasSingleElement(region) && "Expected single block region");
394 // Only iterate through ops directly in the body as we can't handle
395 // ops nested deeper in regions.
396 auto all_reads = region.front().getOps<TF::ReadVariableOp>();
397 for (auto read_op : llvm::make_early_inc_range(all_reads)) {
398 Value resource = read_op.resource();
399 if (!Contains(resource)) continue;
400
401 ResourceInfo& info = resources_[resource];
402 // If replacing loads for read only resources, skip if the resource
403 // was written to.
404 if (read_only && info.is_written) continue;
405
406 read_op.replaceAllUsesWith(info.hoisted_read);
407 read_op.erase();
408 }
409 }
410
411 // For written resources, add its value at the end of each region to that
412 // regions return value. For a region, its value at the end may be a value
413 // written to that resource in that region, or its hoisted read value if the
414 // resource is not written in that region. The return value can be vended out
415 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)416 void RegionResourceHoister::AppendResourceStoreValueToReturn(
417 RegionRange regions) {
418 for (Region* region : regions) {
419 assert(llvm::hasSingleElement(*region) && "Expected single block region");
420 Block& front = region->front();
421 auto old_return = front.getTerminator();
422 assert(old_return->getNumOperands() == op_->getNumResults());
423 auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
424 new_return_operands.resize(num_new_results_);
425
426 // initialize return values for written resources to be the hoisted reads.
427 for (Value resource : written_resources_) {
428 const ResourceInfo& info = resources_[resource];
429 new_return_operands[info.result_index] = info.hoisted_read;
430 }
431
432 // Only iterate through ops directly in the body as op's embedded in child
433 // regions should have been lifted out.
434 auto assign_ops = front.getOps<TF::AssignVariableOp>();
435 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
436 Value resource = assign_variable_op.resource();
437 if (!IsWritten(resource)) continue;
438
439 // TODO(ycao): Prevent same value from being returned multiple times.
440 // TODO(ycao): Do not return resource store value if it is defined outside
441 // of cluster. Both of these can be post-resource-op-lifting cleanup
442 // passes.
443 int result_index = resources_[resource].result_index;
444 new_return_operands[result_index] = assign_variable_op.value();
445 assign_variable_op.erase();
446 }
447 old_return->setOperands(new_return_operands);
448 }
449 }
450
451 // Replace the old op with a new op (with potentially additional results), and
452 // add stores to written resources after the new op.
ReplaceOpWithNewOp()453 void RegionResourceHoister::ReplaceOpWithNewOp() {
454 auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
455 int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
456 Operation* terminator = op_->getRegion(result_region).front().getTerminator();
457 auto extra_result_types =
458 terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
459 new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
460 extra_result_types.end());
461 OpBuilder builder(op_);
462 // Clone this old operation but with new result types.
463 Operation* new_op = Operation::create(
464 op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
465 op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
466 builder.insert(new_op);
467
468 // Move regions to the new op.
469 for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
470 Region& old_region = std::get<0>(it);
471 Region& new_region = std::get<1>(it);
472 new_region.takeBody(old_region);
473 }
474
475 // Insert stores to all written resources.
476 for (Value resource : written_resources_) {
477 ResourceInfo& info = resources_[resource];
478 Value value_to_write = new_op->getResult(info.result_index);
479 Operation* write = builder.create<TF::AssignVariableOp>(
480 op_->getLoc(), resource, value_to_write);
481 write->setAttrs(info.write_attrs);
482 write->removeAttr(kDeviceAttr);
483 }
484
485 // As a part of lifting, we either reuse an existing slot for resource type
486 // results or add a new slot. Resource type results should not have any uses
487 // to begin with. So we can safely replace each old op result with the
488 // corresponding new op result.
489 int old_num_results = op_->getNumResults();
490 op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
491 op_->erase();
492 op_ = nullptr;
493 }
494
495 // Lift resource load and stores out of regions attached to `op`, where op is
496 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)497 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
498 Operation* op) {
499 RegionResourceHoister hoister(op);
500 if (failed(hoister.Analyze())) return failure();
501
502 // If there are no resource region captures, then nothing to do.
503 if (!hoister.NeedsLifting()) return success();
504
505 // Start the transformation. For each region, replace the resource read with
506 // the value read before the op.
507 hoister.GenerateHoistedReads();
508 hoister.ReplaceResourceLoads(/*read_only=*/false);
509 hoister.AppendResourceStoreValueToReturn(op->getRegions());
510 hoister.ReplaceOpWithNewOp();
511 return success();
512 }
513
514 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)515 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
516 TF::WhileRegionOp op) {
517 // For WhileRegion, post canonicalization all resource used within the
518 // body and condition regions are replaced with captured values, so we do not
519 // need to take into account the body and condition region arguments.
520 RegionResourceHoister hoister(op);
521
522 if (failed(hoister.Analyze())) return failure();
523
524 // If there are no resource region captures, then nothing to do.
525 if (!hoister.NeedsLifting()) return success();
526
527 // The resources captured for While loop fall into two categories:
528 // (a) read-only. These reads can be replaced by a hoisted read created
529 // before the WhileOp (similar to if and case).
530 // (b) written: since the value is written in the loop (which can only in
531 // loop body, all these will become loop variables. Since all resource
532 // variables are removed from the loop variabled during
533 // canonicalizationW, we need to create new operand/result slots. The
534 // input operands for these slots are the read values
535 // prior to the op, and all references to these are replaced by the
536 // corresponding slot argument. We need to generate writes following
537 // the while for these resources.
538 //
539 // Note that for WhileRegion ops, if a resource is written, it will be written
540 // only in the body and not the condition, so the hoister analysis will infer
541 // it as needing a read as well.
542
543 // Generate hoisted reads before the while.
544 hoister.GenerateHoistedReads();
545
546 // Replace just the read-only resources with the hoisted reads.
547 hoister.ReplaceResourceLoads(/*read_only=*/true);
548
549 // For written resources, add additional operands to the while op.
550 int num_old_results = op.getNumResults();
551 int num_new_results = hoister.GetLiftedNumResults();
552 int num_extra_results = num_new_results - num_old_results;
553
554 SmallVector<Type, 4> new_result_types;
555 SmallVector<Value, 4> new_while_operands;
556 new_result_types.resize(num_extra_results);
557 new_while_operands.resize(num_extra_results);
558
559 for (auto& it : hoister.GetResources()) {
560 if (!it.second.is_written) continue;
561 int index = it.second.result_index - num_old_results;
562 new_result_types[index] = it.second.data_type;
563 new_while_operands[index] = it.second.hoisted_read;
564 }
565 op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
566
567 // Patch the cond and body regions to have additional arguments, and replace
568 // the remaining resource reads (which will be resource reads for written
569 // resources) with these arguments.
570 for (Region* region : op.getRegions()) {
571 region->addArguments(new_result_types);
572 // Point hoisted read for written resources to the region's arguments.
573 for (auto& it : hoister.GetResources()) {
574 if (!it.second.is_written) continue;
575 it.second.hoisted_read = region->getArgument(it.second.result_index);
576 }
577 hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
578 }
579
580 // Add additional return values to body return. These correspond to values
581 // written to resources in the body region.
582 hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
583
584 // Finally, create a new while with additional return values.
585 hoister.ReplaceOpWithNewOp();
586 return success();
587 }
588
589 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)590 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
591 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
592 return HoistResourcesOutOfWhileRegion(while_op);
593 return HoistResourcesOutOfIfCaseCluster(op);
594 }
595
596 // Holds information about a function's use of a resource argument.
597 struct ResourceArgUseInfo {
598 // Data type of the data contained in the resource.
599 Type data_type;
600 // Is the resource argument used in an assign op?
601 bool updated;
602 // Is the resource argument used in a read or assign op?
603 bool used;
604 };
605
606 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
607 // output (i.e., the argument is an operand of the return op) is not considered
608 // as a use. This doesn't support nesting of ops, so before calling this, nested
609 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)610 LogicalResult FindResourceArgUseInfo(
611 FuncOp func_op, llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
612 auto return_op = func_op.front().getTerminator();
613 for (auto arg : TF::filter_resources(func_op.getArguments())) {
614 ResourceArgUseInfo info;
615 info.used = false;
616 info.updated = false;
617 bool read_or_assigned = false;
618 bool used_in_unsupported_op = false;
619 for (auto user : arg.getUsers()) {
620 if (user == return_op) continue;
621 info.used = true;
622 if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
623 read_or_assigned = true;
624 info.data_type = read.getType();
625 continue;
626 }
627
628 if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
629 read_or_assigned = true;
630 info.updated = true;
631 info.data_type = assign.value().getType();
632 continue;
633 }
634
635 used_in_unsupported_op = true;
636 break;
637 }
638
639 // If the arg is used in an unsupported op, skip lifting it.
640 if (used_in_unsupported_op) continue;
641 (*result)[arg.getArgNumber()] = info;
642 }
643 return success();
644 }
645
646 // Merges two sets of resource arg use infos. An argument is considered used in
647 // the merged result as long as either set marks it as used. This is used to
648 // merge results from functions that have aliasing inputs, e.g., a while loop's
649 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)650 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
651 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
652 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
653 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
654 for (const auto& entry : infos0) {
655 auto info1_it = infos1.find(entry.getFirst());
656 // If the entry is missing in any input, we should not touch this entry.
657 if (info1_it == infos1.end()) continue;
658 auto& info = result[entry.getFirst()];
659 info = entry.getSecond();
660 if (info.updated) continue;
661 if (info1_it->getSecond().used) {
662 info.used = true;
663 info.updated = info1_it->getSecond().updated;
664 info.data_type = info1_it->getSecond().data_type;
665 }
666 }
667 return result;
668 }
669
670 // Removes the unused resource arguments, and the return values that forward the
671 // removed arguments. If old_to_new_arg_indices is provided, it will store the
672 // new argument index that corresponds to each original index (-1 means it is
673 // removed). If remaining_resource_data_types is provided, it will store the
674 // data types of the remaining resource arguments, where the indices are after
675 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)676 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
677 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
678 FuncOp func_op,
679 llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
680 llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
681 nullptr) {
682 // Remove return values forwarded from unused arguments.
683 auto return_op = func_op.front().getTerminator();
684 auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
685 int64_t skipped_retvals = 0;
686 for (auto entry : llvm::enumerate(old_return_vals)) {
687 auto return_val = entry.value();
688 if (auto arg = return_val.dyn_cast<BlockArgument>()) {
689 auto it = infos.find(arg.getArgNumber());
690 if (it != infos.end() && !it->getSecond().used) {
691 return_op->eraseOperand(entry.index() - skipped_retvals++);
692 }
693 }
694 }
695 llvm::SmallVector<unsigned int, 4> indices_to_erase;
696 llvm::SmallVector<Type, 4> new_types;
697 int64_t skipped_args = 0;
698 for (auto arg : func_op.getArguments()) {
699 auto it = infos.find(arg.getArgNumber());
700 if (it != infos.end() && !it->getSecond().used) {
701 indices_to_erase.push_back(arg.getArgNumber());
702 skipped_args++;
703 if (old_to_new_arg_indices != nullptr) {
704 old_to_new_arg_indices->push_back(-1);
705 }
706 } else {
707 new_types.push_back(arg.getType());
708 if (old_to_new_arg_indices != nullptr) {
709 old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
710 }
711 if (it != infos.end() && remaining_resource_data_types != nullptr) {
712 (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
713 it->second.data_type;
714 }
715 }
716 }
717 func_op.eraseArguments(indices_to_erase);
718 func_op.setType(
719 FunctionType::get(func_op.getContext(), new_types,
720 llvm::to_vector<4>(return_op->getOperandTypes())));
721 }
722
723 // Lifts reads/writes of resource arguments from func_op and changes its
724 // signature. resource_data_types is the (index, data type) pair for each
725 // resource argument. handle_updated_arg_value is a caller-provided function
726 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)727 LogicalResult LiftArgRetResourcesForFunction(
728 FuncOp func_op,
729 const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
730 llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
731 RegionResourceHoister hoister(func_op);
732 if (failed(hoister.Analyze())) return failure();
733
734 // Each of these resources could be read or written in the function. If its
735 // read, we need to replace the resource arg with a value arg to get the
736 // read value. If its written, we need to replace the write with an additional
737 // value to be written.
738
739 // Now create read values that will be used to replace each resource that
740 // is read in the function body. These read values are just the same argument
741 // with type replaced.
742 llvm::SmallVector<Value, 4> skipped_args;
743 for (auto& it : hoister.GetResources()) {
744 BlockArgument arg = it.first.dyn_cast<BlockArgument>();
745 assert(arg && "Expect resources for FuncOp to be its arguments");
746 auto type_iter = resource_data_types.find(arg.getArgNumber());
747 if (type_iter == resource_data_types.end()) {
748 // Skip lifting the resource if it's not present in the data type map.
749 // This indicates that the resource is not to be lifted because it is used
750 // in an unsupported op in some other function.
751 skipped_args.push_back(arg);
752 } else {
753 arg.setType(type_iter->second);
754 it.second.hoisted_read = arg;
755 }
756 }
757
758 // Drop all the args that have to be skipped.
759 for (Value arg : skipped_args) hoister.DropResource(arg);
760
761 hoister.ReplaceResourceLoads(/*read_only=*/false);
762
763 // For writes, invoke the callback and then erase the write.
764 auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
765 for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
766 Value resource = assign_variable_op.resource();
767 if (!hoister.Contains(resource)) continue;
768
769 auto arg = resource.dyn_cast<BlockArgument>();
770 handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
771 assign_variable_op.erase();
772 }
773
774 func_op.setType(FunctionType::get(
775 func_op.getContext(), func_op.front().getArgumentTypes(),
776 func_op.front().getTerminator()->getOperandTypes()));
777
778 return success();
779 }
780
781 // Returns a vector filtered from range where the unused elements (specified by
782 // resource_arg_uses) are removed.
783 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)784 llvm::SmallVector<T, 4> FilterRange(
785 Range range,
786 const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
787 llvm::SmallVector<T, 4> filtered;
788 for (auto entry : llvm::enumerate(range)) {
789 auto it = resource_arg_uses.find(entry.index());
790 if (it == resource_arg_uses.end() || it->getSecond().used)
791 filtered.push_back(entry.value());
792 }
793 return filtered;
794 }
795
796 // Changes the types of the control flow op (e.g., while, if) and adds loads and
797 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
798 // be changed) index to its data type and the updated value index in the output
799 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)800 void AddLoadsStoresOutsideControlFlowOp(
801 Operation* caller,
802 const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
803 arg_data_type_and_updated_output_index) {
804 OpBuilder builder(caller);
805 auto new_operands = llvm::to_vector<8>(caller->getOperands());
806 llvm::SmallVector<int64_t, 8> changed_indices;
807 // Find the operands to change, and create the loads.
808 for (auto& entry : arg_data_type_and_updated_output_index) {
809 int64_t index = entry.getFirst();
810 Type new_type = entry.getSecond().first;
811 int64_t updated_index = entry.getSecond().second;
812 auto operand = caller->getOperand(index);
813 builder.setInsertionPoint(caller);
814 new_operands[index] = builder.create<TF::ReadVariableOp>(
815 caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
816 caller->setOperand(index, new_operands[index]);
817 if (updated_index < 0) continue;
818 builder.setInsertionPointAfter(caller);
819 builder.create<TF::AssignVariableOp>(
820 caller->getLoc(), ArrayRef<Type>{},
821 ArrayRef<Value>{operand, caller->getResult(updated_index)});
822 }
823 }
824
825 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,FuncOp body,FuncOp cond)826 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
827 auto return_op = body.front().getTerminator();
828 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
829 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
830 if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
831 failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
832 return failure();
833 }
834 // A resource is considered used as long as it is used in either body or cond.
835 auto resource_arg_uses =
836 MergeArgResourceUseInfo(body_use_info, cond_use_info);
837 if (resource_arg_uses.empty()) return success();
838
839 // Remove unused resources in functions.
840 llvm::SmallVector<int64_t, 4> old_to_new_indices;
841 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
842 RemoveUnusedResourceArgumentsAndForwardedRetvals(
843 resource_arg_uses, body, &old_to_new_indices,
844 &remaining_resource_data_types);
845 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
846 (void)LiftArgRetResourcesForFunction(
847 body, remaining_resource_data_types,
848 [&](int64_t index, Value value) { return_op->setOperand(index, value); });
849 (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
850 [&](int64_t index, Value value) {
851 // We already checked that cond should
852 // not have variable writes.
853 assert(false && "Should not happen");
854 });
855 // Recreate the while op.
856 OpBuilder builder(while_op);
857 // Now use the filtered original operands, which will be replaced by
858 // AddLoadsStoresOutsideControlFlowOp().
859 auto new_while = builder.create<TF::WhileOp>(
860 while_op.getLoc(), body.getType().getResults(),
861 FilterRange<Value, OperandRange>(while_op.getOperands(),
862 resource_arg_uses),
863 while_op->getAttrs());
864 // Prepare for AddLoadsStoresOutsideControlFlowOp().
865 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
866 arg_data_type_and_updated_output_index;
867 for (const auto& entry : remaining_resource_data_types) {
868 int64_t update_index = return_op->getOperand(entry.getFirst()) ==
869 body.getArgument(entry.getFirst())
870 ? -1
871 : entry.getFirst();
872 arg_data_type_and_updated_output_index[entry.getFirst()] = {
873 entry.getSecond(), update_index};
874 }
875 AddLoadsStoresOutsideControlFlowOp(new_while,
876 arg_data_type_and_updated_output_index);
877 // Replace uses.
878 for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
879 if (old_to_new_indices[i] >= 0) {
880 while_op.getResult(i).replaceAllUsesWith(
881 new_while.getResult(old_to_new_indices[i]));
882 }
883 }
884 while_op.erase();
885 return success();
886 }
887
888 // Lifts loads/stores from an IfOp or CaseOp's branches.
889 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches)890 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
891 // For canonicalized If/Case, there should not be any resource outputs
892 int64_t non_resource_results = op.getNumResults();
893
894 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
895 if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
896 return failure();
897
898 for (auto func : branches.drop_front()) {
899 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
900 if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
901 return failure();
902 // A resource is considered used as long as it is used in either branch.
903 resource_arg_uses =
904 MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
905 }
906
907 if (resource_arg_uses.empty()) return success();
908 // Remove unused resources in functions.
909 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
910 RemoveUnusedResourceArgumentsAndForwardedRetvals(
911 resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
912 &remaining_resource_data_types);
913 for (auto func : branches.drop_front())
914 RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
915
916 // Forward resource inputs updated in any branch to the outputs of both
917 // branches. First prepare the mapping from arg to new update output.
918 llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
919 {
920 int64_t removed_args = 0;
921 for (const auto& entry : resource_arg_uses) {
922 if (!entry.getSecond().used) {
923 removed_args++;
924 continue;
925 }
926 if (!entry.getSecond().updated) continue;
927 int64_t new_output_index =
928 non_resource_results + resource_arg_to_new_output.size();
929 resource_arg_to_new_output[entry.getFirst() - removed_args] =
930 new_output_index;
931 }
932 }
933
934 // Append resource updates to the return ops: now they are just forwarded
935 // input resources, but will be replaced by the data value in
936 // LiftArgRetResourcesForFunction().
937 for (auto branch : branches) {
938 auto new_retvals =
939 llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
940 new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
941 for (const auto& entry : resource_arg_to_new_output) {
942 int64_t resource_arg_index = entry.getFirst();
943 int64_t output_index = entry.getSecond();
944 new_retvals[output_index] = branch.getArgument(resource_arg_index);
945 }
946 auto old_return = branch.front().getTerminator();
947 OpBuilder builder(old_return);
948 auto new_return =
949 builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
950 old_return->erase();
951 (void)LiftArgRetResourcesForFunction(
952 branch, remaining_resource_data_types, [&](int64_t index, Value value) {
953 new_return.setOperand(resource_arg_to_new_output[index], value);
954 });
955 }
956
957 // Recreate the op without resource operands.
958 OpBuilder builder(op);
959 // Now use the filtered original operands, which will be replaced by
960 // AddLoadsStoresOutsideControlFlowOp().
961 auto new_operands =
962 FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
963 new_operands.insert(new_operands.begin(), op.getOperand(0));
964 FuncOp first_func = branches.front();
965 auto new_op =
966 builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
967 new_operands, op->getAttrs());
968 // Prepare for AddLoadsStoresOutsideControlFlowOp()
969 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
970 arg_data_type_and_updated_output_index;
971 for (const auto& entry : remaining_resource_data_types) {
972 auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
973 int64_t update_index = new_output_it == resource_arg_to_new_output.end()
974 ? -1
975 : new_output_it->getSecond();
976 arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
977 entry.getSecond(), update_index};
978 }
979 AddLoadsStoresOutsideControlFlowOp(new_op,
980 arg_data_type_and_updated_output_index);
981 // Replace uses.
982 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
983 op.erase();
984 return success();
985 }
986
987 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
988 // information about the lifting changes.
989 struct PartitionedCallLiftingInfo {
990 // Function with resources lifted. Can be nullptr if nothing needs to change.
991 FuncOp lifted_callee;
992 // Mapping from old resource outputs to their aliasing output inputs.
993 llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
994 // Mapping from old to new output indices in case any output is removed.
995 llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
996 // ResourceArgUseInfo for each old resource argument.
997 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
998 // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
999 llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1000 arg_data_type_and_updated_output_index;
1001 };
1002
1003 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1004 // needs to be changed, the original function will be preserved, and the lifting
1005 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(FuncOp callee,PartitionedCallLiftingInfo * result)1006 LogicalResult HandlePartitionedCallOpCallee(
1007 FuncOp callee, PartitionedCallLiftingInfo* result) {
1008 // Sanity check: return of resources should be aliases of inputs. Such outputs
1009 // will be removed later.
1010 int64_t non_resource_results = 0;
1011 for (auto entry :
1012 llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1013 auto retval = entry.value();
1014 if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1015 result->old_to_new_output_indices.push_back(non_resource_results++);
1016 continue;
1017 }
1018 auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1019 if (!aliasing_arg) {
1020 return callee.emitOpError("unsupported function call: ")
1021 << "resource return value does not alias an input.";
1022 }
1023 result->old_outputs_aliasing_old_inputs[entry.index()] =
1024 aliasing_arg.getArgNumber();
1025 result->old_to_new_output_indices.push_back(-1);
1026 }
1027
1028 if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1029 return failure();
1030 }
1031 if (result->use_info.empty()) {
1032 result->lifted_callee = nullptr;
1033 return success();
1034 }
1035
1036 // Clone the callee before making changes.
1037 SmallString<64> name_base = callee.getName();
1038 auto module = callee->getParentOfType<ModuleOp>();
1039 name_base += "_resource_lifted";
1040 auto name = name_base;
1041 callee = callee.clone();
1042 callee.setPrivate();
1043 callee.setName(name);
1044 SymbolTable(module).insert(callee);
1045 result->lifted_callee = callee;
1046
1047 // Remove unused resources in functions.
1048 llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1049 RemoveUnusedResourceArgumentsAndForwardedRetvals(
1050 result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1051 &remaining_resource_data_types);
1052 for (const auto& entry : remaining_resource_data_types) {
1053 result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1054 entry.getSecond(), -1};
1055 }
1056 llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1057 for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1058 // Store indices of results that are not resources.
1059 if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1060 retval_indices_to_preserve.push_back(val.getOperandNumber());
1061 }
1062 int64_t num_retvals = retval_indices_to_preserve.size();
1063 llvm::SmallVector<Value, 4> new_retvals;
1064 // Lift resources.
1065 (void)LiftArgRetResourcesForFunction(
1066 callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1067 result->arg_data_type_and_updated_output_index[index].second =
1068 num_retvals++;
1069 new_retvals.push_back(value);
1070 });
1071
1072 auto old_return = callee.front().getTerminator();
1073 llvm::SmallVector<Value, 4> old_and_new_retvals;
1074 old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1075 new_retvals.size());
1076 for (int64_t retval_index : retval_indices_to_preserve)
1077 old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1078
1079 old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1080 // Replace old return with the new ones with update values.
1081 OpBuilder builder(old_return);
1082 auto new_return =
1083 builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1084 old_return->erase();
1085 callee.setType(
1086 FunctionType::get(callee.getContext(), callee.getType().getInputs(),
1087 llvm::to_vector<4>(new_return.getOperandTypes())));
1088 return success();
1089 }
1090
1091 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1092 // resource-lifted new callee function in lifting_info.
1093 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1094 void UpdatePartitionedCallOpWithNewCallee(
1095 CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1096 if (!lifting_info.lifted_callee) return;
1097 // Replace output resource uses with the aliasing input, so that we can remove
1098 // this output.
1099 for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1100 call_op.getResult(entry.getFirst())
1101 .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1102 }
1103 // Recreate the call op.
1104 OpBuilder builder(call_op);
1105 // Now use the filtered original operands, which will be replaced by
1106 // AddLoadsStoresOutsideControlFlowOp().
1107 auto new_operands =
1108 FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1109 auto new_call = builder.create<CallOpType>(
1110 call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
1111 new_operands, call_op->getAttrs());
1112 new_call->setAttr(
1113 "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
1114 AddLoadsStoresOutsideControlFlowOp(
1115 new_call, lifting_info.arg_data_type_and_updated_output_index);
1116 // Replace uses.
1117 for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1118 i < end; ++i) {
1119 if (lifting_info.old_to_new_output_indices[i] >= 0) {
1120 call_op.getResult(i).replaceAllUsesWith(
1121 new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1122 }
1123 }
1124 call_op.erase();
1125 }
1126
1127 LogicalResult HoistForControlFlow(
1128 Block*, ModuleOp, bool,
1129 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1130
1131 // A templated routine for handling both PartitionedCallOp and
1132 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1133 // the caller op itself; otherwise, it first recursively handles nested control
1134 // flow, then performs lifting on the callee.
1135 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1136 LogicalResult HandlePartitionedCallOp(
1137 CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized,
1138 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1139 lifted_callees) {
1140 auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1141 PartitionedCallLiftingInfo());
1142 if (emplace_res.second) {
1143 // Unseen callee. Perform resource lifting on it.
1144 if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1145 lifted_callees)))
1146 return failure();
1147
1148 if (failed(HandlePartitionedCallOpCallee(
1149 callee, &emplace_res.first->getSecond()))) {
1150 return failure();
1151 }
1152 }
1153 UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1154 return success();
1155 }
1156
1157 // Hoists resource loads/stores from control flow ops in `block` outside the
1158 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1159 LogicalResult HoistForControlFlow(
1160 Block* block, ModuleOp module, bool vars_initialized,
1161 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1162 lifted_partitioned_call_callees) {
1163 if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1164
1165 for (Operation& op : llvm::make_early_inc_range(*block)) {
1166 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1167 auto body = while_op.body_function();
1168 auto cond = while_op.cond_function();
1169 // Recursively handle the nested control flow.
1170 (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1171 lifted_partitioned_call_callees);
1172 (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1173 lifted_partitioned_call_callees);
1174 if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1175 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1176 auto then_branch = if_op.then_function();
1177 auto else_branch = if_op.else_function();
1178 // Recursively handle the nested control flow.
1179 (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1180 lifted_partitioned_call_callees);
1181 (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1182 lifted_partitioned_call_callees);
1183 if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1184 return failure();
1185 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1186 SmallVector<FuncOp, 4> branch_functions;
1187 case_op.get_branch_functions(branch_functions);
1188 for (FuncOp func : branch_functions) {
1189 // Recursively handle the nested control flow.
1190 (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1191 lifted_partitioned_call_callees);
1192 }
1193 if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1194 } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1195 auto callee = call_op.func();
1196 if (!callee) {
1197 return call_op.emitOpError(
1198 "resource lifting does not support call with nested references.");
1199 }
1200 if (failed(HandlePartitionedCallOp(call_op, callee, module,
1201 vars_initialized,
1202 lifted_partitioned_call_callees))) {
1203 // Nested control flow handling is done in HandlePartitionedCallOp().
1204 return failure();
1205 }
1206 } else if (auto call_op =
1207 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1208 if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1209 vars_initialized,
1210 lifted_partitioned_call_callees))) {
1211 return failure();
1212 }
1213 } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1214 for (Region& region : op.getRegions())
1215 (void)HoistForControlFlow(®ion.front(), module, vars_initialized,
1216 lifted_partitioned_call_callees);
1217 LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1218 if (failed(result)) return failure();
1219 }
1220 }
1221
1222 // After we have hoisted operations in the block, we may have added new read
1223 // and writes of resources to this block. Clean them up by doing store-load
1224 // forwarding.
1225 ForwardStoreToLoad(block);
1226 return success();
1227 }
1228
1229 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1230 // Returns failure if there are remaining resource-type values that can not be
1231 // lifted.
runOnOperation()1232 void ResourceOpLiftingPass::runOnOperation() {
1233 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1234 lifted_partitioned_call_callees;
1235 ModuleOp module = getOperation();
1236
1237 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1238 return signalPassFailure();
1239
1240 auto walk_result = module.walk([&](FuncOp func_op) {
1241 return func_op.walk([&](tf_device::ClusterOp cluster) {
1242 LogicalResult result = HoistForControlFlow(
1243 &cluster.GetBody(), module, /*vars_initialized=*/true,
1244 &lifted_partitioned_call_callees);
1245 if (failed(result)) return WalkResult::interrupt();
1246 result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1247 if (failed(result)) return WalkResult::interrupt();
1248 return WalkResult::advance();
1249 });
1250 });
1251
1252 if (walk_result.wasInterrupted()) return signalPassFailure();
1253 }
1254
1255 struct ResourceOpLiftingForMainFunctionPass
1256 : public TFDevice::ResourceOpLiftingForMainFunctionPassBase<
1257 ResourceOpLiftingForMainFunctionPass> {
1258 void runOnOperation() override;
1259 };
1260
runOnOperation()1261 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1262 ModuleOp module = getOperation();
1263 FuncOp main_func = module.lookupSymbol<FuncOp>("main");
1264 if (!main_func) {
1265 return;
1266 }
1267
1268 if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1269 return signalPassFailure();
1270 }
1271 }
1272
1273 } // namespace
1274
1275 namespace TFDevice {
CreateResourceOpLiftingPass()1276 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1277 return std::make_unique<ResourceOpLiftingPass>();
1278 }
1279
1280 std::unique_ptr<OperationPass<ModuleOp>>
CreateResourceOpLiftingForMainFunctionPass()1281 CreateResourceOpLiftingForMainFunctionPass() {
1282 return std::make_unique<ResourceOpLiftingForMainFunctionPass>();
1283 }
1284
1285 } // namespace TFDevice
1286
1287 namespace TF {
ResourceLiftingForFunctionalControlFlow(FuncOp function)1288 LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
1289 // This routine should only be called when control flow operations are still
1290 // represented with TF IfOp and WhileOp operations. In this case, there should
1291 // be only one basic blocks in the MLIR representation.
1292 if (!llvm::hasSingleElement(function)) {
1293 return function.emitError()
1294 << "expect the function to have 1 block while it has "
1295 << function.getBlocks().size();
1296 }
1297
1298 if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1299 return failure();
1300
1301 llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1302 lifted_partitioned_call_callees;
1303 if (failed(HoistForControlFlow(
1304 &function.front(), cast<ModuleOp>(function->getParentOp()),
1305 /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1306 return failure();
1307
1308 // Clean up and canonicalize to remove dead local variables as some local
1309 // variables might be dead after hoisting resource loads/stores from control
1310 // flow ops.
1311 return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1312 }
1313 } // namespace TF
1314
1315 } // namespace mlir
1316