1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/None.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
27 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/SymbolTable.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/platform/types.h"
52
53 namespace mlir {
54
55 namespace {
56
57 namespace cutil = TF::collection_ops_util;
58
59 using std::string;
60
61 // A pass that converts tensor array operations to tensor operations and
62 // read/assign ops on local variables. A later resource lifting pass can further
63 // remove the local variables.
64 //
65 // This pass requires that the full shape of the tensor array can be inferred:
66 // 1) the size needs to be a constant, 2) it specifies the full element shape,
67 // or that can be inferred from a later write, and 3) all elements have the same
68 // shape.
69 //
70 struct TensorArrayOpsDecompositionPass
71 : public PassWrapper<TensorArrayOpsDecompositionPass,
72 OperationPass<ModuleOp>> {
73 void runOnOperation() override;
74 };
75
76 // Infers the element type and count for a TensorArraySplitV3Op. Requires
77 // constant lengths and static shape on the input value.
GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,RankedTensorType * elem_type,int64_t * count)78 LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
79 RankedTensorType* elem_type,
80 int64_t* count) {
81 auto lengths_const =
82 llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp());
83 if (!lengths_const) return split.emitOpError("non-constant split lengths");
84 *count = lengths_const.value().getNumElements();
85 if (*count <= 0) return split.emitOpError("non-positive split count");
86 auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>();
87 if (!buffer_type || !buffer_type.hasStaticShape() ||
88 buffer_type.getRank() < 1) {
89 return split.emitOpError("unknown or invalid split tensor shape");
90 }
91 int64_t length = buffer_type.getDimSize(0) / *count;
92 for (const auto& len : lengths_const.value().getValues<APInt>()) {
93 if (length == len.getSExtValue()) continue;
94 return split.emitOpError("different split lengths are not supported");
95 }
96 llvm::SmallVector<int64_t, 8> elem_shape;
97 elem_shape.push_back(length);
98 for (int64_t dim : buffer_type.getShape().drop_front()) {
99 elem_shape.push_back(dim);
100 }
101 *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType());
102 return success();
103 }
104
105 // Tries to infer the tensor array element shape.
GetTensorArrayElementShape(TF::TensorArrayV3Op ta,ModuleOp module)106 llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
107 TF::TensorArrayV3Op ta, ModuleOp module) {
108 auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
109 if (element_shape.hasStaticShape()) {
110 auto shape = element_shape.getShape();
111 // Convert int64 to int64_t.
112 llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
113 return dims;
114 }
115
116 bool has_failure = false;
117 auto elem_type = cutil::GetElementTypeFromAccess(
118 ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> {
119 if (has_failure) return llvm::None;
120 if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) {
121 return write.value().getType();
122 } else if (auto split =
123 llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) {
124 if (!split.lengths().getDefiningOp() ||
125 !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) {
126 return llvm::None;
127 }
128 RankedTensorType t;
129 int64_t count;
130 if (failed(GetSplitElementTypeAndCount(split, &t, &count))) {
131 has_failure = true;
132 return llvm::None;
133 }
134 return t;
135 } else if (auto scatter =
136 llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) {
137 // TensorArrayScatter writes vector of tensors to TensorArray. We can
138 // deduce the shape of TensorArray by dropping the 0th dim of
139 // TensorArrayScatter `value`.
140 auto t = scatter.value().getType().dyn_cast<RankedTensorType>();
141 if (!t || t.getShape().empty()) return llvm::None;
142 return RankedTensorType::get(t.getShape().drop_front(),
143 t.getElementType());
144 } else if (auto gather =
145 llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
146 // Try to infer from result type of gather.
147 auto t = gather.value().getType().dyn_cast<RankedTensorType>();
148 if (t && !t.getShape().empty())
149 return RankedTensorType::get(t.getShape().drop_front(),
150 t.getElementType());
151 // Try to infer from `element_shape` attribute of gather.
152 auto element_shape = gather.element_shapeAttr()
153 .dyn_cast_or_null<mlir::TF::ShapeAttr>();
154 if (element_shape && element_shape.hasStaticShape()) {
155 return RankedTensorType::get(element_shape.getShape(),
156 gather.dtype());
157 }
158 }
159 return llvm::None;
160 });
161 if (!elem_type) return llvm::None;
162 return llvm::to_vector<8>(elem_type->getShape());
163 }
164
ReplaceAllUsesWithCast(Value old_val,Value new_val)165 void ReplaceAllUsesWithCast(Value old_val, Value new_val) {
166 if (old_val.use_empty()) return;
167 auto cast_op =
168 OpBuilder(old_val.getDefiningOp())
169 .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val);
170 old_val.replaceAllUsesWith(cast_op);
171 }
172
ReplaceAllUsesExceptTerminator(Value old_val,Value new_val)173 void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) {
174 if (old_val.getType() == new_val.getType()) {
175 old_val.replaceAllUsesWith(new_val);
176 return;
177 }
178 Operation* old_op = old_val.getDefiningOp();
179 Operation* terminator_op =
180 old_op->getParentOfType<FuncOp>().front().getTerminator();
181 llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op};
182 old_val.replaceAllUsesExcept(new_val, exceptions);
183 }
184
185 struct TensorArrayStats {
186 // Whether a write op should accumulate with the old value. Set to true if
187 // this is a gradient.
188 bool accumulate_on_write;
189 // Maps from a gradient source string to the local variable to the gradient.
190 llvm::StringMap<Value> grads;
191 };
192
HandleTensorArrayV3Op(TF::TensorArrayV3Op ta,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)193 LogicalResult HandleTensorArrayV3Op(
194 TF::TensorArrayV3Op ta, ModuleOp module,
195 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
196 auto elem_shape = GetTensorArrayElementShape(ta, module);
197 if (!elem_shape) return ta.emitOpError("unknown element shape");
198 if (ta.dynamic_size()) {
199 return ta.emitOpError("dynamic tensor array size is unsupported");
200 }
201 Value buffer;
202 OpBuilder builder(ta);
203 if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta,
204 ta.dtype(), builder, &buffer))) {
205 return failure();
206 }
207 auto var_type = RankedTensorType::get(
208 {}, TF::ResourceType::get(
209 ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
210 ta.getContext()));
211 auto local_var = builder.create<TF::MlirLocalVarOp>(
212 ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
213 cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
214 ta.handle().replaceAllUsesWith(local_var);
215 // The flow output is just a way for the front end to enforce ordering among
216 // tensor array ops, but in the MLIR TF dialect they have sequential ordering.
217 // Just create a constant to replace its uses.
218 tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {});
219 scalar_tensor.scalar<float>()() = 0.0f;
220 auto flow = builder.create<TF::ConstOp>(
221 ta.getLoc(),
222 tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
223 ta.flow().replaceAllUsesWith(flow);
224 ta.erase();
225 (*stats)[local_var].accumulate_on_write = false;
226 return success();
227 }
228
HandleTensorArrayReadV3Op(TF::TensorArrayReadV3Op read,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)229 LogicalResult HandleTensorArrayReadV3Op(
230 TF::TensorArrayReadV3Op read,
231 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
232 auto local_var = read.handle();
233 if (stats.count(local_var) == 0) {
234 return read.emitOpError("unknown tensor array");
235 }
236 OpBuilder builder(read);
237 auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc());
238 auto index_reshape =
239 cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc());
240 auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc());
241 ReplaceAllUsesExceptTerminator(read.value(), elem);
242 ReplaceAllUsesWithCast(read.value(), elem);
243 read.erase();
244 // The clear_after_read attribute does not mean setting the tensor to 0 after
245 // read; instead it does not allow a second read before the next write. We
246 // follow the old bridge's implementation not to do anything here.
247 return success();
248 }
249
HandleTensorArrayWriteV3Op(TF::TensorArrayWriteV3Op write,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)250 LogicalResult HandleTensorArrayWriteV3Op(
251 TF::TensorArrayWriteV3Op write,
252 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
253 auto local_var = write.handle();
254 auto stat_it = stats.find(local_var);
255 if (stat_it == stats.end()) return write.emitOpError("unknown tensor array");
256 OpBuilder builder(write);
257 auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc());
258 auto index_reshape =
259 cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc());
260 auto elem = write.value();
261 if (stat_it->getSecond().accumulate_on_write) {
262 // Get the old slice, and accumulate with it. We set keep_slice_shape
263 // (keeping the leading size-1 dimension) because it avoids reshape back and
264 // forth.
265 auto original_elem =
266 cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
267 /*keep_slice_shape=*/true);
268 // Add a size-1 leading dimension to elem.
269 auto slice_type = original_elem.getType().cast<RankedTensorType>();
270 elem = builder.create<TF::ReshapeOp>(
271 write.getLoc(), ArrayRef<Type>{slice_type},
272 ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
273 write.getLoc())});
274 elem =
275 cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
276 }
277 buffer =
278 cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc());
279 cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc());
280 write.flow_out().replaceAllUsesWith(write.flow_in());
281 write.erase();
282 return success();
283 }
284
HandleTensorArrayConcatV3Op(TF::TensorArrayConcatV3Op concat,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)285 LogicalResult HandleTensorArrayConcatV3Op(
286 TF::TensorArrayConcatV3Op concat,
287 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
288 auto local_var = concat.handle();
289 if (stats.count(local_var) == 0) {
290 return concat.emitOpError("unknown tensor array");
291 }
292 OpBuilder builder(concat);
293 auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc());
294 auto buffer_type = buffer.getType().cast<RankedTensorType>();
295 if (buffer_type.getShape().size() <= 1) {
296 return concat.emitOpError("cannot concat on scalar-element tensor array");
297 }
298 // Merget he first two dimensions.
299 auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front());
300 shape[0] *= buffer_type.getDimSize(0);
301 buffer = builder.create<TF::ReshapeOp>(
302 concat.getLoc(),
303 ArrayRef<Type>{
304 RankedTensorType::get(shape, buffer_type.getElementType())},
305 ArrayRef<Value>{buffer,
306 cutil::GetR1Const(shape, builder, concat.getLoc())});
307 ReplaceAllUsesExceptTerminator(concat.value(), buffer);
308 ReplaceAllUsesWithCast(concat.value(), buffer);
309
310 // Create the lengths as a list of the same value (element size).
311 tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64,
312 {buffer_type.getDimSize(0)});
313 for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) {
314 lengths_tensor.vec<tensorflow::int64>()(i) = buffer_type.getDimSize(1);
315 }
316 concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>(
317 concat.getLoc(),
318 tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie()));
319 concat.erase();
320 return success();
321 }
322
HandleTensorArraySplitV3Op(TF::TensorArraySplitV3Op split,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)323 LogicalResult HandleTensorArraySplitV3Op(
324 TF::TensorArraySplitV3Op split,
325 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
326 auto local_var = split.handle();
327 if (stats.count(local_var) == 0) {
328 return split.emitOpError("unknown tensor array");
329 }
330 OpBuilder builder(split);
331 int64_t count;
332 RankedTensorType elem_type;
333 if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) {
334 return failure();
335 }
336 llvm::SmallVector<int64_t, 8> buffer_shape;
337 buffer_shape.push_back(count);
338 for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim);
339 // Reshape the input to match the buffer of the tensor array.
340 auto buffer = builder
341 .create<TF::ReshapeOp>(
342 split.getLoc(),
343 ArrayRef<Type>{RankedTensorType::get(
344 buffer_shape, elem_type.getElementType())},
345 ArrayRef<Value>{split.value(),
346 cutil::GetR1Const(buffer_shape, builder,
347 split.getLoc())})
348 .output();
349 // Accumulate with the old buffer.
350 auto old_buffer =
351 cutil::ReadLocalVariable(local_var, builder, split.getLoc());
352 buffer =
353 cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc());
354 cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc());
355 split.flow_out().replaceAllUsesWith(split.flow_in());
356 split.erase();
357 return success();
358 }
359
HandleTensorArraySizeV3Op(TF::TensorArraySizeV3Op size,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)360 LogicalResult HandleTensorArraySizeV3Op(
361 TF::TensorArraySizeV3Op size,
362 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
363 auto local_var = size.handle();
364 if (stats.count(local_var) == 0) {
365 return size.emitOpError("unknown tensor array");
366 }
367 auto buffer_type = getElementTypeOrSelf(local_var.getType())
368 .cast<TF::ResourceType>()
369 .getSubtypes()[0]
370 .cast<RankedTensorType>();
371 OpBuilder builder(size);
372 auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder,
373 size.getLoc());
374 size.size().replaceAllUsesWith(result);
375 size.erase();
376 return success();
377 }
378
CreateAndInitializeGradVariable(Type local_var_type,Operation * op,Value * var)379 LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
380 Operation* op, Value* var) {
381 OpBuilder builder(op);
382 *var = builder.create<TF::MlirLocalVarOp>(
383 op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
384 Value buffer;
385 auto buffer_type = getElementTypeOrSelf(local_var_type)
386 .cast<TF::ResourceType>()
387 .getSubtypes()[0]
388 .cast<RankedTensorType>();
389 if (failed(cutil::CreateInitBufferValue(
390 buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
391 buffer_type.getElementType(), builder, &buffer))) {
392 return failure();
393 }
394 cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
395 return success();
396 }
397
HandleTensorArrayGradV3Op(TF::TensorArrayGradV3Op grad,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)398 LogicalResult HandleTensorArrayGradV3Op(
399 TF::TensorArrayGradV3Op grad,
400 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
401 auto local_var = grad.handle();
402 OpBuilder builder(grad);
403 Value grad_var;
404 auto sit = stats->find(local_var);
405 if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
406 auto emplace_res =
407 sit->getSecond().grads.try_emplace(grad.source().str(), Value());
408 if (!emplace_res.second) {
409 // If the source has been assigned a grad, use it.
410 grad_var = emplace_res.first->second;
411 } else {
412 if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
413 &grad_var))) {
414 return failure();
415 }
416 emplace_res.first->second = grad_var;
417 // Write to a grad accumulates with previous writes.
418 (*stats)[grad_var].accumulate_on_write = true;
419 }
420 grad.flow_out().replaceAllUsesWith(grad.flow_in());
421 grad.grad_handle().replaceAllUsesWith(grad_var);
422 grad.erase();
423 return success();
424 }
425
HandleTensorArrayGatherV3Op(TF::TensorArrayGatherV3Op gather,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)426 LogicalResult HandleTensorArrayGatherV3Op(
427 TF::TensorArrayGatherV3Op gather,
428 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
429 auto local_var = gather.handle();
430 if (stats.count(local_var) == 0) {
431 return gather.emitOpError("unknown tensor array");
432 }
433 OpBuilder builder(gather);
434 auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc());
435 auto result =
436 cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc());
437 ReplaceAllUsesExceptTerminator(gather.value(), result);
438 ReplaceAllUsesWithCast(gather.value(), result);
439 gather.erase();
440 return success();
441 }
442
HandleTensorArrayScatterV3Op(TF::TensorArrayScatterV3Op scatter,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)443 LogicalResult HandleTensorArrayScatterV3Op(
444 TF::TensorArrayScatterV3Op scatter,
445 const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
446 auto local_var = scatter.handle();
447 if (stats.count(local_var) == 0) {
448 return scatter.emitOpError("unknown tensor array");
449 }
450 OpBuilder builder(scatter);
451 auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc());
452 buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(),
453 buffer, builder, scatter.getLoc());
454 cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc());
455 scatter.flow_out().replaceAllUsesWith(scatter.flow_in());
456 scatter.erase();
457 return success();
458 }
459
460 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)461 void UpdateFuncType(FuncOp func) {
462 llvm::SmallVector<Type, 8> arg_types;
463 for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
464 func.setType(
465 FunctionType::get(func.getContext(), arg_types,
466 func.front().getTerminator()->getOperandTypes()));
467 }
468
469 // Finds the accessed gradient sources for each tensor array argument.
AccessedGradients(ArrayRef<FuncOp> funcs,ModuleOp module)470 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
471 ArrayRef<FuncOp> funcs, ModuleOp module) {
472 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
473 llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
474 auto insert = [&](Value v, const string& source, const Block& func_block) {
475 auto arg = v.dyn_cast<BlockArgument>();
476 if (!arg || arg.getOwner() != &func_block) return;
477 auto insert_res = result_sets[arg.getArgNumber()].insert(source);
478 if (!insert_res.second) return;
479 result[arg.getArgNumber()].push_back(source);
480 };
481 for (FuncOp func : funcs) {
482 const Block& func_block = func.front();
483 // Walk all operations and nested regions to find accessed gradient sources
484 // for function arguments.
485 func.walk([&](Operation* op) {
486 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
487 op->replaceAllUsesWith(op->getOperands());
488 return;
489 }
490 if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) {
491 insert(grad.handle(), grad.source().str(), func_block);
492 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
493 for (const auto& entry : AccessedGradients(
494 {while_op.body_function(), while_op.cond_function()}, module))
495 for (const string& source : entry.getSecond())
496 insert(while_op.getOperand(entry.getFirst()), source, func_block);
497 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
498 for (const auto& entry : AccessedGradients(
499 {if_op.then_function(), if_op.else_function()}, module))
500 for (const string& source : entry.getSecond())
501 insert(if_op.getOperand(entry.getFirst() + 1), source, func_block);
502 } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
503 auto callee = dyn_cast<FuncOp>(call.resolveCallable());
504 for (const auto& entry : AccessedGradients({callee}, module))
505 for (const string& source : entry.getSecond())
506 insert(call.getArgOperands()[entry.getFirst()], source, func_block);
507 }
508 });
509 }
510 return result;
511 }
512
513 // Contains cached information for decomposed callee functions for (stateful)
514 // partitioned call ops.
515 struct PartitionedCallTensorArrayOpsInfo {
516 bool signature_change;
517 FuncOp decomposed_callee;
518 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
519 arg_grads;
520 llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
521 };
522
523 // Updates a called function's input signature by adjusting resource types, and
524 // adding required gradient arguments.
ChangeFunctionInputSignature(FuncOp func,const llvm::SmallDenseMap<int64_t,llvm::SmallVector<string,4>> & grads,llvm::function_ref<Type (int64_t)> ta_arg_buffer_type,llvm::function_ref<bool (int64_t)> ta_accumulate_on_write,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)525 void ChangeFunctionInputSignature(
526 FuncOp func,
527 const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
528 llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
529 llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
530 llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
531 int64_t original_args = func.getNumArguments();
532 for (int64_t argnum = 0; argnum < original_args; ++argnum) {
533 auto arg = func.getArgument(argnum);
534 Type t = ta_arg_buffer_type(argnum);
535 if (!t) continue;
536 arg.setType(t);
537 auto grad_it = grads.find(argnum);
538 if (grad_it == grads.end()) continue;
539 llvm::StringMap<Value> grads_map;
540 for (const string& source : grad_it->getSecond()) {
541 auto g = func.front().addArgument(t);
542 (*stats)[g].accumulate_on_write = true;
543 grads_map[source] = g;
544 }
545 auto& stat = (*stats)[arg];
546 stat.accumulate_on_write = ta_accumulate_on_write(argnum);
547 stat.grads = std::move(grads_map);
548 }
549 UpdateFuncType(func);
550 }
551
552 LogicalResult DecomposeTensorArrayOps(
553 Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
554 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*);
555
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)556 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
557 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
558 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
559 decomposed_partitioned_call_callees) {
560 auto body = while_op.body_function();
561 auto cond = while_op.cond_function();
562 auto grads = AccessedGradients({body, cond}, module);
563 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
564 auto it = stats->find(while_op.getOperand(index));
565 if (it == stats->end()) return nullptr;
566 return it->getFirst().getType();
567 };
568 auto ta_accumulate_on_write = [&](int64_t index) {
569 auto it = stats->find(while_op.getOperand(index));
570 if (it == stats->end()) return false;
571 return it->getSecond().accumulate_on_write;
572 };
573 llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
574 ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
575 ta_accumulate_on_write, &body_stats);
576 llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
577 ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
578 ta_accumulate_on_write, &cond_stats);
579 if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
580 decomposed_partitioned_call_callees)) ||
581 failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
582 decomposed_partitioned_call_callees))) {
583 return failure();
584 }
585 if (body_stats.empty() && cond_stats.empty()) return success();
586 auto old_body_ret = body.front().getTerminator();
587 auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
588 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
589 if (!ta_arg_buffer_type(i)) continue;
590 auto retval = old_body_ret->getOperand(i);
591 auto arg = retval.dyn_cast<BlockArgument>();
592 if (!arg) {
593 return while_op.emitOpError(
594 "output tensor array does not alias input in a while loop");
595 }
596 for (const string& source : grads[i]) {
597 new_retvals.push_back(body_stats[arg].grads[source]);
598 }
599 }
600 OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
601 old_body_ret->erase();
602 UpdateFuncType(body);
603 // Recreate the while op.
604 auto operands = llvm::to_vector<8>(while_op.getOperands());
605 for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
606 auto grad_it = grads.find(i);
607 auto& stat = (*stats)[operands[i]];
608 if (grad_it == grads.end()) continue;
609 for (const string& source : grad_it->getSecond()) {
610 auto it = stat.grads.find(source);
611 if (it != stat.grads.end()) {
612 operands.push_back(it->second);
613 } else {
614 Value grad_var;
615 if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
616 while_op, &grad_var))) {
617 return failure();
618 }
619 stat.grads[source] = grad_var;
620 operands.push_back(grad_var);
621 (*stats)[grad_var].accumulate_on_write = true;
622 }
623 }
624 }
625 OpBuilder builder(while_op);
626 auto new_while =
627 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
628 operands, while_op.getAttrs());
629 for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
630 if (ta_arg_buffer_type(i)) {
631 while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
632 } else {
633 while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
634 }
635 }
636 while_op.erase();
637 return success();
638 }
639
HandleIfOp(TF::IfOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)640 LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
641 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
642 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
643 decomposed_partitioned_call_callees) {
644 auto then_branch = if_op.then_function();
645 auto else_branch = if_op.else_function();
646 auto grads = AccessedGradients({then_branch, else_branch}, module);
647 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
648 auto it = stats->find(if_op.getOperand(index + 1));
649 if (it == stats->end()) return nullptr;
650 return it->getFirst().getType();
651 };
652 auto ta_accumulate_on_write = [&](int64_t index) {
653 auto it = stats->find(if_op.getOperand(index + 1));
654 if (it == stats->end()) return false;
655 return it->getSecond().accumulate_on_write;
656 };
657 llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
658 ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
659 ta_accumulate_on_write, &then_stats);
660 llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
661 ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
662 ta_accumulate_on_write, &else_stats);
663 if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
664 decomposed_partitioned_call_callees)) ||
665 failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
666 decomposed_partitioned_call_callees))) {
667 return failure();
668 }
669 if (then_stats.empty() && else_stats.empty()) return success();
670 // Recreate the if op.
671 auto operands = llvm::to_vector<8>(if_op.getOperands());
672 for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
673 auto grad_it = grads.find(i);
674 auto& stat = (*stats)[operands[i + 1]];
675 if (grad_it == grads.end()) continue;
676 for (const string& source : grad_it->getSecond()) {
677 auto it = stat.grads.find(source);
678 if (it != stat.grads.end()) {
679 operands.push_back(it->second);
680 } else {
681 Value grad_var;
682 if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
683 if_op, &grad_var))) {
684 return failure();
685 }
686 stat.grads[source] = grad_var;
687 operands.push_back(grad_var);
688 (*stats)[grad_var].accumulate_on_write = true;
689 }
690 }
691 }
692 OpBuilder builder(if_op);
693 auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
694 then_branch.getType().getResults(),
695 operands, if_op.getAttrs());
696 auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
697 auto retval = f.front().getTerminator()->getOperand(ret_ind);
698 auto arg = retval.dyn_cast<BlockArgument>();
699 if (!arg) return -1;
700 return arg.getArgNumber();
701 };
702 for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
703 if (!getElementTypeOrSelf(if_op.getResult(i).getType())
704 .isa<TF::ResourceType>()) {
705 if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
706 continue;
707 }
708 int64_t then_forward_input = ret_forwards_input(then_branch, i);
709 int64_t else_foward_input = ret_forwards_input(else_branch, i);
710 if (then_forward_input != else_foward_input || then_forward_input < 0) {
711 return if_op.emitOpError(
712 "branches do not forward the same input resource");
713 }
714 if_op.getResult(i).replaceAllUsesWith(
715 if_op.getOperand(then_forward_input + 1));
716 }
717 if_op.erase();
718 return success();
719 }
720
721 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)722 LogicalResult HandlePartitionedCallOp(
723 CallOp call, FuncOp callee, ModuleOp module,
724 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
725 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
726 decomposed_partitioned_call_callees) {
727 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
728 callee.getName(), PartitionedCallTensorArrayOpsInfo());
729 auto& info = emplace_res.first->second;
730 // Recreates the call op with info.
731 auto recreate_caller = [&]() -> LogicalResult {
732 auto new_operands = llvm::to_vector<8>(call.getOperands());
733 for (const auto& entry : info.arg_grads) {
734 auto it = stats->find(call.getOperand(entry.first));
735 if (it == stats->end()) return call.emitOpError("unknown tensor array");
736 for (const string& source : entry.second) {
737 auto grad_it = it->getSecond().grads.find(source);
738 if (grad_it != it->getSecond().grads.end()) {
739 new_operands.push_back(grad_it->second);
740 } else {
741 Value grad_var;
742 if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
743 call, &grad_var))) {
744 return failure();
745 }
746 it->getSecond().grads[source] = grad_var;
747 new_operands.push_back(grad_var);
748 }
749 }
750 }
751 OpBuilder builder(call);
752 auto new_call = builder.create<CallOp>(
753 call.getLoc(), info.decomposed_callee.getType().getResults(),
754 new_operands, call.getAttrs());
755 new_call->setAttr(
756 "f", builder.getSymbolRefAttr(
757 const_cast<FuncOp&>(info.decomposed_callee).getName()));
758 for (const auto& entry : info.ret_forward_input) {
759 call.getResult(entry.first)
760 .replaceAllUsesWith(call.getOperand(entry.second));
761 }
762 call.replaceAllUsesWith(new_call);
763 call.erase();
764 return success();
765 };
766 if (!emplace_res.second) {
767 // This callee was handled before.
768 if (!info.signature_change) return success();
769 return recreate_caller();
770 }
771 // Rewrite the callee.
772 info.signature_change = false;
773 auto ta_arg_buffer_type = [&](int64_t index) -> Type {
774 auto it = stats->find(call.getOperand(index));
775 if (it == stats->end()) return nullptr;
776 info.signature_change = true;
777 return it->getFirst().getType();
778 };
779 auto ta_accumulate_on_write = [&](int64_t index) {
780 auto it = stats->find(call.getOperand(index));
781 if (it == stats->end()) return false;
782 return it->getSecond().accumulate_on_write;
783 };
784 FuncOp lowered_callee = callee;
785 if (!callee.isPrivate()) {
786 // Clone non-private callee in case of signature change.
787 lowered_callee = callee.clone();
788 lowered_callee.setPrivate();
789 }
790 auto grads = AccessedGradients({lowered_callee}, module);
791 for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) {
792 auto it = grads.find(i);
793 if (it == grads.end()) continue;
794 info.arg_grads.emplace_back(i, it->getSecond());
795 }
796 llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
797 ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type,
798 ta_accumulate_on_write, &callee_stats);
799 if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module,
800 &callee_stats,
801 decomposed_partitioned_call_callees))) {
802 return failure();
803 }
804 for (int64_t i = 0; i < call.getNumResults(); ++i) {
805 auto ret = lowered_callee.front().getTerminator()->getOperand(i);
806 if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
807 auto arg = ret.dyn_cast<BlockArgument>();
808 if (!arg) continue;
809 info.ret_forward_input.emplace_back(i, arg.getArgNumber());
810 }
811
812 info.decomposed_callee = lowered_callee;
813 if (lowered_callee != callee) {
814 if (!info.signature_change) {
815 // Signature is not modified. We do not need to keep two copies.
816 lowered_callee.setName(callee.getName());
817 callee.erase();
818 } else {
819 // Add the clone with a new name.
820 lowered_callee.setName(
821 llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str());
822 }
823 SymbolTable(module).insert(lowered_callee);
824 }
825 if (info.signature_change) return recreate_caller();
826 return success();
827 }
828
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)829 LogicalResult HandleRegionControlFlowOps(
830 Operation& op, ModuleOp module,
831 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
832 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
833 decomposed_partitioned_call_callees) {
834 for (OpOperand& operand : op.getOpOperands()) {
835 if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
836 return op.emitOpError()
837 << "found unexpected type " << operand.get().getType()
838 << " of operand #" << operand.getOperandNumber()
839 << ", resource type operands are expected to have been "
840 "canonicalized away for region based control flow ops";
841 }
842 }
843 for (OpResult result : op.getResults()) {
844 if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
845 return op.emitOpError()
846 << "found unexpected type " << result.getType() << " of result #"
847 << result.getResultNumber()
848 << ", resource type results are expected to have been "
849 "canonicalized away for region based control flow ops";
850 }
851 }
852
853 for (Region& region : op.getRegions()) {
854 if (failed(DecomposeTensorArrayOps(®ion.front(), module, stats,
855 decomposed_partitioned_call_callees)))
856 return failure();
857 }
858 return success();
859 }
860
DecomposeTensorArrayOps(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)861 LogicalResult DecomposeTensorArrayOps(
862 Block* block, ModuleOp module,
863 llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
864 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
865 decomposed_partitioned_call_callees) {
866 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
867 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
868 op.replaceAllUsesWith(op.getOperands());
869 op.erase();
870 } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
871 if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
872 return failure();
873 }
874 } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
875 if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
876 } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
877 if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
878 } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
879 if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
880 } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
881 if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
882 } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
883 if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
884 } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
885 if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
886 } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
887 if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
888 } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
889 if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
890 return failure();
891 }
892 } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
893 close.erase();
894 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
895 if (failed(HandleWhileOp(while_op, module, stats,
896 decomposed_partitioned_call_callees))) {
897 return failure();
898 }
899 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
900 if (failed(HandleIfOp(if_op, module, stats,
901 decomposed_partitioned_call_callees))) {
902 return failure();
903 }
904 } else if (llvm::isa<TF::CaseRegionOp>(op) ||
905 llvm::isa<TF::IfRegionOp>(op) ||
906 llvm::isa<TF::WhileRegionOp>(op)) {
907 if (failed(HandleRegionControlFlowOps(
908 op, module, stats, decomposed_partitioned_call_callees)))
909 return failure();
910 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
911 auto callee = pcall.func();
912 if (!callee)
913 return pcall.emitOpError(
914 "TensorArray decomposition does not support call with nested "
915 "references.");
916
917 if (failed(
918 HandlePartitionedCallOp(pcall, callee, module, stats,
919 decomposed_partitioned_call_callees))) {
920 return failure();
921 }
922 } else if (auto spcall =
923 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
924 if (failed(
925 HandlePartitionedCallOp(spcall, spcall.func(), module, stats,
926 decomposed_partitioned_call_callees))) {
927 return failure();
928 }
929 }
930 }
931 return success();
932 }
933
runOnOperation()934 void TensorArrayOpsDecompositionPass::runOnOperation() {
935 auto module = getOperation();
936 auto main = module.lookupSymbol<FuncOp>("main");
937 if (!main) return;
938 llvm::SmallDenseMap<Value, TensorArrayStats> stats;
939 llvm::StringMap<PartitionedCallTensorArrayOpsInfo>
940 decomposed_partitioned_call_callees;
941 if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
942 &decomposed_partitioned_call_callees))) {
943 signalPassFailure();
944 }
945 }
946
947 static PassRegistration<TensorArrayOpsDecompositionPass> pass(
948 "tf-tensor-array-ops-decomposition",
949 "Decompose tensor array operations into local variable operations.");
950
951 } // namespace
952
953 namespace TF {
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass()955 CreateTensorArrayOpsDecompositionPass() {
956 return std::make_unique<TensorArrayOpsDecompositionPass>();
957 }
958
959 } // namespace TF
960 } // namespace mlir
961