1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for legalizing HLO to TensorFlow.
17
18 #include <cstdint>
19 #include <cstdlib>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
40 #include "mlir/IR/Location.h" // from @llvm-project
41 #include "mlir/IR/MLIRContext.h" // from @llvm-project
42 #include "mlir/IR/Matchers.h" // from @llvm-project
43 #include "mlir/IR/Operation.h" // from @llvm-project
44 #include "mlir/IR/PatternMatch.h" // from @llvm-project
45 #include "mlir/IR/Region.h" // from @llvm-project
46 #include "mlir/IR/Value.h" // from @llvm-project
47 #include "mlir/Pass/Pass.h" // from @llvm-project
48 #include "mlir/Support/LLVM.h" // from @llvm-project
49 #include "mlir/Support/LogicalResult.h" // from @llvm-project
50 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
54 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
55 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
57 #include "tensorflow/core/framework/kernel_shape_util.h"
58 #include "tensorflow/core/lib/math/math_util.h"
59
60 namespace mlir {
61 namespace TF {
62 namespace {
63
64 using mhlo::DotDimensionNumbers;
65
66 class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
67 public:
68 using OpConversionPattern::OpConversionPattern;
69
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const70 LogicalResult matchAndRewrite(
71 mhlo::ConvOp conv_op, ArrayRef<Value> args,
72 ConversionPatternRewriter &rewriter) const final {
73 if (!IsSupportedConvOp(conv_op)) {
74 return failure();
75 }
76
77 // Constructs strides array.
78 // For example, [2, 3] -> [1, 2, 3, 1].
79 SmallVector<int64_t, 4> strides({1});
80 for (const auto v :
81 conv_op.window_strides().getValue().getValues<int64_t>()) {
82 strides.emplace_back(v);
83 }
84 strides.emplace_back(1);
85
86 // Constructs dilation array.
87 SmallVector<int64_t, 4> dilation;
88 if (auto rhs_dilation = conv_op.rhs_dilation()) {
89 // For example, [2, 3] -> [1, 2, 3, 1].
90 dilation.emplace_back(1);
91 dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
92 rhs_dilation.getValue().getValues<int64_t>().end());
93 dilation.emplace_back(1);
94 } else {
95 // Default value
96 dilation = {1, 1, 1, 1};
97 }
98
99 const int input_feature_dimension =
100 conv_op.dimension_numbers().input_feature_dimension().getInt();
101 const int input_channels =
102 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
103 input_feature_dimension);
104 int feature_group_count = conv_op.feature_group_count();
105
106 if (feature_group_count != 1 && feature_group_count != input_channels) {
107 // Group convolution is not supported yet.
108 return failure();
109 }
110
111 const int num_spatial_dims =
112 conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
113 const bool is_depthwise_conv = input_channels == feature_group_count;
114 std::string padding;
115 SmallVector<int64_t, 8> explicit_padding;
116 if (!conv_op.padding().hasValue() ||
117 (conv_op.padding().getValue().isSplat() &&
118 conv_op.padding()->getSplatValue<int64_t>() == 0)) {
119 padding = "VALID";
120 } else {
121 SmallVector<int64_t, 4> padding_array;
122 for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
123 padding_array.emplace_back(v);
124 }
125
126 if (IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
127 padding_array)) {
128 // Check if padding is "SAME".
129 padding = "SAME";
130 } else {
131 padding = "EXPLICIT";
132 explicit_padding.push_back(0);
133 explicit_padding.push_back(0);
134 explicit_padding.append(padding_array);
135 explicit_padding.push_back(0);
136 explicit_padding.push_back(0);
137 }
138 }
139
140 CreateConvOp(conv_op, strides, padding, explicit_padding, dilation,
141 is_depthwise_conv, input_channels, num_spatial_dims, rewriter);
142 return success();
143 };
144
145 private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const146 bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
147 ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
148 ArrayRef<int64_t> padding_array) const {
149 auto input_spatial_dim_iter = conv_op.dimension_numbers()
150 .input_spatial_dimensions()
151 .getValues<int64_t>()
152 .begin();
153 auto kernel_spatial_dim_iter = conv_op.dimension_numbers()
154 .kernel_spatial_dimensions()
155 .getValues<int64_t>()
156 .begin();
157 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
158 int dim = i + 1;
159 int64_t output_size;
160 int64_t pad_low_int64;
161 int64_t pad_high_int64;
162 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
163 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
164 *(input_spatial_dim_iter + i)),
165 conv_op.rhs().getType().cast<ShapedType>().getDimSize(
166 *(kernel_spatial_dim_iter + i)),
167 dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
168 &pad_low_int64, &pad_high_int64);
169 if (!status.ok()) return false;
170 if (padding_array[2 * i] != pad_low_int64 ||
171 padding_array[2 * i + 1] != pad_high_int64)
172 return false;
173 }
174
175 return true;
176 }
177
178 // Returns true if the op needs reformat.
NeedsReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start) const179 bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim,
180 int spatial_dim_start,
181 int default_batch_dim,
182 int default_feature_dim,
183 int default_spatial_dim_start) const {
184 return batch_dim != default_batch_dim ||
185 feature_dim != default_feature_dim ||
186 spatial_dim_start != default_spatial_dim_start;
187 }
188
189 // Gets reformat type and permutation attribute. Call this function only if
190 // NeedsReformatTypeAndPermutation returns true.
191 std::pair<RankedTensorType, DenseIntElementsAttr>
GetReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,RankedTensorType type,ConversionPatternRewriter & rewriter) const192 GetReformatTypeAndPermutation(int batch_dim, int feature_dim,
193 int spatial_dim_start, int default_batch_dim,
194 int default_feature_dim,
195 int default_spatial_dim_start,
196 int num_spatial_dims, RankedTensorType type,
197 ConversionPatternRewriter &rewriter) const {
198 auto shape = type.getShape();
199 llvm::SmallVector<int64_t, 4> permutation_array(num_spatial_dims + 2);
200 permutation_array[default_batch_dim] = batch_dim;
201 permutation_array[default_feature_dim] = feature_dim;
202 llvm::SmallVector<int64_t, 4> transposed_shape(num_spatial_dims + 2);
203 transposed_shape[default_batch_dim] = shape[batch_dim];
204 transposed_shape[default_feature_dim] = shape[feature_dim];
205 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
206 permutation_array[default_spatial_dim_start + i] = spatial_dim_start + i;
207 transposed_shape[default_spatial_dim_start + i] =
208 shape[spatial_dim_start + i];
209 }
210 auto new_type =
211 RankedTensorType::get(transposed_shape, type.getElementType());
212 auto permutation = DenseIntElementsAttr::get(
213 RankedTensorType::get({type.getRank()}, rewriter.getI64Type()),
214 permutation_array);
215 return {new_type, permutation};
216 }
217
FormatToNHWC(Value value,int batch_dim,int feature_dim,DenseIntElementsAttr spatial_dimensions,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,ConversionPatternRewriter & rewriter) const218 Value FormatToNHWC(Value value, int batch_dim, int feature_dim,
219 DenseIntElementsAttr spatial_dimensions,
220 int default_batch_dim, int default_feature_dim,
221 int default_spatial_dim_start, int num_spatial_dims,
222 ConversionPatternRewriter &rewriter) const {
223 auto type = value.getType().cast<RankedTensorType>();
224 DenseIntElementsAttr permutation;
225 const int spatial_dim_start =
226 *spatial_dimensions.getValues<int64_t>().begin();
227 if (!NeedsReformatTypeAndPermutation(
228 batch_dim, feature_dim, spatial_dim_start, default_batch_dim,
229 default_feature_dim, default_spatial_dim_start)) {
230 // Transpose is not needed becasue the current format is "NHWC".
231 return value;
232 }
233 std::pair<RankedTensorType &, DenseIntElementsAttr &>(type, permutation) =
234 GetReformatTypeAndPermutation(batch_dim, feature_dim, spatial_dim_start,
235 default_batch_dim, default_feature_dim,
236 default_spatial_dim_start,
237 num_spatial_dims, type, rewriter);
238 return rewriter.create<mhlo::TransposeOp>(value.getLoc(), type, value,
239 permutation);
240 }
241
CreateConvOp(mhlo::ConvOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> explicit_padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,int num_spatial_dims,ConversionPatternRewriter & rewriter) const242 void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
243 StringRef padding, ArrayRef<int64_t> explicit_padding,
244 ArrayRef<int64_t> dilation, bool is_depthwise_conv,
245 int input_channels, int num_spatial_dims,
246 ConversionPatternRewriter &rewriter) const {
247 // Transposes lhs and rhs if their formats are not NHWC.
248 Value lhs = FormatToNHWC(
249 conv_op.lhs(),
250 conv_op.dimension_numbers().input_batch_dimension().getInt(),
251 conv_op.dimension_numbers().input_feature_dimension().getInt(),
252 conv_op.dimension_numbers().input_spatial_dimensions(),
253 /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
254 /*default_spatial_dim_start=*/1, num_spatial_dims, rewriter);
255 Value rhs = FormatToNHWC(
256 conv_op.rhs(),
257 conv_op.dimension_numbers().kernel_input_feature_dimension().getInt(),
258 conv_op.dimension_numbers().kernel_output_feature_dimension().getInt(),
259 conv_op.dimension_numbers().kernel_spatial_dimensions(),
260 /*default_batch_dim=*/num_spatial_dims,
261 /*default_feature_dim=*/num_spatial_dims + 1,
262 /*default_spatial_dim_start=*/0, num_spatial_dims, rewriter);
263
264 auto conv_output_type = conv_op.getType().cast<RankedTensorType>();
265 DenseIntElementsAttr permutation;
266 const bool need_transpose_output = NeedsReformatTypeAndPermutation(
267 conv_op.dimension_numbers().output_batch_dimension().getInt(),
268 conv_op.dimension_numbers().output_feature_dimension().getInt(),
269 *conv_op.dimension_numbers()
270 .output_spatial_dimensions()
271 .getValues<int64_t>()
272 .begin(),
273 /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
274 /*default_spatial_dim_start=*/1);
275 if (need_transpose_output) {
276 std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
277 permutation) =
278 GetReformatTypeAndPermutation(
279 conv_op.dimension_numbers().output_batch_dimension().getInt(),
280 conv_op.dimension_numbers().output_feature_dimension().getInt(),
281 *conv_op.dimension_numbers()
282 .output_spatial_dimensions()
283 .getValues<int64_t>()
284 .begin(),
285 /*default_batch_dim=*/0,
286 /*default_feature_dim=*/num_spatial_dims + 1,
287 /*default_spatial_dim_start=*/1, num_spatial_dims,
288 conv_output_type, rewriter);
289 }
290 Value output;
291 if (is_depthwise_conv) {
292 // Reshapes filter format to [filter_height, filter_width, in_channels,
293 // channel_multiplier] from HLO's [filter_height, filter_width, 1,
294 // in_channels * channel_multiplier] format.
295 auto filter_type = rhs.getType().cast<ShapedType>();
296 llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
297 llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
298 hlo_filter_shape.end());
299 tf_filter_shape[2] = input_channels;
300 tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
301 auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
302 rhs.getLoc(),
303 RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
304 rhs);
305
306 output = rewriter.create<DepthwiseConv2dNativeOp>(
307 conv_op.getLoc(), conv_output_type, lhs, reshaped_filter,
308 rewriter.getI64ArrayAttr(strides),
309 /*padding=*/rewriter.getStringAttr(padding),
310 /*explicit_paddings=*/rewriter.getI64ArrayAttr(explicit_padding),
311 /*data_format=*/rewriter.getStringAttr("NHWC"),
312 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
313 } else {
314 output = rewriter.create<Conv2DOp>(
315 conv_op.getLoc(), conv_output_type, lhs, rhs,
316 rewriter.getI64ArrayAttr(strides),
317 /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
318 /*padding=*/rewriter.getStringAttr(padding),
319 /*explicit_paddings=*/rewriter.getI64ArrayAttr(explicit_padding),
320 /*data_format=*/rewriter.getStringAttr("NHWC"),
321 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
322 }
323
324 if (need_transpose_output) {
325 // Converts from "NHWC" format back to the original output format.
326 std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
327 permutation) =
328 GetReformatTypeAndPermutation(
329 /*batch_dim=*/0, /*feature_dim=*/num_spatial_dims + 1,
330 /*spatial_dim_start=*/1,
331 conv_op.dimension_numbers().output_batch_dimension().getInt(),
332 conv_op.dimension_numbers().output_feature_dimension().getInt(),
333 *conv_op.dimension_numbers()
334 .output_spatial_dimensions()
335 .getValues<int64_t>()
336 .begin(),
337 num_spatial_dims, conv_output_type, rewriter);
338 output = rewriter.create<mhlo::TransposeOp>(
339 conv_op.getLoc(), conv_op.getType(), output, permutation);
340 }
341 rewriter.replaceOp(conv_op, {output});
342 }
343
IsSupportedConvOp(mhlo::ConvOp conv_op) const344 bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
345 if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
346 !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
347 !conv_op.getType().cast<ShapedType>().hasStaticShape())
348 return false;
349
350 // All ones in "lhs_dilation" means this "mhlo.conv" op should be
351 // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
352 if (conv_op.lhs_dilation().hasValue()) {
353 auto lhs_dilation = conv_op.lhs_dilation().getValue();
354 if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
355 return false;
356 }
357
358 if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
359 .getValue()
360 .getType()
361 .cast<ShapedType>()
362 .getRank() != 1)
363 return false;
364
365 int num_spatial_dims =
366 conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
367 // TODO(b/158636600): Currently we don't support 3D Convolution.
368 if (num_spatial_dims != 2) return false;
369
370 return true;
371 }
372 };
373
374 class ConvertConvBackpropInputOp : public OpConversionPattern<mhlo::ConvOp> {
375 public:
376 using OpConversionPattern::OpConversionPattern;
377
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const378 LogicalResult matchAndRewrite(
379 mhlo::ConvOp conv_op, ArrayRef<Value> args,
380 ConversionPatternRewriter &rewriter) const final {
381 if (IsSupportedConvOp(conv_op, rewriter).failed()) {
382 return rewriter.notifyMatchFailure(
383 conv_op, "doesn't support to convert to ConvBackpropInputOp");
384 }
385
386 // Constructs strides array from lhs_dilation.
387 // For example, [2, 3] -> [1, 2, 3, 1].
388 SmallVector<int64_t, 4> strides({1});
389 strides.append(
390 conv_op.lhs_dilation().getValue().getValues<int64_t>().begin(),
391 conv_op.lhs_dilation().getValue().getValues<int64_t>().end());
392 strides.emplace_back(1);
393
394 // Constructs dilation array.
395 SmallVector<int64_t, 4> dilation;
396 if (auto rhs_dilation = conv_op.rhs_dilation()) {
397 // For example, [2, 3] -> [1, 2, 3, 1].
398 dilation.emplace_back(1);
399 dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
400 rhs_dilation.getValue().getValues<int64_t>().end());
401 dilation.emplace_back(1);
402 } else {
403 // Default value
404 dilation = {1, 1, 1, 1};
405 }
406
407 std::string padding;
408 if (!conv_op.padding().hasValue() ||
409 (conv_op.padding().getValue().isSplat() &&
410 conv_op.padding()->getSplatValue<int64_t>() == 0)) {
411 padding = "VALID";
412 } else {
413 const int num_spatial_dims = conv_op.dimension_numbers()
414 .input_spatial_dimensions()
415 .getNumElements();
416 if (!IsSamePadding(conv_op, num_spatial_dims, strides)) {
417 return rewriter.notifyMatchFailure(
418 conv_op, "requires padding to be SAME or VALID");
419 }
420 padding = "SAME";
421 }
422
423 // Converts int64_t to int32_t.
424 llvm::SmallVector<int, 4> input_shape;
425 for (int64_t dim : conv_op.getType().cast<RankedTensorType>().getShape()) {
426 input_shape.push_back(dim);
427 }
428 auto input_sizes = rewriter.create<ConstOp>(
429 conv_op.getLoc(),
430 DenseIntElementsAttr::get(
431 RankedTensorType::get({static_cast<int64_t>(input_shape.size())},
432 rewriter.getI32Type()),
433 input_shape));
434 // Mirror the filter in the spatial dimensions.
435 auto filter = rewriter.create<mhlo::ReverseOp>(
436 conv_op.getLoc(), conv_op.rhs(),
437 conv_op.dimension_numbers().kernel_spatial_dimensions());
438 rewriter.replaceOpWithNewOp<Conv2DBackpropInputOp>(
439 conv_op, conv_op.getType(), input_sizes, filter, conv_op.lhs(),
440 rewriter.getI64ArrayAttr(strides),
441 /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
442 /*padding=*/rewriter.getStringAttr(padding),
443 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
444 /*data_format=*/rewriter.getStringAttr("NHWC"),
445 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
446 return success();
447 };
448
449 private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides) const450 bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
451 ArrayRef<int64_t> strides) const {
452 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
453 int dim = i + 1;
454 int stride = strides[dim];
455 int input_size = conv_op.getType().cast<ShapedType>().getDimSize(dim);
456 int output_size =
457 conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim);
458 if (output_size != (input_size + stride - 1) / stride) {
459 return false;
460 }
461 }
462
463 return true;
464 }
465
IsSupportedConvOp(mhlo::ConvOp conv_op,ConversionPatternRewriter & rewriter) const466 LogicalResult IsSupportedConvOp(mhlo::ConvOp conv_op,
467 ConversionPatternRewriter &rewriter) const {
468 if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
469 !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
470 !conv_op.getType().cast<ShapedType>().hasStaticShape())
471 return rewriter.notifyMatchFailure(conv_op, "requires static shape");
472
473 const int input_feature_dimension =
474 conv_op.dimension_numbers().input_feature_dimension().getInt();
475 const int input_channels =
476 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
477 input_feature_dimension);
478 int feature_group_count = conv_op.feature_group_count();
479
480 if (feature_group_count != 1 && feature_group_count != input_channels) {
481 // Group convolution is not supported yet.
482 return rewriter.notifyMatchFailure(conv_op,
483 "doesn't support group convolution");
484 }
485
486 // Checks lhs_dilation is non-trivial.
487 if (!conv_op.lhs_dilation().hasValue()) {
488 return rewriter.notifyMatchFailure(conv_op,
489 "requires lhs_dilation attribute");
490 }
491 auto lhs_dilation = conv_op.lhs_dilation().getValue();
492 if (lhs_dilation.isSplat() && lhs_dilation.getSplatValue<int64_t>() == 1)
493 return rewriter.notifyMatchFailure(conv_op,
494 "requires non-trivial lhs_dilation");
495
496 if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
497 .getValue()
498 .getType()
499 .cast<ShapedType>()
500 .getRank() != 1)
501 return rewriter.notifyMatchFailure(
502 conv_op, "requires window_strides to equal to one");
503
504 int num_spatial_dims =
505 conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
506 // TODO(chhe): Currently we don't support 3D Convolution.
507 if (num_spatial_dims != 2)
508 return rewriter.notifyMatchFailure(conv_op,
509 "doesn't support more than 2D");
510
511 // TODO(chhe): To support more data formats other than "NHWC".
512 // Checks format [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f].
513 if (conv_op.dimension_numbers().input_batch_dimension().getInt() != 0 ||
514 conv_op.dimension_numbers().input_feature_dimension().getInt() !=
515 num_spatial_dims + 1)
516 return rewriter.notifyMatchFailure(conv_op,
517 "requires input format [b, 0, 1, f]");
518 DenseIntElementsAttr input_spatial_dimensions =
519 conv_op.dimension_numbers().input_spatial_dimensions();
520 for (auto p :
521 llvm::enumerate(input_spatial_dimensions.getValues<int64_t>())) {
522 if (p.value() != p.index() + 1)
523 return rewriter.notifyMatchFailure(
524 conv_op, "requires input format [b, 0, 1, f]");
525 }
526
527 // Checks output dimensions.
528 if (conv_op.dimension_numbers().output_batch_dimension().getInt() != 0 ||
529 conv_op.dimension_numbers().output_feature_dimension().getInt() !=
530 num_spatial_dims + 1)
531 return rewriter.notifyMatchFailure(conv_op,
532 "requires output format [b, 0, 1, f]");
533 DenseIntElementsAttr output_spatial_dimensions =
534 conv_op.dimension_numbers().output_spatial_dimensions();
535 for (auto p :
536 llvm::enumerate(output_spatial_dimensions.getValues<int64_t>())) {
537 if (p.value() != p.index() + 1)
538 return rewriter.notifyMatchFailure(
539 conv_op, "requires output format [0, 1, o, i]");
540 }
541
542 // Checks kernel dimensions.
543 if (conv_op.dimension_numbers().kernel_input_feature_dimension().getInt() !=
544 num_spatial_dims + 1 ||
545 conv_op.dimension_numbers()
546 .kernel_output_feature_dimension()
547 .getInt() != num_spatial_dims)
548 return rewriter.notifyMatchFailure(conv_op,
549 "requires kernel format [b, 0, 1, f]");
550 DenseIntElementsAttr kernal_spatial_dimensions =
551 conv_op.dimension_numbers().kernel_spatial_dimensions();
552 for (auto p :
553 llvm::enumerate(kernal_spatial_dimensions.getValues<int64_t>())) {
554 if (p.value() != p.index())
555 return rewriter.notifyMatchFailure(
556 conv_op, "requires kernel format [0, 1, o, i]");
557 }
558
559 return success();
560 }
561 };
562
563 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
564 public:
565 using OpConversionPattern::OpConversionPattern;
566
matchAndRewrite(mhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const567 LogicalResult matchAndRewrite(
568 mhlo::SliceOp slice_op, ArrayRef<Value> args,
569 ConversionPatternRewriter &rewriter) const final {
570 DenseIntElementsAttr strides = slice_op.strides();
571 // Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
572 if (!strides.isSplat() ||
573 strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
574 return failure();
575
576 rewriter.setInsertionPointAfter(slice_op.getOperation());
577 auto start_indices = slice_op.start_indices();
578 auto limit_indices = slice_op.limit_indices();
579 std::vector<int64_t> size_values;
580 for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
581 limit_indices.getValues<APInt>())) {
582 size_values.emplace_back(std::get<1>(pair).getSExtValue() -
583 std::get<0>(pair).getSExtValue());
584 }
585
586 RankedTensorType ty =
587 RankedTensorType::get({static_cast<int64_t>(size_values.size())},
588 rewriter.getIntegerType(64));
589 auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
590 auto size = rewriter.create<ConstOp>(
591 slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
592 rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
593 slice_op.operand(), start, size);
594 return success();
595 };
596 };
597
598 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
599 public:
600 using OpConversionPattern::OpConversionPattern;
601
matchAndRewrite(mhlo::DynamicSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const602 LogicalResult matchAndRewrite(
603 mhlo::DynamicSliceOp op, ArrayRef<Value> args,
604 ConversionPatternRewriter &rewriter) const final {
605 ShapedType input_type = op.operand().getType().cast<ShapedType>();
606 if (!input_type.hasStaticShape()) return failure();
607 Type start_indices_element_type = op.start_indices()
608 .front()
609 .getType()
610 .cast<ShapedType>()
611 .getElementType();
612
613 // Clamp indices to [0, input_size - output_size]
614 llvm::SmallVector<Value, 4> start_indices_vector;
615 start_indices_vector.reserve(op.start_indices().size());
616 Value clamp_min = rewriter.create<ConstOp>(
617 op.getLoc(), rewriter.getIntegerAttr(start_indices_element_type, 0));
618 for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
619 Value clamp_max = rewriter.create<ConstOp>(
620 op.getLoc(),
621 rewriter.getIntegerAttr(start_indices_element_type,
622 input_type.getShape()[i] -
623 op.slice_sizes().getValue<int64_t>({i})));
624 Value clamped_index = rewriter.create<mhlo::ClampOp>(
625 op.getLoc(), op.start_indices()[i].getType(), clamp_min,
626 op.start_indices()[i], clamp_max);
627 start_indices_vector.push_back(clamped_index);
628 }
629
630 // Pack individual start indices to start indices tensor.
631 Type start_indices_type = RankedTensorType::get(
632 {static_cast<int64_t>(start_indices_vector.size())},
633 start_indices_element_type);
634 Value start_indices_op = rewriter.create<PackOp>(
635 op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
636
637 Value slice_sices_op =
638 rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
639 rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
640 start_indices_op, slice_sices_op);
641 return success();
642 };
643 };
644
645 // Appends all elements in `range` to `values`.
646 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)647 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
648 values.insert(values.end(), range.begin(), range.end());
649 }
650
651 // Appends all elements in `range` to `values`.
652 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)653 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
654 RangeTs &&...ranges) {
655 values.insert(values.end(), range.begin(), range.end());
656 Append(values, ranges...);
657 }
658
659 // Returns the number of elements in `range`.
660 template <typename Range>
Size(Range && range)661 size_t Size(Range &&range) {
662 return range.size();
663 }
664
665 // Returns the total number of elements in a variadic number of `ranges`.
666 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)667 size_t Size(Range &&range, RangeTs &&...ranges) {
668 return range.size() + Size(std::forward<RangeTs>(ranges)...);
669 }
670
671 // Concats all elements in `ranges` and returns a small vector as a result.
672 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)673 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
674 llvm::SmallVector<int64_t, 4> results;
675 results.reserve(Size(std::forward<RangeTs>(ranges)...));
676 Append(results, std::forward<RangeTs>(ranges)...);
677 return results;
678 }
679
680 // A struct to hold axes and sizes for a set of dimensions.
681 struct DimensionVector {
AxesArraymlir::TF::__anon964666f10111::DimensionVector682 llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anon964666f10111::DimensionVector683 llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
684
685 llvm::SmallVector<int64_t, 4> axes;
686 llvm::SmallVector<int64_t, 4> sizes;
687 };
688
689 // A struct to hold information about dimensions of dot_general operands.
690 class DotDimensionsInfo {
691 public:
DotDimensionsInfo(ShapedType type,DenseIntElementsAttr batch_dimensions,DenseIntElementsAttr contracting_dimensions)692 DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
693 DenseIntElementsAttr contracting_dimensions) {
694 const int rank = type.getRank();
695 for (const int dim : batch_dimensions.getValues<int64_t>()) {
696 batch_dimensions_.axes.push_back(dim);
697 batch_dimensions_.sizes.push_back(type.getDimSize(dim));
698 }
699
700 for (const int dim : contracting_dimensions.getValues<int64_t>()) {
701 contracting_dimensions_.axes.push_back(dim);
702 contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
703 }
704
705 for (int dim = 0; dim < rank; ++dim) {
706 if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
707 llvm::count(batch_dimensions_.axes, dim) > 0) {
708 continue;
709 }
710 out_dimensions_.axes.push_back(dim);
711 out_dimensions_.sizes.push_back(type.getDimSize(dim));
712 }
713 }
714
batch_dimensions() const715 const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const716 const DimensionVector &contracting_dimensions() const {
717 return contracting_dimensions_;
718 }
719 // Out dimensions are any dimensions that are neither batch nor contracting
720 // dimensions, hence will be propagated to output shape.
out_dimensions() const721 const DimensionVector &out_dimensions() const { return out_dimensions_; }
722
723 // Returns the total dimension size after flattening all contracting
724 // dimensions.
FlattenedContractingDimensionSize() const725 int FlattenedContractingDimensionSize() const {
726 return std::accumulate(contracting_dimensions_.sizes.begin(),
727 contracting_dimensions_.sizes.end(), 1,
728 std::multiplies<int64_t>());
729 }
730
731 // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const732 int FlattenedOutDimensionSize() const {
733 return std::accumulate(out_dimensions_.sizes.begin(),
734 out_dimensions_.sizes.end(), 1,
735 std::multiplies<int64_t>());
736 }
737
738 private:
739 DimensionVector batch_dimensions_;
740 DimensionVector contracting_dimensions_;
741 // Out dimensions are any dimensions that are neither batch nor contracting
742 // dimensions, hence will be propagated to output shape.
743 DimensionVector out_dimensions_;
744 };
745
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbers dot_dimension_numbers,ShapedType result_type,mlir::Location loc)746 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
747 DotDimensionNumbers dot_dimension_numbers,
748 ShapedType result_type, mlir::Location loc) {
749 auto lhs_type = lhs.getType().cast<ShapedType>();
750 auto rhs_type = rhs.getType().cast<ShapedType>();
751 const int lhs_rank = lhs_type.getRank();
752 const int rhs_rank = rhs_type.getRank();
753
754 // Collects lhs and rhs dimensions information.
755 DotDimensionsInfo lhs_dot_dimensions_info(
756 lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
757 dot_dimension_numbers.lhs_contracting_dimensions());
758 DotDimensionsInfo rhs_dot_dimensions_info(
759 rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
760 dot_dimension_numbers.rhs_contracting_dimensions());
761
762 // Transposes lhs shape to be in the order of {batch_dimensions,
763 // out_dimensions, contracting dimensions}.
764 llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
765 lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
766 lhs_dot_dimensions_info.out_dimensions().AxesArray(),
767 lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
768 llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
769 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
770 lhs_dot_dimensions_info.out_dimensions().SizesArray(),
771 lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
772 auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
773 loc,
774 RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
775 lhs,
776 DenseIntElementsAttr::get(
777 RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
778 lhs_permutation));
779
780 // Transposes rhs shape to be in the order of {batch_dimensions, contracting
781 // dimensions, out_dimensions}.
782 llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
783 rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
784 rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
785 rhs_dot_dimensions_info.out_dimensions().AxesArray());
786 llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
787 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
788 rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
789 rhs_dot_dimensions_info.out_dimensions().SizesArray());
790 auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
791 loc,
792 RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
793 rhs,
794 DenseIntElementsAttr::get(
795 RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
796 rhs_permutation));
797
798 // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
799 llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
800 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
801 llvm::ArrayRef<int64_t>{
802 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
803 llvm::ArrayRef<int64_t>{
804 lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
805 auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
806 loc,
807 RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
808 lhs_transposed.getResult());
809
810 // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
811 llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
812 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
813 llvm::ArrayRef<int64_t>{
814 rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
815 llvm::ArrayRef<int64_t>{
816 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
817 auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
818 loc,
819 RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
820 rhs_transposed.getResult());
821
822 // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
823 llvm::SmallVector<int64_t, 4> matmul_shape =
824 Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
825 llvm::ArrayRef<int64_t>{
826 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
827 llvm::ArrayRef<int64_t>{
828 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
829 auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
830 loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
831 lhs_flattend.getResult(), rhs_flattend.getResult());
832 auto reshaped =
833 rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
834 return reshaped.getResult();
835 }
836
837 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
838 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)839 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
840 auto dot_op = cast<mhlo::DotOp>(old_op);
841 auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
842 auto dot_dimension_numbers = DotDimensionNumbers::get(
843 /*lhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
844 /*rhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
845 /*lhs_contracting_dimensions=*/
846 rewriter.getI64TensorAttr({lhs_rank == 1 ? 0 : 1}),
847 /*rhs_contracting_dimensions=*/rewriter.getI64TensorAttr({0}),
848 rewriter.getContext());
849 return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
850 dot_op.getResult().getType().cast<ShapedType>(),
851 dot_op.getLoc());
852 }
853
854 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
855 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)856 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
857 auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
858 return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
859 dot_general_op.dot_dimension_numbers(),
860 dot_general_op.getResult().getType().cast<ShapedType>(),
861 dot_general_op.getLoc());
862 }
863
864 // Checks if the specified region is a binary reduction function what takes 2
865 // inputs, passes it to an instance of the specifiied reduction op and then
866 // returns the result.
867 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)868 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
869 Block &body = function.front();
870 if (body.getNumArguments() != 2) return failure();
871
872 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
873 if (!return_op) return failure();
874 if (return_op.getNumOperands() != 1) return failure();
875
876 ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
877 return_op.getOperands().front().getDefiningOp());
878 if (!reduce_op) return failure();
879 if (reduce_op.lhs() != body.getArgument(0) ||
880 reduce_op.rhs() != body.getArgument(1))
881 return failure();
882
883 return success();
884 }
885
886 // Check if the specified region is a binary reduction function what takes 2
887 // inputs and returns the second input. Functions like this are used by update
888 // scatter like ops.
889 template <>
MatchBinaryReduceFunction(mlir::Region & function)890 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
891 Block &body = function.front();
892 if (body.getNumArguments() != 2) return failure();
893
894 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
895 if (!return_op) return failure();
896 if (return_op.getNumOperands() != 1) return failure();
897 if (return_op.getOperands().front() != body.getArgument(1)) return failure();
898 return success();
899 }
900
901 // Converts an mhlo.reduce op with the specified BinaryOp as the reduction
902 // operation into the specified TfOp.
903 template <typename BinaryOp, typename TfOp>
904 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
905 public:
906 using OpConversionPattern::OpConversionPattern;
907
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const908 LogicalResult matchAndRewrite(
909 mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
910 ConversionPatternRewriter &rewriter) const final {
911 if (failed(MatchReduceOpInput(reduce_op))) return failure();
912
913 if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
914 return failure();
915
916 // In `MatchReduceOpInput` function, we already match that the
917 // "mhlo::ReduceOp" only has one input, one init_value and one result.
918 if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
919
920 auto input = reduce_op.inputs()[0];
921
922 // Get reduction dimension.
923 DenseIntElementsAttr dimension = reduce_op.dimensions();
924 SmallVector<int64_t, 4> reduce_dims;
925 for (const int64_t &dim : dimension.getValues<int64_t>()) {
926 reduce_dims.emplace_back(dim);
927 }
928 auto dim_type = RankedTensorType::get(
929 {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
930 auto reduction_indices = rewriter.create<ConstOp>(
931 reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
932
933 rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
934 reduction_indices,
935 /*keep_dim=*/rewriter.getBoolAttr(false));
936 return success();
937 }
938
939 private:
940 // Checks that the init value matches with the init value expected for the
941 // target TfOp.
942 virtual LogicalResult MatchInitValue(Value init_value) const = 0;
943
944 // This function tries to match that the "mhlo::ReduceOp" only has one
945 // input, one init_value and one result.
MatchReduceOpInput(mhlo::ReduceOp reduce_op) const946 LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
947 if (reduce_op.inputs().size() != 1 || reduce_op.init_values().size() != 1 ||
948 reduce_op.getResults().size() != 1)
949 return failure();
950
951 if (!reduce_op.inputs()[0].getType().isa<RankedTensorType>())
952 return failure();
953 if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
954 return success();
955 }
956 };
957
958 class ConvertReduceOpToTfSum
959 : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
960 public:
961 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
962
MatchInitValue(Value init_value) const963 LogicalResult MatchInitValue(Value init_value) const override {
964 DenseFPElementsAttr init_attr;
965 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
966 !init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
967 return failure();
968 return success();
969 }
970 };
971
972 class ConvertReduceOpToTfMax
973 : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
974 public:
975 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
976
MatchInitValue(Value init_value) const977 LogicalResult MatchInitValue(Value init_value) const override {
978 DenseFPElementsAttr init_attr;
979 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
980 !init_attr.isSplat() ||
981 !init_attr.getSplatValue<APFloat>().isInfinity() ||
982 !init_attr.getSplatValue<APFloat>().isNegative())
983 return failure();
984 return success();
985 }
986 };
987
988 class ConvertReduceOpToTfMin
989 : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
990 public:
991 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
992
MatchInitValue(Value init_value) const993 LogicalResult MatchInitValue(Value init_value) const override {
994 DenseFPElementsAttr init_attr;
995 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
996 !init_attr.isSplat() ||
997 !init_attr.getSplatValue<APFloat>().isInfinity() ||
998 init_attr.getSplatValue<APFloat>().isNegative())
999 return failure();
1000 return success();
1001 }
1002 };
1003
1004 template <typename TfReduce, typename TfArgReduce>
1005 class ConvertReduceOpToTfArgMinMax
1006 : public OpConversionPattern<mhlo::ReduceOp> {
1007 public:
1008 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1009 LogicalResult matchAndRewrite(
1010 mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
1011 ConversionPatternRewriter &rewriter) const final {
1012 if (reduce_op.inputs().size() != 2) return failure();
1013 if (reduce_op.dimensions().getNumElements() != 1) return failure();
1014
1015 // Check that the input init is the expected value.
1016 DenseElementsAttr input_init;
1017 if (!matchPattern(reduce_op.init_values().front(), m_Constant(&input_init)))
1018 return failure();
1019 if (!IsValueInitValue(input_init)) return failure();
1020
1021 // Check that the iota init is zero.
1022 DenseElementsAttr iota_init;
1023 if (!matchPattern(reduce_op.init_values().back(), m_Constant(&iota_init)))
1024 return failure();
1025 if (*iota_init.getIntValues().begin() != 0) return failure();
1026
1027 // Verify that the second argument is an Iota op along the same dimenion as
1028 // the reduction.
1029 Value iota = reduce_op.inputs().back();
1030 mhlo::BroadcastInDimOp iota_broadcast =
1031 llvm::dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1032 if (!iota_broadcast ||
1033 iota_broadcast.broadcast_dimensions() != reduce_op.dimensions())
1034 return failure();
1035 if (!llvm::isa<mhlo::IotaOp>(iota_broadcast.operand().getDefiningOp()))
1036 return failure();
1037
1038 // Match the reduction computation.
1039 if (failed(matchReduceComputation(reduce_op.body()))) return failure();
1040
1041 Value input = reduce_op.inputs().front();
1042 int64_t axis = reduce_op.dimensions().getValue<int64_t>({0});
1043
1044 auto dim_type = RankedTensorType::get({1}, rewriter.getI64Type());
1045 auto reduction_indices = rewriter.create<ConstOp>(
1046 reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr({axis}));
1047
1048 // Generate a Max and an ArgMax of as the mhlo op returns both while in TF
1049 // we have separate ops for them. If only one of them is used then the other
1050 // one will be garbage collected later.
1051 auto result_type = reduce_op.getType(0).cast<TupleType>();
1052 auto tf_reduce_op = rewriter.create<TfReduce>(
1053 reduce_op.getLoc(), result_type.getType(0), input, reduction_indices,
1054 /*keep_dim=*/rewriter.getBoolAttr(false));
1055 auto tf_argreduce_op = rewriter.create<TfArgReduce>(
1056 reduce_op.getLoc(), result_type.getType(1), input, reduction_indices);
1057
1058 // Pack the result into a TupleOp to match return type. The Tuple will be
1059 // optimised out by a subsequent pass.
1060 SmallVector<Value, 2> result{tf_reduce_op, tf_argreduce_op};
1061 rewriter.replaceOpWithNewOp<mhlo::TupleOp>(reduce_op, result);
1062 return success();
1063 }
1064
1065 // Pattern matches the following reduction function for ArgMax/ArgMin:
1066 // %0 = compare{GT}(%lhs_value, %rhs_value)
1067 // %1 = compare{NE}(%lhs_value, %lhs_value)
1068 // %2 = or(%0, %1)
1069 // %3 = select(%2, %lhs_value, %rhs_value)
1070 // %4 = compare{EQ}(%lhs_value, %rhs_value)
1071 // %5 = compare{LT}(%lhs_index, %rhs_index)
1072 // %6 = and(%4, %5)
1073 // %7 = or(%2, %6)
1074 // %8 = select(%7, %lhs_index, %rhs_index)
1075 // %9 = tuple(%3, %8)
1076 // return %9
matchReduceComputation(Region & computation) const1077 LogicalResult matchReduceComputation(Region &computation) const {
1078 Block &body = computation.front();
1079 if (body.getNumArguments() != 4) return failure();
1080
1081 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1082 if (!return_op) return failure();
1083 if (return_op.getNumOperands() != 1) return failure();
1084
1085 mhlo::TupleOp return_tuple = llvm::dyn_cast_or_null<mhlo::TupleOp>(
1086 return_op.getOperand(0).getDefiningOp());
1087 if (!return_tuple ||
1088 return_tuple.getType().cast<TupleType>().getTypes().size() != 2)
1089 return failure();
1090
1091 mhlo::SelectOp value_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
1092 return_tuple.getOperand(0).getDefiningOp());
1093 if (!value_select || value_select.on_true() != body.getArgument(0) ||
1094 value_select.on_false() != body.getArgument(2))
1095 return failure();
1096
1097 mhlo::OrOp value_or = llvm::dyn_cast_or_null<mhlo::OrOp>(
1098 value_select.getOperand(0).getDefiningOp());
1099 if (!value_or) return failure();
1100
1101 mhlo::SelectOp index_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
1102 return_tuple.getOperand(1).getDefiningOp());
1103 if (!index_select || index_select.on_true() != body.getArgument(1) ||
1104 index_select.on_false() != body.getArgument(3))
1105 return failure();
1106
1107 mhlo::CompareOp value_gt =
1108 llvm::dyn_cast_or_null<mhlo::CompareOp>(value_or.lhs().getDefiningOp());
1109 if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
1110 value_gt.lhs() != body.getArgument(0) ||
1111 value_gt.rhs() != body.getArgument(2))
1112 return failure();
1113
1114 mhlo::CompareOp value_ne =
1115 llvm::dyn_cast_or_null<mhlo::CompareOp>(value_or.rhs().getDefiningOp());
1116 if (!value_ne || value_ne.comparison_direction() != "NE" ||
1117 value_ne.lhs() != body.getArgument(0) ||
1118 value_ne.rhs() != body.getArgument(0))
1119 return failure();
1120
1121 mhlo::OrOp index_or =
1122 llvm::dyn_cast_or_null<mhlo::OrOp>(index_select.pred().getDefiningOp());
1123
1124 if (!index_or || index_or.lhs() != value_or) return failure();
1125
1126 mhlo::AndOp index_and =
1127 llvm::dyn_cast_or_null<mhlo::AndOp>(index_or.rhs().getDefiningOp());
1128 if (!index_and) return failure();
1129
1130 mhlo::CompareOp value_eq = llvm::dyn_cast_or_null<mhlo::CompareOp>(
1131 index_and.lhs().getDefiningOp());
1132 if (!value_eq || value_eq.comparison_direction() != "EQ" ||
1133 value_eq.lhs() != body.getArgument(0) ||
1134 value_eq.rhs() != body.getArgument(2))
1135 return failure();
1136
1137 mhlo::CompareOp index_lt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
1138 index_and.rhs().getDefiningOp());
1139 if (!index_lt || index_lt.comparison_direction() != "LT" ||
1140 index_lt.lhs() != body.getArgument(1) ||
1141 index_lt.rhs() != body.getArgument(3))
1142 return failure();
1143
1144 return success();
1145 }
1146
1147 virtual const char *CompareDirection() const = 0;
1148
1149 virtual bool IsValueInitValue(const DenseElementsAttr &attr) const = 0;
1150 };
1151
1152 class ConvertReduceOpToTfArgmax
1153 : public ConvertReduceOpToTfArgMinMax<TF::MaxOp, TF::ArgMaxOp> {
1154 public:
1155 using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
1156
CompareDirection() const1157 const char *CompareDirection() const override { return "GT"; }
IsValueInitValue(const DenseElementsAttr & attr) const1158 bool IsValueInitValue(const DenseElementsAttr &attr) const override {
1159 if (attr.getNumElements() != 1 ||
1160 !attr.getType().getElementType().isa<FloatType>())
1161 return false;
1162 auto value = *attr.getFloatValues().begin();
1163 return value.isNegative() && value.isInfinity();
1164 }
1165 };
1166
1167 class ConvertReduceOpToTfArgmin
1168 : public ConvertReduceOpToTfArgMinMax<TF::MinOp, TF::ArgMinOp> {
1169 public:
1170 using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
1171
CompareDirection() const1172 const char *CompareDirection() const override { return "LT"; }
IsValueInitValue(const DenseElementsAttr & attr) const1173 bool IsValueInitValue(const DenseElementsAttr &attr) const override {
1174 if (attr.getNumElements() != 1 ||
1175 !attr.getType().getElementType().isa<FloatType>())
1176 return false;
1177 auto value = *attr.getFloatValues().begin();
1178 return !value.isNegative() && value.isInfinity();
1179 }
1180 };
1181
1182 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
1183 public:
1184 using OpConversionPattern::OpConversionPattern;
1185
matchAndRewrite(mhlo::IotaOp iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1186 LogicalResult matchAndRewrite(
1187 mhlo::IotaOp iota_op, ArrayRef<Value> args,
1188 ConversionPatternRewriter &rewriter) const final {
1189 RankedTensorType type =
1190 iota_op.getType().dyn_cast_or_null<RankedTensorType>();
1191 if (!type) return failure();
1192
1193 const uint64_t dimension = iota_op.iota_dimension();
1194 Type element_type = type.getElementType();
1195 Attribute start, limit, delta;
1196 if (element_type.isa<FloatType>()) {
1197 start = rewriter.getFloatAttr(element_type, 0.0);
1198 limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
1199 delta = rewriter.getFloatAttr(element_type, 1.0);
1200 } else if (element_type.isa<IntegerType>()) {
1201 start = rewriter.getIntegerAttr(element_type, 0);
1202 limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
1203 delta = rewriter.getIntegerAttr(element_type, 1);
1204 } else {
1205 return failure();
1206 }
1207
1208 auto range_type =
1209 RankedTensorType::get({type.getShape()[dimension]}, element_type);
1210 Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
1211 Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
1212 Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
1213 Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
1214 start_op, limit_op, delta_op);
1215
1216 if (type.getRank() > 1) {
1217 std::vector<int64_t> reshape_shape(type.getRank(), 1);
1218 reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
1219 auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
1220 Value reshape_shape_op = rewriter.create<TF::ConstOp>(
1221 iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
1222 result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
1223 result, reshape_shape_op);
1224
1225 Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
1226 iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
1227 result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
1228 result, broadcast_shape_op);
1229 }
1230
1231 rewriter.replaceOp(iota_op, result);
1232 return success();
1233 }
1234 };
1235
1236 // A helper function for ConvertMaxPoolOp and ConvertAvgMaxPoolOp. Returns true
1237 // if the given ReduceWindowOp is a spatial pooling without dilation. If returns
1238 // true, also outputs the window strides and the TF padding mode ("VALID" or
1239 // "SAME").
IsSpatialPoolingWithoutDilation(mhlo::ReduceWindowOp rw,llvm::SmallVectorImpl<int64_t> * window_strides,std::string * padding_mode)1240 bool IsSpatialPoolingWithoutDilation(
1241 mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl<int64_t> *window_strides,
1242 std::string *padding_mode) {
1243 // tf.max_pool or tf.avg_pool need at least 3 dimensions (batch, spatial,
1244 // channel).
1245 const uint64_t rank = rw.window_dimensions().size();
1246 if (rank <= 2) return false;
1247
1248 if (rw.window_strides().hasValue()) {
1249 window_strides->insert(window_strides->end(),
1250 rw.window_strides()->getValues<int64_t>().begin(),
1251 rw.window_strides()->getValues<int64_t>().end());
1252 } else {
1253 window_strides->resize(rank, 1);
1254 }
1255
1256 llvm::SmallVector<int64_t, 10> padding;
1257 if (rw.padding().hasValue()) {
1258 padding.insert(padding.begin(), rw.padding()->getValues<int64_t>().begin(),
1259 rw.padding()->getValues<int64_t>().end());
1260 } else {
1261 padding.resize(2 * rank, 0);
1262 }
1263
1264 // Check that we don't do any reduction along the batch (first) and channel
1265 // (last) dimensions.
1266 const uint64_t batch_dim = 0;
1267 const uint64_t channel_dim = rank - 1;
1268 if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
1269 rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
1270 (*window_strides)[batch_dim] != 1 ||
1271 (*window_strides)[channel_dim] != 1 || padding[2 * batch_dim] != 0 ||
1272 padding[2 * batch_dim + 1] != 0 || padding[2 * channel_dim] != 0 ||
1273 padding[2 * channel_dim + 1] != 0)
1274 return false;
1275
1276 if (rw.window_dilations().hasValue() &&
1277 !(rw.window_dilations()->isSplat() &&
1278 rw.window_dilations()->getSplatValue<APInt>() == 1))
1279 return false;
1280
1281 if (rw.base_dilations().hasValue() &&
1282 !(rw.base_dilations()->isSplat() &&
1283 rw.base_dilations()->getSplatValue<APInt>() == 1))
1284 return false;
1285
1286 if (llvm::all_of(padding, [](int64_t i) { return i == 0; })) {
1287 *padding_mode = "VALID";
1288 return true;
1289 }
1290
1291 // Check that the individual padding values are corresponding to SAME
1292 // padding from TensorFlow.
1293 RankedTensorType input_type =
1294 rw.inputs()[0].getType().dyn_cast<RankedTensorType>();
1295 RankedTensorType output_type =
1296 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1297 if (!input_type || !output_type) return false;
1298
1299 for (uint64_t i = 1; i < rank - 1; ++i) {
1300 int64_t padding_size =
1301 (output_type.getShape()[i] - 1) * (*window_strides)[i] +
1302 rw.window_dimensions().getValue<int64_t>({i}) -
1303 input_type.getShape()[i];
1304 if (padding[2 * i] != tensorflow::MathUtil::FloorOfRatio(
1305 padding_size, static_cast<int64_t>(2)) ||
1306 padding[2 * i + 1] != tensorflow::MathUtil::CeilOfRatio(
1307 padding_size, static_cast<int64_t>(2)))
1308 return false;
1309 }
1310
1311 *padding_mode = "SAME";
1312 return true;
1313 }
1314
1315 // Maps the following representations of AvgPool in MHLO into a tf.AvgPool{3D}
1316 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
1317 // padding:
1318 // * div(reduce_sum_window(x), constant(sizeof(window)))
1319 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
1320 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
1321 public:
1322 using OpConversionPattern::OpConversionPattern;
1323
matchAndRewrite(mhlo::DivOp div_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1324 LogicalResult matchAndRewrite(
1325 mhlo::DivOp div_op, ArrayRef<Value> args,
1326 ConversionPatternRewriter &rewriter) const final {
1327 auto rw =
1328 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
1329 if (!rw || rw->getNumResults() != 1) return failure();
1330
1331 // Check that the reduce-window is a sum-reduce-window.
1332 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
1333 return failure();
1334
1335 // Check that this is a floating point reduce window with a rank of 4 or 5.
1336 const RankedTensorType rw_type =
1337 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1338 if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
1339 rw_type.getRank() <= 3 || rw_type.getRank() > 5)
1340 return failure();
1341
1342 // Check that the Div op doesn't do broadcasting on the output of the reduce
1343 // window.
1344 if (div_op.getType() != rw_type) return failure();
1345
1346 // If the init value isn't zero then it can't be an average pool.
1347 if (!isFloatZero(rw.init_values()[0])) return failure();
1348
1349 llvm::SmallVector<int64_t, 5> window_strides;
1350 std::string padding_mode;
1351 if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
1352 return rewriter.notifyMatchFailure(
1353 div_op, "not the root of spatial pooling without dilation");
1354 }
1355
1356 DenseFPElementsAttr divisor;
1357 if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
1358 // If the divisor is a constant then check that it matches with the number
1359 // of elements inside the window what is required for a VALID AvgPool.
1360 if (!divisor.isSplat()) return failure();
1361 int64_t window_size = 1;
1362 for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
1363 window_size *= w;
1364 }
1365 if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
1366 return failure();
1367
1368 if (padding_mode != "VALID") {
1369 return failure();
1370 }
1371
1372 return replaceWithAvgPool(
1373 div_op, rw.inputs()[0],
1374 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1375 window_strides, "VALID", rewriter);
1376 }
1377
1378 auto rw_rhs =
1379 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
1380 if (rw_rhs && rw_rhs.getNumResults() == 1) {
1381 // Check that RHS is a sum-reduce-window.
1382 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
1383 return failure();
1384
1385 // Check that the RHS is a reduce_window over a constant 1 input with 0 as
1386 // the init value.
1387 DenseFPElementsAttr rhs_input;
1388 if (!isFloatZero(rw_rhs.init_values()[0]) ||
1389 !matchPattern(rw_rhs.inputs()[0], m_Constant(&rhs_input)) ||
1390 !rhs_input.isSplat() ||
1391 !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
1392 return failure();
1393
1394 // Check that the two reduce window have the same window configuration.
1395 if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
1396 rw.window_strides() != rw_rhs.window_strides() ||
1397 rw.window_dilations() != rw_rhs.window_dilations() ||
1398 rw.base_dilations() != rw_rhs.base_dilations() ||
1399 rw.padding() != rw_rhs.padding())
1400 return failure();
1401
1402 return replaceWithAvgPool(
1403 div_op, rw.inputs()[0],
1404 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1405 window_strides, padding_mode, rewriter);
1406 }
1407
1408 return failure();
1409 }
1410
1411 private:
isFloatZero(Value value) const1412 bool isFloatZero(Value value) const {
1413 DenseFPElementsAttr initial_value;
1414 return matchPattern(value, m_Constant(&initial_value)) &&
1415 initial_value.getNumElements() == 1 &&
1416 initial_value.getValue<APFloat>({}).isZero();
1417 }
1418
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const1419 LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
1420 llvm::ArrayRef<int64_t> ksizes,
1421 llvm::ArrayRef<int64_t> kstrides,
1422 llvm::StringRef padding,
1423 ConversionPatternRewriter &rewriter) const {
1424 if (ksizes.size() == 4) {
1425 rewriter.replaceOpWithNewOp<AvgPoolOp>(
1426 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
1427 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1428 rewriter.getStringAttr("NHWC"));
1429 return success();
1430 } else if (ksizes.size() == 5) {
1431 rewriter.replaceOpWithNewOp<AvgPool3DOp>(
1432 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
1433 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1434 rewriter.getStringAttr("NDHWC"));
1435 return success();
1436 }
1437 return failure();
1438 }
1439 };
1440
1441 class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
1442 public:
1443 using OpConversionPattern::OpConversionPattern;
1444
matchAndRewrite(mhlo::ReduceWindowOp rw,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1445 LogicalResult matchAndRewrite(
1446 mhlo::ReduceWindowOp rw, ArrayRef<Value> args,
1447 ConversionPatternRewriter &rewriter) const final {
1448 // Check that the reduce-window is a max-reduce-window.
1449 if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(rw.body())))
1450 return failure();
1451
1452 // Check that this is a floating point reduce window with a rank of 4 or 5.
1453 const RankedTensorType rw_type =
1454 rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1455 if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
1456 rw_type.getRank() <= 3 || rw_type.getRank() > 5)
1457 return failure();
1458
1459 if (!isFloatMinusInfinity(rw.init_values()[0])) {
1460 return failure();
1461 }
1462
1463 llvm::SmallVector<int64_t, 5> window_strides;
1464 std::string padding_mode;
1465 if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
1466 return rewriter.notifyMatchFailure(
1467 rw, "not the root of spatial pooling without dilation");
1468 }
1469
1470 return replaceWithMaxPool(
1471 rw, rw.inputs()[0],
1472 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1473 window_strides, padding_mode, rewriter);
1474 }
1475
1476 private:
isFloatMinusInfinity(Value value) const1477 bool isFloatMinusInfinity(Value value) const {
1478 DenseFPElementsAttr float_value;
1479 if (!matchPattern(value, m_Constant(&float_value))) {
1480 return false;
1481 }
1482
1483 if (float_value.getNumElements() != 1) {
1484 return false;
1485 }
1486
1487 APFloat element = float_value.getValue<APFloat>({});
1488 if (!element.isInfinity()) {
1489 return false;
1490 }
1491 if (!element.isNegative()) {
1492 return false;
1493 }
1494
1495 return true;
1496 }
1497
replaceWithMaxPool(mhlo::ReduceWindowOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const1498 LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
1499 llvm::ArrayRef<int64_t> ksizes,
1500 llvm::ArrayRef<int64_t> kstrides,
1501 llvm::StringRef padding,
1502 ConversionPatternRewriter &rewriter) const {
1503 if (ksizes.size() == 4) {
1504 rewriter.replaceOpWithNewOp<MaxPoolOp>(
1505 op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
1506 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1507 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
1508 rewriter.getStringAttr("NHWC"));
1509 return success();
1510 } else if (ksizes.size() == 5) {
1511 rewriter.replaceOpWithNewOp<MaxPool3DOp>(
1512 op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
1513 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1514 rewriter.getStringAttr("NDHWC"));
1515 return success();
1516 }
1517 return failure();
1518 }
1519 };
1520
1521 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const1522 void getDependentDialects(DialectRegistry ®istry) const override {
1523 registry.insert<TF::TensorFlowDialect>();
1524 }
1525
1526 public:
1527 LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &)1528 LegalizeHloToTf(const LegalizeHloToTf &) {}
1529
getArgument() const1530 StringRef getArgument() const final { return "tf-legalize-hlo"; }
1531
getDescription() const1532 StringRef getDescription() const final {
1533 return "Legalize from HLO to the TF dialect";
1534 }
1535
1536 /// Performs the legalization to the TF dialect.
1537 void runOnFunction() override;
1538 };
1539
1540 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)1541 ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
1542 ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
1543 auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1544 rewriter.getIntegerType(64));
1545 auto attr = DenseElementsAttr::get(attr_type, shape);
1546 return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
1547 }
1548
IsSign(APFloat a,APFloat sign)1549 bool IsSign(APFloat a, APFloat sign) {
1550 if (a.isNaN() || a.isZero()) return a == sign;
1551 if (a.isNegative()) return sign.isExactlyValue(-1.0);
1552 return sign.isExactlyValue(1.0);
1553 }
1554
1555 // Returns whether the splat constant is the sign of the FloatTensor
FloatTensorIsSign(PatternRewriter & rewriter,ElementsAttr floatv,ElementsAttr sgn_cst)1556 bool FloatTensorIsSign(PatternRewriter &rewriter, ElementsAttr floatv,
1557 ElementsAttr sgn_cst) {
1558 if (!sgn_cst.isa<SplatElementsAttr>()) return false;
1559 auto sgn_cst_spl = sgn_cst.cast<SplatElementsAttr>().getSplatValue<APFloat>();
1560 if (floatv.isa<SplatElementsAttr>()) {
1561 auto floatv_spl = floatv.cast<SplatElementsAttr>().getSplatValue<APFloat>();
1562 return IsSign(floatv_spl, sgn_cst_spl);
1563 } else if (floatv.isa<DenseElementsAttr>()) {
1564 auto floatv_dns = floatv.cast<DenseFPElementsAttr>();
1565 return llvm::all_of(floatv_dns.getAttributeValues(), [&](Attribute value) {
1566 FloatAttr value_f = value.cast<FloatAttr>();
1567 return IsSign(value_f.getValue(), sgn_cst_spl);
1568 });
1569 }
1570 return false;
1571 }
1572
1573 // If index_vector_dim == indices.rank() then insert the implicit extra
1574 // dimension into indices to normalize everything to index_vector_dim ==
1575 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)1576 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
1577 ShapedType &indices_type,
1578 int64_t index_vector_dim,
1579 ConversionPatternRewriter &rewriter) {
1580 if (index_vector_dim == indices_type.getRank()) {
1581 llvm::SmallVector<int64_t, 4> new_start_indices_shape(
1582 indices_type.getShape().begin(), indices_type.getShape().end());
1583 new_start_indices_shape.push_back(1);
1584 indices_type = RankedTensorType::get(new_start_indices_shape,
1585 indices_type.getElementType());
1586 indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
1587 indices_type, indices);
1588 } else if (index_vector_dim != indices_type.getRank() - 1) {
1589 // If index_vector_dim isn't the last dimension in indices then it isn't
1590 // supported yet.
1591 // TODO(tberghammer): Transpose indices to support this usecase.
1592 return rewriter.notifyMatchFailure(
1593 parent_op,
1594 "index vector dim isn't the last dimension in start indices");
1595 }
1596 return success();
1597 }
1598
1599 // Check that `attr` is an R1 iota with integer element type starting from `0`
1600 // with `size` number of values.
IsIotaAttr(const DenseIntElementsAttr & attr,int64_t size)1601 bool IsIotaAttr(const DenseIntElementsAttr &attr, int64_t size) {
1602 if (!attr.getType().getElementType().isa<IntegerType>()) return false;
1603 if (attr.getType().getRank() != 1) return false;
1604 if (attr.getNumElements() != size) return false;
1605 int64_t iota = 0;
1606 for (auto s : attr.getIntValues()) {
1607 if (s != iota) return false;
1608 ++iota;
1609 }
1610 return true;
1611 }
1612
1613 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
1614 public:
1615 using OpConversionPattern::OpConversionPattern;
1616
matchAndRewrite(mhlo::GatherOp gather_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1617 LogicalResult matchAndRewrite(
1618 mhlo::GatherOp gather_op, ArrayRef<Value> args,
1619 ConversionPatternRewriter &rewriter) const final {
1620 Value operand = gather_op.operand();
1621 Value start_indices = gather_op.start_indices();
1622
1623 // Can only convert with static shaped gather.
1624 ShapedType operand_type = operand.getType().cast<ShapedType>();
1625 ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
1626 ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
1627 if (!operand_type.hasStaticShape() ||
1628 !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
1629 return failure();
1630 }
1631
1632 // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1633 int64_t index_vector_dim =
1634 gather_op.dimension_numbers().index_vector_dim().getInt();
1635 if (failed(NormalizeIndexVector(gather_op, start_indices,
1636 start_indices_type, index_vector_dim,
1637 rewriter))) {
1638 return failure();
1639 }
1640
1641 // Verify that start_index_map and collapsed_slice_dims contains the same
1642 // values.
1643 auto start_index_map = gather_op.dimension_numbers().start_index_map();
1644 auto collapsed_slice_dims =
1645 gather_op.dimension_numbers().collapsed_slice_dims();
1646 if (start_index_map.getNumElements() !=
1647 collapsed_slice_dims.getNumElements()) {
1648 return rewriter.notifyMatchFailure(
1649 gather_op,
1650 "different size for start index map and collapsed slice dims");
1651 }
1652 for (auto c : collapsed_slice_dims) {
1653 if (llvm::count(start_index_map, c) == 0) {
1654 return rewriter.notifyMatchFailure(
1655 gather_op, "collapsed slice dim isn't present in start index map");
1656 }
1657 }
1658
1659 // Verify that slice_sizes is 1 for the indexed dimensions and the full
1660 // shape for the rest of the dimensions.
1661 auto slice_sizes = gather_op.slice_sizes();
1662 int64_t index = 0;
1663 for (int64_t s : slice_sizes.getValues<int64_t>()) {
1664 if (llvm::count(start_index_map, index)) {
1665 if (s != 1) {
1666 return rewriter.notifyMatchFailure(gather_op,
1667 "unsupported slice sizes");
1668 }
1669 } else {
1670 if (s != operand_type.getShape()[index]) {
1671 return rewriter.notifyMatchFailure(gather_op,
1672 "unsupported slice sizes");
1673 }
1674 }
1675 ++index;
1676 }
1677
1678 // Verify that offset_dims are the tailing dimensions in the output tensor.
1679 auto offset_dims = gather_op.dimension_numbers().offset_dims();
1680 int64_t offset = start_indices_type.getRank() - 1;
1681 for (int64_t o : offset_dims.getValues<int64_t>()) {
1682 if (o != offset) {
1683 return rewriter.notifyMatchFailure(gather_op,
1684 "unsupported offset dims");
1685 }
1686 ++offset;
1687 }
1688
1689 // Transpose the operand to handle non-iota start index map.
1690 llvm::SmallVector<int64_t, 4> transpose_dimensions;
1691 llvm::SmallVector<int64_t, 4> transpose_shape;
1692 for (auto s : start_index_map) {
1693 transpose_dimensions.push_back(s.getZExtValue());
1694 transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
1695 }
1696 for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
1697 if (llvm::count(start_index_map, i) == 0) {
1698 transpose_dimensions.push_back(i);
1699 transpose_shape.push_back(operand_type.getShape()[i]);
1700 }
1701 }
1702 operand_type =
1703 RankedTensorType::get(transpose_shape, operand_type.getElementType());
1704 operand = rewriter.create<mhlo::TransposeOp>(
1705 gather_op.getLoc(), operand_type, operand,
1706 rewriter.getI64TensorAttr(transpose_dimensions));
1707
1708 rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
1709 start_indices);
1710 return success();
1711 }
1712 };
1713
1714 class ConvertWhileOp : public OpConversionPattern<mhlo::WhileOp> {
1715 public:
1716 using OpConversionPattern::OpConversionPattern;
1717
matchAndRewrite(mhlo::WhileOp while_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1718 LogicalResult matchAndRewrite(
1719 mhlo::WhileOp while_op, ArrayRef<Value> args,
1720 ConversionPatternRewriter &rewriter) const final {
1721 // HLO WhileOp should have two regions: cond and body.
1722 if (while_op->getNumRegions() != 2) return failure();
1723
1724 // This rule doesn't support mhlo::WhileOp with tuple inputs.
1725 for (auto type : while_op->getOperandTypes()) {
1726 if (type.isa<TupleType>()) return failure();
1727 }
1728
1729 // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp
1730 // currently doesn't support stateless and shape invariant, so these
1731 // parameters are set to the default values.
1732 OpBuilder builder(while_op);
1733 auto new_while = builder.create<TF::WhileRegionOp>(
1734 while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(),
1735 /*parallel_iterations=*/10,
1736 /*is_stateless=*/false, /*shape_invariant=*/false);
1737 new_while.cond().takeBody(while_op.getRegion(0));
1738 new_while.body().takeBody(while_op.getRegion(1));
1739 ReplaceReturnOp(new_while.cond(), rewriter);
1740 ReplaceReturnOp(new_while.body(), rewriter);
1741 rewriter.replaceOp(while_op, new_while.getResults());
1742 return success();
1743 }
1744
1745 private:
1746 // Replaces mhlo::ReturnOp to TF::Yield.
ReplaceReturnOp(Region & region,ConversionPatternRewriter & rewriter)1747 static void ReplaceReturnOp(Region ®ion,
1748 ConversionPatternRewriter &rewriter) {
1749 for (auto &block : region.getBlocks()) {
1750 Operation *terminator = block.getTerminator();
1751 auto return_op = llvm::dyn_cast_or_null<mhlo::ReturnOp>(terminator);
1752 if (return_op == nullptr) continue;
1753
1754 OpBuilder builder(return_op);
1755 builder.create<TF::YieldOp>(return_op.getLoc(), return_op->getOperands());
1756 rewriter.eraseOp(return_op);
1757 }
1758 }
1759 };
1760
1761 template <typename BinaryOp, typename TfOp>
1762 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
1763 public:
1764 using OpConversionPattern::OpConversionPattern;
1765
matchAndRewrite(mhlo::ScatterOp scatter_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1766 LogicalResult matchAndRewrite(
1767 mhlo::ScatterOp scatter_op, ArrayRef<Value> args,
1768 ConversionPatternRewriter &rewriter) const final {
1769 Value operand = scatter_op.operand();
1770 Value indices = scatter_op.scatter_indices();
1771 Value updates = scatter_op.updates();
1772 ShapedType operand_type = operand.getType().cast<ShapedType>();
1773 ShapedType indices_type = indices.getType().cast<ShapedType>();
1774 ShapedType updates_type = updates.getType().cast<ShapedType>();
1775
1776 // Can only convert with static shaped scatter.
1777 if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
1778 !updates_type.hasStaticShape()) {
1779 return failure();
1780 }
1781
1782 // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1783 int64_t index_vector_dim =
1784 scatter_op.scatter_dimension_numbers().index_vector_dim().getInt();
1785 if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
1786 index_vector_dim, rewriter))) {
1787 return failure();
1788 }
1789
1790 // Verify that inserted_window_dims and scatter_dims_to_operand_dims are
1791 // both an iota with the same number of elements as the last dimension of
1792 // start_indices.
1793 auto inserted_window_dims =
1794 scatter_op.scatter_dimension_numbers().inserted_window_dims();
1795 auto scatter_dims_to_operand_dims =
1796 scatter_op.scatter_dimension_numbers().scatter_dims_to_operand_dims();
1797 if (!IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) ||
1798 !IsIotaAttr(scatter_dims_to_operand_dims,
1799 indices_type.getShape().back())) {
1800 // TODO(tberghammer): Transform indices to support non-standard
1801 // scatter_dims_to_operand_dims.
1802 return rewriter.notifyMatchFailure(
1803 scatter_op,
1804 "unsupported inserted window dims and/or scatter dims to operand "
1805 "dims");
1806 }
1807
1808 // Verify that update window dims are the tailing dimensions in the update
1809 // tensor.
1810 auto update_window_dims =
1811 scatter_op.scatter_dimension_numbers().update_window_dims();
1812 int64_t offset = indices_type.getRank() - 1;
1813 for (int64_t o : update_window_dims.getValues<int64_t>()) {
1814 if (o != offset) {
1815 return rewriter.notifyMatchFailure(scatter_op,
1816 "unsupported update window dims");
1817 }
1818 ++offset;
1819 }
1820
1821 // Match the scatter computation against computations supported by TF.
1822 if (failed(MatchBinaryReduceFunction<BinaryOp>(
1823 scatter_op.update_computation()))) {
1824 return failure();
1825 }
1826
1827 rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
1828 scatter_op.getResult().getType(), operand,
1829 indices, updates);
1830 return success();
1831 }
1832 };
1833 using ConvertScatterAddOp =
1834 ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
1835 using ConvertScatterMaxOp =
1836 ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
1837 using ConvertScatterMinOp =
1838 ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
1839 using ConvertScatterSubOp =
1840 ConvertScatterOp<mhlo::SubOp, TF::TensorScatterSubOp>;
1841 using ConvertScatterUpdateOp =
1842 ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
1843
1844 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)1845 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
1846 auto pad_op = cast<mhlo::PadOp>(old_op);
1847 mlir::Location loc = pad_op.getLoc();
1848
1849 llvm::SmallVector<APInt, 8> padding;
1850 for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
1851 pad_op.edge_padding_high().getValues<APInt>())) {
1852 padding.push_back(std::get<0>(p));
1853 padding.push_back(std::get<1>(p));
1854 }
1855 auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
1856 rewriter.getI64Type());
1857 auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
1858 auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
1859 return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
1860 padding_op, pad_op.padding_value());
1861 }
1862
1863 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
1864 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)1865 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
1866 Value output) {
1867 // broadcast_dimensions is an increasing list by definition, thus it suffices
1868 // to check the first element.
1869 int64_t input_rank = broadcast_dimensions.getNumElements();
1870 int64_t output_rank = output.getType().cast<ShapedType>().getRank();
1871 return input_rank == 0 ||
1872 (broadcast_dimensions.getValue({0}).cast<IntegerAttr>().getInt() ==
1873 output_rank - input_rank);
1874 }
1875
1876 // Returns the intermediate shape that input tensor should be reshaped to during
1877 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)1878 ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
1879 DenseIntElementsAttr broadcast_dimensions,
1880 Value output) {
1881 // Initialize expanded shape with output rank and dimensions of 1.
1882 SmallVector<Attribute, 4> expanded_shape(
1883 output.getType().cast<ShapedType>().getRank(),
1884 /*Value=*/rewriter.getI64IntegerAttr(1));
1885
1886 // Set dimension sizes specified by broadcast_dimensions.
1887 ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
1888 for (auto x : llvm::enumerate(broadcast_dimensions)) {
1889 expanded_shape[x.value().getSExtValue()] =
1890 rewriter.getI64IntegerAttr(input_shape[x.index()]);
1891 }
1892
1893 // Create the expanded type wrapped in a ConstantOp.
1894 auto attr_type =
1895 RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
1896 rewriter.getIntegerType(64));
1897 auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
1898 return rewriter.create<ConstantOp>(output.getLoc(), attr_type, attr);
1899 }
1900
1901 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
1902
1903 /// Performs the lowering to XLA dialect.
runOnFunction()1904 void LegalizeHloToTf::runOnFunction() {
1905 MLIRContext &context = getContext();
1906
1907 // Add legalization patterns to the list.
1908 OwningRewritePatternList patterns(&getContext());
1909 PopulateLegalizeHloToTfPatterns(&patterns, &context);
1910
1911 ConversionTarget target(context);
1912 target.addLegalDialect<TensorFlowDialect>();
1913 target.addLegalOp<CallOp, ConstantOp>();
1914 target.addLegalOp<mhlo::TupleOp>();
1915 if (failed(
1916 applyPartialConversion(getFunction(), target, std::move(patterns)))) {
1917 getFunction().emitError("mhlo to TF legalization failed.");
1918 signalPassFailure();
1919 }
1920 }
1921
1922 static PassRegistration<LegalizeHloToTf> pass;
1923
1924 } // end namespace
1925
PopulateLegalizeHloToTfPatterns(OwningRewritePatternList * patterns,MLIRContext * context)1926 void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
1927 MLIRContext *context) {
1928 patterns->insert<ConvertWhileOp, ConvertAvgPoolOp, ConvertConvOp,
1929 ConvertConvBackpropInputOp, ConvertDynamicSliceOp,
1930 ConvertGatherOp, ConvertMaxPoolOp, ConvertScatterAddOp,
1931 ConvertScatterMaxOp, ConvertScatterMinOp,
1932 ConvertScatterSubOp, ConvertScatterUpdateOp, ConvertSliceOp,
1933 ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
1934 ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
1935 ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
1936 populateWithGenerated(*patterns);
1937 }
1938
CreateLegalizeHloToTfPass()1939 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
1940 return std::make_unique<LegalizeHloToTf>();
1941 }
1942
1943 } // end namespace TF
1944 } // end namespace mlir
1945