• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
20 
21 #include <cstdint>
22 
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Matchers.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 
33 namespace mlir {
34 namespace TFL {
35 
36 // A dilated convolution can be emulated with a regular convolution by chaining
37 // SpaceToBatch and BatchToSpace ops before and after it:
38 //
39 //     SpaceToBatchND -> Conv2D -> BatchToSpaceND
40 //
41 // This method was common before Conv2D fully supported dilated convolution in
42 // TensorFlow. This transformation detects this "emulation", and replaces it
43 // with a true dilated convolution, eliminating the SpaceToBatch and
44 // BatchtoSpace ops.
45 //
46 // Detecting this alone would be relatively easy. However, in practice some
47 // extra ops are used, so we detect the following patterns:
48 //
49 //
50 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
51 //
52 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
53 //   BiasAdd
54 //
55 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
56 //
57 //   SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
58 //
59 //   SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
60 //
61 //
62 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
63 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
64 // thrown in just for the extra headache. Padding adapts non-conforming input
65 // sizes, and can be discarded. The bias is necessary, so is kept.
66 template <typename Conv2dOpTy>
67 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
68  private:
69   using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
70 
71   // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
72   llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
73       Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
74       PatternRewriter& rewriter) const;
75 
76  public:
77   LogicalResult matchAndRewrite(Conv2dOpTy op,
78                                 PatternRewriter& rewriter) const override;
79 };
80 
81 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)82 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
83     Conv2dOpTy op, PatternRewriter& rewriter) const {
84   // Make sure Conv2D has 'VALID' padding.
85   if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
86     return failure();
87   }
88   // Make sure dilations are all ones if set.
89   const ArrayAttr& dilations =
90       op->template getAttrOfType<ArrayAttr>("dilations");
91   if (dilations && !TFIntListIsAllOnes(dilations)) {
92     return failure();
93   }
94 
95   if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
96     return failure();
97 
98   // Allow dynamic width and height dimensions only.
99   auto result_ty = op.getResult().getType().template cast<TensorType>();
100   if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
101       result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3))
102     return failure();
103 
104   // Check if the ConvOp is preceded by a `Expand` op and succeeded by a
105   // `Squeeze` op.
106   Operation* prev_op = op.getOperation()->getPrevNode();
107   if (!prev_op) return failure();
108 
109   Operation* next_op = op.getOperation()->getNextNode();
110   if (!next_op) return failure();
111 
112   TF::ExpandDimsOp expand_op;
113   TF::SqueezeOp squeeze_op;
114   int64_t expand_axis = -1;
115   // Expand + Squeeze op.
116   if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
117     if (!llvm::isa<TF::SqueezeOp>(next_op)) {
118       // Expand/Squeeze op must come in pair.
119       return failure();
120     }
121     expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
122     squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
123 
124     // Make sure that the axis in `expand_op` is constant.
125     if (auto const_op =
126             llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
127       expand_axis =
128           (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
129               .getSExtValue();
130       // Canonicalize axis. Some TF python functions, such as
131       // `tf.nn.convolution`, use negative axis.
132       if (expand_axis < 0) {
133         // Always expand 3D input to 4D input.
134         expand_axis += 4;
135       }
136     } else {
137       return failure();
138     }
139     // Make sure that the `squeeze_dims` is equal to `expand_axis`.
140     auto squeeze_dims = squeeze_op.squeeze_dims();
141     if (squeeze_dims.size() != 1) {
142       return failure();
143     }
144     int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
145     if (squeeze_axis < 0) {
146       // Always squeeze 4D input to 3D input.
147       squeeze_axis += 4;
148     }
149     if (squeeze_axis != expand_axis) {
150       return failure();
151     }
152 
153     // Update previous/next op pointer.
154     prev_op = prev_op->getPrevNode();
155     if (!prev_op) return failure();
156     next_op = next_op->getNextNode();
157     if (!next_op) return failure();
158   }
159 
160   // SpaceToBatchND op.
161   if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
162   // TODO(b/149936532): Check `padding` input, currently ignored.
163   TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
164 
165   // Pad op.
166   TF::PadOp pad_op;
167   // TODO(b/149936532): Currently we just ignore the PadOp. However note that
168   // in real scenarios this may not always be correct: user can put a PadOp here
169   // with non-trivial consequences.
170   if (llvm::isa<TF::PadOp>(next_op)) {
171     pad_op = llvm::cast<TF::PadOp>(next_op);
172     next_op = next_op->getNextNode();
173     if (!next_op) return failure();
174   }
175 
176   // BatchToSpaceND + BiasAdd.
177   TF::BatchToSpaceNDOp bts_op;
178   TF::BiasAddOp biasadd_op;
179   bool final_op_is_bts = true;
180   if (llvm::isa<TF::BiasAddOp>(next_op)) {
181     // Must be BiasAdd + BatchToSpaceND.
182     biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
183     next_op = next_op->getNextNode();
184     if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
185     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
186   } else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
187     // BatchToSpaceND + (optional) BiasAdd.
188     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
189     next_op = next_op->getNextNode();
190     if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
191       biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
192       final_op_is_bts = false;
193     }
194   } else {
195     return failure();
196   }
197 
198   llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
199       stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
200   if (!dilations_attr.hasValue()) return failure();
201 
202   if (expand_op) {
203     if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
204       return failure();
205     }
206   }
207 
208   // TODO(b/149936532): Check that the input width & height are multiples of
209   // dilation rate.
210   // TF python library will rewrite dilated conv to
211   // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
212   // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
213   // parts of contributions, one is to reduce padding of CONV from 'SAME' to
214   // 'VALID', and another is to make input shape multiples of dilation rate. The
215   // first part of padding, which is also called `base_padding` will be used
216   // here to determine if the original padding format is 'SAME' or 'VALID'.
217   // According to the following formula we will compute the `base_padding` if
218   // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
219   // tensor  in `BatchToSpace` must satisfy the following:
220   //  paddings[i, 0] = base_paddings[i, 0].
221   //  0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
222   // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
223   //  crops[i, 0] = 0.
224   //  crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
225 
226   //  If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
227   // tells us the original padding is 'SAME' (with one caveat presented below).
228   // Here we need to reset the padding back to `SAME` if `base_padding`
229   // != 0.
230   // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
231   // determine the original padding format. For example, users can build
232   // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
233   // dilated conv, hence we shouldn't pattern match here. Instead, we need to
234   // check values of `paddings` and `crops` to make sure it really stands for
235   // a dilated conv.
236   auto stb_paddings = stb_op.paddings();
237   auto bts_crops = bts_op.crops();
238   ElementsAttr stb_paddings_attr, bts_crops_attr;
239   if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
240       matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
241     if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
242       return failure();
243     // padding - crop.
244     auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
245     auto crops = bts_crops_attr.getValues<IntegerAttr>();
246     for (auto it1 = paddings.begin(), it2 = crops.begin();
247          it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
248       if ((*it1).getInt() != (*it2).getInt()) {
249         op->setAttr("padding", rewriter.getStringAttr("SAME"));
250         break;
251       }
252     }
253   }
254 
255   // Set dilations
256   op->setAttr("dilations", dilations_attr.getValue());
257 
258   if (expand_op) {
259     // If there is `expand_op`, we need to rewire the inputs to bypass the
260     // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
261     // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
262     // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
263 
264     // Connect `expand_op` with the input of `stb_op`.
265     expand_op.setOperand(0, stb_op.input());
266     // Calculate the shape for expand.
267     auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
268     SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
269                                          input_shape.end());
270     expand_shape.insert(expand_shape.begin() + expand_axis, 1);
271 
272     auto expand_result_type = RankedTensorType::get(
273         expand_shape, getElementTypeOrSelf(stb_op.input()));
274     expand_op.getResult().setType(expand_result_type);
275 
276     // Update the conv op's output shape.
277     auto bts_output_shape =
278         bts_op.output().getType().cast<ShapedType>().getShape();
279     SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
280                                               bts_output_shape.end());
281     conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
282     auto conv_result_type = RankedTensorType::get(
283         conv_result_shape, getElementTypeOrSelf(stb_op.input()));
284     op.getResult().setType(conv_result_type);
285 
286     squeeze_op.getResult().setType(bts_op.output().getType());
287 
288     // Connect `biasadd_op` with the output of `squeeze_op`.
289     if (biasadd_op) {
290       biasadd_op.setOperand(0, squeeze_op.output());
291       biasadd_op.output().setType(squeeze_op.output().getType());
292     }
293   } else {
294     if (biasadd_op) biasadd_op.setOperand(0, op.output());
295     op.setOperand(0, stb_op.input());
296     op.getResult().setType(bts_op.getResult().getType());
297   }
298 
299   if (final_op_is_bts) {
300     bts_op.getResult().replaceAllUsesWith(bts_op.input());
301   }
302 
303   stb_op.getResult().dropAllUses();
304   return success();
305 }
306 
307 template <typename Conv2dOpTy>
308 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)309 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
310     Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
311     PatternRewriter& rewriter) const {
312   ElementsAttr stb_bs_attr, bts_bs_attr;
313   if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
314       !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
315     // Returns failure status if block_shape is not a constant.
316     return {};
317   }
318   // Check that the block_shape of `stb_op` and `bts_op` are equal.
319   if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
320   for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
321     if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
322   }
323 
324   int dilation_h_factor = -1, dilation_w_factor = -1;
325   // Set dilation factor.
326   if (stb_bs_attr.getNumElements() >= 2) {
327     dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
328     dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
329   } else if (stb_bs_attr.getNumElements() == 1) {
330     // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
331     // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
332     // dilation factor of W dim, and dilation factor of H dim is set to 1.
333     if (expand_axis == 1) {
334       // NWC -> NHWC
335       dilation_h_factor = 1;
336       dilation_w_factor =
337           stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
338     } else if (expand_axis == 2) {
339       // NHC -> NHWC
340       dilation_h_factor =
341           stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
342       dilation_w_factor = 1;
343     }
344   }
345 
346   if (dilation_h_factor == -1 || dilation_w_factor == -1) {
347     return {};
348   }
349 
350   return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
351 }
352 
353 }  // namespace TFL
354 }  // namespace mlir
355 
356 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
357