1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iostream>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Location.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/OperationSupport.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
46 #include "tensorflow/core/framework/tensor_shape.pb.h"
47 #include "tensorflow/core/util/device_name_utils.h"
48
49 namespace mlir {
50 namespace TFTPU {
51
52 namespace {
53
54 constexpr char kDeviceAttr[] = "device";
55 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
56
57 struct BlockArgumentInfo {
58 unsigned arg_num;
59 unsigned num_users;
60 };
61
62 // A pass that applies automatic space to depth transform for the first or
63 // frontier convolutions consume host inputs on TPU.
64 // This is done by adding space to depth transform op after host input and
65 // applying space to depth transform for the first convolution and its backprop
66 // filter on TPU.
67 //
68 // Example: original program:
69 //
70 // module {
71 // func @while_body {
72 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}:
73 // -> tensor<2x224x224x3xf32>
74 // %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...)
75 // return ...
76 // }
77 // func @_func(%input: tensor<2x224x224x3xf32>,
78 // %filter: tensor<7x7x3x64xf32>) {
79 // %6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}:
80 // (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) ->
81 // tensor<2x112x112x64xf32>
82 // }
83 // }
84 //
85 // With this pass, the program will be transformed into:
86 // module {
87 // func @while_body {
88 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
89 // -> tensor<2x224x224x3xf32>
90 // %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}:
91 // (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
92 // %device_launch = "tf_device.cluster_func"(%space_to_depth,...)
93 // {func = @_func,...)
94 // return ...
95 // }
96 // func @_func(%input: tensor<2x112x112x12xf32>,
97 // %filter: tensor<7x7x3x64xf32>) {
98 // %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter):
99 // tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
100 // %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}:
101 // (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) ->
102 // tensor<2x112x112x64xf32>
103 // }
104 // }
105 //
106 // This way, the first convolution with 3 feature dimension will be transformed
107 // to 12 feature dimension, which has better performance on TPU.
108 //
109 // TODO(wangtao): add a pass to check if it is profitable to space to depth
110 // transform and invoke the transform if it is needed.
111 struct TPUSpaceToDepthPass
112 : public PassWrapper<TPUSpaceToDepthPass, OperationPass<ModuleOp>> {
113 void runOnOperation() override;
114 };
115
116 // Updates func argument type to have the updated input shape.
UpdateFuncType(FuncOp func)117 void UpdateFuncType(FuncOp func) {
118 auto arg_types = func.front().getArgumentTypes();
119 auto result_types = func.front().getTerminator()->getOperandTypes();
120 func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
121 }
122
HandleFuncOp(Operation * op)123 void HandleFuncOp(Operation* op) {
124 auto func = llvm::cast<FuncOp>(op);
125 UpdateFuncType(func);
126 }
127
128 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)129 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
130 auto cast_input = cast_op.x();
131 // Update input type.
132 auto transform_result_type =
133 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
134 cast_input.setType(transform_result_type);
135 auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
136 auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
137 while (block_arg || cast_op_input) {
138 if (block_arg) {
139 // Change on device function type/shape.
140 HandleFuncOp(block_arg.getOwner()->getParentOp());
141 block_arg = nullptr;
142 cast_op_input = nullptr;
143 } else {
144 auto cast_input = cast_op_input.x();
145 // Update input type.
146 auto transform_result_type =
147 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
148 cast_input.setType(transform_result_type);
149 // Update block arg and cast_op_input.
150 block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
151 cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
152 }
153 }
154 return success();
155 }
156
157 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)158 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
159 auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
160 if (!ranked_type) return failure();
161 auto pad_input_shape = ranked_type.getShape();
162 Location loc = op.getLoc();
163 OpBuilder builder(op);
164 builder.setInsertionPoint(op);
165 auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
166
167 // Calculate paddings.
168 int32_t pad_total = kernel_size - 1;
169 int32_t pad_beg = (pad_total / 2 + 1) / block_size;
170 int32_t pad_end = (pad_total / 2) / block_size;
171 SmallVector<int32_t, 8> values = {0, 0, pad_beg, pad_end,
172 pad_beg, pad_end, 0, 0};
173 auto paddings = DenseIntElementsAttr::get(padding_type, values);
174 // Update pad_op paddings.
175 op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
176
177 // Set input type.
178 auto input = op.getOperand(0);
179 SmallVector<int64_t, 4> transform_shape = {
180 pad_input_shape[0], pad_input_shape[1] / block_size,
181 pad_input_shape[2] / block_size,
182 pad_input_shape[3] * block_size * block_size};
183 // Input of the pad op could be a cast op.
184 if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
185 if (failed(HandleCast(cast_op, transform_shape))) return failure();
186
187 auto transform_result_type =
188 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
189 input.setType(transform_result_type);
190 op.setOperand(0, input);
191 return success();
192 }
193
194 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)195 void HandleConv2DStride(TF::Conv2DOp conv2d) {
196 MLIRContext* context = conv2d.getContext();
197 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
198 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
199 return IntegerAttr::get(IntegerType::get(context, 64), v);
200 });
201 // TODO(b/157276506): change type of strides to DenseElementsAttr
202 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
203 conv2d->setAttr("strides", strides);
204 }
205
206 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)207 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
208 auto input = conv2d.input();
209 auto input_shape = input.getType().cast<RankedTensorType>().getShape();
210 SmallVector<int64_t, 4> transform_shape = {
211 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
212 input_shape[3] * block_size * block_size};
213 auto transform_result_type =
214 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
215 input.setType(transform_result_type);
216 }
217
218 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)219 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
220 OpBuilder* builder, int32_t pad_h,
221 int32_t pad_w) {
222 SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
223 auto padding_type =
224 RankedTensorType::get({4, 2}, builder->getIntegerType(32));
225 auto paddings = DenseIntElementsAttr::get(padding_type, values);
226 auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
227 std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
228 filter_shape[1] + pad_w, filter_shape[2],
229 filter_shape[3]};
230 SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
231
232 auto expand_result_type =
233 RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
234 return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
235 paddings_value);
236 }
237
238 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)239 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
240 Value input, OpBuilder* builder) {
241 auto reshape_result_type =
242 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
243 auto reshape_type = RankedTensorType::get(
244 {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
245 auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
246 auto reshape_value =
247 builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
248 return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
249 input, reshape_value);
250 }
251
252 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)253 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
254 SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
255 auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
256 auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
257 auto permute_value =
258 builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
259 return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
260 }
261
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)262 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
263 // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
264 // will apply below transforms to the filter:
265 // 1. Pad the filter to [8, 8, 3, 64]
266 // 2. Reshape to [4, 2, 4, 2, 3, 64]
267 // 3. Transpose to [4, 4, 2, 2, 3, 64]
268 // 4. Reshape to [4, 4, 12, 64]
269 auto filter = conv2d.filter();
270 OpBuilder builder(conv2d);
271 builder.setInsertionPoint(conv2d);
272 // Book keeping filter information.
273 auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
274 int64_t height = filter_shape[0];
275 int64_t width = filter_shape[1];
276 int64_t channel = filter_shape[2];
277 int64_t out_channel = filter_shape[3];
278 // Value/Op before reshape op.
279 Value before_reshape_value = filter;
280 if (height % block_size != 0 || width % block_size != 0) {
281 // Calculate paddings for height and width.
282 int32_t pad_h = block_size - height % block_size;
283 int32_t pad_w = block_size - width % block_size;
284 auto pad_op =
285 GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
286 // Update op, height and width before reshape.
287 before_reshape_value = pad_op;
288 height = height + pad_h;
289 width = width + pad_w;
290 }
291
292 // Reshape.
293 SmallVector<int64_t, 6> new_shape = {
294 height / block_size, block_size, width / block_size,
295 block_size, channel, out_channel};
296 auto reshape_op =
297 GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
298
299 // Transpose.
300 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
301
302 // Reshape Back.
303 SmallVector<int64_t, 4> final_shape = {
304 height / block_size, width / block_size,
305 channel * block_size * block_size, out_channel};
306 auto final_reshape_op =
307 GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
308 // Update filter of Conv2D.
309 conv2d.setOperand(1, final_reshape_op);
310 }
311
312 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)313 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
314 ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
315 SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
316 old_filter_shape.end());
317 auto slice_result_type =
318 RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
319 auto slice_size_op = builder->create<TF::ConstOp>(
320 input.getLoc(),
321 DenseIntElementsAttr::get(
322 RankedTensorType::get({4}, builder->getIntegerType(32)),
323 old_filter_shape));
324 SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
325 auto start_position_type =
326 RankedTensorType::get({4}, builder->getIntegerType(64));
327 auto start_position = builder->create<TF::ConstOp>(
328 input.getLoc(),
329 DenseIntElementsAttr::get(start_position_type, slice_start_position));
330 return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
331 start_position, slice_size_op);
332 }
333
334 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)335 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
336 ArrayRef<int32_t> old_filter_shape,
337 ArrayRef<int32_t> new_filter_shape,
338 int64_t block_size) {
339 OpBuilder builder(backprop);
340 builder.setInsertionPoint(backprop);
341
342 auto input = backprop.input();
343 // Get new filter size from new_filter_shape.
344 auto new_filter_sizes = builder.create<TF::ConstOp>(
345 backprop.getLoc(),
346 DenseIntElementsAttr::get(
347 RankedTensorType::get({4}, builder.getIntegerType(32)),
348 new_filter_shape));
349
350 // Set stride to [1, 1, 1, 1].
351 MLIRContext* context = backprop.getContext();
352 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
353 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
354 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
355 });
356 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
357
358 // new result type.
359 SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
360 new_filter_shape.end());
361 auto new_result_type =
362 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
363
364 // Build new BackPropFilterOp.
365 auto loc = backprop.getLoc();
366 auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
367 loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
368 strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
369 backprop.explicit_paddings(), backprop.data_format(),
370 backprop.dilations());
371
372 // For example, if new filter shape is [4, 4, 12, 64], old filter shape
373 // is [7, 7, 3, 64] with block_size 2.
374 // Below transforms will be applied to the filter:
375 // 1. Reshape to [4, 4, 2, 2, 3, 64];
376 // 2. Transpose to [4, 2, 4, 2, 3, 64];
377 // 3. Reshape to [8, 8, 3, 64];
378 // 4. Slice to [7, 7, 3, 64].
379 SmallVector<int64_t, 6> first_reshape_shape = {
380 new_filter_shape[0],
381 new_filter_shape[1],
382 block_size,
383 block_size,
384 new_filter_shape[2] / (block_size * block_size),
385 new_filter_shape[3]};
386 auto first_reshape_op =
387 GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
388
389 // Transpose.
390 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
391
392 // Last Reshape op.
393 SmallVector<int64_t, 4> last_reshape_shape = {
394 new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
395 new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
396 auto final_reshape_op =
397 GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
398
399 // create slice op.
400 auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
401 final_reshape_op, &builder);
402
403 // Update backprop's user with the slice op.
404 backprop.replaceAllUsesWith(slice_op.getResult());
405 }
406
407 // Checks if the input producer op is supported in this transform. Right now, we
408 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)409 bool IsSupportedHostInputOp(Operation* op) {
410 TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
411 if (!iter) return false;
412 auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
413 if (!device) return false;
414 tensorflow::DeviceNameUtils::ParsedName parsed_device;
415 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
416 &parsed_device)) {
417 return false;
418 }
419 return parsed_device.type == "CPU";
420 }
421
422 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)423 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
424 Value input, int32_t block_size,
425 ArrayRef<int64_t> input_shape) {
426 auto input_op = input.getDefiningOp();
427 OpBuilder builder(input_op);
428 builder.setInsertionPointAfter(input_op);
429 SmallVector<int64_t, 4> transform_shape = {
430 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
431 input_shape[3] * block_size * block_size};
432 auto transform_result_type =
433 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
434 return builder.create<TF::SpaceToDepthOp>(
435 cluster_func.getLoc(), transform_result_type, input, block_size);
436 }
437
438 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)439 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
440 tf_device::ClusterFuncOp cluster_func,
441 int32_t block_size,
442 ArrayRef<int64_t> input_shape) {
443 auto space_to_depth =
444 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
445 cluster_func.setOperand(index, space_to_depth);
446 return space_to_depth;
447 }
448
449 // Performs transformation for replicated inputs. Returns true if this is a
450 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)451 bool HandleHostReplicatedInputs(int64_t index,
452 tf_device::ClusterFuncOp cluster_func,
453 BlockArgument block_arg,
454 tf_device::ReplicateOp replicate,
455 int32_t block_size) {
456 // We need to know the devices to copy to.
457 if (!replicate.devices()) return false;
458
459 MutableArrayRef<OpOperand> inputs =
460 replicate.GetOperandsForBlockArgument(block_arg);
461 for (auto& input : inputs) {
462 auto input_op = input.get().getDefiningOp();
463 if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
464 }
465 for (auto entry : llvm::enumerate(inputs)) {
466 Value input = entry.value().get();
467 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
468 if (!ranked_type) return false;
469 auto input_shape = ranked_type.getShape();
470 auto space_to_depth =
471 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
472 entry.value().set(space_to_depth);
473 block_arg.setType(space_to_depth.getType());
474 }
475 return true;
476 }
477
478 // Performs transformation on a pair of execute and compile ops. The compile
479 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)480 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
481 unsigned arg_num) {
482 auto maybe_replicate =
483 llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
484
485 llvm::SmallVector<int64_t, 8> transform_input_indices;
486 for (auto input : llvm::enumerate(cluster_func.operands())) {
487 if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
488 if (block_arg.getArgNumber() != arg_num) continue;
489 // For a block argument, consider transforms only when it is a replicated
490 // input (defining ops will be outside the replicate node).
491 if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
492 HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
493 maybe_replicate, block_size);
494 }
495 } else {
496 // For an op output, consider transforms only when 1) there is no
497 // replicateion or 2) it is outside the replicate node that encloses the
498 // execute node. (Because if the op is inside replicate, it is probably
499 // not on the host.)
500 if (input.index() != arg_num) continue;
501 auto input_op = input.value().getDefiningOp();
502 if (maybe_replicate &&
503 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
504 continue;
505 }
506 if (!IsSupportedHostInputOp(input_op)) continue;
507 auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
508 if (!ranked_type) continue;
509 auto input_shape = ranked_type.getShape();
510 HandleHostInput(input.value(), input.index(), cluster_func, block_size,
511 input_shape);
512 }
513 }
514 }
515
516 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)517 bool Conv2DInputShapeCanTransform(Value input) {
518 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
519 if (!ranked_type) return false;
520 auto input_shape = ranked_type.getShape();
521 int32_t batch_size = input_shape[0];
522 int32_t channel = input_shape[3];
523 if (batch_size > 8 || channel > 8) {
524 return false;
525 }
526 return true;
527 }
528
529 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)530 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
531 if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
532 if (!Conv2DInputShapeCanTransform(arg)) return None;
533 unsigned num_users =
534 std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
535 BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
536 return block_arg_info;
537 }
538 return None;
539 }
540
541 // Gets input block argument id and number of users for the input recursively.
542 // Current supported ops between convolution input and the block arguments are
543 // PadOp and CastOp.
GetInputBlockArgNum(Value input)544 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
545 auto block_arg_num = GetBlockArgNum(input);
546 if (block_arg_num.hasValue()) return block_arg_num;
547
548 Value next_input = input;
549 auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
550 auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
551
552 while (pad_op || cast_op) {
553 if (pad_op) {
554 auto block_arg_num = GetBlockArgNum(pad_op.input());
555 if (block_arg_num.hasValue()) return block_arg_num;
556 next_input = pad_op.input();
557 } else {
558 auto block_arg_num = GetBlockArgNum(cast_op.x());
559 if (block_arg_num.hasValue()) return block_arg_num;
560 next_input = cast_op.x();
561 }
562 pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
563 cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
564 }
565
566 return None;
567 }
568
569 // Checks if a convoluton can apply SpaceToDepth transform.
570 // Only the first convolution in the graph whose batch size smaller than 8
571 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)572 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
573 if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
574 return None;
575 }
576 // Current supported ops between convolution input and the block arguments are
577 // PadOp and CastOp.
578 return GetInputBlockArgNum(conv2d.input());
579 }
580
581 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)582 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
583 // Check if input and filter type are RankedTensorType.
584 auto input_tensor_type =
585 conv2d.input().getType().dyn_cast<RankedTensorType>();
586 auto filter_tensor_type =
587 conv2d.filter().getType().dyn_cast<RankedTensorType>();
588 if (!input_tensor_type || !filter_tensor_type) return;
589 // Book keeping filter shape for padding and backprop filter rewrite.
590 auto filter_shape = filter_tensor_type.getShape();
591 SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
592 filter_shape.end());
593 // Handles input.
594 auto conv2d_input = conv2d.input();
595 if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
596 // Change on device function type/shape.
597 HandleFuncOp(block_arg.getOwner()->getParentOp());
598 }
599
600 if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
601 // Rewrite pad_op before Convolutioin.
602 if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
603 auto pad_input = pad_op.input();
604 if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
605 // Change on device function type/shape.
606 HandleFuncOp(block_arg.getOwner()->getParentOp());
607 }
608 }
609
610 // Handle Conv2D input, stride and filter.
611 HandleConv2DInput(conv2d, block_size);
612 HandleConv2DStride(conv2d);
613 HandleConv2DFilter(conv2d, block_size);
614
615 // Book keeping new filter shape for backprop filter rewrite.
616 // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
617 filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
618 SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
619 filter_shape.end());
620
621 // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
622 if (!conv2d_input.getDefiningOp()) return;
623 for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
624 if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
625 HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
626 block_size);
627 }
628 }
629 }
630
631 // Gets block size that is equal to stride from spatial dimension
632 // from convolution.
633 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)634 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
635 SmallVector<int32_t, 4> strides(4, 1);
636 for (int i = 0; i < 3; ++i) {
637 strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
638 }
639
640 // Space to depth only supports striding at spatial dimension.
641 if (strides[0] != 1 || strides[3] != 1) return 1;
642
643 // Space to depth only supports height_stride == width_stride case.
644 if (strides[1] != strides[2]) return 1;
645
646 return strides[1];
647 }
648
runOnOperation()649 void TPUSpaceToDepthPass::runOnOperation() {
650 Optional<tf_device::ClusterFuncOp> cluster_func;
651 // Space to depth only supports training loop.
652 auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
653 cluster_func = cluster;
654 return WalkResult::interrupt();
655 });
656
657 // Return if there is no tf_device::ClusterFuncOp in training loop.
658 if (!func_result.wasInterrupted() || !cluster_func.hasValue()) {
659 return;
660 }
661
662 // Get the function on device.
663 auto device_func = cluster_func->getFunc();
664 if (!device_func) return;
665
666 TF::Conv2DOp first_conv;
667 // A map maps block argument id to the convolutions consumes them.
668 llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
669 argnum_and_convolutions;
670 // A map maps block argument id to the number of users.
671 llvm::SmallDenseMap<unsigned, int> argnum_num_users;
672
673 // Find out the qualified convolutions and its block argument ids.
674 auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
675 Optional<BlockArgumentInfo> arg_num_and_num_users =
676 GetConv2DInputArgNum(conv2d);
677 if (arg_num_and_num_users.hasValue()) {
678 // Get block size for the first convolution.
679 int64_t block_size = GetConv2DBlockSize(conv2d);
680 auto arg_num = arg_num_and_num_users.getValue().arg_num;
681 auto num_users = arg_num_and_num_users.getValue().num_users;
682 argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
683 argnum_num_users[arg_num] = num_users;
684 return WalkResult::interrupt();
685 }
686 return WalkResult::advance();
687 });
688 if (!conv2d_result.wasInterrupted()) {
689 return;
690 }
691
692 // Iterate through block argument and its convolution users. Space to depth
693 // transform will be applied only if all the below conditions are satisfied:
694 // 1. All the users of the block argument will lead to convolutions;
695 // 2. block_size of for the space to depth transform for these convolutions
696 // are the same;
697 // 3. block_size of for the space to depth transform for these convolutions
698 // are larger than 1.
699 for (auto argnum_and_convolution : argnum_and_convolutions) {
700 auto arg_num = argnum_and_convolution.getFirst();
701 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
702 // Continue if number of users of the block argment doesn't equal to number
703 // of transformable convolutions and there is no qualified convolution
704 // for transform or block size is smaller than 2.
705 if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
706 conv2d_and_block_sizes.empty()) {
707 argnum_and_convolutions.erase(arg_num);
708 continue;
709 }
710 int64_t block_size = conv2d_and_block_sizes[0].second;
711 if (block_size < 2) {
712 argnum_and_convolutions.erase(arg_num);
713 continue;
714 }
715 // Continue if not all the block sizes for space to depth transform are the
716 // same.
717 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
718 if (conv2d_and_block_size.second != block_size) {
719 argnum_and_convolutions.erase(arg_num);
720 break;
721 }
722 }
723 }
724
725 // If there is no qualified space to depth transform.
726 if (argnum_and_convolutions.empty()) {
727 return;
728 }
729
730 // Apply space to depth transform.
731 for (auto argnum_and_convolution : argnum_and_convolutions) {
732 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
733 int64_t block_size = conv2d_and_block_sizes[0].second;
734 // Apply space to depth transform to the input on the host.
735 HandleCluster(cluster_func.getValue(), block_size,
736 argnum_and_convolution.getFirst());
737 // Transform the convolution.
738 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
739 HandleFirstConvolution(conv2d_and_block_size.first,
740 conv2d_and_block_size.second);
741 }
742 }
743 }
744
745 } // namespace
746
CreateTPUSpaceToDepthPass()747 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
748 return std::make_unique<TPUSpaceToDepthPass>();
749 }
750
751 static PassRegistration<TPUSpaceToDepthPass> pass(
752 "tf-tpu-space-to-depth-pass",
753 "Adds ops that allow TPU program enable automaic space to depth for the"
754 "convolution determined at JIT compile time.");
755
756 } // namespace TFTPU
757 } // namespace mlir
758