• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iostream>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Location.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 
50 namespace mlir {
51 namespace TFTPU {
52 
53 namespace {
54 
55 constexpr char kDeviceAttr[] = "device";
56 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
57 
58 struct BlockArgumentInfo {
59   unsigned arg_num;
60   unsigned num_users;
61 };
62 
63 // TODO(wangtao): add a pass to check if it is profitable to space to depth
64 // transform and invoke the transform if it is needed.
65 struct TPUSpaceToDepthPass
66     : public TF::TPUSpaceToDepthPassBase<TPUSpaceToDepthPass> {
67   void runOnOperation() override;
68 };
69 
70 // Updates func argument type to have the updated input shape.
UpdateFuncType(FuncOp func)71 void UpdateFuncType(FuncOp func) {
72   auto arg_types = func.front().getArgumentTypes();
73   auto result_types = func.front().getTerminator()->getOperandTypes();
74   func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
75 }
76 
HandleFuncOp(Operation * op)77 void HandleFuncOp(Operation* op) {
78   auto func = llvm::cast<FuncOp>(op);
79   UpdateFuncType(func);
80 }
81 
82 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)83 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
84   auto cast_input = cast_op.x();
85   // Update input type.
86   auto transform_result_type =
87       RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
88   cast_input.setType(transform_result_type);
89   auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
90   auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
91   while (block_arg || cast_op_input) {
92     if (block_arg) {
93       // Change on device function type/shape.
94       HandleFuncOp(block_arg.getOwner()->getParentOp());
95       block_arg = nullptr;
96       cast_op_input = nullptr;
97     } else {
98       auto cast_input = cast_op_input.x();
99       // Update input type.
100       auto transform_result_type =
101           RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
102       cast_input.setType(transform_result_type);
103       // Update block arg and cast_op_input.
104       block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
105       cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
106     }
107   }
108   return success();
109 }
110 
111 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)112 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
113   auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
114   if (!ranked_type) return failure();
115   auto pad_input_shape = ranked_type.getShape();
116   Location loc = op.getLoc();
117   OpBuilder builder(op);
118   builder.setInsertionPoint(op);
119   auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
120 
121   // Calculate paddings.
122   int32_t pad_total = kernel_size - 1;
123   int32_t pad_beg = (pad_total / 2 + 1) / block_size;
124   int32_t pad_end = (pad_total / 2) / block_size;
125   SmallVector<int32_t, 8> values = {0,       0,       pad_beg, pad_end,
126                                     pad_beg, pad_end, 0,       0};
127   auto paddings = DenseIntElementsAttr::get(padding_type, values);
128   // Update pad_op paddings.
129   op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
130 
131   // Set input type.
132   auto input = op.getOperand(0);
133   SmallVector<int64_t, 4> transform_shape = {
134       pad_input_shape[0], pad_input_shape[1] / block_size,
135       pad_input_shape[2] / block_size,
136       pad_input_shape[3] * block_size * block_size};
137   // Input of the pad op could be a cast op.
138   if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
139     if (failed(HandleCast(cast_op, transform_shape))) return failure();
140 
141   auto transform_result_type =
142       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
143   input.setType(transform_result_type);
144   op.setOperand(0, input);
145   return success();
146 }
147 
148 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)149 void HandleConv2DStride(TF::Conv2DOp conv2d) {
150   MLIRContext* context = conv2d.getContext();
151   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
152   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
153     return IntegerAttr::get(IntegerType::get(context, 64), v);
154   });
155   // TODO(b/157276506): change type of strides to DenseElementsAttr
156   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
157   conv2d->setAttr("strides", strides);
158 }
159 
160 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)161 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
162   auto input = conv2d.input();
163   auto input_shape = input.getType().cast<RankedTensorType>().getShape();
164   SmallVector<int64_t, 4> transform_shape = {
165       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
166       input_shape[3] * block_size * block_size};
167   auto transform_result_type =
168       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
169   input.setType(transform_result_type);
170 }
171 
172 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)173 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
174                                   OpBuilder* builder, int32_t pad_h,
175                                   int32_t pad_w) {
176   SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
177   auto padding_type =
178       RankedTensorType::get({4, 2}, builder->getIntegerType(32));
179   auto paddings = DenseIntElementsAttr::get(padding_type, values);
180   auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
181   std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
182                                     filter_shape[1] + pad_w, filter_shape[2],
183                                     filter_shape[3]};
184   SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
185 
186   auto expand_result_type =
187       RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
188   return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
189                                     paddings_value);
190 }
191 
192 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)193 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
194                                           Value input, OpBuilder* builder) {
195   auto reshape_result_type =
196       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
197   auto reshape_type = RankedTensorType::get(
198       {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
199   auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
200   auto reshape_value =
201       builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
202   return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
203                                         input, reshape_value);
204 }
205 
206 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)207 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
208   SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
209   auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
210   auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
211   auto permute_value =
212       builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
213   return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
214 }
215 
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)216 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
217   // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
218   // will apply below transforms to the filter:
219   // 1. Pad the filter to [8, 8, 3, 64]
220   // 2. Reshape to [4, 2, 4, 2, 3, 64]
221   // 3. Transpose to [4, 4, 2, 2, 3, 64]
222   // 4. Reshape to [4, 4, 12, 64]
223   auto filter = conv2d.filter();
224   OpBuilder builder(conv2d);
225   builder.setInsertionPoint(conv2d);
226   // Book keeping filter information.
227   auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
228   int64_t height = filter_shape[0];
229   int64_t width = filter_shape[1];
230   int64_t channel = filter_shape[2];
231   int64_t out_channel = filter_shape[3];
232   // Value/Op before reshape op.
233   Value before_reshape_value = filter;
234   if (height % block_size != 0 || width % block_size != 0) {
235     // Calculate paddings for height and width.
236     int32_t pad_h = block_size - height % block_size;
237     int32_t pad_w = block_size - width % block_size;
238     auto pad_op =
239         GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
240     // Update op, height and width before reshape.
241     before_reshape_value = pad_op;
242     height = height + pad_h;
243     width = width + pad_w;
244   }
245 
246   // Reshape.
247   SmallVector<int64_t, 6> new_shape = {
248       height / block_size, block_size, width / block_size,
249       block_size,          channel,    out_channel};
250   auto reshape_op =
251       GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
252 
253   // Transpose.
254   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
255 
256   // Reshape Back.
257   SmallVector<int64_t, 4> final_shape = {
258       height / block_size, width / block_size,
259       channel * block_size * block_size, out_channel};
260   auto final_reshape_op =
261       GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
262   // Update filter of Conv2D.
263   conv2d.setOperand(1, final_reshape_op);
264 }
265 
266 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)267 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
268     ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
269   SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
270                                      old_filter_shape.end());
271   auto slice_result_type =
272       RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
273   auto slice_size_op = builder->create<TF::ConstOp>(
274       input.getLoc(),
275       DenseIntElementsAttr::get(
276           RankedTensorType::get({4}, builder->getIntegerType(32)),
277           old_filter_shape));
278   SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
279   auto start_position_type =
280       RankedTensorType::get({4}, builder->getIntegerType(64));
281   auto start_position = builder->create<TF::ConstOp>(
282       input.getLoc(),
283       DenseIntElementsAttr::get(start_position_type, slice_start_position));
284   return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
285                                       start_position, slice_size_op);
286 }
287 
288 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)289 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
290                                 ArrayRef<int32_t> old_filter_shape,
291                                 ArrayRef<int32_t> new_filter_shape,
292                                 int64_t block_size) {
293   OpBuilder builder(backprop);
294   builder.setInsertionPoint(backprop);
295 
296   auto input = backprop.input();
297   // Get new filter size from new_filter_shape.
298   auto new_filter_sizes = builder.create<TF::ConstOp>(
299       backprop.getLoc(),
300       DenseIntElementsAttr::get(
301           RankedTensorType::get({4}, builder.getIntegerType(32)),
302           new_filter_shape));
303 
304   // Set stride to [1, 1, 1, 1].
305   MLIRContext* context = backprop.getContext();
306   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
307   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
308     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
309   });
310   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
311 
312   // new result type.
313   SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
314                                     new_filter_shape.end());
315   auto new_result_type =
316       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
317 
318   // Build new BackPropFilterOp.
319   auto loc = backprop.getLoc();
320   auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
321       loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
322       strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
323       backprop.explicit_paddings(), backprop.data_format(),
324       backprop.dilations());
325 
326   // For example, if new filter shape is [4, 4, 12, 64], old filter shape
327   // is [7, 7, 3, 64] with block_size 2.
328   // Below transforms will be applied to the filter:
329   // 1. Reshape to [4, 4, 2, 2, 3, 64];
330   // 2. Transpose to [4, 2, 4, 2, 3, 64];
331   // 3. Reshape to [8, 8, 3, 64];
332   // 4. Slice to [7, 7, 3, 64].
333   SmallVector<int64_t, 6> first_reshape_shape = {
334       new_filter_shape[0],
335       new_filter_shape[1],
336       block_size,
337       block_size,
338       new_filter_shape[2] / (block_size * block_size),
339       new_filter_shape[3]};
340   auto first_reshape_op =
341       GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
342 
343   // Transpose.
344   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
345 
346   // Last Reshape op.
347   SmallVector<int64_t, 4> last_reshape_shape = {
348       new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
349       new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
350   auto final_reshape_op =
351       GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
352 
353   // create slice op.
354   auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
355                                                     final_reshape_op, &builder);
356 
357   // Update backprop's user with the slice op.
358   backprop.replaceAllUsesWith(slice_op.getResult());
359 }
360 
361 // Checks if the input producer op is supported in this transform. Right now, we
362 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)363 bool IsSupportedHostInputOp(Operation* op) {
364   TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
365   if (!iter) return false;
366   auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
367   if (!device) return false;
368   tensorflow::DeviceNameUtils::ParsedName parsed_device;
369   if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
370                                                   &parsed_device)) {
371     return false;
372   }
373   return parsed_device.type == "CPU";
374 }
375 
376 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)377 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
378                                      Value input, int32_t block_size,
379                                      ArrayRef<int64_t> input_shape) {
380   auto input_op = input.getDefiningOp();
381   OpBuilder builder(input_op);
382   builder.setInsertionPointAfter(input_op);
383   SmallVector<int64_t, 4> transform_shape = {
384       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
385       input_shape[3] * block_size * block_size};
386   auto transform_result_type =
387       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
388   return builder.create<TF::SpaceToDepthOp>(
389       cluster_func.getLoc(), transform_result_type, input, block_size);
390 }
391 
392 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)393 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
394                                    tf_device::ClusterFuncOp cluster_func,
395                                    int32_t block_size,
396                                    ArrayRef<int64_t> input_shape) {
397   auto space_to_depth =
398       BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
399   cluster_func.setOperand(index, space_to_depth);
400   return space_to_depth;
401 }
402 
403 // Performs transformation for replicated inputs. Returns true if this is a
404 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)405 bool HandleHostReplicatedInputs(int64_t index,
406                                 tf_device::ClusterFuncOp cluster_func,
407                                 BlockArgument block_arg,
408                                 tf_device::ReplicateOp replicate,
409                                 int32_t block_size) {
410   // We need to know the devices to copy to.
411   if (!replicate.devices()) return false;
412 
413   MutableArrayRef<OpOperand> inputs =
414       replicate.GetOperandsForBlockArgument(block_arg);
415   for (auto& input : inputs) {
416     auto input_op = input.get().getDefiningOp();
417     if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
418   }
419   for (auto entry : llvm::enumerate(inputs)) {
420     Value input = entry.value().get();
421     auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
422     if (!ranked_type) return false;
423     auto input_shape = ranked_type.getShape();
424     auto space_to_depth =
425         BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
426     entry.value().set(space_to_depth);
427     block_arg.setType(space_to_depth.getType());
428   }
429   return true;
430 }
431 
432 // Performs transformation on a pair of execute and compile ops. The compile
433 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)434 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
435                    unsigned arg_num) {
436   auto maybe_replicate =
437       llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
438 
439   llvm::SmallVector<int64_t, 8> transform_input_indices;
440   for (auto input : llvm::enumerate(cluster_func.operands())) {
441     if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
442       if (block_arg.getArgNumber() != arg_num) continue;
443       // For a block argument, consider transforms only when it is a replicated
444       // input (defining ops will be outside the replicate node).
445       if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
446         HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
447                                    maybe_replicate, block_size);
448       }
449     } else {
450       // For an op output, consider transforms only when 1) there is no
451       // replicateion or 2) it is outside the replicate node that encloses the
452       // execute node. (Because if the op is inside replicate, it is probably
453       // not on the host.)
454       if (input.index() != arg_num) continue;
455       auto input_op = input.value().getDefiningOp();
456       if (maybe_replicate &&
457           maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
458         continue;
459       }
460       if (!IsSupportedHostInputOp(input_op)) continue;
461       auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
462       if (!ranked_type) continue;
463       auto input_shape = ranked_type.getShape();
464       HandleHostInput(input.value(), input.index(), cluster_func, block_size,
465                       input_shape);
466     }
467   }
468 }
469 
470 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)471 bool Conv2DInputShapeCanTransform(Value input) {
472   auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
473   if (!ranked_type) return false;
474   auto input_shape = ranked_type.getShape();
475   int32_t batch_size = input_shape[0];
476   int32_t channel = input_shape[3];
477   if (batch_size > 8 || channel > 8) {
478     return false;
479   }
480   return true;
481 }
482 
483 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)484 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
485   if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
486     if (!Conv2DInputShapeCanTransform(arg)) return None;
487     unsigned num_users =
488         std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
489     BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
490     return block_arg_info;
491   }
492   return None;
493 }
494 
495 // Gets input block argument id and number of users for the input recursively.
496 // Current supported ops between convolution input and the block arguments are
497 // PadOp and CastOp.
GetInputBlockArgNum(Value input)498 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
499   auto block_arg_num = GetBlockArgNum(input);
500   if (block_arg_num.hasValue()) return block_arg_num;
501 
502   Value next_input = input;
503   auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
504   auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
505 
506   while (pad_op || cast_op) {
507     if (pad_op) {
508       auto block_arg_num = GetBlockArgNum(pad_op.input());
509       if (block_arg_num.hasValue()) return block_arg_num;
510       next_input = pad_op.input();
511     } else {
512       auto block_arg_num = GetBlockArgNum(cast_op.x());
513       if (block_arg_num.hasValue()) return block_arg_num;
514       next_input = cast_op.x();
515     }
516     pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
517     cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
518   }
519 
520   return None;
521 }
522 
523 // Checks if a convoluton can apply SpaceToDepth transform.
524 // Only the first convolution in the graph whose batch size smaller than 8
525 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)526 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
527   if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
528     return None;
529   }
530   // Current supported ops between convolution input and the block arguments are
531   // PadOp and CastOp.
532   return GetInputBlockArgNum(conv2d.input());
533 }
534 
535 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)536 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
537   // Check if input and filter type are RankedTensorType.
538   auto input_tensor_type =
539       conv2d.input().getType().dyn_cast<RankedTensorType>();
540   auto filter_tensor_type =
541       conv2d.filter().getType().dyn_cast<RankedTensorType>();
542   if (!input_tensor_type || !filter_tensor_type) return;
543   // Book keeping filter shape for padding and backprop filter rewrite.
544   auto filter_shape = filter_tensor_type.getShape();
545   SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
546                                            filter_shape.end());
547   // Handles input.
548   auto conv2d_input = conv2d.input();
549   if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
550     // Change on device function type/shape.
551     HandleFuncOp(block_arg.getOwner()->getParentOp());
552   }
553 
554   if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
555     // Rewrite pad_op before Convolutioin.
556     if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
557     auto pad_input = pad_op.input();
558     if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
559       // Change on device function type/shape.
560       HandleFuncOp(block_arg.getOwner()->getParentOp());
561     }
562   }
563 
564   // Handle Conv2D input, stride and filter.
565   HandleConv2DInput(conv2d, block_size);
566   HandleConv2DStride(conv2d);
567   HandleConv2DFilter(conv2d, block_size);
568 
569   // Book keeping new filter shape for backprop filter rewrite.
570   // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
571   filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
572   SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
573                                            filter_shape.end());
574 
575   // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
576   if (!conv2d_input.getDefiningOp()) return;
577   for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
578     if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
579       HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
580                                  block_size);
581     }
582   }
583 }
584 
585 // Gets block size that is equal to stride from spatial dimension
586 // from convolution.
587 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)588 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
589   SmallVector<int32_t, 4> strides(4, 1);
590   for (int i = 0; i < 3; ++i) {
591     strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
592   }
593 
594   // Space to depth only supports striding at spatial dimension.
595   if (strides[0] != 1 || strides[3] != 1) return 1;
596 
597   // Space to depth only supports height_stride == width_stride case.
598   if (strides[1] != strides[2]) return 1;
599 
600   return strides[1];
601 }
602 
runOnOperation()603 void TPUSpaceToDepthPass::runOnOperation() {
604   Optional<tf_device::ClusterFuncOp> cluster_func;
605   // Space to depth only supports training loop.
606   auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
607     cluster_func = cluster;
608     return WalkResult::interrupt();
609   });
610 
611   // Return if there is no tf_device::ClusterFuncOp in training loop.
612   if (!func_result.wasInterrupted() || !cluster_func.hasValue()) {
613     return;
614   }
615 
616   // Get the function on device.
617   auto device_func = cluster_func->getFunc();
618   if (!device_func) return;
619 
620   TF::Conv2DOp first_conv;
621   // A map maps block argument id to the convolutions consumes them.
622   llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
623       argnum_and_convolutions;
624   // A map maps block argument id to the number of users.
625   llvm::SmallDenseMap<unsigned, int> argnum_num_users;
626 
627   // Find out the qualified convolutions and its block argument ids.
628   auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
629     Optional<BlockArgumentInfo> arg_num_and_num_users =
630         GetConv2DInputArgNum(conv2d);
631     if (arg_num_and_num_users.hasValue()) {
632       // Get block size for the first convolution.
633       int64_t block_size = GetConv2DBlockSize(conv2d);
634       auto arg_num = arg_num_and_num_users.getValue().arg_num;
635       auto num_users = arg_num_and_num_users.getValue().num_users;
636       argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
637       argnum_num_users[arg_num] = num_users;
638       return WalkResult::interrupt();
639     }
640     return WalkResult::advance();
641   });
642   if (!conv2d_result.wasInterrupted()) {
643     return;
644   }
645 
646   // Iterate through block argument and its convolution users. Space to depth
647   // transform will be applied only if all the below conditions are satisfied:
648   //  1. All the users of the block argument will lead to convolutions;
649   //  2. block_size of for the space to depth transform for these convolutions
650   //     are the same;
651   //  3. block_size of for the space to depth transform for these convolutions
652   //     are larger than 1.
653   for (auto argnum_and_convolution : argnum_and_convolutions) {
654     auto arg_num = argnum_and_convolution.getFirst();
655     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
656     // Continue if number of users of the block argment doesn't equal to number
657     // of transformable convolutions and there is no qualified convolution
658     // for transform or block size is smaller than 2.
659     if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
660         conv2d_and_block_sizes.empty()) {
661       argnum_and_convolutions.erase(arg_num);
662       continue;
663     }
664     int64_t block_size = conv2d_and_block_sizes[0].second;
665     if (block_size < 2) {
666       argnum_and_convolutions.erase(arg_num);
667       continue;
668     }
669     // Continue if not all the block sizes for space to depth transform are the
670     // same.
671     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
672       if (conv2d_and_block_size.second != block_size) {
673         argnum_and_convolutions.erase(arg_num);
674         break;
675       }
676     }
677   }
678 
679   // If there is no qualified space to depth transform.
680   if (argnum_and_convolutions.empty()) {
681     return;
682   }
683 
684   // Apply space to depth transform.
685   for (auto argnum_and_convolution : argnum_and_convolutions) {
686     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
687     int64_t block_size = conv2d_and_block_sizes[0].second;
688     // Apply space to depth transform to the input on the host.
689     HandleCluster(cluster_func.getValue(), block_size,
690                   argnum_and_convolution.getFirst());
691     // Transform the convolution.
692     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
693       HandleFirstConvolution(conv2d_and_block_size.first,
694                              conv2d_and_block_size.second);
695     }
696   }
697 }
698 
699 }  // namespace
700 
CreateTPUSpaceToDepthPass()701 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
702   return std::make_unique<TPUSpaceToDepthPass>();
703 }
704 
705 }  // namespace TFTPU
706 }  // namespace mlir
707