1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/platform/types.h"
40
41 namespace mlir {
42
43 namespace {
44
45 namespace cutil = TF::collection_ops_util;
46
47 // A pass that rewrites tensor list operations to tensor operations on buffers
48 // and size values.
49 //
50 // This pass requires that the full shape of the tensor list can be inferred: 1)
51 // the maximum size needs to be a constant and 2) the element shape needs to be
52 // constant.
53 //
54 // A tensor list creation op "tf.EmptyTensorList"/"tf.TensorListReserve" will be
55 // turned in to a zero-initialized buffer, and the size is initialized to a 0
56 // for "tf.EmptyTensorList" or the specified size for "tf.TensorListReserve".
57 // Each push will be turned into "tf.XlaDynamicUpdateSlice" with the incremented
58 // size, and each pop will be turned into a "tf.Slice" and a copy of the buffer
59 // with decremented size. Each SetItem will be turned into a
60 // "tf.XlaDynamicUpdateSlice" with unchanged size, and each GetItem will be
61 // turned into a "tf.Slice".
62 //
63 // The pass also works across control flow and functional calls.
64 struct TensorListOpsDecompositionPass
65 : public PassWrapper<TensorListOpsDecompositionPass,
66 OperationPass<ModuleOp>> {
67 void runOnOperation() override;
68 };
69
70 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)71 void UpdateFuncType(FuncOp func) {
72 llvm::SmallVector<Type, 8> arg_types;
73 for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
74 func.setType(
75 FunctionType::get(func.getContext(), arg_types,
76 func.front().getTerminator()->getOperandTypes()));
77 }
78
79 // Holds the size value of a tensor list and whether the size is statically
80 // known (fixed).
81 struct SizeInfo {
82 Value size;
83 bool fixed;
84 };
85
86 // Modifies a function's signature to rewrite tensor list arguments to buffers
87 // and sizes.
ModifyFunctionSignature(FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)88 void ModifyFunctionSignature(
89 FuncOp func, Type size_type,
90 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
91 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
92 llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
93 auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
94 int64_t original_arg_count = new_input_types.size();
95 for (int64_t i = 0; i < original_arg_count; ++i) {
96 auto buffer_type = arg_to_buffer_type(i);
97 if (!buffer_type.hasValue()) continue;
98 func.getArgument(i).setType(*buffer_type);
99 new_input_types[i] = *buffer_type;
100 auto size_arg = func.front().addArgument(size_type);
101 new_input_types.push_back(size_arg.getType());
102 if (buffer_to_size) {
103 (*buffer_to_size)[func.getArgument(i)] = {size_arg,
104 arg_buffer_size_is_fixed(i)};
105 }
106 }
107 UpdateFuncType(func);
108 }
109
110 // Holds information about a decomposed callee function for
111 // PartitionedCall/StatefulPartitionedCall.
112 struct PartitionedCallDecompositionInfo {
113 bool signature_change;
114 FuncOp decomposed_callee;
115 llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
116 // Each element is a tuple of (buffer_return_index, size_return_index,
117 // fixed_size).
118 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
119 buffer_ret_to_size_ret;
120 };
121
122 LogicalResult DecomposeTensorListOpsInternal(
123 Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
124 llvm::StringMap<PartitionedCallDecompositionInfo>*);
125
126 // Adds the corresponding sizes of tensor list buffers in block's terminator
127 // to the list of return values. Returns the mapping from the buffer
128 // indices to the added size indices, which is a list of tuples
129 // (buffer_return_index, size_return_index, fixed_size).
130 template <class TerminatorOp>
131 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)132 AddTensorListSizesToTerminator(
133 Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
134 auto old_terminator = block.getTerminator();
135 auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
136 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
137 output_buffer_to_size;
138 for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
139 auto it = buffer_to_size.find(retval.value());
140 if (it == buffer_to_size.end()) continue;
141 output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
142 it->getSecond().fixed);
143 new_outputs.push_back(it->getSecond().size);
144 }
145 OpBuilder(old_terminator)
146 .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
147 old_terminator->erase();
148 return output_buffer_to_size;
149 }
150
151 // Adds the corresponding sizes of tensor list buffers in func's return values
152 // to the list of return values. Returns the mapping from the buffer indices to
153 // the added size indices, which is a list of tuples (buffer_return_index,
154 // size_return_index, fixed_size).
ModifyFunctionReturn(FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)155 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
156 FuncOp func, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
157 auto output_buffer_to_size =
158 AddTensorListSizesToTerminator<ReturnOp>(func.front(), buffer_to_size);
159 UpdateFuncType(func);
160 return output_buffer_to_size;
161 }
162
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)163 LogicalResult HandleWhileOp(
164 TF::WhileOp while_op, ModuleOp module,
165 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
166 llvm::StringMap<PartitionedCallDecompositionInfo>*
167 decomposed_partitioned_call_callees) {
168 // Rewrite body.
169 auto body = while_op.body_function();
170 llvm::SmallDenseMap<Value, SizeInfo> body_map;
171 auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
172 auto it = buffer_to_size->find(while_op.getOperand(index));
173 if (it == buffer_to_size->end()) return llvm::None;
174 return it->getFirst().getType();
175 };
176 auto arg_buffer_size_is_fixed = [&](int64_t index) {
177 return (*buffer_to_size)[while_op.getOperand(index)].fixed;
178 };
179 OpBuilder builder(while_op);
180 ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
181 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
182 if (failed(DecomposeTensorListOpsInternal(
183 &body.front(), module, &body_map,
184 decomposed_partitioned_call_callees))) {
185 return failure();
186 }
187 auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
188
189 // Rewrite cond.
190 auto cond = while_op.cond_function();
191 llvm::SmallDenseMap<Value, SizeInfo> cond_map;
192 ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
193 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
194 if (failed(DecomposeTensorListOpsInternal(
195 &cond.front(), module, &cond_map,
196 decomposed_partitioned_call_callees))) {
197 return failure();
198 }
199 if (output_buffer_to_size.empty()) {
200 return success();
201 }
202 // Create the new while op.
203 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
204 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
205 auto it = buffer_to_size->find(while_op.getOperand(i));
206 if (it == buffer_to_size->end()) continue;
207 new_while_operands.push_back(it->getSecond().size);
208 }
209 auto new_while =
210 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
211 new_while_operands, while_op.getAttrs());
212 for (const auto& entry : output_buffer_to_size) {
213 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
214 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
215 }
216 while_op.replaceAllUsesWith(
217 new_while.getResults().take_front(while_op.getNumResults()));
218 while_op.erase();
219 return success();
220 }
221
222 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)223 LogicalResult HandleCaseOrIfOp(
224 CaseOrIfOp op, ArrayRef<FuncOp> branches, ModuleOp module,
225 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
226 llvm::StringMap<PartitionedCallDecompositionInfo>*
227 decomposed_partitioned_call_callees) {
228 // Rewrite the branches.
229 SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
230 branch_maps.resize(branches.size());
231
232 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
233 auto it = buffer_to_size->find(op.getOperand(index + 1));
234 if (it == buffer_to_size->end()) return llvm::None;
235 return it->getFirst().getType();
236 };
237 auto arg_buffer_size_is_fixed = [&](int64_t index) {
238 return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
239 };
240 OpBuilder builder(op);
241 for (const auto& pair : llvm::zip(branches, branch_maps)) {
242 FuncOp branch = std::get<0>(pair);
243 llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
244 ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
245 find_arg_buffer_type, arg_buffer_size_is_fixed);
246
247 if (failed(DecomposeTensorListOpsInternal(
248 &branch.front(), module, &branch_map,
249 decomposed_partitioned_call_callees)))
250 return failure();
251 }
252
253 const bool arg_no_changed = branch_maps.front().empty();
254 auto output_buffer_to_size =
255 ModifyFunctionReturn(branches.front(), branch_maps.front());
256 for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
257 ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
258
259 if (output_buffer_to_size.empty() && arg_no_changed) return success();
260
261 // Recreate the op.
262 auto new_operands = llvm::to_vector<8>(op.getOperands());
263 for (int64_t i = 1; i < op.getNumOperands(); ++i) {
264 auto it = buffer_to_size->find(op.getOperand(i));
265 if (it == buffer_to_size->end()) continue;
266 new_operands.push_back(it->getSecond().size);
267 }
268 FuncOp first_branch = branches.front();
269 auto new_op = OpBuilder(op).create<CaseOrIfOp>(
270 op.getLoc(), first_branch.getType().getResults(), new_operands,
271 op.getAttrs());
272 for (const auto& entry : output_buffer_to_size) {
273 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
274 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
275 }
276 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
277 op.erase();
278 return success();
279 }
280
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)281 LogicalResult HandleWhileRegionOp(
282 TF::WhileRegionOp while_op, ModuleOp module,
283 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
284 llvm::StringMap<PartitionedCallDecompositionInfo>*
285 decomposed_partitioned_call_callees) {
286 OpBuilder builder(while_op);
287 auto modify_region_arguments = [&](Region& region) {
288 int64_t original_arg_count = region.getNumArguments();
289 for (int64_t i = 0; i < original_arg_count; ++i) {
290 auto operand = while_op.getOperand(i);
291 auto it = buffer_to_size->find(operand);
292 if (it == buffer_to_size->end()) continue;
293 auto buffer_type = it->getFirst().getType();
294 region.getArgument(i).setType(buffer_type);
295 auto size_arg = region.addArgument(cutil::GetSizeType(builder));
296 (*buffer_to_size)[region.getArgument(i)] = {size_arg,
297 it->getSecond().fixed};
298 }
299 };
300
301 // Rewrite body.
302 Region& body_region = while_op.body();
303 modify_region_arguments(body_region);
304 if (failed(DecomposeTensorListOpsInternal(
305 &body_region.front(), module, buffer_to_size,
306 decomposed_partitioned_call_callees))) {
307 return failure();
308 }
309 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
310 body_region.front(), *buffer_to_size);
311
312 // Rewrite cond.
313 Region& cond_region = while_op.cond();
314 modify_region_arguments(cond_region);
315 if (failed(DecomposeTensorListOpsInternal(
316 &cond_region.front(), module, buffer_to_size,
317 decomposed_partitioned_call_callees))) {
318 return failure();
319 }
320
321 if (output_buffer_to_size.empty()) return success();
322
323 // Create the new while op.
324 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
325 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
326 auto it = buffer_to_size->find(while_op.getOperand(i));
327 if (it == buffer_to_size->end()) continue;
328 new_while_operands.push_back(it->getSecond().size);
329 }
330 auto new_while = builder.create<TF::WhileRegionOp>(
331 while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
332 new_while_operands, while_op.getAttrs());
333 new_while.body().takeBody(body_region);
334 new_while.cond().takeBody(cond_region);
335 for (const auto& entry : output_buffer_to_size) {
336 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
337 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
338 }
339 while_op.replaceAllUsesWith(
340 new_while.getResults().take_front(while_op.getNumResults()));
341 while_op.erase();
342 return success();
343 }
344
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)345 LogicalResult HandleIfRegionOp(
346 TF::IfRegionOp if_op, ModuleOp module,
347 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
348 llvm::StringMap<PartitionedCallDecompositionInfo>*
349 decomposed_partitioned_call_callees) {
350 // Rewrite the branches.
351 Region& then_branch = if_op.then_branch();
352 Region& else_branch = if_op.else_branch();
353 if (failed(DecomposeTensorListOpsInternal(
354 &then_branch.front(), module, buffer_to_size,
355 decomposed_partitioned_call_callees)))
356 return failure();
357 if (failed(DecomposeTensorListOpsInternal(
358 &else_branch.front(), module, buffer_to_size,
359 decomposed_partitioned_call_callees)))
360 return failure();
361
362 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
363 then_branch.front(), *buffer_to_size);
364 AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
365 *buffer_to_size);
366
367 if (output_buffer_to_size.empty()) return success();
368
369 // Recreate the op.
370 auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
371 if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
372 if_op.getOperand(), if_op.getAttrs());
373 for (const auto& entry : output_buffer_to_size) {
374 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
375 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
376 }
377
378 new_op.then_branch().takeBody(if_op.then_branch());
379 new_op.else_branch().takeBody(if_op.else_branch());
380
381 if_op.replaceAllUsesWith(
382 new_op.getResults().take_front(if_op.getNumResults()));
383 if_op.erase();
384 return success();
385 }
386
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)387 LogicalResult HandleCaseRegionOp(
388 TF::CaseRegionOp case_op, ModuleOp module,
389 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
390 llvm::StringMap<PartitionedCallDecompositionInfo>*
391 decomposed_partitioned_call_callees) {
392 // Rewrite the branches.
393 RegionRange branches = case_op.getRegions();
394
395 for (Region* branch : branches) {
396 if (failed(DecomposeTensorListOpsInternal(
397 &branch->front(), module, buffer_to_size,
398 decomposed_partitioned_call_callees)))
399 return failure();
400 }
401
402 // Get the output buffer index to size index mapping one of the branches. It
403 // should be same for all the branches so we only get it for the first branch.
404 Region* first_branch = branches.front();
405 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
406 first_branch->front(), *buffer_to_size);
407 for (Region* branch : branches.drop_front()) {
408 AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
409 *buffer_to_size);
410 }
411
412 if (output_buffer_to_size.empty()) return success();
413
414 // Recreate the op.
415 auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
416 case_op.getLoc(),
417 first_branch->front().getTerminator()->getOperandTypes(),
418 case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions());
419 for (const auto& entry : output_buffer_to_size) {
420 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
421 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
422 }
423
424 for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
425 std::get<0>(pair)->takeBody(*std::get<1>(pair));
426 }
427 case_op.replaceAllUsesWith(
428 new_op.getResults().take_front(case_op.getNumResults()));
429 case_op.erase();
430 return success();
431 }
432
433 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)434 LogicalResult HandlePartitionedCallOp(
435 CallOp call, FuncOp callee, ModuleOp module,
436 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
437 llvm::StringMap<PartitionedCallDecompositionInfo>*
438 decomposed_partitioned_call_callees) {
439 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
440 callee.getName(), PartitionedCallDecompositionInfo());
441 auto& info = emplace_res.first->second;
442 // Recreates the call op with info.
443 auto recreate_caller = [&] {
444 auto new_operands = llvm::to_vector<8>(call.getOperands());
445 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
446 auto arg_it = info.buffer_arg_to_size_arg.find(i);
447 if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
448 auto it = buffer_to_size->find(call.getOperand(i));
449 if (it == buffer_to_size->end()) {
450 call.emitOpError("unknown tensor list.");
451 return failure();
452 }
453 assert(arg_it->second == new_operands.size());
454 new_operands.push_back(it->getSecond().size);
455 }
456 OpBuilder builder(call);
457 auto new_call = builder.create<CallOp>(
458 call.getLoc(), info.decomposed_callee.getType().getResults(),
459 new_operands, call.getAttrs());
460 new_call->setAttr(
461 "f", builder.getSymbolRefAttr(
462 const_cast<FuncOp&>(info.decomposed_callee).getName()));
463 for (const auto& entry : info.buffer_ret_to_size_ret) {
464 (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
465 new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
466 }
467 call.replaceAllUsesWith(
468 new_call.getResults().take_front(call.getNumResults()));
469 call.erase();
470 return success();
471 };
472 if (!emplace_res.second) {
473 // This callee was handled before.
474 if (!info.signature_change) return success();
475 return recreate_caller();
476 }
477 // Rewrite the callee.
478 llvm::SmallDenseMap<Value, SizeInfo> callee_map;
479 FuncOp lowered_callee = callee;
480 if (!callee.isPrivate()) {
481 // Clone non-private callee in case of signature change.
482 lowered_callee = callee.clone();
483 lowered_callee.setPrivate();
484 }
485 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
486 auto it = buffer_to_size->find(call.getOperand(index));
487 if (it == buffer_to_size->end()) return llvm::None;
488 return it->getFirst().getType();
489 };
490 auto arg_buffer_size_is_fixed = [&](int64_t index) {
491 return (*buffer_to_size)[call.getOperand(index)].fixed;
492 };
493 ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
494 &callee_map, find_arg_buffer_type,
495 arg_buffer_size_is_fixed);
496 const bool args_no_changed = callee_map.empty();
497 if (failed(DecomposeTensorListOpsInternal(
498 &lowered_callee.front(), module, &callee_map,
499 decomposed_partitioned_call_callees))) {
500 return failure();
501 }
502 info.buffer_ret_to_size_ret =
503 ModifyFunctionReturn(lowered_callee, callee_map);
504 info.decomposed_callee = lowered_callee;
505 if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
506 // Signature is not modified. We do not need to keep two copies.
507 info.signature_change = false;
508 if (lowered_callee != callee) {
509 lowered_callee.setName(callee.getName());
510 callee.erase();
511 SymbolTable(module).insert(lowered_callee);
512 }
513 } else {
514 info.signature_change = true;
515 for (auto& entry : callee_map) {
516 auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
517 if (!buffer_arg) continue;
518 info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
519 entry.getSecond().size.cast<BlockArgument>().getArgNumber();
520 }
521 if (lowered_callee != callee) {
522 // Add the clone with a new name.
523 lowered_callee.setName(
524 llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str());
525 SymbolTable(module).insert(lowered_callee);
526 callee = lowered_callee;
527 }
528 }
529 if (info.signature_change) return recreate_caller();
530 return success();
531 }
532
533 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
534 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)535 LogicalResult GetConstShapeValue(Value shape_value,
536 llvm::SmallVector<int64_t, 8>* shape) {
537 auto shape_op = shape_value.getDefiningOp();
538 if (!shape_op) return failure();
539 auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
540 if (!shape_const_op) return failure();
541 for (const auto& v : shape_const_op.value().getValues<APInt>()) {
542 int64_t dim_size = v.getSExtValue();
543 if (dim_size == ShapedType::kDynamicSize) return failure();
544 shape->push_back(dim_size);
545 }
546 return success();
547 }
548
549 // Checks the result Variant type to infer the element shape if fully defined.
550 // If the Variant type has multiple subtypes or does not have static shape,
551 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)552 LogicalResult GetElementShapeFromResultType(
553 Type type, llvm::SmallVector<int64_t, 8>* shape) {
554 auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
555 if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
556 TensorType tensor_type = variant_type.getSubtypes().front();
557 if (!tensor_type.hasStaticShape()) return failure();
558 for (auto d : tensor_type.getShape()) shape->push_back(d);
559 return success();
560 }
561
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)562 LogicalResult HandleEmptyTensorListOp(
563 TF::EmptyTensorListOp list,
564 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
565 Value buffer;
566 OpBuilder builder(list);
567 llvm::SmallVector<int64_t, 8> element_shape;
568 // Infer TensorList element shape from the return type first, and then from
569 // the const element shape operand. We first check the return type because
570 // shape inference might have successfully inferred the element shape from
571 // write operations on the TensorList.
572 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
573 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
574 return list.emitOpError("unknown tensor list element shape");
575 }
576 }
577 if (failed(cutil::CreateInitBufferValue(
578 element_shape, list.max_num_elements(), list, list.element_dtype(),
579 builder, &buffer))) {
580 return failure();
581 }
582 Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
583 list.handle().replaceAllUsesWith(buffer);
584 (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
585 list.erase();
586 return success();
587 }
588
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)589 LogicalResult HandleTensorListReserveOp(
590 TF::TensorListReserveOp list,
591 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
592 Value buffer;
593 OpBuilder builder(list);
594 llvm::SmallVector<int64_t, 8> element_shape;
595 // Infer TensorList element shape from the return type first, and then from
596 // the const element shape operand. We first check the return type because
597 // shape inference might have successfully inferred the element shape from
598 // write operations on the TensorList.
599 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
600 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
601 return list.emitOpError("unknown tensor list element shape");
602 }
603 }
604 if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
605 list, list.element_dtype(), builder,
606 &buffer))) {
607 return failure();
608 }
609 Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
610 list.getLoc());
611 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
612 list.handle().replaceAllUsesWith(buffer);
613 list.erase();
614 return success();
615 }
616
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)617 LogicalResult HandleTensorListFromTensorOp(
618 TF::TensorListFromTensorOp list,
619 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
620 OpBuilder builder(list);
621 Value buffer = builder.create<TF::IdentityOp>(
622 list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
623 ArrayRef<Value>{list.tensor()});
624 auto type = buffer.getType().cast<TensorType>();
625 if (!type.hasStaticShape()) {
626 return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
627 }
628 Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
629 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
630 list.output_handle().replaceAllUsesWith(buffer);
631 list.erase();
632 return success();
633 }
634
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)635 LogicalResult HandleTensorListPushBackOp(
636 TF::TensorListPushBackOp push,
637 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
638 auto buffer = push.input_handle();
639 auto it = buffer_to_size->find(buffer);
640 if (it == buffer_to_size->end()) {
641 return push.emitOpError(
642 "found tf.TensorListPushBack on unknown TensorList.");
643 }
644 if (it->getSecond().fixed) {
645 return push.emitError("cannot push on a fixed-size tensor list");
646 }
647 auto size = it->getSecond().size;
648 OpBuilder builder(push);
649 auto new_buffer =
650 cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
651 auto new_size = builder.create<TF::AddV2Op>(
652 push.getLoc(), ArrayRef<Type>{size.getType()},
653 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
654 push.output_handle().replaceAllUsesWith(new_buffer);
655 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
656 push.erase();
657 return success();
658 }
659
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)660 LogicalResult HandleTensorListPopBackOp(
661 TF::TensorListPopBackOp pop,
662 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
663 auto buffer = pop.input_handle();
664 auto it = buffer_to_size->find(buffer);
665 if (it == buffer_to_size->end()) {
666 pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
667 return failure();
668 }
669 if (it->getSecond().fixed) {
670 return pop.emitError("cannot pop on a fixed-size tensor list");
671 }
672 auto size = it->getSecond().size;
673 OpBuilder builder(pop);
674 auto new_buffer = builder.create<TF::IdentityOp>(
675 pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
676 auto new_size = builder.create<TF::SubOp>(
677 pop.getLoc(), ArrayRef<Type>{size.getType()},
678 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
679 auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
680 pop.output_handle().replaceAllUsesWith(new_buffer);
681 pop.tensor().replaceAllUsesWith(element);
682 pop.erase();
683 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
684 return success();
685 }
686
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)687 LogicalResult HandleTensorListGetItemOp(
688 TF::TensorListGetItemOp get_item,
689 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
690 auto buffer = get_item.input_handle();
691 auto it = buffer_to_size.find(buffer);
692 if (it == buffer_to_size.end()) {
693 get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
694 return failure();
695 }
696 OpBuilder builder(get_item);
697 auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
698 get_item.getLoc());
699 auto element =
700 cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
701 get_item.item().replaceAllUsesWith(element);
702 get_item.erase();
703 return success();
704 }
705
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)706 LogicalResult HandleTensorListSetItemOp(
707 TF::TensorListSetItemOp set_item,
708 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
709 auto buffer = set_item.input_handle();
710 auto it = buffer_to_size->find(buffer);
711 if (it == buffer_to_size->end()) {
712 set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
713 return failure();
714 }
715 OpBuilder builder(set_item);
716 auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
717 set_item.getLoc());
718 auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
719 set_item.getLoc());
720 set_item.output_handle().replaceAllUsesWith(new_buffer);
721 auto size = it->getSecond();
722 (*buffer_to_size)[new_buffer] = size;
723 set_item.erase();
724 return success();
725 }
726
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)727 LogicalResult HandleTensorListLengthOp(
728 TF::TensorListLengthOp length,
729 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
730 auto it = buffer_to_size.find(length.input_handle());
731 if (it == buffer_to_size.end()) {
732 length.emitOpError("found tf.TensorListLength on unknown TensorList.");
733 return failure();
734 }
735 OpBuilder builder(length);
736 if (it->getSecond().fixed) {
737 auto dim = cutil::CreateScalarConst(
738 length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
739 builder, length.getLoc());
740 length.length().replaceAllUsesWith(dim);
741 } else {
742 auto current_size = it->getSecond().size;
743 // Reshapes the R1 length to a scalar.
744 auto reshape = builder.create<TF::ReshapeOp>(
745 length.getLoc(),
746 ArrayRef<Type>{RankedTensorType::get(
747 {}, getElementTypeOrSelf(current_size.getType()))},
748 ArrayRef<Value>{current_size,
749 cutil::GetR1Const({}, builder, length.getLoc())});
750 length.length().replaceAllUsesWith(reshape);
751 }
752 length.erase();
753 return success();
754 }
755
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)756 LogicalResult HandleTensorListElementShapeOp(
757 TF::TensorListElementShapeOp elem_shape,
758 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
759 if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
760 return elem_shape.emitOpError("unknown tensor list");
761 }
762 auto buffer = elem_shape.input_handle();
763 auto result = cutil::GetR1Const(
764 buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
765 OpBuilder(elem_shape), elem_shape.getLoc(),
766 elem_shape.shape_type().getIntOrFloatBitWidth());
767 elem_shape.element_shape().replaceAllUsesWith(result);
768 elem_shape.erase();
769 return success();
770 }
771
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)772 LogicalResult HandleTensorListGatherOp(
773 TF::TensorListGatherOp gather,
774 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
775 auto it = buffer_to_size.find(gather.input_handle());
776 if (it == buffer_to_size.end()) {
777 return gather.emitOpError("unknown tensor list");
778 }
779 auto buffer = gather.input_handle();
780 auto result = cutil::GatherElements(gather.indices(), buffer,
781 OpBuilder(gather), gather.getLoc());
782 gather.values().replaceAllUsesWith(result);
783 gather.erase();
784 return success();
785 }
786
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)787 LogicalResult HandleTensorListScatterIntoExistingListOp(
788 TF::TensorListScatterIntoExistingListOp scatter,
789 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
790 auto it = buffer_to_size->find(scatter.input_handle());
791 if (it == buffer_to_size->end()) {
792 return scatter.emitOpError("unknown tensor list");
793 }
794 auto buffer = scatter.input_handle();
795 OpBuilder builder(scatter);
796 auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
797 if (!indices_type) return scatter.emitOpError("unranked indices shape");
798 auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
799 auto shape = builder.create<TF::ConstOp>(
800 scatter.getLoc(),
801 DenseElementsAttr::get(
802 shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
803 auto indices =
804 builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
805 Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
806 scatter.getLoc(), buffer, indices, scatter.tensor());
807 scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
808 scatter.erase();
809 auto size = it->getSecond();
810 (*buffer_to_size)[tensor_scatter_update] = size;
811 return success();
812 }
813
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)814 LogicalResult DecomposeTensorListOpsInternal(
815 Block* block, ModuleOp module,
816 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
817 llvm::StringMap<PartitionedCallDecompositionInfo>*
818 decomposed_partitioned_call_callees) {
819 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
820 // TODO(yuanzx): Add a pass to remove identities in device computation.
821 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
822 op.replaceAllUsesWith(op.getOperands());
823 op.erase();
824 } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
825 if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
826 return failure();
827 }
828 } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
829 if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
830 return failure();
831 }
832 } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
833 if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
834 return failure();
835 }
836 } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
837 if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
838 return failure();
839 }
840 } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
841 if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
842 return failure();
843 }
844 } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
845 if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
846 return failure();
847 }
848 } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
849 if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
850 return failure();
851 }
852 } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
853 if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
854 return failure();
855 }
856 } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
857 stack.tensor().replaceAllUsesWith(stack.input_handle());
858 stack.erase();
859 } else if (auto elem_shape =
860 llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
861 if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
862 return failure();
863 }
864 } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
865 if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
866 return failure();
867 }
868 } else if (auto scatter =
869 llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
870 &op)) {
871 if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
872 buffer_to_size))) {
873 return failure();
874 }
875 } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
876 auto it = buffer_to_size->find(addn.getOperand(0));
877 if (it != buffer_to_size->end()) {
878 addn.sum().setType(addn.getOperand(0).getType());
879 auto size = it->getSecond();
880 (*buffer_to_size)[addn.sum()] = size;
881 }
882 } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
883 if (buffer_to_size->count(zeros.x()) > 0) {
884 zeros.y().setType(zeros.x().getType());
885 auto size = (*buffer_to_size)[zeros.x()];
886 (*buffer_to_size)[zeros.y()] = size;
887 }
888 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
889 if (failed(HandleWhileOp(while_op, module, buffer_to_size,
890 decomposed_partitioned_call_callees))) {
891 return failure();
892 }
893 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
894 if (failed(HandleCaseOrIfOp(
895 if_op, {if_op.then_function(), if_op.else_function()}, module,
896 buffer_to_size, decomposed_partitioned_call_callees))) {
897 return failure();
898 }
899 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
900 SmallVector<FuncOp, 2> branches;
901 case_op.get_branch_functions(branches);
902 if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
903 decomposed_partitioned_call_callees))) {
904 return failure();
905 }
906 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
907 if (!pcall.func())
908 return pcall.emitOpError(
909 "TensorList decomposition does not support call with nested "
910 "references.");
911
912 if (failed(HandlePartitionedCallOp(
913 pcall, pcall.func(), module, buffer_to_size,
914 decomposed_partitioned_call_callees))) {
915 return failure();
916 }
917 } else if (auto spcall =
918 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
919 if (failed(HandlePartitionedCallOp(
920 spcall, spcall.func(), module, buffer_to_size,
921 decomposed_partitioned_call_callees))) {
922 return failure();
923 }
924 } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
925 if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
926 decomposed_partitioned_call_callees))) {
927 return failure();
928 }
929 } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
930 if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
931 decomposed_partitioned_call_callees))) {
932 return failure();
933 }
934 } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
935 if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
936 decomposed_partitioned_call_callees))) {
937 return failure();
938 }
939 }
940 }
941 return success();
942 }
943
DecomposeTensorListOps(Block * block,ModuleOp module)944 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
945 llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
946 llvm::StringMap<PartitionedCallDecompositionInfo>
947 decomposed_partitioned_call_callees;
948 return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
949 &decomposed_partitioned_call_callees);
950 }
951
runOnOperation()952 void TensorListOpsDecompositionPass::runOnOperation() {
953 auto module = getOperation();
954 auto main = module.lookupSymbol<FuncOp>("main");
955 if (!main) return;
956 if (failed(DecomposeTensorListOps(&main.front(), module))) {
957 signalPassFailure();
958 }
959 }
960
961 static PassRegistration<TensorListOpsDecompositionPass> pass(
962 "tf-tensor-list-ops-decomposition",
963 "Decompose tensor list operations into operations on buffers and sizes. "
964 "Needs static shapes.");
965
966 } // namespace
967
968 namespace TF {
969 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()970 CreateTensorListOpsDecompositionPass() {
971 return std::make_unique<TensorListOpsDecompositionPass>();
972 }
973 } // namespace TF
974 } // namespace mlir
975