• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <deque>
18 #include <tuple>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Transforms/Passes.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
34 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
35 
36 using llvm::ArrayRef;
37 using mlir::failure;
38 using mlir::success;
39 
40 namespace tensorflow {
41 namespace {
42 
43 constexpr char kCpuDeviceName[] =
44     "/job:localhost/replica:0/task:0/device:CPU:0";
45 
46 // Tuple storing {device, container, shared_name} attributes in that order.
47 using ResourceKey =
48     std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>;
49 
IsSessionInitializer(mlir::FuncOp op)50 bool IsSessionInitializer(mlir::FuncOp op) {
51   auto session_initializer_op = mlir::tf_saved_model::GetSessionInitializerOp(
52       op->getParentOfType<mlir::ModuleOp>());
53   if (!session_initializer_op) return false;
54 
55   for (auto sym_ref : session_initializer_op.initializers()) {
56     if (op.sym_name() == sym_ref.cast<mlir::FlatSymbolRefAttr>().getValue())
57       return true;
58   }
59 
60   return false;
61 }
62 
GetResourceKey(mlir::Operation * op)63 ResourceKey GetResourceKey(mlir::Operation *op) {
64   llvm::StringRef device;
65   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
66     device = attr.getValue();
67   }
68 
69   llvm::StringRef container;
70   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
71     container = attr.getValue();
72   }
73 
74   llvm::StringRef shared_name;
75   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
76     shared_name = attr.getValue();
77   }
78 
79   return {device, container, shared_name};
80 }
81 
82 struct HoistInfo {
83   // All hoisted ops in the topological order.
84   llvm::SmallVector<mlir::Operation *, 4> hoists_in_topological_order;
85 
86   // Mapping from the old values produced by hoisted ops before hoisting to the
87   // new values after hoisting.
88   mlir::BlockAndValueMapping value_mapping;
89 
90   // `hoisted_values` is to keep all values that are produced by hoisted ops
91   // but used by non-hoisted ops. These values will be replaced by results of
92   // tf._TfrtGetResource op. The index of each value in this array will be the
93   // index used in tf._TfrtGetResource and tf._TfrtSetResource op. This also
94   // stores the ResourceKey which has the shared_name and container attributes
95   // used by later resource alias analysis and side effect analysis passes.
96   llvm::SmallVector<std::pair<mlir::Value, ResourceKey>, 4> hoisted_values;
97 };
98 
ReplaceHoistedValues(llvm::ArrayRef<std::pair<mlir::Value,ResourceKey>> hoisted_values,mlir::OpBuilder & builder)99 void ReplaceHoistedValues(
100     llvm::ArrayRef<std::pair<mlir::Value, ResourceKey>> hoisted_values,
101     mlir::OpBuilder &builder) {
102   struct HoistedValueInfo {
103     llvm::SmallVector<mlir::Value, 4> hoisted_values;
104     llvm::SmallVector<int64_t, 4> indices;
105     llvm::SmallVector<llvm::StringRef, 4> shared_names;
106     llvm::SmallVector<llvm::StringRef, 4> containers;
107   };
108   // Rearrange the hoisted values by each function and each device.
109   llvm::DenseMap<mlir::Block *, llvm::StringMap<HoistedValueInfo>>
110       hoisted_values_by_block_device;
111 
112   // Find a block where to place tf._TfrtGetResource operation. We do not place
113   // get resource operations inside the `tf_device.cluster` operations, because
114   // these blocks are intended for later on-device compilation. Insert resource
115   // reads to the closest block outside of the `tf_device.cluster` operation.
116   auto hoist_into_block = [](mlir::Value value) -> mlir::Block * {
117     mlir::Operation *cluster_op =
118         value.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
119     return cluster_op ? cluster_op->getBlock() : value.getParentBlock();
120   };
121 
122   for (auto iter : llvm::enumerate(hoisted_values)) {
123     auto value = iter.value().first;
124     auto container = std::get<1>(iter.value().second);
125     auto shared_name = std::get<2>(iter.value().second);
126     auto index = iter.index();
127     auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)];
128 
129     assert(value.getDefiningOp() && "hoisted values must not be arguments.");
130     llvm::StringRef device = kCpuDeviceName;
131     if (auto device_attr =
132             value.getDefiningOp()->getAttrOfType<mlir::StringAttr>("device")) {
133       if (!device_attr.getValue().empty()) device = device_attr.getValue();
134     }
135 
136     auto &item = device_map[device];
137 
138     item.hoisted_values.push_back(value);
139     item.indices.push_back(index);
140     item.shared_names.push_back(shared_name);
141     item.containers.push_back(container);
142   }
143 
144   // Create tf._TfrtGetResource op for each function and device.
145   for (const auto &block_iter : hoisted_values_by_block_device) {
146     auto *block = block_iter.first;
147     const auto &device_map = block_iter.second;
148 
149     builder.setInsertionPointToStart(block);
150     for (const auto &device_iter : device_map) {
151       llvm::StringRef device = device_iter.getKey();
152       mlir::ValueRange old_values = device_iter.getValue().hoisted_values;
153       const auto &indices = device_iter.getValue().indices;
154       const auto &shared_name_arr = device_iter.getValue().shared_names;
155       const auto &container_arr = device_iter.getValue().containers;
156 
157       auto get_resource_op = builder.create<mlir::TF::_TfrtGetResourceOp>(
158           block->getParentOp()->getLoc(), old_values.getTypes(),
159           builder.getI64ArrayAttr(indices),
160           builder.getStrArrayAttr(shared_name_arr),
161           builder.getStrArrayAttr(container_arr));
162       get_resource_op->setAttr("device", builder.getStringAttr(device));
163 
164       auto new_values = get_resource_op.results();
165       for (auto iter : llvm::zip(old_values, new_values)) {
166         auto old_value = std::get<0>(iter);
167         auto new_value = std::get<1>(iter);
168         old_value.replaceAllUsesWith(new_value);
169       }
170     }
171   }
172 }
173 
OnlyHasReadEffect(mlir::Operation * op)174 bool OnlyHasReadEffect(mlir::Operation *op) {
175   auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op);
176   if (!interface) return false;
177   return interface.onlyHasEffect<mlir::MemoryEffects::Read>();
178 }
179 
CanHoist(const llvm::DenseSet<ResourceKey> & read_only_vars,mlir::Operation * op)180 bool CanHoist(const llvm::DenseSet<ResourceKey> &read_only_vars,
181               mlir::Operation *op) {
182   // return ops should not be hoisted.
183   if (op->mightHaveTrait<mlir::OpTrait::IsTerminator>()) return false;
184 
185   // Non-side-effecting ops can be hoisted.
186   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return true;
187 
188   // ResourceHandle ops can be hoisted.
189   if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op))
190     return true;
191 
192   // If it is ReadVariableOp and the variable is readonly, it can be hoisted.
193   if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
194     if (auto var_handle_op = llvm::dyn_cast_or_null<mlir::TF::VarHandleOp>(
195             read_var_op.resource().getDefiningOp())) {
196       if (read_only_vars.count(GetResourceKey(var_handle_op)) > 0) return true;
197     }
198   }
199 
200   // If it is LookupTableSizeOp, it can be hoisted as the size of the hash table
201   // cannot be changed after initialization.
202   if (auto lookup_table_size_op =
203           llvm::dyn_cast<mlir::TF::LookupTableSizeV2Op>(op)) {
204     if (auto hash_table_op = llvm::dyn_cast_or_null<mlir::TF::HashTableV2Op>(
205             lookup_table_size_op.table_handle().getDefiningOp())) {
206       if (read_only_vars.count(GetResourceKey(hash_table_op)) > 0) return true;
207     }
208   }
209 
210   // TODO(chky): Allow more readonly ops.
211 
212   return false;
213 }
214 
HoistInvariantOpsInFunction(mlir::FuncOp func,const llvm::DenseSet<ResourceKey> & read_only_vars,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,mlir::OpBuilder & builder,HoistInfo & module_hoist_info)215 void HoistInvariantOpsInFunction(
216     mlir::FuncOp func, const llvm::DenseSet<ResourceKey> &read_only_vars,
217     const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
218     mlir::OpBuilder &builder, HoistInfo &module_hoist_info) {
219   // Keep the hoisted ops in this function.
220   llvm::DenseSet<mlir::Operation *> hoists;
221 
222   auto all_operands_in_hoists = [&module_hoist_info](mlir::Operation *op) {
223     for (mlir::Value operand : op->getOperands()) {
224       if (module_hoist_info.value_mapping.lookupOrNull(operand) == nullptr)
225         return false;
226     }
227     return true;
228   };
229 
230   auto all_control_predeccessors_in_hoists = [&hoists, &side_effect_analysis](
231                                                  mlir::Operation *op) {
232     auto preds = side_effect_analysis.DirectControlPredecessors(op);
233     return std::all_of(
234         preds.begin(), preds.end(),
235         [&hoists](mlir::Operation *pred) { return hoists.count(pred) > 0; });
236   };
237 
238   std::deque<mlir::Operation *> work_list;
239 
240   // Start with ops with tf.VarHandleOp ops and tf.Const ops.
241   //
242   // TODO(chky): Consider allowing other ops including custom ops to be hoisted.
243   func.walk([&work_list](mlir::Operation *op) {
244     if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op,
245                   mlir::TF::ConstOp>(op))
246       work_list.push_back(op);
247   });
248 
249   while (!work_list.empty()) {
250     auto *op = work_list.front();
251     work_list.pop_front();
252 
253     // Skip if it is already hoisted.
254     if (hoists.count(op) > 0) continue;
255 
256     // If the op can be hoisted, and all of its data dependencies and control
257     // dependencies are hoisted, then we hoist it. Otherwise, skip.
258     if (!(CanHoist(read_only_vars, op) && all_operands_in_hoists(op) &&
259           all_control_predeccessors_in_hoists(op)))
260       continue;
261 
262     // Record the hoisted operation.
263     hoists.insert(op);
264     module_hoist_info.hoists_in_topological_order.push_back(op);
265 
266     // Create a copy in the init function.
267     builder.clone(*op, module_hoist_info.value_mapping);
268 
269     for (mlir::Operation *user : op->getUsers()) {
270       work_list.push_back(user);
271     }
272   }
273 
274   // Find out the values that are produced by hoisted ops but used by
275   // non-hoisted ops. These values need to be replaced.
276   for (auto *op : hoists) {
277     for (auto result : op->getResults()) {
278       if (std::any_of(result.getUsers().begin(), result.getUsers().end(),
279                       [&hoists](mlir::Operation *user) {
280                         return hoists.count(user) == 0;
281                       })) {
282         module_hoist_info.hoisted_values.push_back(
283             {result, GetResourceKey(op)});
284       }
285     }
286   }
287 }
288 
FindCalleesRecursive(const mlir::SymbolTable & symbol_table,mlir::FuncOp func,llvm::StringSet<> & callees)289 void FindCalleesRecursive(const mlir::SymbolTable &symbol_table,
290                           mlir::FuncOp func, llvm::StringSet<> &callees) {
291   assert(func);
292   func.walk([&](mlir::Operation *op) {
293     for (const auto &named_attr : op->getAttrs()) {
294       if (auto symbol_attr =
295               named_attr.second.dyn_cast<mlir::FlatSymbolRefAttr>()) {
296         auto symbol = symbol_attr.getValue();
297         if (!callees.contains(symbol)) {
298           callees.insert(symbol);
299 
300           auto func = symbol_table.lookup<mlir::FuncOp>(symbol);
301           if (!func) continue;
302 
303           FindCalleesRecursive(symbol_table, func, callees);
304         }
305       }
306     }
307   });
308 }
309 
HoistInvariantOps(mlir::ModuleOp module)310 void HoistInvariantOps(mlir::ModuleOp module) {
311   mlir::SymbolTable symbol_table(module);
312 
313   // Find all resources used in non-init functions.
314   llvm::DenseMap<ResourceKey, llvm::SmallVector<mlir::Operation *, 4>>
315       resources;
316 
317   // Find all callees referenced in the initialization functions.
318   llvm::StringSet<> init_callees;
319 
320   module.walk([&](mlir::Operation *op) {
321     if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op)) {
322       auto func = op->getParentOfType<mlir::FuncOp>();
323       if (IsSessionInitializer(func)) return;
324       resources[GetResourceKey(op)].push_back(op);
325     } else if (auto func = llvm::dyn_cast<mlir::FuncOp>(op)) {
326       if (!IsSessionInitializer(func)) return;
327       FindCalleesRecursive(symbol_table, func, init_callees);
328     }
329   });
330 
331   llvm::DenseSet<ResourceKey> read_only_vars;
332   for (const auto &iter : resources) {
333     const auto &key = iter.first;
334     const auto &vars = iter.second;
335     if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) {
336           for (auto *user : op->getUsers()) {
337             if (!OnlyHasReadEffect(user)) return false;
338           }
339           return true;
340         })) {
341       read_only_vars.insert(key);
342     }
343   }
344 
345   mlir::TF::SideEffectAnalysis side_effect_analysis(module);
346 
347   mlir::OpBuilder builder(&module.body());
348   // "_tfrt_resource_init" is the special function that executes all invariant
349   // ops (eg. read-only variables) used in the model. This function should be
350   // executed after user-specified initialization.
351   auto init_func_op = builder.create<mlir::FuncOp>(
352       module.getLoc(), "_tfrt_resource_init",
353       mlir::FunctionType::get(module.getContext(), /*inputs=*/{},
354                               /*results=*/{}));
355   auto *block = init_func_op.addEntryBlock();
356   builder.setInsertionPointToStart(block);
357 
358   HoistInfo module_hoist_info;
359 
360   for (auto func : module.getOps<mlir::FuncOp>()) {
361     // Skips hoisting if this function is an init function or any callees,
362     // including recursive ones, of an init functions, because otherwise the
363     // hoisted values won't be initialized when this function is called.
364     if (IsSessionInitializer(func) || init_callees.contains(func.sym_name()) ||
365         func == init_func_op)
366       continue;
367 
368     HoistInvariantOpsInFunction(func, read_only_vars,
369                                 side_effect_analysis.GetAnalysisForFunc(func),
370                                 builder, module_hoist_info);
371   }
372 
373   // Create tf._TfrtSetResource ops in the init function.
374   for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) {
375     mlir::Value value = iter.value().first;
376     int64_t index = iter.index();
377 
378     auto new_value = module_hoist_info.value_mapping.lookup(value);
379     auto *new_op = new_value.getDefiningOp();
380     assert(new_op);
381     builder.setInsertionPointAfter(new_op);
382     auto set_resource_op = builder.create<mlir::TF::_TfrtSetResourceOp>(
383         new_op->getLoc(), new_value, index);
384 
385     // Preserve the device attribute.
386     llvm::StringRef device = kCpuDeviceName;
387     if (auto device_attr = new_op->getAttrOfType<mlir::StringAttr>("device")) {
388       if (!device_attr.getValue().empty()) device = device_attr.getValue();
389     }
390     set_resource_op->setAttr("device", builder.getStringAttr(device));
391   }
392 
393   builder.setInsertionPointToEnd(block);
394   // Finish building the init function by inserting an return op.
395   builder.create<mlir::ReturnOp>(init_func_op.getLoc());
396 
397   // Now that we have the index for each value that will be replaced, we can
398   // create the tf._TfrtGetResource op in each function using these indices.
399   ReplaceHoistedValues(module_hoist_info.hoisted_values, builder);
400 
401   // Lastly, erase the hoisted ops in reverse topological order.
402   for (auto *op :
403        llvm::reverse(module_hoist_info.hoists_in_topological_order)) {
404     assert(op->use_empty());
405     op->erase();
406   }
407 }
408 
409 // This pass rewrites tf_saved_model dialect's ops according to TFRT's
410 // requirements:
411 //
412 // 1) Remove all tf_saved_model's attributes and ops.
413 // 2) Create a function for every exported names of the original function.
414 // 3) Promote all uses of global tensors from resource handles to the underlying
415 // tensors.
416 // 4) Hoist invariant ops (ie. guaranteed to return the same value on every
417 // invocation) for every non-init function.
418 //
419 class LowerTFSavedModelPass
420     : public mlir::PassWrapper<LowerTFSavedModelPass,
421                                mlir::OperationPass<mlir::ModuleOp>> {
422  public:
LowerTFSavedModelPass(bool hoist_invariant_ops)423   explicit LowerTFSavedModelPass(bool hoist_invariant_ops) {
424     hoist_invariant_ops_ = hoist_invariant_ops;
425   }
426   LowerTFSavedModelPass() = default;
LowerTFSavedModelPass(const LowerTFSavedModelPass &)427   LowerTFSavedModelPass(const LowerTFSavedModelPass &) {}
428 
getArgument() const429   llvm::StringRef getArgument() const final {
430     return "tfrt-lower-tf-savedmodel";
431   }
getDescription() const432   llvm::StringRef getDescription() const final {
433     return "Lower tf-saved-model ops according to TFRT's requirements.";
434   }
435 
runOnOperation()436   void runOnOperation() override {
437     auto module = getOperation();
438 
439     // TODO(b/185928201): Create a standalone pass for hoisting invariant ops so
440     // that it can be reusable and configurable in other contexts than saved
441     // models.
442     if (hoist_invariant_ops_) HoistInvariantOps(module);
443 
444     // Skip non-savedmodel MLIR module.
445     if (!mlir::tf_saved_model::HasTfSavedModelSemantics(module)) return;
446 
447     mlir::SymbolTable symbol_table(module);
448 
449     // TODO(b/177590991): Remove PromoteGlobalTensors() once non lite MLIR
450     // importer is no longer used. PromoteGlobalTensors() is only used for non
451     // lite MLIR importer which rewrites resource variables to global_tensors.
452     // However, for many models it is not supported.
453     for (auto func : module.getOps<mlir::FuncOp>()) {
454       if (mlir::tf_saved_model::IsExported(func)) {
455         if (mlir::failed(PromoteGlobalTensors(func, symbol_table))) {
456           func.emitOpError("failed to promote resource variables.");
457           signalPassFailure();
458           return;
459         }
460       }
461     }
462 
463     module->removeAttr("tf_saved_model.semantics");
464 
465     mlir::OpBuilder builder(&getContext());
466     auto resource_id = builder.getIdentifier("tf.resource_name");
467     auto bound_id = builder.getIdentifier("tf_saved_model.bound_input");
468     auto path_id = builder.getIdentifier("tf_saved_model.index_path");
469 
470     module.walk([resource_id, bound_id, path_id,
471                  &builder](mlir::Operation *op) mutable {
472       if (auto func_op = llvm::dyn_cast<mlir::FuncOp>(op)) {
473         // Remove tf_saved_model specific function arg attributes.
474         for (unsigned i = 0, e = func_op.getNumArguments(); i != e; ++i) {
475           if (auto sym = func_op.getArgAttrOfType<mlir::FlatSymbolRefAttr>(
476                   i, bound_id)) {
477             func_op.removeArgAttr(i, bound_id);
478             func_op.setArgAttr(i, resource_id,
479                                builder.getStringAttr(sym.getValue()));
480           }
481           func_op.removeArgAttr(i, path_id);
482         }
483         for (unsigned i = 0, e = func_op.getNumResults(); i != e; ++i) {
484           func_op.removeResultAttr(i, bound_id);
485           func_op.removeResultAttr(i, path_id);
486         }
487         if (auto exported_names = func_op->getAttrOfType<mlir::ArrayAttr>(
488                 "tf_saved_model.exported_names")) {
489           bool is_session_initializer = IsSessionInitializer(func_op);
490 
491           // Create a function for each exported name.
492           //
493           // TODO(b/148477882): TFRT dialect should have similar concepts of
494           // exported names so that a function can be referenced by multiple
495           // exported names.
496           func_op->removeAttr("tf_saved_model.exported_names");
497           for (auto exported_name : exported_names) {
498             auto exported_func_op = func_op.clone();
499             exported_func_op.setName(
500                 exported_name.cast<mlir::StringAttr>().getValue());
501 
502             // If it is a session initializer, we want to maximize parallelism
503             // and do not perform any stream merge, to minimize latency.
504             //
505             // TODO(b/183219530): This is a workaround as the cost model used
506             // currently is not very accurate, and leads to performance
507             // regression on IO ops that are common in initialization functions.
508             if (is_session_initializer) {
509               exported_func_op->setAttr("tfrt.cost_threshold",
510                                         builder.getI64IntegerAttr(1));
511             }
512 
513             builder.setInsertionPoint(func_op);
514             builder.insert(exported_func_op);
515           }
516           func_op.erase();
517         }
518       }
519     });
520 
521     module.walk([](mlir::Operation *op) {
522       if (llvm::isa<mlir::tf_saved_model::TensorFlowSavedModelDialect>(
523               op->getDialect())) {
524         // Remove all tf_saved_model ops.
525         op->erase();
526       }
527     });
528   }
529 
530  private:
531   // Promote global tensors used by an exported function.
532   mlir::LogicalResult PromoteGlobalTensors(
533       mlir::FuncOp op, const mlir::SymbolTable &symbol_table);
534 
535   // Replace a function argument that is a resource hanndle with an argument of
536   // the underlying tensor type. It also replaces all its uses recursively.
537   mlir::LogicalResult PromoteFunctionArgument(
538       mlir::FuncOp func, unsigned arg_index, mlir::Type promoted_type,
539       const mlir::SymbolTable &symbol_table);
540 
541   // Replace an operand that is a resource handle with an operand of the
542   // underlying type and replace all uses of this operation if the results are
543   // also promoted. If it is a control flow op, it will process the callees
544   // recursively. The original op will be invalidated in some cases.
545   mlir::LogicalResult PromoteOpOperand(mlir::Operation *op,
546                                        unsigned operand_number,
547                                        mlir::Value promoted,
548                                        const mlir::SymbolTable &symbol_table);
549 
550   // Replace all uses of a resource handle value with its promoted version
551   // recursively.
552   mlir::LogicalResult PromoteValueUses(mlir::Value old, mlir::Value promoted,
553                                        const mlir::SymbolTable &symbol_table);
554 
555   Option<bool> hoist_invariant_ops_{*this, "hoist-invariant-ops",
556                                     llvm::cl::desc("hoist-invariant-ops"),
557                                     llvm::cl::init(false)};
558 };
559 
CompareTypes(mlir::TypeRange x,mlir::TypeRange y)560 static llvm::SmallVector<unsigned, 4> CompareTypes(mlir::TypeRange x,
561                                                    mlir::TypeRange y) {
562   llvm::SmallVector<unsigned, 4> results;
563   assert(x.size() == y.size());
564   for (int i = 0, e = x.size(); i < e; ++i) {
565     if (x[i] != y[i]) results.push_back(i);
566   }
567   return results;
568 }
569 
PromoteGlobalTensors(mlir::FuncOp op,const mlir::SymbolTable & symbol_table)570 mlir::LogicalResult LowerTFSavedModelPass::PromoteGlobalTensors(
571     mlir::FuncOp op, const mlir::SymbolTable &symbol_table) {
572   for (int i = 0, e = op.getNumArguments(); i < e; ++i) {
573     auto global_tensor_op = mlir::tf_saved_model::LookupBoundInputOfType<
574         mlir::tf_saved_model::GlobalTensorOp>(op, i, symbol_table);
575     if (!global_tensor_op) continue;
576 
577     auto result_types = op.getType().getResults();
578     if (failed(PromoteFunctionArgument(op, i, global_tensor_op.type(),
579                                        symbol_table)))
580       return failure();
581 
582     if (!CompareTypes(op.getType().getResults(), result_types).empty())
583       op.emitOpError("cannot promote exported functions's results");
584   }
585   return success();
586 }
587 
PromoteFunctionArgument(mlir::FuncOp func,unsigned arg_index,mlir::Type promoted_type,const mlir::SymbolTable & symbol_table)588 mlir::LogicalResult LowerTFSavedModelPass::PromoteFunctionArgument(
589     mlir::FuncOp func, unsigned arg_index, mlir::Type promoted_type,
590     const mlir::SymbolTable &symbol_table) {
591   // Replace this argument before replacing its uses.
592   auto &block = func.front();
593   auto arg = block.getArgument(arg_index);
594 
595   auto cleanup_on_failure = llvm::make_scope_exit(
596       [&, orig_type = arg.getType()]() { arg.setType(orig_type); });
597 
598   arg.setType(promoted_type);
599 
600   // Promote all uses of `arg`.
601   if (failed(PromoteValueUses(arg, arg, symbol_table))) return failure();
602 
603   cleanup_on_failure.release();
604 
605   // Update the function type accordingly.
606   auto return_op = llvm::cast<mlir::ReturnOp>(block.getTerminator());
607   auto new_results = return_op.operands();
608 
609   func.setType(mlir::FunctionType::get(
610       func.getContext(), block.getArgumentTypes(), new_results.getTypes()));
611   return success();
612 }
613 
PromoteOpOperand(mlir::Operation * op,unsigned operand_number,mlir::Value promoted,const mlir::SymbolTable & symbol_table)614 mlir::LogicalResult LowerTFSavedModelPass::PromoteOpOperand(
615     mlir::Operation *op, unsigned operand_number, mlir::Value promoted,
616     const mlir::SymbolTable &symbol_table) {
617   // TODO(chky): Consider a more scalable way to handling all read-only ops.
618 
619   // If it is a ReadVariableOp, we just need to replace all its uses and erase
620   // this op.
621   if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
622     read_var_op.value().replaceAllUsesWith(promoted);
623     op->erase();
624     return success();
625   }
626 
627   // Next, we handle control flow ops.
628   if (!llvm::isa<mlir::TF::IfOp, mlir::TF::CaseOp, mlir::TF::WhileOp,
629                  mlir::CallOpInterface, mlir::TF::BatchFunctionOp,
630                  mlir::ReturnOp>(op))
631     return op->emitOpError("unsupported users of resource variables");
632 
633   llvm::SmallVector<unsigned, 2> promoted_result_indices;
634   auto update_promoted_result_indices =
635       [&promoted_result_indices](
636           mlir::Operation *op,
637           ArrayRef<mlir::Type> result_types) -> mlir::LogicalResult {
638     if (op->getNumResults() != result_types.size())
639       return op->emitOpError(
640           "cannot promote call ops whose op resutls do not fully match the "
641           "callee results");
642 
643     auto result = CompareTypes(op->getResultTypes(), result_types);
644     if (promoted_result_indices.empty()) {
645       promoted_result_indices.assign(result.begin(), result.end());
646     } else {
647       // We cannot handle the case where two branches' results are promoted
648       // differently.
649       if (promoted_result_indices != result)
650         return op->emitOpError(
651             "cannot promote callees with different result types");
652     }
653     return success();
654   };
655 
656   if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
657     if (operand_number == 0)
658       return if_op.emitOpError("cannot promote cond tensor for tf.If");
659 
660     auto then_branch = symbol_table.lookup<mlir::FuncOp>(if_op.then_branch());
661     auto else_branch = symbol_table.lookup<mlir::FuncOp>(if_op.else_branch());
662     assert(then_branch);
663     assert(else_branch);
664 
665     unsigned arg_index = operand_number - 1;
666     for (auto func : {then_branch, else_branch}) {
667       if (func.getType().getInput(arg_index) != promoted.getType()) {
668         // Rescursively promote the uses in branches.
669         if (failed(PromoteFunctionArgument(func, arg_index, promoted.getType(),
670                                            symbol_table)))
671           return failure();
672       }
673 
674       if (failed(update_promoted_result_indices(if_op,
675                                                 func.getType().getResults())))
676         return failure();
677     }
678   } else if (auto case_op = llvm::dyn_cast<mlir::TF::CaseOp>(op)) {
679     assert(operand_number > 0);
680     unsigned arg_index = operand_number - 1;
681     for (auto branch_attr : case_op.branches()) {
682       auto branch = symbol_table.lookup<mlir::FuncOp>(
683           branch_attr.cast<mlir::FlatSymbolRefAttr>().getValue());
684 
685       if (branch.getType().getInput(arg_index) != promoted.getType()) {
686         // Rescursively promote the uses in branches.
687         if (failed(PromoteFunctionArgument(branch, arg_index,
688                                            promoted.getType(), symbol_table)))
689           return failure();
690       }
691 
692       if (failed(update_promoted_result_indices(case_op,
693                                                 branch.getType().getResults())))
694         return failure();
695     }
696   } else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
697     auto cond = symbol_table.lookup<mlir::FuncOp>(while_op.cond());
698     auto body = symbol_table.lookup<mlir::FuncOp>(while_op.body());
699     assert(cond);
700     assert(body);
701 
702     unsigned arg_index = operand_number;
703     if (cond.getType().getInput(arg_index) != promoted.getType()) {
704       auto cond_result_type = cond.getType().getResult(0);
705       if (failed(PromoteFunctionArgument(cond, arg_index, promoted.getType(),
706                                          symbol_table)))
707         return failure();
708 
709       // We cannot promote the result of cond branch as it may change the
710       // behavior of this while op.
711       if (cond_result_type != cond.getType().getResult(0))
712         return while_op.emitOpError("failed to promote cond for tf.While");
713     }
714 
715     if (body.getType().getInput(arg_index) != promoted.getType()) {
716       if (failed(PromoteFunctionArgument(body, /*arg_index=*/operand_number,
717                                          promoted.getType(), symbol_table)))
718         return failure();
719     }
720 
721     if (failed(update_promoted_result_indices(while_op,
722                                               body.getType().getResults())))
723       return failure();
724 
725   } else if (auto call_interface = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
726     auto callee_name = call_interface.getCallableForCallee()
727                            .get<mlir::SymbolRefAttr>()
728                            .cast<mlir::FlatSymbolRefAttr>()
729                            .getValue();
730     auto callee = symbol_table.lookup<mlir::FuncOp>(callee_name);
731     assert(callee);
732 
733     unsigned arg_index =
734         operand_number - call_interface.getArgOperands().getBeginOperandIndex();
735     if (callee.getType().getInput(arg_index) != promoted.getType()) {
736       if (failed(PromoteFunctionArgument(callee, arg_index, promoted.getType(),
737                                          symbol_table)))
738         return failure();
739     }
740 
741     if (failed(
742             update_promoted_result_indices(op, callee.getType().getResults())))
743       return failure();
744   } else if (auto batch_function_op =
745                  llvm::dyn_cast<mlir::TF::BatchFunctionOp>(op)) {
746     auto batch_fn = symbol_table.lookup<mlir::FuncOp>(
747         batch_function_op.f().getRootReference());
748     assert(batch_fn);
749 
750     unsigned arg_index = operand_number;
751     if (batch_fn.getType().getInput(arg_index) != promoted.getType()) {
752       if (failed(PromoteFunctionArgument(batch_fn, arg_index,
753                                          promoted.getType(), symbol_table)))
754         return failure();
755     }
756 
757     if (failed(update_promoted_result_indices(op,
758                                               batch_fn.getType().getResults())))
759       return failure();
760   }
761 
762   // Replace the operand.
763   op->setOperand(operand_number, promoted);
764 
765   if (promoted_result_indices.empty()) return success();
766 
767   // If results are also promoted, we need to create a new op with the new
768   // results and replaces all uses recursively.
769 
770   mlir::OpBuilder builder(op);
771 
772   llvm::SmallVector<mlir::Type, 4> new_result_types(op->result_type_begin(),
773                                                     op->result_type_end());
774   for (unsigned result_number : promoted_result_indices) {
775     new_result_types[result_number] = promoted.getType();
776   }
777 
778   mlir::OperationState state(op->getLoc(), op->getName());
779   state.addOperands(op->getOperands());
780   state.addTypes(new_result_types);
781   state.addAttributes(op->getAttrs());
782 
783   auto *new_op = builder.createOperation(state);
784 
785   // Replace all uses of `op`, and recursively replace those promoted uses.
786   for (unsigned i = 0, j = 0, e = op->getNumResults(); i < e; ++i) {
787     if (j < promoted_result_indices.size() && promoted_result_indices[j] == i) {
788       j++;
789       if (failed(PromoteValueUses(op->getResult(i), new_op->getResult(i),
790                                   symbol_table))) {
791         // On failure, replace all uses of new_op with op and erase the new op.
792         new_op->replaceAllUsesWith(op);
793         new_op->erase();
794         return failure();
795       }
796     } else {
797       op->getResult(i).replaceAllUsesWith(new_op->getResult(i));
798     }
799   }
800 
801   // On success, erase the original op.
802   op->erase();
803 
804   return success();
805 }
806 
PromoteValueUses(mlir::Value old,mlir::Value promoted,const mlir::SymbolTable & symbol_table)807 mlir::LogicalResult LowerTFSavedModelPass::PromoteValueUses(
808     mlir::Value old, mlir::Value promoted,
809     const mlir::SymbolTable &symbol_table) {
810   // Retrieve the current uses before replacing the uses as the use list can be
811   // invalidated later.
812   llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4> uses;
813   for (auto &use : old.getUses())
814     uses.push_back({use.getOwner(), use.getOperandNumber()});
815 
816   // Replace uses recursively.
817   for (const auto &use : uses) {
818     if (failed(PromoteOpOperand(/*op=*/use.first,
819                                 /*operand_number=*/use.second, promoted,
820                                 symbol_table)))
821       return failure();
822   }
823 
824   return success();
825 }
826 
827 // Converts ref variables to resource variables in a few cases.
828 //
829 // If the users of one variable in the entire module satisfies the following
830 // condition, it will be converted to resource variable:
831 //
832 // 1) tf.Identity op
833 // 2) tf.Assign op
834 // 3) side-effect-free ops: This is also the TF1 behavior that the TF executor
835 //    will automatically convert ref tensors to non-ref tensors if the user is
836 //    not expecting a ref tensor. Refer to
837 //    http://cs?q=tensorflow/core/common_runtime/executor.cc:932%20at_cl:356873227
838 class ConvertReferenceVariableToResourceVariablePass
839     : public mlir::PassWrapper<ConvertReferenceVariableToResourceVariablePass,
840                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const841   llvm::StringRef getArgument() const final {
842     return "tfrt-convert-ref-variables";
843   }
getDescription() const844   llvm::StringRef getDescription() const final {
845     return "Convert reference variable to resource variables.";
846   }
847   void runOnOperation() override;
848 };
849 
ConvertReferenceVariableToResourceVariable(mlir::TF::VariableV2Op var_op)850 mlir::LogicalResult ConvertReferenceVariableToResourceVariable(
851     mlir::TF::VariableV2Op var_op) {
852   auto tensor_type =
853       mlir::TF::DropRefType(var_op.ref().getType()).cast<mlir::TensorType>();
854 
855   llvm::SmallVector<mlir::TF::IdentityOp, 4> identity_ops;
856   llvm::SmallVector<mlir::TF::AssignOp, 4> assign_ops;
857   llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4>
858       side_effect_free_ops;
859 
860   for (mlir::OpOperand &use : var_op.ref().getUses()) {
861     mlir::Operation *user = use.getOwner();
862 
863     if (auto identity = llvm::dyn_cast<mlir::TF::IdentityOp>(user)) {
864       identity_ops.push_back(identity);
865       continue;
866     } else if (auto assign = llvm::dyn_cast<mlir::TF::AssignOp>(user)) {
867       // Conservatively we only allow the case that the output of this tf.Assign
868       // is not consumed by any other ops.
869       if (assign.output_ref().use_empty()) {
870         assign_ops.push_back(assign);
871         continue;
872       }
873     } else if (mlir::MemoryEffectOpInterface::hasNoEffect(user)) {
874       side_effect_free_ops.push_back({user, use.getOperandNumber()});
875       continue;
876     }
877 
878     return var_op.emitOpError()
879            << "failed to convert reference variables with unexpected users. "
880            << *user;
881   }
882 
883   mlir::OpBuilder builder(var_op);
884 
885   auto var_handle_op = builder.create<mlir::TF::VarHandleOp>(
886       var_op.getLoc(),
887       mlir::RankedTensorType::get(
888           {},
889           mlir::TF::ResourceType::get(ArrayRef<mlir::TensorType>{tensor_type},
890                                       builder.getContext())),
891       var_op.container(), var_op.shared_name());
892 
893   for (auto op : identity_ops) {
894     // Set insertion point to this identity_op so that the side-effect
895     // visibility is preserved.
896     builder.setInsertionPoint(op);
897     auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
898         op.getLoc(), op.getType(), var_handle_op);
899     op.replaceAllUsesWith(read_var_op.value());
900     op.erase();
901   }
902 
903   for (auto op : assign_ops) {
904     // Set the insertion point after the assign op so that all operands are
905     // dominating the newly created op.
906     builder.setInsertionPoint(op);
907     builder.create<mlir::TF::AssignVariableOp>(op.getLoc(), var_handle_op,
908                                                op.value());
909     op.erase();
910   }
911 
912   for (auto pair : side_effect_free_ops) {
913     mlir::Operation *op = pair.first;
914     unsigned idx = pair.second;
915     // Set the insertion point after the op so that all operands are dominating
916     // the newly created op.
917     builder.setInsertionPoint(op);
918     // Create a new read variable op, so that the side-effects are preserved.
919     auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
920         op->getLoc(), tensor_type, var_handle_op);
921     op->setOperand(idx, read_var_op.value());
922   }
923 
924   return success();
925 }
926 
runOnOperation()927 void ConvertReferenceVariableToResourceVariablePass::runOnOperation() {
928   auto module = getOperation();
929 
930   // The key here is a tuple of device, container and shared_name to uniquely
931   // identify a variable.
932   llvm::DenseMap<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>,
933                  llvm::SmallVector<mlir::TF::VariableV2Op, 4>>
934       ref_vars;
935 
936   // First, we collect all variables' corresponding tf.VariableV2 ops.
937   module.walk([&ref_vars](mlir::TF::VariableV2Op op) {
938     if (op.shared_name().empty()) {
939       op.emitOpError()
940           << "unable to convert reference variables with empty shared_names.";
941       return mlir::WalkResult::interrupt();
942     }
943 
944     llvm::StringRef device;
945     if (auto device_attr = op->getAttrOfType<mlir::StringAttr>("device")) {
946       device = device_attr.getValue();
947     }
948 
949     ref_vars[{device, op.container(), op.shared_name()}].push_back(op);
950 
951     return mlir::WalkResult::advance();
952   });
953 
954   // Then we perform rewrite for each variable if possible.
955   for (const auto &iter : ref_vars) {
956     const auto &var_ops = iter.second;
957 
958     for (auto var_op : var_ops) {
959       if (mlir::succeeded(ConvertReferenceVariableToResourceVariable(var_op)))
960         var_op.erase();
961     }
962   }
963 }
964 
965 }  // namespace
966 
967 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerTFSavedModelPass(bool hoist_invariant_ops)968 CreateLowerTFSavedModelPass(bool hoist_invariant_ops) {
969   return std::make_unique<LowerTFSavedModelPass>(hoist_invariant_ops);
970 }
971 
972 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateConvertReferenceVariableToResourceVariablePass()973 CreateConvertReferenceVariableToResourceVariablePass() {
974   return std::make_unique<ConvertReferenceVariableToResourceVariablePass>();
975 }
976 
977 static mlir::PassRegistration<LowerTFSavedModelPass> saved_model_pass;
978 
979 static mlir::PassRegistration<ConvertReferenceVariableToResourceVariablePass>
980     ref_var_pass;
981 
982 }  // namespace tensorflow
983