• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <deque>
18 #include <tuple>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Transforms/Passes.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 constexpr char kCpuDeviceName[] =
41     "/job:localhost/replica:0/task:0/device:CPU:0";
42 
IsSessionInitializer(mlir::func::FuncOp op)43 bool IsSessionInitializer(mlir::func::FuncOp op) {
44   auto session_initializer_op = mlir::tf_saved_model::GetSessionInitializerOp(
45       op->getParentOfType<mlir::ModuleOp>());
46   if (!session_initializer_op) return false;
47 
48   for (auto sym_ref : session_initializer_op.initializers()) {
49     if (op.getSymName() == sym_ref.cast<mlir::FlatSymbolRefAttr>().getValue())
50       return true;
51   }
52 
53   return false;
54 }
55 
GetResourceHandle(mlir::Operation * op)56 mlir::TF::ResourceHandle GetResourceHandle(mlir::Operation *op) {
57   llvm::StringRef device;
58   if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
59     device = attr.getValue();
60   }
61 
62   llvm::StringRef container;
63   if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
64     container = attr.getValue();
65   }
66 
67   llvm::StringRef shared_name;
68   if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
69     shared_name = attr.getValue();
70   }
71 
72   return {container, shared_name, device, /*op=*/nullptr};
73 }
74 
75 struct HoistInfo {
76   // All hoisted ops in the topological order.
77   llvm::SmallVector<mlir::Operation *, 4> hoists_in_topological_order;
78 
79   // Mapping from the old values produced by hoisted ops before hoisting to the
80   // new values after hoisting.
81   mlir::BlockAndValueMapping value_mapping;
82 
83   // `hoisted_values` is to keep all values that are produced by hoisted ops
84   // but used by non-hoisted ops. These values will be replaced by results of
85   // tf._TfrtGetResource op. The index of each value in this array will be the
86   // index used in tf._TfrtGetResource and tf._TfrtSetResource op. This also
87   // stores the ResourceHandle which has the shared_name and container
88   // attributes used by later resource alias analysis and side effect analysis
89   // passes.
90   llvm::SmallVector<std::pair<mlir::Value, mlir::TF::ResourceHandle>, 4>
91       hoisted_values;
92 };
93 
ReplaceHoistedValues(llvm::ArrayRef<std::pair<mlir::Value,mlir::TF::ResourceHandle>> hoisted_values,mlir::OpBuilder & builder)94 void ReplaceHoistedValues(
95     llvm::ArrayRef<std::pair<mlir::Value, mlir::TF::ResourceHandle>>
96         hoisted_values,
97     mlir::OpBuilder &builder) {
98   struct HoistedValueInfo {
99     llvm::SmallVector<mlir::Value, 4> hoisted_values;
100     llvm::SmallVector<int64_t, 4> indices;
101     llvm::SmallVector<llvm::StringRef, 4> shared_names;
102     llvm::SmallVector<llvm::StringRef, 4> containers;
103   };
104   // Rearrange the hoisted values by each function and each device.
105   llvm::DenseMap<mlir::Block *, llvm::StringMap<HoistedValueInfo>>
106       hoisted_values_by_block_device;
107 
108   // Find a block where to place tf._TfrtGetResource operation. We do not place
109   // get resource operations inside the `tf_device.cluster` operations, because
110   // these blocks are intended for later on-device compilation. Insert resource
111   // reads to the closest block outside of the `tf_device.cluster` operation.
112   auto hoist_into_block = [](mlir::Value value) -> mlir::Block * {
113     mlir::Operation *cluster_op =
114         value.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
115     return cluster_op ? cluster_op->getBlock() : value.getParentBlock();
116   };
117 
118   for (auto iter : llvm::enumerate(hoisted_values)) {
119     auto value = iter.value().first;
120     auto index = iter.index();
121     auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)];
122 
123     assert(value.getDefiningOp() && "hoisted values must not be arguments.");
124     llvm::StringRef device = kCpuDeviceName;
125     if (auto device_attr =
126             value.getDefiningOp()->getAttrOfType<mlir::StringAttr>("device")) {
127       if (!device_attr.getValue().empty()) device = device_attr.getValue();
128     }
129 
130     auto &item = device_map[device];
131 
132     item.hoisted_values.push_back(value);
133     item.indices.push_back(index);
134     item.shared_names.push_back(iter.value().second.name);
135     item.containers.push_back(iter.value().second.container);
136   }
137 
138   // Create tf._TfrtGetResource op for each function and device.
139   for (const auto &block_iter : hoisted_values_by_block_device) {
140     auto *block = block_iter.first;
141     const auto &device_map = block_iter.second;
142 
143     builder.setInsertionPointToStart(block);
144     for (const auto &device_iter : device_map) {
145       llvm::StringRef device = device_iter.getKey();
146       mlir::ValueRange old_values = device_iter.getValue().hoisted_values;
147       const auto &indices = device_iter.getValue().indices;
148       const auto &shared_name_arr = device_iter.getValue().shared_names;
149       const auto &container_arr = device_iter.getValue().containers;
150 
151       auto get_resource_op = builder.create<mlir::TF::_TfrtGetResourceOp>(
152           block->getParentOp()->getLoc(), old_values.getTypes(),
153           builder.getI64ArrayAttr(indices),
154           builder.getStrArrayAttr(shared_name_arr),
155           builder.getStrArrayAttr(container_arr));
156       get_resource_op->setAttr("device", builder.getStringAttr(device));
157 
158       auto new_values = get_resource_op.results();
159       for (auto iter : llvm::zip(old_values, new_values)) {
160         auto old_value = std::get<0>(iter);
161         auto new_value = std::get<1>(iter);
162         old_value.replaceAllUsesWith(new_value);
163       }
164     }
165   }
166 }
167 
OnlyHasReadEffect(mlir::Operation * op)168 bool OnlyHasReadEffect(mlir::Operation *op) {
169   auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op);
170   if (!interface) return false;
171   return interface.onlyHasEffect<mlir::MemoryEffects::Read>();
172 }
173 
CanHoist(const llvm::DenseSet<mlir::TF::ResourceHandle> & read_only_vars,mlir::Operation * op)174 bool CanHoist(const llvm::DenseSet<mlir::TF::ResourceHandle> &read_only_vars,
175               mlir::Operation *op) {
176   // return ops should not be hoisted.
177   if (op->mightHaveTrait<mlir::OpTrait::IsTerminator>()) return false;
178 
179   // Non-side-effecting ops can be hoisted.
180   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return true;
181 
182   // ResourceHandle ops can be hoisted.
183   if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op))
184     return true;
185 
186   // If it is ReadVariableOp and the variable is readonly, it can be hoisted.
187   if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
188     if (auto var_handle_op = llvm::dyn_cast_or_null<mlir::TF::VarHandleOp>(
189             read_var_op.resource().getDefiningOp())) {
190       if (read_only_vars.count(GetResourceHandle(var_handle_op)) > 0)
191         return true;
192     }
193   }
194 
195   // If it is LookupTableSizeOp, it can be hoisted as the size of the hash table
196   // cannot be changed after initialization.
197   if (auto lookup_table_size_op =
198           llvm::dyn_cast<mlir::TF::LookupTableSizeV2Op>(op)) {
199     if (auto hash_table_op = llvm::dyn_cast_or_null<mlir::TF::HashTableV2Op>(
200             lookup_table_size_op.table_handle().getDefiningOp())) {
201       if (read_only_vars.count(GetResourceHandle(hash_table_op)) > 0)
202         return true;
203     }
204   }
205 
206   // TODO(chky): Allow more readonly ops.
207 
208   return false;
209 }
210 
HoistInvariantOpsInFunction(mlir::func::FuncOp func,const llvm::DenseSet<mlir::TF::ResourceHandle> & read_only_vars,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,mlir::OpBuilder & builder,HoistInfo & module_hoist_info)211 void HoistInvariantOpsInFunction(
212     mlir::func::FuncOp func,
213     const llvm::DenseSet<mlir::TF::ResourceHandle> &read_only_vars,
214     const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
215     mlir::OpBuilder &builder, HoistInfo &module_hoist_info) {
216   // Keep the hoisted ops in this function.
217   llvm::DenseSet<mlir::Operation *> hoists;
218 
219   auto all_operands_in_hoists = [&module_hoist_info](mlir::Operation *op) {
220     for (mlir::Value operand : op->getOperands()) {
221       if (module_hoist_info.value_mapping.lookupOrNull(operand) == nullptr)
222         return false;
223     }
224     return true;
225   };
226 
227   auto all_control_predeccessors_in_hoists = [&hoists, &side_effect_analysis](
228                                                  mlir::Operation *op) {
229     auto preds = side_effect_analysis.DirectControlPredecessors(op);
230     return std::all_of(
231         preds.begin(), preds.end(),
232         [&hoists](mlir::Operation *pred) { return hoists.count(pred) > 0; });
233   };
234 
235   std::deque<mlir::Operation *> work_list;
236 
237   // Start with ops with tf.VarHandleOp ops and tf.Const ops.
238   //
239   // TODO(chky): Consider allowing other ops including custom ops to be hoisted.
240   func.walk([&work_list](mlir::Operation *op) {
241     if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op,
242                   mlir::TF::ConstOp>(op))
243       work_list.push_back(op);
244   });
245 
246   while (!work_list.empty()) {
247     auto *op = work_list.front();
248     work_list.pop_front();
249 
250     // Skip if it is already hoisted.
251     if (hoists.count(op) > 0) continue;
252 
253     // If the op can be hoisted, and all of its data dependencies and control
254     // dependencies are hoisted, then we hoist it. Otherwise, skip.
255     if (!(CanHoist(read_only_vars, op) && all_operands_in_hoists(op) &&
256           all_control_predeccessors_in_hoists(op)))
257       continue;
258 
259     // Record the hoisted operation.
260     hoists.insert(op);
261     module_hoist_info.hoists_in_topological_order.push_back(op);
262 
263     // Create a copy in the init function.
264     builder.clone(*op, module_hoist_info.value_mapping);
265 
266     for (mlir::Operation *user : op->getUsers()) {
267       work_list.push_back(user);
268     }
269   }
270 
271   // Find out the values that are produced by hoisted ops but used by
272   // non-hoisted ops. These values need to be replaced.
273   for (auto *op : hoists) {
274     for (auto result : op->getResults()) {
275       if (std::any_of(result.getUsers().begin(), result.getUsers().end(),
276                       [&hoists](mlir::Operation *user) {
277                         return hoists.count(user) == 0;
278                       })) {
279         module_hoist_info.hoisted_values.push_back(
280             {result, GetResourceHandle(op)});
281       }
282     }
283   }
284 }
285 
FindCalleesRecursive(const mlir::SymbolTable & symbol_table,mlir::func::FuncOp func,llvm::StringSet<> & callees)286 void FindCalleesRecursive(const mlir::SymbolTable &symbol_table,
287                           mlir::func::FuncOp func, llvm::StringSet<> &callees) {
288   assert(func);
289   func.walk([&](mlir::Operation *op) {
290     for (const auto &named_attr : op->getAttrs()) {
291       if (auto symbol_attr =
292               named_attr.getValue().dyn_cast<mlir::FlatSymbolRefAttr>()) {
293         auto symbol = symbol_attr.getValue();
294         if (!callees.contains(symbol)) {
295           callees.insert(symbol);
296 
297           auto func = symbol_table.lookup<mlir::func::FuncOp>(symbol);
298           if (!func) continue;
299 
300           FindCalleesRecursive(symbol_table, func, callees);
301         }
302       }
303     }
304   });
305 }
306 
HoistInvariantOps(mlir::ModuleOp module)307 void HoistInvariantOps(mlir::ModuleOp module) {
308   mlir::SymbolTable symbol_table(module);
309 
310   // Find all resources used in non-init functions.
311   llvm::DenseMap<mlir::TF::ResourceHandle,
312                  llvm::SmallVector<mlir::Operation *, 4>>
313       resources;
314 
315   // Find all callees referenced in the initialization functions.
316   llvm::StringSet<> init_callees;
317 
318   module.walk([&](mlir::Operation *op) {
319     if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op)) {
320       auto func = op->getParentOfType<mlir::func::FuncOp>();
321       if (IsSessionInitializer(func)) return;
322       resources[GetResourceHandle(op)].push_back(op);
323     } else if (auto func = llvm::dyn_cast<mlir::func::FuncOp>(op)) {
324       if (!IsSessionInitializer(func)) return;
325       FindCalleesRecursive(symbol_table, func, init_callees);
326     }
327   });
328 
329   llvm::DenseSet<mlir::TF::ResourceHandle> read_only_vars;
330   for (const auto &iter : resources) {
331     const auto &key = iter.first;
332     const auto &vars = iter.second;
333     if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) {
334           for (auto *user : op->getUsers()) {
335             if (!OnlyHasReadEffect(user)) return false;
336           }
337           return true;
338         })) {
339       read_only_vars.insert(key);
340     }
341   }
342 
343   mlir::TF::SideEffectAnalysis side_effect_analysis(module);
344 
345   mlir::OpBuilder builder(&module.getBodyRegion());
346   // "_tfrt_resource_init" is the special function that executes all invariant
347   // ops (eg. read-only variables) used in the model. This function should be
348   // executed after user-specified initialization.
349   auto init_func_op = builder.create<mlir::func::FuncOp>(
350       module.getLoc(), "_tfrt_resource_init",
351       mlir::FunctionType::get(module.getContext(), /*inputs=*/{},
352                               /*results=*/{}));
353   auto *block = init_func_op.addEntryBlock();
354   builder.setInsertionPointToStart(block);
355 
356   HoistInfo module_hoist_info;
357 
358   for (auto func : module.getOps<mlir::func::FuncOp>()) {
359     // Skips hoisting if this function is an init function or any callees,
360     // including recursive ones, of an init functions, because otherwise the
361     // hoisted values won't be initialized when this function is called.
362     if (IsSessionInitializer(func) ||
363         init_callees.contains(func.getSymName()) || func == init_func_op)
364       continue;
365 
366     HoistInvariantOpsInFunction(func, read_only_vars,
367                                 side_effect_analysis.GetAnalysisForFunc(func),
368                                 builder, module_hoist_info);
369   }
370 
371   // Create tf._TfrtSetResource ops in the init function.
372   for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) {
373     mlir::Value value = iter.value().first;
374     int64_t index = iter.index();
375 
376     auto new_value = module_hoist_info.value_mapping.lookup(value);
377     auto *new_op = new_value.getDefiningOp();
378     assert(new_op);
379     builder.setInsertionPointAfter(new_op);
380     auto set_resource_op = builder.create<mlir::TF::_TfrtSetResourceOp>(
381         new_op->getLoc(), new_value, index);
382 
383     // Preserve the device attribute.
384     llvm::StringRef device = kCpuDeviceName;
385     if (auto device_attr = new_op->getAttrOfType<mlir::StringAttr>("device")) {
386       if (!device_attr.getValue().empty()) device = device_attr.getValue();
387     }
388     set_resource_op->setAttr("device", builder.getStringAttr(device));
389   }
390 
391   builder.setInsertionPointToEnd(block);
392   // Finish building the init function by inserting an return op.
393   builder.create<mlir::func::ReturnOp>(init_func_op.getLoc());
394 
395   // Now that we have the index for each value that will be replaced, we can
396   // create the tf._TfrtGetResource op in each function using these indices.
397   ReplaceHoistedValues(module_hoist_info.hoisted_values, builder);
398 
399   // Lastly, erase the hoisted ops in reverse topological order.
400   for (auto *op :
401        llvm::reverse(module_hoist_info.hoists_in_topological_order)) {
402     assert(op->use_empty());
403     op->erase();
404   }
405 }
406 
407 // This pass rewrites tf_saved_model dialect's ops according to TFRT's
408 // requirements:
409 //
410 // 1) Remove all tf_saved_model's attributes and ops.
411 // 2) Create a function for every exported names of the original function.
412 // 3) Hoist invariant ops (ie. guaranteed to return the same value on every
413 // invocation) for every non-init function.
414 //
415 class LowerTFSavedModelPass
416     : public mlir::PassWrapper<LowerTFSavedModelPass,
417                                mlir::OperationPass<mlir::ModuleOp>> {
418  public:
419   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerTFSavedModelPass)
420 
LowerTFSavedModelPass(bool hoist_invariant_ops)421   explicit LowerTFSavedModelPass(bool hoist_invariant_ops) {
422     hoist_invariant_ops_ = hoist_invariant_ops;
423   }
424   LowerTFSavedModelPass() = default;
LowerTFSavedModelPass(const LowerTFSavedModelPass &)425   LowerTFSavedModelPass(const LowerTFSavedModelPass &) {}
426 
getArgument() const427   llvm::StringRef getArgument() const final {
428     return "tfrt-lower-tf-savedmodel";
429   }
getDescription() const430   llvm::StringRef getDescription() const final {
431     return "Lower tf-saved-model ops according to TFRT's requirements.";
432   }
433 
runOnOperation()434   void runOnOperation() override {
435     auto module = getOperation();
436 
437     // TODO(b/185928201): Create a standalone pass for hoisting invariant ops so
438     // that it can be reusable and configurable in other contexts than saved
439     // models.
440     if (hoist_invariant_ops_) HoistInvariantOps(module);
441 
442     // Skip non-savedmodel MLIR module.
443     if (!mlir::tf_saved_model::HasTfSavedModelSemantics(module)) return;
444 
445     mlir::SymbolTable symbol_table(module);
446 
447     module->removeAttr("tf_saved_model.semantics");
448 
449     mlir::OpBuilder builder(&getContext());
450     auto resource_id = builder.getStringAttr("tf.resource_name");
451     auto bound_id = builder.getStringAttr("tf_saved_model.bound_input");
452     auto path_id = builder.getStringAttr("tf_saved_model.index_path");
453 
454     module.walk([resource_id, bound_id, path_id,
455                  &builder](mlir::Operation *op) mutable {
456       if (auto func_op = llvm::dyn_cast<mlir::func::FuncOp>(op)) {
457         // Remove tf_saved_model specific function arg attributes.
458         for (unsigned i = 0, e = func_op.getNumArguments(); i != e; ++i) {
459           if (auto sym = func_op.getArgAttrOfType<mlir::FlatSymbolRefAttr>(
460                   i, bound_id)) {
461             func_op.removeArgAttr(i, bound_id);
462             func_op.setArgAttr(i, resource_id,
463                                builder.getStringAttr(sym.getValue()));
464           }
465           func_op.removeArgAttr(i, path_id);
466         }
467         for (unsigned i = 0, e = func_op.getNumResults(); i != e; ++i) {
468           func_op.removeResultAttr(i, bound_id);
469           func_op.removeResultAttr(i, path_id);
470         }
471         if (auto exported_names = func_op->getAttrOfType<mlir::ArrayAttr>(
472                 "tf_saved_model.exported_names")) {
473           bool is_session_initializer = IsSessionInitializer(func_op);
474 
475           // Create a function for each exported name.
476           //
477           // TODO(b/148477882): TFRT dialect should have similar concepts of
478           // exported names so that a function can be referenced by multiple
479           // exported names.
480           func_op->removeAttr("tf_saved_model.exported_names");
481           for (auto exported_name : exported_names) {
482             auto exported_func_op = func_op.clone();
483             exported_func_op.setName(exported_name.cast<mlir::StringAttr>());
484 
485             // If it is a session initializer, we want to maximize parallelism
486             // and do not perform any stream merge, to minimize latency.
487             //
488             // TODO(b/183219530): This is a workaround as the cost model used
489             // currently is not very accurate, and leads to performance
490             // regression on IO ops that are common in initialization functions.
491             if (is_session_initializer) {
492               exported_func_op->setAttr("tfrt.cost_threshold",
493                                         builder.getI64IntegerAttr(1));
494             }
495 
496             builder.setInsertionPoint(func_op);
497             builder.insert(exported_func_op);
498           }
499           func_op.erase();
500         }
501       }
502     });
503 
504     module.walk([](mlir::Operation *op) {
505       if (llvm::isa<mlir::tf_saved_model::TensorFlowSavedModelDialect>(
506               op->getDialect())) {
507         // Remove all tf_saved_model ops.
508         op->erase();
509       }
510     });
511   }
512 
513  private:
514   Option<bool> hoist_invariant_ops_{*this, "hoist-invariant-ops",
515                                     llvm::cl::desc("hoist-invariant-ops"),
516                                     llvm::cl::init(false)};
517 };
518 
CompareTypes(mlir::TypeRange x,mlir::TypeRange y)519 static llvm::SmallVector<unsigned, 4> CompareTypes(mlir::TypeRange x,
520                                                    mlir::TypeRange y) {
521   llvm::SmallVector<unsigned, 4> results;
522   assert(x.size() == y.size());
523   for (int i = 0, e = x.size(); i < e; ++i) {
524     if (x[i] != y[i]) results.push_back(i);
525   }
526   return results;
527 }
528 
529 // Converts ref variables to resource variables in a few cases.
530 //
531 // If the users of one variable in the entire module satisfies the following
532 // condition, it will be converted to resource variable:
533 //
534 // 1) tf.Identity op
535 // 2) tf.Assign op
536 // 3) side-effect-free ops: This is also the TF1 behavior that the TF executor
537 //    will automatically convert ref tensors to non-ref tensors if the user is
538 //    not expecting a ref tensor. Refer to
539 //    http://cs?q=tensorflow/core/common_runtime/executor.cc:932%20at_cl:356873227
540 class ConvertReferenceVariableToResourceVariablePass
541     : public mlir::PassWrapper<ConvertReferenceVariableToResourceVariablePass,
542                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const543   llvm::StringRef getArgument() const final {
544     return "tfrt-convert-ref-variables";
545   }
getDescription() const546   llvm::StringRef getDescription() const final {
547     return "Convert reference variable to resource variables.";
548   }
549   void runOnOperation() override;
550 
551  public:
552   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
553       ConvertReferenceVariableToResourceVariablePass)
554 };
555 
ConvertReferenceVariableToResourceVariable(mlir::TF::VariableV2Op var_op)556 mlir::LogicalResult ConvertReferenceVariableToResourceVariable(
557     mlir::TF::VariableV2Op var_op) {
558   auto tensor_type =
559       mlir::TF::DropRefType(var_op.ref().getType()).cast<mlir::TensorType>();
560 
561   llvm::SmallVector<mlir::TF::IdentityOp, 4> identity_ops;
562   llvm::SmallVector<mlir::TF::AssignOp, 4> assign_ops;
563   llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4>
564       side_effect_free_ops;
565 
566   for (mlir::OpOperand &use : var_op.ref().getUses()) {
567     mlir::Operation *user = use.getOwner();
568 
569     if (auto identity = llvm::dyn_cast<mlir::TF::IdentityOp>(user)) {
570       identity_ops.push_back(identity);
571       continue;
572     } else if (auto assign = llvm::dyn_cast<mlir::TF::AssignOp>(user)) {
573       // Conservatively we only allow the case that the output of this tf.Assign
574       // is not consumed by any other ops.
575       if (assign.output_ref().use_empty()) {
576         assign_ops.push_back(assign);
577         continue;
578       }
579     } else if (mlir::MemoryEffectOpInterface::hasNoEffect(user)) {
580       side_effect_free_ops.push_back({user, use.getOperandNumber()});
581       continue;
582     }
583 
584     return var_op.emitOpError()
585            << "failed to convert reference variables with unexpected users. "
586            << *user;
587   }
588 
589   mlir::OpBuilder builder(var_op);
590 
591   auto var_handle_op = builder.create<mlir::TF::VarHandleOp>(
592       var_op.getLoc(),
593       mlir::RankedTensorType::get(
594           {}, mlir::TF::ResourceType::get(
595                   llvm::ArrayRef<mlir::TensorType>{tensor_type},
596                   builder.getContext())),
597       var_op.container(), var_op.shared_name());
598 
599   for (auto op : identity_ops) {
600     // Set insertion point to this identity_op so that the side-effect
601     // visibility is preserved.
602     builder.setInsertionPoint(op);
603     auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
604         op.getLoc(), op.getType(), var_handle_op);
605     op.replaceAllUsesWith(read_var_op.value());
606     op.erase();
607   }
608 
609   for (auto op : assign_ops) {
610     // Set the insertion point after the assign op so that all operands are
611     // dominating the newly created op.
612     builder.setInsertionPoint(op);
613     builder.create<mlir::TF::AssignVariableOp>(op.getLoc(), var_handle_op,
614                                                op.value());
615     op.erase();
616   }
617 
618   for (auto pair : side_effect_free_ops) {
619     mlir::Operation *op = pair.first;
620     unsigned idx = pair.second;
621     // Set the insertion point after the op so that all operands are dominating
622     // the newly created op.
623     builder.setInsertionPoint(op);
624     // Create a new read variable op, so that the side-effects are preserved.
625     auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
626         op->getLoc(), tensor_type, var_handle_op);
627     op->setOperand(idx, read_var_op.value());
628   }
629 
630   return mlir::success();
631 }
632 
runOnOperation()633 void ConvertReferenceVariableToResourceVariablePass::runOnOperation() {
634   auto module = getOperation();
635 
636   // The key here is a tuple of device, container and shared_name to uniquely
637   // identify a variable.
638   llvm::DenseMap<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>,
639                  llvm::SmallVector<mlir::TF::VariableV2Op, 4>>
640       ref_vars;
641 
642   // First, we collect all variables' corresponding tf.VariableV2 ops.
643   module.walk([&ref_vars](mlir::TF::VariableV2Op op) {
644     if (op.shared_name().empty()) {
645       op.emitOpError()
646           << "unable to convert reference variables with empty shared_names.";
647       return mlir::WalkResult::interrupt();
648     }
649 
650     llvm::StringRef device;
651     if (auto device_attr = op->getAttrOfType<mlir::StringAttr>("device")) {
652       device = device_attr.getValue();
653     }
654 
655     ref_vars[{device, op.container(), op.shared_name()}].push_back(op);
656 
657     return mlir::WalkResult::advance();
658   });
659 
660   // Then we perform rewrite for each variable if possible.
661   for (const auto &iter : ref_vars) {
662     const auto &var_ops = iter.second;
663 
664     for (auto var_op : var_ops) {
665       if (mlir::succeeded(ConvertReferenceVariableToResourceVariable(var_op)))
666         var_op.erase();
667     }
668   }
669 }
670 
671 }  // namespace
672 
673 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerTFSavedModelPass(bool hoist_invariant_ops)674 CreateLowerTFSavedModelPass(bool hoist_invariant_ops) {
675   return std::make_unique<LowerTFSavedModelPass>(hoist_invariant_ops);
676 }
677 
678 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateConvertReferenceVariableToResourceVariablePass()679 CreateConvertReferenceVariableToResourceVariablePass() {
680   return std::make_unique<ConvertReferenceVariableToResourceVariablePass>();
681 }
682 
683 static mlir::PassRegistration<LowerTFSavedModelPass> saved_model_pass;
684 
685 static mlir::PassRegistration<ConvertReferenceVariableToResourceVariablePass>
686     ref_var_pass;
687 
688 }  // namespace tensorflow
689