• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/None.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace mlir {
54 
55 namespace {
56 
57 namespace cutil = TF::collection_ops_util;
58 
59 using std::string;
60 
61 // A pass that converts tensor array operations to tensor operations and
62 // read/assign ops on local variables. A later resource lifting pass can further
63 // remove the local variables.
64 //
65 // This pass requires that the full shape of the tensor array can be inferred:
66 // 1) the size needs to be a constant, 2) it specifies the full element shape,
67 // or that can be inferred from a later write, and 3) all elements have the same
68 // shape.
69 //
70 struct TensorArrayOpsDecompositionPass
71     : public PassWrapper<TensorArrayOpsDecompositionPass,
72                          OperationPass<ModuleOp>> {
73   void runOnOperation() override;
74 };
75 
76 // Infers the element type and count for a TensorArraySplitV3Op. Requires
77 // constant lengths and static shape on the input value.
GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,RankedTensorType * elem_type,int64_t * count)78 LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
79                                           RankedTensorType* elem_type,
80                                           int64_t* count) {
81   auto lengths_const =
82       llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp());
83   if (!lengths_const) return split.emitOpError("non-constant split lengths");
84   *count = lengths_const.value().getNumElements();
85   if (*count <= 0) return split.emitOpError("non-positive split count");
86   auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>();
87   if (!buffer_type || !buffer_type.hasStaticShape() ||
88       buffer_type.getRank() < 1) {
89     return split.emitOpError("unknown or invalid split tensor shape");
90   }
91   int64_t length = buffer_type.getDimSize(0) / *count;
92   for (const auto& len : lengths_const.value().getValues<APInt>()) {
93     if (length == len.getSExtValue()) continue;
94     return split.emitOpError("different split lengths are not supported");
95   }
96   llvm::SmallVector<int64_t, 8> elem_shape;
97   elem_shape.push_back(length);
98   for (int64_t dim : buffer_type.getShape().drop_front()) {
99     elem_shape.push_back(dim);
100   }
101   *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType());
102   return success();
103 }
104 
105 // Tries to infer the tensor array element shape.
GetTensorArrayElementShape(TF::TensorArrayV3Op ta,ModuleOp module)106 llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
107     TF::TensorArrayV3Op ta, ModuleOp module) {
108   auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
109   if (element_shape.hasStaticShape()) {
110     auto shape = element_shape.getShape();
111     // Convert int64 to int64_t.
112     llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
113     return dims;
114   }
115 
116   bool has_failure = false;
117   auto elem_type = cutil::GetElementTypeFromAccess(
118       ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> {
119         if (has_failure) return llvm::None;
120         if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) {
121           return write.value().getType();
122         } else if (auto split =
123                        llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) {
124           if (!split.lengths().getDefiningOp() ||
125               !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) {
126             return llvm::None;
127           }
128           RankedTensorType t;
129           int64_t count;
130           if (failed(GetSplitElementTypeAndCount(split, &t, &count))) {
131             has_failure = true;
132             return llvm::None;
133           }
134           return t;
135         } else if (auto scatter =
136                        llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) {
137           // TensorArrayScatter writes vector of tensors to TensorArray. We can
138           // deduce the shape of TensorArray by dropping the 0th dim of
139           // TensorArrayScatter `value`.
140           auto t = scatter.value().getType().dyn_cast<RankedTensorType>();
141           if (!t || t.getShape().empty()) return llvm::None;
142           return RankedTensorType::get(t.getShape().drop_front(),
143                                        t.getElementType());
144         } else if (auto gather =
145                        llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
146           // Try to infer from result type of gather.
147           auto t = gather.value().getType().dyn_cast<RankedTensorType>();
148           if (t && !t.getShape().empty())
149             return RankedTensorType::get(t.getShape().drop_front(),
150                                          t.getElementType());
151           // Try to infer from `element_shape` attribute of gather.
152           auto element_shape = gather.element_shapeAttr()
153                                    .dyn_cast_or_null<mlir::TF::ShapeAttr>();
154           if (element_shape && element_shape.hasStaticShape()) {
155             return RankedTensorType::get(element_shape.getShape(),
156                                          gather.dtype());
157           }
158         }
159         return llvm::None;
160       });
161   if (!elem_type) return llvm::None;
162   return llvm::to_vector<8>(elem_type->getShape());
163 }
164 
ReplaceAllUsesWithCast(Value old_val,Value new_val)165 void ReplaceAllUsesWithCast(Value old_val, Value new_val) {
166   if (old_val.use_empty()) return;
167   auto cast_op =
168       OpBuilder(old_val.getDefiningOp())
169           .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val);
170   old_val.replaceAllUsesWith(cast_op);
171 }
172 
ReplaceAllUsesExceptTerminator(Value old_val,Value new_val)173 void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) {
174   if (old_val.getType() == new_val.getType()) {
175     old_val.replaceAllUsesWith(new_val);
176     return;
177   }
178   Operation* old_op = old_val.getDefiningOp();
179   Operation* terminator_op =
180       old_op->getParentOfType<FuncOp>().front().getTerminator();
181   llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op};
182   old_val.replaceAllUsesExcept(new_val, exceptions);
183 }
184 
185 struct TensorArrayStats {
186   // Whether a write op should accumulate with the old value. Set to true if
187   // this is a gradient.
188   bool accumulate_on_write;
189   // Maps from a gradient source string to the local variable to the gradient.
190   llvm::StringMap<Value> grads;
191 };
192 
HandleTensorArrayV3Op(TF::TensorArrayV3Op ta,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)193 LogicalResult HandleTensorArrayV3Op(
194     TF::TensorArrayV3Op ta, ModuleOp module,
195     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
196   auto elem_shape = GetTensorArrayElementShape(ta, module);
197   if (!elem_shape) return ta.emitOpError("unknown element shape");
198   if (ta.dynamic_size()) {
199     return ta.emitOpError("dynamic tensor array size is unsupported");
200   }
201   Value buffer;
202   OpBuilder builder(ta);
203   if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta,
204                                           ta.dtype(), builder, &buffer))) {
205     return failure();
206   }
207   auto var_type = RankedTensorType::get(
208       {}, TF::ResourceType::get(
209               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
210               ta.getContext()));
211   auto local_var = builder.create<TF::MlirLocalVarOp>(
212       ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
213   cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
214   ta.handle().replaceAllUsesWith(local_var);
215   // The flow output is just a way for the front end to enforce ordering among
216   // tensor array ops, but in the MLIR TF dialect they have sequential ordering.
217   // Just create a constant to replace its uses.
218   tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {});
219   scalar_tensor.scalar<float>()() = 0.0f;
220   auto flow = builder.create<TF::ConstOp>(
221       ta.getLoc(),
222       tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
223   ta.flow().replaceAllUsesWith(flow);
224   ta.erase();
225   (*stats)[local_var].accumulate_on_write = false;
226   return success();
227 }
228 
HandleTensorArrayReadV3Op(TF::TensorArrayReadV3Op read,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)229 LogicalResult HandleTensorArrayReadV3Op(
230     TF::TensorArrayReadV3Op read,
231     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
232   auto local_var = read.handle();
233   if (stats.count(local_var) == 0) {
234     return read.emitOpError("unknown tensor array");
235   }
236   OpBuilder builder(read);
237   auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc());
238   auto index_reshape =
239       cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc());
240   auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc());
241   ReplaceAllUsesExceptTerminator(read.value(), elem);
242   ReplaceAllUsesWithCast(read.value(), elem);
243   read.erase();
244   // The clear_after_read attribute does not mean setting the tensor to 0 after
245   // read; instead it does not allow a second read before the next write. We
246   // follow the old bridge's implementation not to do anything here.
247   return success();
248 }
249 
HandleTensorArrayWriteV3Op(TF::TensorArrayWriteV3Op write,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)250 LogicalResult HandleTensorArrayWriteV3Op(
251     TF::TensorArrayWriteV3Op write,
252     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
253   auto local_var = write.handle();
254   auto stat_it = stats.find(local_var);
255   if (stat_it == stats.end()) return write.emitOpError("unknown tensor array");
256   OpBuilder builder(write);
257   auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc());
258   auto index_reshape =
259       cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc());
260   auto elem = write.value();
261   if (stat_it->getSecond().accumulate_on_write) {
262     // Get the old slice, and accumulate with it. We set keep_slice_shape
263     // (keeping the leading size-1 dimension) because it avoids reshape back and
264     // forth.
265     auto original_elem =
266         cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
267                           /*keep_slice_shape=*/true);
268     // Add a size-1 leading dimension to elem.
269     auto slice_type = original_elem.getType().cast<RankedTensorType>();
270     elem = builder.create<TF::ReshapeOp>(
271         write.getLoc(), ArrayRef<Type>{slice_type},
272         ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
273                                                 write.getLoc())});
274     elem =
275         cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
276   }
277   buffer =
278       cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc());
279   cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc());
280   write.flow_out().replaceAllUsesWith(write.flow_in());
281   write.erase();
282   return success();
283 }
284 
HandleTensorArrayConcatV3Op(TF::TensorArrayConcatV3Op concat,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)285 LogicalResult HandleTensorArrayConcatV3Op(
286     TF::TensorArrayConcatV3Op concat,
287     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
288   auto local_var = concat.handle();
289   if (stats.count(local_var) == 0) {
290     return concat.emitOpError("unknown tensor array");
291   }
292   OpBuilder builder(concat);
293   auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc());
294   auto buffer_type = buffer.getType().cast<RankedTensorType>();
295   if (buffer_type.getShape().size() <= 1) {
296     return concat.emitOpError("cannot concat on scalar-element tensor array");
297   }
298   // Merget he first two dimensions.
299   auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front());
300   shape[0] *= buffer_type.getDimSize(0);
301   buffer = builder.create<TF::ReshapeOp>(
302       concat.getLoc(),
303       ArrayRef<Type>{
304           RankedTensorType::get(shape, buffer_type.getElementType())},
305       ArrayRef<Value>{buffer,
306                       cutil::GetR1Const(shape, builder, concat.getLoc())});
307   ReplaceAllUsesExceptTerminator(concat.value(), buffer);
308   ReplaceAllUsesWithCast(concat.value(), buffer);
309 
310   // Create the lengths as a list of the same value (element size).
311   tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64,
312                                     {buffer_type.getDimSize(0)});
313   for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) {
314     lengths_tensor.vec<tensorflow::int64>()(i) = buffer_type.getDimSize(1);
315   }
316   concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>(
317       concat.getLoc(),
318       tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie()));
319   concat.erase();
320   return success();
321 }
322 
HandleTensorArraySplitV3Op(TF::TensorArraySplitV3Op split,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)323 LogicalResult HandleTensorArraySplitV3Op(
324     TF::TensorArraySplitV3Op split,
325     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
326   auto local_var = split.handle();
327   if (stats.count(local_var) == 0) {
328     return split.emitOpError("unknown tensor array");
329   }
330   OpBuilder builder(split);
331   int64_t count;
332   RankedTensorType elem_type;
333   if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) {
334     return failure();
335   }
336   llvm::SmallVector<int64_t, 8> buffer_shape;
337   buffer_shape.push_back(count);
338   for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim);
339   // Reshape the input to match the buffer of the tensor array.
340   auto buffer = builder
341                     .create<TF::ReshapeOp>(
342                         split.getLoc(),
343                         ArrayRef<Type>{RankedTensorType::get(
344                             buffer_shape, elem_type.getElementType())},
345                         ArrayRef<Value>{split.value(),
346                                         cutil::GetR1Const(buffer_shape, builder,
347                                                           split.getLoc())})
348                     .output();
349   // Accumulate with the old buffer.
350   auto old_buffer =
351       cutil::ReadLocalVariable(local_var, builder, split.getLoc());
352   buffer =
353       cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc());
354   cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc());
355   split.flow_out().replaceAllUsesWith(split.flow_in());
356   split.erase();
357   return success();
358 }
359 
HandleTensorArraySizeV3Op(TF::TensorArraySizeV3Op size,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)360 LogicalResult HandleTensorArraySizeV3Op(
361     TF::TensorArraySizeV3Op size,
362     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
363   auto local_var = size.handle();
364   if (stats.count(local_var) == 0) {
365     return size.emitOpError("unknown tensor array");
366   }
367   auto buffer_type = getElementTypeOrSelf(local_var.getType())
368                          .cast<TF::ResourceType>()
369                          .getSubtypes()[0]
370                          .cast<RankedTensorType>();
371   OpBuilder builder(size);
372   auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder,
373                                          size.getLoc());
374   size.size().replaceAllUsesWith(result);
375   size.erase();
376   return success();
377 }
378 
CreateAndInitializeGradVariable(Type local_var_type,Operation * op,Value * var)379 LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
380                                               Operation* op, Value* var) {
381   OpBuilder builder(op);
382   *var = builder.create<TF::MlirLocalVarOp>(
383       op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
384   Value buffer;
385   auto buffer_type = getElementTypeOrSelf(local_var_type)
386                          .cast<TF::ResourceType>()
387                          .getSubtypes()[0]
388                          .cast<RankedTensorType>();
389   if (failed(cutil::CreateInitBufferValue(
390           buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
391           buffer_type.getElementType(), builder, &buffer))) {
392     return failure();
393   }
394   cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
395   return success();
396 }
397 
HandleTensorArrayGradV3Op(TF::TensorArrayGradV3Op grad,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)398 LogicalResult HandleTensorArrayGradV3Op(
399     TF::TensorArrayGradV3Op grad,
400     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
401   auto local_var = grad.handle();
402   OpBuilder builder(grad);
403   Value grad_var;
404   auto sit = stats->find(local_var);
405   if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
406   auto emplace_res =
407       sit->getSecond().grads.try_emplace(grad.source().str(), Value());
408   if (!emplace_res.second) {
409     // If the source has been assigned a grad, use it.
410     grad_var = emplace_res.first->second;
411   } else {
412     if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
413                                                &grad_var))) {
414       return failure();
415     }
416     emplace_res.first->second = grad_var;
417     // Write to a grad accumulates with previous writes.
418     (*stats)[grad_var].accumulate_on_write = true;
419   }
420   grad.flow_out().replaceAllUsesWith(grad.flow_in());
421   grad.grad_handle().replaceAllUsesWith(grad_var);
422   grad.erase();
423   return success();
424 }
425 
HandleTensorArrayGatherV3Op(TF::TensorArrayGatherV3Op gather,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)426 LogicalResult HandleTensorArrayGatherV3Op(
427     TF::TensorArrayGatherV3Op gather,
428     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
429   auto local_var = gather.handle();
430   if (stats.count(local_var) == 0) {
431     return gather.emitOpError("unknown tensor array");
432   }
433   OpBuilder builder(gather);
434   auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc());
435   auto result =
436       cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc());
437   ReplaceAllUsesExceptTerminator(gather.value(), result);
438   ReplaceAllUsesWithCast(gather.value(), result);
439   gather.erase();
440   return success();
441 }
442 
HandleTensorArrayScatterV3Op(TF::TensorArrayScatterV3Op scatter,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)443 LogicalResult HandleTensorArrayScatterV3Op(
444     TF::TensorArrayScatterV3Op scatter,
445     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
446   auto local_var = scatter.handle();
447   if (stats.count(local_var) == 0) {
448     return scatter.emitOpError("unknown tensor array");
449   }
450   OpBuilder builder(scatter);
451   auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc());
452   buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(),
453                                             buffer, builder, scatter.getLoc());
454   cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc());
455   scatter.flow_out().replaceAllUsesWith(scatter.flow_in());
456   scatter.erase();
457   return success();
458 }
459 
460 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)461 void UpdateFuncType(FuncOp func) {
462   llvm::SmallVector<Type, 8> arg_types;
463   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
464   func.setType(
465       FunctionType::get(func.getContext(), arg_types,
466                         func.front().getTerminator()->getOperandTypes()));
467 }
468 
469 // Finds the accessed gradient sources for each tensor array argument.
AccessedGradients(ArrayRef<FuncOp> funcs,ModuleOp module)470 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
471     ArrayRef<FuncOp> funcs, ModuleOp module) {
472   llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
473   llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
474   auto insert = [&](Value v, const string& source, const Block& func_block) {
475     auto arg = v.dyn_cast<BlockArgument>();
476     if (!arg || arg.getOwner() != &func_block) return;
477     auto insert_res = result_sets[arg.getArgNumber()].insert(source);
478     if (!insert_res.second) return;
479     result[arg.getArgNumber()].push_back(source);
480   };
481   for (FuncOp func : funcs) {
482     const Block& func_block = func.front();
483     // Walk all operations and nested regions to find accessed gradient sources
484     // for function arguments.
485     func.walk([&](Operation* op) {
486       if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
487         op->replaceAllUsesWith(op->getOperands());
488         return;
489       }
490       if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) {
491         insert(grad.handle(), grad.source().str(), func_block);
492       } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
493         for (const auto& entry : AccessedGradients(
494                  {while_op.body_function(), while_op.cond_function()}, module))
495           for (const string& source : entry.getSecond())
496             insert(while_op.getOperand(entry.getFirst()), source, func_block);
497       } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
498         for (const auto& entry : AccessedGradients(
499                  {if_op.then_function(), if_op.else_function()}, module))
500           for (const string& source : entry.getSecond())
501             insert(if_op.getOperand(entry.getFirst() + 1), source, func_block);
502       } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
503         auto callee = dyn_cast<FuncOp>(call.resolveCallable());
504         for (const auto& entry : AccessedGradients({callee}, module))
505           for (const string& source : entry.getSecond())
506             insert(call.getArgOperands()[entry.getFirst()], source, func_block);
507       }
508     });
509   }
510   return result;
511 }
512 
513 // Contains cached information for decomposed callee functions for (stateful)
514 // partitioned call ops.
515 struct PartitionedCallTensorArrayOpsInfo {
516   bool signature_change;
517   FuncOp decomposed_callee;
518   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
519       arg_grads;
520   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
521 };
522 
523 // Updates a called function's input signature by adjusting resource types, and
524 // adding required gradient arguments.
ChangeFunctionInputSignature(FuncOp func,const llvm::SmallDenseMap<int64_t,llvm::SmallVector<string,4>> & grads,llvm::function_ref<Type (int64_t)> ta_arg_buffer_type,llvm::function_ref<bool (int64_t)> ta_accumulate_on_write,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)525 void ChangeFunctionInputSignature(
526     FuncOp func,
527     const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
528     llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
529     llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
530     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
531   int64_t original_args = func.getNumArguments();
532   for (int64_t argnum = 0; argnum < original_args; ++argnum) {
533     auto arg = func.getArgument(argnum);
534     Type t = ta_arg_buffer_type(argnum);
535     if (!t) continue;
536     arg.setType(t);
537     auto grad_it = grads.find(argnum);
538     if (grad_it == grads.end()) continue;
539     llvm::StringMap<Value> grads_map;
540     for (const string& source : grad_it->getSecond()) {
541       auto g = func.front().addArgument(t);
542       (*stats)[g].accumulate_on_write = true;
543       grads_map[source] = g;
544     }
545     auto& stat = (*stats)[arg];
546     stat.accumulate_on_write = ta_accumulate_on_write(argnum);
547     stat.grads = std::move(grads_map);
548   }
549   UpdateFuncType(func);
550 }
551 
552 LogicalResult DecomposeTensorArrayOps(
553     Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
554     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*);
555 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)556 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
557                             llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
558                             llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
559                                 decomposed_partitioned_call_callees) {
560   auto body = while_op.body_function();
561   auto cond = while_op.cond_function();
562   auto grads = AccessedGradients({body, cond}, module);
563   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
564     auto it = stats->find(while_op.getOperand(index));
565     if (it == stats->end()) return nullptr;
566     return it->getFirst().getType();
567   };
568   auto ta_accumulate_on_write = [&](int64_t index) {
569     auto it = stats->find(while_op.getOperand(index));
570     if (it == stats->end()) return false;
571     return it->getSecond().accumulate_on_write;
572   };
573   llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
574   ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
575                                ta_accumulate_on_write, &body_stats);
576   llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
577   ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
578                                ta_accumulate_on_write, &cond_stats);
579   if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
580                                      decomposed_partitioned_call_callees)) ||
581       failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
582                                      decomposed_partitioned_call_callees))) {
583     return failure();
584   }
585   if (body_stats.empty() && cond_stats.empty()) return success();
586   auto old_body_ret = body.front().getTerminator();
587   auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
588   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
589     if (!ta_arg_buffer_type(i)) continue;
590     auto retval = old_body_ret->getOperand(i);
591     auto arg = retval.dyn_cast<BlockArgument>();
592     if (!arg) {
593       return while_op.emitOpError(
594           "output tensor array does not alias input in a while loop");
595     }
596     for (const string& source : grads[i]) {
597       new_retvals.push_back(body_stats[arg].grads[source]);
598     }
599   }
600   OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
601   old_body_ret->erase();
602   UpdateFuncType(body);
603   // Recreate the while op.
604   auto operands = llvm::to_vector<8>(while_op.getOperands());
605   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
606     auto grad_it = grads.find(i);
607     auto& stat = (*stats)[operands[i]];
608     if (grad_it == grads.end()) continue;
609     for (const string& source : grad_it->getSecond()) {
610       auto it = stat.grads.find(source);
611       if (it != stat.grads.end()) {
612         operands.push_back(it->second);
613       } else {
614         Value grad_var;
615         if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
616                                                    while_op, &grad_var))) {
617           return failure();
618         }
619         stat.grads[source] = grad_var;
620         operands.push_back(grad_var);
621         (*stats)[grad_var].accumulate_on_write = true;
622       }
623     }
624   }
625   OpBuilder builder(while_op);
626   auto new_while =
627       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
628                                   operands, while_op.getAttrs());
629   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
630     if (ta_arg_buffer_type(i)) {
631       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
632     } else {
633       while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
634     }
635   }
636   while_op.erase();
637   return success();
638 }
639 
HandleIfOp(TF::IfOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)640 LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
641                          llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
642                          llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
643                              decomposed_partitioned_call_callees) {
644   auto then_branch = if_op.then_function();
645   auto else_branch = if_op.else_function();
646   auto grads = AccessedGradients({then_branch, else_branch}, module);
647   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
648     auto it = stats->find(if_op.getOperand(index + 1));
649     if (it == stats->end()) return nullptr;
650     return it->getFirst().getType();
651   };
652   auto ta_accumulate_on_write = [&](int64_t index) {
653     auto it = stats->find(if_op.getOperand(index + 1));
654     if (it == stats->end()) return false;
655     return it->getSecond().accumulate_on_write;
656   };
657   llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
658   ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
659                                ta_accumulate_on_write, &then_stats);
660   llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
661   ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
662                                ta_accumulate_on_write, &else_stats);
663   if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
664                                      decomposed_partitioned_call_callees)) ||
665       failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
666                                      decomposed_partitioned_call_callees))) {
667     return failure();
668   }
669   if (then_stats.empty() && else_stats.empty()) return success();
670   // Recreate the if op.
671   auto operands = llvm::to_vector<8>(if_op.getOperands());
672   for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
673     auto grad_it = grads.find(i);
674     auto& stat = (*stats)[operands[i + 1]];
675     if (grad_it == grads.end()) continue;
676     for (const string& source : grad_it->getSecond()) {
677       auto it = stat.grads.find(source);
678       if (it != stat.grads.end()) {
679         operands.push_back(it->second);
680       } else {
681         Value grad_var;
682         if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
683                                                    if_op, &grad_var))) {
684           return failure();
685         }
686         stat.grads[source] = grad_var;
687         operands.push_back(grad_var);
688         (*stats)[grad_var].accumulate_on_write = true;
689       }
690     }
691   }
692   OpBuilder builder(if_op);
693   auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
694                                          then_branch.getType().getResults(),
695                                          operands, if_op.getAttrs());
696   auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
697     auto retval = f.front().getTerminator()->getOperand(ret_ind);
698     auto arg = retval.dyn_cast<BlockArgument>();
699     if (!arg) return -1;
700     return arg.getArgNumber();
701   };
702   for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
703     if (!getElementTypeOrSelf(if_op.getResult(i).getType())
704              .isa<TF::ResourceType>()) {
705       if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
706       continue;
707     }
708     int64_t then_forward_input = ret_forwards_input(then_branch, i);
709     int64_t else_foward_input = ret_forwards_input(else_branch, i);
710     if (then_forward_input != else_foward_input || then_forward_input < 0) {
711       return if_op.emitOpError(
712           "branches do not forward the same input resource");
713     }
714     if_op.getResult(i).replaceAllUsesWith(
715         if_op.getOperand(then_forward_input + 1));
716   }
717   if_op.erase();
718   return success();
719 }
720 
721 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)722 LogicalResult HandlePartitionedCallOp(
723     CallOp call, FuncOp callee, ModuleOp module,
724     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
725     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
726         decomposed_partitioned_call_callees) {
727   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
728       callee.getName(), PartitionedCallTensorArrayOpsInfo());
729   auto& info = emplace_res.first->second;
730   // Recreates the call op with info.
731   auto recreate_caller = [&]() -> LogicalResult {
732     auto new_operands = llvm::to_vector<8>(call.getOperands());
733     for (const auto& entry : info.arg_grads) {
734       auto it = stats->find(call.getOperand(entry.first));
735       if (it == stats->end()) return call.emitOpError("unknown tensor array");
736       for (const string& source : entry.second) {
737         auto grad_it = it->getSecond().grads.find(source);
738         if (grad_it != it->getSecond().grads.end()) {
739           new_operands.push_back(grad_it->second);
740         } else {
741           Value grad_var;
742           if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
743                                                      call, &grad_var))) {
744             return failure();
745           }
746           it->getSecond().grads[source] = grad_var;
747           new_operands.push_back(grad_var);
748         }
749       }
750     }
751     OpBuilder builder(call);
752     auto new_call = builder.create<CallOp>(
753         call.getLoc(), info.decomposed_callee.getType().getResults(),
754         new_operands, call.getAttrs());
755     new_call->setAttr(
756         "f", builder.getSymbolRefAttr(
757                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
758     for (const auto& entry : info.ret_forward_input) {
759       call.getResult(entry.first)
760           .replaceAllUsesWith(call.getOperand(entry.second));
761     }
762     call.replaceAllUsesWith(new_call);
763     call.erase();
764     return success();
765   };
766   if (!emplace_res.second) {
767     // This callee was handled before.
768     if (!info.signature_change) return success();
769     return recreate_caller();
770   }
771   // Rewrite the callee.
772   info.signature_change = false;
773   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
774     auto it = stats->find(call.getOperand(index));
775     if (it == stats->end()) return nullptr;
776     info.signature_change = true;
777     return it->getFirst().getType();
778   };
779   auto ta_accumulate_on_write = [&](int64_t index) {
780     auto it = stats->find(call.getOperand(index));
781     if (it == stats->end()) return false;
782     return it->getSecond().accumulate_on_write;
783   };
784   FuncOp lowered_callee = callee;
785   if (!callee.isPrivate()) {
786     // Clone non-private callee in case of signature change.
787     lowered_callee = callee.clone();
788     lowered_callee.setPrivate();
789   }
790   auto grads = AccessedGradients({lowered_callee}, module);
791   for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) {
792     auto it = grads.find(i);
793     if (it == grads.end()) continue;
794     info.arg_grads.emplace_back(i, it->getSecond());
795   }
796   llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
797   ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type,
798                                ta_accumulate_on_write, &callee_stats);
799   if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module,
800                                      &callee_stats,
801                                      decomposed_partitioned_call_callees))) {
802     return failure();
803   }
804   for (int64_t i = 0; i < call.getNumResults(); ++i) {
805     auto ret = lowered_callee.front().getTerminator()->getOperand(i);
806     if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
807     auto arg = ret.dyn_cast<BlockArgument>();
808     if (!arg) continue;
809     info.ret_forward_input.emplace_back(i, arg.getArgNumber());
810   }
811 
812   info.decomposed_callee = lowered_callee;
813   if (lowered_callee != callee) {
814     if (!info.signature_change) {
815       // Signature is not modified. We do not need to keep two copies.
816       lowered_callee.setName(callee.getName());
817       callee.erase();
818     } else {
819       // Add the clone with a new name.
820       lowered_callee.setName(
821           llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str());
822     }
823     SymbolTable(module).insert(lowered_callee);
824   }
825   if (info.signature_change) return recreate_caller();
826   return success();
827 }
828 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)829 LogicalResult HandleRegionControlFlowOps(
830     Operation& op, ModuleOp module,
831     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
832     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
833         decomposed_partitioned_call_callees) {
834   for (OpOperand& operand : op.getOpOperands()) {
835     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
836       return op.emitOpError()
837              << "found unexpected type " << operand.get().getType()
838              << " of operand #" << operand.getOperandNumber()
839              << ", resource type operands are expected to have been "
840                 "canonicalized away for region based control flow ops";
841     }
842   }
843   for (OpResult result : op.getResults()) {
844     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
845       return op.emitOpError()
846              << "found unexpected type " << result.getType() << " of result #"
847              << result.getResultNumber()
848              << ", resource type results are expected to have been "
849                 "canonicalized away for region based control flow ops";
850     }
851   }
852 
853   for (Region& region : op.getRegions()) {
854     if (failed(DecomposeTensorArrayOps(&region.front(), module, stats,
855                                        decomposed_partitioned_call_callees)))
856       return failure();
857   }
858   return success();
859 }
860 
DecomposeTensorArrayOps(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)861 LogicalResult DecomposeTensorArrayOps(
862     Block* block, ModuleOp module,
863     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
864     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
865         decomposed_partitioned_call_callees) {
866   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
867     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
868       op.replaceAllUsesWith(op.getOperands());
869       op.erase();
870     } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
871       if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
872         return failure();
873       }
874     } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
875       if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
876     } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
877       if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
878     } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
879       if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
880     } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
881       if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
882     } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
883       if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
884     } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
885       if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
886     } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
887       if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
888     } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
889       if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
890         return failure();
891       }
892     } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
893       close.erase();
894     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
895       if (failed(HandleWhileOp(while_op, module, stats,
896                                decomposed_partitioned_call_callees))) {
897         return failure();
898       }
899     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
900       if (failed(HandleIfOp(if_op, module, stats,
901                             decomposed_partitioned_call_callees))) {
902         return failure();
903       }
904     } else if (llvm::isa<TF::CaseRegionOp>(op) ||
905                llvm::isa<TF::IfRegionOp>(op) ||
906                llvm::isa<TF::WhileRegionOp>(op)) {
907       if (failed(HandleRegionControlFlowOps(
908               op, module, stats, decomposed_partitioned_call_callees)))
909         return failure();
910     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
911       auto callee = pcall.func();
912       if (!callee)
913         return pcall.emitOpError(
914             "TensorArray decomposition does not support call with nested "
915             "references.");
916 
917       if (failed(
918               HandlePartitionedCallOp(pcall, callee, module, stats,
919                                       decomposed_partitioned_call_callees))) {
920         return failure();
921       }
922     } else if (auto spcall =
923                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
924       if (failed(
925               HandlePartitionedCallOp(spcall, spcall.func(), module, stats,
926                                       decomposed_partitioned_call_callees))) {
927         return failure();
928       }
929     }
930   }
931   return success();
932 }
933 
runOnOperation()934 void TensorArrayOpsDecompositionPass::runOnOperation() {
935   auto module = getOperation();
936   auto main = module.lookupSymbol<FuncOp>("main");
937   if (!main) return;
938   llvm::SmallDenseMap<Value, TensorArrayStats> stats;
939   llvm::StringMap<PartitionedCallTensorArrayOpsInfo>
940       decomposed_partitioned_call_callees;
941   if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
942                                      &decomposed_partitioned_call_callees))) {
943     signalPassFailure();
944   }
945 }
946 
947 static PassRegistration<TensorArrayOpsDecompositionPass> pass(
948     "tf-tensor-array-ops-decomposition",
949     "Decompose tensor array operations into local variable operations.");
950 
951 }  // namespace
952 
953 namespace TF {
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass()955 CreateTensorArrayOpsDecompositionPass() {
956   return std::make_unique<TensorArrayOpsDecompositionPass>();
957 }
958 
959 }  // namespace TF
960 }  // namespace mlir
961