• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
20 
21 #include <cstdint>
22 
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 
35 namespace mlir {
36 namespace TFL {
37 
38 // A dilated convolution can be emulated with a regular convolution by chaining
39 // SpaceToBatch and BatchToSpace ops before and after it:
40 //
41 //     SpaceToBatchND -> Conv2D -> BatchToSpaceND
42 //
43 // This method was common before Conv2D fully supported dilated convolution in
44 // TensorFlow. This transformation detects this "emulation", and replaces it
45 // with a true dilated convolution, eliminating the SpaceToBatch and
46 // BatchtoSpace ops.
47 //
48 // Detecting this alone would be relatively easy. However, in practice some
49 // extra ops are used, so we detect the following patterns:
50 //
51 //
52 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
53 //
54 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
55 //   BiasAdd
56 //
57 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
58 //
59 //   SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
60 //
61 //   SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
62 //
63 //
64 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
65 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
66 // thrown in just for the extra headache. Padding adapts non-conforming input
67 // sizes, and can be discarded. The bias is necessary, so is kept.
68 template <typename Conv2dOpTy>
69 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
70  private:
71   using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
72 
73   // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
74   llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
75       Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
76       PatternRewriter& rewriter) const;
77 
78  public:
79   LogicalResult matchAndRewrite(Conv2dOpTy op,
80                                 PatternRewriter& rewriter) const override;
81 };
82 
83 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)84 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
85     Conv2dOpTy op, PatternRewriter& rewriter) const {
86   if (!op.getResult().hasOneUse()) {
87     return rewriter.notifyMatchFailure(
88         op, "result for current op has more than 1 use");
89   }
90   // Make sure Conv2D has 'VALID' padding.
91   if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
92     return rewriter.notifyMatchFailure(op,
93                                        "Conv2D op doesn't have valid padding");
94   }
95   // Make sure dilations are all ones if set.
96   const ArrayAttr& dilations =
97       op->template getAttrOfType<ArrayAttr>("dilations");
98   if (dilations && !TFIntListIsAllOnes(dilations)) {
99     return rewriter.notifyMatchFailure(op, "dilations should be all 1");
100   }
101 
102   if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
103     return rewriter.notifyMatchFailure(
104         op, "op's input is not float or the data format isn't NHWC");
105   }
106 
107   // Allow dynamic width and height dimensions only.
108   auto result_ty = op.getResult().getType().template cast<TensorType>();
109   if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
110       result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
111     return rewriter.notifyMatchFailure(
112         op, "only dynamic width and height dimensions are allowed");
113   }
114 
115   // Check if the ConvOp's input is defined by `Expand` op, and the output used
116   // by `Squeeze` op.
117   Operation* producer_op = op.getOperand(0).getDefiningOp();
118   if (!producer_op || producer_op->getNumResults() != 1) {
119     return rewriter.notifyMatchFailure(
120         op, "op doesn't have a producer node that has a single result");
121   }
122   if (!producer_op->hasOneUse() ||
123       *(producer_op->getResult(0).user_begin()) != op) {
124     return rewriter.notifyMatchFailure(
125         op, "op's input isn't produced by previous operation");
126   }
127 
128   auto tryGetDirectConsumerOp =
129       [&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
130     // Check the current operation has a single result.
131     if (current->getNumResults() != 1) {
132       return {
133           rewriter.notifyMatchFailure(current, "op doesn't have single result"),
134           nullptr};
135     }
136     // Check the current operation has a consumer node.
137     Operation* consumer_op =
138         current->getResult(0).getUses().begin()->getOwner();
139     if (!consumer_op) {
140       return {
141           rewriter.notifyMatchFailure(current, "op doesn't have consumer node"),
142           nullptr};
143     }
144     // Check the current operation's result is used by its successor node.
145     if (!current->hasOneUse() ||
146         *(current->getResult(0).user_begin()) != consumer_op) {
147       return {
148           rewriter.notifyMatchFailure(
149               current, "op's result isn't directly consumed by the next op"),
150           nullptr};
151     }
152     return {LogicalResult::success(), consumer_op};
153   };
154 
155   std::pair<LogicalResult, Operation*> maybeConsumer =
156       tryGetDirectConsumerOp(op.getOperation());
157   if (failed(maybeConsumer.first)) {
158     return maybeConsumer.first;
159   }
160   Operation* consumer_op = maybeConsumer.second;
161 
162   TF::ExpandDimsOp expand_op;
163   TF::SqueezeOp squeeze_op;
164   int64_t expand_axis = -1;
165   // Expand + Squeeze op.
166   if (llvm::isa<TF::ExpandDimsOp>(producer_op)) {
167     if (!llvm::isa<TF::SqueezeOp>(consumer_op)) {
168       // Expand/Squeeze op must come in pair.
169       return rewriter.notifyMatchFailure(
170           op, "ExpandDimsOp and SqueezeOp should come in pair");
171     }
172     expand_op = llvm::cast<TF::ExpandDimsOp>(producer_op);
173     squeeze_op = llvm::cast<TF::SqueezeOp>(consumer_op);
174     if (!expand_op.getResult().hasOneUse()) {
175       return rewriter.notifyMatchFailure(
176           expand_op, "result for current op has more than 1 use");
177     }
178     if (!squeeze_op.getResult().hasOneUse()) {
179       return rewriter.notifyMatchFailure(
180           squeeze_op, "result for current op has more than 1 use");
181     }
182     // Make sure that the axis in `expand_op` is constant.
183     if (auto const_op =
184             llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
185       expand_axis =
186           (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
187               .getSExtValue();
188       // Canonicalize axis. Some TF python functions, such as
189       // `tf.nn.convolution`, use negative axis.
190       if (expand_axis < 0) {
191         // Always expand 3D input to 4D input.
192         expand_axis += 4;
193       }
194     } else {
195       return rewriter.notifyMatchFailure(
196           expand_op, "ExpandDimsOp doesn't have a constant axis");
197     }
198     // Make sure that the `squeeze_dims` is equal to `expand_axis`.
199     auto squeeze_dims = squeeze_op.squeeze_dims();
200     if (squeeze_dims.size() != 1) {
201       return rewriter.notifyMatchFailure(
202           squeeze_op, "squeeze dims should have exactly 1 dimension specified");
203     }
204     int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
205     if (squeeze_axis < 0) {
206       // Always squeeze 4D input to 3D input.
207       squeeze_axis += 4;
208     }
209     if (squeeze_axis != expand_axis) {
210       return rewriter.notifyMatchFailure(
211           op, "squeeze axis and expand axis doesn't match");
212     }
213 
214     // Update previous/next op pointer.
215     Operation* tmp = expand_op.input().getDefiningOp();
216     if (!tmp || tmp->getNumResults() != 1) {
217       return rewriter.notifyMatchFailure(
218           producer_op,
219           "op doesn't have a producer node that has a single result");
220     }
221     if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != producer_op) {
222       return rewriter.notifyMatchFailure(
223           producer_op, "op's input isn't defined by its previous node");
224     }
225     producer_op = tmp;
226     std::pair<LogicalResult, Operation*> maybeConsumer =
227         tryGetDirectConsumerOp(consumer_op);
228     if (failed(maybeConsumer.first)) {
229       return maybeConsumer.first;
230     }
231     consumer_op = maybeConsumer.second;
232   }
233 
234   // SpaceToBatchND op.
235   if (!llvm::isa<TF::SpaceToBatchNDOp>(producer_op)) {
236     return rewriter.notifyMatchFailure(producer_op,
237                                        "op should be a SpaceToBatchND op");
238   }
239   // TODO(b/149936532): Check `padding` input, currently ignored.
240   TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(producer_op);
241   if (!stb_op.getResult().hasOneUse()) {
242     return rewriter.notifyMatchFailure(
243         stb_op, "result for current op has more than 1 use");
244   }
245 
246   // Pad op.
247   TF::PadOp pad_op;
248   ElementsAttr pad_attr;
249   if (llvm::isa<TF::PadOp>(consumer_op)) {
250     pad_op = llvm::cast<TF::PadOp>(consumer_op);
251     if (!pad_op.getResult().hasOneUse()) {
252       return rewriter.notifyMatchFailure(
253           pad_op, "result for current op has more than 1 use");
254     }
255     std::pair<LogicalResult, Operation*> maybeConsumer =
256         tryGetDirectConsumerOp(consumer_op);
257     if (failed(maybeConsumer.first)) {
258       return maybeConsumer.first;
259     }
260     consumer_op = maybeConsumer.second;
261     if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
262       // If the padding value isn't constant, we can't determine the padding
263       // scheme for Conv2D below, in this case just reject the pattern.
264       return rewriter.notifyMatchFailure(
265           pad_op, "PadOp's padding value isn't constant");
266     }
267   }
268 
269   // BatchToSpaceND + BiasAdd.
270   TF::BatchToSpaceNDOp bts_op;
271   TF::BiasAddOp biasadd_op;
272   bool final_op_is_bts = true;
273   if (llvm::isa<TF::BiasAddOp>(consumer_op)) {
274     // Must be BiasAdd + BatchToSpaceND.
275     biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
276     if (!biasadd_op.getResult().hasOneUse()) {
277       return rewriter.notifyMatchFailure(
278           biasadd_op, "result for current op has more than 1 use");
279     }
280     std::pair<LogicalResult, Operation*> maybeConsumer =
281         tryGetDirectConsumerOp(consumer_op);
282     if (failed(maybeConsumer.first)) {
283       return maybeConsumer.first;
284     }
285     if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeConsumer.second)) {
286       return rewriter.notifyMatchFailure(
287           consumer_op, "op's next node isn't BatchToSpaceND op");
288     }
289     consumer_op = maybeConsumer.second;
290     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
291   } else if (llvm::isa<TF::BatchToSpaceNDOp>(consumer_op)) {
292     // BatchToSpaceND + (optional) BiasAdd.
293     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
294     std::pair<LogicalResult, Operation*> maybeConsumer =
295         tryGetDirectConsumerOp(consumer_op);
296     Operation* tmp = maybeConsumer.second;
297     if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
298       consumer_op = tmp;
299       biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
300       final_op_is_bts = false;
301     }
302   } else {
303     return rewriter.notifyMatchFailure(
304         consumer_op, "next op is neither BiasAdd nor BatchToSpaceND");
305   }
306 
307   llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
308       stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
309   if (!dilations_attr.hasValue()) {
310     return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
311   }
312 
313   if (expand_op) {
314     if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
315       return rewriter.notifyMatchFailure(
316           stb_op, "SpaceToBatchND op's input should have RankedTensorType");
317     }
318   }
319 
320   // TODO(b/149936532): Check that the input width & height are multiples of
321   // dilation rate.
322   // TF python library will rewrite dilated conv to
323   // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
324   // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
325   // parts of contributions, one is to reduce padding of CONV from 'SAME' to
326   // 'VALID', and another is to make input shape multiples of dilation rate. The
327   // first part of padding, which is also called `base_padding` will be used
328   // here to determine if the original padding format is 'SAME' or 'VALID'.
329   // According to the following formula we will compute the `base_padding` if
330   // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
331   // tensor  in `BatchToSpace` must satisfy the following:
332   //  paddings[i, 0] = base_paddings[i, 0].
333   //  0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
334   // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
335   //  crops[i, 0] = 0.
336   //  crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
337 
338   //  If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
339   // tells us the original padding is 'SAME' (with one caveat presented below).
340   // Here we need to reset the padding back to `SAME` if `base_padding`
341   // != 0.
342   // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
343   // determine the original padding format. For example, users can build
344   // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
345   // dilated conv, hence we shouldn't pattern match here. Instead, we need to
346   // check values of `paddings` and `crops` to make sure it really stands for
347   // a dilated conv.
348   auto stb_paddings = stb_op.paddings();
349   auto bts_crops = bts_op.crops();
350   ElementsAttr stb_paddings_attr, bts_crops_attr;
351   if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
352       !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
353     return rewriter.notifyMatchFailure(
354         op,
355         "either SpaceToBatchND or BatchToSpaceND "
356         "doesn't have constant padding/crops value");
357   }
358   if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
359     return rewriter.notifyMatchFailure(
360         stb_op,
361         "SpaceToBatchND op's padding doesn't have same shape/type with "
362         "BatchToSpaceND op's crops");
363   }
364   int64_t m = stb_paddings_attr.getType().getDimSize(0);
365   // padding - crop.
366   for (uint64_t i = 0; i < m; ++i) {
367     for (uint64_t j = 0; j < 2; ++j) {
368       // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
369       // specifies the amount to crop from input dimension i + 1. If the input
370       // of `BatchToSpaceND` has been padded explicitly, then we need to
371       // take into account the additional padding when determining the padding
372       // scheme for `Conv2D`.
373       int64_t addtional_pad =
374           pad_attr ? pad_attr.getValue<IntegerAttr>({i + 1, j}).getInt() : 0;
375       if (stb_paddings_attr.getValue<IntegerAttr>({i, j}).getInt() +
376               addtional_pad !=
377           bts_crops_attr.getValue<IntegerAttr>({i, j}).getInt()) {
378         op->setAttr("padding", rewriter.getStringAttr("SAME"));
379         break;
380       }
381     }
382   }
383 
384   // Set dilations
385   op->setAttr("dilations", dilations_attr.getValue());
386 
387   if (expand_op) {
388     // If there is `expand_op`, we need to rewire the inputs to bypass the
389     // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
390     // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
391     // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
392 
393     // Connect `expand_op` with the input of `stb_op`.
394     expand_op.setOperand(0, stb_op.input());
395     // Calculate the shape for expand.
396     auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
397     SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
398                                          input_shape.end());
399     expand_shape.insert(expand_shape.begin() + expand_axis, 1);
400 
401     auto expand_result_type = RankedTensorType::get(
402         expand_shape, getElementTypeOrSelf(stb_op.input()));
403     expand_op.getResult().setType(expand_result_type);
404 
405     // Update the conv op's output shape.
406     auto bts_output_shape =
407         bts_op.output().getType().cast<ShapedType>().getShape();
408     SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
409                                               bts_output_shape.end());
410     conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
411     auto conv_result_type = RankedTensorType::get(
412         conv_result_shape, getElementTypeOrSelf(stb_op.input()));
413     op.getResult().setType(conv_result_type);
414 
415     squeeze_op.getResult().setType(bts_op.output().getType());
416 
417     // Connect `biasadd_op` with the output of `squeeze_op`.
418     if (biasadd_op) {
419       biasadd_op.setOperand(0, squeeze_op.output());
420       biasadd_op.output().setType(squeeze_op.output().getType());
421     }
422   } else {
423     if (biasadd_op) biasadd_op.setOperand(0, op.output());
424     op.setOperand(0, stb_op.input());
425     op.getResult().setType(bts_op.getResult().getType());
426   }
427 
428   if (final_op_is_bts) {
429     if (bts_op.input().getDefiningOp<TF::PadOp>()) {
430       bts_op.getResult().replaceAllUsesWith(pad_op.input());
431     } else {
432       bts_op.getResult().replaceAllUsesWith(bts_op.input());
433     }
434   }
435 
436   stb_op.getResult().dropAllUses();
437   return success();
438 }
439 
440 template <typename Conv2dOpTy>
441 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)442 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
443     Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
444     PatternRewriter& rewriter) const {
445   ElementsAttr stb_bs_attr, bts_bs_attr;
446   if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
447       !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
448     // Returns failure status if block_shape is not a constant.
449     return {};
450   }
451   // Check that the block_shape of `stb_op` and `bts_op` are equal.
452   if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
453   for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
454     if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
455   }
456 
457   int dilation_h_factor = -1, dilation_w_factor = -1;
458   // Set dilation factor.
459   if (stb_bs_attr.getNumElements() >= 2) {
460     dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
461     dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
462   } else if (stb_bs_attr.getNumElements() == 1) {
463     // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
464     // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
465     // dilation factor of W dim, and dilation factor of H dim is set to 1.
466     if (expand_axis == 1) {
467       // NWC -> NHWC
468       dilation_h_factor = 1;
469       dilation_w_factor =
470           stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
471     } else if (expand_axis == 2) {
472       // NHC -> NHWC
473       dilation_h_factor =
474           stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
475       dilation_w_factor = 1;
476     }
477   }
478 
479   if (dilation_h_factor == -1 || dilation_w_factor == -1) {
480     return {};
481   }
482 
483   return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
484 }
485 
486 }  // namespace TFL
487 }  // namespace mlir
488 
489 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
490