1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
20
21 #include <cstdint>
22
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Support/LogicalResult.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34
35 namespace mlir {
36 namespace TFL {
37
38 // A dilated convolution can be emulated with a regular convolution by chaining
39 // SpaceToBatch and BatchToSpace ops before and after it:
40 //
41 // SpaceToBatchND -> Conv2D -> BatchToSpaceND
42 //
43 // This method was common before Conv2D fully supported dilated convolution in
44 // TensorFlow. This transformation detects this "emulation", and replaces it
45 // with a true dilated convolution, eliminating the SpaceToBatch and
46 // BatchtoSpace ops.
47 //
48 // Detecting this alone would be relatively easy. However, in practice some
49 // extra ops are used, so we detect the following patterns:
50 //
51 //
52 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
53 //
54 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
55 // BiasAdd
56 //
57 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
58 //
59 // SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
60 //
61 // SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
62 //
63 //
64 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
65 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
66 // thrown in just for the extra headache. Padding adapts non-conforming input
67 // sizes, and can be discarded. The bias is necessary, so is kept.
68 template <typename Conv2dOpTy>
69 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
70 private:
71 using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
72
73 // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
74 llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
75 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
76 PatternRewriter& rewriter) const;
77
78 public:
79 LogicalResult matchAndRewrite(Conv2dOpTy op,
80 PatternRewriter& rewriter) const override;
81 };
82
83 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)84 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
85 Conv2dOpTy op, PatternRewriter& rewriter) const {
86 if (!op.getResult().hasOneUse()) {
87 return rewriter.notifyMatchFailure(
88 op, "result for current op has more than 1 use");
89 }
90 // Make sure Conv2D has 'VALID' padding.
91 if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
92 return rewriter.notifyMatchFailure(op,
93 "Conv2D op doesn't have valid padding");
94 }
95 // Make sure dilations are all ones if set.
96 const ArrayAttr& dilations =
97 op->template getAttrOfType<ArrayAttr>("dilations");
98 if (dilations && !TFIntListIsAllOnes(dilations)) {
99 return rewriter.notifyMatchFailure(op, "dilations should be all 1");
100 }
101
102 if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
103 return rewriter.notifyMatchFailure(
104 op, "op's input is not float or the data format isn't NHWC");
105 }
106
107 // Allow dynamic width and height dimensions only.
108 auto result_ty = op.getResult().getType().template cast<TensorType>();
109 if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
110 result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
111 return rewriter.notifyMatchFailure(
112 op, "only dynamic width and height dimensions are allowed");
113 }
114
115 // Check if the ConvOp's input is defined by `Expand` op, and the output used
116 // by `Squeeze` op.
117 Operation* producer_op = op.getOperand(0).getDefiningOp();
118 if (!producer_op || producer_op->getNumResults() != 1) {
119 return rewriter.notifyMatchFailure(
120 op, "op doesn't have a producer node that has a single result");
121 }
122 if (!producer_op->hasOneUse() ||
123 *(producer_op->getResult(0).user_begin()) != op) {
124 return rewriter.notifyMatchFailure(
125 op, "op's input isn't produced by previous operation");
126 }
127
128 auto tryGetDirectConsumerOp =
129 [&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
130 // Check the current operation has a single result.
131 if (current->getNumResults() != 1) {
132 return {
133 rewriter.notifyMatchFailure(current, "op doesn't have single result"),
134 nullptr};
135 }
136 // Check the current operation has a consumer node.
137 Operation* consumer_op =
138 current->getResult(0).getUses().begin()->getOwner();
139 if (!consumer_op) {
140 return {
141 rewriter.notifyMatchFailure(current, "op doesn't have consumer node"),
142 nullptr};
143 }
144 // Check the current operation's result is used by its successor node.
145 if (!current->hasOneUse() ||
146 *(current->getResult(0).user_begin()) != consumer_op) {
147 return {
148 rewriter.notifyMatchFailure(
149 current, "op's result isn't directly consumed by the next op"),
150 nullptr};
151 }
152 return {LogicalResult::success(), consumer_op};
153 };
154
155 std::pair<LogicalResult, Operation*> maybeConsumer =
156 tryGetDirectConsumerOp(op.getOperation());
157 if (failed(maybeConsumer.first)) {
158 return maybeConsumer.first;
159 }
160 Operation* consumer_op = maybeConsumer.second;
161
162 TF::ExpandDimsOp expand_op;
163 TF::SqueezeOp squeeze_op;
164 int64_t expand_axis = -1;
165 // Expand + Squeeze op.
166 if (llvm::isa<TF::ExpandDimsOp>(producer_op)) {
167 if (!llvm::isa<TF::SqueezeOp>(consumer_op)) {
168 // Expand/Squeeze op must come in pair.
169 return rewriter.notifyMatchFailure(
170 op, "ExpandDimsOp and SqueezeOp should come in pair");
171 }
172 expand_op = llvm::cast<TF::ExpandDimsOp>(producer_op);
173 squeeze_op = llvm::cast<TF::SqueezeOp>(consumer_op);
174 if (!expand_op.getResult().hasOneUse()) {
175 return rewriter.notifyMatchFailure(
176 expand_op, "result for current op has more than 1 use");
177 }
178 if (!squeeze_op.getResult().hasOneUse()) {
179 return rewriter.notifyMatchFailure(
180 squeeze_op, "result for current op has more than 1 use");
181 }
182 // Make sure that the axis in `expand_op` is constant.
183 if (auto const_op =
184 llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
185 expand_axis =
186 (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
187 .getSExtValue();
188 // Canonicalize axis. Some TF python functions, such as
189 // `tf.nn.convolution`, use negative axis.
190 if (expand_axis < 0) {
191 // Always expand 3D input to 4D input.
192 expand_axis += 4;
193 }
194 } else {
195 return rewriter.notifyMatchFailure(
196 expand_op, "ExpandDimsOp doesn't have a constant axis");
197 }
198 // Make sure that the `squeeze_dims` is equal to `expand_axis`.
199 auto squeeze_dims = squeeze_op.squeeze_dims();
200 if (squeeze_dims.size() != 1) {
201 return rewriter.notifyMatchFailure(
202 squeeze_op, "squeeze dims should have exactly 1 dimension specified");
203 }
204 int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
205 if (squeeze_axis < 0) {
206 // Always squeeze 4D input to 3D input.
207 squeeze_axis += 4;
208 }
209 if (squeeze_axis != expand_axis) {
210 return rewriter.notifyMatchFailure(
211 op, "squeeze axis and expand axis doesn't match");
212 }
213
214 // Update previous/next op pointer.
215 Operation* tmp = expand_op.input().getDefiningOp();
216 if (!tmp || tmp->getNumResults() != 1) {
217 return rewriter.notifyMatchFailure(
218 producer_op,
219 "op doesn't have a producer node that has a single result");
220 }
221 if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != producer_op) {
222 return rewriter.notifyMatchFailure(
223 producer_op, "op's input isn't defined by its previous node");
224 }
225 producer_op = tmp;
226 std::pair<LogicalResult, Operation*> maybeConsumer =
227 tryGetDirectConsumerOp(consumer_op);
228 if (failed(maybeConsumer.first)) {
229 return maybeConsumer.first;
230 }
231 consumer_op = maybeConsumer.second;
232 }
233
234 // SpaceToBatchND op.
235 if (!llvm::isa<TF::SpaceToBatchNDOp>(producer_op)) {
236 return rewriter.notifyMatchFailure(producer_op,
237 "op should be a SpaceToBatchND op");
238 }
239 // TODO(b/149936532): Check `padding` input, currently ignored.
240 TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(producer_op);
241 if (!stb_op.getResult().hasOneUse()) {
242 return rewriter.notifyMatchFailure(
243 stb_op, "result for current op has more than 1 use");
244 }
245
246 // Pad op.
247 TF::PadOp pad_op;
248 ElementsAttr pad_attr;
249 if (llvm::isa<TF::PadOp>(consumer_op)) {
250 pad_op = llvm::cast<TF::PadOp>(consumer_op);
251 if (!pad_op.getResult().hasOneUse()) {
252 return rewriter.notifyMatchFailure(
253 pad_op, "result for current op has more than 1 use");
254 }
255 std::pair<LogicalResult, Operation*> maybeConsumer =
256 tryGetDirectConsumerOp(consumer_op);
257 if (failed(maybeConsumer.first)) {
258 return maybeConsumer.first;
259 }
260 consumer_op = maybeConsumer.second;
261 if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
262 // If the padding value isn't constant, we can't determine the padding
263 // scheme for Conv2D below, in this case just reject the pattern.
264 return rewriter.notifyMatchFailure(
265 pad_op, "PadOp's padding value isn't constant");
266 }
267 }
268
269 // BatchToSpaceND + BiasAdd.
270 TF::BatchToSpaceNDOp bts_op;
271 TF::BiasAddOp biasadd_op;
272 bool final_op_is_bts = true;
273 if (llvm::isa<TF::BiasAddOp>(consumer_op)) {
274 // Must be BiasAdd + BatchToSpaceND.
275 biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
276 if (!biasadd_op.getResult().hasOneUse()) {
277 return rewriter.notifyMatchFailure(
278 biasadd_op, "result for current op has more than 1 use");
279 }
280 std::pair<LogicalResult, Operation*> maybeConsumer =
281 tryGetDirectConsumerOp(consumer_op);
282 if (failed(maybeConsumer.first)) {
283 return maybeConsumer.first;
284 }
285 if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeConsumer.second)) {
286 return rewriter.notifyMatchFailure(
287 consumer_op, "op's next node isn't BatchToSpaceND op");
288 }
289 consumer_op = maybeConsumer.second;
290 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
291 } else if (llvm::isa<TF::BatchToSpaceNDOp>(consumer_op)) {
292 // BatchToSpaceND + (optional) BiasAdd.
293 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
294 std::pair<LogicalResult, Operation*> maybeConsumer =
295 tryGetDirectConsumerOp(consumer_op);
296 Operation* tmp = maybeConsumer.second;
297 if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
298 consumer_op = tmp;
299 biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
300 final_op_is_bts = false;
301 }
302 } else {
303 return rewriter.notifyMatchFailure(
304 consumer_op, "next op is neither BiasAdd nor BatchToSpaceND");
305 }
306
307 llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
308 stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
309 if (!dilations_attr.hasValue()) {
310 return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
311 }
312
313 if (expand_op) {
314 if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
315 return rewriter.notifyMatchFailure(
316 stb_op, "SpaceToBatchND op's input should have RankedTensorType");
317 }
318 }
319
320 // TODO(b/149936532): Check that the input width & height are multiples of
321 // dilation rate.
322 // TF python library will rewrite dilated conv to
323 // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
324 // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
325 // parts of contributions, one is to reduce padding of CONV from 'SAME' to
326 // 'VALID', and another is to make input shape multiples of dilation rate. The
327 // first part of padding, which is also called `base_padding` will be used
328 // here to determine if the original padding format is 'SAME' or 'VALID'.
329 // According to the following formula we will compute the `base_padding` if
330 // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
331 // tensor in `BatchToSpace` must satisfy the following:
332 // paddings[i, 0] = base_paddings[i, 0].
333 // 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
334 // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
335 // crops[i, 0] = 0.
336 // crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
337
338 // If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
339 // tells us the original padding is 'SAME' (with one caveat presented below).
340 // Here we need to reset the padding back to `SAME` if `base_padding`
341 // != 0.
342 // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
343 // determine the original padding format. For example, users can build
344 // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
345 // dilated conv, hence we shouldn't pattern match here. Instead, we need to
346 // check values of `paddings` and `crops` to make sure it really stands for
347 // a dilated conv.
348 auto stb_paddings = stb_op.paddings();
349 auto bts_crops = bts_op.crops();
350 ElementsAttr stb_paddings_attr, bts_crops_attr;
351 if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
352 !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
353 return rewriter.notifyMatchFailure(
354 op,
355 "either SpaceToBatchND or BatchToSpaceND "
356 "doesn't have constant padding/crops value");
357 }
358 if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
359 return rewriter.notifyMatchFailure(
360 stb_op,
361 "SpaceToBatchND op's padding doesn't have same shape/type with "
362 "BatchToSpaceND op's crops");
363 }
364 int64_t m = stb_paddings_attr.getType().getDimSize(0);
365 // padding - crop.
366 for (uint64_t i = 0; i < m; ++i) {
367 for (uint64_t j = 0; j < 2; ++j) {
368 // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
369 // specifies the amount to crop from input dimension i + 1. If the input
370 // of `BatchToSpaceND` has been padded explicitly, then we need to
371 // take into account the additional padding when determining the padding
372 // scheme for `Conv2D`.
373 int64_t addtional_pad =
374 pad_attr ? pad_attr.getValue<IntegerAttr>({i + 1, j}).getInt() : 0;
375 if (stb_paddings_attr.getValue<IntegerAttr>({i, j}).getInt() +
376 addtional_pad !=
377 bts_crops_attr.getValue<IntegerAttr>({i, j}).getInt()) {
378 op->setAttr("padding", rewriter.getStringAttr("SAME"));
379 break;
380 }
381 }
382 }
383
384 // Set dilations
385 op->setAttr("dilations", dilations_attr.getValue());
386
387 if (expand_op) {
388 // If there is `expand_op`, we need to rewire the inputs to bypass the
389 // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
390 // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
391 // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
392
393 // Connect `expand_op` with the input of `stb_op`.
394 expand_op.setOperand(0, stb_op.input());
395 // Calculate the shape for expand.
396 auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
397 SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
398 input_shape.end());
399 expand_shape.insert(expand_shape.begin() + expand_axis, 1);
400
401 auto expand_result_type = RankedTensorType::get(
402 expand_shape, getElementTypeOrSelf(stb_op.input()));
403 expand_op.getResult().setType(expand_result_type);
404
405 // Update the conv op's output shape.
406 auto bts_output_shape =
407 bts_op.output().getType().cast<ShapedType>().getShape();
408 SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
409 bts_output_shape.end());
410 conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
411 auto conv_result_type = RankedTensorType::get(
412 conv_result_shape, getElementTypeOrSelf(stb_op.input()));
413 op.getResult().setType(conv_result_type);
414
415 squeeze_op.getResult().setType(bts_op.output().getType());
416
417 // Connect `biasadd_op` with the output of `squeeze_op`.
418 if (biasadd_op) {
419 biasadd_op.setOperand(0, squeeze_op.output());
420 biasadd_op.output().setType(squeeze_op.output().getType());
421 }
422 } else {
423 if (biasadd_op) biasadd_op.setOperand(0, op.output());
424 op.setOperand(0, stb_op.input());
425 op.getResult().setType(bts_op.getResult().getType());
426 }
427
428 if (final_op_is_bts) {
429 if (bts_op.input().getDefiningOp<TF::PadOp>()) {
430 bts_op.getResult().replaceAllUsesWith(pad_op.input());
431 } else {
432 bts_op.getResult().replaceAllUsesWith(bts_op.input());
433 }
434 }
435
436 stb_op.getResult().dropAllUses();
437 return success();
438 }
439
440 template <typename Conv2dOpTy>
441 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)442 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
443 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
444 PatternRewriter& rewriter) const {
445 ElementsAttr stb_bs_attr, bts_bs_attr;
446 if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
447 !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
448 // Returns failure status if block_shape is not a constant.
449 return {};
450 }
451 // Check that the block_shape of `stb_op` and `bts_op` are equal.
452 if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
453 for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
454 if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
455 }
456
457 int dilation_h_factor = -1, dilation_w_factor = -1;
458 // Set dilation factor.
459 if (stb_bs_attr.getNumElements() >= 2) {
460 dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
461 dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
462 } else if (stb_bs_attr.getNumElements() == 1) {
463 // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
464 // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
465 // dilation factor of W dim, and dilation factor of H dim is set to 1.
466 if (expand_axis == 1) {
467 // NWC -> NHWC
468 dilation_h_factor = 1;
469 dilation_w_factor =
470 stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
471 } else if (expand_axis == 2) {
472 // NHC -> NHWC
473 dilation_h_factor =
474 stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
475 dilation_w_factor = 1;
476 }
477 }
478
479 if (dilation_h_factor == -1 || dilation_w_factor == -1) {
480 return {};
481 }
482
483 return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
484 }
485
486 } // namespace TFL
487 } // namespace mlir
488
489 #endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
490