1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7 http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
21 #include <cstdint>
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Matchers.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 namespace mlir {
34 namespace TFL {
36 // A dilated convolution can be emulated with a regular convolution by chaining
37 // SpaceToBatch and BatchToSpace ops before and after it:
38 //
39 // SpaceToBatchND -> Conv2D -> BatchToSpaceND
40 //
41 // This method was common before Conv2D fully supported dilated convolution in
42 // TensorFlow. This transformation detects this "emulation", and replaces it
43 // with a true dilated convolution, eliminating the SpaceToBatch and
44 // BatchtoSpace ops.
45 //
46 // Detecting this alone would be relatively easy. However, in practice some
47 // extra ops are used, so we detect the following patterns:
48 //
49 //
50 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
51 //
52 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
53 // BiasAdd
54 //
55 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
56 //
57 // SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
58 //
59 // SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
60 //
61 //
62 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
63 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
64 // thrown in just for the extra headache. Padding adapts non-conforming input
65 // sizes, and can be discarded. The bias is necessary, so is kept.
66 template <typename Conv2dOpTy>
67 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
68 private:
69 using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
71 // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
72 llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
73 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
74 PatternRewriter& rewriter) const;
76 public:
77 LogicalResult matchAndRewrite(Conv2dOpTy op,
78 PatternRewriter& rewriter) const override;
79 };
81 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)82 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
83 Conv2dOpTy op, PatternRewriter& rewriter) const {
84 // Make sure Conv2D has 'VALID' padding.
85 if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
86 return failure();
87 }
88 // Make sure dilations are all ones if set.
89 const ArrayAttr& dilations =
90 op->template getAttrOfType<ArrayAttr>("dilations");
91 if (dilations && !TFIntListIsAllOnes(dilations)) {
92 return failure();
93 }
95 if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op))
96 return failure();
98 // Allow dynamic width and height dimensions only.
99 auto result_ty = op.getResult().getType().template cast<TensorType>();
100 if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
101 result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3))
102 return failure();
104 // Check if the ConvOp is preceded by a `Expand` op and succeeded by a
105 // `Squeeze` op.
106 Operation* prev_op = op.getOperation()->getPrevNode();
107 if (!prev_op) return failure();
109 Operation* next_op = op.getOperation()->getNextNode();
110 if (!next_op) return failure();
112 TF::ExpandDimsOp expand_op;
113 TF::SqueezeOp squeeze_op;
114 int64_t expand_axis = -1;
115 // Expand + Squeeze op.
116 if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
117 if (!llvm::isa<TF::SqueezeOp>(next_op)) {
118 // Expand/Squeeze op must come in pair.
119 return failure();
120 }
121 expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
122 squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
124 // Make sure that the axis in `expand_op` is constant.
125 if (auto const_op =
126 llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
127 expand_axis =
128 (*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
129 .getSExtValue();
130 // Canonicalize axis. Some TF python functions, such as
131 // `tf.nn.convolution`, use negative axis.
132 if (expand_axis < 0) {
133 // Always expand 3D input to 4D input.
134 expand_axis += 4;
135 }
136 } else {
137 return failure();
138 }
139 // Make sure that the `squeeze_dims` is equal to `expand_axis`.
140 auto squeeze_dims = squeeze_op.squeeze_dims();
141 if (squeeze_dims.size() != 1) {
142 return failure();
143 }
144 int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
145 if (squeeze_axis < 0) {
146 // Always squeeze 4D input to 3D input.
147 squeeze_axis += 4;
148 }
149 if (squeeze_axis != expand_axis) {
150 return failure();
151 }
153 // Update previous/next op pointer.
154 prev_op = prev_op->getPrevNode();
155 if (!prev_op) return failure();
156 next_op = next_op->getNextNode();
157 if (!next_op) return failure();
158 }
160 // SpaceToBatchND op.
161 if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return failure();
162 // TODO(b/149936532): Check `padding` input, currently ignored.
163 TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
165 // Pad op.
166 TF::PadOp pad_op;
167 // TODO(b/149936532): Currently we just ignore the PadOp. However note that
168 // in real scenarios this may not always be correct: user can put a PadOp here
169 // with non-trivial consequences.
170 if (llvm::isa<TF::PadOp>(next_op)) {
171 pad_op = llvm::cast<TF::PadOp>(next_op);
172 next_op = next_op->getNextNode();
173 if (!next_op) return failure();
174 }
176 // BatchToSpaceND + BiasAdd.
177 TF::BatchToSpaceNDOp bts_op;
178 TF::BiasAddOp biasadd_op;
179 bool final_op_is_bts = true;
180 if (llvm::isa<TF::BiasAddOp>(next_op)) {
181 // Must be BiasAdd + BatchToSpaceND.
182 biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
183 next_op = next_op->getNextNode();
184 if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op)) return failure();
185 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
186 } else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
187 // BatchToSpaceND + (optional) BiasAdd.
188 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
189 next_op = next_op->getNextNode();
190 if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
191 biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
192 final_op_is_bts = false;
193 }
194 } else {
195 return failure();
196 }
198 llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
199 stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
200 if (!dilations_attr.hasValue()) return failure();
202 if (expand_op) {
203 if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
204 return failure();
205 }
206 }
208 // TODO(b/149936532): Check that the input width & height are multiples of
209 // dilation rate.
210 // TF python library will rewrite dilated conv to
211 // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
212 // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
213 // parts of contributions, one is to reduce padding of CONV from 'SAME' to
214 // 'VALID', and another is to make input shape multiples of dilation rate. The
215 // first part of padding, which is also called `base_padding` will be used
216 // here to determine if the original padding format is 'SAME' or 'VALID'.
217 // According to the following formula we will compute the `base_padding` if
218 // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
219 // tensor in `BatchToSpace` must satisfy the following:
220 // paddings[i, 0] = base_paddings[i, 0].
221 // 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
222 // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
223 // crops[i, 0] = 0.
224 // crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
226 // If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
227 // tells us the original padding is 'SAME' (with one caveat presented below).
228 // Here we need to reset the padding back to `SAME` if `base_padding`
229 // != 0.
230 // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
231 // determine the original padding format. For example, users can build
232 // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
233 // dilated conv, hence we shouldn't pattern match here. Instead, we need to
234 // check values of `paddings` and `crops` to make sure it really stands for
235 // a dilated conv.
236 auto stb_paddings = stb_op.paddings();
237 auto bts_crops = bts_op.crops();
238 ElementsAttr stb_paddings_attr, bts_crops_attr;
239 if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) &&
240 matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
241 if (stb_paddings_attr.getNumElements() != bts_crops_attr.getNumElements())
242 return failure();
243 // padding - crop.
244 auto paddings = stb_paddings_attr.getValues<IntegerAttr>();
245 auto crops = bts_crops_attr.getValues<IntegerAttr>();
246 for (auto it1 = paddings.begin(), it2 = crops.begin();
247 it1 != paddings.end() && it2 != crops.end(); it1++, it2++) {
248 if ((*it1).getInt() != (*it2).getInt()) {
249 op->setAttr("padding", rewriter.getStringAttr("SAME"));
250 break;
251 }
252 }
253 }
255 // Set dilations
256 op->setAttr("dilations", dilations_attr.getValue());
258 if (expand_op) {
259 // If there is `expand_op`, we need to rewire the inputs to bypass the
260 // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
261 // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
262 // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
264 // Connect `expand_op` with the input of `stb_op`.
265 expand_op.setOperand(0, stb_op.input());
266 // Calculate the shape for expand.
267 auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
268 SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
269 input_shape.end());
270 expand_shape.insert(expand_shape.begin() + expand_axis, 1);
272 auto expand_result_type = RankedTensorType::get(
273 expand_shape, getElementTypeOrSelf(stb_op.input()));
274 expand_op.getResult().setType(expand_result_type);
276 // Update the conv op's output shape.
277 auto bts_output_shape =
278 bts_op.output().getType().cast<ShapedType>().getShape();
279 SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
280 bts_output_shape.end());
281 conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
282 auto conv_result_type = RankedTensorType::get(
283 conv_result_shape, getElementTypeOrSelf(stb_op.input()));
284 op.getResult().setType(conv_result_type);
286 squeeze_op.getResult().setType(bts_op.output().getType());
288 // Connect `biasadd_op` with the output of `squeeze_op`.
289 if (biasadd_op) {
290 biasadd_op.setOperand(0, squeeze_op.output());
291 biasadd_op.output().setType(squeeze_op.output().getType());
292 }
293 } else {
294 if (biasadd_op) biasadd_op.setOperand(0, op.output());
295 op.setOperand(0, stb_op.input());
296 op.getResult().setType(bts_op.getResult().getType());
297 }
299 if (final_op_is_bts) {
300 bts_op.getResult().replaceAllUsesWith(bts_op.input());
301 }
303 stb_op.getResult().dropAllUses();
304 return success();
305 }
307 template <typename Conv2dOpTy>
308 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)309 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
310 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
311 PatternRewriter& rewriter) const {
312 ElementsAttr stb_bs_attr, bts_bs_attr;
313 if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
314 !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
315 // Returns failure status if block_shape is not a constant.
316 return {};
317 }
318 // Check that the block_shape of `stb_op` and `bts_op` are equal.
319 if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
320 for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
321 if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
322 }
324 int dilation_h_factor = -1, dilation_w_factor = -1;
325 // Set dilation factor.
326 if (stb_bs_attr.getNumElements() >= 2) {
327 dilation_h_factor = stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
328 dilation_w_factor = stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
329 } else if (stb_bs_attr.getNumElements() == 1) {
330 // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
331 // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
332 // dilation factor of W dim, and dilation factor of H dim is set to 1.
333 if (expand_axis == 1) {
334 // NWC -> NHWC
335 dilation_h_factor = 1;
336 dilation_w_factor =
337 stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
338 } else if (expand_axis == 2) {
339 // NHC -> NHWC
340 dilation_h_factor =
341 stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
342 dilation_w_factor = 1;
343 }
344 }
346 if (dilation_h_factor == -1 || dilation_w_factor == -1) {
347 return {};
348 }
350 return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
351 }
353 } // namespace TFL
354 } // namespace mlir