• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iostream>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Location.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
46 #include "tensorflow/core/framework/tensor_shape.pb.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48 
49 namespace mlir {
50 namespace TFTPU {
51 
52 namespace {
53 
54 constexpr char kDeviceAttr[] = "device";
55 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
56 
57 struct BlockArgumentInfo {
58   unsigned arg_num;
59   unsigned num_users;
60 };
61 
62 // A pass that applies automatic space to depth transform for the first or
63 // frontier convolutions consume host inputs on TPU.
64 // This is done by adding space to depth transform op after host input and
65 // applying space to depth transform for the first convolution and its backprop
66 // filter on TPU.
67 //
68 // Example: original program:
69 //
70 // module {
71 //   func @while_body {
72 //     %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}:
73 //              -> tensor<2x224x224x3xf32>
74 //     %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...)
75 //     return ...
76 //   }
77 //   func @_func(%input: tensor<2x224x224x3xf32>,
78 //               %filter: tensor<7x7x3x64xf32>) {
79 //     %6 = "tf.Conv2D"(%input, %filter)  {strides = [1, 2, 2, 1]}:
80 //      (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) ->
81 //      tensor<2x112x112x64xf32>
82 //   }
83 // }
84 //
85 // With this pass, the program will be transformed into:
86 // module {
87 //   func @while_body {
88 //     %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
89 //               -> tensor<2x224x224x3xf32>
90 //     %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}:
91 //        (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
92 //     %device_launch = "tf_device.cluster_func"(%space_to_depth,...)
93 //       {func = @_func,...)
94 //     return ...
95 //   }
96 //   func @_func(%input: tensor<2x112x112x12xf32>,
97 //               %filter: tensor<7x7x3x64xf32>) {
98 //     %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter):
99 //       tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
100 //     %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}:
101 //       (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) ->
102 //       tensor<2x112x112x64xf32>
103 //   }
104 // }
105 //
106 // This way, the first convolution with 3 feature dimension will be transformed
107 // to 12 feature dimension, which has better performance on TPU.
108 //
109 // TODO(wangtao): add a pass to check if it is profitable to space to depth
110 // transform and invoke the transform if it is needed.
111 struct TPUSpaceToDepthPass
112     : public PassWrapper<TPUSpaceToDepthPass, OperationPass<ModuleOp>> {
113   void runOnOperation() override;
114 };
115 
116 // Updates func argument type to have the updated input shape.
UpdateFuncType(FuncOp func)117 void UpdateFuncType(FuncOp func) {
118   auto arg_types = func.front().getArgumentTypes();
119   auto result_types = func.front().getTerminator()->getOperandTypes();
120   func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
121 }
122 
HandleFuncOp(Operation * op)123 void HandleFuncOp(Operation* op) {
124   auto func = llvm::cast<FuncOp>(op);
125   UpdateFuncType(func);
126 }
127 
128 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)129 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
130   auto cast_input = cast_op.x();
131   // Update input type.
132   auto transform_result_type =
133       RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
134   cast_input.setType(transform_result_type);
135   auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
136   auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
137   while (block_arg || cast_op_input) {
138     if (block_arg) {
139       // Change on device function type/shape.
140       HandleFuncOp(block_arg.getOwner()->getParentOp());
141       block_arg = nullptr;
142       cast_op_input = nullptr;
143     } else {
144       auto cast_input = cast_op_input.x();
145       // Update input type.
146       auto transform_result_type =
147           RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
148       cast_input.setType(transform_result_type);
149       // Update block arg and cast_op_input.
150       block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
151       cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
152     }
153   }
154   return success();
155 }
156 
157 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)158 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
159   auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
160   if (!ranked_type) return failure();
161   auto pad_input_shape = ranked_type.getShape();
162   Location loc = op.getLoc();
163   OpBuilder builder(op);
164   builder.setInsertionPoint(op);
165   auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
166 
167   // Calculate paddings.
168   int32_t pad_total = kernel_size - 1;
169   int32_t pad_beg = (pad_total / 2 + 1) / block_size;
170   int32_t pad_end = (pad_total / 2) / block_size;
171   SmallVector<int32_t, 8> values = {0,       0,       pad_beg, pad_end,
172                                     pad_beg, pad_end, 0,       0};
173   auto paddings = DenseIntElementsAttr::get(padding_type, values);
174   // Update pad_op paddings.
175   op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
176 
177   // Set input type.
178   auto input = op.getOperand(0);
179   SmallVector<int64_t, 4> transform_shape = {
180       pad_input_shape[0], pad_input_shape[1] / block_size,
181       pad_input_shape[2] / block_size,
182       pad_input_shape[3] * block_size * block_size};
183   // Input of the pad op could be a cast op.
184   if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
185     if (failed(HandleCast(cast_op, transform_shape))) return failure();
186 
187   auto transform_result_type =
188       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
189   input.setType(transform_result_type);
190   op.setOperand(0, input);
191   return success();
192 }
193 
194 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)195 void HandleConv2DStride(TF::Conv2DOp conv2d) {
196   MLIRContext* context = conv2d.getContext();
197   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
198   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
199     return IntegerAttr::get(IntegerType::get(context, 64), v);
200   });
201   // TODO(b/157276506): change type of strides to DenseElementsAttr
202   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
203   conv2d->setAttr("strides", strides);
204 }
205 
206 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)207 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
208   auto input = conv2d.input();
209   auto input_shape = input.getType().cast<RankedTensorType>().getShape();
210   SmallVector<int64_t, 4> transform_shape = {
211       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
212       input_shape[3] * block_size * block_size};
213   auto transform_result_type =
214       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
215   input.setType(transform_result_type);
216 }
217 
218 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)219 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
220                                   OpBuilder* builder, int32_t pad_h,
221                                   int32_t pad_w) {
222   SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
223   auto padding_type =
224       RankedTensorType::get({4, 2}, builder->getIntegerType(32));
225   auto paddings = DenseIntElementsAttr::get(padding_type, values);
226   auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
227   std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
228                                     filter_shape[1] + pad_w, filter_shape[2],
229                                     filter_shape[3]};
230   SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
231 
232   auto expand_result_type =
233       RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
234   return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
235                                     paddings_value);
236 }
237 
238 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)239 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
240                                           Value input, OpBuilder* builder) {
241   auto reshape_result_type =
242       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
243   auto reshape_type = RankedTensorType::get(
244       {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
245   auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
246   auto reshape_value =
247       builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
248   return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
249                                         input, reshape_value);
250 }
251 
252 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)253 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
254   SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
255   auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
256   auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
257   auto permute_value =
258       builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
259   return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
260 }
261 
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)262 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
263   // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
264   // will apply below transforms to the filter:
265   // 1. Pad the filter to [8, 8, 3, 64]
266   // 2. Reshape to [4, 2, 4, 2, 3, 64]
267   // 3. Transpose to [4, 4, 2, 2, 3, 64]
268   // 4. Reshape to [4, 4, 12, 64]
269   auto filter = conv2d.filter();
270   OpBuilder builder(conv2d);
271   builder.setInsertionPoint(conv2d);
272   // Book keeping filter information.
273   auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
274   int64_t height = filter_shape[0];
275   int64_t width = filter_shape[1];
276   int64_t channel = filter_shape[2];
277   int64_t out_channel = filter_shape[3];
278   // Value/Op before reshape op.
279   Value before_reshape_value = filter;
280   if (height % block_size != 0 || width % block_size != 0) {
281     // Calculate paddings for height and width.
282     int32_t pad_h = block_size - height % block_size;
283     int32_t pad_w = block_size - width % block_size;
284     auto pad_op =
285         GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
286     // Update op, height and width before reshape.
287     before_reshape_value = pad_op;
288     height = height + pad_h;
289     width = width + pad_w;
290   }
291 
292   // Reshape.
293   SmallVector<int64_t, 6> new_shape = {
294       height / block_size, block_size, width / block_size,
295       block_size,          channel,    out_channel};
296   auto reshape_op =
297       GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
298 
299   // Transpose.
300   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
301 
302   // Reshape Back.
303   SmallVector<int64_t, 4> final_shape = {
304       height / block_size, width / block_size,
305       channel * block_size * block_size, out_channel};
306   auto final_reshape_op =
307       GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
308   // Update filter of Conv2D.
309   conv2d.setOperand(1, final_reshape_op);
310 }
311 
312 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)313 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
314     ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
315   SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
316                                      old_filter_shape.end());
317   auto slice_result_type =
318       RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
319   auto slice_size_op = builder->create<TF::ConstOp>(
320       input.getLoc(),
321       DenseIntElementsAttr::get(
322           RankedTensorType::get({4}, builder->getIntegerType(32)),
323           old_filter_shape));
324   SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
325   auto start_position_type =
326       RankedTensorType::get({4}, builder->getIntegerType(64));
327   auto start_position = builder->create<TF::ConstOp>(
328       input.getLoc(),
329       DenseIntElementsAttr::get(start_position_type, slice_start_position));
330   return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
331                                       start_position, slice_size_op);
332 }
333 
334 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)335 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
336                                 ArrayRef<int32_t> old_filter_shape,
337                                 ArrayRef<int32_t> new_filter_shape,
338                                 int64_t block_size) {
339   OpBuilder builder(backprop);
340   builder.setInsertionPoint(backprop);
341 
342   auto input = backprop.input();
343   // Get new filter size from new_filter_shape.
344   auto new_filter_sizes = builder.create<TF::ConstOp>(
345       backprop.getLoc(),
346       DenseIntElementsAttr::get(
347           RankedTensorType::get({4}, builder.getIntegerType(32)),
348           new_filter_shape));
349 
350   // Set stride to [1, 1, 1, 1].
351   MLIRContext* context = backprop.getContext();
352   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
353   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
354     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
355   });
356   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
357 
358   // new result type.
359   SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
360                                     new_filter_shape.end());
361   auto new_result_type =
362       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
363 
364   // Build new BackPropFilterOp.
365   auto loc = backprop.getLoc();
366   auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
367       loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
368       strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
369       backprop.explicit_paddings(), backprop.data_format(),
370       backprop.dilations());
371 
372   // For example, if new filter shape is [4, 4, 12, 64], old filter shape
373   // is [7, 7, 3, 64] with block_size 2.
374   // Below transforms will be applied to the filter:
375   // 1. Reshape to [4, 4, 2, 2, 3, 64];
376   // 2. Transpose to [4, 2, 4, 2, 3, 64];
377   // 3. Reshape to [8, 8, 3, 64];
378   // 4. Slice to [7, 7, 3, 64].
379   SmallVector<int64_t, 6> first_reshape_shape = {
380       new_filter_shape[0],
381       new_filter_shape[1],
382       block_size,
383       block_size,
384       new_filter_shape[2] / (block_size * block_size),
385       new_filter_shape[3]};
386   auto first_reshape_op =
387       GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
388 
389   // Transpose.
390   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
391 
392   // Last Reshape op.
393   SmallVector<int64_t, 4> last_reshape_shape = {
394       new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
395       new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
396   auto final_reshape_op =
397       GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
398 
399   // create slice op.
400   auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
401                                                     final_reshape_op, &builder);
402 
403   // Update backprop's user with the slice op.
404   backprop.replaceAllUsesWith(slice_op.getResult());
405 }
406 
407 // Checks if the input producer op is supported in this transform. Right now, we
408 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)409 bool IsSupportedHostInputOp(Operation* op) {
410   TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
411   if (!iter) return false;
412   auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
413   if (!device) return false;
414   tensorflow::DeviceNameUtils::ParsedName parsed_device;
415   if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
416                                                   &parsed_device)) {
417     return false;
418   }
419   return parsed_device.type == "CPU";
420 }
421 
422 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)423 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
424                                      Value input, int32_t block_size,
425                                      ArrayRef<int64_t> input_shape) {
426   auto input_op = input.getDefiningOp();
427   OpBuilder builder(input_op);
428   builder.setInsertionPointAfter(input_op);
429   SmallVector<int64_t, 4> transform_shape = {
430       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
431       input_shape[3] * block_size * block_size};
432   auto transform_result_type =
433       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
434   return builder.create<TF::SpaceToDepthOp>(
435       cluster_func.getLoc(), transform_result_type, input, block_size);
436 }
437 
438 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)439 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
440                                    tf_device::ClusterFuncOp cluster_func,
441                                    int32_t block_size,
442                                    ArrayRef<int64_t> input_shape) {
443   auto space_to_depth =
444       BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
445   cluster_func.setOperand(index, space_to_depth);
446   return space_to_depth;
447 }
448 
449 // Performs transformation for replicated inputs. Returns true if this is a
450 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)451 bool HandleHostReplicatedInputs(int64_t index,
452                                 tf_device::ClusterFuncOp cluster_func,
453                                 BlockArgument block_arg,
454                                 tf_device::ReplicateOp replicate,
455                                 int32_t block_size) {
456   // We need to know the devices to copy to.
457   if (!replicate.devices()) return false;
458 
459   MutableArrayRef<OpOperand> inputs =
460       replicate.GetOperandsForBlockArgument(block_arg);
461   for (auto& input : inputs) {
462     auto input_op = input.get().getDefiningOp();
463     if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
464   }
465   for (auto entry : llvm::enumerate(inputs)) {
466     Value input = entry.value().get();
467     auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
468     if (!ranked_type) return false;
469     auto input_shape = ranked_type.getShape();
470     auto space_to_depth =
471         BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
472     entry.value().set(space_to_depth);
473     block_arg.setType(space_to_depth.getType());
474   }
475   return true;
476 }
477 
478 // Performs transformation on a pair of execute and compile ops. The compile
479 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)480 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
481                    unsigned arg_num) {
482   auto maybe_replicate =
483       llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
484 
485   llvm::SmallVector<int64_t, 8> transform_input_indices;
486   for (auto input : llvm::enumerate(cluster_func.operands())) {
487     if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
488       if (block_arg.getArgNumber() != arg_num) continue;
489       // For a block argument, consider transforms only when it is a replicated
490       // input (defining ops will be outside the replicate node).
491       if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
492         HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
493                                    maybe_replicate, block_size);
494       }
495     } else {
496       // For an op output, consider transforms only when 1) there is no
497       // replicateion or 2) it is outside the replicate node that encloses the
498       // execute node. (Because if the op is inside replicate, it is probably
499       // not on the host.)
500       if (input.index() != arg_num) continue;
501       auto input_op = input.value().getDefiningOp();
502       if (maybe_replicate &&
503           maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
504         continue;
505       }
506       if (!IsSupportedHostInputOp(input_op)) continue;
507       auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
508       if (!ranked_type) continue;
509       auto input_shape = ranked_type.getShape();
510       HandleHostInput(input.value(), input.index(), cluster_func, block_size,
511                       input_shape);
512     }
513   }
514 }
515 
516 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)517 bool Conv2DInputShapeCanTransform(Value input) {
518   auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
519   if (!ranked_type) return false;
520   auto input_shape = ranked_type.getShape();
521   int32_t batch_size = input_shape[0];
522   int32_t channel = input_shape[3];
523   if (batch_size > 8 || channel > 8) {
524     return false;
525   }
526   return true;
527 }
528 
529 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)530 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
531   if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
532     if (!Conv2DInputShapeCanTransform(arg)) return None;
533     unsigned num_users =
534         std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
535     BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
536     return block_arg_info;
537   }
538   return None;
539 }
540 
541 // Gets input block argument id and number of users for the input recursively.
542 // Current supported ops between convolution input and the block arguments are
543 // PadOp and CastOp.
GetInputBlockArgNum(Value input)544 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
545   auto block_arg_num = GetBlockArgNum(input);
546   if (block_arg_num.hasValue()) return block_arg_num;
547 
548   Value next_input = input;
549   auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
550   auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
551 
552   while (pad_op || cast_op) {
553     if (pad_op) {
554       auto block_arg_num = GetBlockArgNum(pad_op.input());
555       if (block_arg_num.hasValue()) return block_arg_num;
556       next_input = pad_op.input();
557     } else {
558       auto block_arg_num = GetBlockArgNum(cast_op.x());
559       if (block_arg_num.hasValue()) return block_arg_num;
560       next_input = cast_op.x();
561     }
562     pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
563     cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
564   }
565 
566   return None;
567 }
568 
569 // Checks if a convoluton can apply SpaceToDepth transform.
570 // Only the first convolution in the graph whose batch size smaller than 8
571 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)572 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
573   if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
574     return None;
575   }
576   // Current supported ops between convolution input and the block arguments are
577   // PadOp and CastOp.
578   return GetInputBlockArgNum(conv2d.input());
579 }
580 
581 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)582 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
583   // Check if input and filter type are RankedTensorType.
584   auto input_tensor_type =
585       conv2d.input().getType().dyn_cast<RankedTensorType>();
586   auto filter_tensor_type =
587       conv2d.filter().getType().dyn_cast<RankedTensorType>();
588   if (!input_tensor_type || !filter_tensor_type) return;
589   // Book keeping filter shape for padding and backprop filter rewrite.
590   auto filter_shape = filter_tensor_type.getShape();
591   SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
592                                            filter_shape.end());
593   // Handles input.
594   auto conv2d_input = conv2d.input();
595   if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
596     // Change on device function type/shape.
597     HandleFuncOp(block_arg.getOwner()->getParentOp());
598   }
599 
600   if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
601     // Rewrite pad_op before Convolutioin.
602     if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
603     auto pad_input = pad_op.input();
604     if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
605       // Change on device function type/shape.
606       HandleFuncOp(block_arg.getOwner()->getParentOp());
607     }
608   }
609 
610   // Handle Conv2D input, stride and filter.
611   HandleConv2DInput(conv2d, block_size);
612   HandleConv2DStride(conv2d);
613   HandleConv2DFilter(conv2d, block_size);
614 
615   // Book keeping new filter shape for backprop filter rewrite.
616   // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
617   filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
618   SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
619                                            filter_shape.end());
620 
621   // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
622   if (!conv2d_input.getDefiningOp()) return;
623   for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
624     if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
625       HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
626                                  block_size);
627     }
628   }
629 }
630 
631 // Gets block size that is equal to stride from spatial dimension
632 // from convolution.
633 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)634 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
635   SmallVector<int32_t, 4> strides(4, 1);
636   for (int i = 0; i < 3; ++i) {
637     strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
638   }
639 
640   // Space to depth only supports striding at spatial dimension.
641   if (strides[0] != 1 || strides[3] != 1) return 1;
642 
643   // Space to depth only supports height_stride == width_stride case.
644   if (strides[1] != strides[2]) return 1;
645 
646   return strides[1];
647 }
648 
runOnOperation()649 void TPUSpaceToDepthPass::runOnOperation() {
650   Optional<tf_device::ClusterFuncOp> cluster_func;
651   // Space to depth only supports training loop.
652   auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
653     cluster_func = cluster;
654     return WalkResult::interrupt();
655   });
656 
657   // Return if there is no tf_device::ClusterFuncOp in training loop.
658   if (!func_result.wasInterrupted() || !cluster_func.hasValue()) {
659     return;
660   }
661 
662   // Get the function on device.
663   auto device_func = cluster_func->getFunc();
664   if (!device_func) return;
665 
666   TF::Conv2DOp first_conv;
667   // A map maps block argument id to the convolutions consumes them.
668   llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
669       argnum_and_convolutions;
670   // A map maps block argument id to the number of users.
671   llvm::SmallDenseMap<unsigned, int> argnum_num_users;
672 
673   // Find out the qualified convolutions and its block argument ids.
674   auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
675     Optional<BlockArgumentInfo> arg_num_and_num_users =
676         GetConv2DInputArgNum(conv2d);
677     if (arg_num_and_num_users.hasValue()) {
678       // Get block size for the first convolution.
679       int64_t block_size = GetConv2DBlockSize(conv2d);
680       auto arg_num = arg_num_and_num_users.getValue().arg_num;
681       auto num_users = arg_num_and_num_users.getValue().num_users;
682       argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
683       argnum_num_users[arg_num] = num_users;
684       return WalkResult::interrupt();
685     }
686     return WalkResult::advance();
687   });
688   if (!conv2d_result.wasInterrupted()) {
689     return;
690   }
691 
692   // Iterate through block argument and its convolution users. Space to depth
693   // transform will be applied only if all the below conditions are satisfied:
694   //  1. All the users of the block argument will lead to convolutions;
695   //  2. block_size of for the space to depth transform for these convolutions
696   //     are the same;
697   //  3. block_size of for the space to depth transform for these convolutions
698   //     are larger than 1.
699   for (auto argnum_and_convolution : argnum_and_convolutions) {
700     auto arg_num = argnum_and_convolution.getFirst();
701     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
702     // Continue if number of users of the block argment doesn't equal to number
703     // of transformable convolutions and there is no qualified convolution
704     // for transform or block size is smaller than 2.
705     if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
706         conv2d_and_block_sizes.empty()) {
707       argnum_and_convolutions.erase(arg_num);
708       continue;
709     }
710     int64_t block_size = conv2d_and_block_sizes[0].second;
711     if (block_size < 2) {
712       argnum_and_convolutions.erase(arg_num);
713       continue;
714     }
715     // Continue if not all the block sizes for space to depth transform are the
716     // same.
717     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
718       if (conv2d_and_block_size.second != block_size) {
719         argnum_and_convolutions.erase(arg_num);
720         break;
721       }
722     }
723   }
724 
725   // If there is no qualified space to depth transform.
726   if (argnum_and_convolutions.empty()) {
727     return;
728   }
729 
730   // Apply space to depth transform.
731   for (auto argnum_and_convolution : argnum_and_convolutions) {
732     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
733     int64_t block_size = conv2d_and_block_sizes[0].second;
734     // Apply space to depth transform to the input on the host.
735     HandleCluster(cluster_func.getValue(), block_size,
736                   argnum_and_convolution.getFirst());
737     // Transform the convolution.
738     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
739       HandleFirstConvolution(conv2d_and_block_size.first,
740                              conv2d_and_block_size.second);
741     }
742   }
743 }
744 
745 }  // namespace
746 
CreateTPUSpaceToDepthPass()747 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
748   return std::make_unique<TPUSpaceToDepthPass>();
749 }
750 
751 static PassRegistration<TPUSpaceToDepthPass> pass(
752     "tf-tpu-space-to-depth-pass",
753     "Adds ops that allow TPU program enable automaic space to depth for the"
754     "convolution determined at JIT compile time.");
755 
756 }  // namespace TFTPU
757 }  // namespace mlir
758