1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/MLIRContext.h" // from @llvm-project
33 #include "mlir/IR/SymbolTable.h" // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LLVM.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/framework/types.pb.h"
49 #include "tensorflow/core/platform/types.h"
50
51 namespace mlir {
52
53 namespace {
54
55 namespace cutil = TF::collection_ops_util;
56
57 // A pass that converts stack operations to tensor operations and read/assign
58 // ops on local variables. A later resource lifting pass can further remove the
59 // local variables.
60 //
61 // This pass requires that the full shape of the stack can be inferred: 1) the
62 // maximum size needs to be a constant and 2) a push op can be found with a
63 // known shape, and all push ops need to have the same shape.
64 //
65 // A stack creation op "tf.StackV2" will be turned in to two zero-initialized
66 // variables, for the buffer and current size. Each push will be turned into
67 // %old_val = "tf.ReadVariableOp"(%buffer)
68 // %old_size = "tf.ReadVariableOp"(%size)
69 // %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
70 // %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
71 // "tf.AssignVariableOp"(%buffer, %new_val)
72 // %new_size = "tf.AddV2"(%old_size, %const1)
73 // "tf.AssignVariableOp"(%size, %new_size)
74 //
75 // and each pop will be turned into
76 //
77 // %old_val = "tf.ReadVariableOp"(%buffer)
78 // %old_size = "tf.ReadVariableOp"(%size)
79 // %new_size = "tf.Sub"(%old_size, %const1)
80 // %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
81 // %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
82 // %pop_result = "tf.Reshape"(%slice, %elem_size_const)
83 // "tf.AssignVariableOp"(%size, %new_size)
84 //
85 // The pass also works across control flow and functional calls.
86 struct StackOpsDecompositionPass
87 : public PassWrapper<StackOpsDecompositionPass, OperationPass<ModuleOp>> {
getArgumentmlir::__anon7dffa0050111::StackOpsDecompositionPass88 StringRef getArgument() const final { return "tf-stack-ops-decomposition"; }
89
getDescriptionmlir::__anon7dffa0050111::StackOpsDecompositionPass90 StringRef getDescription() const final {
91 return "Decompose stack operations into local variable operations. Needs "
92 "static shapes.";
93 }
94
95 void runOnOperation() override;
96 };
97
98 // Returns the type of the local variable for the stack size.
GetSizeVarType(OpBuilder builder)99 Type GetSizeVarType(OpBuilder builder) {
100 auto size_type = cutil::GetSizeType(builder);
101 return RankedTensorType::get(
102 {}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
103 builder.getContext()));
104 }
105
106 // Returns the aliasing argument number of a fucntion return value if it simply
107 // forwards the argument. Otherwise, returns -1.
FindAliasedInput(FuncOp func,int64_t return_index)108 int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
109 Value return_val = func.front().getTerminator()->getOperand(return_index);
110 auto maybe_arg = return_val.dyn_cast<BlockArgument>();
111 if (!maybe_arg) return -1;
112 return maybe_arg.getArgNumber();
113 }
114
115 // Changes the function signature that has stacks in the arguments. A stack
116 // argument will be turned into a variable type if arg_to_stack_type returns
117 // such a type, and a new argument will be added to the end of the argument
118 // list for the size variable.
119 //
120 // If stack_var_to_size_var is not nullptr, it will be used to store the
121 // mapping from the stack-variable argument to the size-variable argument.
122 //
123 // If handle_new_size_vars is provided, it will be invoked on the list of new
124 // size variables before finally changing the function type.
ModifyFunctionSignature(FuncOp func,llvm::SmallDenseMap<Value,Value> * stack_var_to_size_var,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_stack_type,llvm::function_ref<void (ArrayRef<BlockArgument>)> handle_new_size_vars=nullptr)125 void ModifyFunctionSignature(
126 FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
127 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
128 llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
129 nullptr) {
130 auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
131 auto size_var_type = GetSizeVarType(OpBuilder(func));
132 int64_t original_arg_count = new_input_types.size();
133 for (int64_t i = 0; i < original_arg_count; ++i) {
134 auto stack_type = arg_to_stack_type(i);
135 if (!stack_type.hasValue()) continue;
136 func.getArgument(i).setType(*stack_type);
137 new_input_types[i] = *stack_type;
138 auto size_arg = func.front().addArgument(size_var_type);
139 new_input_types.push_back(size_arg.getType());
140 if (stack_var_to_size_var) {
141 (*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
142 }
143 }
144 if (handle_new_size_vars) {
145 handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
146 }
147 func.setType(
148 FunctionType::get(func.getContext(), new_input_types,
149 func.front().getTerminator()->getOperandTypes()));
150 }
151
152 // Contains cached information for decomposed callee functions for (stateful)
153 // partitioned call ops.
154 struct PartitionedCallStackOpsInfo {
155 bool signature_change;
156 FuncOp decomposed_callee;
157 llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
158 };
159
160 LogicalResult DecomposeStackOpsInternal(
161 Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
162 llvm::StringMap<PartitionedCallStackOpsInfo>*);
163
164 // Handles stack usage by a tf.While. It will convert the body and conditional
165 // function signatures, and performs stack ops decomposition on them.
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)166 LogicalResult HandleWhileOp(
167 TF::WhileOp while_op, ModuleOp module,
168 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
169 llvm::StringMap<PartitionedCallStackOpsInfo>*
170 decomposed_partitioned_call_callees) {
171 auto body = while_op.body_function();
172 llvm::SmallDenseMap<Value, Value> body_map;
173 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
174 auto it = data_var_to_size_var.find(while_op.getOperand(index));
175 if (it == data_var_to_size_var.end()) return llvm::None;
176 return it->getFirst().getType();
177 };
178 auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
179 if (new_args.empty()) return;
180 auto body_ret = body.front().getTerminator();
181 auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
182 for (auto arg : new_args) new_body_returns.push_back(arg);
183 OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
184 body_ret->erase();
185 };
186 // Handle body.
187 ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
188 add_size_vars_to_return);
189 const bool signature_change = !body_map.empty();
190 if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
191 decomposed_partitioned_call_callees))) {
192 return failure();
193 }
194 // Cond should not change stacks in the arguments, so use an empty map.
195 auto cond = while_op.cond_function();
196 ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
197 llvm::SmallDenseMap<Value, Value> empty_map;
198 if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
199 decomposed_partitioned_call_callees))) {
200 return failure();
201 }
202 if (!signature_change) return success();
203 // Create the new while op.
204 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
205 OpBuilder builder(while_op);
206 assert(while_op.getNumOperands() == while_op.getNumResults());
207 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
208 auto it = data_var_to_size_var.find(while_op.getOperand(i));
209 if (it == data_var_to_size_var.end()) continue;
210 new_while_operands.push_back(it->getSecond());
211 }
212 auto new_while =
213 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
214 new_while_operands, while_op->getAttrs());
215 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
216 if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
217 .isa<TF::ResourceType>()) {
218 continue;
219 }
220 int64_t aliased_input = FindAliasedInput(body, i);
221 if (aliased_input == i) {
222 // Replace aliased stack output uses with input.
223 while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
224 }
225 }
226 while_op.replaceAllUsesWith(
227 new_while.getResults().take_front(while_op.getNumResults()));
228 while_op.erase();
229 return success();
230 }
231
232 // Handles stack usage by a tf.If. It will convert the branch function
233 // signatures, and performs stack ops decomposition on them.
HandleIfOp(TF::IfOp if_op,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)234 LogicalResult HandleIfOp(
235 TF::IfOp if_op, ModuleOp module,
236 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
237 llvm::StringMap<PartitionedCallStackOpsInfo>*
238 decomposed_partitioned_call_callees) {
239 auto then_func = if_op.then_function();
240 auto else_func = if_op.else_function();
241 llvm::SmallDenseMap<Value, Value> then_map;
242 llvm::SmallDenseMap<Value, Value> else_map;
243
244 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
245 auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
246 if (it == data_var_to_size_var.end()) return llvm::None;
247 return it->getFirst().getType();
248 };
249 ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type);
250 ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type);
251 const bool signature_change = !then_map.empty() || !else_map.empty();
252 if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map,
253 decomposed_partitioned_call_callees)) ||
254 failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map,
255 decomposed_partitioned_call_callees))) {
256 return failure();
257 }
258 if (!signature_change) return success();
259 auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
260 for (auto operand : if_op.getOperands()) {
261 auto it = data_var_to_size_var.find(operand);
262 if (it == data_var_to_size_var.end()) continue;
263 new_if_operands.push_back(it->getSecond());
264 }
265 auto new_if = OpBuilder(if_op).create<TF::IfOp>(
266 if_op.getLoc(), then_func.getType().getResults(), new_if_operands,
267 if_op->getAttrs());
268 for (auto result : if_op.getResults()) {
269 if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
270 continue;
271 }
272 int64_t then_aliased_input =
273 FindAliasedInput(then_func, result.getResultNumber());
274 int64_t else_aliased_input =
275 FindAliasedInput(else_func, result.getResultNumber());
276 if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
277 // Replace aliased stack output uses with input.
278 result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
279 }
280 }
281 if_op.replaceAllUsesWith(new_if);
282 if_op.erase();
283 return success();
284 }
285
286 // Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
287 // It will first check if the callee was previously handled, and try to reuse
288 // that result if so. Otherwise, it will clone and convert the callee function,
289 // and performs stack ops decomposition on it.
290 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,const llvm::SmallDenseMap<Value,Value> & data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)291 LogicalResult HandlePartitionedCallOp(
292 CallOp call, FuncOp callee, ModuleOp module,
293 const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
294 llvm::StringMap<PartitionedCallStackOpsInfo>*
295 decomposed_partitioned_call_callees) {
296 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
297 callee.getName(), PartitionedCallStackOpsInfo());
298 auto& info = emplace_res.first->second;
299 // Recreate the call op with info.
300 auto recreate_caller = [&] {
301 auto new_operands = llvm::to_vector<8>(call.getOperands());
302 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
303 auto arg_it = info.stack_var_arg_to_size_arg.find(i);
304 if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
305 auto it = data_var_to_size_var.find(call.getOperand(i));
306 if (it == data_var_to_size_var.end()) {
307 call.emitOpError("unknown stack");
308 return failure();
309 }
310 assert(arg_it->second == new_operands.size());
311 new_operands.push_back(it->getSecond());
312 }
313 OpBuilder builder(call);
314 auto new_call = builder.create<CallOp>(
315 call.getLoc(), info.decomposed_callee.getType().getResults(),
316 new_operands, call->getAttrs());
317 new_call->setAttr(
318 "f", builder.getSymbolRefAttr(
319 const_cast<FuncOp&>(info.decomposed_callee).getName()));
320 for (int64_t i = 0; i < call.getNumResults(); ++i) {
321 auto result = call.getResult(i);
322 if (!getElementTypeOrSelf(result.getType())
323 .template isa<TF::ResourceType>()) {
324 continue;
325 }
326 int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
327 if (aliased_input >= 0) {
328 // Replace aliased stack output uses with input.
329 result.replaceAllUsesWith(call.getOperand(aliased_input));
330 }
331 }
332 call.replaceAllUsesWith(new_call);
333 call.erase();
334 return success();
335 };
336 if (!emplace_res.second) {
337 // This callee was handled before.
338 if (!info.signature_change) return success();
339 return recreate_caller();
340 }
341 llvm::SmallDenseMap<Value, Value> callee_map;
342 FuncOp lowered_callee = callee;
343 if (!callee.isPrivate()) {
344 // Clone non-private callee in case of signature change.
345 lowered_callee = callee.clone();
346 lowered_callee.setPrivate();
347 }
348 auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
349 auto it = data_var_to_size_var.find(call.getOperand(index));
350 if (it == data_var_to_size_var.end()) return llvm::None;
351 return it->getFirst().getType();
352 };
353 ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type);
354 info.signature_change = !callee_map.empty();
355 if (!info.signature_change) {
356 // Signature is not modified. We do not need the clone.
357 if (lowered_callee != callee) {
358 lowered_callee.erase();
359 }
360 } else {
361 info.decomposed_callee = lowered_callee;
362 for (auto& entry : callee_map) {
363 info.stack_var_arg_to_size_arg
364 [entry.getFirst().cast<BlockArgument>().getArgNumber()] =
365 entry.getSecond().cast<BlockArgument>().getArgNumber();
366 }
367 if (lowered_callee != callee) {
368 // Add the clone with a new name.
369 lowered_callee.setName(
370 llvm::formatv("{0}_stack_decomposed", callee.getName()).str());
371 SymbolTable(module).insert(lowered_callee);
372 callee = lowered_callee;
373 }
374 }
375 if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
376 decomposed_partitioned_call_callees))) {
377 return failure();
378 }
379 if (info.signature_change) return recreate_caller();
380 return success();
381 }
382
HandleStackV2Op(TF::StackV2Op stack,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)383 LogicalResult HandleStackV2Op(
384 TF::StackV2Op stack, ModuleOp module,
385 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
386 // Create a buffer variable and a size variable to replace the stack.
387 auto elem_type = cutil::GetElementTypeFromAccess(
388 stack.handle(), module, [](Operation* user) -> llvm::Optional<Type> {
389 auto push = llvm::dyn_cast<TF::StackPushV2Op>(user);
390 if (!push) return llvm::None;
391 return push.elem().getType();
392 });
393 if (!elem_type.hasValue()) {
394 return stack.emitOpError("cannot infer element shape of stack");
395 }
396 OpBuilder builder(stack);
397 Value buffer;
398 if (failed(cutil::CreateInitBufferValue(
399 elem_type->getShape(), stack.max_size(), stack,
400 elem_type->getElementType(), builder, &buffer))) {
401 return failure();
402 }
403 auto size_var_type = GetSizeVarType(builder);
404 auto var_type = RankedTensorType::get(
405 {}, TF::ResourceType::get(
406 ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
407 stack.getContext()));
408 auto local_var = builder.create<TF::MlirLocalVarOp>(
409 stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
410 auto local_size_var = builder.create<TF::MlirLocalVarOp>(
411 stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{});
412 // Zero-initialize the local vars.
413 cutil::WriteLocalVariable(local_size_var,
414 cutil::GetR1Const({0LL}, builder, stack.getLoc()),
415 builder, stack.getLoc());
416 cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc());
417 stack.handle().replaceAllUsesWith(local_var);
418 (*data_var_to_size_var)[local_var] = local_size_var;
419 stack.erase();
420 return success();
421 }
422
HandleStackPushV2Op(TF::StackPushV2Op push,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)423 LogicalResult HandleStackPushV2Op(
424 TF::StackPushV2Op push,
425 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
426 auto it = data_var_to_size_var->find(push.handle());
427 if (it == data_var_to_size_var->end()) {
428 return push.emitOpError("unknown stack");
429 }
430 // Push output simply forward the input element.
431 push.replaceAllUsesWith(push.elem());
432 OpBuilder builder(push);
433 // Read the current buffer and size.
434 auto stack_val =
435 cutil::ReadLocalVariable(push.handle(), builder, push.getLoc());
436 auto index =
437 cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc());
438 stack_val =
439 cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc());
440 // Assign the new buffer and size.
441 cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
442 index = builder.create<TF::AddV2Op>(
443 push.getLoc(), ArrayRef<Type>{index.getType()},
444 ArrayRef<Value>{index, cutil::GetR1Const({1}, builder, push.getLoc())});
445 cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
446 push.erase();
447 return success();
448 }
449
HandleStackPopV2Op(TF::StackPopV2Op pop,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var)450 LogicalResult HandleStackPopV2Op(
451 TF::StackPopV2Op pop,
452 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
453 auto it = data_var_to_size_var->find(pop.handle());
454 if (it == data_var_to_size_var->end()) {
455 return pop.emitOpError("unknown stack");
456 }
457 OpBuilder builder(pop);
458 // Read the current buffer and size.
459 auto stack_val =
460 cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc());
461 auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
462 auto new_size = builder.create<TF::SubOp>(
463 pop.getLoc(), ArrayRef<Type>{size.getType()},
464 ArrayRef<Value>{size, cutil::GetR1Const({1}, builder, pop.getLoc())});
465 auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc());
466 pop.replaceAllUsesWith(pop_val);
467 // Update the size.
468 cutil::WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
469 pop.erase();
470 return success();
471 }
472
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)473 LogicalResult HandleRegionControlFlowOps(
474 Operation& op, ModuleOp module,
475 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
476 llvm::StringMap<PartitionedCallStackOpsInfo>*
477 decomposed_partitioned_call_callees) {
478 for (OpOperand& operand : op.getOpOperands()) {
479 if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
480 return op.emitOpError()
481 << "found unexpected type " << operand.get().getType()
482 << " of operand #" << operand.getOperandNumber()
483 << ", resource type operands are expected to have been "
484 "canonicalized away for region based control flow ops";
485 }
486 }
487 for (OpResult result : op.getResults()) {
488 if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
489 return op.emitOpError()
490 << "found unexpected type " << result.getType() << " of result #"
491 << result.getResultNumber()
492 << ", resource type results are expected to have been "
493 "canonicalized away for region based control flow ops";
494 }
495 }
496 for (Region& region : op.getRegions()) {
497 if (failed(DecomposeStackOpsInternal(®ion.front(), module,
498 data_var_to_size_var,
499 decomposed_partitioned_call_callees)))
500 return failure();
501 }
502 return success();
503 }
504
505 // Decomposes stack ops on a region and recursively decomposes called functions.
506 // data_var_to_size_var: a mapping from stacks' buffer local variables to size
507 // local variables.
508 // decomposed_partitioned_call_callees: cache for partitioned call ops' callee
509 // function handling.
DecomposeStackOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,Value> * data_var_to_size_var,llvm::StringMap<PartitionedCallStackOpsInfo> * decomposed_partitioned_call_callees)510 LogicalResult DecomposeStackOpsInternal(
511 Block* block, ModuleOp module,
512 llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
513 llvm::StringMap<PartitionedCallStackOpsInfo>*
514 decomposed_partitioned_call_callees) {
515 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
516 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
517 // Removes identity nodes in the block. The device computation does not
518 // need such nodes to carry information.
519 op.replaceAllUsesWith(op.getOperands());
520 op.erase();
521 } else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
522 if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
523 return failure();
524 }
525 } else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
526 if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
527 return failure();
528 }
529 } else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
530 if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
531 return failure();
532 }
533 } else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
534 data_var_to_size_var->erase(close.handle());
535 close.erase();
536 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
537 if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
538 decomposed_partitioned_call_callees))) {
539 return failure();
540 }
541 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
542 if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
543 decomposed_partitioned_call_callees))) {
544 return failure();
545 }
546 } else if (llvm::isa<TF::WhileRegionOp>(op) ||
547 llvm::isa<TF::IfRegionOp>(op) ||
548 llvm::isa<TF::CaseRegionOp>(op)) {
549 if (failed(
550 HandleRegionControlFlowOps(op, module, data_var_to_size_var,
551 decomposed_partitioned_call_callees)))
552 return failure();
553 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
554 if (!pcall.func()) {
555 return pcall.emitOpError(
556 "stack decomposition does not support call with nested references");
557 }
558 if (failed(HandlePartitionedCallOp(
559 pcall, pcall.func(), module, *data_var_to_size_var,
560 decomposed_partitioned_call_callees))) {
561 return failure();
562 }
563 } else if (auto spcall =
564 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
565 if (failed(HandlePartitionedCallOp(
566 spcall, spcall.func(), module, *data_var_to_size_var,
567 decomposed_partitioned_call_callees))) {
568 return failure();
569 }
570 }
571 }
572 return success();
573 }
574
DecomposeStackOps(Block * block,ModuleOp module)575 LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
576 llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
577 llvm::StringMap<PartitionedCallStackOpsInfo>
578 decomposed_partitioned_call_callees;
579 return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
580 &decomposed_partitioned_call_callees);
581 }
582
runOnOperation()583 void StackOpsDecompositionPass::runOnOperation() {
584 auto module = getOperation();
585 auto main = module.lookupSymbol<FuncOp>("main");
586 if (!main) return;
587 if (failed(DecomposeStackOps(&main.front(), module))) {
588 signalPassFailure();
589 }
590 }
591
592 static PassRegistration<StackOpsDecompositionPass> pass;
593
594 } // namespace
595
596 namespace TF {
CreateStackOpsDecompositionPass()597 std::unique_ptr<OperationPass<ModuleOp>> CreateStackOpsDecompositionPass() {
598 return std::make_unique<StackOpsDecompositionPass>();
599 }
600
601 } // namespace TF
602 } // namespace mlir
603