• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace mlir {
43 
44 namespace {
45 
46 namespace cutil = TF::collection_ops_util;
47 
48 struct TensorListOpsDecompositionPass
49     : public TF::TensorListOpsDecompositionPassBase<
50           TensorListOpsDecompositionPass> {
51   void runOnOperation() override;
52 };
53 
54 // Updates func's type according to its current arguments and return values.
UpdateFuncType(func::FuncOp func)55 void UpdateFuncType(func::FuncOp func) {
56   llvm::SmallVector<Type, 8> arg_types;
57   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
58   func.setType(
59       FunctionType::get(func.getContext(), arg_types,
60                         func.front().getTerminator()->getOperandTypes()));
61 }
62 
63 // Holds the size value of a tensor list and whether the size is statically
64 // known (fixed).
65 struct SizeInfo {
66   Value size;
67   bool fixed;
68 };
69 
70 // Modifies a function's signature to rewrite tensor list arguments to buffers
71 // and sizes.
ModifyFunctionSignature(func::FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)72 void ModifyFunctionSignature(
73     func::FuncOp func, Type size_type,
74     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
75     llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
76     llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
77   auto new_input_types = llvm::to_vector<8>(func.getFunctionType().getInputs());
78   int64_t original_arg_count = new_input_types.size();
79   Location loc = func.getLoc();
80   for (int64_t i = 0; i < original_arg_count; ++i) {
81     auto buffer_type = arg_to_buffer_type(i);
82     if (!buffer_type.has_value()) continue;
83     func.getArgument(i).setType(*buffer_type);
84     new_input_types[i] = *buffer_type;
85     auto size_arg = func.front().addArgument(size_type, loc);
86     new_input_types.push_back(size_arg.getType());
87     if (buffer_to_size) {
88       (*buffer_to_size)[func.getArgument(i)] = {size_arg,
89                                                 arg_buffer_size_is_fixed(i)};
90     }
91   }
92   UpdateFuncType(func);
93 }
94 
95 // Holds information about a decomposed callee function for
96 // PartitionedCall/StatefulPartitionedCall.
97 struct PartitionedCallDecompositionInfo {
98   bool signature_change;
99   func::FuncOp decomposed_callee;
100   llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
101   // Each element is a tuple of (buffer_return_index, size_return_index,
102   // fixed_size).
103   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
104       buffer_ret_to_size_ret;
105 };
106 
107 LogicalResult DecomposeTensorListOpsInternal(
108     Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
109     llvm::StringMap<PartitionedCallDecompositionInfo>*);
110 
111 // Adds the corresponding sizes of tensor list buffers in block's terminator
112 // to the list of return values. Returns the mapping from the buffer
113 // indices to the added size indices, which is a list of tuples
114 // (buffer_return_index, size_return_index, fixed_size).
115 template <class TerminatorOp>
116 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)117 AddTensorListSizesToTerminator(
118     Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
119   auto old_terminator = block.getTerminator();
120   auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
121   llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
122       output_buffer_to_size;
123   for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
124     auto it = buffer_to_size.find(retval.value());
125     if (it == buffer_to_size.end()) continue;
126     output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
127                                        it->getSecond().fixed);
128     new_outputs.push_back(it->getSecond().size);
129   }
130   OpBuilder(old_terminator)
131       .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
132   old_terminator->erase();
133   return output_buffer_to_size;
134 }
135 
136 // Adds the corresponding sizes of tensor list buffers in func's return values
137 // to the list of return values. Returns the mapping from the buffer indices to
138 // the added size indices, which is a list of tuples (buffer_return_index,
139 // size_return_index, fixed_size).
ModifyFunctionReturn(func::FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)140 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
141     func::FuncOp func,
142     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
143   auto output_buffer_to_size = AddTensorListSizesToTerminator<func::ReturnOp>(
144       func.front(), buffer_to_size);
145   UpdateFuncType(func);
146   return output_buffer_to_size;
147 }
148 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)149 LogicalResult HandleWhileOp(
150     TF::WhileOp while_op, ModuleOp module,
151     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
152     llvm::StringMap<PartitionedCallDecompositionInfo>*
153         decomposed_partitioned_call_callees) {
154   // Rewrite body.
155   auto body = while_op.body_function();
156   llvm::SmallDenseMap<Value, SizeInfo> body_map;
157   auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
158     auto it = buffer_to_size->find(while_op.getOperand(index));
159     if (it == buffer_to_size->end()) return llvm::None;
160     return it->getFirst().getType();
161   };
162   auto arg_buffer_size_is_fixed = [&](int64_t index) {
163     return (*buffer_to_size)[while_op.getOperand(index)].fixed;
164   };
165   OpBuilder builder(while_op);
166   ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
167                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
168   if (failed(DecomposeTensorListOpsInternal(
169           &body.front(), module, &body_map,
170           decomposed_partitioned_call_callees))) {
171     return failure();
172   }
173   auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
174 
175   // Rewrite cond.
176   auto cond = while_op.cond_function();
177   llvm::SmallDenseMap<Value, SizeInfo> cond_map;
178   ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
179                           find_arg_tensor_list_type, arg_buffer_size_is_fixed);
180   if (failed(DecomposeTensorListOpsInternal(
181           &cond.front(), module, &cond_map,
182           decomposed_partitioned_call_callees))) {
183     return failure();
184   }
185   if (output_buffer_to_size.empty()) {
186     return success();
187   }
188   // Create the new while op.
189   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
190   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
191     auto it = buffer_to_size->find(while_op.getOperand(i));
192     if (it == buffer_to_size->end()) continue;
193     new_while_operands.push_back(it->getSecond().size);
194   }
195   auto new_while = builder.create<TF::WhileOp>(
196       while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands,
197       while_op->getAttrs());
198   for (const auto& entry : output_buffer_to_size) {
199     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
200         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
201   }
202   while_op.replaceAllUsesWith(
203       new_while.getResults().take_front(while_op.getNumResults()));
204   while_op.erase();
205   return success();
206 }
207 
208 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<func::FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)209 LogicalResult HandleCaseOrIfOp(
210     CaseOrIfOp op, ArrayRef<func::FuncOp> branches, ModuleOp module,
211     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
212     llvm::StringMap<PartitionedCallDecompositionInfo>*
213         decomposed_partitioned_call_callees) {
214   // Rewrite the branches.
215   SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
216   branch_maps.resize(branches.size());
217 
218   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
219     auto it = buffer_to_size->find(op.getOperand(index + 1));
220     if (it == buffer_to_size->end()) return llvm::None;
221     return it->getFirst().getType();
222   };
223   auto arg_buffer_size_is_fixed = [&](int64_t index) {
224     return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
225   };
226   OpBuilder builder(op);
227   for (const auto& pair : llvm::zip(branches, branch_maps)) {
228     func::FuncOp branch = std::get<0>(pair);
229     llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
230     ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
231                             find_arg_buffer_type, arg_buffer_size_is_fixed);
232 
233     if (failed(DecomposeTensorListOpsInternal(
234             &branch.front(), module, &branch_map,
235             decomposed_partitioned_call_callees)))
236       return failure();
237   }
238 
239   const bool arg_no_changed = branch_maps.front().empty();
240   auto output_buffer_to_size =
241       ModifyFunctionReturn(branches.front(), branch_maps.front());
242   for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
243     ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
244 
245   if (output_buffer_to_size.empty() && arg_no_changed) return success();
246 
247   // Recreate the op.
248   auto new_operands = llvm::to_vector<8>(op.getOperands());
249   for (int64_t i = 1; i < op.getNumOperands(); ++i) {
250     auto it = buffer_to_size->find(op.getOperand(i));
251     if (it == buffer_to_size->end()) continue;
252     new_operands.push_back(it->getSecond().size);
253   }
254   func::FuncOp first_branch = branches.front();
255   auto new_op = OpBuilder(op).create<CaseOrIfOp>(
256       op.getLoc(), first_branch.getFunctionType().getResults(), new_operands,
257       op->getAttrs());
258   for (const auto& entry : output_buffer_to_size) {
259     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
260         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
261   }
262   op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
263   op.erase();
264   return success();
265 }
266 
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)267 LogicalResult HandleWhileRegionOp(
268     TF::WhileRegionOp while_op, ModuleOp module,
269     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
270     llvm::StringMap<PartitionedCallDecompositionInfo>*
271         decomposed_partitioned_call_callees) {
272   OpBuilder builder(while_op);
273   auto modify_region_arguments = [&](Region& region) {
274     int64_t original_arg_count = region.getNumArguments();
275     for (int64_t i = 0; i < original_arg_count; ++i) {
276       auto operand = while_op.getOperand(i);
277       auto it = buffer_to_size->find(operand);
278       if (it == buffer_to_size->end()) continue;
279       auto buffer_type = it->getFirst().getType();
280       region.getArgument(i).setType(buffer_type);
281       auto size_arg =
282           region.addArgument(cutil::GetSizeType(builder), region.getLoc());
283       (*buffer_to_size)[region.getArgument(i)] = {size_arg,
284                                                   it->getSecond().fixed};
285     }
286   };
287 
288   // Rewrite body.
289   Region& body_region = while_op.body();
290   modify_region_arguments(body_region);
291   if (failed(DecomposeTensorListOpsInternal(
292           &body_region.front(), module, buffer_to_size,
293           decomposed_partitioned_call_callees))) {
294     return failure();
295   }
296   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
297       body_region.front(), *buffer_to_size);
298 
299   // Rewrite cond.
300   Region& cond_region = while_op.cond();
301   modify_region_arguments(cond_region);
302   if (failed(DecomposeTensorListOpsInternal(
303           &cond_region.front(), module, buffer_to_size,
304           decomposed_partitioned_call_callees))) {
305     return failure();
306   }
307 
308   if (output_buffer_to_size.empty()) return success();
309 
310   // Create the new while op.
311   auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
312   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
313     auto it = buffer_to_size->find(while_op.getOperand(i));
314     if (it == buffer_to_size->end()) continue;
315     new_while_operands.push_back(it->getSecond().size);
316   }
317   auto new_while = builder.create<TF::WhileRegionOp>(
318       while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
319       new_while_operands, while_op->getAttrs());
320   new_while.body().takeBody(body_region);
321   new_while.cond().takeBody(cond_region);
322   for (const auto& entry : output_buffer_to_size) {
323     (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
324         new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
325   }
326   while_op.replaceAllUsesWith(
327       new_while.getResults().take_front(while_op.getNumResults()));
328   while_op.erase();
329   return success();
330 }
331 
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)332 LogicalResult HandleIfRegionOp(
333     TF::IfRegionOp if_op, ModuleOp module,
334     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
335     llvm::StringMap<PartitionedCallDecompositionInfo>*
336         decomposed_partitioned_call_callees) {
337   // Rewrite the branches.
338   Region& then_branch = if_op.then_branch();
339   Region& else_branch = if_op.else_branch();
340   if (failed(DecomposeTensorListOpsInternal(
341           &then_branch.front(), module, buffer_to_size,
342           decomposed_partitioned_call_callees)))
343     return failure();
344   if (failed(DecomposeTensorListOpsInternal(
345           &else_branch.front(), module, buffer_to_size,
346           decomposed_partitioned_call_callees)))
347     return failure();
348 
349   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
350       then_branch.front(), *buffer_to_size);
351   AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
352                                               *buffer_to_size);
353 
354   if (output_buffer_to_size.empty()) return success();
355 
356   // Recreate the op.
357   auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
358       if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
359       if_op.getOperand(), if_op->getAttrs());
360   for (const auto& entry : output_buffer_to_size) {
361     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
362         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
363   }
364 
365   new_op.then_branch().takeBody(if_op.then_branch());
366   new_op.else_branch().takeBody(if_op.else_branch());
367 
368   if_op.replaceAllUsesWith(
369       new_op.getResults().take_front(if_op.getNumResults()));
370   if_op.erase();
371   return success();
372 }
373 
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)374 LogicalResult HandleCaseRegionOp(
375     TF::CaseRegionOp case_op, ModuleOp module,
376     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
377     llvm::StringMap<PartitionedCallDecompositionInfo>*
378         decomposed_partitioned_call_callees) {
379   // Rewrite the branches.
380   RegionRange branches = case_op.getRegions();
381 
382   for (Region* branch : branches) {
383     if (failed(DecomposeTensorListOpsInternal(
384             &branch->front(), module, buffer_to_size,
385             decomposed_partitioned_call_callees)))
386       return failure();
387   }
388 
389   // Get the output buffer index to size index mapping one of the branches. It
390   // should be same for all the branches so we only get it for the first branch.
391   Region* first_branch = branches.front();
392   auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
393       first_branch->front(), *buffer_to_size);
394   for (Region* branch : branches.drop_front()) {
395     AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
396                                                 *buffer_to_size);
397   }
398 
399   if (output_buffer_to_size.empty()) return success();
400 
401   // Recreate the op.
402   auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
403       case_op.getLoc(),
404       first_branch->front().getTerminator()->getOperandTypes(),
405       case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
406   for (const auto& entry : output_buffer_to_size) {
407     (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
408         new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
409   }
410 
411   for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
412     std::get<0>(pair)->takeBody(*std::get<1>(pair));
413   }
414   case_op.replaceAllUsesWith(
415       new_op.getResults().take_front(case_op.getNumResults()));
416   case_op.erase();
417   return success();
418 }
419 
420 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)421 LogicalResult HandlePartitionedCallOp(
422     CallOp call, func::FuncOp callee, ModuleOp module,
423     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
424     llvm::StringMap<PartitionedCallDecompositionInfo>*
425         decomposed_partitioned_call_callees) {
426   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
427       callee.getName(), PartitionedCallDecompositionInfo());
428   auto& info = emplace_res.first->second;
429   // Recreates the call op with info.
430   auto recreate_caller = [&] {
431     auto new_operands = llvm::to_vector<8>(call.getOperands());
432     for (int64_t i = 0; i < call.getNumOperands(); ++i) {
433       auto arg_it = info.buffer_arg_to_size_arg.find(i);
434       if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
435       auto it = buffer_to_size->find(call.getOperand(i));
436       if (it == buffer_to_size->end()) {
437         call.emitOpError("unknown tensor list.");
438         return failure();
439       }
440       assert(arg_it->second == new_operands.size());
441       new_operands.push_back(it->getSecond().size);
442     }
443     OpBuilder builder(call);
444     auto new_call = builder.create<CallOp>(
445         call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
446         new_operands, call->getAttrs());
447     new_call->setAttr(
448         "f", SymbolRefAttr::get(
449                  builder.getContext(),
450                  const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
451     for (const auto& entry : info.buffer_ret_to_size_ret) {
452       (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
453           new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
454     }
455     call.replaceAllUsesWith(
456         new_call.getResults().take_front(call.getNumResults()));
457     call.erase();
458     return success();
459   };
460   if (!emplace_res.second) {
461     // This callee was handled before.
462     if (!info.signature_change) return success();
463     return recreate_caller();
464   }
465   // Rewrite the callee.
466   llvm::SmallDenseMap<Value, SizeInfo> callee_map;
467   func::FuncOp lowered_callee = callee;
468   if (!callee.isPrivate()) {
469     // Clone non-private callee in case of signature change.
470     lowered_callee = callee.clone();
471     lowered_callee.setPrivate();
472   }
473   auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
474     auto it = buffer_to_size->find(call.getOperand(index));
475     if (it == buffer_to_size->end()) return llvm::None;
476     return it->getFirst().getType();
477   };
478   auto arg_buffer_size_is_fixed = [&](int64_t index) {
479     return (*buffer_to_size)[call.getOperand(index)].fixed;
480   };
481   ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
482                           &callee_map, find_arg_buffer_type,
483                           arg_buffer_size_is_fixed);
484   const bool args_no_changed = callee_map.empty();
485   if (failed(DecomposeTensorListOpsInternal(
486           &lowered_callee.front(), module, &callee_map,
487           decomposed_partitioned_call_callees))) {
488     return failure();
489   }
490   info.buffer_ret_to_size_ret =
491       ModifyFunctionReturn(lowered_callee, callee_map);
492   info.decomposed_callee = lowered_callee;
493   if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
494     // Signature is not modified. We do not need to keep two copies.
495     info.signature_change = false;
496     if (lowered_callee != callee) {
497       lowered_callee.setName(
498           StringAttr::get(callee->getContext(), callee.getName()));
499       callee.erase();
500       SymbolTable(module).insert(lowered_callee);
501     }
502   } else {
503     info.signature_change = true;
504     for (auto& entry : callee_map) {
505       auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
506       if (!buffer_arg) continue;
507       info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
508           entry.getSecond().size.cast<BlockArgument>().getArgNumber();
509     }
510     if (lowered_callee != callee) {
511       // Add the clone with a new name.
512       lowered_callee.setName(StringAttr::get(
513           callee->getContext(),
514           llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str()));
515       SymbolTable(module).insert(lowered_callee);
516       callee = lowered_callee;
517     }
518   }
519   if (info.signature_change) return recreate_caller();
520   return success();
521 }
522 
523 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
524 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)525 LogicalResult GetConstShapeValue(Value shape_value,
526                                  llvm::SmallVector<int64_t, 8>* shape) {
527   auto shape_op = shape_value.getDefiningOp();
528   if (!shape_op) return failure();
529   auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
530   if (!shape_const_op) return failure();
531   for (const auto& v : shape_const_op.value().getValues<APInt>()) {
532     int64_t dim_size = v.getSExtValue();
533     if (dim_size == ShapedType::kDynamicSize) return failure();
534     shape->push_back(dim_size);
535   }
536   return success();
537 }
538 
539 // Checks the result Variant type to infer the element shape if fully defined.
540 // If the Variant type has multiple subtypes or does not have static shape,
541 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)542 LogicalResult GetElementShapeFromResultType(
543     Type type, llvm::SmallVector<int64_t, 8>* shape) {
544   auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
545   if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
546   TensorType tensor_type = variant_type.getSubtypes().front();
547   if (!tensor_type.hasStaticShape()) return failure();
548   for (auto d : tensor_type.getShape()) shape->push_back(d);
549   return success();
550 }
551 
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)552 LogicalResult HandleEmptyTensorListOp(
553     TF::EmptyTensorListOp list,
554     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
555   Value buffer;
556   OpBuilder builder(list);
557   llvm::SmallVector<int64_t, 8> element_shape;
558   // Infer TensorList element shape from the return type first, and then from
559   // the const element shape operand. We first check the return type because
560   // shape inference might have successfully inferred the element shape from
561   // write operations on the TensorList.
562   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
563     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
564       return list.emitOpError("unknown tensor list element shape");
565     }
566   }
567   if (failed(cutil::CreateInitBufferValue(
568           element_shape, list.max_num_elements(), list, list.element_dtype(),
569           builder, &buffer))) {
570     return failure();
571   }
572   Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
573   list.handle().replaceAllUsesWith(buffer);
574   (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
575   list.erase();
576   return success();
577 }
578 
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)579 LogicalResult HandleTensorListReserveOp(
580     TF::TensorListReserveOp list,
581     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
582   Value buffer;
583   OpBuilder builder(list);
584   llvm::SmallVector<int64_t, 8> element_shape;
585   // Infer TensorList element shape from the return type first, and then from
586   // the const element shape operand. We first check the return type because
587   // shape inference might have successfully inferred the element shape from
588   // write operations on the TensorList.
589   if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
590     if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
591       return list.emitOpError("unknown tensor list element shape");
592     }
593   }
594   if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
595                                           list, list.element_dtype(), builder,
596                                           &buffer))) {
597     return failure();
598   }
599   Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
600                                               list.getLoc());
601   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
602   list.handle().replaceAllUsesWith(buffer);
603   list.erase();
604   return success();
605 }
606 
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)607 LogicalResult HandleTensorListFromTensorOp(
608     TF::TensorListFromTensorOp list,
609     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
610   OpBuilder builder(list);
611   Value buffer = builder.create<TF::IdentityOp>(
612       list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
613       ArrayRef<Value>{list.tensor()});
614   auto type = buffer.getType().cast<TensorType>();
615   if (!type.hasStaticShape()) {
616     return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
617   }
618   Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
619   (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
620   list.output_handle().replaceAllUsesWith(buffer);
621   list.erase();
622   return success();
623 }
624 
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)625 LogicalResult HandleTensorListPushBackOp(
626     TF::TensorListPushBackOp push,
627     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
628   auto buffer = push.input_handle();
629   auto it = buffer_to_size->find(buffer);
630   if (it == buffer_to_size->end()) {
631     return push.emitOpError(
632         "found tf.TensorListPushBack on unknown TensorList.");
633   }
634   if (it->getSecond().fixed) {
635     return push.emitError("cannot push on a fixed-size tensor list");
636   }
637   auto size = it->getSecond().size;
638   OpBuilder builder(push);
639   auto new_buffer =
640       cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
641   auto new_size = builder.create<TF::AddV2Op>(
642       push.getLoc(), ArrayRef<Type>{size.getType()},
643       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
644   push.output_handle().replaceAllUsesWith(new_buffer);
645   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
646   push.erase();
647   return success();
648 }
649 
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)650 LogicalResult HandleTensorListPopBackOp(
651     TF::TensorListPopBackOp pop,
652     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
653   auto buffer = pop.input_handle();
654   auto it = buffer_to_size->find(buffer);
655   if (it == buffer_to_size->end()) {
656     pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
657     return failure();
658   }
659   if (it->getSecond().fixed) {
660     return pop.emitError("cannot pop on a fixed-size tensor list");
661   }
662   auto size = it->getSecond().size;
663   OpBuilder builder(pop);
664   auto new_buffer = builder.create<TF::IdentityOp>(
665       pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
666   auto new_size = builder.create<TF::SubOp>(
667       pop.getLoc(), ArrayRef<Type>{size.getType()},
668       ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
669   auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
670   pop.output_handle().replaceAllUsesWith(new_buffer);
671   pop.tensor().replaceAllUsesWith(element);
672   pop.erase();
673   (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
674   return success();
675 }
676 
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)677 LogicalResult HandleTensorListGetItemOp(
678     TF::TensorListGetItemOp get_item,
679     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
680   auto buffer = get_item.input_handle();
681   auto it = buffer_to_size.find(buffer);
682   if (it == buffer_to_size.end()) {
683     get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
684     return failure();
685   }
686   OpBuilder builder(get_item);
687   auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
688                                               get_item.getLoc());
689   auto element =
690       cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
691   get_item.item().replaceAllUsesWith(element);
692   get_item.erase();
693   return success();
694 }
695 
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)696 LogicalResult HandleTensorListSetItemOp(
697     TF::TensorListSetItemOp set_item,
698     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
699   auto buffer = set_item.input_handle();
700   auto it = buffer_to_size->find(buffer);
701   if (it == buffer_to_size->end()) {
702     set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
703     return failure();
704   }
705   OpBuilder builder(set_item);
706   auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
707                                               set_item.getLoc());
708   auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
709                                       set_item.getLoc());
710   set_item.output_handle().replaceAllUsesWith(new_buffer);
711   auto size = it->getSecond();
712   (*buffer_to_size)[new_buffer] = size;
713   set_item.erase();
714   return success();
715 }
716 
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)717 LogicalResult HandleTensorListLengthOp(
718     TF::TensorListLengthOp length,
719     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
720   auto it = buffer_to_size.find(length.input_handle());
721   if (it == buffer_to_size.end()) {
722     length.emitOpError("found tf.TensorListLength on unknown TensorList.");
723     return failure();
724   }
725   OpBuilder builder(length);
726   if (it->getSecond().fixed) {
727     auto dim = cutil::CreateScalarConst(
728         length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
729         builder, length.getLoc());
730     length.length().replaceAllUsesWith(dim);
731   } else {
732     auto current_size = it->getSecond().size;
733     // Reshapes the R1 length to a scalar.
734     auto reshape = builder.create<TF::ReshapeOp>(
735         length.getLoc(),
736         ArrayRef<Type>{RankedTensorType::get(
737             {}, getElementTypeOrSelf(current_size.getType()))},
738         ArrayRef<Value>{current_size,
739                         cutil::GetR1Const({}, builder, length.getLoc())});
740     length.length().replaceAllUsesWith(reshape);
741   }
742   length.erase();
743   return success();
744 }
745 
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)746 LogicalResult HandleTensorListElementShapeOp(
747     TF::TensorListElementShapeOp elem_shape,
748     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
749   if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
750     return elem_shape.emitOpError("unknown tensor list");
751   }
752   auto buffer = elem_shape.input_handle();
753   auto result = cutil::GetR1Const(
754       buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
755       OpBuilder(elem_shape), elem_shape.getLoc(),
756       elem_shape.shape_type().getIntOrFloatBitWidth());
757   elem_shape.element_shape().replaceAllUsesWith(result);
758   elem_shape.erase();
759   return success();
760 }
761 
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)762 LogicalResult HandleTensorListGatherOp(
763     TF::TensorListGatherOp gather,
764     const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
765   auto it = buffer_to_size.find(gather.input_handle());
766   if (it == buffer_to_size.end()) {
767     return gather.emitOpError("unknown tensor list");
768   }
769   auto buffer = gather.input_handle();
770   auto result = cutil::GatherElements(gather.indices(), buffer,
771                                       OpBuilder(gather), gather.getLoc());
772   gather.values().replaceAllUsesWith(result);
773   gather.erase();
774   return success();
775 }
776 
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)777 LogicalResult HandleTensorListScatterIntoExistingListOp(
778     TF::TensorListScatterIntoExistingListOp scatter,
779     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
780   auto it = buffer_to_size->find(scatter.input_handle());
781   if (it == buffer_to_size->end()) {
782     return scatter.emitOpError("unknown tensor list");
783   }
784   auto buffer = scatter.input_handle();
785   OpBuilder builder(scatter);
786   auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
787   if (!indices_type) return scatter.emitOpError("unranked indices shape");
788   auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
789   auto shape = builder.create<TF::ConstOp>(
790       scatter.getLoc(),
791       DenseElementsAttr::get(
792           shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
793   auto indices =
794       builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
795   Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
796       scatter.getLoc(), buffer, indices, scatter.tensor());
797   scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
798   scatter.erase();
799   auto size = it->getSecond();
800   (*buffer_to_size)[tensor_scatter_update] = size;
801   return success();
802 }
803 
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)804 LogicalResult DecomposeTensorListOpsInternal(
805     Block* block, ModuleOp module,
806     llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
807     llvm::StringMap<PartitionedCallDecompositionInfo>*
808         decomposed_partitioned_call_callees) {
809   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
810     // TODO(yuanzx): Add a pass to remove identities in device computation.
811     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
812       op.replaceAllUsesWith(op.getOperands());
813       op.erase();
814     } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
815       if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
816         return failure();
817       }
818     } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
819       if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
820         return failure();
821       }
822     } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
823       if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
824         return failure();
825       }
826     } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
827       if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
828         return failure();
829       }
830     } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
831       if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
832         return failure();
833       }
834     } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
835       if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
836         return failure();
837       }
838     } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
839       if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
840         return failure();
841       }
842     } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
843       if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
844         return failure();
845       }
846     } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
847       stack.tensor().replaceAllUsesWith(stack.input_handle());
848       stack.erase();
849     } else if (auto elem_shape =
850                    llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
851       if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
852         return failure();
853       }
854     } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
855       if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
856         return failure();
857       }
858     } else if (auto scatter =
859                    llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
860                        &op)) {
861       if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
862                                                            buffer_to_size))) {
863         return failure();
864       }
865     } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
866       auto it = buffer_to_size->find(addn.getOperand(0));
867       if (it != buffer_to_size->end()) {
868         addn.sum().setType(addn.getOperand(0).getType());
869         auto size = it->getSecond();
870         (*buffer_to_size)[addn.sum()] = size;
871       }
872     } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
873       if (buffer_to_size->count(zeros.x()) > 0) {
874         zeros.y().setType(zeros.x().getType());
875         auto size = (*buffer_to_size)[zeros.x()];
876         (*buffer_to_size)[zeros.y()] = size;
877       }
878     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
879       if (failed(HandleWhileOp(while_op, module, buffer_to_size,
880                                decomposed_partitioned_call_callees))) {
881         return failure();
882       }
883     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
884       if (failed(HandleCaseOrIfOp(
885               if_op, {if_op.then_function(), if_op.else_function()}, module,
886               buffer_to_size, decomposed_partitioned_call_callees))) {
887         return failure();
888       }
889     } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
890       SmallVector<func::FuncOp, 2> branches;
891       case_op.get_branch_functions(branches);
892       if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
893                                   decomposed_partitioned_call_callees))) {
894         return failure();
895       }
896     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
897       if (!pcall.func())
898         return pcall.emitOpError(
899             "TensorList decomposition does not support call with nested "
900             "references.");
901 
902       if (failed(HandlePartitionedCallOp(
903               pcall, pcall.func(), module, buffer_to_size,
904               decomposed_partitioned_call_callees))) {
905         return failure();
906       }
907     } else if (auto spcall =
908                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
909       if (failed(HandlePartitionedCallOp(
910               spcall, spcall.func(), module, buffer_to_size,
911               decomposed_partitioned_call_callees))) {
912         return failure();
913       }
914     } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
915       if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
916                                      decomposed_partitioned_call_callees))) {
917         return failure();
918       }
919     } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
920       if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
921                                   decomposed_partitioned_call_callees))) {
922         return failure();
923       }
924     } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
925       if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
926                                     decomposed_partitioned_call_callees))) {
927         return failure();
928       }
929     }
930   }
931   return success();
932 }
933 
DecomposeTensorListOps(Block * block,ModuleOp module)934 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
935   llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
936   llvm::StringMap<PartitionedCallDecompositionInfo>
937       decomposed_partitioned_call_callees;
938   return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
939                                         &decomposed_partitioned_call_callees);
940 }
941 
runOnOperation()942 void TensorListOpsDecompositionPass::runOnOperation() {
943   auto module = getOperation();
944   auto main = module.lookupSymbol<func::FuncOp>("main");
945   if (!main) return;
946   if (failed(DecomposeTensorListOps(&main.front(), module))) {
947     signalPassFailure();
948   }
949 }
950 
951 }  // namespace
952 
953 namespace TF {
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()955 CreateTensorListOpsDecompositionPass() {
956   return std::make_unique<TensorListOpsDecompositionPass>();
957 }
958 }  // namespace TF
959 }  // namespace mlir
960