1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iostream>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Location.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/OperationSupport.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
35 #include "mlir/IR/Types.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
47 #include "tensorflow/core/framework/tensor_shape.pb.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49
50 namespace mlir {
51 namespace TFTPU {
52
53 namespace {
54
55 constexpr char kDeviceAttr[] = "device";
56 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
57
58 struct BlockArgumentInfo {
59 unsigned arg_num;
60 unsigned num_users;
61 };
62
63 // TODO(wangtao): add a pass to check if it is profitable to space to depth
64 // transform and invoke the transform if it is needed.
65 struct TPUSpaceToDepthPass
66 : public TF::TPUSpaceToDepthPassBase<TPUSpaceToDepthPass> {
67 void runOnOperation() override;
68 };
69
70 // Updates func argument type to have the updated input shape.
UpdateFuncType(FuncOp func)71 void UpdateFuncType(FuncOp func) {
72 auto arg_types = func.front().getArgumentTypes();
73 auto result_types = func.front().getTerminator()->getOperandTypes();
74 func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
75 }
76
HandleFuncOp(Operation * op)77 void HandleFuncOp(Operation* op) {
78 auto func = llvm::cast<FuncOp>(op);
79 UpdateFuncType(func);
80 }
81
82 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)83 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
84 auto cast_input = cast_op.x();
85 // Update input type.
86 auto transform_result_type =
87 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
88 cast_input.setType(transform_result_type);
89 auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
90 auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
91 while (block_arg || cast_op_input) {
92 if (block_arg) {
93 // Change on device function type/shape.
94 HandleFuncOp(block_arg.getOwner()->getParentOp());
95 block_arg = nullptr;
96 cast_op_input = nullptr;
97 } else {
98 auto cast_input = cast_op_input.x();
99 // Update input type.
100 auto transform_result_type =
101 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
102 cast_input.setType(transform_result_type);
103 // Update block arg and cast_op_input.
104 block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
105 cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
106 }
107 }
108 return success();
109 }
110
111 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)112 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
113 auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
114 if (!ranked_type) return failure();
115 auto pad_input_shape = ranked_type.getShape();
116 Location loc = op.getLoc();
117 OpBuilder builder(op);
118 builder.setInsertionPoint(op);
119 auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
120
121 // Calculate paddings.
122 int32_t pad_total = kernel_size - 1;
123 int32_t pad_beg = (pad_total / 2 + 1) / block_size;
124 int32_t pad_end = (pad_total / 2) / block_size;
125 SmallVector<int32_t, 8> values = {0, 0, pad_beg, pad_end,
126 pad_beg, pad_end, 0, 0};
127 auto paddings = DenseIntElementsAttr::get(padding_type, values);
128 // Update pad_op paddings.
129 op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
130
131 // Set input type.
132 auto input = op.getOperand(0);
133 SmallVector<int64_t, 4> transform_shape = {
134 pad_input_shape[0], pad_input_shape[1] / block_size,
135 pad_input_shape[2] / block_size,
136 pad_input_shape[3] * block_size * block_size};
137 // Input of the pad op could be a cast op.
138 if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
139 if (failed(HandleCast(cast_op, transform_shape))) return failure();
140
141 auto transform_result_type =
142 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
143 input.setType(transform_result_type);
144 op.setOperand(0, input);
145 return success();
146 }
147
148 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)149 void HandleConv2DStride(TF::Conv2DOp conv2d) {
150 MLIRContext* context = conv2d.getContext();
151 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
152 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
153 return IntegerAttr::get(IntegerType::get(context, 64), v);
154 });
155 // TODO(b/157276506): change type of strides to DenseElementsAttr
156 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
157 conv2d->setAttr("strides", strides);
158 }
159
160 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)161 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
162 auto input = conv2d.input();
163 auto input_shape = input.getType().cast<RankedTensorType>().getShape();
164 SmallVector<int64_t, 4> transform_shape = {
165 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
166 input_shape[3] * block_size * block_size};
167 auto transform_result_type =
168 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
169 input.setType(transform_result_type);
170 }
171
172 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)173 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
174 OpBuilder* builder, int32_t pad_h,
175 int32_t pad_w) {
176 SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
177 auto padding_type =
178 RankedTensorType::get({4, 2}, builder->getIntegerType(32));
179 auto paddings = DenseIntElementsAttr::get(padding_type, values);
180 auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
181 std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
182 filter_shape[1] + pad_w, filter_shape[2],
183 filter_shape[3]};
184 SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
185
186 auto expand_result_type =
187 RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
188 return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
189 paddings_value);
190 }
191
192 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)193 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
194 Value input, OpBuilder* builder) {
195 auto reshape_result_type =
196 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
197 auto reshape_type = RankedTensorType::get(
198 {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
199 auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
200 auto reshape_value =
201 builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
202 return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
203 input, reshape_value);
204 }
205
206 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)207 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
208 SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
209 auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
210 auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
211 auto permute_value =
212 builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
213 return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
214 }
215
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)216 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
217 // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
218 // will apply below transforms to the filter:
219 // 1. Pad the filter to [8, 8, 3, 64]
220 // 2. Reshape to [4, 2, 4, 2, 3, 64]
221 // 3. Transpose to [4, 4, 2, 2, 3, 64]
222 // 4. Reshape to [4, 4, 12, 64]
223 auto filter = conv2d.filter();
224 OpBuilder builder(conv2d);
225 builder.setInsertionPoint(conv2d);
226 // Book keeping filter information.
227 auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
228 int64_t height = filter_shape[0];
229 int64_t width = filter_shape[1];
230 int64_t channel = filter_shape[2];
231 int64_t out_channel = filter_shape[3];
232 // Value/Op before reshape op.
233 Value before_reshape_value = filter;
234 if (height % block_size != 0 || width % block_size != 0) {
235 // Calculate paddings for height and width.
236 int32_t pad_h = block_size - height % block_size;
237 int32_t pad_w = block_size - width % block_size;
238 auto pad_op =
239 GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
240 // Update op, height and width before reshape.
241 before_reshape_value = pad_op;
242 height = height + pad_h;
243 width = width + pad_w;
244 }
245
246 // Reshape.
247 SmallVector<int64_t, 6> new_shape = {
248 height / block_size, block_size, width / block_size,
249 block_size, channel, out_channel};
250 auto reshape_op =
251 GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
252
253 // Transpose.
254 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
255
256 // Reshape Back.
257 SmallVector<int64_t, 4> final_shape = {
258 height / block_size, width / block_size,
259 channel * block_size * block_size, out_channel};
260 auto final_reshape_op =
261 GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
262 // Update filter of Conv2D.
263 conv2d.setOperand(1, final_reshape_op);
264 }
265
266 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)267 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
268 ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
269 SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
270 old_filter_shape.end());
271 auto slice_result_type =
272 RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
273 auto slice_size_op = builder->create<TF::ConstOp>(
274 input.getLoc(),
275 DenseIntElementsAttr::get(
276 RankedTensorType::get({4}, builder->getIntegerType(32)),
277 old_filter_shape));
278 SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
279 auto start_position_type =
280 RankedTensorType::get({4}, builder->getIntegerType(64));
281 auto start_position = builder->create<TF::ConstOp>(
282 input.getLoc(),
283 DenseIntElementsAttr::get(start_position_type, slice_start_position));
284 return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
285 start_position, slice_size_op);
286 }
287
288 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)289 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
290 ArrayRef<int32_t> old_filter_shape,
291 ArrayRef<int32_t> new_filter_shape,
292 int64_t block_size) {
293 OpBuilder builder(backprop);
294 builder.setInsertionPoint(backprop);
295
296 auto input = backprop.input();
297 // Get new filter size from new_filter_shape.
298 auto new_filter_sizes = builder.create<TF::ConstOp>(
299 backprop.getLoc(),
300 DenseIntElementsAttr::get(
301 RankedTensorType::get({4}, builder.getIntegerType(32)),
302 new_filter_shape));
303
304 // Set stride to [1, 1, 1, 1].
305 MLIRContext* context = backprop.getContext();
306 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
307 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
308 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
309 });
310 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
311
312 // new result type.
313 SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
314 new_filter_shape.end());
315 auto new_result_type =
316 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
317
318 // Build new BackPropFilterOp.
319 auto loc = backprop.getLoc();
320 auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
321 loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
322 strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
323 backprop.explicit_paddings(), backprop.data_format(),
324 backprop.dilations());
325
326 // For example, if new filter shape is [4, 4, 12, 64], old filter shape
327 // is [7, 7, 3, 64] with block_size 2.
328 // Below transforms will be applied to the filter:
329 // 1. Reshape to [4, 4, 2, 2, 3, 64];
330 // 2. Transpose to [4, 2, 4, 2, 3, 64];
331 // 3. Reshape to [8, 8, 3, 64];
332 // 4. Slice to [7, 7, 3, 64].
333 SmallVector<int64_t, 6> first_reshape_shape = {
334 new_filter_shape[0],
335 new_filter_shape[1],
336 block_size,
337 block_size,
338 new_filter_shape[2] / (block_size * block_size),
339 new_filter_shape[3]};
340 auto first_reshape_op =
341 GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
342
343 // Transpose.
344 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
345
346 // Last Reshape op.
347 SmallVector<int64_t, 4> last_reshape_shape = {
348 new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
349 new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
350 auto final_reshape_op =
351 GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
352
353 // create slice op.
354 auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
355 final_reshape_op, &builder);
356
357 // Update backprop's user with the slice op.
358 backprop.replaceAllUsesWith(slice_op.getResult());
359 }
360
361 // Checks if the input producer op is supported in this transform. Right now, we
362 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)363 bool IsSupportedHostInputOp(Operation* op) {
364 TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
365 if (!iter) return false;
366 auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
367 if (!device) return false;
368 tensorflow::DeviceNameUtils::ParsedName parsed_device;
369 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
370 &parsed_device)) {
371 return false;
372 }
373 return parsed_device.type == "CPU";
374 }
375
376 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)377 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
378 Value input, int32_t block_size,
379 ArrayRef<int64_t> input_shape) {
380 auto input_op = input.getDefiningOp();
381 OpBuilder builder(input_op);
382 builder.setInsertionPointAfter(input_op);
383 SmallVector<int64_t, 4> transform_shape = {
384 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
385 input_shape[3] * block_size * block_size};
386 auto transform_result_type =
387 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
388 return builder.create<TF::SpaceToDepthOp>(
389 cluster_func.getLoc(), transform_result_type, input, block_size);
390 }
391
392 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)393 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
394 tf_device::ClusterFuncOp cluster_func,
395 int32_t block_size,
396 ArrayRef<int64_t> input_shape) {
397 auto space_to_depth =
398 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
399 cluster_func.setOperand(index, space_to_depth);
400 return space_to_depth;
401 }
402
403 // Performs transformation for replicated inputs. Returns true if this is a
404 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)405 bool HandleHostReplicatedInputs(int64_t index,
406 tf_device::ClusterFuncOp cluster_func,
407 BlockArgument block_arg,
408 tf_device::ReplicateOp replicate,
409 int32_t block_size) {
410 // We need to know the devices to copy to.
411 if (!replicate.devices()) return false;
412
413 MutableArrayRef<OpOperand> inputs =
414 replicate.GetOperandsForBlockArgument(block_arg);
415 for (auto& input : inputs) {
416 auto input_op = input.get().getDefiningOp();
417 if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
418 }
419 for (auto entry : llvm::enumerate(inputs)) {
420 Value input = entry.value().get();
421 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
422 if (!ranked_type) return false;
423 auto input_shape = ranked_type.getShape();
424 auto space_to_depth =
425 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
426 entry.value().set(space_to_depth);
427 block_arg.setType(space_to_depth.getType());
428 }
429 return true;
430 }
431
432 // Performs transformation on a pair of execute and compile ops. The compile
433 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)434 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
435 unsigned arg_num) {
436 auto maybe_replicate =
437 llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
438
439 llvm::SmallVector<int64_t, 8> transform_input_indices;
440 for (auto input : llvm::enumerate(cluster_func.operands())) {
441 if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
442 if (block_arg.getArgNumber() != arg_num) continue;
443 // For a block argument, consider transforms only when it is a replicated
444 // input (defining ops will be outside the replicate node).
445 if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
446 HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
447 maybe_replicate, block_size);
448 }
449 } else {
450 // For an op output, consider transforms only when 1) there is no
451 // replicateion or 2) it is outside the replicate node that encloses the
452 // execute node. (Because if the op is inside replicate, it is probably
453 // not on the host.)
454 if (input.index() != arg_num) continue;
455 auto input_op = input.value().getDefiningOp();
456 if (maybe_replicate &&
457 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
458 continue;
459 }
460 if (!IsSupportedHostInputOp(input_op)) continue;
461 auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
462 if (!ranked_type) continue;
463 auto input_shape = ranked_type.getShape();
464 HandleHostInput(input.value(), input.index(), cluster_func, block_size,
465 input_shape);
466 }
467 }
468 }
469
470 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)471 bool Conv2DInputShapeCanTransform(Value input) {
472 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
473 if (!ranked_type) return false;
474 auto input_shape = ranked_type.getShape();
475 int32_t batch_size = input_shape[0];
476 int32_t channel = input_shape[3];
477 if (batch_size > 8 || channel > 8) {
478 return false;
479 }
480 return true;
481 }
482
483 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)484 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
485 if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
486 if (!Conv2DInputShapeCanTransform(arg)) return None;
487 unsigned num_users =
488 std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
489 BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
490 return block_arg_info;
491 }
492 return None;
493 }
494
495 // Gets input block argument id and number of users for the input recursively.
496 // Current supported ops between convolution input and the block arguments are
497 // PadOp and CastOp.
GetInputBlockArgNum(Value input)498 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
499 auto block_arg_num = GetBlockArgNum(input);
500 if (block_arg_num.hasValue()) return block_arg_num;
501
502 Value next_input = input;
503 auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
504 auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
505
506 while (pad_op || cast_op) {
507 if (pad_op) {
508 auto block_arg_num = GetBlockArgNum(pad_op.input());
509 if (block_arg_num.hasValue()) return block_arg_num;
510 next_input = pad_op.input();
511 } else {
512 auto block_arg_num = GetBlockArgNum(cast_op.x());
513 if (block_arg_num.hasValue()) return block_arg_num;
514 next_input = cast_op.x();
515 }
516 pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
517 cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
518 }
519
520 return None;
521 }
522
523 // Checks if a convoluton can apply SpaceToDepth transform.
524 // Only the first convolution in the graph whose batch size smaller than 8
525 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)526 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
527 if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
528 return None;
529 }
530 // Current supported ops between convolution input and the block arguments are
531 // PadOp and CastOp.
532 return GetInputBlockArgNum(conv2d.input());
533 }
534
535 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)536 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
537 // Check if input and filter type are RankedTensorType.
538 auto input_tensor_type =
539 conv2d.input().getType().dyn_cast<RankedTensorType>();
540 auto filter_tensor_type =
541 conv2d.filter().getType().dyn_cast<RankedTensorType>();
542 if (!input_tensor_type || !filter_tensor_type) return;
543 // Book keeping filter shape for padding and backprop filter rewrite.
544 auto filter_shape = filter_tensor_type.getShape();
545 SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
546 filter_shape.end());
547 // Handles input.
548 auto conv2d_input = conv2d.input();
549 if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
550 // Change on device function type/shape.
551 HandleFuncOp(block_arg.getOwner()->getParentOp());
552 }
553
554 if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
555 // Rewrite pad_op before Convolutioin.
556 if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
557 auto pad_input = pad_op.input();
558 if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
559 // Change on device function type/shape.
560 HandleFuncOp(block_arg.getOwner()->getParentOp());
561 }
562 }
563
564 // Handle Conv2D input, stride and filter.
565 HandleConv2DInput(conv2d, block_size);
566 HandleConv2DStride(conv2d);
567 HandleConv2DFilter(conv2d, block_size);
568
569 // Book keeping new filter shape for backprop filter rewrite.
570 // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
571 filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
572 SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
573 filter_shape.end());
574
575 // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
576 if (!conv2d_input.getDefiningOp()) return;
577 for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
578 if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
579 HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
580 block_size);
581 }
582 }
583 }
584
585 // Gets block size that is equal to stride from spatial dimension
586 // from convolution.
587 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)588 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
589 SmallVector<int32_t, 4> strides(4, 1);
590 for (int i = 0; i < 3; ++i) {
591 strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
592 }
593
594 // Space to depth only supports striding at spatial dimension.
595 if (strides[0] != 1 || strides[3] != 1) return 1;
596
597 // Space to depth only supports height_stride == width_stride case.
598 if (strides[1] != strides[2]) return 1;
599
600 return strides[1];
601 }
602
runOnOperation()603 void TPUSpaceToDepthPass::runOnOperation() {
604 Optional<tf_device::ClusterFuncOp> cluster_func;
605 // Space to depth only supports training loop.
606 auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
607 cluster_func = cluster;
608 return WalkResult::interrupt();
609 });
610
611 // Return if there is no tf_device::ClusterFuncOp in training loop.
612 if (!func_result.wasInterrupted() || !cluster_func.hasValue()) {
613 return;
614 }
615
616 // Get the function on device.
617 auto device_func = cluster_func->getFunc();
618 if (!device_func) return;
619
620 TF::Conv2DOp first_conv;
621 // A map maps block argument id to the convolutions consumes them.
622 llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
623 argnum_and_convolutions;
624 // A map maps block argument id to the number of users.
625 llvm::SmallDenseMap<unsigned, int> argnum_num_users;
626
627 // Find out the qualified convolutions and its block argument ids.
628 auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
629 Optional<BlockArgumentInfo> arg_num_and_num_users =
630 GetConv2DInputArgNum(conv2d);
631 if (arg_num_and_num_users.hasValue()) {
632 // Get block size for the first convolution.
633 int64_t block_size = GetConv2DBlockSize(conv2d);
634 auto arg_num = arg_num_and_num_users.getValue().arg_num;
635 auto num_users = arg_num_and_num_users.getValue().num_users;
636 argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
637 argnum_num_users[arg_num] = num_users;
638 return WalkResult::interrupt();
639 }
640 return WalkResult::advance();
641 });
642 if (!conv2d_result.wasInterrupted()) {
643 return;
644 }
645
646 // Iterate through block argument and its convolution users. Space to depth
647 // transform will be applied only if all the below conditions are satisfied:
648 // 1. All the users of the block argument will lead to convolutions;
649 // 2. block_size of for the space to depth transform for these convolutions
650 // are the same;
651 // 3. block_size of for the space to depth transform for these convolutions
652 // are larger than 1.
653 for (auto argnum_and_convolution : argnum_and_convolutions) {
654 auto arg_num = argnum_and_convolution.getFirst();
655 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
656 // Continue if number of users of the block argment doesn't equal to number
657 // of transformable convolutions and there is no qualified convolution
658 // for transform or block size is smaller than 2.
659 if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
660 conv2d_and_block_sizes.empty()) {
661 argnum_and_convolutions.erase(arg_num);
662 continue;
663 }
664 int64_t block_size = conv2d_and_block_sizes[0].second;
665 if (block_size < 2) {
666 argnum_and_convolutions.erase(arg_num);
667 continue;
668 }
669 // Continue if not all the block sizes for space to depth transform are the
670 // same.
671 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
672 if (conv2d_and_block_size.second != block_size) {
673 argnum_and_convolutions.erase(arg_num);
674 break;
675 }
676 }
677 }
678
679 // If there is no qualified space to depth transform.
680 if (argnum_and_convolutions.empty()) {
681 return;
682 }
683
684 // Apply space to depth transform.
685 for (auto argnum_and_convolution : argnum_and_convolutions) {
686 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
687 int64_t block_size = conv2d_and_block_sizes[0].second;
688 // Apply space to depth transform to the input on the host.
689 HandleCluster(cluster_func.getValue(), block_size,
690 argnum_and_convolution.getFirst());
691 // Transform the convolution.
692 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
693 HandleFirstConvolution(conv2d_and_block_size.first,
694 conv2d_and_block_size.second);
695 }
696 }
697 }
698
699 } // namespace
700
CreateTPUSpaceToDepthPass()701 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
702 return std::make_unique<TPUSpaceToDepthPass>();
703 }
704
705 } // namespace TFTPU
706 } // namespace mlir
707