• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace mlir {
43 
44 namespace {
45 
46 namespace cutil = TF::collection_ops_util;
47 
48 struct TensorListOpsDecompositionPass
49     : public TF::TensorListOpsDecompositionPassBase<
50           TensorListOpsDecompositionPass> {
51   void runOnOperation() override;
52 };
53 
54 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)55 void UpdateFuncType(FuncOp func) {
56   llvm::SmallVector<Type, 8> arg_types;
57   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
58   func.setType(
59       FunctionType::get(func.getContext(), arg_types,
60                         func.front().getTerminator()->getOperandTypes()));
61 }
62 
63 // Holds the size value of a tensor list and whether the size is statically
64 // known (fixed).
65 struct SizeInfo {
66   Value size;
67   bool fixed;
68 };
69 
70 // Modifies a function's signature to rewrite tensor list arguments to buffers
71 // and sizes.
ModifyFunctionSignature(FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)72 void ModifyFunctionSignature(
73     FuncOp func, Type size_type,
74     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
75     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
76     llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
77   auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
78   int64_t original_arg_count = new_input_types.size();
79   for (int64_t i = 0; i < original_arg_count; ++i) {
80     auto buffer_type = arg_to_buffer_type(i);
81     if (!buffer_type.hasValue()) continue;
82     func.getArgument(i).setType(*buffer_type);
83     new_input_types[i] = *buffer_type;
84     auto size_arg = func.front().addArgument(size_type);
85     new_input_types.push_back(size_arg.getType());
86     if (buffer_to_size) {
87       (*buffer_to_size)[func.getArgument(i)] = {size_arg,
88                                                 arg_buffer_size_is_fixed(i)};
89     }
90   }
91   UpdateFuncType(func);
92 }
93 
94 // Holds information about a decomposed callee function for
95 // PartitionedCall/StatefulPartitionedCall.
96 struct PartitionedCallDecompositionInfo {
97   bool signature_change;
98   FuncOp decomposed_callee;
99   llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
100   // Each element is a tuple of (buffer_return_index, size_return_index,
101   // fixed_size).
102   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
103       buffer_ret_to_size_ret;
104 };
105 
106 LogicalResult DecomposeTensorListOpsInternal(
107     Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
108     llvm::StringMap<PartitionedCallDecompositionInfo>*);
109 
110 // Adds the corresponding sizes of tensor list buffers in block's terminator
111 // to the list of return values. Returns the mapping from the buffer
112 // indices to the added size indices, which is a list of tuples
113 // (buffer_return_index, size_return_index, fixed_size).
114 template <class TerminatorOp>
115 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)116 AddTensorListSizesToTerminator(
117     Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
118   auto old_terminator = block.getTerminator();
119   auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
120   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
121       output_buffer_to_size;
122   for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
123     auto it = buffer_to_size.find(retval.value());
124     if (it == buffer_to_size.end()) continue;
125     output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
126                                        it->getSecond().fixed);
127     new_outputs.push_back(it->getSecond().size);
128   }
129   OpBuilder(old_terminator)
130       .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
131   old_terminator->erase();
132   return output_buffer_to_size;
133 }
134 
135 // Adds the corresponding sizes of tensor list buffers in func's return values
136 // to the list of return values. Returns the mapping from the buffer indices to
137 // the added size indices, which is a list of tuples (buffer_return_index,
138 // size_return_index, fixed_size).
ModifyFunctionReturn(FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)139 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
140     FuncOp func, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
141   auto output_buffer_to_size =
142       AddTensorListSizesToTerminator<ReturnOp>(func.front(), buffer_to_size);
143   UpdateFuncType(func);
144   return output_buffer_to_size;
145 }
146 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)147 LogicalResult HandleWhileOp(
148     TF::WhileOp while_op, ModuleOp module,
149     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
150     llvm::StringMap<PartitionedCallDecompositionInfo>*
151         decomposed_partitioned_call_callees) {
152   // Rewrite body.
153   auto body = while_op.body_function();
154   llvm::SmallDenseMap<Value, SizeInfo> body_map;
155   auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
156     auto it = buffer_to_size->find(while_op.getOperand(index));
157     if (it == buffer_to_size->end()) return llvm::None;
158     return it->getFirst().getType();
159   };
160   auto arg_buffer_size_is_fixed = [&](int64_t index) {
161     return (*buffer_to_size)[while_op.getOperand(index)].fixed;
162   };
163   OpBuilder builder(while_op);
164   ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
165                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
166   if (failed(DecomposeTensorListOpsInternal(
167           &body.front(), module, &body_map,
168           decomposed_partitioned_call_callees))) {
169     return failure();
170   }
171   auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
172 
173   // Rewrite cond.
174   auto cond = while_op.cond_function();
175   llvm::SmallDenseMap<Value, SizeInfo> cond_map;
176   ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
177                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
178   if (failed(DecomposeTensorListOpsInternal(
179           &cond.front(), module, &cond_map,
180           decomposed_partitioned_call_callees))) {
181     return failure();
182   }
183   if (output_buffer_to_size.empty()) {
184     return success();
185   }
186   // Create the new while op.
187   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
188   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
189     auto it = buffer_to_size->find(while_op.getOperand(i));
190     if (it == buffer_to_size->end()) continue;
191     new_while_operands.push_back(it->getSecond().size);
192   }
193   auto new_while =
194       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
195                                   new_while_operands, while_op->getAttrs());
196   for (const auto& entry : output_buffer_to_size) {
197     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
198         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
199   }
200   while_op.replaceAllUsesWith(
201       new_while.getResults().take_front(while_op.getNumResults()));
202   while_op.erase();
203   return success();
204 }
205 
206 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)207 LogicalResult HandleCaseOrIfOp(
208     CaseOrIfOp op, ArrayRef<FuncOp> branches, ModuleOp module,
209     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
210     llvm::StringMap<PartitionedCallDecompositionInfo>*
211         decomposed_partitioned_call_callees) {
212   // Rewrite the branches.
213   SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
214   branch_maps.resize(branches.size());
215 
216   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
217     auto it = buffer_to_size->find(op.getOperand(index + 1));
218     if (it == buffer_to_size->end()) return llvm::None;
219     return it->getFirst().getType();
220   };
221   auto arg_buffer_size_is_fixed = [&](int64_t index) {
222     return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
223   };
224   OpBuilder builder(op);
225   for (const auto& pair : llvm::zip(branches, branch_maps)) {
226     FuncOp branch = std::get<0>(pair);
227     llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
228     ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
229                             find_arg_buffer_type, arg_buffer_size_is_fixed);
230 
231     if (failed(DecomposeTensorListOpsInternal(
232             &branch.front(), module, &branch_map,
233             decomposed_partitioned_call_callees)))
234       return failure();
235   }
236 
237   const bool arg_no_changed = branch_maps.front().empty();
238   auto output_buffer_to_size =
239       ModifyFunctionReturn(branches.front(), branch_maps.front());
240   for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
241     ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
242 
243   if (output_buffer_to_size.empty() && arg_no_changed) return success();
244 
245   // Recreate the op.
246   auto new_operands = llvm::to_vector<8>(op.getOperands());
247   for (int64_t i = 1; i < op.getNumOperands(); ++i) {
248     auto it = buffer_to_size->find(op.getOperand(i));
249     if (it == buffer_to_size->end()) continue;
250     new_operands.push_back(it->getSecond().size);
251   }
252   FuncOp first_branch = branches.front();
253   auto new_op = OpBuilder(op).create<CaseOrIfOp>(
254       op.getLoc(), first_branch.getType().getResults(), new_operands,
255       op->getAttrs());
256   for (const auto& entry : output_buffer_to_size) {
257     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
258         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
259   }
260   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
261   op.erase();
262   return success();
263 }
264 
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)265 LogicalResult HandleWhileRegionOp(
266     TF::WhileRegionOp while_op, ModuleOp module,
267     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
268     llvm::StringMap<PartitionedCallDecompositionInfo>*
269         decomposed_partitioned_call_callees) {
270   OpBuilder builder(while_op);
271   auto modify_region_arguments = [&](Region& region) {
272     int64_t original_arg_count = region.getNumArguments();
273     for (int64_t i = 0; i < original_arg_count; ++i) {
274       auto operand = while_op.getOperand(i);
275       auto it = buffer_to_size->find(operand);
276       if (it == buffer_to_size->end()) continue;
277       auto buffer_type = it->getFirst().getType();
278       region.getArgument(i).setType(buffer_type);
279       auto size_arg = region.addArgument(cutil::GetSizeType(builder));
280       (*buffer_to_size)[region.getArgument(i)] = {size_arg,
281                                                   it->getSecond().fixed};
282     }
283   };
284 
285   // Rewrite body.
286   Region& body_region = while_op.body();
287   modify_region_arguments(body_region);
288   if (failed(DecomposeTensorListOpsInternal(
289           &body_region.front(), module, buffer_to_size,
290           decomposed_partitioned_call_callees))) {
291     return failure();
292   }
293   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
294       body_region.front(), *buffer_to_size);
295 
296   // Rewrite cond.
297   Region& cond_region = while_op.cond();
298   modify_region_arguments(cond_region);
299   if (failed(DecomposeTensorListOpsInternal(
300           &cond_region.front(), module, buffer_to_size,
301           decomposed_partitioned_call_callees))) {
302     return failure();
303   }
304 
305   if (output_buffer_to_size.empty()) return success();
306 
307   // Create the new while op.
308   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
309   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
310     auto it = buffer_to_size->find(while_op.getOperand(i));
311     if (it == buffer_to_size->end()) continue;
312     new_while_operands.push_back(it->getSecond().size);
313   }
314   auto new_while = builder.create<TF::WhileRegionOp>(
315       while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
316       new_while_operands, while_op->getAttrs());
317   new_while.body().takeBody(body_region);
318   new_while.cond().takeBody(cond_region);
319   for (const auto& entry : output_buffer_to_size) {
320     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
321         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
322   }
323   while_op.replaceAllUsesWith(
324       new_while.getResults().take_front(while_op.getNumResults()));
325   while_op.erase();
326   return success();
327 }
328 
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)329 LogicalResult HandleIfRegionOp(
330     TF::IfRegionOp if_op, ModuleOp module,
331     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
332     llvm::StringMap<PartitionedCallDecompositionInfo>*
333         decomposed_partitioned_call_callees) {
334   // Rewrite the branches.
335   Region& then_branch = if_op.then_branch();
336   Region& else_branch = if_op.else_branch();
337   if (failed(DecomposeTensorListOpsInternal(
338           &then_branch.front(), module, buffer_to_size,
339           decomposed_partitioned_call_callees)))
340     return failure();
341   if (failed(DecomposeTensorListOpsInternal(
342           &else_branch.front(), module, buffer_to_size,
343           decomposed_partitioned_call_callees)))
344     return failure();
345 
346   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
347       then_branch.front(), *buffer_to_size);
348   AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
349                                               *buffer_to_size);
350 
351   if (output_buffer_to_size.empty()) return success();
352 
353   // Recreate the op.
354   auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
355       if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
356       if_op.getOperand(), if_op->getAttrs());
357   for (const auto& entry : output_buffer_to_size) {
358     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
359         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
360   }
361 
362   new_op.then_branch().takeBody(if_op.then_branch());
363   new_op.else_branch().takeBody(if_op.else_branch());
364 
365   if_op.replaceAllUsesWith(
366       new_op.getResults().take_front(if_op.getNumResults()));
367   if_op.erase();
368   return success();
369 }
370 
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)371 LogicalResult HandleCaseRegionOp(
372     TF::CaseRegionOp case_op, ModuleOp module,
373     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
374     llvm::StringMap<PartitionedCallDecompositionInfo>*
375         decomposed_partitioned_call_callees) {
376   // Rewrite the branches.
377   RegionRange branches = case_op.getRegions();
378 
379   for (Region* branch : branches) {
380     if (failed(DecomposeTensorListOpsInternal(
381             &branch->front(), module, buffer_to_size,
382             decomposed_partitioned_call_callees)))
383       return failure();
384   }
385 
386   // Get the output buffer index to size index mapping one of the branches. It
387   // should be same for all the branches so we only get it for the first branch.
388   Region* first_branch = branches.front();
389   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
390       first_branch->front(), *buffer_to_size);
391   for (Region* branch : branches.drop_front()) {
392     AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
393                                                 *buffer_to_size);
394   }
395 
396   if (output_buffer_to_size.empty()) return success();
397 
398   // Recreate the op.
399   auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
400       case_op.getLoc(),
401       first_branch->front().getTerminator()->getOperandTypes(),
402       case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
403   for (const auto& entry : output_buffer_to_size) {
404     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
405         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
406   }
407 
408   for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
409     std::get<0>(pair)->takeBody(*std::get<1>(pair));
410   }
411   case_op.replaceAllUsesWith(
412       new_op.getResults().take_front(case_op.getNumResults()));
413   case_op.erase();
414   return success();
415 }
416 
417 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)418 LogicalResult HandlePartitionedCallOp(
419     CallOp call, FuncOp callee, ModuleOp module,
420     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
421     llvm::StringMap<PartitionedCallDecompositionInfo>*
422         decomposed_partitioned_call_callees) {
423   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
424       callee.getName(), PartitionedCallDecompositionInfo());
425   auto& info = emplace_res.first->second;
426   // Recreates the call op with info.
427   auto recreate_caller = [&] {
428     auto new_operands = llvm::to_vector<8>(call.getOperands());
429     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
430       auto arg_it = info.buffer_arg_to_size_arg.find(i);
431       if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
432       auto it = buffer_to_size->find(call.getOperand(i));
433       if (it == buffer_to_size->end()) {
434         call.emitOpError("unknown tensor list.");
435         return failure();
436       }
437       assert(arg_it->second == new_operands.size());
438       new_operands.push_back(it->getSecond().size);
439     }
440     OpBuilder builder(call);
441     auto new_call = builder.create<CallOp>(
442         call.getLoc(), info.decomposed_callee.getType().getResults(),
443         new_operands, call->getAttrs());
444     new_call->setAttr(
445         "f", builder.getSymbolRefAttr(
446                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
447     for (const auto& entry : info.buffer_ret_to_size_ret) {
448       (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
449           new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
450     }
451     call.replaceAllUsesWith(
452         new_call.getResults().take_front(call.getNumResults()));
453     call.erase();
454     return success();
455   };
456   if (!emplace_res.second) {
457     // This callee was handled before.
458     if (!info.signature_change) return success();
459     return recreate_caller();
460   }
461   // Rewrite the callee.
462   llvm::SmallDenseMap<Value, SizeInfo> callee_map;
463   FuncOp lowered_callee = callee;
464   if (!callee.isPrivate()) {
465     // Clone non-private callee in case of signature change.
466     lowered_callee = callee.clone();
467     lowered_callee.setPrivate();
468   }
469   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
470     auto it = buffer_to_size->find(call.getOperand(index));
471     if (it == buffer_to_size->end()) return llvm::None;
472     return it->getFirst().getType();
473   };
474   auto arg_buffer_size_is_fixed = [&](int64_t index) {
475     return (*buffer_to_size)[call.getOperand(index)].fixed;
476   };
477   ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
478                           &callee_map, find_arg_buffer_type,
479                           arg_buffer_size_is_fixed);
480   const bool args_no_changed = callee_map.empty();
481   if (failed(DecomposeTensorListOpsInternal(
482           &lowered_callee.front(), module, &callee_map,
483           decomposed_partitioned_call_callees))) {
484     return failure();
485   }
486   info.buffer_ret_to_size_ret =
487       ModifyFunctionReturn(lowered_callee, callee_map);
488   info.decomposed_callee = lowered_callee;
489   if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
490     // Signature is not modified. We do not need to keep two copies.
491     info.signature_change = false;
492     if (lowered_callee != callee) {
493       lowered_callee.setName(callee.getName());
494       callee.erase();
495       SymbolTable(module).insert(lowered_callee);
496     }
497   } else {
498     info.signature_change = true;
499     for (auto& entry : callee_map) {
500       auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
501       if (!buffer_arg) continue;
502       info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
503           entry.getSecond().size.cast<BlockArgument>().getArgNumber();
504     }
505     if (lowered_callee != callee) {
506       // Add the clone with a new name.
507       lowered_callee.setName(
508           llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str());
509       SymbolTable(module).insert(lowered_callee);
510       callee = lowered_callee;
511     }
512   }
513   if (info.signature_change) return recreate_caller();
514   return success();
515 }
516 
517 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
518 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)519 LogicalResult GetConstShapeValue(Value shape_value,
520                                  llvm::SmallVector<int64_t, 8>* shape) {
521   auto shape_op = shape_value.getDefiningOp();
522   if (!shape_op) return failure();
523   auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
524   if (!shape_const_op) return failure();
525   for (const auto& v : shape_const_op.value().getValues<APInt>()) {
526     int64_t dim_size = v.getSExtValue();
527     if (dim_size == ShapedType::kDynamicSize) return failure();
528     shape->push_back(dim_size);
529   }
530   return success();
531 }
532 
533 // Checks the result Variant type to infer the element shape if fully defined.
534 // If the Variant type has multiple subtypes or does not have static shape,
535 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)536 LogicalResult GetElementShapeFromResultType(
537     Type type, llvm::SmallVector<int64_t, 8>* shape) {
538   auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
539   if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
540   TensorType tensor_type = variant_type.getSubtypes().front();
541   if (!tensor_type.hasStaticShape()) return failure();
542   for (auto d : tensor_type.getShape()) shape->push_back(d);
543   return success();
544 }
545 
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)546 LogicalResult HandleEmptyTensorListOp(
547     TF::EmptyTensorListOp list,
548     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
549   Value buffer;
550   OpBuilder builder(list);
551   llvm::SmallVector<int64_t, 8> element_shape;
552   // Infer TensorList element shape from the return type first, and then from
553   // the const element shape operand. We first check the return type because
554   // shape inference might have successfully inferred the element shape from
555   // write operations on the TensorList.
556   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
557     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
558       return list.emitOpError("unknown tensor list element shape");
559     }
560   }
561   if (failed(cutil::CreateInitBufferValue(
562           element_shape, list.max_num_elements(), list, list.element_dtype(),
563           builder, &buffer))) {
564     return failure();
565   }
566   Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
567   list.handle().replaceAllUsesWith(buffer);
568   (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
569   list.erase();
570   return success();
571 }
572 
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)573 LogicalResult HandleTensorListReserveOp(
574     TF::TensorListReserveOp list,
575     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
576   Value buffer;
577   OpBuilder builder(list);
578   llvm::SmallVector<int64_t, 8> element_shape;
579   // Infer TensorList element shape from the return type first, and then from
580   // the const element shape operand. We first check the return type because
581   // shape inference might have successfully inferred the element shape from
582   // write operations on the TensorList.
583   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
584     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
585       return list.emitOpError("unknown tensor list element shape");
586     }
587   }
588   if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
589                                           list, list.element_dtype(), builder,
590                                           &buffer))) {
591     return failure();
592   }
593   Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
594                                               list.getLoc());
595   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
596   list.handle().replaceAllUsesWith(buffer);
597   list.erase();
598   return success();
599 }
600 
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)601 LogicalResult HandleTensorListFromTensorOp(
602     TF::TensorListFromTensorOp list,
603     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
604   OpBuilder builder(list);
605   Value buffer = builder.create<TF::IdentityOp>(
606       list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
607       ArrayRef<Value>{list.tensor()});
608   auto type = buffer.getType().cast<TensorType>();
609   if (!type.hasStaticShape()) {
610     return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
611   }
612   Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
613   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
614   list.output_handle().replaceAllUsesWith(buffer);
615   list.erase();
616   return success();
617 }
618 
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)619 LogicalResult HandleTensorListPushBackOp(
620     TF::TensorListPushBackOp push,
621     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
622   auto buffer = push.input_handle();
623   auto it = buffer_to_size->find(buffer);
624   if (it == buffer_to_size->end()) {
625     return push.emitOpError(
626         "found tf.TensorListPushBack on unknown TensorList.");
627   }
628   if (it->getSecond().fixed) {
629     return push.emitError("cannot push on a fixed-size tensor list");
630   }
631   auto size = it->getSecond().size;
632   OpBuilder builder(push);
633   auto new_buffer =
634       cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
635   auto new_size = builder.create<TF::AddV2Op>(
636       push.getLoc(), ArrayRef<Type>{size.getType()},
637       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
638   push.output_handle().replaceAllUsesWith(new_buffer);
639   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
640   push.erase();
641   return success();
642 }
643 
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)644 LogicalResult HandleTensorListPopBackOp(
645     TF::TensorListPopBackOp pop,
646     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
647   auto buffer = pop.input_handle();
648   auto it = buffer_to_size->find(buffer);
649   if (it == buffer_to_size->end()) {
650     pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
651     return failure();
652   }
653   if (it->getSecond().fixed) {
654     return pop.emitError("cannot pop on a fixed-size tensor list");
655   }
656   auto size = it->getSecond().size;
657   OpBuilder builder(pop);
658   auto new_buffer = builder.create<TF::IdentityOp>(
659       pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
660   auto new_size = builder.create<TF::SubOp>(
661       pop.getLoc(), ArrayRef<Type>{size.getType()},
662       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
663   auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
664   pop.output_handle().replaceAllUsesWith(new_buffer);
665   pop.tensor().replaceAllUsesWith(element);
666   pop.erase();
667   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
668   return success();
669 }
670 
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)671 LogicalResult HandleTensorListGetItemOp(
672     TF::TensorListGetItemOp get_item,
673     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
674   auto buffer = get_item.input_handle();
675   auto it = buffer_to_size.find(buffer);
676   if (it == buffer_to_size.end()) {
677     get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
678     return failure();
679   }
680   OpBuilder builder(get_item);
681   auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
682                                               get_item.getLoc());
683   auto element =
684       cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
685   get_item.item().replaceAllUsesWith(element);
686   get_item.erase();
687   return success();
688 }
689 
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)690 LogicalResult HandleTensorListSetItemOp(
691     TF::TensorListSetItemOp set_item,
692     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
693   auto buffer = set_item.input_handle();
694   auto it = buffer_to_size->find(buffer);
695   if (it == buffer_to_size->end()) {
696     set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
697     return failure();
698   }
699   OpBuilder builder(set_item);
700   auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
701                                               set_item.getLoc());
702   auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
703                                       set_item.getLoc());
704   set_item.output_handle().replaceAllUsesWith(new_buffer);
705   auto size = it->getSecond();
706   (*buffer_to_size)[new_buffer] = size;
707   set_item.erase();
708   return success();
709 }
710 
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)711 LogicalResult HandleTensorListLengthOp(
712     TF::TensorListLengthOp length,
713     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
714   auto it = buffer_to_size.find(length.input_handle());
715   if (it == buffer_to_size.end()) {
716     length.emitOpError("found tf.TensorListLength on unknown TensorList.");
717     return failure();
718   }
719   OpBuilder builder(length);
720   if (it->getSecond().fixed) {
721     auto dim = cutil::CreateScalarConst(
722         length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
723         builder, length.getLoc());
724     length.length().replaceAllUsesWith(dim);
725   } else {
726     auto current_size = it->getSecond().size;
727     // Reshapes the R1 length to a scalar.
728     auto reshape = builder.create<TF::ReshapeOp>(
729         length.getLoc(),
730         ArrayRef<Type>{RankedTensorType::get(
731             {}, getElementTypeOrSelf(current_size.getType()))},
732         ArrayRef<Value>{current_size,
733                         cutil::GetR1Const({}, builder, length.getLoc())});
734     length.length().replaceAllUsesWith(reshape);
735   }
736   length.erase();
737   return success();
738 }
739 
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)740 LogicalResult HandleTensorListElementShapeOp(
741     TF::TensorListElementShapeOp elem_shape,
742     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
743   if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
744     return elem_shape.emitOpError("unknown tensor list");
745   }
746   auto buffer = elem_shape.input_handle();
747   auto result = cutil::GetR1Const(
748       buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
749       OpBuilder(elem_shape), elem_shape.getLoc(),
750       elem_shape.shape_type().getIntOrFloatBitWidth());
751   elem_shape.element_shape().replaceAllUsesWith(result);
752   elem_shape.erase();
753   return success();
754 }
755 
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)756 LogicalResult HandleTensorListGatherOp(
757     TF::TensorListGatherOp gather,
758     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
759   auto it = buffer_to_size.find(gather.input_handle());
760   if (it == buffer_to_size.end()) {
761     return gather.emitOpError("unknown tensor list");
762   }
763   auto buffer = gather.input_handle();
764   auto result = cutil::GatherElements(gather.indices(), buffer,
765                                       OpBuilder(gather), gather.getLoc());
766   gather.values().replaceAllUsesWith(result);
767   gather.erase();
768   return success();
769 }
770 
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)771 LogicalResult HandleTensorListScatterIntoExistingListOp(
772     TF::TensorListScatterIntoExistingListOp scatter,
773     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
774   auto it = buffer_to_size->find(scatter.input_handle());
775   if (it == buffer_to_size->end()) {
776     return scatter.emitOpError("unknown tensor list");
777   }
778   auto buffer = scatter.input_handle();
779   OpBuilder builder(scatter);
780   auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
781   if (!indices_type) return scatter.emitOpError("unranked indices shape");
782   auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
783   auto shape = builder.create<TF::ConstOp>(
784       scatter.getLoc(),
785       DenseElementsAttr::get(
786           shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
787   auto indices =
788       builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
789   Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
790       scatter.getLoc(), buffer, indices, scatter.tensor());
791   scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
792   scatter.erase();
793   auto size = it->getSecond();
794   (*buffer_to_size)[tensor_scatter_update] = size;
795   return success();
796 }
797 
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)798 LogicalResult DecomposeTensorListOpsInternal(
799     Block* block, ModuleOp module,
800     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
801     llvm::StringMap<PartitionedCallDecompositionInfo>*
802         decomposed_partitioned_call_callees) {
803   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
804     // TODO(yuanzx): Add a pass to remove identities in device computation.
805     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
806       op.replaceAllUsesWith(op.getOperands());
807       op.erase();
808     } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
809       if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
810         return failure();
811       }
812     } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
813       if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
814         return failure();
815       }
816     } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
817       if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
818         return failure();
819       }
820     } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
821       if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
822         return failure();
823       }
824     } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
825       if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
826         return failure();
827       }
828     } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
829       if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
830         return failure();
831       }
832     } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
833       if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
834         return failure();
835       }
836     } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
837       if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
838         return failure();
839       }
840     } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
841       stack.tensor().replaceAllUsesWith(stack.input_handle());
842       stack.erase();
843     } else if (auto elem_shape =
844                    llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
845       if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
846         return failure();
847       }
848     } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
849       if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
850         return failure();
851       }
852     } else if (auto scatter =
853                    llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
854                        &op)) {
855       if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
856                                                            buffer_to_size))) {
857         return failure();
858       }
859     } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
860       auto it = buffer_to_size->find(addn.getOperand(0));
861       if (it != buffer_to_size->end()) {
862         addn.sum().setType(addn.getOperand(0).getType());
863         auto size = it->getSecond();
864         (*buffer_to_size)[addn.sum()] = size;
865       }
866     } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
867       if (buffer_to_size->count(zeros.x()) > 0) {
868         zeros.y().setType(zeros.x().getType());
869         auto size = (*buffer_to_size)[zeros.x()];
870         (*buffer_to_size)[zeros.y()] = size;
871       }
872     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
873       if (failed(HandleWhileOp(while_op, module, buffer_to_size,
874                                decomposed_partitioned_call_callees))) {
875         return failure();
876       }
877     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
878       if (failed(HandleCaseOrIfOp(
879               if_op, {if_op.then_function(), if_op.else_function()}, module,
880               buffer_to_size, decomposed_partitioned_call_callees))) {
881         return failure();
882       }
883     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
884       SmallVector<FuncOp, 2> branches;
885       case_op.get_branch_functions(branches);
886       if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
887                                   decomposed_partitioned_call_callees))) {
888         return failure();
889       }
890     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
891       if (!pcall.func())
892         return pcall.emitOpError(
893             "TensorList decomposition does not support call with nested "
894             "references.");
895 
896       if (failed(HandlePartitionedCallOp(
897               pcall, pcall.func(), module, buffer_to_size,
898               decomposed_partitioned_call_callees))) {
899         return failure();
900       }
901     } else if (auto spcall =
902                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
903       if (failed(HandlePartitionedCallOp(
904               spcall, spcall.func(), module, buffer_to_size,
905               decomposed_partitioned_call_callees))) {
906         return failure();
907       }
908     } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
909       if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
910                                      decomposed_partitioned_call_callees))) {
911         return failure();
912       }
913     } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
914       if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
915                                   decomposed_partitioned_call_callees))) {
916         return failure();
917       }
918     } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
919       if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
920                                     decomposed_partitioned_call_callees))) {
921         return failure();
922       }
923     }
924   }
925   return success();
926 }
927 
DecomposeTensorListOps(Block * block,ModuleOp module)928 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
929   llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
930   llvm::StringMap<PartitionedCallDecompositionInfo>
931       decomposed_partitioned_call_callees;
932   return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
933                                         &decomposed_partitioned_call_callees);
934 }
935 
runOnOperation()936 void TensorListOpsDecompositionPass::runOnOperation() {
937   auto module = getOperation();
938   auto main = module.lookupSymbol<FuncOp>("main");
939   if (!main) return;
940   if (failed(DecomposeTensorListOps(&main.front(), module))) {
941     signalPassFailure();
942   }
943 }
944 
945 }  // namespace
946 
947 namespace TF {
948 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()949 CreateTensorListOpsDecompositionPass() {
950   return std::make_unique<TensorListOpsDecompositionPass>();
951 }
952 }  // namespace TF
953 }  // namespace mlir
954