1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace mlir {
43
44 namespace {
45
46 namespace cutil = TF::collection_ops_util;
47
48 struct TensorListOpsDecompositionPass
49 : public TF::TensorListOpsDecompositionPassBase<
50 TensorListOpsDecompositionPass> {
51 void runOnOperation() override;
52 };
53
54 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)55 void UpdateFuncType(FuncOp func) {
56 llvm::SmallVector<Type, 8> arg_types;
57 for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
58 func.setType(
59 FunctionType::get(func.getContext(), arg_types,
60 func.front().getTerminator()->getOperandTypes()));
61 }
62
63 // Holds the size value of a tensor list and whether the size is statically
64 // known (fixed).
65 struct SizeInfo {
66 Value size;
67 bool fixed;
68 };
69
70 // Modifies a function's signature to rewrite tensor list arguments to buffers
71 // and sizes.
ModifyFunctionSignature(FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)72 void ModifyFunctionSignature(
73 FuncOp func, Type size_type,
74 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
75 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
76 llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
77 auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
78 int64_t original_arg_count = new_input_types.size();
79 for (int64_t i = 0; i < original_arg_count; ++i) {
80 auto buffer_type = arg_to_buffer_type(i);
81 if (!buffer_type.hasValue()) continue;
82 func.getArgument(i).setType(*buffer_type);
83 new_input_types[i] = *buffer_type;
84 auto size_arg = func.front().addArgument(size_type);
85 new_input_types.push_back(size_arg.getType());
86 if (buffer_to_size) {
87 (*buffer_to_size)[func.getArgument(i)] = {size_arg,
88 arg_buffer_size_is_fixed(i)};
89 }
90 }
91 UpdateFuncType(func);
92 }
93
94 // Holds information about a decomposed callee function for
95 // PartitionedCall/StatefulPartitionedCall.
96 struct PartitionedCallDecompositionInfo {
97 bool signature_change;
98 FuncOp decomposed_callee;
99 llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
100 // Each element is a tuple of (buffer_return_index, size_return_index,
101 // fixed_size).
102 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
103 buffer_ret_to_size_ret;
104 };
105
106 LogicalResult DecomposeTensorListOpsInternal(
107 Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
108 llvm::StringMap<PartitionedCallDecompositionInfo>*);
109
110 // Adds the corresponding sizes of tensor list buffers in block's terminator
111 // to the list of return values. Returns the mapping from the buffer
112 // indices to the added size indices, which is a list of tuples
113 // (buffer_return_index, size_return_index, fixed_size).
114 template <class TerminatorOp>
115 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)116 AddTensorListSizesToTerminator(
117 Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
118 auto old_terminator = block.getTerminator();
119 auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
120 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
121 output_buffer_to_size;
122 for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
123 auto it = buffer_to_size.find(retval.value());
124 if (it == buffer_to_size.end()) continue;
125 output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
126 it->getSecond().fixed);
127 new_outputs.push_back(it->getSecond().size);
128 }
129 OpBuilder(old_terminator)
130 .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
131 old_terminator->erase();
132 return output_buffer_to_size;
133 }
134
135 // Adds the corresponding sizes of tensor list buffers in func's return values
136 // to the list of return values. Returns the mapping from the buffer indices to
137 // the added size indices, which is a list of tuples (buffer_return_index,
138 // size_return_index, fixed_size).
ModifyFunctionReturn(FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)139 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
140 FuncOp func, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
141 auto output_buffer_to_size =
142 AddTensorListSizesToTerminator<ReturnOp>(func.front(), buffer_to_size);
143 UpdateFuncType(func);
144 return output_buffer_to_size;
145 }
146
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)147 LogicalResult HandleWhileOp(
148 TF::WhileOp while_op, ModuleOp module,
149 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
150 llvm::StringMap<PartitionedCallDecompositionInfo>*
151 decomposed_partitioned_call_callees) {
152 // Rewrite body.
153 auto body = while_op.body_function();
154 llvm::SmallDenseMap<Value, SizeInfo> body_map;
155 auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
156 auto it = buffer_to_size->find(while_op.getOperand(index));
157 if (it == buffer_to_size->end()) return llvm::None;
158 return it->getFirst().getType();
159 };
160 auto arg_buffer_size_is_fixed = [&](int64_t index) {
161 return (*buffer_to_size)[while_op.getOperand(index)].fixed;
162 };
163 OpBuilder builder(while_op);
164 ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
165 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
166 if (failed(DecomposeTensorListOpsInternal(
167 &body.front(), module, &body_map,
168 decomposed_partitioned_call_callees))) {
169 return failure();
170 }
171 auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
172
173 // Rewrite cond.
174 auto cond = while_op.cond_function();
175 llvm::SmallDenseMap<Value, SizeInfo> cond_map;
176 ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
177 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
178 if (failed(DecomposeTensorListOpsInternal(
179 &cond.front(), module, &cond_map,
180 decomposed_partitioned_call_callees))) {
181 return failure();
182 }
183 if (output_buffer_to_size.empty()) {
184 return success();
185 }
186 // Create the new while op.
187 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
188 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
189 auto it = buffer_to_size->find(while_op.getOperand(i));
190 if (it == buffer_to_size->end()) continue;
191 new_while_operands.push_back(it->getSecond().size);
192 }
193 auto new_while =
194 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
195 new_while_operands, while_op->getAttrs());
196 for (const auto& entry : output_buffer_to_size) {
197 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
198 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
199 }
200 while_op.replaceAllUsesWith(
201 new_while.getResults().take_front(while_op.getNumResults()));
202 while_op.erase();
203 return success();
204 }
205
206 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)207 LogicalResult HandleCaseOrIfOp(
208 CaseOrIfOp op, ArrayRef<FuncOp> branches, ModuleOp module,
209 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
210 llvm::StringMap<PartitionedCallDecompositionInfo>*
211 decomposed_partitioned_call_callees) {
212 // Rewrite the branches.
213 SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
214 branch_maps.resize(branches.size());
215
216 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
217 auto it = buffer_to_size->find(op.getOperand(index + 1));
218 if (it == buffer_to_size->end()) return llvm::None;
219 return it->getFirst().getType();
220 };
221 auto arg_buffer_size_is_fixed = [&](int64_t index) {
222 return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
223 };
224 OpBuilder builder(op);
225 for (const auto& pair : llvm::zip(branches, branch_maps)) {
226 FuncOp branch = std::get<0>(pair);
227 llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
228 ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
229 find_arg_buffer_type, arg_buffer_size_is_fixed);
230
231 if (failed(DecomposeTensorListOpsInternal(
232 &branch.front(), module, &branch_map,
233 decomposed_partitioned_call_callees)))
234 return failure();
235 }
236
237 const bool arg_no_changed = branch_maps.front().empty();
238 auto output_buffer_to_size =
239 ModifyFunctionReturn(branches.front(), branch_maps.front());
240 for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
241 ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
242
243 if (output_buffer_to_size.empty() && arg_no_changed) return success();
244
245 // Recreate the op.
246 auto new_operands = llvm::to_vector<8>(op.getOperands());
247 for (int64_t i = 1; i < op.getNumOperands(); ++i) {
248 auto it = buffer_to_size->find(op.getOperand(i));
249 if (it == buffer_to_size->end()) continue;
250 new_operands.push_back(it->getSecond().size);
251 }
252 FuncOp first_branch = branches.front();
253 auto new_op = OpBuilder(op).create<CaseOrIfOp>(
254 op.getLoc(), first_branch.getType().getResults(), new_operands,
255 op->getAttrs());
256 for (const auto& entry : output_buffer_to_size) {
257 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
258 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
259 }
260 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
261 op.erase();
262 return success();
263 }
264
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)265 LogicalResult HandleWhileRegionOp(
266 TF::WhileRegionOp while_op, ModuleOp module,
267 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
268 llvm::StringMap<PartitionedCallDecompositionInfo>*
269 decomposed_partitioned_call_callees) {
270 OpBuilder builder(while_op);
271 auto modify_region_arguments = [&](Region& region) {
272 int64_t original_arg_count = region.getNumArguments();
273 for (int64_t i = 0; i < original_arg_count; ++i) {
274 auto operand = while_op.getOperand(i);
275 auto it = buffer_to_size->find(operand);
276 if (it == buffer_to_size->end()) continue;
277 auto buffer_type = it->getFirst().getType();
278 region.getArgument(i).setType(buffer_type);
279 auto size_arg = region.addArgument(cutil::GetSizeType(builder));
280 (*buffer_to_size)[region.getArgument(i)] = {size_arg,
281 it->getSecond().fixed};
282 }
283 };
284
285 // Rewrite body.
286 Region& body_region = while_op.body();
287 modify_region_arguments(body_region);
288 if (failed(DecomposeTensorListOpsInternal(
289 &body_region.front(), module, buffer_to_size,
290 decomposed_partitioned_call_callees))) {
291 return failure();
292 }
293 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
294 body_region.front(), *buffer_to_size);
295
296 // Rewrite cond.
297 Region& cond_region = while_op.cond();
298 modify_region_arguments(cond_region);
299 if (failed(DecomposeTensorListOpsInternal(
300 &cond_region.front(), module, buffer_to_size,
301 decomposed_partitioned_call_callees))) {
302 return failure();
303 }
304
305 if (output_buffer_to_size.empty()) return success();
306
307 // Create the new while op.
308 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
309 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
310 auto it = buffer_to_size->find(while_op.getOperand(i));
311 if (it == buffer_to_size->end()) continue;
312 new_while_operands.push_back(it->getSecond().size);
313 }
314 auto new_while = builder.create<TF::WhileRegionOp>(
315 while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
316 new_while_operands, while_op->getAttrs());
317 new_while.body().takeBody(body_region);
318 new_while.cond().takeBody(cond_region);
319 for (const auto& entry : output_buffer_to_size) {
320 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
321 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
322 }
323 while_op.replaceAllUsesWith(
324 new_while.getResults().take_front(while_op.getNumResults()));
325 while_op.erase();
326 return success();
327 }
328
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)329 LogicalResult HandleIfRegionOp(
330 TF::IfRegionOp if_op, ModuleOp module,
331 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
332 llvm::StringMap<PartitionedCallDecompositionInfo>*
333 decomposed_partitioned_call_callees) {
334 // Rewrite the branches.
335 Region& then_branch = if_op.then_branch();
336 Region& else_branch = if_op.else_branch();
337 if (failed(DecomposeTensorListOpsInternal(
338 &then_branch.front(), module, buffer_to_size,
339 decomposed_partitioned_call_callees)))
340 return failure();
341 if (failed(DecomposeTensorListOpsInternal(
342 &else_branch.front(), module, buffer_to_size,
343 decomposed_partitioned_call_callees)))
344 return failure();
345
346 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
347 then_branch.front(), *buffer_to_size);
348 AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
349 *buffer_to_size);
350
351 if (output_buffer_to_size.empty()) return success();
352
353 // Recreate the op.
354 auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
355 if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
356 if_op.getOperand(), if_op->getAttrs());
357 for (const auto& entry : output_buffer_to_size) {
358 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
359 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
360 }
361
362 new_op.then_branch().takeBody(if_op.then_branch());
363 new_op.else_branch().takeBody(if_op.else_branch());
364
365 if_op.replaceAllUsesWith(
366 new_op.getResults().take_front(if_op.getNumResults()));
367 if_op.erase();
368 return success();
369 }
370
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)371 LogicalResult HandleCaseRegionOp(
372 TF::CaseRegionOp case_op, ModuleOp module,
373 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
374 llvm::StringMap<PartitionedCallDecompositionInfo>*
375 decomposed_partitioned_call_callees) {
376 // Rewrite the branches.
377 RegionRange branches = case_op.getRegions();
378
379 for (Region* branch : branches) {
380 if (failed(DecomposeTensorListOpsInternal(
381 &branch->front(), module, buffer_to_size,
382 decomposed_partitioned_call_callees)))
383 return failure();
384 }
385
386 // Get the output buffer index to size index mapping one of the branches. It
387 // should be same for all the branches so we only get it for the first branch.
388 Region* first_branch = branches.front();
389 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
390 first_branch->front(), *buffer_to_size);
391 for (Region* branch : branches.drop_front()) {
392 AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
393 *buffer_to_size);
394 }
395
396 if (output_buffer_to_size.empty()) return success();
397
398 // Recreate the op.
399 auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
400 case_op.getLoc(),
401 first_branch->front().getTerminator()->getOperandTypes(),
402 case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
403 for (const auto& entry : output_buffer_to_size) {
404 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
405 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
406 }
407
408 for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
409 std::get<0>(pair)->takeBody(*std::get<1>(pair));
410 }
411 case_op.replaceAllUsesWith(
412 new_op.getResults().take_front(case_op.getNumResults()));
413 case_op.erase();
414 return success();
415 }
416
417 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)418 LogicalResult HandlePartitionedCallOp(
419 CallOp call, FuncOp callee, ModuleOp module,
420 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
421 llvm::StringMap<PartitionedCallDecompositionInfo>*
422 decomposed_partitioned_call_callees) {
423 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
424 callee.getName(), PartitionedCallDecompositionInfo());
425 auto& info = emplace_res.first->second;
426 // Recreates the call op with info.
427 auto recreate_caller = [&] {
428 auto new_operands = llvm::to_vector<8>(call.getOperands());
429 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
430 auto arg_it = info.buffer_arg_to_size_arg.find(i);
431 if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
432 auto it = buffer_to_size->find(call.getOperand(i));
433 if (it == buffer_to_size->end()) {
434 call.emitOpError("unknown tensor list.");
435 return failure();
436 }
437 assert(arg_it->second == new_operands.size());
438 new_operands.push_back(it->getSecond().size);
439 }
440 OpBuilder builder(call);
441 auto new_call = builder.create<CallOp>(
442 call.getLoc(), info.decomposed_callee.getType().getResults(),
443 new_operands, call->getAttrs());
444 new_call->setAttr(
445 "f", builder.getSymbolRefAttr(
446 const_cast<FuncOp&>(info.decomposed_callee).getName()));
447 for (const auto& entry : info.buffer_ret_to_size_ret) {
448 (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
449 new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
450 }
451 call.replaceAllUsesWith(
452 new_call.getResults().take_front(call.getNumResults()));
453 call.erase();
454 return success();
455 };
456 if (!emplace_res.second) {
457 // This callee was handled before.
458 if (!info.signature_change) return success();
459 return recreate_caller();
460 }
461 // Rewrite the callee.
462 llvm::SmallDenseMap<Value, SizeInfo> callee_map;
463 FuncOp lowered_callee = callee;
464 if (!callee.isPrivate()) {
465 // Clone non-private callee in case of signature change.
466 lowered_callee = callee.clone();
467 lowered_callee.setPrivate();
468 }
469 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
470 auto it = buffer_to_size->find(call.getOperand(index));
471 if (it == buffer_to_size->end()) return llvm::None;
472 return it->getFirst().getType();
473 };
474 auto arg_buffer_size_is_fixed = [&](int64_t index) {
475 return (*buffer_to_size)[call.getOperand(index)].fixed;
476 };
477 ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
478 &callee_map, find_arg_buffer_type,
479 arg_buffer_size_is_fixed);
480 const bool args_no_changed = callee_map.empty();
481 if (failed(DecomposeTensorListOpsInternal(
482 &lowered_callee.front(), module, &callee_map,
483 decomposed_partitioned_call_callees))) {
484 return failure();
485 }
486 info.buffer_ret_to_size_ret =
487 ModifyFunctionReturn(lowered_callee, callee_map);
488 info.decomposed_callee = lowered_callee;
489 if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
490 // Signature is not modified. We do not need to keep two copies.
491 info.signature_change = false;
492 if (lowered_callee != callee) {
493 lowered_callee.setName(callee.getName());
494 callee.erase();
495 SymbolTable(module).insert(lowered_callee);
496 }
497 } else {
498 info.signature_change = true;
499 for (auto& entry : callee_map) {
500 auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
501 if (!buffer_arg) continue;
502 info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
503 entry.getSecond().size.cast<BlockArgument>().getArgNumber();
504 }
505 if (lowered_callee != callee) {
506 // Add the clone with a new name.
507 lowered_callee.setName(
508 llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str());
509 SymbolTable(module).insert(lowered_callee);
510 callee = lowered_callee;
511 }
512 }
513 if (info.signature_change) return recreate_caller();
514 return success();
515 }
516
517 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
518 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)519 LogicalResult GetConstShapeValue(Value shape_value,
520 llvm::SmallVector<int64_t, 8>* shape) {
521 auto shape_op = shape_value.getDefiningOp();
522 if (!shape_op) return failure();
523 auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
524 if (!shape_const_op) return failure();
525 for (const auto& v : shape_const_op.value().getValues<APInt>()) {
526 int64_t dim_size = v.getSExtValue();
527 if (dim_size == ShapedType::kDynamicSize) return failure();
528 shape->push_back(dim_size);
529 }
530 return success();
531 }
532
533 // Checks the result Variant type to infer the element shape if fully defined.
534 // If the Variant type has multiple subtypes or does not have static shape,
535 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)536 LogicalResult GetElementShapeFromResultType(
537 Type type, llvm::SmallVector<int64_t, 8>* shape) {
538 auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
539 if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
540 TensorType tensor_type = variant_type.getSubtypes().front();
541 if (!tensor_type.hasStaticShape()) return failure();
542 for (auto d : tensor_type.getShape()) shape->push_back(d);
543 return success();
544 }
545
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)546 LogicalResult HandleEmptyTensorListOp(
547 TF::EmptyTensorListOp list,
548 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
549 Value buffer;
550 OpBuilder builder(list);
551 llvm::SmallVector<int64_t, 8> element_shape;
552 // Infer TensorList element shape from the return type first, and then from
553 // the const element shape operand. We first check the return type because
554 // shape inference might have successfully inferred the element shape from
555 // write operations on the TensorList.
556 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
557 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
558 return list.emitOpError("unknown tensor list element shape");
559 }
560 }
561 if (failed(cutil::CreateInitBufferValue(
562 element_shape, list.max_num_elements(), list, list.element_dtype(),
563 builder, &buffer))) {
564 return failure();
565 }
566 Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
567 list.handle().replaceAllUsesWith(buffer);
568 (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
569 list.erase();
570 return success();
571 }
572
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)573 LogicalResult HandleTensorListReserveOp(
574 TF::TensorListReserveOp list,
575 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
576 Value buffer;
577 OpBuilder builder(list);
578 llvm::SmallVector<int64_t, 8> element_shape;
579 // Infer TensorList element shape from the return type first, and then from
580 // the const element shape operand. We first check the return type because
581 // shape inference might have successfully inferred the element shape from
582 // write operations on the TensorList.
583 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
584 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
585 return list.emitOpError("unknown tensor list element shape");
586 }
587 }
588 if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
589 list, list.element_dtype(), builder,
590 &buffer))) {
591 return failure();
592 }
593 Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
594 list.getLoc());
595 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
596 list.handle().replaceAllUsesWith(buffer);
597 list.erase();
598 return success();
599 }
600
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)601 LogicalResult HandleTensorListFromTensorOp(
602 TF::TensorListFromTensorOp list,
603 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
604 OpBuilder builder(list);
605 Value buffer = builder.create<TF::IdentityOp>(
606 list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
607 ArrayRef<Value>{list.tensor()});
608 auto type = buffer.getType().cast<TensorType>();
609 if (!type.hasStaticShape()) {
610 return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
611 }
612 Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
613 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
614 list.output_handle().replaceAllUsesWith(buffer);
615 list.erase();
616 return success();
617 }
618
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)619 LogicalResult HandleTensorListPushBackOp(
620 TF::TensorListPushBackOp push,
621 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
622 auto buffer = push.input_handle();
623 auto it = buffer_to_size->find(buffer);
624 if (it == buffer_to_size->end()) {
625 return push.emitOpError(
626 "found tf.TensorListPushBack on unknown TensorList.");
627 }
628 if (it->getSecond().fixed) {
629 return push.emitError("cannot push on a fixed-size tensor list");
630 }
631 auto size = it->getSecond().size;
632 OpBuilder builder(push);
633 auto new_buffer =
634 cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
635 auto new_size = builder.create<TF::AddV2Op>(
636 push.getLoc(), ArrayRef<Type>{size.getType()},
637 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
638 push.output_handle().replaceAllUsesWith(new_buffer);
639 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
640 push.erase();
641 return success();
642 }
643
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)644 LogicalResult HandleTensorListPopBackOp(
645 TF::TensorListPopBackOp pop,
646 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
647 auto buffer = pop.input_handle();
648 auto it = buffer_to_size->find(buffer);
649 if (it == buffer_to_size->end()) {
650 pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
651 return failure();
652 }
653 if (it->getSecond().fixed) {
654 return pop.emitError("cannot pop on a fixed-size tensor list");
655 }
656 auto size = it->getSecond().size;
657 OpBuilder builder(pop);
658 auto new_buffer = builder.create<TF::IdentityOp>(
659 pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
660 auto new_size = builder.create<TF::SubOp>(
661 pop.getLoc(), ArrayRef<Type>{size.getType()},
662 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
663 auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
664 pop.output_handle().replaceAllUsesWith(new_buffer);
665 pop.tensor().replaceAllUsesWith(element);
666 pop.erase();
667 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
668 return success();
669 }
670
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)671 LogicalResult HandleTensorListGetItemOp(
672 TF::TensorListGetItemOp get_item,
673 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
674 auto buffer = get_item.input_handle();
675 auto it = buffer_to_size.find(buffer);
676 if (it == buffer_to_size.end()) {
677 get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
678 return failure();
679 }
680 OpBuilder builder(get_item);
681 auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
682 get_item.getLoc());
683 auto element =
684 cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
685 get_item.item().replaceAllUsesWith(element);
686 get_item.erase();
687 return success();
688 }
689
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)690 LogicalResult HandleTensorListSetItemOp(
691 TF::TensorListSetItemOp set_item,
692 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
693 auto buffer = set_item.input_handle();
694 auto it = buffer_to_size->find(buffer);
695 if (it == buffer_to_size->end()) {
696 set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
697 return failure();
698 }
699 OpBuilder builder(set_item);
700 auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
701 set_item.getLoc());
702 auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
703 set_item.getLoc());
704 set_item.output_handle().replaceAllUsesWith(new_buffer);
705 auto size = it->getSecond();
706 (*buffer_to_size)[new_buffer] = size;
707 set_item.erase();
708 return success();
709 }
710
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)711 LogicalResult HandleTensorListLengthOp(
712 TF::TensorListLengthOp length,
713 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
714 auto it = buffer_to_size.find(length.input_handle());
715 if (it == buffer_to_size.end()) {
716 length.emitOpError("found tf.TensorListLength on unknown TensorList.");
717 return failure();
718 }
719 OpBuilder builder(length);
720 if (it->getSecond().fixed) {
721 auto dim = cutil::CreateScalarConst(
722 length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
723 builder, length.getLoc());
724 length.length().replaceAllUsesWith(dim);
725 } else {
726 auto current_size = it->getSecond().size;
727 // Reshapes the R1 length to a scalar.
728 auto reshape = builder.create<TF::ReshapeOp>(
729 length.getLoc(),
730 ArrayRef<Type>{RankedTensorType::get(
731 {}, getElementTypeOrSelf(current_size.getType()))},
732 ArrayRef<Value>{current_size,
733 cutil::GetR1Const({}, builder, length.getLoc())});
734 length.length().replaceAllUsesWith(reshape);
735 }
736 length.erase();
737 return success();
738 }
739
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)740 LogicalResult HandleTensorListElementShapeOp(
741 TF::TensorListElementShapeOp elem_shape,
742 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
743 if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
744 return elem_shape.emitOpError("unknown tensor list");
745 }
746 auto buffer = elem_shape.input_handle();
747 auto result = cutil::GetR1Const(
748 buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
749 OpBuilder(elem_shape), elem_shape.getLoc(),
750 elem_shape.shape_type().getIntOrFloatBitWidth());
751 elem_shape.element_shape().replaceAllUsesWith(result);
752 elem_shape.erase();
753 return success();
754 }
755
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)756 LogicalResult HandleTensorListGatherOp(
757 TF::TensorListGatherOp gather,
758 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
759 auto it = buffer_to_size.find(gather.input_handle());
760 if (it == buffer_to_size.end()) {
761 return gather.emitOpError("unknown tensor list");
762 }
763 auto buffer = gather.input_handle();
764 auto result = cutil::GatherElements(gather.indices(), buffer,
765 OpBuilder(gather), gather.getLoc());
766 gather.values().replaceAllUsesWith(result);
767 gather.erase();
768 return success();
769 }
770
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)771 LogicalResult HandleTensorListScatterIntoExistingListOp(
772 TF::TensorListScatterIntoExistingListOp scatter,
773 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
774 auto it = buffer_to_size->find(scatter.input_handle());
775 if (it == buffer_to_size->end()) {
776 return scatter.emitOpError("unknown tensor list");
777 }
778 auto buffer = scatter.input_handle();
779 OpBuilder builder(scatter);
780 auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
781 if (!indices_type) return scatter.emitOpError("unranked indices shape");
782 auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
783 auto shape = builder.create<TF::ConstOp>(
784 scatter.getLoc(),
785 DenseElementsAttr::get(
786 shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
787 auto indices =
788 builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
789 Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
790 scatter.getLoc(), buffer, indices, scatter.tensor());
791 scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
792 scatter.erase();
793 auto size = it->getSecond();
794 (*buffer_to_size)[tensor_scatter_update] = size;
795 return success();
796 }
797
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)798 LogicalResult DecomposeTensorListOpsInternal(
799 Block* block, ModuleOp module,
800 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
801 llvm::StringMap<PartitionedCallDecompositionInfo>*
802 decomposed_partitioned_call_callees) {
803 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
804 // TODO(yuanzx): Add a pass to remove identities in device computation.
805 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
806 op.replaceAllUsesWith(op.getOperands());
807 op.erase();
808 } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
809 if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
810 return failure();
811 }
812 } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
813 if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
814 return failure();
815 }
816 } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
817 if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
818 return failure();
819 }
820 } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
821 if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
822 return failure();
823 }
824 } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
825 if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
826 return failure();
827 }
828 } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
829 if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
830 return failure();
831 }
832 } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
833 if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
834 return failure();
835 }
836 } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
837 if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
838 return failure();
839 }
840 } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
841 stack.tensor().replaceAllUsesWith(stack.input_handle());
842 stack.erase();
843 } else if (auto elem_shape =
844 llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
845 if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
846 return failure();
847 }
848 } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
849 if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
850 return failure();
851 }
852 } else if (auto scatter =
853 llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
854 &op)) {
855 if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
856 buffer_to_size))) {
857 return failure();
858 }
859 } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
860 auto it = buffer_to_size->find(addn.getOperand(0));
861 if (it != buffer_to_size->end()) {
862 addn.sum().setType(addn.getOperand(0).getType());
863 auto size = it->getSecond();
864 (*buffer_to_size)[addn.sum()] = size;
865 }
866 } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
867 if (buffer_to_size->count(zeros.x()) > 0) {
868 zeros.y().setType(zeros.x().getType());
869 auto size = (*buffer_to_size)[zeros.x()];
870 (*buffer_to_size)[zeros.y()] = size;
871 }
872 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
873 if (failed(HandleWhileOp(while_op, module, buffer_to_size,
874 decomposed_partitioned_call_callees))) {
875 return failure();
876 }
877 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
878 if (failed(HandleCaseOrIfOp(
879 if_op, {if_op.then_function(), if_op.else_function()}, module,
880 buffer_to_size, decomposed_partitioned_call_callees))) {
881 return failure();
882 }
883 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
884 SmallVector<FuncOp, 2> branches;
885 case_op.get_branch_functions(branches);
886 if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
887 decomposed_partitioned_call_callees))) {
888 return failure();
889 }
890 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
891 if (!pcall.func())
892 return pcall.emitOpError(
893 "TensorList decomposition does not support call with nested "
894 "references.");
895
896 if (failed(HandlePartitionedCallOp(
897 pcall, pcall.func(), module, buffer_to_size,
898 decomposed_partitioned_call_callees))) {
899 return failure();
900 }
901 } else if (auto spcall =
902 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
903 if (failed(HandlePartitionedCallOp(
904 spcall, spcall.func(), module, buffer_to_size,
905 decomposed_partitioned_call_callees))) {
906 return failure();
907 }
908 } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
909 if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
910 decomposed_partitioned_call_callees))) {
911 return failure();
912 }
913 } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
914 if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
915 decomposed_partitioned_call_callees))) {
916 return failure();
917 }
918 } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
919 if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
920 decomposed_partitioned_call_callees))) {
921 return failure();
922 }
923 }
924 }
925 return success();
926 }
927
DecomposeTensorListOps(Block * block,ModuleOp module)928 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
929 llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
930 llvm::StringMap<PartitionedCallDecompositionInfo>
931 decomposed_partitioned_call_callees;
932 return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
933 &decomposed_partitioned_call_callees);
934 }
935
runOnOperation()936 void TensorListOpsDecompositionPass::runOnOperation() {
937 auto module = getOperation();
938 auto main = module.lookupSymbol<FuncOp>("main");
939 if (!main) return;
940 if (failed(DecomposeTensorListOps(&main.front(), module))) {
941 signalPassFailure();
942 }
943 }
944
945 } // namespace
946
947 namespace TF {
948 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()949 CreateTensorListOpsDecompositionPass() {
950 return std::make_unique<TensorListOpsDecompositionPass>();
951 }
952 } // namespace TF
953 } // namespace mlir
954