• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/None.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/platform/types.h"
52 
53 namespace mlir {
54 namespace {
55 
56 namespace cutil = TF::collection_ops_util;
57 
58 using std::string;
59 
60 // A pass that converts tensor array operations to tensor operations and
61 // read/assign ops on local variables. A later resource lifting pass can further
62 // remove the local variables.
63 //
64 // This pass requires that the full shape of the tensor array can be inferred:
65 // 1) the size needs to be a constant, 2) it specifies the full element shape,
66 // or that can be inferred from a later write, and 3) all elements have the same
67 // shape.
68 //
69 struct TensorArrayOpsDecompositionPass
70     : public TF::TensorArrayOpsDecompositionPassBase<
71           TensorArrayOpsDecompositionPass> {
72   void runOnOperation() override;
73 };
74 
75 // Infers the element type and count for a TensorArraySplitV3Op. Requires
76 // constant lengths and static shape on the input value.
GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,RankedTensorType * elem_type,int64_t * count)77 LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
78                                           RankedTensorType* elem_type,
79                                           int64_t* count) {
80   auto lengths_const =
81       llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp());
82   if (!lengths_const) return split.emitOpError("non-constant split lengths");
83   *count = lengths_const.value().getNumElements();
84   if (*count <= 0) return split.emitOpError("non-positive split count");
85   auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>();
86   if (!buffer_type || !buffer_type.hasStaticShape() ||
87       buffer_type.getRank() < 1) {
88     return split.emitOpError("unknown or invalid split tensor shape");
89   }
90   int64_t length = buffer_type.getDimSize(0) / *count;
91   for (const auto& len : lengths_const.value().getValues<APInt>()) {
92     if (length == len.getSExtValue()) continue;
93     return split.emitOpError("different split lengths are not supported");
94   }
95   llvm::SmallVector<int64_t, 8> elem_shape;
96   elem_shape.push_back(length);
97   for (int64_t dim : buffer_type.getShape().drop_front()) {
98     elem_shape.push_back(dim);
99   }
100   *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType());
101   return success();
102 }
103 
104 // Tries to infer the tensor array element shape.
GetTensorArrayElementShape(TF::TensorArrayV3Op ta,ModuleOp module)105 llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape(
106     TF::TensorArrayV3Op ta, ModuleOp module) {
107   auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>();
108   if (element_shape.hasStaticShape()) {
109     auto shape = element_shape.getShape();
110     // Convert int64 to int64_t.
111     llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end());
112     return dims;
113   }
114 
115   bool has_failure = false;
116   auto elem_type = cutil::GetElementTypeFromAccess(
117       ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> {
118         if (has_failure) return llvm::None;
119         if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) {
120           return write.value().getType();
121         } else if (auto split =
122                        llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) {
123           if (!split.lengths().getDefiningOp() ||
124               !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) {
125             return llvm::None;
126           }
127           RankedTensorType t;
128           int64_t count;
129           if (failed(GetSplitElementTypeAndCount(split, &t, &count))) {
130             has_failure = true;
131             return llvm::None;
132           }
133           return t;
134         } else if (auto scatter =
135                        llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) {
136           // TensorArrayScatter writes vector of tensors to TensorArray. We can
137           // deduce the shape of TensorArray by dropping the 0th dim of
138           // TensorArrayScatter `value`.
139           auto t = scatter.value().getType().dyn_cast<RankedTensorType>();
140           if (!t || t.getShape().empty()) return llvm::None;
141           return RankedTensorType::get(t.getShape().drop_front(),
142                                        t.getElementType());
143         } else if (auto gather =
144                        llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) {
145           // Try to infer from result type of gather.
146           auto t = gather.value().getType().dyn_cast<RankedTensorType>();
147           if (t && !t.getShape().empty())
148             return RankedTensorType::get(t.getShape().drop_front(),
149                                          t.getElementType());
150           // Try to infer from `element_shape` attribute of gather.
151           auto element_shape = gather.element_shapeAttr()
152                                    .dyn_cast_or_null<mlir::TF::ShapeAttr>();
153           if (element_shape && element_shape.hasStaticShape()) {
154             return RankedTensorType::get(element_shape.getShape(),
155                                          gather.dtype());
156           }
157         }
158         return llvm::None;
159       });
160   if (!elem_type) return llvm::None;
161   return llvm::to_vector<8>(elem_type->getShape());
162 }
163 
ReplaceAllUsesWithCast(Value old_val,Value new_val)164 void ReplaceAllUsesWithCast(Value old_val, Value new_val) {
165   if (old_val.use_empty()) return;
166   auto cast_op =
167       OpBuilder(old_val.getDefiningOp())
168           .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val);
169   old_val.replaceAllUsesWith(cast_op);
170 }
171 
ReplaceAllUsesExceptTerminator(Value old_val,Value new_val)172 void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) {
173   if (old_val.getType() == new_val.getType()) {
174     old_val.replaceAllUsesWith(new_val);
175     return;
176   }
177   Operation* old_op = old_val.getDefiningOp();
178   Operation* terminator_op =
179       old_op->getParentOfType<FuncOp>().front().getTerminator();
180   llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op};
181   old_val.replaceAllUsesExcept(new_val, exceptions);
182 }
183 
184 struct TensorArrayStats {
185   // Whether a write op should accumulate with the old value. Set to true if
186   // this is a gradient.
187   bool accumulate_on_write;
188   // Maps from a gradient source string to the local variable to the gradient.
189   llvm::StringMap<Value> grads;
190 };
191 
HandleTensorArrayV3Op(TF::TensorArrayV3Op ta,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)192 LogicalResult HandleTensorArrayV3Op(
193     TF::TensorArrayV3Op ta, ModuleOp module,
194     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
195   auto elem_shape = GetTensorArrayElementShape(ta, module);
196   if (!elem_shape) return ta.emitOpError("unknown element shape");
197   if (ta.dynamic_size()) {
198     return ta.emitOpError("dynamic tensor array size is unsupported");
199   }
200   Value buffer;
201   OpBuilder builder(ta);
202   if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta,
203                                           ta.dtype(), builder, &buffer))) {
204     return failure();
205   }
206   auto var_type = RankedTensorType::get(
207       {}, TF::ResourceType::get(
208               ArrayRef<TensorType>{buffer.getType().cast<TensorType>()},
209               ta.getContext()));
210   auto local_var = builder.create<TF::MlirLocalVarOp>(
211       ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{});
212   cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc());
213   ta.handle().replaceAllUsesWith(local_var);
214   // The flow output is just a way for the front end to enforce ordering among
215   // tensor array ops, but in the MLIR TF dialect they have sequential ordering.
216   // Just create a constant to replace its uses.
217   tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {});
218   scalar_tensor.scalar<float>()() = 0.0f;
219   auto flow = builder.create<TF::ConstOp>(
220       ta.getLoc(),
221       tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
222   ta.flow().replaceAllUsesWith(flow);
223   ta.erase();
224   (*stats)[local_var].accumulate_on_write = false;
225   return success();
226 }
227 
HandleTensorArrayReadV3Op(TF::TensorArrayReadV3Op read,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)228 LogicalResult HandleTensorArrayReadV3Op(
229     TF::TensorArrayReadV3Op read,
230     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
231   auto local_var = read.handle();
232   if (stats.count(local_var) == 0) {
233     return read.emitOpError("unknown tensor array");
234   }
235   OpBuilder builder(read);
236   auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc());
237   auto index_reshape =
238       cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc());
239   auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc());
240   ReplaceAllUsesExceptTerminator(read.value(), elem);
241   ReplaceAllUsesWithCast(read.value(), elem);
242   read.erase();
243   // The clear_after_read attribute does not mean setting the tensor to 0 after
244   // read; instead it does not allow a second read before the next write. We
245   // follow the old bridge's implementation not to do anything here.
246   return success();
247 }
248 
HandleTensorArrayWriteV3Op(TF::TensorArrayWriteV3Op write,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)249 LogicalResult HandleTensorArrayWriteV3Op(
250     TF::TensorArrayWriteV3Op write,
251     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
252   auto local_var = write.handle();
253   auto stat_it = stats.find(local_var);
254   if (stat_it == stats.end()) return write.emitOpError("unknown tensor array");
255   OpBuilder builder(write);
256   auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc());
257   auto index_reshape =
258       cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc());
259   auto elem = write.value();
260   if (stat_it->getSecond().accumulate_on_write) {
261     // Get the old slice, and accumulate with it. We set keep_slice_shape
262     // (keeping the leading size-1 dimension) because it avoids reshape back and
263     // forth.
264     auto original_elem =
265         cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
266                           /*keep_slice_shape=*/true);
267     // Add a size-1 leading dimension to elem.
268     auto slice_type = original_elem.getType().cast<RankedTensorType>();
269     elem = builder.create<TF::ReshapeOp>(
270         write.getLoc(), ArrayRef<Type>{slice_type},
271         ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
272                                                 write.getLoc())});
273     elem =
274         cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc());
275   }
276   buffer =
277       cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc());
278   cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc());
279   write.flow_out().replaceAllUsesWith(write.flow_in());
280   write.erase();
281   return success();
282 }
283 
HandleTensorArrayConcatV3Op(TF::TensorArrayConcatV3Op concat,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)284 LogicalResult HandleTensorArrayConcatV3Op(
285     TF::TensorArrayConcatV3Op concat,
286     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
287   auto local_var = concat.handle();
288   if (stats.count(local_var) == 0) {
289     return concat.emitOpError("unknown tensor array");
290   }
291   OpBuilder builder(concat);
292   auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc());
293   auto buffer_type = buffer.getType().cast<RankedTensorType>();
294   if (buffer_type.getShape().size() <= 1) {
295     return concat.emitOpError("cannot concat on scalar-element tensor array");
296   }
297   // Merget he first two dimensions.
298   auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front());
299   shape[0] *= buffer_type.getDimSize(0);
300   buffer = builder.create<TF::ReshapeOp>(
301       concat.getLoc(),
302       ArrayRef<Type>{
303           RankedTensorType::get(shape, buffer_type.getElementType())},
304       ArrayRef<Value>{buffer,
305                       cutil::GetR1Const(shape, builder, concat.getLoc())});
306   ReplaceAllUsesExceptTerminator(concat.value(), buffer);
307   ReplaceAllUsesWithCast(concat.value(), buffer);
308 
309   // Create the lengths as a list of the same value (element size).
310   tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64,
311                                     {buffer_type.getDimSize(0)});
312   for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) {
313     lengths_tensor.vec<tensorflow::int64>()(i) = buffer_type.getDimSize(1);
314   }
315   concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>(
316       concat.getLoc(),
317       tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie()));
318   concat.erase();
319   return success();
320 }
321 
HandleTensorArraySplitV3Op(TF::TensorArraySplitV3Op split,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)322 LogicalResult HandleTensorArraySplitV3Op(
323     TF::TensorArraySplitV3Op split,
324     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
325   auto local_var = split.handle();
326   if (stats.count(local_var) == 0) {
327     return split.emitOpError("unknown tensor array");
328   }
329   OpBuilder builder(split);
330   int64_t count;
331   RankedTensorType elem_type;
332   if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) {
333     return failure();
334   }
335   llvm::SmallVector<int64_t, 8> buffer_shape;
336   buffer_shape.push_back(count);
337   for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim);
338   // Reshape the input to match the buffer of the tensor array.
339   auto buffer = builder
340                     .create<TF::ReshapeOp>(
341                         split.getLoc(),
342                         ArrayRef<Type>{RankedTensorType::get(
343                             buffer_shape, elem_type.getElementType())},
344                         ArrayRef<Value>{split.value(),
345                                         cutil::GetR1Const(buffer_shape, builder,
346                                                           split.getLoc())})
347                     .output();
348   // Accumulate with the old buffer.
349   auto old_buffer =
350       cutil::ReadLocalVariable(local_var, builder, split.getLoc());
351   buffer =
352       cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc());
353   cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc());
354   split.flow_out().replaceAllUsesWith(split.flow_in());
355   split.erase();
356   return success();
357 }
358 
HandleTensorArraySizeV3Op(TF::TensorArraySizeV3Op size,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)359 LogicalResult HandleTensorArraySizeV3Op(
360     TF::TensorArraySizeV3Op size,
361     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
362   auto local_var = size.handle();
363   if (stats.count(local_var) == 0) {
364     return size.emitOpError("unknown tensor array");
365   }
366   auto buffer_type = getElementTypeOrSelf(local_var.getType())
367                          .cast<TF::ResourceType>()
368                          .getSubtypes()[0]
369                          .cast<RankedTensorType>();
370   OpBuilder builder(size);
371   auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder,
372                                          size.getLoc());
373   size.size().replaceAllUsesWith(result);
374   size.erase();
375   return success();
376 }
377 
CreateAndInitializeGradVariable(Type local_var_type,Operation * op,Value * var)378 LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
379                                               Operation* op, Value* var) {
380   OpBuilder builder(op);
381   *var = builder.create<TF::MlirLocalVarOp>(
382       op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{});
383   Value buffer;
384   auto buffer_type = getElementTypeOrSelf(local_var_type)
385                          .cast<TF::ResourceType>()
386                          .getSubtypes()[0]
387                          .cast<RankedTensorType>();
388   if (failed(cutil::CreateInitBufferValue(
389           buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
390           buffer_type.getElementType(), builder, &buffer))) {
391     return failure();
392   }
393   cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
394   return success();
395 }
396 
HandleTensorArrayGradV3Op(TF::TensorArrayGradV3Op grad,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)397 LogicalResult HandleTensorArrayGradV3Op(
398     TF::TensorArrayGradV3Op grad,
399     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
400   auto local_var = grad.handle();
401   OpBuilder builder(grad);
402   Value grad_var;
403   auto sit = stats->find(local_var);
404   if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
405   auto emplace_res =
406       sit->getSecond().grads.try_emplace(grad.source().str(), Value());
407   if (!emplace_res.second) {
408     // If the source has been assigned a grad, use it.
409     grad_var = emplace_res.first->second;
410   } else {
411     if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
412                                                &grad_var))) {
413       return failure();
414     }
415     emplace_res.first->second = grad_var;
416     // Write to a grad accumulates with previous writes.
417     (*stats)[grad_var].accumulate_on_write = true;
418   }
419   grad.flow_out().replaceAllUsesWith(grad.flow_in());
420   grad.grad_handle().replaceAllUsesWith(grad_var);
421   grad.erase();
422   return success();
423 }
424 
HandleTensorArrayGatherV3Op(TF::TensorArrayGatherV3Op gather,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)425 LogicalResult HandleTensorArrayGatherV3Op(
426     TF::TensorArrayGatherV3Op gather,
427     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
428   auto local_var = gather.handle();
429   if (stats.count(local_var) == 0) {
430     return gather.emitOpError("unknown tensor array");
431   }
432   OpBuilder builder(gather);
433   auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc());
434   auto result =
435       cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc());
436   ReplaceAllUsesExceptTerminator(gather.value(), result);
437   ReplaceAllUsesWithCast(gather.value(), result);
438   gather.erase();
439   return success();
440 }
441 
HandleTensorArrayScatterV3Op(TF::TensorArrayScatterV3Op scatter,const llvm::SmallDenseMap<Value,TensorArrayStats> & stats)442 LogicalResult HandleTensorArrayScatterV3Op(
443     TF::TensorArrayScatterV3Op scatter,
444     const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) {
445   auto local_var = scatter.handle();
446   if (stats.count(local_var) == 0) {
447     return scatter.emitOpError("unknown tensor array");
448   }
449   OpBuilder builder(scatter);
450   auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc());
451   buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(),
452                                             buffer, builder, scatter.getLoc());
453   cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc());
454   scatter.flow_out().replaceAllUsesWith(scatter.flow_in());
455   scatter.erase();
456   return success();
457 }
458 
459 // Updates func's type according to its current arguments and return values.
UpdateFuncType(FuncOp func)460 void UpdateFuncType(FuncOp func) {
461   llvm::SmallVector<Type, 8> arg_types;
462   for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
463   func.setType(
464       FunctionType::get(func.getContext(), arg_types,
465                         func.front().getTerminator()->getOperandTypes()));
466 }
467 
468 // Finds the accessed gradient sources for each tensor array argument.
AccessedGradients(ArrayRef<FuncOp> funcs,ModuleOp module)469 llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
470     ArrayRef<FuncOp> funcs, ModuleOp module) {
471   llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
472   llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
473   auto insert = [&](Value v, const string& source, const Block& func_block) {
474     auto arg = v.dyn_cast<BlockArgument>();
475     if (!arg || arg.getOwner() != &func_block) return;
476     auto insert_res = result_sets[arg.getArgNumber()].insert(source);
477     if (!insert_res.second) return;
478     result[arg.getArgNumber()].push_back(source);
479   };
480   for (FuncOp func : funcs) {
481     const Block& func_block = func.front();
482     // Walk all operations and nested regions to find accessed gradient sources
483     // for function arguments.
484     func.walk([&](Operation* op) {
485       if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
486         op->replaceAllUsesWith(op->getOperands());
487         return;
488       }
489       if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) {
490         insert(grad.handle(), grad.source().str(), func_block);
491       } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
492         for (const auto& entry : AccessedGradients(
493                  {while_op.body_function(), while_op.cond_function()}, module))
494           for (const string& source : entry.getSecond())
495             insert(while_op.getOperand(entry.getFirst()), source, func_block);
496       } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
497         for (const auto& entry : AccessedGradients(
498                  {if_op.then_function(), if_op.else_function()}, module))
499           for (const string& source : entry.getSecond())
500             insert(if_op.getOperand(entry.getFirst() + 1), source, func_block);
501       } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) {
502         auto callee = dyn_cast<FuncOp>(call.resolveCallable());
503         for (const auto& entry : AccessedGradients({callee}, module))
504           for (const string& source : entry.getSecond())
505             insert(call.getArgOperands()[entry.getFirst()], source, func_block);
506       }
507     });
508   }
509   return result;
510 }
511 
512 // Contains cached information for decomposed callee functions for (stateful)
513 // partitioned call ops.
514 struct PartitionedCallTensorArrayOpsInfo {
515   bool signature_change;
516   FuncOp decomposed_callee;
517   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
518       arg_grads;
519   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
520 };
521 
522 // Updates a called function's input signature by adjusting resource types, and
523 // adding required gradient arguments.
ChangeFunctionInputSignature(FuncOp func,const llvm::SmallDenseMap<int64_t,llvm::SmallVector<string,4>> & grads,llvm::function_ref<Type (int64_t)> ta_arg_buffer_type,llvm::function_ref<bool (int64_t)> ta_accumulate_on_write,llvm::SmallDenseMap<Value,TensorArrayStats> * stats)524 void ChangeFunctionInputSignature(
525     FuncOp func,
526     const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
527     llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
528     llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
529     llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
530   int64_t original_args = func.getNumArguments();
531   for (int64_t argnum = 0; argnum < original_args; ++argnum) {
532     auto arg = func.getArgument(argnum);
533     Type t = ta_arg_buffer_type(argnum);
534     if (!t) continue;
535     arg.setType(t);
536     auto grad_it = grads.find(argnum);
537     if (grad_it == grads.end()) continue;
538     llvm::StringMap<Value> grads_map;
539     for (const string& source : grad_it->getSecond()) {
540       auto g = func.front().addArgument(t);
541       (*stats)[g].accumulate_on_write = true;
542       grads_map[source] = g;
543     }
544     auto& stat = (*stats)[arg];
545     stat.accumulate_on_write = ta_accumulate_on_write(argnum);
546     stat.grads = std::move(grads_map);
547   }
548   UpdateFuncType(func);
549 }
550 
551 LogicalResult DecomposeTensorArrayOps(
552     Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
553     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*);
554 
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)555 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
556                             llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
557                             llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
558                                 decomposed_partitioned_call_callees) {
559   auto body = while_op.body_function();
560   auto cond = while_op.cond_function();
561   auto grads = AccessedGradients({body, cond}, module);
562   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
563     auto it = stats->find(while_op.getOperand(index));
564     if (it == stats->end()) return nullptr;
565     return it->getFirst().getType();
566   };
567   auto ta_accumulate_on_write = [&](int64_t index) {
568     auto it = stats->find(while_op.getOperand(index));
569     if (it == stats->end()) return false;
570     return it->getSecond().accumulate_on_write;
571   };
572   llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
573   ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
574                                ta_accumulate_on_write, &body_stats);
575   llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
576   ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
577                                ta_accumulate_on_write, &cond_stats);
578   if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
579                                      decomposed_partitioned_call_callees)) ||
580       failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
581                                      decomposed_partitioned_call_callees))) {
582     return failure();
583   }
584   if (body_stats.empty() && cond_stats.empty()) return success();
585   auto old_body_ret = body.front().getTerminator();
586   auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
587   for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
588     if (!ta_arg_buffer_type(i)) continue;
589     auto retval = old_body_ret->getOperand(i);
590     auto arg = retval.dyn_cast<BlockArgument>();
591     if (!arg) {
592       return while_op.emitOpError(
593           "output tensor array does not alias input in a while loop");
594     }
595     for (const string& source : grads[i]) {
596       new_retvals.push_back(body_stats[arg].grads[source]);
597     }
598   }
599   OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
600   old_body_ret->erase();
601   UpdateFuncType(body);
602   // Recreate the while op.
603   auto operands = llvm::to_vector<8>(while_op.getOperands());
604   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
605     auto grad_it = grads.find(i);
606     auto& stat = (*stats)[operands[i]];
607     if (grad_it == grads.end()) continue;
608     for (const string& source : grad_it->getSecond()) {
609       auto it = stat.grads.find(source);
610       if (it != stat.grads.end()) {
611         operands.push_back(it->second);
612       } else {
613         Value grad_var;
614         if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
615                                                    while_op, &grad_var))) {
616           return failure();
617         }
618         stat.grads[source] = grad_var;
619         operands.push_back(grad_var);
620         (*stats)[grad_var].accumulate_on_write = true;
621       }
622     }
623   }
624   OpBuilder builder(while_op);
625   auto new_while =
626       builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
627                                   operands, while_op->getAttrs());
628   for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
629     if (ta_arg_buffer_type(i)) {
630       while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
631     } else {
632       while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
633     }
634   }
635   while_op.erase();
636   return success();
637 }
638 
HandleIfOp(TF::IfOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)639 LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
640                          llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
641                          llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
642                              decomposed_partitioned_call_callees) {
643   auto then_branch = if_op.then_function();
644   auto else_branch = if_op.else_function();
645   auto grads = AccessedGradients({then_branch, else_branch}, module);
646   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
647     auto it = stats->find(if_op.getOperand(index + 1));
648     if (it == stats->end()) return nullptr;
649     return it->getFirst().getType();
650   };
651   auto ta_accumulate_on_write = [&](int64_t index) {
652     auto it = stats->find(if_op.getOperand(index + 1));
653     if (it == stats->end()) return false;
654     return it->getSecond().accumulate_on_write;
655   };
656   llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
657   ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
658                                ta_accumulate_on_write, &then_stats);
659   llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
660   ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
661                                ta_accumulate_on_write, &else_stats);
662   if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
663                                      decomposed_partitioned_call_callees)) ||
664       failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
665                                      decomposed_partitioned_call_callees))) {
666     return failure();
667   }
668   if (then_stats.empty() && else_stats.empty()) return success();
669   // Recreate the if op.
670   auto operands = llvm::to_vector<8>(if_op.getOperands());
671   for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
672     auto grad_it = grads.find(i);
673     auto& stat = (*stats)[operands[i + 1]];
674     if (grad_it == grads.end()) continue;
675     for (const string& source : grad_it->getSecond()) {
676       auto it = stat.grads.find(source);
677       if (it != stat.grads.end()) {
678         operands.push_back(it->second);
679       } else {
680         Value grad_var;
681         if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
682                                                    if_op, &grad_var))) {
683           return failure();
684         }
685         stat.grads[source] = grad_var;
686         operands.push_back(grad_var);
687         (*stats)[grad_var].accumulate_on_write = true;
688       }
689     }
690   }
691   OpBuilder builder(if_op);
692   auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
693                                          then_branch.getType().getResults(),
694                                          operands, if_op->getAttrs());
695   auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
696     auto retval = f.front().getTerminator()->getOperand(ret_ind);
697     auto arg = retval.dyn_cast<BlockArgument>();
698     if (!arg) return -1;
699     return arg.getArgNumber();
700   };
701   for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
702     if (!getElementTypeOrSelf(if_op.getResult(i).getType())
703              .isa<TF::ResourceType>()) {
704       if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
705       continue;
706     }
707     int64_t then_forward_input = ret_forwards_input(then_branch, i);
708     int64_t else_foward_input = ret_forwards_input(else_branch, i);
709     if (then_forward_input != else_foward_input || then_forward_input < 0) {
710       return if_op.emitOpError(
711           "branches do not forward the same input resource");
712     }
713     if_op.getResult(i).replaceAllUsesWith(
714         if_op.getOperand(then_forward_input + 1));
715   }
716   if_op.erase();
717   return success();
718 }
719 
720 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)721 LogicalResult HandlePartitionedCallOp(
722     CallOp call, FuncOp callee, ModuleOp module,
723     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
724     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
725         decomposed_partitioned_call_callees) {
726   auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
727       callee.getName(), PartitionedCallTensorArrayOpsInfo());
728   auto& info = emplace_res.first->second;
729   // Recreates the call op with info.
730   auto recreate_caller = [&]() -> LogicalResult {
731     auto new_operands = llvm::to_vector<8>(call.getOperands());
732     for (const auto& entry : info.arg_grads) {
733       auto it = stats->find(call.getOperand(entry.first));
734       if (it == stats->end()) return call.emitOpError("unknown tensor array");
735       for (const string& source : entry.second) {
736         auto grad_it = it->getSecond().grads.find(source);
737         if (grad_it != it->getSecond().grads.end()) {
738           new_operands.push_back(grad_it->second);
739         } else {
740           Value grad_var;
741           if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
742                                                      call, &grad_var))) {
743             return failure();
744           }
745           it->getSecond().grads[source] = grad_var;
746           new_operands.push_back(grad_var);
747         }
748       }
749     }
750     OpBuilder builder(call);
751     auto new_call = builder.create<CallOp>(
752         call.getLoc(), info.decomposed_callee.getType().getResults(),
753         new_operands, call->getAttrs());
754     new_call->setAttr(
755         "f", builder.getSymbolRefAttr(
756                  const_cast<FuncOp&>(info.decomposed_callee).getName()));
757     for (const auto& entry : info.ret_forward_input) {
758       call.getResult(entry.first)
759           .replaceAllUsesWith(call.getOperand(entry.second));
760     }
761     call.replaceAllUsesWith(new_call);
762     call.erase();
763     return success();
764   };
765   if (!emplace_res.second) {
766     // This callee was handled before.
767     if (!info.signature_change) return success();
768     return recreate_caller();
769   }
770   // Rewrite the callee.
771   info.signature_change = false;
772   auto ta_arg_buffer_type = [&](int64_t index) -> Type {
773     auto it = stats->find(call.getOperand(index));
774     if (it == stats->end()) return nullptr;
775     info.signature_change = true;
776     return it->getFirst().getType();
777   };
778   auto ta_accumulate_on_write = [&](int64_t index) {
779     auto it = stats->find(call.getOperand(index));
780     if (it == stats->end()) return false;
781     return it->getSecond().accumulate_on_write;
782   };
783   FuncOp lowered_callee = callee;
784   if (!callee.isPrivate()) {
785     // Clone non-private callee in case of signature change.
786     lowered_callee = callee.clone();
787     lowered_callee.setPrivate();
788   }
789   auto grads = AccessedGradients({lowered_callee}, module);
790   for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) {
791     auto it = grads.find(i);
792     if (it == grads.end()) continue;
793     info.arg_grads.emplace_back(i, it->getSecond());
794   }
795   llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
796   ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type,
797                                ta_accumulate_on_write, &callee_stats);
798   if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module,
799                                      &callee_stats,
800                                      decomposed_partitioned_call_callees))) {
801     return failure();
802   }
803   for (int64_t i = 0; i < call.getNumResults(); ++i) {
804     auto ret = lowered_callee.front().getTerminator()->getOperand(i);
805     if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
806     auto arg = ret.dyn_cast<BlockArgument>();
807     if (!arg) continue;
808     info.ret_forward_input.emplace_back(i, arg.getArgNumber());
809   }
810 
811   info.decomposed_callee = lowered_callee;
812   if (lowered_callee != callee) {
813     if (!info.signature_change) {
814       // Signature is not modified. We do not need to keep two copies.
815       lowered_callee.setName(callee.getName());
816       callee.erase();
817     } else {
818       // Add the clone with a new name.
819       lowered_callee.setName(
820           llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str());
821     }
822     SymbolTable(module).insert(lowered_callee);
823   }
824   if (info.signature_change) return recreate_caller();
825   return success();
826 }
827 
HandleRegionControlFlowOps(Operation & op,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)828 LogicalResult HandleRegionControlFlowOps(
829     Operation& op, ModuleOp module,
830     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
831     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
832         decomposed_partitioned_call_callees) {
833   for (OpOperand& operand : op.getOpOperands()) {
834     if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) {
835       return op.emitOpError()
836              << "found unexpected type " << operand.get().getType()
837              << " of operand #" << operand.getOperandNumber()
838              << ", resource type operands are expected to have been "
839                 "canonicalized away for region based control flow ops";
840     }
841   }
842   for (OpResult result : op.getResults()) {
843     if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
844       return op.emitOpError()
845              << "found unexpected type " << result.getType() << " of result #"
846              << result.getResultNumber()
847              << ", resource type results are expected to have been "
848                 "canonicalized away for region based control flow ops";
849     }
850   }
851 
852   for (Region& region : op.getRegions()) {
853     if (failed(DecomposeTensorArrayOps(&region.front(), module, stats,
854                                        decomposed_partitioned_call_callees)))
855       return failure();
856   }
857   return success();
858 }
859 
DecomposeTensorArrayOps(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,TensorArrayStats> * stats,llvm::StringMap<PartitionedCallTensorArrayOpsInfo> * decomposed_partitioned_call_callees)860 LogicalResult DecomposeTensorArrayOps(
861     Block* block, ModuleOp module,
862     llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
863     llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
864         decomposed_partitioned_call_callees) {
865   for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
866     if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
867       op.replaceAllUsesWith(op.getOperands());
868       op.erase();
869     } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
870       if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
871         return failure();
872       }
873     } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
874       if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
875     } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
876       if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
877     } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
878       if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
879     } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
880       if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
881     } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
882       if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
883     } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
884       if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
885     } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
886       if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
887     } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
888       if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
889         return failure();
890       }
891     } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
892       close.erase();
893     } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
894       if (failed(HandleWhileOp(while_op, module, stats,
895                                decomposed_partitioned_call_callees))) {
896         return failure();
897       }
898     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
899       if (failed(HandleIfOp(if_op, module, stats,
900                             decomposed_partitioned_call_callees))) {
901         return failure();
902       }
903     } else if (llvm::isa<TF::CaseRegionOp>(op) ||
904                llvm::isa<TF::IfRegionOp>(op) ||
905                llvm::isa<TF::WhileRegionOp>(op)) {
906       if (failed(HandleRegionControlFlowOps(
907               op, module, stats, decomposed_partitioned_call_callees)))
908         return failure();
909     } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
910       auto callee = pcall.func();
911       if (!callee)
912         return pcall.emitOpError(
913             "TensorArray decomposition does not support call with nested "
914             "references.");
915 
916       if (failed(
917               HandlePartitionedCallOp(pcall, callee, module, stats,
918                                       decomposed_partitioned_call_callees))) {
919         return failure();
920       }
921     } else if (auto spcall =
922                    llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
923       if (failed(
924               HandlePartitionedCallOp(spcall, spcall.func(), module, stats,
925                                       decomposed_partitioned_call_callees))) {
926         return failure();
927       }
928     }
929   }
930   return success();
931 }
932 
runOnOperation()933 void TensorArrayOpsDecompositionPass::runOnOperation() {
934   auto module = getOperation();
935   auto main = module.lookupSymbol<FuncOp>("main");
936   if (!main) return;
937   llvm::SmallDenseMap<Value, TensorArrayStats> stats;
938   llvm::StringMap<PartitionedCallTensorArrayOpsInfo>
939       decomposed_partitioned_call_callees;
940   if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
941                                      &decomposed_partitioned_call_callees))) {
942     signalPassFailure();
943   }
944 }
945 
946 }  // namespace
947 
948 namespace TF {
949 
950 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorArrayOpsDecompositionPass()951 CreateTensorArrayOpsDecompositionPass() {
952   return std::make_unique<TensorArrayOpsDecompositionPass>();
953 }
954 
955 }  // namespace TF
956 }  // namespace mlir
957