1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <deque>
18 #include <tuple>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/OperationSupport.h" // from @llvm-project
29 #include "mlir/IR/Types.h" // from @llvm-project
30 #include "mlir/Transforms/Passes.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
34 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
35
36 using llvm::ArrayRef;
37 using mlir::failure;
38 using mlir::success;
39
40 namespace tensorflow {
41 namespace {
42
43 constexpr char kCpuDeviceName[] =
44 "/job:localhost/replica:0/task:0/device:CPU:0";
45
46 // Tuple storing {device, container, shared_name} attributes in that order.
47 using ResourceKey =
48 std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>;
49
IsSessionInitializer(mlir::FuncOp op)50 bool IsSessionInitializer(mlir::FuncOp op) {
51 auto session_initializer_op = mlir::tf_saved_model::GetSessionInitializerOp(
52 op->getParentOfType<mlir::ModuleOp>());
53 if (!session_initializer_op) return false;
54
55 for (auto sym_ref : session_initializer_op.initializers()) {
56 if (op.sym_name() == sym_ref.cast<mlir::FlatSymbolRefAttr>().getValue())
57 return true;
58 }
59
60 return false;
61 }
62
GetResourceKey(mlir::Operation * op)63 ResourceKey GetResourceKey(mlir::Operation *op) {
64 llvm::StringRef device;
65 if (auto attr = op->getAttrOfType<mlir::StringAttr>("device")) {
66 device = attr.getValue();
67 }
68
69 llvm::StringRef container;
70 if (auto attr = op->getAttrOfType<mlir::StringAttr>("container")) {
71 container = attr.getValue();
72 }
73
74 llvm::StringRef shared_name;
75 if (auto attr = op->getAttrOfType<mlir::StringAttr>("shared_name")) {
76 shared_name = attr.getValue();
77 }
78
79 return {device, container, shared_name};
80 }
81
82 struct HoistInfo {
83 // All hoisted ops in the topological order.
84 llvm::SmallVector<mlir::Operation *, 4> hoists_in_topological_order;
85
86 // Mapping from the old values produced by hoisted ops before hoisting to the
87 // new values after hoisting.
88 mlir::BlockAndValueMapping value_mapping;
89
90 // `hoisted_values` is to keep all values that are produced by hoisted ops
91 // but used by non-hoisted ops. These values will be replaced by results of
92 // tf._TfrtGetResource op. The index of each value in this array will be the
93 // index used in tf._TfrtGetResource and tf._TfrtSetResource op. This also
94 // stores the ResourceKey which has the shared_name and container attributes
95 // used by later resource alias analysis and side effect analysis passes.
96 llvm::SmallVector<std::pair<mlir::Value, ResourceKey>, 4> hoisted_values;
97 };
98
ReplaceHoistedValues(llvm::ArrayRef<std::pair<mlir::Value,ResourceKey>> hoisted_values,mlir::OpBuilder & builder)99 void ReplaceHoistedValues(
100 llvm::ArrayRef<std::pair<mlir::Value, ResourceKey>> hoisted_values,
101 mlir::OpBuilder &builder) {
102 struct HoistedValueInfo {
103 llvm::SmallVector<mlir::Value, 4> hoisted_values;
104 llvm::SmallVector<int64_t, 4> indices;
105 llvm::SmallVector<llvm::StringRef, 4> shared_names;
106 llvm::SmallVector<llvm::StringRef, 4> containers;
107 };
108 // Rearrange the hoisted values by each function and each device.
109 llvm::DenseMap<mlir::Block *, llvm::StringMap<HoistedValueInfo>>
110 hoisted_values_by_block_device;
111
112 // Find a block where to place tf._TfrtGetResource operation. We do not place
113 // get resource operations inside the `tf_device.cluster` operations, because
114 // these blocks are intended for later on-device compilation. Insert resource
115 // reads to the closest block outside of the `tf_device.cluster` operation.
116 auto hoist_into_block = [](mlir::Value value) -> mlir::Block * {
117 mlir::Operation *cluster_op =
118 value.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
119 return cluster_op ? cluster_op->getBlock() : value.getParentBlock();
120 };
121
122 for (auto iter : llvm::enumerate(hoisted_values)) {
123 auto value = iter.value().first;
124 auto container = std::get<1>(iter.value().second);
125 auto shared_name = std::get<2>(iter.value().second);
126 auto index = iter.index();
127 auto &device_map = hoisted_values_by_block_device[hoist_into_block(value)];
128
129 assert(value.getDefiningOp() && "hoisted values must not be arguments.");
130 llvm::StringRef device = kCpuDeviceName;
131 if (auto device_attr =
132 value.getDefiningOp()->getAttrOfType<mlir::StringAttr>("device")) {
133 if (!device_attr.getValue().empty()) device = device_attr.getValue();
134 }
135
136 auto &item = device_map[device];
137
138 item.hoisted_values.push_back(value);
139 item.indices.push_back(index);
140 item.shared_names.push_back(shared_name);
141 item.containers.push_back(container);
142 }
143
144 // Create tf._TfrtGetResource op for each function and device.
145 for (const auto &block_iter : hoisted_values_by_block_device) {
146 auto *block = block_iter.first;
147 const auto &device_map = block_iter.second;
148
149 builder.setInsertionPointToStart(block);
150 for (const auto &device_iter : device_map) {
151 llvm::StringRef device = device_iter.getKey();
152 mlir::ValueRange old_values = device_iter.getValue().hoisted_values;
153 const auto &indices = device_iter.getValue().indices;
154 const auto &shared_name_arr = device_iter.getValue().shared_names;
155 const auto &container_arr = device_iter.getValue().containers;
156
157 auto get_resource_op = builder.create<mlir::TF::_TfrtGetResourceOp>(
158 block->getParentOp()->getLoc(), old_values.getTypes(),
159 builder.getI64ArrayAttr(indices),
160 builder.getStrArrayAttr(shared_name_arr),
161 builder.getStrArrayAttr(container_arr));
162 get_resource_op->setAttr("device", builder.getStringAttr(device));
163
164 auto new_values = get_resource_op.results();
165 for (auto iter : llvm::zip(old_values, new_values)) {
166 auto old_value = std::get<0>(iter);
167 auto new_value = std::get<1>(iter);
168 old_value.replaceAllUsesWith(new_value);
169 }
170 }
171 }
172 }
173
OnlyHasReadEffect(mlir::Operation * op)174 bool OnlyHasReadEffect(mlir::Operation *op) {
175 auto interface = llvm::dyn_cast<mlir::MemoryEffectOpInterface>(op);
176 if (!interface) return false;
177 return interface.onlyHasEffect<mlir::MemoryEffects::Read>();
178 }
179
CanHoist(const llvm::DenseSet<ResourceKey> & read_only_vars,mlir::Operation * op)180 bool CanHoist(const llvm::DenseSet<ResourceKey> &read_only_vars,
181 mlir::Operation *op) {
182 // return ops should not be hoisted.
183 if (op->mightHaveTrait<mlir::OpTrait::IsTerminator>()) return false;
184
185 // Non-side-effecting ops can be hoisted.
186 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return true;
187
188 // ResourceHandle ops can be hoisted.
189 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op))
190 return true;
191
192 // If it is ReadVariableOp and the variable is readonly, it can be hoisted.
193 if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
194 if (auto var_handle_op = llvm::dyn_cast_or_null<mlir::TF::VarHandleOp>(
195 read_var_op.resource().getDefiningOp())) {
196 if (read_only_vars.count(GetResourceKey(var_handle_op)) > 0) return true;
197 }
198 }
199
200 // If it is LookupTableSizeOp, it can be hoisted as the size of the hash table
201 // cannot be changed after initialization.
202 if (auto lookup_table_size_op =
203 llvm::dyn_cast<mlir::TF::LookupTableSizeV2Op>(op)) {
204 if (auto hash_table_op = llvm::dyn_cast_or_null<mlir::TF::HashTableV2Op>(
205 lookup_table_size_op.table_handle().getDefiningOp())) {
206 if (read_only_vars.count(GetResourceKey(hash_table_op)) > 0) return true;
207 }
208 }
209
210 // TODO(chky): Allow more readonly ops.
211
212 return false;
213 }
214
HoistInvariantOpsInFunction(mlir::FuncOp func,const llvm::DenseSet<ResourceKey> & read_only_vars,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,mlir::OpBuilder & builder,HoistInfo & module_hoist_info)215 void HoistInvariantOpsInFunction(
216 mlir::FuncOp func, const llvm::DenseSet<ResourceKey> &read_only_vars,
217 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
218 mlir::OpBuilder &builder, HoistInfo &module_hoist_info) {
219 // Keep the hoisted ops in this function.
220 llvm::DenseSet<mlir::Operation *> hoists;
221
222 auto all_operands_in_hoists = [&module_hoist_info](mlir::Operation *op) {
223 for (mlir::Value operand : op->getOperands()) {
224 if (module_hoist_info.value_mapping.lookupOrNull(operand) == nullptr)
225 return false;
226 }
227 return true;
228 };
229
230 auto all_control_predeccessors_in_hoists = [&hoists, &side_effect_analysis](
231 mlir::Operation *op) {
232 auto preds = side_effect_analysis.DirectControlPredecessors(op);
233 return std::all_of(
234 preds.begin(), preds.end(),
235 [&hoists](mlir::Operation *pred) { return hoists.count(pred) > 0; });
236 };
237
238 std::deque<mlir::Operation *> work_list;
239
240 // Start with ops with tf.VarHandleOp ops and tf.Const ops.
241 //
242 // TODO(chky): Consider allowing other ops including custom ops to be hoisted.
243 func.walk([&work_list](mlir::Operation *op) {
244 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op,
245 mlir::TF::ConstOp>(op))
246 work_list.push_back(op);
247 });
248
249 while (!work_list.empty()) {
250 auto *op = work_list.front();
251 work_list.pop_front();
252
253 // Skip if it is already hoisted.
254 if (hoists.count(op) > 0) continue;
255
256 // If the op can be hoisted, and all of its data dependencies and control
257 // dependencies are hoisted, then we hoist it. Otherwise, skip.
258 if (!(CanHoist(read_only_vars, op) && all_operands_in_hoists(op) &&
259 all_control_predeccessors_in_hoists(op)))
260 continue;
261
262 // Record the hoisted operation.
263 hoists.insert(op);
264 module_hoist_info.hoists_in_topological_order.push_back(op);
265
266 // Create a copy in the init function.
267 builder.clone(*op, module_hoist_info.value_mapping);
268
269 for (mlir::Operation *user : op->getUsers()) {
270 work_list.push_back(user);
271 }
272 }
273
274 // Find out the values that are produced by hoisted ops but used by
275 // non-hoisted ops. These values need to be replaced.
276 for (auto *op : hoists) {
277 for (auto result : op->getResults()) {
278 if (std::any_of(result.getUsers().begin(), result.getUsers().end(),
279 [&hoists](mlir::Operation *user) {
280 return hoists.count(user) == 0;
281 })) {
282 module_hoist_info.hoisted_values.push_back(
283 {result, GetResourceKey(op)});
284 }
285 }
286 }
287 }
288
FindCalleesRecursive(const mlir::SymbolTable & symbol_table,mlir::FuncOp func,llvm::StringSet<> & callees)289 void FindCalleesRecursive(const mlir::SymbolTable &symbol_table,
290 mlir::FuncOp func, llvm::StringSet<> &callees) {
291 assert(func);
292 func.walk([&](mlir::Operation *op) {
293 for (const auto &named_attr : op->getAttrs()) {
294 if (auto symbol_attr =
295 named_attr.second.dyn_cast<mlir::FlatSymbolRefAttr>()) {
296 auto symbol = symbol_attr.getValue();
297 if (!callees.contains(symbol)) {
298 callees.insert(symbol);
299
300 auto func = symbol_table.lookup<mlir::FuncOp>(symbol);
301 if (!func) continue;
302
303 FindCalleesRecursive(symbol_table, func, callees);
304 }
305 }
306 }
307 });
308 }
309
HoistInvariantOps(mlir::ModuleOp module)310 void HoistInvariantOps(mlir::ModuleOp module) {
311 mlir::SymbolTable symbol_table(module);
312
313 // Find all resources used in non-init functions.
314 llvm::DenseMap<ResourceKey, llvm::SmallVector<mlir::Operation *, 4>>
315 resources;
316
317 // Find all callees referenced in the initialization functions.
318 llvm::StringSet<> init_callees;
319
320 module.walk([&](mlir::Operation *op) {
321 if (llvm::isa<mlir::TF::VarHandleOp, mlir::TF::HashTableV2Op>(op)) {
322 auto func = op->getParentOfType<mlir::FuncOp>();
323 if (IsSessionInitializer(func)) return;
324 resources[GetResourceKey(op)].push_back(op);
325 } else if (auto func = llvm::dyn_cast<mlir::FuncOp>(op)) {
326 if (!IsSessionInitializer(func)) return;
327 FindCalleesRecursive(symbol_table, func, init_callees);
328 }
329 });
330
331 llvm::DenseSet<ResourceKey> read_only_vars;
332 for (const auto &iter : resources) {
333 const auto &key = iter.first;
334 const auto &vars = iter.second;
335 if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) {
336 for (auto *user : op->getUsers()) {
337 if (!OnlyHasReadEffect(user)) return false;
338 }
339 return true;
340 })) {
341 read_only_vars.insert(key);
342 }
343 }
344
345 mlir::TF::SideEffectAnalysis side_effect_analysis(module);
346
347 mlir::OpBuilder builder(&module.body());
348 // "_tfrt_resource_init" is the special function that executes all invariant
349 // ops (eg. read-only variables) used in the model. This function should be
350 // executed after user-specified initialization.
351 auto init_func_op = builder.create<mlir::FuncOp>(
352 module.getLoc(), "_tfrt_resource_init",
353 mlir::FunctionType::get(module.getContext(), /*inputs=*/{},
354 /*results=*/{}));
355 auto *block = init_func_op.addEntryBlock();
356 builder.setInsertionPointToStart(block);
357
358 HoistInfo module_hoist_info;
359
360 for (auto func : module.getOps<mlir::FuncOp>()) {
361 // Skips hoisting if this function is an init function or any callees,
362 // including recursive ones, of an init functions, because otherwise the
363 // hoisted values won't be initialized when this function is called.
364 if (IsSessionInitializer(func) || init_callees.contains(func.sym_name()) ||
365 func == init_func_op)
366 continue;
367
368 HoistInvariantOpsInFunction(func, read_only_vars,
369 side_effect_analysis.GetAnalysisForFunc(func),
370 builder, module_hoist_info);
371 }
372
373 // Create tf._TfrtSetResource ops in the init function.
374 for (auto iter : llvm::enumerate(module_hoist_info.hoisted_values)) {
375 mlir::Value value = iter.value().first;
376 int64_t index = iter.index();
377
378 auto new_value = module_hoist_info.value_mapping.lookup(value);
379 auto *new_op = new_value.getDefiningOp();
380 assert(new_op);
381 builder.setInsertionPointAfter(new_op);
382 auto set_resource_op = builder.create<mlir::TF::_TfrtSetResourceOp>(
383 new_op->getLoc(), new_value, index);
384
385 // Preserve the device attribute.
386 llvm::StringRef device = kCpuDeviceName;
387 if (auto device_attr = new_op->getAttrOfType<mlir::StringAttr>("device")) {
388 if (!device_attr.getValue().empty()) device = device_attr.getValue();
389 }
390 set_resource_op->setAttr("device", builder.getStringAttr(device));
391 }
392
393 builder.setInsertionPointToEnd(block);
394 // Finish building the init function by inserting an return op.
395 builder.create<mlir::ReturnOp>(init_func_op.getLoc());
396
397 // Now that we have the index for each value that will be replaced, we can
398 // create the tf._TfrtGetResource op in each function using these indices.
399 ReplaceHoistedValues(module_hoist_info.hoisted_values, builder);
400
401 // Lastly, erase the hoisted ops in reverse topological order.
402 for (auto *op :
403 llvm::reverse(module_hoist_info.hoists_in_topological_order)) {
404 assert(op->use_empty());
405 op->erase();
406 }
407 }
408
409 // This pass rewrites tf_saved_model dialect's ops according to TFRT's
410 // requirements:
411 //
412 // 1) Remove all tf_saved_model's attributes and ops.
413 // 2) Create a function for every exported names of the original function.
414 // 3) Promote all uses of global tensors from resource handles to the underlying
415 // tensors.
416 // 4) Hoist invariant ops (ie. guaranteed to return the same value on every
417 // invocation) for every non-init function.
418 //
419 class LowerTFSavedModelPass
420 : public mlir::PassWrapper<LowerTFSavedModelPass,
421 mlir::OperationPass<mlir::ModuleOp>> {
422 public:
LowerTFSavedModelPass(bool hoist_invariant_ops)423 explicit LowerTFSavedModelPass(bool hoist_invariant_ops) {
424 hoist_invariant_ops_ = hoist_invariant_ops;
425 }
426 LowerTFSavedModelPass() = default;
LowerTFSavedModelPass(const LowerTFSavedModelPass &)427 LowerTFSavedModelPass(const LowerTFSavedModelPass &) {}
428
getArgument() const429 llvm::StringRef getArgument() const final {
430 return "tfrt-lower-tf-savedmodel";
431 }
getDescription() const432 llvm::StringRef getDescription() const final {
433 return "Lower tf-saved-model ops according to TFRT's requirements.";
434 }
435
runOnOperation()436 void runOnOperation() override {
437 auto module = getOperation();
438
439 // TODO(b/185928201): Create a standalone pass for hoisting invariant ops so
440 // that it can be reusable and configurable in other contexts than saved
441 // models.
442 if (hoist_invariant_ops_) HoistInvariantOps(module);
443
444 // Skip non-savedmodel MLIR module.
445 if (!mlir::tf_saved_model::HasTfSavedModelSemantics(module)) return;
446
447 mlir::SymbolTable symbol_table(module);
448
449 // TODO(b/177590991): Remove PromoteGlobalTensors() once non lite MLIR
450 // importer is no longer used. PromoteGlobalTensors() is only used for non
451 // lite MLIR importer which rewrites resource variables to global_tensors.
452 // However, for many models it is not supported.
453 for (auto func : module.getOps<mlir::FuncOp>()) {
454 if (mlir::tf_saved_model::IsExported(func)) {
455 if (mlir::failed(PromoteGlobalTensors(func, symbol_table))) {
456 func.emitOpError("failed to promote resource variables.");
457 signalPassFailure();
458 return;
459 }
460 }
461 }
462
463 module->removeAttr("tf_saved_model.semantics");
464
465 mlir::OpBuilder builder(&getContext());
466 auto resource_id = builder.getIdentifier("tf.resource_name");
467 auto bound_id = builder.getIdentifier("tf_saved_model.bound_input");
468 auto path_id = builder.getIdentifier("tf_saved_model.index_path");
469
470 module.walk([resource_id, bound_id, path_id,
471 &builder](mlir::Operation *op) mutable {
472 if (auto func_op = llvm::dyn_cast<mlir::FuncOp>(op)) {
473 // Remove tf_saved_model specific function arg attributes.
474 for (unsigned i = 0, e = func_op.getNumArguments(); i != e; ++i) {
475 if (auto sym = func_op.getArgAttrOfType<mlir::FlatSymbolRefAttr>(
476 i, bound_id)) {
477 func_op.removeArgAttr(i, bound_id);
478 func_op.setArgAttr(i, resource_id,
479 builder.getStringAttr(sym.getValue()));
480 }
481 func_op.removeArgAttr(i, path_id);
482 }
483 for (unsigned i = 0, e = func_op.getNumResults(); i != e; ++i) {
484 func_op.removeResultAttr(i, bound_id);
485 func_op.removeResultAttr(i, path_id);
486 }
487 if (auto exported_names = func_op->getAttrOfType<mlir::ArrayAttr>(
488 "tf_saved_model.exported_names")) {
489 bool is_session_initializer = IsSessionInitializer(func_op);
490
491 // Create a function for each exported name.
492 //
493 // TODO(b/148477882): TFRT dialect should have similar concepts of
494 // exported names so that a function can be referenced by multiple
495 // exported names.
496 func_op->removeAttr("tf_saved_model.exported_names");
497 for (auto exported_name : exported_names) {
498 auto exported_func_op = func_op.clone();
499 exported_func_op.setName(
500 exported_name.cast<mlir::StringAttr>().getValue());
501
502 // If it is a session initializer, we want to maximize parallelism
503 // and do not perform any stream merge, to minimize latency.
504 //
505 // TODO(b/183219530): This is a workaround as the cost model used
506 // currently is not very accurate, and leads to performance
507 // regression on IO ops that are common in initialization functions.
508 if (is_session_initializer) {
509 exported_func_op->setAttr("tfrt.cost_threshold",
510 builder.getI64IntegerAttr(1));
511 }
512
513 builder.setInsertionPoint(func_op);
514 builder.insert(exported_func_op);
515 }
516 func_op.erase();
517 }
518 }
519 });
520
521 module.walk([](mlir::Operation *op) {
522 if (llvm::isa<mlir::tf_saved_model::TensorFlowSavedModelDialect>(
523 op->getDialect())) {
524 // Remove all tf_saved_model ops.
525 op->erase();
526 }
527 });
528 }
529
530 private:
531 // Promote global tensors used by an exported function.
532 mlir::LogicalResult PromoteGlobalTensors(
533 mlir::FuncOp op, const mlir::SymbolTable &symbol_table);
534
535 // Replace a function argument that is a resource hanndle with an argument of
536 // the underlying tensor type. It also replaces all its uses recursively.
537 mlir::LogicalResult PromoteFunctionArgument(
538 mlir::FuncOp func, unsigned arg_index, mlir::Type promoted_type,
539 const mlir::SymbolTable &symbol_table);
540
541 // Replace an operand that is a resource handle with an operand of the
542 // underlying type and replace all uses of this operation if the results are
543 // also promoted. If it is a control flow op, it will process the callees
544 // recursively. The original op will be invalidated in some cases.
545 mlir::LogicalResult PromoteOpOperand(mlir::Operation *op,
546 unsigned operand_number,
547 mlir::Value promoted,
548 const mlir::SymbolTable &symbol_table);
549
550 // Replace all uses of a resource handle value with its promoted version
551 // recursively.
552 mlir::LogicalResult PromoteValueUses(mlir::Value old, mlir::Value promoted,
553 const mlir::SymbolTable &symbol_table);
554
555 Option<bool> hoist_invariant_ops_{*this, "hoist-invariant-ops",
556 llvm::cl::desc("hoist-invariant-ops"),
557 llvm::cl::init(false)};
558 };
559
CompareTypes(mlir::TypeRange x,mlir::TypeRange y)560 static llvm::SmallVector<unsigned, 4> CompareTypes(mlir::TypeRange x,
561 mlir::TypeRange y) {
562 llvm::SmallVector<unsigned, 4> results;
563 assert(x.size() == y.size());
564 for (int i = 0, e = x.size(); i < e; ++i) {
565 if (x[i] != y[i]) results.push_back(i);
566 }
567 return results;
568 }
569
PromoteGlobalTensors(mlir::FuncOp op,const mlir::SymbolTable & symbol_table)570 mlir::LogicalResult LowerTFSavedModelPass::PromoteGlobalTensors(
571 mlir::FuncOp op, const mlir::SymbolTable &symbol_table) {
572 for (int i = 0, e = op.getNumArguments(); i < e; ++i) {
573 auto global_tensor_op = mlir::tf_saved_model::LookupBoundInputOfType<
574 mlir::tf_saved_model::GlobalTensorOp>(op, i, symbol_table);
575 if (!global_tensor_op) continue;
576
577 auto result_types = op.getType().getResults();
578 if (failed(PromoteFunctionArgument(op, i, global_tensor_op.type(),
579 symbol_table)))
580 return failure();
581
582 if (!CompareTypes(op.getType().getResults(), result_types).empty())
583 op.emitOpError("cannot promote exported functions's results");
584 }
585 return success();
586 }
587
PromoteFunctionArgument(mlir::FuncOp func,unsigned arg_index,mlir::Type promoted_type,const mlir::SymbolTable & symbol_table)588 mlir::LogicalResult LowerTFSavedModelPass::PromoteFunctionArgument(
589 mlir::FuncOp func, unsigned arg_index, mlir::Type promoted_type,
590 const mlir::SymbolTable &symbol_table) {
591 // Replace this argument before replacing its uses.
592 auto &block = func.front();
593 auto arg = block.getArgument(arg_index);
594
595 auto cleanup_on_failure = llvm::make_scope_exit(
596 [&, orig_type = arg.getType()]() { arg.setType(orig_type); });
597
598 arg.setType(promoted_type);
599
600 // Promote all uses of `arg`.
601 if (failed(PromoteValueUses(arg, arg, symbol_table))) return failure();
602
603 cleanup_on_failure.release();
604
605 // Update the function type accordingly.
606 auto return_op = llvm::cast<mlir::ReturnOp>(block.getTerminator());
607 auto new_results = return_op.operands();
608
609 func.setType(mlir::FunctionType::get(
610 func.getContext(), block.getArgumentTypes(), new_results.getTypes()));
611 return success();
612 }
613
PromoteOpOperand(mlir::Operation * op,unsigned operand_number,mlir::Value promoted,const mlir::SymbolTable & symbol_table)614 mlir::LogicalResult LowerTFSavedModelPass::PromoteOpOperand(
615 mlir::Operation *op, unsigned operand_number, mlir::Value promoted,
616 const mlir::SymbolTable &symbol_table) {
617 // TODO(chky): Consider a more scalable way to handling all read-only ops.
618
619 // If it is a ReadVariableOp, we just need to replace all its uses and erase
620 // this op.
621 if (auto read_var_op = llvm::dyn_cast<mlir::TF::ReadVariableOp>(op)) {
622 read_var_op.value().replaceAllUsesWith(promoted);
623 op->erase();
624 return success();
625 }
626
627 // Next, we handle control flow ops.
628 if (!llvm::isa<mlir::TF::IfOp, mlir::TF::CaseOp, mlir::TF::WhileOp,
629 mlir::CallOpInterface, mlir::TF::BatchFunctionOp,
630 mlir::ReturnOp>(op))
631 return op->emitOpError("unsupported users of resource variables");
632
633 llvm::SmallVector<unsigned, 2> promoted_result_indices;
634 auto update_promoted_result_indices =
635 [&promoted_result_indices](
636 mlir::Operation *op,
637 ArrayRef<mlir::Type> result_types) -> mlir::LogicalResult {
638 if (op->getNumResults() != result_types.size())
639 return op->emitOpError(
640 "cannot promote call ops whose op resutls do not fully match the "
641 "callee results");
642
643 auto result = CompareTypes(op->getResultTypes(), result_types);
644 if (promoted_result_indices.empty()) {
645 promoted_result_indices.assign(result.begin(), result.end());
646 } else {
647 // We cannot handle the case where two branches' results are promoted
648 // differently.
649 if (promoted_result_indices != result)
650 return op->emitOpError(
651 "cannot promote callees with different result types");
652 }
653 return success();
654 };
655
656 if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
657 if (operand_number == 0)
658 return if_op.emitOpError("cannot promote cond tensor for tf.If");
659
660 auto then_branch = symbol_table.lookup<mlir::FuncOp>(if_op.then_branch());
661 auto else_branch = symbol_table.lookup<mlir::FuncOp>(if_op.else_branch());
662 assert(then_branch);
663 assert(else_branch);
664
665 unsigned arg_index = operand_number - 1;
666 for (auto func : {then_branch, else_branch}) {
667 if (func.getType().getInput(arg_index) != promoted.getType()) {
668 // Rescursively promote the uses in branches.
669 if (failed(PromoteFunctionArgument(func, arg_index, promoted.getType(),
670 symbol_table)))
671 return failure();
672 }
673
674 if (failed(update_promoted_result_indices(if_op,
675 func.getType().getResults())))
676 return failure();
677 }
678 } else if (auto case_op = llvm::dyn_cast<mlir::TF::CaseOp>(op)) {
679 assert(operand_number > 0);
680 unsigned arg_index = operand_number - 1;
681 for (auto branch_attr : case_op.branches()) {
682 auto branch = symbol_table.lookup<mlir::FuncOp>(
683 branch_attr.cast<mlir::FlatSymbolRefAttr>().getValue());
684
685 if (branch.getType().getInput(arg_index) != promoted.getType()) {
686 // Rescursively promote the uses in branches.
687 if (failed(PromoteFunctionArgument(branch, arg_index,
688 promoted.getType(), symbol_table)))
689 return failure();
690 }
691
692 if (failed(update_promoted_result_indices(case_op,
693 branch.getType().getResults())))
694 return failure();
695 }
696 } else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
697 auto cond = symbol_table.lookup<mlir::FuncOp>(while_op.cond());
698 auto body = symbol_table.lookup<mlir::FuncOp>(while_op.body());
699 assert(cond);
700 assert(body);
701
702 unsigned arg_index = operand_number;
703 if (cond.getType().getInput(arg_index) != promoted.getType()) {
704 auto cond_result_type = cond.getType().getResult(0);
705 if (failed(PromoteFunctionArgument(cond, arg_index, promoted.getType(),
706 symbol_table)))
707 return failure();
708
709 // We cannot promote the result of cond branch as it may change the
710 // behavior of this while op.
711 if (cond_result_type != cond.getType().getResult(0))
712 return while_op.emitOpError("failed to promote cond for tf.While");
713 }
714
715 if (body.getType().getInput(arg_index) != promoted.getType()) {
716 if (failed(PromoteFunctionArgument(body, /*arg_index=*/operand_number,
717 promoted.getType(), symbol_table)))
718 return failure();
719 }
720
721 if (failed(update_promoted_result_indices(while_op,
722 body.getType().getResults())))
723 return failure();
724
725 } else if (auto call_interface = llvm::dyn_cast<mlir::CallOpInterface>(op)) {
726 auto callee_name = call_interface.getCallableForCallee()
727 .get<mlir::SymbolRefAttr>()
728 .cast<mlir::FlatSymbolRefAttr>()
729 .getValue();
730 auto callee = symbol_table.lookup<mlir::FuncOp>(callee_name);
731 assert(callee);
732
733 unsigned arg_index =
734 operand_number - call_interface.getArgOperands().getBeginOperandIndex();
735 if (callee.getType().getInput(arg_index) != promoted.getType()) {
736 if (failed(PromoteFunctionArgument(callee, arg_index, promoted.getType(),
737 symbol_table)))
738 return failure();
739 }
740
741 if (failed(
742 update_promoted_result_indices(op, callee.getType().getResults())))
743 return failure();
744 } else if (auto batch_function_op =
745 llvm::dyn_cast<mlir::TF::BatchFunctionOp>(op)) {
746 auto batch_fn = symbol_table.lookup<mlir::FuncOp>(
747 batch_function_op.f().getRootReference());
748 assert(batch_fn);
749
750 unsigned arg_index = operand_number;
751 if (batch_fn.getType().getInput(arg_index) != promoted.getType()) {
752 if (failed(PromoteFunctionArgument(batch_fn, arg_index,
753 promoted.getType(), symbol_table)))
754 return failure();
755 }
756
757 if (failed(update_promoted_result_indices(op,
758 batch_fn.getType().getResults())))
759 return failure();
760 }
761
762 // Replace the operand.
763 op->setOperand(operand_number, promoted);
764
765 if (promoted_result_indices.empty()) return success();
766
767 // If results are also promoted, we need to create a new op with the new
768 // results and replaces all uses recursively.
769
770 mlir::OpBuilder builder(op);
771
772 llvm::SmallVector<mlir::Type, 4> new_result_types(op->result_type_begin(),
773 op->result_type_end());
774 for (unsigned result_number : promoted_result_indices) {
775 new_result_types[result_number] = promoted.getType();
776 }
777
778 mlir::OperationState state(op->getLoc(), op->getName());
779 state.addOperands(op->getOperands());
780 state.addTypes(new_result_types);
781 state.addAttributes(op->getAttrs());
782
783 auto *new_op = builder.createOperation(state);
784
785 // Replace all uses of `op`, and recursively replace those promoted uses.
786 for (unsigned i = 0, j = 0, e = op->getNumResults(); i < e; ++i) {
787 if (j < promoted_result_indices.size() && promoted_result_indices[j] == i) {
788 j++;
789 if (failed(PromoteValueUses(op->getResult(i), new_op->getResult(i),
790 symbol_table))) {
791 // On failure, replace all uses of new_op with op and erase the new op.
792 new_op->replaceAllUsesWith(op);
793 new_op->erase();
794 return failure();
795 }
796 } else {
797 op->getResult(i).replaceAllUsesWith(new_op->getResult(i));
798 }
799 }
800
801 // On success, erase the original op.
802 op->erase();
803
804 return success();
805 }
806
PromoteValueUses(mlir::Value old,mlir::Value promoted,const mlir::SymbolTable & symbol_table)807 mlir::LogicalResult LowerTFSavedModelPass::PromoteValueUses(
808 mlir::Value old, mlir::Value promoted,
809 const mlir::SymbolTable &symbol_table) {
810 // Retrieve the current uses before replacing the uses as the use list can be
811 // invalidated later.
812 llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4> uses;
813 for (auto &use : old.getUses())
814 uses.push_back({use.getOwner(), use.getOperandNumber()});
815
816 // Replace uses recursively.
817 for (const auto &use : uses) {
818 if (failed(PromoteOpOperand(/*op=*/use.first,
819 /*operand_number=*/use.second, promoted,
820 symbol_table)))
821 return failure();
822 }
823
824 return success();
825 }
826
827 // Converts ref variables to resource variables in a few cases.
828 //
829 // If the users of one variable in the entire module satisfies the following
830 // condition, it will be converted to resource variable:
831 //
832 // 1) tf.Identity op
833 // 2) tf.Assign op
834 // 3) side-effect-free ops: This is also the TF1 behavior that the TF executor
835 // will automatically convert ref tensors to non-ref tensors if the user is
836 // not expecting a ref tensor. Refer to
837 // http://cs?q=tensorflow/core/common_runtime/executor.cc:932%20at_cl:356873227
838 class ConvertReferenceVariableToResourceVariablePass
839 : public mlir::PassWrapper<ConvertReferenceVariableToResourceVariablePass,
840 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const841 llvm::StringRef getArgument() const final {
842 return "tfrt-convert-ref-variables";
843 }
getDescription() const844 llvm::StringRef getDescription() const final {
845 return "Convert reference variable to resource variables.";
846 }
847 void runOnOperation() override;
848 };
849
ConvertReferenceVariableToResourceVariable(mlir::TF::VariableV2Op var_op)850 mlir::LogicalResult ConvertReferenceVariableToResourceVariable(
851 mlir::TF::VariableV2Op var_op) {
852 auto tensor_type =
853 mlir::TF::DropRefType(var_op.ref().getType()).cast<mlir::TensorType>();
854
855 llvm::SmallVector<mlir::TF::IdentityOp, 4> identity_ops;
856 llvm::SmallVector<mlir::TF::AssignOp, 4> assign_ops;
857 llvm::SmallVector<std::pair<mlir::Operation *, unsigned>, 4>
858 side_effect_free_ops;
859
860 for (mlir::OpOperand &use : var_op.ref().getUses()) {
861 mlir::Operation *user = use.getOwner();
862
863 if (auto identity = llvm::dyn_cast<mlir::TF::IdentityOp>(user)) {
864 identity_ops.push_back(identity);
865 continue;
866 } else if (auto assign = llvm::dyn_cast<mlir::TF::AssignOp>(user)) {
867 // Conservatively we only allow the case that the output of this tf.Assign
868 // is not consumed by any other ops.
869 if (assign.output_ref().use_empty()) {
870 assign_ops.push_back(assign);
871 continue;
872 }
873 } else if (mlir::MemoryEffectOpInterface::hasNoEffect(user)) {
874 side_effect_free_ops.push_back({user, use.getOperandNumber()});
875 continue;
876 }
877
878 return var_op.emitOpError()
879 << "failed to convert reference variables with unexpected users. "
880 << *user;
881 }
882
883 mlir::OpBuilder builder(var_op);
884
885 auto var_handle_op = builder.create<mlir::TF::VarHandleOp>(
886 var_op.getLoc(),
887 mlir::RankedTensorType::get(
888 {},
889 mlir::TF::ResourceType::get(ArrayRef<mlir::TensorType>{tensor_type},
890 builder.getContext())),
891 var_op.container(), var_op.shared_name());
892
893 for (auto op : identity_ops) {
894 // Set insertion point to this identity_op so that the side-effect
895 // visibility is preserved.
896 builder.setInsertionPoint(op);
897 auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
898 op.getLoc(), op.getType(), var_handle_op);
899 op.replaceAllUsesWith(read_var_op.value());
900 op.erase();
901 }
902
903 for (auto op : assign_ops) {
904 // Set the insertion point after the assign op so that all operands are
905 // dominating the newly created op.
906 builder.setInsertionPoint(op);
907 builder.create<mlir::TF::AssignVariableOp>(op.getLoc(), var_handle_op,
908 op.value());
909 op.erase();
910 }
911
912 for (auto pair : side_effect_free_ops) {
913 mlir::Operation *op = pair.first;
914 unsigned idx = pair.second;
915 // Set the insertion point after the op so that all operands are dominating
916 // the newly created op.
917 builder.setInsertionPoint(op);
918 // Create a new read variable op, so that the side-effects are preserved.
919 auto read_var_op = builder.create<mlir::TF::ReadVariableOp>(
920 op->getLoc(), tensor_type, var_handle_op);
921 op->setOperand(idx, read_var_op.value());
922 }
923
924 return success();
925 }
926
runOnOperation()927 void ConvertReferenceVariableToResourceVariablePass::runOnOperation() {
928 auto module = getOperation();
929
930 // The key here is a tuple of device, container and shared_name to uniquely
931 // identify a variable.
932 llvm::DenseMap<std::tuple<llvm::StringRef, llvm::StringRef, llvm::StringRef>,
933 llvm::SmallVector<mlir::TF::VariableV2Op, 4>>
934 ref_vars;
935
936 // First, we collect all variables' corresponding tf.VariableV2 ops.
937 module.walk([&ref_vars](mlir::TF::VariableV2Op op) {
938 if (op.shared_name().empty()) {
939 op.emitOpError()
940 << "unable to convert reference variables with empty shared_names.";
941 return mlir::WalkResult::interrupt();
942 }
943
944 llvm::StringRef device;
945 if (auto device_attr = op->getAttrOfType<mlir::StringAttr>("device")) {
946 device = device_attr.getValue();
947 }
948
949 ref_vars[{device, op.container(), op.shared_name()}].push_back(op);
950
951 return mlir::WalkResult::advance();
952 });
953
954 // Then we perform rewrite for each variable if possible.
955 for (const auto &iter : ref_vars) {
956 const auto &var_ops = iter.second;
957
958 for (auto var_op : var_ops) {
959 if (mlir::succeeded(ConvertReferenceVariableToResourceVariable(var_op)))
960 var_op.erase();
961 }
962 }
963 }
964
965 } // namespace
966
967 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateLowerTFSavedModelPass(bool hoist_invariant_ops)968 CreateLowerTFSavedModelPass(bool hoist_invariant_ops) {
969 return std::make_unique<LowerTFSavedModelPass>(hoist_invariant_ops);
970 }
971
972 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateConvertReferenceVariableToResourceVariablePass()973 CreateConvertReferenceVariableToResourceVariablePass() {
974 return std::make_unique<ConvertReferenceVariableToResourceVariablePass>();
975 }
976
977 static mlir::PassRegistration<LowerTFSavedModelPass> saved_model_pass;
978
979 static mlir::PassRegistration<ConvertReferenceVariableToResourceVariablePass>
980 ref_var_pass;
981
982 } // namespace tensorflow
983