• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass lifts resource variable operations outside of device computation.
17 
18 #include <cstddef>
19 #include <cstdint>
20 
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
34 #include "mlir/IR/Builders.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
36 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
37 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/Region.h"  // from @llvm-project
40 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
41 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
42 #include "mlir/IR/Types.h"  // from @llvm-project
43 #include "mlir/IR/Value.h"  // from @llvm-project
44 #include "mlir/IR/Verifier.h"  // from @llvm-project
45 #include "mlir/IR/Visitors.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Support/LLVM.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
51 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
52 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
54 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
57 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
58 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
59 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
60 #include "tensorflow/core/framework/tensor_shape.pb.h"
61 
62 namespace mlir {
63 
64 namespace {
65 
66 constexpr char kDeviceAttr[] = "device";
67 
68 // Lift resource operations out of device computation.
69 struct ResourceOpLiftingPass
70     : public TFDevice::ResourceOpLiftingPassBase<ResourceOpLiftingPass> {
71   void runOnOperation() override;
72 };
73 
IsResource(Value value)74 bool IsResource(Value value) {
75   return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
76 }
77 
78 // Get the type of the data contained in a resource. Returns null if there is
79 // no single type in the resource.
GetResourceSubtype(Value value)80 Type GetResourceSubtype(Value value) {
81   auto resource_type =
82       getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
83   auto subtypes = resource_type.getSubtypes();
84   if (subtypes.size() == 1) return subtypes[0];
85   return nullptr;
86 }
87 
88 // Replaces all `tf.VarIsInitializedOp` in a block with a constant true.
89 // TODO(b/171039585): Replace this with proper analysis of
90 // `tf.VarIsInitializedOp` in regards to resource writes and control flow.
SetAllVarIsInitializedToTrue(Block * block)91 void SetAllVarIsInitializedToTrue(Block* block) {
92   auto builder = OpBuilder::atBlockBegin(block);
93   TF::ConstOp const_true = nullptr;
94   for (auto op :
95        llvm::make_early_inc_range(block->getOps<TF::VarIsInitializedOp>())) {
96     builder.setInsertionPoint(op);
97     if (!const_true)
98       const_true = builder.create<TF::ConstOp>(
99           op.getLoc(),
100           DenseIntElementsAttr::get(
101               RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
102 
103     op.is_initialized().replaceAllUsesWith(const_true);
104     op.erase();
105   }
106 }
107 
108 // Performs store-load forwarding. This effectively removes
109 // 1) Any resource loads after a store to that same resource is done
110 // 2) Any resource stores except the last one.
111 // TODO(ycao): Store-load forwarding implemented here is only correct when
112 // computation is purely sequential (no concurrency). Need to support concurrent
113 // computation as well.
ForwardStoreToLoad(Block * block)114 void ForwardStoreToLoad(Block* block) {
115   // resource_handle_to_last_store_op keeps track of the most recent (last)
116   // store to each resource. Non-existent entry indicates that a resource has
117   // not been stored to yet.
118   llvm::SmallDenseMap<Value, TF::AssignVariableOp>
119       resource_handle_to_last_store_op;
120 
121   // Only iterate through ops directly in the block as we can't handle ops
122   // nested deeper in regions.
123   for (Operation& op : llvm::make_early_inc_range(*block)) {
124     if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
125       Value resource = read_variable_op.resource();
126       auto last_store = resource_handle_to_last_store_op[resource];
127       if (!last_store) continue;
128 
129       // Use stored value in last_store to replace all uses of current resource
130       // load's result, then erase this resource load. Add an intermediate
131       // CastOp if the shape of types doesn't exactly match.
132       Type read_type = read_variable_op.value().getType();
133       if (read_type != last_store.value().getType()) {
134         OpBuilder builder(last_store);
135         builder.setInsertionPointAfter(last_store);
136         auto cast = builder.create<TF::CastOp>(
137             last_store.getLoc(), read_type, last_store.value(),
138             /*Truncate=*/builder.getBoolAttr(false));
139         read_variable_op.value().replaceAllUsesWith(cast);
140       } else {
141         read_variable_op.value().replaceAllUsesWith(last_store.value());
142       }
143 
144       read_variable_op.erase();
145       continue;
146     }
147 
148     if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
149       Value resource = assign_variable_op.resource();
150       auto last_store = resource_handle_to_last_store_op[resource];
151       // Previous store ops to same resource can be erased.
152       if (last_store) last_store.erase();
153 
154       resource_handle_to_last_store_op[resource] = assign_variable_op;
155     }
156   }
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // RegionResourceHoister
161 //===----------------------------------------------------------------------===//
162 
163 // Helper class to hoist resource ops out of regions attached to an op.
164 class RegionResourceHoister {
165  public:
RegionResourceHoister(Operation * op)166   explicit RegionResourceHoister(Operation* op) : op_(op) {}
167 
168   // Analyzes attached regions to record resources read and written.
169   LogicalResult Analyze();
170 
171   // Returns all resources accessed by the regions attached the op.
GetResources()172   auto& GetResources() { return resources_; }
173 
174   // Returns if the given value is a resource that needs lifting.
Contains(Value resource) const175   bool Contains(Value resource) const {
176     return resources_.find(resource) != resources_.end();
177   }
178 
179   // Drops the given resource from lifting.
DropResource(Value resource)180   void DropResource(Value resource) {
181     resources_.erase(resource);
182     written_resources_.remove(resource);
183   }
184 
185   // Replaces all resource loads in all regions attached to the op.
ReplaceResourceLoads(bool read_only)186   void ReplaceResourceLoads(bool read_only) {
187     llvm::for_each(op_->getRegions(), [&](Region& region) {
188       ReplaceResourceLoads(region, read_only);
189     });
190   }
191 
192   static LogicalResult ReplaceOpWithNewOp(Operation* op);
193 
194  private:
195   // Returns if any resources need lifting.
NeedsLifting() const196   bool NeedsLifting() const { return !resources_.empty(); }
197 
198   // Returns the number of results generated by the lifted op.
GetLiftedNumResults() const199   int GetLiftedNumResults() const { return num_new_results_; }
200 
201   // Generates hoisted reads for resources that need them before the op.
202   void GenerateHoistedReads();
203 
204   // Replaces all resource loads in the given region with hoisted loads. If
205   // `read_only` is true, limit this replacement to read only resources.
206   void ReplaceResourceLoads(Region& region, bool read_only);
207 
208   // Appends final values writte to resources to the region returns for the
209   // given set of regions.
210   void AppendResourceStoreValueToReturn(RegionRange regions);
211 
212   // Performs the final replacement of the op.
213   void ReplaceOpWithNewOp();
214 
215   // Returns is this resource was written to in any of the regions.
IsWritten(Value resource) const216   bool IsWritten(Value resource) const {
217     return written_resources_.contains(resource);
218   }
219 
220   static LogicalResult HoistResourcesOutOfIfCaseCluster(Operation* op);
221   static LogicalResult HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op);
222 
223   Operation* op_;
224 
225   // Per resource information about accesses to that resource.
226   struct ResourceInfo {
227     // Is this resource read in any of the regions?
228     bool is_read;
229     // Is this resource written in any of the regions?
230     bool is_written;
231     // Is this resource written in all of the regions?
232     bool is_written_all;
233     // The hoisted read used to replace region reads.
234     Value hoisted_read;
235     // the type of the data held by the resource.
236     Type data_type;
237     // For written resources, the result # of the lifted op which will hold the
238     // value of the resource. This result will be used to generates writes to
239     // the resource after the lifted op.
240     int result_index;
241     // Attributes on the read operation.
242     DictionaryAttr read_attrs;
243     // Attributes on the write operation.
244     DictionaryAttr write_attrs;
245 
ResourceInfomlir::__anonfff803140111::RegionResourceHoister::ResourceInfo246     ResourceInfo()
247         : is_read(false),
248           is_written(false),
249           is_written_all(false),
250           hoisted_read(nullptr),
251           data_type(nullptr),
252           result_index(-1) {}
253 
IsResultIndexAssignedmlir::__anonfff803140111::RegionResourceHoister::ResourceInfo254     bool IsResultIndexAssigned() { return result_index != -1; }
255 
256     // Refine the resource type using the given type `type`.
RefineTypemlir::__anonfff803140111::RegionResourceHoister::ResourceInfo257     void RefineType(Type type) {
258       if (!data_type) {
259         data_type = type;
260       } else {
261         data_type = TF::GetCastCompatibleType(data_type, type,
262                                               /*may_ignore_ref_type_a=*/false);
263         assert(data_type != nullptr && "Resource used with incompatible types");
264       }
265     }
266   };
267   llvm::MapVector<Value, ResourceInfo> resources_;
268   llvm::SetVector<Value> written_resources_;
269   // number of new results after lifting.
270   int num_new_results_;
271 };
272 
273 // Analyzes resources that are read or written within attached regions.
Analyze()274 LogicalResult RegionResourceHoister::Analyze() {
275   // Hoisting of child regions might have created opportunity for store-load
276   // forwarding.
277   for (Region& region : op_->getRegions()) {
278     ForwardStoreToLoad(&region.front());
279   }
280 
281   llvm::SetVector<Value> all_resources;
282   bool is_func = false;
283   // For functions, the resources to analyze are the function arguments.
284   // Otherwise, its the region captures.
285   if (FuncOp func = dyn_cast<FuncOp>(op_)) {
286     is_func = true;
287     Region& body = func.getBody();
288     for (BlockArgument arg : body.getArguments()) {
289       if (IsResource(arg)) all_resources.insert(arg);
290     }
291   } else {
292     getUsedValuesDefinedAbove(op_->getRegions(), all_resources);
293     all_resources.remove_if([](Value value) { return !IsResource(value); });
294   }
295 
296   num_new_results_ = op_->getNumResults();
297 
298   for (auto resource : all_resources) {
299     ResourceInfo info;
300     info.data_type = GetResourceSubtype(resource);
301     llvm::BitVector written_regions(op_->getNumRegions());
302     bool unsupported_use = false;
303     for (OpOperand& use : resource.getUses()) {
304       Operation* user = use.getOwner();
305       // If the user is not in one of the regions, we are not interested in it.
306       // Since all the sub-regions within this region (i.e., regions attached to
307       // op's in this region) have themselves gone through lifting, all resource
308       // users are expected to be operations in this region and not embedded
309       // within other sub-regions attached to op's in this region. So the check
310       // for whether a user is in one of the regions attached to this op is
311       // straightforward.
312       if (user->getParentRegion()->getParentOp() != op_) continue;
313 
314       // For functions, if the resource is used as a return operand, use that
315       // as its result index.
316       if (is_func && isa<ReturnOp>(user)) {
317         assert(!info.IsResultIndexAssigned() &&
318                "Expect resource argument to returned no more than once");
319         info.result_index = use.getOperandNumber();
320         continue;
321       }
322 
323       auto read = dyn_cast<TF::ReadVariableOp>(user);
324       auto write = dyn_cast<TF::AssignVariableOp>(user);
325       if (!read && !write) {
326         unsupported_use = true;
327         break;
328       }
329 
330       if (read && !info.is_read) {
331         info.is_read = true;
332         info.RefineType(read.value().getType());
333         info.read_attrs = user->getAttrDictionary();
334       }
335 
336       if (write) {
337         info.is_written = true;
338         info.RefineType(write.value().getType());
339         info.write_attrs = user->getAttrDictionary();
340         written_regions.set(user->getParentRegion()->getRegionNumber());
341       }
342     }
343 
344     // If the resource is used in an op that we do not understand, skip
345     // lifting for that resource.
346     if (unsupported_use) continue;
347 
348     info.is_written_all = written_regions.count() == op_->getNumRegions();
349 
350     // If the resource is written in some but not all regions, we would need
351     // a read for the value before these regions. Note that this is applicable
352     // only to multi-region ops:
353     // If/Case: If not all regions write to the resource, post hoisting the read
354     //   value need to be routed through all paths that don't write.
355     // While: since while condition cannot write, any resource written in the
356     //   while body will need to be read as well in case the while body is never
357     //   executed.
358     // Both cases are handled by the condition below.
359     if (info.is_written && !info.is_written_all) info.is_read = true;
360 
361     // Allocate a result index for written resources that don't have one.
362     if (info.is_written) {
363       written_resources_.insert(resource);
364       if (!info.IsResultIndexAssigned()) info.result_index = num_new_results_++;
365     }
366 
367     resources_.insert({resource, info});
368   }
369   return success();
370 }
371 
372 // Generates hoisted reads for all resources that need them just before the op.
GenerateHoistedReads()373 void RegionResourceHoister::GenerateHoistedReads() {
374   OpBuilder builder(op_);
375   DictionaryAttr empty_attrs = builder.getDictionaryAttr({});
376   for (auto& resource_it : GetResources()) {
377     Value resource = resource_it.first;
378     auto& info = resource_it.second;
379 
380     if (info.is_read) {
381       Operation* read = builder.create<TF::ReadVariableOp>(
382           op_->getLoc(), info.data_type, resource);
383       read->setAttrs(info.read_attrs ? info.read_attrs : empty_attrs);
384       read->removeAttr(kDeviceAttr);
385       info.hoisted_read = read->getResult(0);
386     }
387   }
388 }
389 
390 // Replaces all resource reads with the hoisted read.
ReplaceResourceLoads(Region & region,bool read_only)391 void RegionResourceHoister::ReplaceResourceLoads(Region& region,
392                                                  bool read_only) {
393   assert(llvm::hasSingleElement(region) && "Expected single block region");
394   // Only iterate through ops directly in the body as we can't handle
395   // ops nested deeper in regions.
396   auto all_reads = region.front().getOps<TF::ReadVariableOp>();
397   for (auto read_op : llvm::make_early_inc_range(all_reads)) {
398     Value resource = read_op.resource();
399     if (!Contains(resource)) continue;
400 
401     ResourceInfo& info = resources_[resource];
402     // If replacing loads for read only resources, skip if the resource
403     // was written to.
404     if (read_only && info.is_written) continue;
405 
406     read_op.replaceAllUsesWith(info.hoisted_read);
407     read_op.erase();
408   }
409 }
410 
411 // For written resources, add its value at the end of each region to that
412 // regions return value. For a region, its value at the end may be a value
413 // written to that resource in that region, or its hoisted read value if the
414 // resource is not written in that region. The return value can be vended out
415 // either as an existing return value, or a newly allocated return value.
AppendResourceStoreValueToReturn(RegionRange regions)416 void RegionResourceHoister::AppendResourceStoreValueToReturn(
417     RegionRange regions) {
418   for (Region* region : regions) {
419     assert(llvm::hasSingleElement(*region) && "Expected single block region");
420     Block& front = region->front();
421     auto old_return = front.getTerminator();
422     assert(old_return->getNumOperands() == op_->getNumResults());
423     auto new_return_operands = llvm::to_vector<4>(old_return->getOperands());
424     new_return_operands.resize(num_new_results_);
425 
426     // initialize return values for written resources to be the hoisted reads.
427     for (Value resource : written_resources_) {
428       const ResourceInfo& info = resources_[resource];
429       new_return_operands[info.result_index] = info.hoisted_read;
430     }
431 
432     // Only iterate through ops directly in the body as op's embedded in child
433     // regions should have been lifted out.
434     auto assign_ops = front.getOps<TF::AssignVariableOp>();
435     for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
436       Value resource = assign_variable_op.resource();
437       if (!IsWritten(resource)) continue;
438 
439       // TODO(ycao): Prevent same value from being returned multiple times.
440       // TODO(ycao): Do not return resource store value if it is defined outside
441       // of cluster. Both of these can be post-resource-op-lifting cleanup
442       // passes.
443       int result_index = resources_[resource].result_index;
444       new_return_operands[result_index] = assign_variable_op.value();
445       assign_variable_op.erase();
446     }
447     old_return->setOperands(new_return_operands);
448   }
449 }
450 
451 // Replace the old op with a new op (with potentially additional results), and
452 // add stores to written resources after the new op.
ReplaceOpWithNewOp()453 void RegionResourceHoister::ReplaceOpWithNewOp() {
454   auto new_result_types = llvm::to_vector<4>(op_->getResultTypes());
455   int result_region = isa<TF::WhileRegionOp>(op_) ? 1 : 0;
456   Operation* terminator = op_->getRegion(result_region).front().getTerminator();
457   auto extra_result_types =
458       terminator->getOperands().drop_front(op_->getNumResults()).getTypes();
459   new_result_types.insert(new_result_types.end(), extra_result_types.begin(),
460                           extra_result_types.end());
461   OpBuilder builder(op_);
462   // Clone this old operation but with new result types.
463   Operation* new_op = Operation::create(
464       op_->getLoc(), op_->getName(), new_result_types, op_->getOperands(),
465       op_->getAttrs(), op_->getSuccessors(), op_->getNumRegions());
466   builder.insert(new_op);
467 
468   // Move regions to the new op.
469   for (auto it : llvm::zip(op_->getRegions(), new_op->getRegions())) {
470     Region& old_region = std::get<0>(it);
471     Region& new_region = std::get<1>(it);
472     new_region.takeBody(old_region);
473   }
474 
475   // Insert stores to all written resources.
476   for (Value resource : written_resources_) {
477     ResourceInfo& info = resources_[resource];
478     Value value_to_write = new_op->getResult(info.result_index);
479     Operation* write = builder.create<TF::AssignVariableOp>(
480         op_->getLoc(), resource, value_to_write);
481     write->setAttrs(info.write_attrs);
482     write->removeAttr(kDeviceAttr);
483   }
484 
485   // As a part of lifting, we either reuse an existing slot for resource type
486   // results or add a new slot. Resource type results should not have any uses
487   // to begin with. So we can safely replace each old op result with the
488   // corresponding new op result.
489   int old_num_results = op_->getNumResults();
490   op_->replaceAllUsesWith(new_op->getResults().take_front(old_num_results));
491   op_->erase();
492   op_ = nullptr;
493 }
494 
495 // Lift resource load and stores out of regions attached to `op`, where op is
496 // an If/case/cluster op.
HoistResourcesOutOfIfCaseCluster(Operation * op)497 LogicalResult RegionResourceHoister::HoistResourcesOutOfIfCaseCluster(
498     Operation* op) {
499   RegionResourceHoister hoister(op);
500   if (failed(hoister.Analyze())) return failure();
501 
502   // If there are no resource region captures, then nothing to do.
503   if (!hoister.NeedsLifting()) return success();
504 
505   // Start the transformation. For each region, replace the resource read with
506   // the value read before the op.
507   hoister.GenerateHoistedReads();
508   hoister.ReplaceResourceLoads(/*read_only=*/false);
509   hoister.AppendResourceStoreValueToReturn(op->getRegions());
510   hoister.ReplaceOpWithNewOp();
511   return success();
512 }
513 
514 // Lift resource loads and stores out of WhileRegion
HoistResourcesOutOfWhileRegion(TF::WhileRegionOp op)515 LogicalResult RegionResourceHoister::HoistResourcesOutOfWhileRegion(
516     TF::WhileRegionOp op) {
517   // For WhileRegion, post canonicalization all resource used within the
518   // body and condition regions are replaced with captured values, so we do not
519   // need to take into account the body and condition region arguments.
520   RegionResourceHoister hoister(op);
521 
522   if (failed(hoister.Analyze())) return failure();
523 
524   // If there are no resource region captures, then nothing to do.
525   if (!hoister.NeedsLifting()) return success();
526 
527   // The resources captured for While loop fall into two categories:
528   // (a) read-only. These reads can be replaced by a hoisted read created
529   //        before the WhileOp (similar to if and case).
530   // (b) written: since the value is written in the loop (which can only in
531   //        loop body, all these will become loop variables. Since all resource
532   //        variables are removed from the loop variabled during
533   //        canonicalizationW, we need to create new operand/result slots. The
534   //        input operands for these slots are the read values
535   //        prior to the op, and all references to these are replaced by the
536   //        corresponding slot argument. We need to generate writes following
537   //        the while for these resources.
538   //
539   // Note that for WhileRegion ops, if a resource is written, it will be written
540   // only in the body and not the condition, so the hoister analysis will infer
541   // it as needing a read as well.
542 
543   // Generate hoisted reads before the while.
544   hoister.GenerateHoistedReads();
545 
546   // Replace just the read-only resources with the hoisted reads.
547   hoister.ReplaceResourceLoads(/*read_only=*/true);
548 
549   // For written resources, add additional operands to the while op.
550   int num_old_results = op.getNumResults();
551   int num_new_results = hoister.GetLiftedNumResults();
552   int num_extra_results = num_new_results - num_old_results;
553 
554   SmallVector<Type, 4> new_result_types;
555   SmallVector<Value, 4> new_while_operands;
556   new_result_types.resize(num_extra_results);
557   new_while_operands.resize(num_extra_results);
558 
559   for (auto& it : hoister.GetResources()) {
560     if (!it.second.is_written) continue;
561     int index = it.second.result_index - num_old_results;
562     new_result_types[index] = it.second.data_type;
563     new_while_operands[index] = it.second.hoisted_read;
564   }
565   op.getOperation()->insertOperands(op.getNumOperands(), new_while_operands);
566 
567   // Patch the cond and body regions to have additional arguments, and replace
568   // the remaining resource reads (which will be resource reads for written
569   // resources) with these arguments.
570   for (Region* region : op.getRegions()) {
571     region->addArguments(new_result_types);
572     // Point hoisted read for written resources to the region's arguments.
573     for (auto& it : hoister.GetResources()) {
574       if (!it.second.is_written) continue;
575       it.second.hoisted_read = region->getArgument(it.second.result_index);
576     }
577     hoister.ReplaceResourceLoads(*region, /*read_only=*/false);
578   }
579 
580   // Add additional return values to body return. These correspond to values
581   // written to resources in the body region.
582   hoister.AppendResourceStoreValueToReturn(op.getRegions().drop_front());
583 
584   // Finally, create a new while with additional return values.
585   hoister.ReplaceOpWithNewOp();
586   return success();
587 }
588 
589 // Lift resources out of the regions attached to `op`
ReplaceOpWithNewOp(Operation * op)590 LogicalResult RegionResourceHoister::ReplaceOpWithNewOp(Operation* op) {
591   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op))
592     return HoistResourcesOutOfWhileRegion(while_op);
593   return HoistResourcesOutOfIfCaseCluster(op);
594 }
595 
596 // Holds information about a function's use of a resource argument.
597 struct ResourceArgUseInfo {
598   // Data type of the data contained in the resource.
599   Type data_type;
600   // Is the resource argument used in an assign op?
601   bool updated;
602   // Is the resource argument used in a read or assign op?
603   bool used;
604 };
605 
606 // Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
607 // output (i.e., the argument is an operand of the return op) is not considered
608 // as a use. This doesn't support nesting of ops, so before calling this, nested
609 // ops/functions need to be already resource-lifted.
FindResourceArgUseInfo(FuncOp func_op,llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> * result)610 LogicalResult FindResourceArgUseInfo(
611     FuncOp func_op, llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
612   auto return_op = func_op.front().getTerminator();
613   for (auto arg : TF::filter_resources(func_op.getArguments())) {
614     ResourceArgUseInfo info;
615     info.used = false;
616     info.updated = false;
617     bool read_or_assigned = false;
618     bool used_in_unsupported_op = false;
619     for (auto user : arg.getUsers()) {
620       if (user == return_op) continue;
621       info.used = true;
622       if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
623         read_or_assigned = true;
624         info.data_type = read.getType();
625         continue;
626       }
627 
628       if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
629         read_or_assigned = true;
630         info.updated = true;
631         info.data_type = assign.value().getType();
632         continue;
633       }
634 
635       used_in_unsupported_op = true;
636       break;
637     }
638 
639     // If the arg is used in an unsupported op, skip lifting it.
640     if (used_in_unsupported_op) continue;
641     (*result)[arg.getArgNumber()] = info;
642   }
643   return success();
644 }
645 
646 // Merges two sets of resource arg use infos. An argument is considered used in
647 // the merged result as long as either set marks it as used. This is used to
648 // merge results from functions that have aliasing inputs, e.g., a while loop's
649 // body and condition. The sets of keys of the two maps must be the same.
MergeArgResourceUseInfo(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos0,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos1)650 llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
651     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
652     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
653   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
654   for (const auto& entry : infos0) {
655     auto info1_it = infos1.find(entry.getFirst());
656     // If the entry is missing in any input, we should not touch this entry.
657     if (info1_it == infos1.end()) continue;
658     auto& info = result[entry.getFirst()];
659     info = entry.getSecond();
660     if (info.updated) continue;
661     if (info1_it->getSecond().used) {
662       info.used = true;
663       info.updated = info1_it->getSecond().updated;
664       info.data_type = info1_it->getSecond().data_type;
665     }
666   }
667   return result;
668 }
669 
670 // Removes the unused resource arguments, and the return values that forward the
671 // removed arguments. If old_to_new_arg_indices is provided, it will store the
672 // new argument index that corresponds to each original index (-1 means it is
673 // removed). If remaining_resource_data_types is provided, it will store the
674 // data types of the remaining resource arguments, where the indices are after
675 // removing unused ones.
RemoveUnusedResourceArgumentsAndForwardedRetvals(const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & infos,FuncOp func_op,llvm::SmallVector<int64_t,4> * old_to_new_arg_indices=nullptr,llvm::SmallDenseMap<int64_t,Type> * remaining_resource_data_types=nullptr)676 void RemoveUnusedResourceArgumentsAndForwardedRetvals(
677     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
678     FuncOp func_op,
679     llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
680     llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
681         nullptr) {
682   // Remove return values forwarded from unused arguments.
683   auto return_op = func_op.front().getTerminator();
684   auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
685   int64_t skipped_retvals = 0;
686   for (auto entry : llvm::enumerate(old_return_vals)) {
687     auto return_val = entry.value();
688     if (auto arg = return_val.dyn_cast<BlockArgument>()) {
689       auto it = infos.find(arg.getArgNumber());
690       if (it != infos.end() && !it->getSecond().used) {
691         return_op->eraseOperand(entry.index() - skipped_retvals++);
692       }
693     }
694   }
695   llvm::SmallVector<unsigned int, 4> indices_to_erase;
696   llvm::SmallVector<Type, 4> new_types;
697   int64_t skipped_args = 0;
698   for (auto arg : func_op.getArguments()) {
699     auto it = infos.find(arg.getArgNumber());
700     if (it != infos.end() && !it->getSecond().used) {
701       indices_to_erase.push_back(arg.getArgNumber());
702       skipped_args++;
703       if (old_to_new_arg_indices != nullptr) {
704         old_to_new_arg_indices->push_back(-1);
705       }
706     } else {
707       new_types.push_back(arg.getType());
708       if (old_to_new_arg_indices != nullptr) {
709         old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
710       }
711       if (it != infos.end() && remaining_resource_data_types != nullptr) {
712         (*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
713             it->second.data_type;
714       }
715     }
716   }
717   func_op.eraseArguments(indices_to_erase);
718   func_op.setType(
719       FunctionType::get(func_op.getContext(), new_types,
720                         llvm::to_vector<4>(return_op->getOperandTypes())));
721 }
722 
723 // Lifts reads/writes of resource arguments from func_op and changes its
724 // signature. resource_data_types is the (index, data type) pair for each
725 // resource argument. handle_updated_arg_value is a caller-provided function
726 // that handles the updated value for an resource argument.
LiftArgRetResourcesForFunction(FuncOp func_op,const llvm::SmallDenseMap<int64_t,Type> & resource_data_types,llvm::function_ref<void (int64_t,Value)> handle_updated_arg_value)727 LogicalResult LiftArgRetResourcesForFunction(
728     FuncOp func_op,
729     const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
730     llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
731   RegionResourceHoister hoister(func_op);
732   if (failed(hoister.Analyze())) return failure();
733 
734   // Each of these resources could be read or written in the function. If its
735   // read, we need to replace the resource arg with a value arg to get the
736   // read value. If its written, we need to replace the write with an additional
737   // value to be written.
738 
739   // Now create read values that will be used to replace each resource that
740   // is read in the function body. These read values are just the same argument
741   // with type replaced.
742   llvm::SmallVector<Value, 4> skipped_args;
743   for (auto& it : hoister.GetResources()) {
744     BlockArgument arg = it.first.dyn_cast<BlockArgument>();
745     assert(arg && "Expect resources for FuncOp to be its arguments");
746     auto type_iter = resource_data_types.find(arg.getArgNumber());
747     if (type_iter == resource_data_types.end()) {
748       // Skip lifting the resource if it's not present in the data type map.
749       // This indicates that the resource is not to be lifted because it is used
750       // in an unsupported op in some other function.
751       skipped_args.push_back(arg);
752     } else {
753       arg.setType(type_iter->second);
754       it.second.hoisted_read = arg;
755     }
756   }
757 
758   // Drop all the args that have to be skipped.
759   for (Value arg : skipped_args) hoister.DropResource(arg);
760 
761   hoister.ReplaceResourceLoads(/*read_only=*/false);
762 
763   // For writes, invoke the callback and then erase the write.
764   auto assign_ops = func_op.front().getOps<TF::AssignVariableOp>();
765   for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) {
766     Value resource = assign_variable_op.resource();
767     if (!hoister.Contains(resource)) continue;
768 
769     auto arg = resource.dyn_cast<BlockArgument>();
770     handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value());
771     assign_variable_op.erase();
772   }
773 
774   func_op.setType(FunctionType::get(
775       func_op.getContext(), func_op.front().getArgumentTypes(),
776       func_op.front().getTerminator()->getOperandTypes()));
777 
778   return success();
779 }
780 
781 // Returns a vector filtered from range where the unused elements (specified by
782 // resource_arg_uses) are removed.
783 template <typename T, typename Range>
FilterRange(Range range,const llvm::SmallDenseMap<int64_t,ResourceArgUseInfo> & resource_arg_uses)784 llvm::SmallVector<T, 4> FilterRange(
785     Range range,
786     const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
787   llvm::SmallVector<T, 4> filtered;
788   for (auto entry : llvm::enumerate(range)) {
789     auto it = resource_arg_uses.find(entry.index());
790     if (it == resource_arg_uses.end() || it->getSecond().used)
791       filtered.push_back(entry.value());
792   }
793   return filtered;
794 }
795 
796 // Changes the types of the control flow op (e.g., while, if) and adds loads and
797 // stores around it. arg_data_type_and_updated_output_index maps an operand (to
798 // be changed) index to its data type and the updated value index in the output
799 // (-1 means not updated.)
AddLoadsStoresOutsideControlFlowOp(Operation * caller,const llvm::SmallDenseMap<int64_t,std::pair<Type,int64_t>> & arg_data_type_and_updated_output_index)800 void AddLoadsStoresOutsideControlFlowOp(
801     Operation* caller,
802     const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
803         arg_data_type_and_updated_output_index) {
804   OpBuilder builder(caller);
805   auto new_operands = llvm::to_vector<8>(caller->getOperands());
806   llvm::SmallVector<int64_t, 8> changed_indices;
807   // Find the operands to change, and create the loads.
808   for (auto& entry : arg_data_type_and_updated_output_index) {
809     int64_t index = entry.getFirst();
810     Type new_type = entry.getSecond().first;
811     int64_t updated_index = entry.getSecond().second;
812     auto operand = caller->getOperand(index);
813     builder.setInsertionPoint(caller);
814     new_operands[index] = builder.create<TF::ReadVariableOp>(
815         caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand});
816     caller->setOperand(index, new_operands[index]);
817     if (updated_index < 0) continue;
818     builder.setInsertionPointAfter(caller);
819     builder.create<TF::AssignVariableOp>(
820         caller->getLoc(), ArrayRef<Type>{},
821         ArrayRef<Value>{operand, caller->getResult(updated_index)});
822   }
823 }
824 
825 // Lifts loads/stores from while loop's body and cond functions.
HandleWhileLoop(TF::WhileOp while_op,FuncOp body,FuncOp cond)826 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
827   auto return_op = body.front().getTerminator();
828   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
829   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
830   if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
831       failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
832     return failure();
833   }
834   // A resource is considered used as long as it is used in either body or cond.
835   auto resource_arg_uses =
836       MergeArgResourceUseInfo(body_use_info, cond_use_info);
837   if (resource_arg_uses.empty()) return success();
838 
839   // Remove unused resources in functions.
840   llvm::SmallVector<int64_t, 4> old_to_new_indices;
841   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
842   RemoveUnusedResourceArgumentsAndForwardedRetvals(
843       resource_arg_uses, body, &old_to_new_indices,
844       &remaining_resource_data_types);
845   RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
846   (void)LiftArgRetResourcesForFunction(
847       body, remaining_resource_data_types,
848       [&](int64_t index, Value value) { return_op->setOperand(index, value); });
849   (void)LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
850                                        [&](int64_t index, Value value) {
851                                          // We already checked that cond should
852                                          // not have variable writes.
853                                          assert(false && "Should not happen");
854                                        });
855   // Recreate the while op.
856   OpBuilder builder(while_op);
857   // Now use the filtered original operands, which will be replaced by
858   // AddLoadsStoresOutsideControlFlowOp().
859   auto new_while = builder.create<TF::WhileOp>(
860       while_op.getLoc(), body.getType().getResults(),
861       FilterRange<Value, OperandRange>(while_op.getOperands(),
862                                        resource_arg_uses),
863       while_op->getAttrs());
864   // Prepare for AddLoadsStoresOutsideControlFlowOp().
865   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
866       arg_data_type_and_updated_output_index;
867   for (const auto& entry : remaining_resource_data_types) {
868     int64_t update_index = return_op->getOperand(entry.getFirst()) ==
869                                    body.getArgument(entry.getFirst())
870                                ? -1
871                                : entry.getFirst();
872     arg_data_type_and_updated_output_index[entry.getFirst()] = {
873         entry.getSecond(), update_index};
874   }
875   AddLoadsStoresOutsideControlFlowOp(new_while,
876                                      arg_data_type_and_updated_output_index);
877   // Replace uses.
878   for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
879     if (old_to_new_indices[i] >= 0) {
880       while_op.getResult(i).replaceAllUsesWith(
881           new_while.getResult(old_to_new_indices[i]));
882     }
883   }
884   while_op.erase();
885   return success();
886 }
887 
888 // Lifts loads/stores from an IfOp or CaseOp's branches.
889 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches)890 LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
891   // For canonicalized If/Case, there should not be any resource outputs
892   int64_t non_resource_results = op.getNumResults();
893 
894   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
895   if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
896     return failure();
897 
898   for (auto func : branches.drop_front()) {
899     llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
900     if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
901       return failure();
902     // A resource is considered used as long as it is used in either branch.
903     resource_arg_uses =
904         MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
905   }
906 
907   if (resource_arg_uses.empty()) return success();
908   // Remove unused resources in functions.
909   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
910   RemoveUnusedResourceArgumentsAndForwardedRetvals(
911       resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
912       &remaining_resource_data_types);
913   for (auto func : branches.drop_front())
914     RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
915 
916   // Forward resource inputs updated in any branch to the outputs of both
917   // branches. First prepare the mapping from arg to new update output.
918   llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
919   {
920     int64_t removed_args = 0;
921     for (const auto& entry : resource_arg_uses) {
922       if (!entry.getSecond().used) {
923         removed_args++;
924         continue;
925       }
926       if (!entry.getSecond().updated) continue;
927       int64_t new_output_index =
928           non_resource_results + resource_arg_to_new_output.size();
929       resource_arg_to_new_output[entry.getFirst() - removed_args] =
930           new_output_index;
931     }
932   }
933 
934   // Append resource updates to the return ops: now they are just forwarded
935   // input resources, but will be replaced by the data value in
936   // LiftArgRetResourcesForFunction().
937   for (auto branch : branches) {
938     auto new_retvals =
939         llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
940     new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size());
941     for (const auto& entry : resource_arg_to_new_output) {
942       int64_t resource_arg_index = entry.getFirst();
943       int64_t output_index = entry.getSecond();
944       new_retvals[output_index] = branch.getArgument(resource_arg_index);
945     }
946     auto old_return = branch.front().getTerminator();
947     OpBuilder builder(old_return);
948     auto new_return =
949         builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
950     old_return->erase();
951     (void)LiftArgRetResourcesForFunction(
952         branch, remaining_resource_data_types, [&](int64_t index, Value value) {
953           new_return.setOperand(resource_arg_to_new_output[index], value);
954         });
955   }
956 
957   // Recreate the op without resource operands.
958   OpBuilder builder(op);
959   // Now use the filtered original operands, which will be replaced by
960   // AddLoadsStoresOutsideControlFlowOp().
961   auto new_operands =
962       FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
963   new_operands.insert(new_operands.begin(), op.getOperand(0));
964   FuncOp first_func = branches.front();
965   auto new_op =
966       builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
967                                  new_operands, op->getAttrs());
968   // Prepare for AddLoadsStoresOutsideControlFlowOp()
969   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
970       arg_data_type_and_updated_output_index;
971   for (const auto& entry : remaining_resource_data_types) {
972     auto new_output_it = resource_arg_to_new_output.find(entry.getFirst());
973     int64_t update_index = new_output_it == resource_arg_to_new_output.end()
974                                ? -1
975                                : new_output_it->getSecond();
976     arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
977         entry.getSecond(), update_index};
978   }
979   AddLoadsStoresOutsideControlFlowOp(new_op,
980                                      arg_data_type_and_updated_output_index);
981   // Replace uses.
982   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
983   op.erase();
984   return success();
985 }
986 
987 // A resource-lifted function for (potentially multiple) PartitionedCallOps and
988 // information about the lifting changes.
989 struct PartitionedCallLiftingInfo {
990   // Function with resources lifted. Can be nullptr if nothing needs to change.
991   FuncOp lifted_callee;
992   // Mapping from old resource outputs to their aliasing output inputs.
993   llvm::SmallDenseMap<int64_t, int64_t> old_outputs_aliasing_old_inputs;
994   // Mapping from old to new output indices in case any output is removed.
995   llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
996   // ResourceArgUseInfo for each old resource argument.
997   llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> use_info;
998   // Input for AddLoadsStoresOutsideControlFlowOp(), see its comment.
999   llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
1000       arg_data_type_and_updated_output_index;
1001 };
1002 
1003 // Lifts loads/stores from a PartitionedCallOp's callee function. If anything
1004 // needs to be changed, the original function will be preserved, and the lifting
1005 // happens on a clone, which will be stored in `result`.
HandlePartitionedCallOpCallee(FuncOp callee,PartitionedCallLiftingInfo * result)1006 LogicalResult HandlePartitionedCallOpCallee(
1007     FuncOp callee, PartitionedCallLiftingInfo* result) {
1008   // Sanity check: return of resources should be aliases of inputs. Such outputs
1009   // will be removed later.
1010   int64_t non_resource_results = 0;
1011   for (auto entry :
1012        llvm::enumerate(callee.front().getTerminator()->getOperands())) {
1013     auto retval = entry.value();
1014     if (!getElementTypeOrSelf(retval.getType()).isa<TF::ResourceType>()) {
1015       result->old_to_new_output_indices.push_back(non_resource_results++);
1016       continue;
1017     }
1018     auto aliasing_arg = retval.dyn_cast<BlockArgument>();
1019     if (!aliasing_arg) {
1020       return callee.emitOpError("unsupported function call: ")
1021              << "resource return value does not alias an input.";
1022     }
1023     result->old_outputs_aliasing_old_inputs[entry.index()] =
1024         aliasing_arg.getArgNumber();
1025     result->old_to_new_output_indices.push_back(-1);
1026   }
1027 
1028   if (failed(FindResourceArgUseInfo(callee, &result->use_info))) {
1029     return failure();
1030   }
1031   if (result->use_info.empty()) {
1032     result->lifted_callee = nullptr;
1033     return success();
1034   }
1035 
1036   // Clone the callee before making changes.
1037   SmallString<64> name_base = callee.getName();
1038   auto module = callee->getParentOfType<ModuleOp>();
1039   name_base += "_resource_lifted";
1040   auto name = name_base;
1041   callee = callee.clone();
1042   callee.setPrivate();
1043   callee.setName(name);
1044   SymbolTable(module).insert(callee);
1045   result->lifted_callee = callee;
1046 
1047   // Remove unused resources in functions.
1048   llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
1049   RemoveUnusedResourceArgumentsAndForwardedRetvals(
1050       result->use_info, callee, /*old_to_new_arg_indices=*/nullptr,
1051       &remaining_resource_data_types);
1052   for (const auto& entry : remaining_resource_data_types) {
1053     result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
1054         entry.getSecond(), -1};
1055   }
1056   llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
1057   for (auto& val : callee.front().getTerminator()->getOpOperands()) {
1058     // Store indices of results that are not resources.
1059     if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
1060       retval_indices_to_preserve.push_back(val.getOperandNumber());
1061   }
1062   int64_t num_retvals = retval_indices_to_preserve.size();
1063   llvm::SmallVector<Value, 4> new_retvals;
1064   // Lift resources.
1065   (void)LiftArgRetResourcesForFunction(
1066       callee, remaining_resource_data_types, [&](int64_t index, Value value) {
1067         result->arg_data_type_and_updated_output_index[index].second =
1068             num_retvals++;
1069         new_retvals.push_back(value);
1070       });
1071 
1072   auto old_return = callee.front().getTerminator();
1073   llvm::SmallVector<Value, 4> old_and_new_retvals;
1074   old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
1075                               new_retvals.size());
1076   for (int64_t retval_index : retval_indices_to_preserve)
1077     old_and_new_retvals.push_back(old_return->getOperand(retval_index));
1078 
1079   old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
1080   // Replace old return with the new ones with update values.
1081   OpBuilder builder(old_return);
1082   auto new_return =
1083       builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
1084   old_return->erase();
1085   callee.setType(
1086       FunctionType::get(callee.getContext(), callee.getType().getInputs(),
1087                         llvm::to_vector<4>(new_return.getOperandTypes())));
1088   return success();
1089 }
1090 
1091 // Updates a PartitionedCallOp/StatefulPartitionedCallOp according to the
1092 // resource-lifted new callee function in lifting_info.
1093 template <typename CallOpType>
UpdatePartitionedCallOpWithNewCallee(CallOpType call_op,PartitionedCallLiftingInfo & lifting_info)1094 void UpdatePartitionedCallOpWithNewCallee(
1095     CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) {
1096   if (!lifting_info.lifted_callee) return;
1097   // Replace output resource uses with the aliasing input, so that we can remove
1098   // this output.
1099   for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) {
1100     call_op.getResult(entry.getFirst())
1101         .replaceAllUsesWith(call_op.getOperand(entry.getSecond()));
1102   }
1103   // Recreate the call op.
1104   OpBuilder builder(call_op);
1105   // Now use the filtered original operands, which will be replaced by
1106   // AddLoadsStoresOutsideControlFlowOp().
1107   auto new_operands =
1108       FilterRange<Value, OperandRange>(call_op.args(), lifting_info.use_info);
1109   auto new_call = builder.create<CallOpType>(
1110       call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
1111       new_operands, call_op->getAttrs());
1112   new_call->setAttr(
1113       "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
1114   AddLoadsStoresOutsideControlFlowOp(
1115       new_call, lifting_info.arg_data_type_and_updated_output_index);
1116   // Replace uses.
1117   for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
1118        i < end; ++i) {
1119     if (lifting_info.old_to_new_output_indices[i] >= 0) {
1120       call_op.getResult(i).replaceAllUsesWith(
1121           new_call.getResult(lifting_info.old_to_new_output_indices[i]));
1122     }
1123   }
1124   call_op.erase();
1125 }
1126 
1127 LogicalResult HoistForControlFlow(
1128     Block*, ModuleOp, bool,
1129     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*);
1130 
1131 // A templated routine for handling both PartitionedCallOp and
1132 // StatefulPartitionedCallOp. If the callee is already lifted, it just updates
1133 // the caller op itself; otherwise, it first recursively handles nested control
1134 // flow, then performs lifting on the callee.
1135 template <typename CallOpType>
HandlePartitionedCallOp(CallOpType call_op,FuncOp callee,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_callees)1136 LogicalResult HandlePartitionedCallOp(
1137     CallOpType call_op, FuncOp callee, ModuleOp module, bool vars_initialized,
1138     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1139         lifted_callees) {
1140   auto emplace_res = lifted_callees->try_emplace(callee.getName(),
1141                                                  PartitionedCallLiftingInfo());
1142   if (emplace_res.second) {
1143     // Unseen callee. Perform resource lifting on it.
1144     if (failed(HoistForControlFlow(&callee.front(), module, vars_initialized,
1145                                    lifted_callees)))
1146       return failure();
1147 
1148     if (failed(HandlePartitionedCallOpCallee(
1149             callee, &emplace_res.first->getSecond()))) {
1150       return failure();
1151     }
1152   }
1153   UpdatePartitionedCallOpWithNewCallee(call_op, emplace_res.first->getSecond());
1154   return success();
1155 }
1156 
1157 // Hoists resource loads/stores from control flow ops in `block` outside the
1158 // body/cond/branch/callee functions.
HoistForControlFlow(Block * block,ModuleOp module,bool vars_initialized,llvm::SmallDenseMap<llvm::StringRef,PartitionedCallLiftingInfo> * lifted_partitioned_call_callees)1159 LogicalResult HoistForControlFlow(
1160     Block* block, ModuleOp module, bool vars_initialized,
1161     llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>*
1162         lifted_partitioned_call_callees) {
1163   if (vars_initialized) SetAllVarIsInitializedToTrue(block);
1164 
1165   for (Operation& op : llvm::make_early_inc_range(*block)) {
1166     if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
1167       auto body = while_op.body_function();
1168       auto cond = while_op.cond_function();
1169       // Recursively handle the nested control flow.
1170       (void)HoistForControlFlow(&body.front(), module, vars_initialized,
1171                                 lifted_partitioned_call_callees);
1172       (void)HoistForControlFlow(&cond.front(), module, vars_initialized,
1173                                 lifted_partitioned_call_callees);
1174       if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
1175     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
1176       auto then_branch = if_op.then_function();
1177       auto else_branch = if_op.else_function();
1178       // Recursively handle the nested control flow.
1179       (void)HoistForControlFlow(&then_branch.front(), module, vars_initialized,
1180                                 lifted_partitioned_call_callees);
1181       (void)HoistForControlFlow(&else_branch.front(), module, vars_initialized,
1182                                 lifted_partitioned_call_callees);
1183       if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
1184         return failure();
1185     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
1186       SmallVector<FuncOp, 4> branch_functions;
1187       case_op.get_branch_functions(branch_functions);
1188       for (FuncOp func : branch_functions) {
1189         // Recursively handle the nested control flow.
1190         (void)HoistForControlFlow(&func.front(), module, vars_initialized,
1191                                   lifted_partitioned_call_callees);
1192       }
1193       if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
1194     } else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
1195       auto callee = call_op.func();
1196       if (!callee) {
1197         return call_op.emitOpError(
1198             "resource lifting does not support call with nested references.");
1199       }
1200       if (failed(HandlePartitionedCallOp(call_op, callee, module,
1201                                          vars_initialized,
1202                                          lifted_partitioned_call_callees))) {
1203         // Nested control flow handling is done in HandlePartitionedCallOp().
1204         return failure();
1205       }
1206     } else if (auto call_op =
1207                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
1208       if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module,
1209                                          vars_initialized,
1210                                          lifted_partitioned_call_callees))) {
1211         return failure();
1212       }
1213     } else if (isa<TF::IfRegionOp, TF::CaseRegionOp, TF::WhileRegionOp>(op)) {
1214       for (Region& region : op.getRegions())
1215         (void)HoistForControlFlow(&region.front(), module, vars_initialized,
1216                                   lifted_partitioned_call_callees);
1217       LogicalResult result = RegionResourceHoister::ReplaceOpWithNewOp(&op);
1218       if (failed(result)) return failure();
1219     }
1220   }
1221 
1222   // After we have hoisted operations in the block, we may have added new read
1223   // and writes of resources to this block. Clean them up by doing store-load
1224   // forwarding.
1225   ForwardStoreToLoad(block);
1226   return success();
1227 }
1228 
1229 // Lifts resource operation from tf_device.cluster ops nested in `op` outside.
1230 // Returns failure if there are remaining resource-type values that can not be
1231 // lifted.
runOnOperation()1232 void ResourceOpLiftingPass::runOnOperation() {
1233   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1234       lifted_partitioned_call_callees;
1235   ModuleOp module = getOperation();
1236 
1237   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(module)))
1238     return signalPassFailure();
1239 
1240   auto walk_result = module.walk([&](FuncOp func_op) {
1241     return func_op.walk([&](tf_device::ClusterOp cluster) {
1242       LogicalResult result = HoistForControlFlow(
1243           &cluster.GetBody(), module, /*vars_initialized=*/true,
1244           &lifted_partitioned_call_callees);
1245       if (failed(result)) return WalkResult::interrupt();
1246       result = RegionResourceHoister::ReplaceOpWithNewOp(cluster);
1247       if (failed(result)) return WalkResult::interrupt();
1248       return WalkResult::advance();
1249     });
1250   });
1251 
1252   if (walk_result.wasInterrupted()) return signalPassFailure();
1253 }
1254 
1255 struct ResourceOpLiftingForMainFunctionPass
1256     : public TFDevice::ResourceOpLiftingForMainFunctionPassBase<
1257           ResourceOpLiftingForMainFunctionPass> {
1258   void runOnOperation() override;
1259 };
1260 
runOnOperation()1261 void ResourceOpLiftingForMainFunctionPass::runOnOperation() {
1262   ModuleOp module = getOperation();
1263   FuncOp main_func = module.lookupSymbol<FuncOp>("main");
1264   if (!main_func) {
1265     return;
1266   }
1267 
1268   if (failed(TF::ResourceLiftingForFunctionalControlFlow(main_func))) {
1269     return signalPassFailure();
1270   }
1271 }
1272 
1273 }  // namespace
1274 
1275 namespace TFDevice {
CreateResourceOpLiftingPass()1276 std::unique_ptr<OperationPass<ModuleOp>> CreateResourceOpLiftingPass() {
1277   return std::make_unique<ResourceOpLiftingPass>();
1278 }
1279 
1280 std::unique_ptr<OperationPass<ModuleOp>>
CreateResourceOpLiftingForMainFunctionPass()1281 CreateResourceOpLiftingForMainFunctionPass() {
1282   return std::make_unique<ResourceOpLiftingForMainFunctionPass>();
1283 }
1284 
1285 }  // namespace TFDevice
1286 
1287 namespace TF {
ResourceLiftingForFunctionalControlFlow(FuncOp function)1288 LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
1289   // This routine should only be called when control flow operations are still
1290   // represented with TF IfOp and WhileOp operations. In this case, there should
1291   // be only one basic blocks in the MLIR representation.
1292   if (!llvm::hasSingleElement(function)) {
1293     return function.emitError()
1294            << "expect the function to have 1 block while it has "
1295            << function.getBlocks().size();
1296   }
1297 
1298   if (failed(TF::CleanupAndCanonicalizeForResourceOpLifting(function)))
1299     return failure();
1300 
1301   llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
1302       lifted_partitioned_call_callees;
1303   if (failed(HoistForControlFlow(
1304           &function.front(), cast<ModuleOp>(function->getParentOp()),
1305           /*vars_initialized=*/false, &lifted_partitioned_call_callees)))
1306     return failure();
1307 
1308   // Clean up and canonicalize to remove dead local variables as some local
1309   // variables might be dead after hoisting resource loads/stores from control
1310   // flow ops.
1311   return TF::CleanupAndCanonicalizeForResourceOpLifting(function);
1312 }
1313 }  // namespace TF
1314 
1315 }  // namespace mlir
1316