1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <deque>
18 #include <tuple>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/OperationSupport.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Transforms/Passes.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36
37 namespace tensorflow {
38 namespace {
39
40 constexpr char kCpuDeviceName[] =
41 "/job:localhost/replica:0/task:0/device:CPU:0";
42
IsSessionInitializer(mlir::func::FuncOp op)43 bool IsSessionInitializer(mlir::func::FuncOp op) {
44 auto session_initializer_op = mlir::tf_saved_model::GetSessionInitializerOp(
45 op->getParentOfType<mlir::ModuleOp>());
46 if (!session_initializer_op) return false;
47
48 for (auto sym_ref : session_initializer_op.initializers()) {
49 if (op.getSymName() == sym_ref.cast<mlir::FlatSymbolRefAttr>().getValue())
50 return true;
51 }
52
53 return false;
54 }
55
GetResourceHandle(mlir::Operation * op)56 mlir::TF::ResourceHandle GetResourceHandle(mlir::Operation *op) {
57 llvm::StringRef device;
58 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
59 device = attr.getValue();
60 }
61
62 llvm::StringRef container;
63 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
64 container = attr.getValue();
65 }
66
67 llvm::StringRef shared_name;
68 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
69 shared_name = attr.getValue();
70 }
71
72 return {container, shared_name, device, /*op=*/nullptr};
73 }
74
75 struct HoistInfo {
76 // All hoisted ops in the topological order.
77 llvm::SmallVector<mlir::Operation *, 4> hoists_in_topological_order;
78
79 // Mapping from the old values produced by hoisted ops before hoisting to the
80 // new values after hoisting.
81 mlir::BlockAndValueMapping value_mapping;
82
83 // `hoisted_values` is to keep all values that are produced by hoisted ops
84 // but used by non-hoisted ops. These values will be replaced by results of
85 // tf._TfrtGetResource op. The index of each value in this array will be the
86 // index used in tf._TfrtGetResource and tf._TfrtSetResource op. This also
87 // stores the ResourceHandle which has the shared_name and container
88 // attributes used by later resource alias analysis and side effect analysis
89 // passes.
90 llvm::SmallVector<std::pair<mlir::Value, mlir::TF::ResourceHandle>, 4>
91 hoisted_values;
92 };
93
ReplaceHoistedValues(llvm::ArrayRef<std::pair<mlir::Value,mlir::TF::ResourceHandle>> hoisted_values,mlir::OpBuilder & builder)94 void ReplaceHoistedValues(
95 llvm::ArrayRef<std::pair<mlir::Value, mlir::TF::ResourceHandle>>
96 hoisted_values,
97 mlir::OpBuilder &builder) {
98 struct HoistedValueInfo {
99 llvm::SmallVector<mlir::Value, 4> hoisted_values;
100 llvm::SmallVector<int64_t, 4> indices;
101 llvm::SmallVector<llvm::StringRef, 4> shared_names;
102 llvm::SmallVector<llvm::StringRef, 4> containers;
103 };
104 // Rearrange the hoisted values by each function and each device.
105 llvm::DenseMap<mlir::Block *, llvm::StringMap<HoistedValueInfo>>
106 hoisted_values_by_block_device;
107
108 // Find a block where to place tf._TfrtGetResource operation. We do not place
109 // get resource operations inside the `tf_device.cluster` operations, because
110 // these blocks are intended for later on-device compilation. Insert resource
111 // reads to the closest block outside of the `tf_device.cluster` operation.
112 auto hoist_into_block = [](mlir::Value value) -> mlir::Block * {
113 mlir::Operation *cluster_op =
114 value.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
115 return cluster_op ? cluster_op->getBlock() : value.getParentBlock();
116 };
117
118 for (auto iter : llvm::enumerate(hoisted_values)) {
119 auto value = iter.value().first;
120 auto index = iter.index();
121 auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)];
122
123 assert(value.getDefiningOp() && "hoisted values must not be arguments.");
124 llvm::StringRef device = kCpuDeviceName;
125 if (auto device_attr =
126 value.getDefiningOp()->getAttrOfType<mlir::StringAttr>("device")) {
127 if (!device_attr.getValue().empty()) device = device_attr.getValue();
128 }
129
130 auto &item = device_map[device];
131
132 item.hoisted_values.push_back(value);
133 item.indices.push_back(index);
134 item.shared_names.push_back(iter.value().second.name);
135 item.containers.push_back(iter.value().second.container);
136 }
137
138 // Create tf._TfrtGetResource op for each function and device.
139 for (const auto &block_iter : hoisted_values_by_block_device) {
140 auto *block = block_iter.first;
141 const auto &device_map = block_iter.second;
142
143 builder.setInsertionPointToStart(block);
144 for (const auto &device_iter : device_map) {
145 llvm::StringRef device = device_iter.getKey();
146 mlir::ValueRange old_values = device_iter.getValue().hoisted_values;
147 const auto &indices = device_iter.getValue().indices;
148 const auto &shared_name_arr = device_iter.getValue().shared_names;
149 const auto &container_arr = device_iter.getValue().containers;
150
151 auto get_resource_op = builder.create<mlir::TF::_TfrtGetResourceOp>(
152 block->getParentOp()->getLoc(), old_values.getTypes(),
153 builder.getI64ArrayAttr(indices),
154 builder.getStrArrayAttr(shared_name_arr),
155 builder.getStrArrayAttr(container_arr));
156 get_resource_op->setAttr("device", builder.getStringAttr(device));
157
158 auto new_values = get_resource_op.results();
159 for (auto iter : llvm::zip(old_values, new_values)) {
160 auto old_value = std::get<0>(iter);
161 auto new_value = std::get<1>(iter);
162 old_value.replaceAllUsesWith(new_value);
163 }
164 }
165 }
166 }
167
OnlyHasReadEffect(mlir::Operation * op)168 bool OnlyHasReadEffect(mlir::Operation *op) {
169 auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op);
170 if (!interface) return false;
171 return interface.onlyHasEffect<mlir::MemoryEffects::Read>();
172 }
173
CanHoist(const llvm::DenseSet<mlir::TF::ResourceHandle> & read_only_vars,mlir::Operation * op)174 bool CanHoist(const llvm::DenseSet<mlir::TF::ResourceHandle> &read_only_vars,
175 mlir::Operation *op) {
176 // return ops should not be hoisted.
177 if (op->mightHaveTrait<mlir::OpTrait::IsTerminator>()) return false;
178
179 // Non-side-effecting ops can be hoisted.
180 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return true;
181
182 // ResourceHandle ops can be hoisted.
183 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op))
184 return true;
185
186 // If it is ReadVariableOp and the variable is readonly, it can be hoisted.
187 if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
188 if (auto var_handle_op = llvm::dyn_cast_or_null<mlir::TF::VarHandleOp>(
189 read_var_op.resource().getDefiningOp())) {
190 if (read_only_vars.count(GetResourceHandle(var_handle_op)) > 0)
191 return true;
192 }
193 }
194
195 // If it is LookupTableSizeOp, it can be hoisted as the size of the hash table
196 // cannot be changed after initialization.
197 if (auto lookup_table_size_op =
198 llvm::dyn_cast<mlir::TF::LookupTableSizeV2Op>(op)) {
199 if (auto hash_table_op = llvm::dyn_cast_or_null<mlir::TF::HashTableV2Op>(
200 lookup_table_size_op.table_handle().getDefiningOp())) {
201 if (read_only_vars.count(GetResourceHandle(hash_table_op)) > 0)
202 return true;
203 }
204 }
205
206 // TODO(chky): Allow more readonly ops.
207
208 return false;
209 }
210
HoistInvariantOpsInFunction(mlir::func::FuncOp func,const llvm::DenseSet<mlir::TF::ResourceHandle> & read_only_vars,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,mlir::OpBuilder & builder,HoistInfo & module_hoist_info)211 void HoistInvariantOpsInFunction(
212 mlir::func::FuncOp func,
213 const llvm::DenseSet<mlir::TF::ResourceHandle> &read_only_vars,
214 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
215 mlir::OpBuilder &builder, HoistInfo &module_hoist_info) {
216 // Keep the hoisted ops in this function.
217 llvm::DenseSet<mlir::Operation *> hoists;
218
219 auto all_operands_in_hoists = [&module_hoist_info](mlir::Operation *op) {
220 for (mlir::Value operand : op->getOperands()) {
221 if (module_hoist_info.value_mapping.lookupOrNull(operand) == nullptr)
222 return false;
223 }
224 return true;
225 };
226
227 auto all_control_predeccessors_in_hoists = [&hoists, &side_effect_analysis](
228 mlir::Operation *op) {
229 auto preds = side_effect_analysis.DirectControlPredecessors(op);
230 return std::all_of(
231 preds.begin(), preds.end(),
232 [&hoists](mlir::Operation *pred) { return hoists.count(pred) > 0; });
233 };
234
235 std::deque<mlir::Operation *> work_list;
236
237 // Start with ops with tf.VarHandleOp ops and tf.Const ops.
238 //
239 // TODO(chky): Consider allowing other ops including custom ops to be hoisted.
240 func.walk([&work_list](mlir::Operation *op) {
241 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op,
242 mlir::TF::ConstOp>(op))
243 work_list.push_back(op);
244 });
245
246 while (!work_list.empty()) {
247 auto *op = work_list.front();
248 work_list.pop_front();
249
250 // Skip if it is already hoisted.
251 if (hoists.count(op) > 0) continue;
252
253 // If the op can be hoisted, and all of its data dependencies and control
254 // dependencies are hoisted, then we hoist it. Otherwise, skip.
255 if (!(CanHoist(read_only_vars, op) && all_operands_in_hoists(op) &&
256 all_control_predeccessors_in_hoists(op)))
257 continue;
258
259 // Record the hoisted operation.
260 hoists.insert(op);
261 module_hoist_info.hoists_in_topological_order.push_back(op);
262
263 // Create a copy in the init function.
264 builder.clone(*op, module_hoist_info.value_mapping);
265
266 for (mlir::Operation *user : op->getUsers()) {
267 work_list.push_back(user);
268 }
269 }
270
271 // Find out the values that are produced by hoisted ops but used by
272 // non-hoisted ops. These values need to be replaced.
273 for (auto *op : hoists) {
274 for (auto result : op->getResults()) {
275 if (std::any_of(result.getUsers().begin(), result.getUsers().end(),
276 [&hoists](mlir::Operation *user) {
277 return hoists.count(user) == 0;
278 })) {
279 module_hoist_info.hoisted_values.push_back(
280 {result, GetResourceHandle(op)});
281 }
282 }
283 }
284 }
285
FindCalleesRecursive(const mlir::SymbolTable & symbol_table,mlir::func::FuncOp func,llvm::StringSet<> & callees)286 void FindCalleesRecursive(const mlir::SymbolTable &symbol_table,
287 mlir::func::FuncOp func, llvm::StringSet<> &callees) {
288 assert(func);
289 func.walk([&](mlir::Operation *op) {
290 for (const auto &named_attr : op->getAttrs()) {
291 if (auto symbol_attr =
292 named_attr.getValue().dyn_cast<mlir::FlatSymbolRefAttr>()) {
293 auto symbol = symbol_attr.getValue();
294 if (!callees.contains(symbol)) {
295 callees.insert(symbol);
296
297 auto func = symbol_table.lookup<mlir::func::FuncOp>(symbol);
298 if (!func) continue;
299
300 FindCalleesRecursive(symbol_table, func, callees);
301 }
302 }
303 }
304 });
305 }
306
HoistInvariantOps(mlir::ModuleOp module)307 void HoistInvariantOps(mlir::ModuleOp module) {
308 mlir::SymbolTable symbol_table(module);
309
310 // Find all resources used in non-init functions.
311 llvm::DenseMap<mlir::TF::ResourceHandle,
312 llvm::SmallVector<mlir::Operation *, 4>>
313 resources;
314
315 // Find all callees referenced in the initialization functions.
316 llvm::StringSet<> init_callees;
317
318 module.walk([&](mlir::Operation *op) {
319 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op)) {
320 auto func = op->getParentOfType<mlir::func::FuncOp>();
321 if (IsSessionInitializer(func)) return;
322 resources[GetResourceHandle(op)].push_back(op);
323 } else if (auto func = llvm::dyn_cast<mlir::func::FuncOp>(op)) {
324 if (!IsSessionInitializer(func)) return;
325 FindCalleesRecursive(symbol_table, func, init_callees);
326 }
327 });
328
329 llvm::DenseSet<mlir::TF::ResourceHandle> read_only_vars;
330 for (const auto &iter : resources) {
331 const auto &key = iter.first;
332 const auto &vars = iter.second;
333 if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) {
334 for (auto *user : op->getUsers()) {
335 if (!OnlyHasReadEffect(user)) return false;
336 }
337 return true;
338 })) {
339 read_only_vars.insert(key);
340 }
341 }
342
343 mlir::TF::SideEffectAnalysis side_effect_analysis(module);
344
345 mlir::OpBuilder builder(&module.getBodyRegion());
346 // "_tfrt_resource_init" is the special function that executes all invariant
347 // ops (eg. read-only variables) used in the model. This function should be
348 // executed after user-specified initialization.
349 auto init_func_op = builder.create<mlir::func::FuncOp>(
350 module.getLoc(), "_tfrt_resource_init",
351 mlir::FunctionType::get(module.getContext(), /*inputs=*/{},
352 /*results=*/{}));
353 auto *block = init_func_op.addEntryBlock();
354 builder.setInsertionPointToStart(block);
355
356 HoistInfo module_hoist_info;
357
358 for (auto func : module.getOps<mlir::func::FuncOp>()) {
359 // Skips hoisting if this function is an init function or any callees,
360 // including recursive ones, of an init functions, because otherwise the
361 // hoisted values won't be initialized when this function is called.
362 if (IsSessionInitializer(func) ||
363 init_callees.contains(func.getSymName()) || func == init_func_op)
364 continue;
365
366 HoistInvariantOpsInFunction(func, read_only_vars,
367 side_effect_analysis.GetAnalysisForFunc(func),
368 builder, module_hoist_info);
369 }
370
371 // Create tf._TfrtSetResource ops in the init function.
372 for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) {
373 mlir::Value value = iter.value().first;
374 int64_t index = iter.index();
375
376 auto new_value = module_hoist_info.value_mapping.lookup(value);
377 auto *new_op = new_value.getDefiningOp();
378 assert(new_op);
379 builder.setInsertionPointAfter(new_op);
380 auto set_resource_op = builder.create<mlir::TF::_TfrtSetResourceOp>(
381 new_op->getLoc(), new_value, index);
382
383 // Preserve the device attribute.
384 llvm::StringRef device = kCpuDeviceName;
385 if (auto device_attr = new_op->getAttrOfType<mlir::StringAttr>("device")) {
386 if (!device_attr.getValue().empty()) device = device_attr.getValue();
387 }
388 set_resource_op->setAttr("device", builder.getStringAttr(device));
389 }
390
391 builder.setInsertionPointToEnd(block);
392 // Finish building the init function by inserting an return op.
393 builder.create<mlir::func::ReturnOp>(init_func_op.getLoc());
394
395 // Now that we have the index for each value that will be replaced, we can
396 // create the tf._TfrtGetResource op in each function using these indices.
397 ReplaceHoistedValues(module_hoist_info.hoisted_values, builder);
398
399 // Lastly, erase the hoisted ops in reverse topological order.
400 for (auto *op :
401 llvm::reverse(module_hoist_info.hoists_in_topological_order)) {
402 assert(op->use_empty());
403 op->erase();
404 }
405 }
406
407 // This pass rewrites tf_saved_model dialect's ops according to TFRT's
408 // requirements:
409 //
410 // 1) Remove all tf_saved_model's attributes and ops.
411 // 2) Create a function for every exported names of the original function.
412 // 3) Hoist invariant ops (ie. guaranteed to return the same value on every
413 // invocation) for every non-init function.
414 //
415 class LowerTFSavedModelPass
416 : public mlir::PassWrapper<LowerTFSavedModelPass,
417 mlir::OperationPass<mlir::ModuleOp>> {
418 public:
419 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerTFSavedModelPass)
420
LowerTFSavedModelPass(bool hoist_invariant_ops)421 explicit LowerTFSavedModelPass(bool hoist_invariant_ops) {
422 hoist_invariant_ops_ = hoist_invariant_ops;
423 }
424 LowerTFSavedModelPass() = default;
LowerTFSavedModelPass(const LowerTFSavedModelPass &)425 LowerTFSavedModelPass(const LowerTFSavedModelPass &) {}
426
getArgument() const427 llvm::StringRef getArgument() const final {
428 return "tfrt-lower-tf-savedmodel";
429 }
getDescription() const430 llvm::StringRef getDescription() const final {
431 return "Lower tf-saved-model ops according to TFRT's requirements.";
432 }
433
runOnOperation()434 void runOnOperation() override {
435 auto module = getOperation();
436
437 // TODO(b/185928201): Create a standalone pass for hoisting invariant ops so
438 // that it can be reusable and configurable in other contexts than saved
439 // models.
440 if (hoist_invariant_ops_) HoistInvariantOps(module);
441
442 // Skip non-savedmodel MLIR module.
443 if (!mlir::tf_saved_model::HasTfSavedModelSemantics(module)) return;
444
445 mlir::SymbolTable symbol_table(module);
446
447 module->removeAttr("tf_saved_model.semantics");
448
449 mlir::OpBuilder builder(&getContext());
450 auto resource_id = builder.getStringAttr("tf.resource_name");
451 auto bound_id = builder.getStringAttr("tf_saved_model.bound_input");
452 auto path_id = builder.getStringAttr("tf_saved_model.index_path");
453
454 module.walk([resource_id, bound_id, path_id,
455 &builder](mlir::Operation *op) mutable {
456 if (auto func_op = llvm::dyn_cast<mlir::func::FuncOp>(op)) {
457 // Remove tf_saved_model specific function arg attributes.
458 for (unsigned i = 0, e = func_op.getNumArguments(); i != e; ++i) {
459 if (auto sym = func_op.getArgAttrOfType<mlir::FlatSymbolRefAttr>(
460 i, bound_id)) {
461 func_op.removeArgAttr(i, bound_id);
462 func_op.setArgAttr(i, resource_id,
463 builder.getStringAttr(sym.getValue()));
464 }
465 func_op.removeArgAttr(i, path_id);
466 }
467 for (unsigned i = 0, e = func_op.getNumResults(); i != e; ++i) {
468 func_op.removeResultAttr(i, bound_id);
469 func_op.removeResultAttr(i, path_id);
470 }
471 if (auto exported_names = func_op->getAttrOfType<mlir::ArrayAttr>(
472 "tf_saved_model.exported_names")) {
473 bool is_session_initializer = IsSessionInitializer(func_op);
474
475 // Create a function for each exported name.
476 //
477 // TODO(b/148477882): TFRT dialect should have similar concepts of
478 // exported names so that a function can be referenced by multiple
479 // exported names.
480 func_op->removeAttr("tf_saved_model.exported_names");
481 for (auto exported_name : exported_names) {
482 auto exported_func_op = func_op.clone();
483 exported_func_op.setName(exported_name.cast<mlir::StringAttr>());
484
485 // If it is a session initializer, we want to maximize parallelism
486 // and do not perform any stream merge, to minimize latency.
487 //
488 // TODO(b/183219530): This is a workaround as the cost model used
489 // currently is not very accurate, and leads to performance
490 // regression on IO ops that are common in initialization functions.
491 if (is_session_initializer) {
492 exported_func_op->setAttr("tfrt.cost_threshold",
493 builder.getI64IntegerAttr(1));
494 }
495
496 builder.setInsertionPoint(func_op);
497 builder.insert(exported_func_op);
498 }
499 func_op.erase();
500 }
501 }
502 });
503
504 module.walk([](mlir::Operation *op) {
505 if (llvm::isa<mlir::tf_saved_model::TensorFlowSavedModelDialect>(
506 op->getDialect())) {
507 // Remove all tf_saved_model ops.
508 op->erase();
509 }
510 });
511 }
512
513 private:
514 Option<bool> hoist_invariant_ops_{*this, "hoist-invariant-ops",
515 llvm::cl::desc("hoist-invariant-ops"),
516 llvm::cl::init(false)};
517 };
518
CompareTypes(mlir::TypeRange x,mlir::TypeRange y)519 static llvm::SmallVector<unsigned, 4> CompareTypes(mlir::TypeRange x,
520 mlir::TypeRange y) {
521 llvm::SmallVector<unsigned, 4> results;
522 assert(x.size() == y.size());
523 for (int i = 0, e = x.size(); i < e; ++i) {
524 if (x[i] != y[i]) results.push_back(i);
525 }
526 return results;
527 }
528
529 // Converts ref variables to resource variables in a few cases.
530 //
531 // If the users of one variable in the entire module satisfies the following
532 // condition, it will be converted to resource variable:
533 //
534 // 1) tf.Identity op
535 // 2) tf.Assign op
536 // 3) side-effect-free ops: This is also the TF1 behavior that the TF executor
537 // will automatically convert ref tensors to non-ref tensors if the user is
538 // not expecting a ref tensor. Refer to
539 // http://cs?q=tensorflow/core/common_runtime/executor.cc:932%20at_cl:356873227
540 class ConvertReferenceVariableToResourceVariablePass
541 : public mlir::PassWrapper<ConvertReferenceVariableToResourceVariablePass,
542 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const543 llvm::StringRef getArgument() const final {
544 return "tfrt-convert-ref-variables";
545 }
getDescription() const546 llvm::StringRef getDescription() const final {
547 return "Convert reference variable to resource variables.";
548 }
549 void runOnOperation() override;
550
551 public:
552 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
553 ConvertReferenceVariableToResourceVariablePass)
554 };
555
ConvertReferenceVariableToResourceVariable(mlir::TF::VariableV2Op var_op)556 mlir::LogicalResult ConvertReferenceVariableToResourceVariable(
557 mlir::TF::VariableV2Op var_op) {
558 auto tensor_type =
559 mlir::TF::DropRefType(var_op.ref().getType()).cast<mlir::TensorType>();
560
561 llvm::SmallVector<mlir::TF::IdentityOp, 4> identity_ops;
562 llvm::SmallVector<mlir::TF::AssignOp, 4> assign_ops;
563 llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4>
564 side_effect_free_ops;
565
566 for (mlir::OpOperand &use : var_op.ref().getUses()) {
567 mlir::Operation *user = use.getOwner();
568
569 if (auto identity = llvm::dyn_cast<mlir::TF::IdentityOp>(user)) {
570 identity_ops.push_back(identity);
571 continue;
572 } else if (auto assign = llvm::dyn_cast<mlir::TF::AssignOp>(user)) {
573 // Conservatively we only allow the case that the output of this tf.Assign
574 // is not consumed by any other ops.
575 if (assign.output_ref().use_empty()) {
576 assign_ops.push_back(assign);
577 continue;
578 }
579 } else if (mlir::MemoryEffectOpInterface::hasNoEffect(user)) {
580 side_effect_free_ops.push_back({user, use.getOperandNumber()});
581 continue;
582 }
583
584 return var_op.emitOpError()
585 << "failed to convert reference variables with unexpected users. "
586 << *user;
587 }
588
589 mlir::OpBuilder builder(var_op);
590
591 auto var_handle_op = builder.create<mlir::TF::VarHandleOp>(
592 var_op.getLoc(),
593 mlir::RankedTensorType::get(
594 {}, mlir::TF::ResourceType::get(
595 llvm::ArrayRef<mlir::TensorType>{tensor_type},
596 builder.getContext())),
597 var_op.container(), var_op.shared_name());
598
599 for (auto op : identity_ops) {
600 // Set insertion point to this identity_op so that the side-effect
601 // visibility is preserved.
602 builder.setInsertionPoint(op);
603 auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
604 op.getLoc(), op.getType(), var_handle_op);
605 op.replaceAllUsesWith(read_var_op.value());
606 op.erase();
607 }
608
609 for (auto op : assign_ops) {
610 // Set the insertion point after the assign op so that all operands are
611 // dominating the newly created op.
612 builder.setInsertionPoint(op);
613 builder.create<mlir::TF::AssignVariableOp>(op.getLoc(), var_handle_op,
614 op.value());
615 op.erase();
616 }
617
618 for (auto pair : side_effect_free_ops) {
619 mlir::Operation *op = pair.first;
620 unsigned idx = pair.second;
621 // Set the insertion point after the op so that all operands are dominating
622 // the newly created op.
623 builder.setInsertionPoint(op);
624 // Create a new read variable op, so that the side-effects are preserved.
625 auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
626 op->getLoc(), tensor_type, var_handle_op);
627 op->setOperand(idx, read_var_op.value());
628 }
629
630 return mlir::success();
631 }
632
runOnOperation()633 void ConvertReferenceVariableToResourceVariablePass::runOnOperation() {
634 auto module = getOperation();
635
636 // The key here is a tuple of device, container and shared_name to uniquely
637 // identify a variable.
638 llvm::DenseMap<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>,
639 llvm::SmallVector<mlir::TF::VariableV2Op, 4>>
640 ref_vars;
641
642 // First, we collect all variables' corresponding tf.VariableV2 ops.
643 module.walk([&ref_vars](mlir::TF::VariableV2Op op) {
644 if (op.shared_name().empty()) {
645 op.emitOpError()
646 << "unable to convert reference variables with empty shared_names.";
647 return mlir::WalkResult::interrupt();
648 }
649
650 llvm::StringRef device;
651 if (auto device_attr = op->getAttrOfType<mlir::StringAttr>("device")) {
652 device = device_attr.getValue();
653 }
654
655 ref_vars[{device, op.container(), op.shared_name()}].push_back(op);
656
657 return mlir::WalkResult::advance();
658 });
659
660 // Then we perform rewrite for each variable if possible.
661 for (const auto &iter : ref_vars) {
662 const auto &var_ops = iter.second;
663
664 for (auto var_op : var_ops) {
665 if (mlir::succeeded(ConvertReferenceVariableToResourceVariable(var_op)))
666 var_op.erase();
667 }
668 }
669 }
670
671 } // namespace
672
673 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerTFSavedModelPass(bool hoist_invariant_ops)674 CreateLowerTFSavedModelPass(bool hoist_invariant_ops) {
675 return std::make_unique<LowerTFSavedModelPass>(hoist_invariant_ops);
676 }
677
678 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateConvertReferenceVariableToResourceVariablePass()679 CreateConvertReferenceVariableToResourceVariablePass() {
680 return std::make_unique<ConvertReferenceVariableToResourceVariablePass>();
681 }
682
683 static mlir::PassRegistration<LowerTFSavedModelPass> saved_model_pass;
684
685 static mlir::PassRegistration<ConvertReferenceVariableToResourceVariablePass>
686 ref_var_pass;
687
688 } // namespace tensorflow
689