• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/platform/types.h"
40 
41 namespace mlir {
42 
43 namespace {
44 
45 namespace cutil = TF::collection_ops_util;
46 
47 // A pass that rewrites tensor list operations to tensor operations on buffers
48 // and size values.
49 //
50 // This pass requires that the full shape of the tensor list can be inferred: 1)
51 // the maximum size needs to be a constant and 2) the element shape needs to be
52 // constant.
53 //
54 // A tensor list creation op "tf.EmptyTensorList"/"tf.TensorListReserve" will be
55 // turned in to a zero-initialized buffer, and the size is initialized to a 0
56 // for "tf.EmptyTensorList" or the specified size for "tf.TensorListReserve".
57 // Each push will be turned into "tf.XlaDynamicUpdateSlice" with the incremented
58 // size, and each pop will be turned into a "tf.Slice" and a copy of the buffer
59 // with decremented size. Each SetItem will be turned into a
60 // "tf.XlaDynamicUpdateSlice" with unchanged size, and each GetItem will be
61 // turned into a "tf.Slice".
62 //
63 // The pass also works across control flow and functional calls.
64 struct TensorListOpsDecompositionPass
65     : public PassWrapper<TensorListOpsDecompositionPass,
66                          OperationPass<ModuleOp>> {
67   void runOnOperation() override;
68 };
69 
70 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)71 void UpdateFuncType(FuncOp func) {
72   llvm::SmallVector<Type, 8> arg_types;
73   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
74   func.setType(
75       FunctionType::get(func.getContext(), arg_types,
76                         func.front().getTerminator()->getOperandTypes()));
77 }
78 
79 // Holds the size value of a tensor list and whether the size is statically
80 // known (fixed).
81 struct SizeInfo {
82   Value size;
83   bool fixed;
84 };
85 
86 // Modifies a function's signature to rewrite tensor list arguments to buffers
87 // and sizes.
ModifyFunctionSignature(FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)88 void ModifyFunctionSignature(
89     FuncOp func, Type size_type,
90     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
91     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
92     llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
93   auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
94   int64_t original_arg_count = new_input_types.size();
95   for (int64_t i = 0; i < original_arg_count; ++i) {
96     auto buffer_type = arg_to_buffer_type(i);
97     if (!buffer_type.hasValue()) continue;
98     func.getArgument(i).setType(*buffer_type);
99     new_input_types[i] = *buffer_type;
100     auto size_arg = func.front().addArgument(size_type);
101     new_input_types.push_back(size_arg.getType());
102     if (buffer_to_size) {
103       (*buffer_to_size)[func.getArgument(i)] = {size_arg,
104                                                 arg_buffer_size_is_fixed(i)};
105     }
106   }
107   UpdateFuncType(func);
108 }
109 
110 // Holds information about a decomposed callee function for
111 // PartitionedCall/StatefulPartitionedCall.
112 struct PartitionedCallDecompositionInfo {
113   bool signature_change;
114   FuncOp decomposed_callee;
115   llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
116   // Each element is a tuple of (buffer_return_index, size_return_index,
117   // fixed_size).
118   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
119       buffer_ret_to_size_ret;
120 };
121 
122 LogicalResult DecomposeTensorListOpsInternal(
123     Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
124     llvm::StringMap<PartitionedCallDecompositionInfo>*);
125 
126 // Adds the corresponding sizes of tensor list buffers in block's terminator
127 // to the list of return values. Returns the mapping from the buffer
128 // indices to the added size indices, which is a list of tuples
129 // (buffer_return_index, size_return_index, fixed_size).
130 template <class TerminatorOp>
131 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)132 AddTensorListSizesToTerminator(
133     Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
134   auto old_terminator = block.getTerminator();
135   auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
136   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
137       output_buffer_to_size;
138   for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
139     auto it = buffer_to_size.find(retval.value());
140     if (it == buffer_to_size.end()) continue;
141     output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
142                                        it->getSecond().fixed);
143     new_outputs.push_back(it->getSecond().size);
144   }
145   OpBuilder(old_terminator)
146       .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
147   old_terminator->erase();
148   return output_buffer_to_size;
149 }
150 
151 // Adds the corresponding sizes of tensor list buffers in func's return values
152 // to the list of return values. Returns the mapping from the buffer indices to
153 // the added size indices, which is a list of tuples (buffer_return_index,
154 // size_return_index, fixed_size).
ModifyFunctionReturn(FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)155 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
156     FuncOp func, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
157   auto output_buffer_to_size =
158       AddTensorListSizesToTerminator<ReturnOp>(func.front(), buffer_to_size);
159   UpdateFuncType(func);
160   return output_buffer_to_size;
161 }
162 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)163 LogicalResult HandleWhileOp(
164     TF::WhileOp while_op, ModuleOp module,
165     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
166     llvm::StringMap<PartitionedCallDecompositionInfo>*
167         decomposed_partitioned_call_callees) {
168   // Rewrite body.
169   auto body = while_op.body_function();
170   llvm::SmallDenseMap<Value, SizeInfo> body_map;
171   auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
172     auto it = buffer_to_size->find(while_op.getOperand(index));
173     if (it == buffer_to_size->end()) return llvm::None;
174     return it->getFirst().getType();
175   };
176   auto arg_buffer_size_is_fixed = [&](int64_t index) {
177     return (*buffer_to_size)[while_op.getOperand(index)].fixed;
178   };
179   OpBuilder builder(while_op);
180   ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
181                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
182   if (failed(DecomposeTensorListOpsInternal(
183           &body.front(), module, &body_map,
184           decomposed_partitioned_call_callees))) {
185     return failure();
186   }
187   auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
188 
189   // Rewrite cond.
190   auto cond = while_op.cond_function();
191   llvm::SmallDenseMap<Value, SizeInfo> cond_map;
192   ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
193                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
194   if (failed(DecomposeTensorListOpsInternal(
195           &cond.front(), module, &cond_map,
196           decomposed_partitioned_call_callees))) {
197     return failure();
198   }
199   if (output_buffer_to_size.empty()) {
200     return success();
201   }
202   // Create the new while op.
203   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
204   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
205     auto it = buffer_to_size->find(while_op.getOperand(i));
206     if (it == buffer_to_size->end()) continue;
207     new_while_operands.push_back(it->getSecond().size);
208   }
209   auto new_while =
210       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
211                                   new_while_operands, while_op.getAttrs());
212   for (const auto& entry : output_buffer_to_size) {
213     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
214         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
215   }
216   while_op.replaceAllUsesWith(
217       new_while.getResults().take_front(while_op.getNumResults()));
218   while_op.erase();
219   return success();
220 }
221 
222 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)223 LogicalResult HandleCaseOrIfOp(
224     CaseOrIfOp op, ArrayRef<FuncOp> branches, ModuleOp module,
225     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
226     llvm::StringMap<PartitionedCallDecompositionInfo>*
227         decomposed_partitioned_call_callees) {
228   // Rewrite the branches.
229   SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
230   branch_maps.resize(branches.size());
231 
232   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
233     auto it = buffer_to_size->find(op.getOperand(index + 1));
234     if (it == buffer_to_size->end()) return llvm::None;
235     return it->getFirst().getType();
236   };
237   auto arg_buffer_size_is_fixed = [&](int64_t index) {
238     return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
239   };
240   OpBuilder builder(op);
241   for (const auto& pair : llvm::zip(branches, branch_maps)) {
242     FuncOp branch = std::get<0>(pair);
243     llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
244     ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
245                             find_arg_buffer_type, arg_buffer_size_is_fixed);
246 
247     if (failed(DecomposeTensorListOpsInternal(
248             &branch.front(), module, &branch_map,
249             decomposed_partitioned_call_callees)))
250       return failure();
251   }
252 
253   const bool arg_no_changed = branch_maps.front().empty();
254   auto output_buffer_to_size =
255       ModifyFunctionReturn(branches.front(), branch_maps.front());
256   for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
257     ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
258 
259   if (output_buffer_to_size.empty() && arg_no_changed) return success();
260 
261   // Recreate the op.
262   auto new_operands = llvm::to_vector<8>(op.getOperands());
263   for (int64_t i = 1; i < op.getNumOperands(); ++i) {
264     auto it = buffer_to_size->find(op.getOperand(i));
265     if (it == buffer_to_size->end()) continue;
266     new_operands.push_back(it->getSecond().size);
267   }
268   FuncOp first_branch = branches.front();
269   auto new_op = OpBuilder(op).create<CaseOrIfOp>(
270       op.getLoc(), first_branch.getType().getResults(), new_operands,
271       op.getAttrs());
272   for (const auto& entry : output_buffer_to_size) {
273     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
274         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
275   }
276   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
277   op.erase();
278   return success();
279 }
280 
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)281 LogicalResult HandleWhileRegionOp(
282     TF::WhileRegionOp while_op, ModuleOp module,
283     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
284     llvm::StringMap<PartitionedCallDecompositionInfo>*
285         decomposed_partitioned_call_callees) {
286   OpBuilder builder(while_op);
287   auto modify_region_arguments = [&](Region& region) {
288     int64_t original_arg_count = region.getNumArguments();
289     for (int64_t i = 0; i < original_arg_count; ++i) {
290       auto operand = while_op.getOperand(i);
291       auto it = buffer_to_size->find(operand);
292       if (it == buffer_to_size->end()) continue;
293       auto buffer_type = it->getFirst().getType();
294       region.getArgument(i).setType(buffer_type);
295       auto size_arg = region.addArgument(cutil::GetSizeType(builder));
296       (*buffer_to_size)[region.getArgument(i)] = {size_arg,
297                                                   it->getSecond().fixed};
298     }
299   };
300 
301   // Rewrite body.
302   Region& body_region = while_op.body();
303   modify_region_arguments(body_region);
304   if (failed(DecomposeTensorListOpsInternal(
305           &body_region.front(), module, buffer_to_size,
306           decomposed_partitioned_call_callees))) {
307     return failure();
308   }
309   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
310       body_region.front(), *buffer_to_size);
311 
312   // Rewrite cond.
313   Region& cond_region = while_op.cond();
314   modify_region_arguments(cond_region);
315   if (failed(DecomposeTensorListOpsInternal(
316           &cond_region.front(), module, buffer_to_size,
317           decomposed_partitioned_call_callees))) {
318     return failure();
319   }
320 
321   if (output_buffer_to_size.empty()) return success();
322 
323   // Create the new while op.
324   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
325   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
326     auto it = buffer_to_size->find(while_op.getOperand(i));
327     if (it == buffer_to_size->end()) continue;
328     new_while_operands.push_back(it->getSecond().size);
329   }
330   auto new_while = builder.create<TF::WhileRegionOp>(
331       while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
332       new_while_operands, while_op.getAttrs());
333   new_while.body().takeBody(body_region);
334   new_while.cond().takeBody(cond_region);
335   for (const auto& entry : output_buffer_to_size) {
336     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
337         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
338   }
339   while_op.replaceAllUsesWith(
340       new_while.getResults().take_front(while_op.getNumResults()));
341   while_op.erase();
342   return success();
343 }
344 
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)345 LogicalResult HandleIfRegionOp(
346     TF::IfRegionOp if_op, ModuleOp module,
347     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
348     llvm::StringMap<PartitionedCallDecompositionInfo>*
349         decomposed_partitioned_call_callees) {
350   // Rewrite the branches.
351   Region& then_branch = if_op.then_branch();
352   Region& else_branch = if_op.else_branch();
353   if (failed(DecomposeTensorListOpsInternal(
354           &then_branch.front(), module, buffer_to_size,
355           decomposed_partitioned_call_callees)))
356     return failure();
357   if (failed(DecomposeTensorListOpsInternal(
358           &else_branch.front(), module, buffer_to_size,
359           decomposed_partitioned_call_callees)))
360     return failure();
361 
362   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
363       then_branch.front(), *buffer_to_size);
364   AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
365                                               *buffer_to_size);
366 
367   if (output_buffer_to_size.empty()) return success();
368 
369   // Recreate the op.
370   auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
371       if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
372       if_op.getOperand(), if_op.getAttrs());
373   for (const auto& entry : output_buffer_to_size) {
374     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
375         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
376   }
377 
378   new_op.then_branch().takeBody(if_op.then_branch());
379   new_op.else_branch().takeBody(if_op.else_branch());
380 
381   if_op.replaceAllUsesWith(
382       new_op.getResults().take_front(if_op.getNumResults()));
383   if_op.erase();
384   return success();
385 }
386 
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)387 LogicalResult HandleCaseRegionOp(
388     TF::CaseRegionOp case_op, ModuleOp module,
389     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
390     llvm::StringMap<PartitionedCallDecompositionInfo>*
391         decomposed_partitioned_call_callees) {
392   // Rewrite the branches.
393   RegionRange branches = case_op.getRegions();
394 
395   for (Region* branch : branches) {
396     if (failed(DecomposeTensorListOpsInternal(
397             &branch->front(), module, buffer_to_size,
398             decomposed_partitioned_call_callees)))
399       return failure();
400   }
401 
402   // Get the output buffer index to size index mapping one of the branches. It
403   // should be same for all the branches so we only get it for the first branch.
404   Region* first_branch = branches.front();
405   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
406       first_branch->front(), *buffer_to_size);
407   for (Region* branch : branches.drop_front()) {
408     AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
409                                                 *buffer_to_size);
410   }
411 
412   if (output_buffer_to_size.empty()) return success();
413 
414   // Recreate the op.
415   auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
416       case_op.getLoc(),
417       first_branch->front().getTerminator()->getOperandTypes(),
418       case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions());
419   for (const auto& entry : output_buffer_to_size) {
420     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
421         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
422   }
423 
424   for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
425     std::get<0>(pair)->takeBody(*std::get<1>(pair));
426   }
427   case_op.replaceAllUsesWith(
428       new_op.getResults().take_front(case_op.getNumResults()));
429   case_op.erase();
430   return success();
431 }
432 
433 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)434 LogicalResult HandlePartitionedCallOp(
435     CallOp call, FuncOp callee, ModuleOp module,
436     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
437     llvm::StringMap<PartitionedCallDecompositionInfo>*
438         decomposed_partitioned_call_callees) {
439   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
440       callee.getName(), PartitionedCallDecompositionInfo());
441   auto& info = emplace_res.first->second;
442   // Recreates the call op with info.
443   auto recreate_caller = [&] {
444     auto new_operands = llvm::to_vector<8>(call.getOperands());
445     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
446       auto arg_it = info.buffer_arg_to_size_arg.find(i);
447       if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
448       auto it = buffer_to_size->find(call.getOperand(i));
449       if (it == buffer_to_size->end()) {
450         call.emitOpError("unknown tensor list.");
451         return failure();
452       }
453       assert(arg_it->second == new_operands.size());
454       new_operands.push_back(it->getSecond().size);
455     }
456     OpBuilder builder(call);
457     auto new_call = builder.create<CallOp>(
458         call.getLoc(), info.decomposed_callee.getType().getResults(),
459         new_operands, call.getAttrs());
460     new_call->setAttr(
461         "f", builder.getSymbolRefAttr(
462                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
463     for (const auto& entry : info.buffer_ret_to_size_ret) {
464       (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
465           new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
466     }
467     call.replaceAllUsesWith(
468         new_call.getResults().take_front(call.getNumResults()));
469     call.erase();
470     return success();
471   };
472   if (!emplace_res.second) {
473     // This callee was handled before.
474     if (!info.signature_change) return success();
475     return recreate_caller();
476   }
477   // Rewrite the callee.
478   llvm::SmallDenseMap<Value, SizeInfo> callee_map;
479   FuncOp lowered_callee = callee;
480   if (!callee.isPrivate()) {
481     // Clone non-private callee in case of signature change.
482     lowered_callee = callee.clone();
483     lowered_callee.setPrivate();
484   }
485   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
486     auto it = buffer_to_size->find(call.getOperand(index));
487     if (it == buffer_to_size->end()) return llvm::None;
488     return it->getFirst().getType();
489   };
490   auto arg_buffer_size_is_fixed = [&](int64_t index) {
491     return (*buffer_to_size)[call.getOperand(index)].fixed;
492   };
493   ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
494                           &callee_map, find_arg_buffer_type,
495                           arg_buffer_size_is_fixed);
496   const bool args_no_changed = callee_map.empty();
497   if (failed(DecomposeTensorListOpsInternal(
498           &lowered_callee.front(), module, &callee_map,
499           decomposed_partitioned_call_callees))) {
500     return failure();
501   }
502   info.buffer_ret_to_size_ret =
503       ModifyFunctionReturn(lowered_callee, callee_map);
504   info.decomposed_callee = lowered_callee;
505   if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
506     // Signature is not modified. We do not need to keep two copies.
507     info.signature_change = false;
508     if (lowered_callee != callee) {
509       lowered_callee.setName(callee.getName());
510       callee.erase();
511       SymbolTable(module).insert(lowered_callee);
512     }
513   } else {
514     info.signature_change = true;
515     for (auto& entry : callee_map) {
516       auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
517       if (!buffer_arg) continue;
518       info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
519           entry.getSecond().size.cast<BlockArgument>().getArgNumber();
520     }
521     if (lowered_callee != callee) {
522       // Add the clone with a new name.
523       lowered_callee.setName(
524           llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str());
525       SymbolTable(module).insert(lowered_callee);
526       callee = lowered_callee;
527     }
528   }
529   if (info.signature_change) return recreate_caller();
530   return success();
531 }
532 
533 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
534 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)535 LogicalResult GetConstShapeValue(Value shape_value,
536                                  llvm::SmallVector<int64_t, 8>* shape) {
537   auto shape_op = shape_value.getDefiningOp();
538   if (!shape_op) return failure();
539   auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
540   if (!shape_const_op) return failure();
541   for (const auto& v : shape_const_op.value().getValues<APInt>()) {
542     int64_t dim_size = v.getSExtValue();
543     if (dim_size == ShapedType::kDynamicSize) return failure();
544     shape->push_back(dim_size);
545   }
546   return success();
547 }
548 
549 // Checks the result Variant type to infer the element shape if fully defined.
550 // If the Variant type has multiple subtypes or does not have static shape,
551 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)552 LogicalResult GetElementShapeFromResultType(
553     Type type, llvm::SmallVector<int64_t, 8>* shape) {
554   auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
555   if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
556   TensorType tensor_type = variant_type.getSubtypes().front();
557   if (!tensor_type.hasStaticShape()) return failure();
558   for (auto d : tensor_type.getShape()) shape->push_back(d);
559   return success();
560 }
561 
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)562 LogicalResult HandleEmptyTensorListOp(
563     TF::EmptyTensorListOp list,
564     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
565   Value buffer;
566   OpBuilder builder(list);
567   llvm::SmallVector<int64_t, 8> element_shape;
568   // Infer TensorList element shape from the return type first, and then from
569   // the const element shape operand. We first check the return type because
570   // shape inference might have successfully inferred the element shape from
571   // write operations on the TensorList.
572   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
573     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
574       return list.emitOpError("unknown tensor list element shape");
575     }
576   }
577   if (failed(cutil::CreateInitBufferValue(
578           element_shape, list.max_num_elements(), list, list.element_dtype(),
579           builder, &buffer))) {
580     return failure();
581   }
582   Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
583   list.handle().replaceAllUsesWith(buffer);
584   (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
585   list.erase();
586   return success();
587 }
588 
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)589 LogicalResult HandleTensorListReserveOp(
590     TF::TensorListReserveOp list,
591     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
592   Value buffer;
593   OpBuilder builder(list);
594   llvm::SmallVector<int64_t, 8> element_shape;
595   // Infer TensorList element shape from the return type first, and then from
596   // the const element shape operand. We first check the return type because
597   // shape inference might have successfully inferred the element shape from
598   // write operations on the TensorList.
599   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
600     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
601       return list.emitOpError("unknown tensor list element shape");
602     }
603   }
604   if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
605                                           list, list.element_dtype(), builder,
606                                           &buffer))) {
607     return failure();
608   }
609   Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
610                                               list.getLoc());
611   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
612   list.handle().replaceAllUsesWith(buffer);
613   list.erase();
614   return success();
615 }
616 
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)617 LogicalResult HandleTensorListFromTensorOp(
618     TF::TensorListFromTensorOp list,
619     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
620   OpBuilder builder(list);
621   Value buffer = builder.create<TF::IdentityOp>(
622       list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
623       ArrayRef<Value>{list.tensor()});
624   auto type = buffer.getType().cast<TensorType>();
625   if (!type.hasStaticShape()) {
626     return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
627   }
628   Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
629   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
630   list.output_handle().replaceAllUsesWith(buffer);
631   list.erase();
632   return success();
633 }
634 
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)635 LogicalResult HandleTensorListPushBackOp(
636     TF::TensorListPushBackOp push,
637     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
638   auto buffer = push.input_handle();
639   auto it = buffer_to_size->find(buffer);
640   if (it == buffer_to_size->end()) {
641     return push.emitOpError(
642         "found tf.TensorListPushBack on unknown TensorList.");
643   }
644   if (it->getSecond().fixed) {
645     return push.emitError("cannot push on a fixed-size tensor list");
646   }
647   auto size = it->getSecond().size;
648   OpBuilder builder(push);
649   auto new_buffer =
650       cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
651   auto new_size = builder.create<TF::AddV2Op>(
652       push.getLoc(), ArrayRef<Type>{size.getType()},
653       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
654   push.output_handle().replaceAllUsesWith(new_buffer);
655   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
656   push.erase();
657   return success();
658 }
659 
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)660 LogicalResult HandleTensorListPopBackOp(
661     TF::TensorListPopBackOp pop,
662     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
663   auto buffer = pop.input_handle();
664   auto it = buffer_to_size->find(buffer);
665   if (it == buffer_to_size->end()) {
666     pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
667     return failure();
668   }
669   if (it->getSecond().fixed) {
670     return pop.emitError("cannot pop on a fixed-size tensor list");
671   }
672   auto size = it->getSecond().size;
673   OpBuilder builder(pop);
674   auto new_buffer = builder.create<TF::IdentityOp>(
675       pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
676   auto new_size = builder.create<TF::SubOp>(
677       pop.getLoc(), ArrayRef<Type>{size.getType()},
678       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
679   auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
680   pop.output_handle().replaceAllUsesWith(new_buffer);
681   pop.tensor().replaceAllUsesWith(element);
682   pop.erase();
683   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
684   return success();
685 }
686 
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)687 LogicalResult HandleTensorListGetItemOp(
688     TF::TensorListGetItemOp get_item,
689     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
690   auto buffer = get_item.input_handle();
691   auto it = buffer_to_size.find(buffer);
692   if (it == buffer_to_size.end()) {
693     get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
694     return failure();
695   }
696   OpBuilder builder(get_item);
697   auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
698                                               get_item.getLoc());
699   auto element =
700       cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
701   get_item.item().replaceAllUsesWith(element);
702   get_item.erase();
703   return success();
704 }
705 
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)706 LogicalResult HandleTensorListSetItemOp(
707     TF::TensorListSetItemOp set_item,
708     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
709   auto buffer = set_item.input_handle();
710   auto it = buffer_to_size->find(buffer);
711   if (it == buffer_to_size->end()) {
712     set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
713     return failure();
714   }
715   OpBuilder builder(set_item);
716   auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
717                                               set_item.getLoc());
718   auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
719                                       set_item.getLoc());
720   set_item.output_handle().replaceAllUsesWith(new_buffer);
721   auto size = it->getSecond();
722   (*buffer_to_size)[new_buffer] = size;
723   set_item.erase();
724   return success();
725 }
726 
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)727 LogicalResult HandleTensorListLengthOp(
728     TF::TensorListLengthOp length,
729     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
730   auto it = buffer_to_size.find(length.input_handle());
731   if (it == buffer_to_size.end()) {
732     length.emitOpError("found tf.TensorListLength on unknown TensorList.");
733     return failure();
734   }
735   OpBuilder builder(length);
736   if (it->getSecond().fixed) {
737     auto dim = cutil::CreateScalarConst(
738         length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
739         builder, length.getLoc());
740     length.length().replaceAllUsesWith(dim);
741   } else {
742     auto current_size = it->getSecond().size;
743     // Reshapes the R1 length to a scalar.
744     auto reshape = builder.create<TF::ReshapeOp>(
745         length.getLoc(),
746         ArrayRef<Type>{RankedTensorType::get(
747             {}, getElementTypeOrSelf(current_size.getType()))},
748         ArrayRef<Value>{current_size,
749                         cutil::GetR1Const({}, builder, length.getLoc())});
750     length.length().replaceAllUsesWith(reshape);
751   }
752   length.erase();
753   return success();
754 }
755 
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)756 LogicalResult HandleTensorListElementShapeOp(
757     TF::TensorListElementShapeOp elem_shape,
758     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
759   if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
760     return elem_shape.emitOpError("unknown tensor list");
761   }
762   auto buffer = elem_shape.input_handle();
763   auto result = cutil::GetR1Const(
764       buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
765       OpBuilder(elem_shape), elem_shape.getLoc(),
766       elem_shape.shape_type().getIntOrFloatBitWidth());
767   elem_shape.element_shape().replaceAllUsesWith(result);
768   elem_shape.erase();
769   return success();
770 }
771 
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)772 LogicalResult HandleTensorListGatherOp(
773     TF::TensorListGatherOp gather,
774     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
775   auto it = buffer_to_size.find(gather.input_handle());
776   if (it == buffer_to_size.end()) {
777     return gather.emitOpError("unknown tensor list");
778   }
779   auto buffer = gather.input_handle();
780   auto result = cutil::GatherElements(gather.indices(), buffer,
781                                       OpBuilder(gather), gather.getLoc());
782   gather.values().replaceAllUsesWith(result);
783   gather.erase();
784   return success();
785 }
786 
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)787 LogicalResult HandleTensorListScatterIntoExistingListOp(
788     TF::TensorListScatterIntoExistingListOp scatter,
789     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
790   auto it = buffer_to_size->find(scatter.input_handle());
791   if (it == buffer_to_size->end()) {
792     return scatter.emitOpError("unknown tensor list");
793   }
794   auto buffer = scatter.input_handle();
795   OpBuilder builder(scatter);
796   auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
797   if (!indices_type) return scatter.emitOpError("unranked indices shape");
798   auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
799   auto shape = builder.create<TF::ConstOp>(
800       scatter.getLoc(),
801       DenseElementsAttr::get(
802           shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
803   auto indices =
804       builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
805   Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
806       scatter.getLoc(), buffer, indices, scatter.tensor());
807   scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
808   scatter.erase();
809   auto size = it->getSecond();
810   (*buffer_to_size)[tensor_scatter_update] = size;
811   return success();
812 }
813 
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)814 LogicalResult DecomposeTensorListOpsInternal(
815     Block* block, ModuleOp module,
816     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
817     llvm::StringMap<PartitionedCallDecompositionInfo>*
818         decomposed_partitioned_call_callees) {
819   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
820     // TODO(yuanzx): Add a pass to remove identities in device computation.
821     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
822       op.replaceAllUsesWith(op.getOperands());
823       op.erase();
824     } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
825       if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
826         return failure();
827       }
828     } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
829       if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
830         return failure();
831       }
832     } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
833       if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
834         return failure();
835       }
836     } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
837       if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
838         return failure();
839       }
840     } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
841       if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
842         return failure();
843       }
844     } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
845       if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
846         return failure();
847       }
848     } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
849       if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
850         return failure();
851       }
852     } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
853       if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
854         return failure();
855       }
856     } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
857       stack.tensor().replaceAllUsesWith(stack.input_handle());
858       stack.erase();
859     } else if (auto elem_shape =
860                    llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
861       if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
862         return failure();
863       }
864     } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
865       if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
866         return failure();
867       }
868     } else if (auto scatter =
869                    llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
870                        &op)) {
871       if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
872                                                            buffer_to_size))) {
873         return failure();
874       }
875     } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
876       auto it = buffer_to_size->find(addn.getOperand(0));
877       if (it != buffer_to_size->end()) {
878         addn.sum().setType(addn.getOperand(0).getType());
879         auto size = it->getSecond();
880         (*buffer_to_size)[addn.sum()] = size;
881       }
882     } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
883       if (buffer_to_size->count(zeros.x()) > 0) {
884         zeros.y().setType(zeros.x().getType());
885         auto size = (*buffer_to_size)[zeros.x()];
886         (*buffer_to_size)[zeros.y()] = size;
887       }
888     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
889       if (failed(HandleWhileOp(while_op, module, buffer_to_size,
890                                decomposed_partitioned_call_callees))) {
891         return failure();
892       }
893     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
894       if (failed(HandleCaseOrIfOp(
895               if_op, {if_op.then_function(), if_op.else_function()}, module,
896               buffer_to_size, decomposed_partitioned_call_callees))) {
897         return failure();
898       }
899     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
900       SmallVector<FuncOp, 2> branches;
901       case_op.get_branch_functions(branches);
902       if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
903                                   decomposed_partitioned_call_callees))) {
904         return failure();
905       }
906     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
907       if (!pcall.func())
908         return pcall.emitOpError(
909             "TensorList decomposition does not support call with nested "
910             "references.");
911 
912       if (failed(HandlePartitionedCallOp(
913               pcall, pcall.func(), module, buffer_to_size,
914               decomposed_partitioned_call_callees))) {
915         return failure();
916       }
917     } else if (auto spcall =
918                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
919       if (failed(HandlePartitionedCallOp(
920               spcall, spcall.func(), module, buffer_to_size,
921               decomposed_partitioned_call_callees))) {
922         return failure();
923       }
924     } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
925       if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
926                                      decomposed_partitioned_call_callees))) {
927         return failure();
928       }
929     } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
930       if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
931                                   decomposed_partitioned_call_callees))) {
932         return failure();
933       }
934     } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
935       if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
936                                     decomposed_partitioned_call_callees))) {
937         return failure();
938       }
939     }
940   }
941   return success();
942 }
943 
DecomposeTensorListOps(Block * block,ModuleOp module)944 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
945   llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
946   llvm::StringMap<PartitionedCallDecompositionInfo>
947       decomposed_partitioned_call_callees;
948   return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
949                                         &decomposed_partitioned_call_callees);
950 }
951 
runOnOperation()952 void TensorListOpsDecompositionPass::runOnOperation() {
953   auto module = getOperation();
954   auto main = module.lookupSymbol<FuncOp>("main");
955   if (!main) return;
956   if (failed(DecomposeTensorListOps(&main.front(), module))) {
957     signalPassFailure();
958   }
959 }
960 
961 static PassRegistration<TensorListOpsDecompositionPass> pass(
962     "tf-tensor-list-ops-decomposition",
963     "Decompose tensor list operations into operations on buffers and sizes. "
964     "Needs static shapes.");
965 
966 }  // namespace
967 
968 namespace TF {
969 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()970 CreateTensorListOpsDecompositionPass() {
971   return std::make_unique<TensorListOpsDecompositionPass>();
972 }
973 }  // namespace TF
974 }  // namespace mlir
975