• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LLVM.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/types.h"
50 
51 namespace mlir {
52 
53 namespace {
54 
55 namespace cutil = TF::collection_ops_util;
56 
57 // A pass that converts stack operations to tensor operations and read/assign
58 // ops on local variables. A later resource lifting pass can further remove the
59 // local variables.
60 //
61 // This pass requires that the full shape of the stack can be inferred: 1) the
62 // maximum size needs to be a constant and 2) a push op can be found with a
63 // known shape, and all push ops need to have the same shape.
64 //
65 // A stack creation op "tf.StackV2" will be turned in to two zero-initialized
66 // variables, for the buffer and current size. Each push will be turned into
67 //   %old_val = "tf.ReadVariableOp"(%buffer)
68 //   %old_size = "tf.ReadVariableOp"(%size)
69 //   %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
70 //   %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
71 //   "tf.AssignVariableOp"(%buffer, %new_val)
72 //   %new_size = "tf.AddV2"(%old_size, %const1)
73 //   "tf.AssignVariableOp"(%size, %new_size)
74 //
75 // and each pop will be turned into
76 //
77 //   %old_val = "tf.ReadVariableOp"(%buffer)
78 //   %old_size = "tf.ReadVariableOp"(%size)
79 //   %new_size = "tf.Sub"(%old_size, %const1)
80 //   %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
81 //   %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
82 //   %pop_result = "tf.Reshape"(%slice, %elem_size_const)
83 //   "tf.AssignVariableOp"(%size, %new_size)
84 //
85 // The pass also works across control flow and functional calls.
86 struct StackOpsDecompositionPass
87     : public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
getArgumentmlir::__anon7dffa0050111::StackOpsDecompositionPass88   StringRef getArgument() const final { return "tf-stack-ops-decomposition"; }
89 
getDescriptionmlir::__anon7dffa0050111::StackOpsDecompositionPass90   StringRef getDescription() const final {
91     return "Decompose stack operations into local variable operations. Needs "
92            "static shapes.";
93   }
94 
95   void runOnOperation() override;
96 };
97 
98 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)99 Type GetSizeVarType(OpBuilder builder) {
100   auto size_type = cutil::GetSizeType(builder);
101   return RankedTensorType::get(
102       {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
103                                 builder.getContext()));
104 }
105 
106 // Returns the aliasing argument number of a fucntion return value if it simply
107 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(FuncOp func,int64_t return_index)108 int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
109   Value return_val = func.front().getTerminator()->getOperand(return_index);
110   auto maybe_arg = return_val.dyn_cast<BlockArgument>();
111   if (!maybe_arg) return -1;
112   return maybe_arg.getArgNumber();
113 }
114 
115 // Changes the function signature that has stacks in the arguments. A stack
116 // argument will be turned into a variable type if arg_to_stack_type returns
117 // such a type, and a new argument will be added to the end of the argument
118 // list for the size variable.
119 //
120 // If stack_var_to_size_var is not nullptr, it will  be used to store the
121 // mapping from the stack-variable argument to the size-variable argument.
122 //
123 // If handle_new_size_vars is provided, it will be invoked on the list of new
124 // size variables before finally changing the function type.
ModifyFunctionSignature(FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)125 void ModifyFunctionSignature(
126     FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
127     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
128     llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
129         nullptr) {
130   auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
131   auto size_var_type = GetSizeVarType(OpBuilder(func));
132   int64_t original_arg_count = new_input_types.size();
133   for (int64_t i = 0; i < original_arg_count; ++i) {
134     auto stack_type = arg_to_stack_type(i);
135     if (!stack_type.hasValue()) continue;
136     func.getArgument(i).setType(*stack_type);
137     new_input_types[i] = *stack_type;
138     auto size_arg = func.front().addArgument(size_var_type);
139     new_input_types.push_back(size_arg.getType());
140     if (stack_var_to_size_var) {
141       (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
142     }
143   }
144   if (handle_new_size_vars) {
145     handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
146   }
147   func.setType(
148       FunctionType::get(func.getContext(), new_input_types,
149                         func.front().getTerminator()->getOperandTypes()));
150 }
151 
152 // Contains cached information for decomposed callee functions for (stateful)
153 // partitioned call ops.
154 struct PartitionedCallStackOpsInfo {
155   bool signature_change;
156   FuncOp decomposed_callee;
157   llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
158 };
159 
160 LogicalResult DecomposeStackOpsInternal(
161     Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
162     llvm::StringMap<PartitionedCallStackOpsInfo>*);
163 
164 // Handles stack usage by a tf.While. It will convert the body and conditional
165 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)166 LogicalResult HandleWhileOp(
167     TF::WhileOp while_op, ModuleOp module,
168     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
169     llvm::StringMap<PartitionedCallStackOpsInfo>*
170         decomposed_partitioned_call_callees) {
171   auto body = while_op.body_function();
172   llvm::SmallDenseMap<Value, Value> body_map;
173   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
174     auto it = data_var_to_size_var.find(while_op.getOperand(index));
175     if (it == data_var_to_size_var.end()) return llvm::None;
176     return it->getFirst().getType();
177   };
178   auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
179     if (new_args.empty()) return;
180     auto body_ret = body.front().getTerminator();
181     auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
182     for (auto arg : new_args) new_body_returns.push_back(arg);
183     OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
184     body_ret->erase();
185   };
186   // Handle body.
187   ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
188                           add_size_vars_to_return);
189   const bool signature_change = !body_map.empty();
190   if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
191                                        decomposed_partitioned_call_callees))) {
192     return failure();
193   }
194   // Cond should not change stacks in the arguments, so use an empty map.
195   auto cond = while_op.cond_function();
196   ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
197   llvm::SmallDenseMap<Value, Value> empty_map;
198   if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
199                                        decomposed_partitioned_call_callees))) {
200     return failure();
201   }
202   if (!signature_change) return success();
203   // Create the new while op.
204   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
205   OpBuilder builder(while_op);
206   assert(while_op.getNumOperands() == while_op.getNumResults());
207   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
208     auto it = data_var_to_size_var.find(while_op.getOperand(i));
209     if (it == data_var_to_size_var.end()) continue;
210     new_while_operands.push_back(it->getSecond());
211   }
212   auto new_while =
213       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
214                                   new_while_operands, while_op->getAttrs());
215   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
216     if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
217              .isa<TF::ResourceType>()) {
218       continue;
219     }
220     int64_t aliased_input = FindAliasedInput(body, i);
221     if (aliased_input == i) {
222       // Replace aliased stack output uses with input.
223       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
224     }
225   }
226   while_op.replaceAllUsesWith(
227       new_while.getResults().take_front(while_op.getNumResults()));
228   while_op.erase();
229   return success();
230 }
231 
232 // Handles stack usage by a tf.If. It will convert the branch function
233 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)234 LogicalResult HandleIfOp(
235     TF::IfOp if_op, ModuleOp module,
236     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
237     llvm::StringMap<PartitionedCallStackOpsInfo>*
238         decomposed_partitioned_call_callees) {
239   auto then_func = if_op.then_function();
240   auto else_func = if_op.else_function();
241   llvm::SmallDenseMap<Value, Value> then_map;
242   llvm::SmallDenseMap<Value, Value> else_map;
243 
244   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
245     auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
246     if (it == data_var_to_size_var.end()) return llvm::None;
247     return it->getFirst().getType();
248   };
249   ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
250   ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
251   const bool signature_change = !then_map.empty() || !else_map.empty();
252   if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
253                                        decomposed_partitioned_call_callees)) ||
254       failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
255                                        decomposed_partitioned_call_callees))) {
256     return failure();
257   }
258   if (!signature_change) return success();
259   auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
260   for (auto operand : if_op.getOperands()) {
261     auto it = data_var_to_size_var.find(operand);
262     if (it == data_var_to_size_var.end()) continue;
263     new_if_operands.push_back(it->getSecond());
264   }
265   auto new_if = OpBuilder(if_op).create<TF::IfOp>(
266       if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
267       if_op->getAttrs());
268   for (auto result : if_op.getResults()) {
269     if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
270       continue;
271     }
272     int64_t then_aliased_input =
273         FindAliasedInput(then_func, result.getResultNumber());
274     int64_t else_aliased_input =
275         FindAliasedInput(else_func, result.getResultNumber());
276     if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
277       // Replace aliased stack output uses with input.
278       result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
279     }
280   }
281   if_op.replaceAllUsesWith(new_if);
282   if_op.erase();
283   return success();
284 }
285 
286 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
287 // It will first check if the callee was previously handled, and try to reuse
288 // that result if so. Otherwise, it will clone and convert the callee function,
289 // and performs stack ops decomposition on it.
290 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)291 LogicalResult HandlePartitionedCallOp(
292     CallOp call, FuncOp callee, ModuleOp module,
293     const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
294     llvm::StringMap<PartitionedCallStackOpsInfo>*
295         decomposed_partitioned_call_callees) {
296   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
297       callee.getName(), PartitionedCallStackOpsInfo());
298   auto& info = emplace_res.first->second;
299   // Recreate the call op with info.
300   auto recreate_caller = [&] {
301     auto new_operands = llvm::to_vector<8>(call.getOperands());
302     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
303       auto arg_it = info.stack_var_arg_to_size_arg.find(i);
304       if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
305       auto it = data_var_to_size_var.find(call.getOperand(i));
306       if (it == data_var_to_size_var.end()) {
307         call.emitOpError("unknown stack");
308         return failure();
309       }
310       assert(arg_it->second == new_operands.size());
311       new_operands.push_back(it->getSecond());
312     }
313     OpBuilder builder(call);
314     auto new_call = builder.create<CallOp>(
315         call.getLoc(), info.decomposed_callee.getType().getResults(),
316         new_operands, call->getAttrs());
317     new_call->setAttr(
318         "f", builder.getSymbolRefAttr(
319                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
320     for (int64_t i = 0; i < call.getNumResults(); ++i) {
321       auto result = call.getResult(i);
322       if (!getElementTypeOrSelf(result.getType())
323                .template isa<TF::ResourceType>()) {
324         continue;
325       }
326       int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
327       if (aliased_input >= 0) {
328         // Replace aliased stack output uses with input.
329         result.replaceAllUsesWith(call.getOperand(aliased_input));
330       }
331     }
332     call.replaceAllUsesWith(new_call);
333     call.erase();
334     return success();
335   };
336   if (!emplace_res.second) {
337     // This callee was handled before.
338     if (!info.signature_change) return success();
339     return recreate_caller();
340   }
341   llvm::SmallDenseMap<Value, Value> callee_map;
342   FuncOp lowered_callee = callee;
343   if (!callee.isPrivate()) {
344     // Clone non-private callee in case of signature change.
345     lowered_callee = callee.clone();
346     lowered_callee.setPrivate();
347   }
348   auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
349     auto it = data_var_to_size_var.find(call.getOperand(index));
350     if (it == data_var_to_size_var.end()) return llvm::None;
351     return it->getFirst().getType();
352   };
353   ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
354   info.signature_change = !callee_map.empty();
355   if (!info.signature_change) {
356     // Signature is not modified. We do not need the clone.
357     if (lowered_callee != callee) {
358       lowered_callee.erase();
359     }
360   } else {
361     info.decomposed_callee = lowered_callee;
362     for (auto& entry : callee_map) {
363       info.stack_var_arg_to_size_arg
364           [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
365           entry.getSecond().cast<BlockArgument>().getArgNumber();
366     }
367     if (lowered_callee != callee) {
368       // Add the clone with a new name.
369       lowered_callee.setName(
370           llvm::formatv("{0}_stack_decomposed", callee.getName()).str());
371       SymbolTable(module).insert(lowered_callee);
372       callee = lowered_callee;
373     }
374   }
375   if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
376                                        decomposed_partitioned_call_callees))) {
377     return failure();
378   }
379   if (info.signature_change) return recreate_caller();
380   return success();
381 }
382 
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)383 LogicalResult HandleStackV2Op(
384     TF::StackV2Op stack, ModuleOp module,
385     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
386   // Create a buffer variable and a size variable to replace the stack.
387   auto elem_type = cutil::GetElementTypeFromAccess(
388       stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
389         auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
390         if (!push) return llvm::None;
391         return push.elem().getType();
392       });
393   if (!elem_type.hasValue()) {
394     return stack.emitOpError("cannot infer element shape of stack");
395   }
396   OpBuilder builder(stack);
397   Value buffer;
398   if (failed(cutil::CreateInitBufferValue(
399           elem_type->getShape(), stack.max_size(), stack,
400           elem_type->getElementType(), builder, &buffer))) {
401     return failure();
402   }
403   auto size_var_type = GetSizeVarType(builder);
404   auto var_type = RankedTensorType::get(
405       {}, TF::ResourceType::get(
406               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
407               stack.getContext()));
408   auto local_var = builder.create<TF::MlirLocalVarOp>(
409       stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
410   auto local_size_var = builder.create<TF::MlirLocalVarOp>(
411       stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
412   // Zero-initialize the local vars.
413   cutil::WriteLocalVariable(local_size_var,
414                             cutil::GetR1Const({0LL}, builder, stack.getLoc()),
415                             builder, stack.getLoc());
416   cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
417   stack.handle().replaceAllUsesWith(local_var);
418   (*data_var_to_size_var)[local_var] = local_size_var;
419   stack.erase();
420   return success();
421 }
422 
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)423 LogicalResult HandleStackPushV2Op(
424     TF::StackPushV2Op push,
425     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
426   auto it = data_var_to_size_var->find(push.handle());
427   if (it == data_var_to_size_var->end()) {
428     return push.emitOpError("unknown stack");
429   }
430   // Push output simply forward the input element.
431   push.replaceAllUsesWith(push.elem());
432   OpBuilder builder(push);
433   // Read the current buffer and size.
434   auto stack_val =
435       cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
436   auto index =
437       cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
438   stack_val =
439       cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
440   // Assign the new buffer and size.
441   cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
442   index = builder.create<TF::AddV2Op>(
443       push.getLoc(), ArrayRef<Type>{index.getType()},
444       ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
445   cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
446   push.erase();
447   return success();
448 }
449 
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)450 LogicalResult HandleStackPopV2Op(
451     TF::StackPopV2Op pop,
452     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
453   auto it = data_var_to_size_var->find(pop.handle());
454   if (it == data_var_to_size_var->end()) {
455     return pop.emitOpError("unknown stack");
456   }
457   OpBuilder builder(pop);
458   // Read the current buffer and size.
459   auto stack_val =
460       cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
461   auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
462   auto new_size = builder.create<TF::SubOp>(
463       pop.getLoc(), ArrayRef<Type>{size.getType()},
464       ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
465   auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
466   pop.replaceAllUsesWith(pop_val);
467   // Update the size.
468   cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
469   pop.erase();
470   return success();
471 }
472 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)473 LogicalResult HandleRegionControlFlowOps(
474     Operation& op, ModuleOp module,
475     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
476     llvm::StringMap<PartitionedCallStackOpsInfo>*
477         decomposed_partitioned_call_callees) {
478   for (OpOperand& operand : op.getOpOperands()) {
479     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
480       return op.emitOpError()
481              << "found unexpected type " << operand.get().getType()
482              << " of operand #" << operand.getOperandNumber()
483              << ", resource type operands are expected to have been "
484                 "canonicalized away for region based control flow ops";
485     }
486   }
487   for (OpResult result : op.getResults()) {
488     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
489       return op.emitOpError()
490              << "found unexpected type " << result.getType() << " of result #"
491              << result.getResultNumber()
492              << ", resource type results are expected to have been "
493                 "canonicalized away for region based control flow ops";
494     }
495   }
496   for (Region& region : op.getRegions()) {
497     if (failed(DecomposeStackOpsInternal(&region.front(), module,
498                                          data_var_to_size_var,
499                                          decomposed_partitioned_call_callees)))
500       return failure();
501   }
502   return success();
503 }
504 
505 // Decomposes stack ops on a region and recursively decomposes called functions.
506 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
507 // local variables.
508 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
509 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)510 LogicalResult DecomposeStackOpsInternal(
511     Block* block, ModuleOp module,
512     llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
513     llvm::StringMap<PartitionedCallStackOpsInfo>*
514         decomposed_partitioned_call_callees) {
515   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
516     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
517       // Removes identity nodes in the block. The device computation does not
518       // need such nodes to carry information.
519       op.replaceAllUsesWith(op.getOperands());
520       op.erase();
521     } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
522       if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
523         return failure();
524       }
525     } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
526       if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
527         return failure();
528       }
529     } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
530       if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
531         return failure();
532       }
533     } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
534       data_var_to_size_var->erase(close.handle());
535       close.erase();
536     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
537       if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
538                                decomposed_partitioned_call_callees))) {
539         return failure();
540       }
541     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
542       if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
543                             decomposed_partitioned_call_callees))) {
544         return failure();
545       }
546     } else if (llvm::isa<TF::WhileRegionOp>(op) ||
547                llvm::isa<TF::IfRegionOp>(op) ||
548                llvm::isa<TF::CaseRegionOp>(op)) {
549       if (failed(
550               HandleRegionControlFlowOps(op, module, data_var_to_size_var,
551                                          decomposed_partitioned_call_callees)))
552         return failure();
553     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
554       if (!pcall.func()) {
555         return pcall.emitOpError(
556             "stack decomposition does not support call with nested references");
557       }
558       if (failed(HandlePartitionedCallOp(
559               pcall, pcall.func(), module, *data_var_to_size_var,
560               decomposed_partitioned_call_callees))) {
561         return failure();
562       }
563     } else if (auto spcall =
564                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
565       if (failed(HandlePartitionedCallOp(
566               spcall, spcall.func(), module, *data_var_to_size_var,
567               decomposed_partitioned_call_callees))) {
568         return failure();
569       }
570     }
571   }
572   return success();
573 }
574 
DecomposeStackOps(Block * block,ModuleOp module)575 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
576   llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
577   llvm::StringMap<PartitionedCallStackOpsInfo>
578       decomposed_partitioned_call_callees;
579   return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
580                                    &decomposed_partitioned_call_callees);
581 }
582 
runOnOperation()583 void StackOpsDecompositionPass::runOnOperation() {
584   auto module = getOperation();
585   auto main = module.lookupSymbol<FuncOp>("main");
586   if (!main) return;
587   if (failed(DecomposeStackOps(&main.front(), module))) {
588     signalPassFailure();
589   }
590 }
591 
592 static PassRegistration<StackOpsDecompositionPass> pass;
593 
594 }  // namespace
595 
596 namespace TF {
CreateStackOpsDecompositionPass()597 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
598   return std::make_unique<StackOpsDecompositionPass>();
599 }
600 
601 }  // namespace TF
602 }  // namespace mlir
603