• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for legalizing HLO to TensorFlow.
17 
18 #include <cstdint>
19 #include <cstdlib>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "llvm/ADT/APInt.h"
28 #include "llvm/ADT/ArrayRef.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringRef.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
40 #include "mlir/IR/Location.h"  // from @llvm-project
41 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
42 #include "mlir/IR/Matchers.h"  // from @llvm-project
43 #include "mlir/IR/Operation.h"  // from @llvm-project
44 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
45 #include "mlir/IR/Region.h"  // from @llvm-project
46 #include "mlir/IR/Value.h"  // from @llvm-project
47 #include "mlir/Pass/Pass.h"  // from @llvm-project
48 #include "mlir/Support/LLVM.h"  // from @llvm-project
49 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
50 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
53 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
54 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
55 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
56 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
57 #include "tensorflow/core/framework/kernel_shape_util.h"
58 #include "tensorflow/core/lib/math/math_util.h"
59 
60 namespace mlir {
61 namespace TF {
62 namespace {
63 
64 using mhlo::DotDimensionNumbers;
65 
66 class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
67  public:
68   using OpConversionPattern::OpConversionPattern;
69 
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const70   LogicalResult matchAndRewrite(
71       mhlo::ConvOp conv_op, ArrayRef<Value> args,
72       ConversionPatternRewriter &rewriter) const final {
73     if (!IsSupportedConvOp(conv_op)) {
74       return failure();
75     }
76 
77     // Constructs strides array.
78     // For example, [2, 3] -> [1, 2, 3, 1].
79     SmallVector<int64_t, 4> strides({1});
80     for (const auto v :
81          conv_op.window_strides().getValue().getValues<int64_t>()) {
82       strides.emplace_back(v);
83     }
84     strides.emplace_back(1);
85 
86     // Constructs dilation array.
87     SmallVector<int64_t, 4> dilation;
88     if (auto rhs_dilation = conv_op.rhs_dilation()) {
89       // For example, [2, 3] -> [1, 2, 3, 1].
90       dilation.emplace_back(1);
91       dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
92                       rhs_dilation.getValue().getValues<int64_t>().end());
93       dilation.emplace_back(1);
94     } else {
95       // Default value
96       dilation = {1, 1, 1, 1};
97     }
98 
99     const int input_feature_dimension =
100         conv_op.dimension_numbers().input_feature_dimension().getInt();
101     const int input_channels =
102         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
103             input_feature_dimension);
104     int feature_group_count = conv_op.feature_group_count();
105 
106     if (feature_group_count != 1 && feature_group_count != input_channels) {
107       // Group convolution is not supported yet.
108       return failure();
109     }
110 
111     const int num_spatial_dims =
112         conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
113     const bool is_depthwise_conv = input_channels == feature_group_count;
114     std::string padding;
115     SmallVector<int64_t, 8> explicit_padding;
116     if (!conv_op.padding().hasValue() ||
117         (conv_op.padding().getValue().isSplat() &&
118          conv_op.padding()->getSplatValue<int64_t>() == 0)) {
119       padding = "VALID";
120     } else {
121       SmallVector<int64_t, 4> padding_array;
122       for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
123         padding_array.emplace_back(v);
124       }
125 
126       if (IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
127                         padding_array)) {
128         // Check if padding is "SAME".
129         padding = "SAME";
130       } else {
131         padding = "EXPLICIT";
132         explicit_padding.push_back(0);
133         explicit_padding.push_back(0);
134         explicit_padding.append(padding_array);
135         explicit_padding.push_back(0);
136         explicit_padding.push_back(0);
137       }
138     }
139 
140     CreateConvOp(conv_op, strides, padding, explicit_padding, dilation,
141                  is_depthwise_conv, input_channels, num_spatial_dims, rewriter);
142     return success();
143   };
144 
145  private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const146   bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
147                      ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
148                      ArrayRef<int64_t> padding_array) const {
149     auto input_spatial_dim_iter = conv_op.dimension_numbers()
150                                       .input_spatial_dimensions()
151                                       .getValues<int64_t>()
152                                       .begin();
153     auto kernel_spatial_dim_iter = conv_op.dimension_numbers()
154                                        .kernel_spatial_dimensions()
155                                        .getValues<int64_t>()
156                                        .begin();
157     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
158       int dim = i + 1;
159       int64_t output_size;
160       int64_t pad_low_int64;
161       int64_t pad_high_int64;
162       tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
163           conv_op.lhs().getType().cast<ShapedType>().getDimSize(
164               *(input_spatial_dim_iter + i)),
165           conv_op.rhs().getType().cast<ShapedType>().getDimSize(
166               *(kernel_spatial_dim_iter + i)),
167           dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
168           &pad_low_int64, &pad_high_int64);
169       if (!status.ok()) return false;
170       if (padding_array[2 * i] != pad_low_int64 ||
171           padding_array[2 * i + 1] != pad_high_int64)
172         return false;
173     }
174 
175     return true;
176   }
177 
178   // Returns true if the op needs reformat.
NeedsReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start) const179   bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim,
180                                        int spatial_dim_start,
181                                        int default_batch_dim,
182                                        int default_feature_dim,
183                                        int default_spatial_dim_start) const {
184     return batch_dim != default_batch_dim ||
185            feature_dim != default_feature_dim ||
186            spatial_dim_start != default_spatial_dim_start;
187   }
188 
189   // Gets reformat type and permutation attribute. Call this function only if
190   // NeedsReformatTypeAndPermutation returns true.
191   std::pair<RankedTensorType, DenseIntElementsAttr>
GetReformatTypeAndPermutation(int batch_dim,int feature_dim,int spatial_dim_start,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,RankedTensorType type,ConversionPatternRewriter & rewriter) const192   GetReformatTypeAndPermutation(int batch_dim, int feature_dim,
193                                 int spatial_dim_start, int default_batch_dim,
194                                 int default_feature_dim,
195                                 int default_spatial_dim_start,
196                                 int num_spatial_dims, RankedTensorType type,
197                                 ConversionPatternRewriter &rewriter) const {
198     auto shape = type.getShape();
199     llvm::SmallVector<int64_t, 4> permutation_array(num_spatial_dims + 2);
200     permutation_array[default_batch_dim] = batch_dim;
201     permutation_array[default_feature_dim] = feature_dim;
202     llvm::SmallVector<int64_t, 4> transposed_shape(num_spatial_dims + 2);
203     transposed_shape[default_batch_dim] = shape[batch_dim];
204     transposed_shape[default_feature_dim] = shape[feature_dim];
205     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
206       permutation_array[default_spatial_dim_start + i] = spatial_dim_start + i;
207       transposed_shape[default_spatial_dim_start + i] =
208           shape[spatial_dim_start + i];
209     }
210     auto new_type =
211         RankedTensorType::get(transposed_shape, type.getElementType());
212     auto permutation = DenseIntElementsAttr::get(
213         RankedTensorType::get({type.getRank()}, rewriter.getI64Type()),
214         permutation_array);
215     return {new_type, permutation};
216   }
217 
FormatToNHWC(Value value,int batch_dim,int feature_dim,DenseIntElementsAttr spatial_dimensions,int default_batch_dim,int default_feature_dim,int default_spatial_dim_start,int num_spatial_dims,ConversionPatternRewriter & rewriter) const218   Value FormatToNHWC(Value value, int batch_dim, int feature_dim,
219                      DenseIntElementsAttr spatial_dimensions,
220                      int default_batch_dim, int default_feature_dim,
221                      int default_spatial_dim_start, int num_spatial_dims,
222                      ConversionPatternRewriter &rewriter) const {
223     auto type = value.getType().cast<RankedTensorType>();
224     DenseIntElementsAttr permutation;
225     const int spatial_dim_start =
226         *spatial_dimensions.getValues<int64_t>().begin();
227     if (!NeedsReformatTypeAndPermutation(
228             batch_dim, feature_dim, spatial_dim_start, default_batch_dim,
229             default_feature_dim, default_spatial_dim_start)) {
230       // Transpose is not needed becasue the current format is "NHWC".
231       return value;
232     }
233     std::pair<RankedTensorType &, DenseIntElementsAttr &>(type, permutation) =
234         GetReformatTypeAndPermutation(batch_dim, feature_dim, spatial_dim_start,
235                                       default_batch_dim, default_feature_dim,
236                                       default_spatial_dim_start,
237                                       num_spatial_dims, type, rewriter);
238     return rewriter.create<mhlo::TransposeOp>(value.getLoc(), type, value,
239                                               permutation);
240   }
241 
CreateConvOp(mhlo::ConvOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> explicit_padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,int num_spatial_dims,ConversionPatternRewriter & rewriter) const242   void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
243                     StringRef padding, ArrayRef<int64_t> explicit_padding,
244                     ArrayRef<int64_t> dilation, bool is_depthwise_conv,
245                     int input_channels, int num_spatial_dims,
246                     ConversionPatternRewriter &rewriter) const {
247     // Transposes lhs and rhs if their formats are not NHWC.
248     Value lhs = FormatToNHWC(
249         conv_op.lhs(),
250         conv_op.dimension_numbers().input_batch_dimension().getInt(),
251         conv_op.dimension_numbers().input_feature_dimension().getInt(),
252         conv_op.dimension_numbers().input_spatial_dimensions(),
253         /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
254         /*default_spatial_dim_start=*/1, num_spatial_dims, rewriter);
255     Value rhs = FormatToNHWC(
256         conv_op.rhs(),
257         conv_op.dimension_numbers().kernel_input_feature_dimension().getInt(),
258         conv_op.dimension_numbers().kernel_output_feature_dimension().getInt(),
259         conv_op.dimension_numbers().kernel_spatial_dimensions(),
260         /*default_batch_dim=*/num_spatial_dims,
261         /*default_feature_dim=*/num_spatial_dims + 1,
262         /*default_spatial_dim_start=*/0, num_spatial_dims, rewriter);
263 
264     auto conv_output_type = conv_op.getType().cast<RankedTensorType>();
265     DenseIntElementsAttr permutation;
266     const bool need_transpose_output = NeedsReformatTypeAndPermutation(
267         conv_op.dimension_numbers().output_batch_dimension().getInt(),
268         conv_op.dimension_numbers().output_feature_dimension().getInt(),
269         *conv_op.dimension_numbers()
270              .output_spatial_dimensions()
271              .getValues<int64_t>()
272              .begin(),
273         /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1,
274         /*default_spatial_dim_start=*/1);
275     if (need_transpose_output) {
276       std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
277                                                             permutation) =
278           GetReformatTypeAndPermutation(
279               conv_op.dimension_numbers().output_batch_dimension().getInt(),
280               conv_op.dimension_numbers().output_feature_dimension().getInt(),
281               *conv_op.dimension_numbers()
282                    .output_spatial_dimensions()
283                    .getValues<int64_t>()
284                    .begin(),
285               /*default_batch_dim=*/0,
286               /*default_feature_dim=*/num_spatial_dims + 1,
287               /*default_spatial_dim_start=*/1, num_spatial_dims,
288               conv_output_type, rewriter);
289     }
290     Value output;
291     if (is_depthwise_conv) {
292       // Reshapes filter format to [filter_height, filter_width, in_channels,
293       // channel_multiplier] from HLO's [filter_height, filter_width, 1,
294       // in_channels * channel_multiplier] format.
295       auto filter_type = rhs.getType().cast<ShapedType>();
296       llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
297       llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
298                                                     hlo_filter_shape.end());
299       tf_filter_shape[2] = input_channels;
300       tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
301       auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
302           rhs.getLoc(),
303           RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
304           rhs);
305 
306       output = rewriter.create<DepthwiseConv2dNativeOp>(
307           conv_op.getLoc(), conv_output_type, lhs, reshaped_filter,
308           rewriter.getI64ArrayAttr(strides),
309           /*padding=*/rewriter.getStringAttr(padding),
310           /*explicit_paddings=*/rewriter.getI64ArrayAttr(explicit_padding),
311           /*data_format=*/rewriter.getStringAttr("NHWC"),
312           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
313     } else {
314       output = rewriter.create<Conv2DOp>(
315           conv_op.getLoc(), conv_output_type, lhs, rhs,
316           rewriter.getI64ArrayAttr(strides),
317           /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
318           /*padding=*/rewriter.getStringAttr(padding),
319           /*explicit_paddings=*/rewriter.getI64ArrayAttr(explicit_padding),
320           /*data_format=*/rewriter.getStringAttr("NHWC"),
321           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
322     }
323 
324     if (need_transpose_output) {
325       // Converts from "NHWC" format back to the original output format.
326       std::pair<RankedTensorType &, DenseIntElementsAttr &>(conv_output_type,
327                                                             permutation) =
328           GetReformatTypeAndPermutation(
329               /*batch_dim=*/0, /*feature_dim=*/num_spatial_dims + 1,
330               /*spatial_dim_start=*/1,
331               conv_op.dimension_numbers().output_batch_dimension().getInt(),
332               conv_op.dimension_numbers().output_feature_dimension().getInt(),
333               *conv_op.dimension_numbers()
334                    .output_spatial_dimensions()
335                    .getValues<int64_t>()
336                    .begin(),
337               num_spatial_dims, conv_output_type, rewriter);
338       output = rewriter.create<mhlo::TransposeOp>(
339           conv_op.getLoc(), conv_op.getType(), output, permutation);
340     }
341     rewriter.replaceOp(conv_op, {output});
342   }
343 
IsSupportedConvOp(mhlo::ConvOp conv_op) const344   bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
345     if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
346         !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
347         !conv_op.getType().cast<ShapedType>().hasStaticShape())
348       return false;
349 
350     // All ones in "lhs_dilation" means this "mhlo.conv" op should be
351     // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
352     if (conv_op.lhs_dilation().hasValue()) {
353       auto lhs_dilation = conv_op.lhs_dilation().getValue();
354       if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
355         return false;
356     }
357 
358     if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
359                                                         .getValue()
360                                                         .getType()
361                                                         .cast<ShapedType>()
362                                                         .getRank() != 1)
363       return false;
364 
365     int num_spatial_dims =
366         conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
367     // TODO(b/158636600): Currently we don't support 3D Convolution.
368     if (num_spatial_dims != 2) return false;
369 
370     return true;
371   }
372 };
373 
374 class ConvertConvBackpropInputOp : public OpConversionPattern<mhlo::ConvOp> {
375  public:
376   using OpConversionPattern::OpConversionPattern;
377 
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const378   LogicalResult matchAndRewrite(
379       mhlo::ConvOp conv_op, ArrayRef<Value> args,
380       ConversionPatternRewriter &rewriter) const final {
381     if (IsSupportedConvOp(conv_op, rewriter).failed()) {
382       return rewriter.notifyMatchFailure(
383           conv_op, "doesn't support to convert to ConvBackpropInputOp");
384     }
385 
386     // Constructs strides array from lhs_dilation.
387     // For example, [2, 3] -> [1, 2, 3, 1].
388     SmallVector<int64_t, 4> strides({1});
389     strides.append(
390         conv_op.lhs_dilation().getValue().getValues<int64_t>().begin(),
391         conv_op.lhs_dilation().getValue().getValues<int64_t>().end());
392     strides.emplace_back(1);
393 
394     // Constructs dilation array.
395     SmallVector<int64_t, 4> dilation;
396     if (auto rhs_dilation = conv_op.rhs_dilation()) {
397       // For example, [2, 3] -> [1, 2, 3, 1].
398       dilation.emplace_back(1);
399       dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
400                       rhs_dilation.getValue().getValues<int64_t>().end());
401       dilation.emplace_back(1);
402     } else {
403       // Default value
404       dilation = {1, 1, 1, 1};
405     }
406 
407     std::string padding;
408     if (!conv_op.padding().hasValue() ||
409         (conv_op.padding().getValue().isSplat() &&
410          conv_op.padding()->getSplatValue<int64_t>() == 0)) {
411       padding = "VALID";
412     } else {
413       const int num_spatial_dims = conv_op.dimension_numbers()
414                                        .input_spatial_dimensions()
415                                        .getNumElements();
416       if (!IsSamePadding(conv_op, num_spatial_dims, strides)) {
417         return rewriter.notifyMatchFailure(
418             conv_op, "requires padding to be SAME or VALID");
419       }
420       padding = "SAME";
421     }
422 
423     // Converts int64_t to int32_t.
424     llvm::SmallVector<int, 4> input_shape;
425     for (int64_t dim : conv_op.getType().cast<RankedTensorType>().getShape()) {
426       input_shape.push_back(dim);
427     }
428     auto input_sizes = rewriter.create<ConstOp>(
429         conv_op.getLoc(),
430         DenseIntElementsAttr::get(
431             RankedTensorType::get({static_cast<int64_t>(input_shape.size())},
432                                   rewriter.getI32Type()),
433             input_shape));
434     // Mirror the filter in the spatial dimensions.
435     auto filter = rewriter.create<mhlo::ReverseOp>(
436         conv_op.getLoc(), conv_op.rhs(),
437         conv_op.dimension_numbers().kernel_spatial_dimensions());
438     rewriter.replaceOpWithNewOp<Conv2DBackpropInputOp>(
439         conv_op, conv_op.getType(), input_sizes, filter, conv_op.lhs(),
440         rewriter.getI64ArrayAttr(strides),
441         /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
442         /*padding=*/rewriter.getStringAttr(padding),
443         /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
444         /*data_format=*/rewriter.getStringAttr("NHWC"),
445         /*dilations=*/rewriter.getI64ArrayAttr(dilation));
446     return success();
447   };
448 
449  private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides) const450   bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
451                      ArrayRef<int64_t> strides) const {
452     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
453       int dim = i + 1;
454       int stride = strides[dim];
455       int input_size = conv_op.getType().cast<ShapedType>().getDimSize(dim);
456       int output_size =
457           conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim);
458       if (output_size != (input_size + stride - 1) / stride) {
459         return false;
460       }
461     }
462 
463     return true;
464   }
465 
IsSupportedConvOp(mhlo::ConvOp conv_op,ConversionPatternRewriter & rewriter) const466   LogicalResult IsSupportedConvOp(mhlo::ConvOp conv_op,
467                                   ConversionPatternRewriter &rewriter) const {
468     if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
469         !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
470         !conv_op.getType().cast<ShapedType>().hasStaticShape())
471       return rewriter.notifyMatchFailure(conv_op, "requires static shape");
472 
473     const int input_feature_dimension =
474         conv_op.dimension_numbers().input_feature_dimension().getInt();
475     const int input_channels =
476         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
477             input_feature_dimension);
478     int feature_group_count = conv_op.feature_group_count();
479 
480     if (feature_group_count != 1 && feature_group_count != input_channels) {
481       // Group convolution is not supported yet.
482       return rewriter.notifyMatchFailure(conv_op,
483                                          "doesn't support group convolution");
484     }
485 
486     // Checks lhs_dilation is non-trivial.
487     if (!conv_op.lhs_dilation().hasValue()) {
488       return rewriter.notifyMatchFailure(conv_op,
489                                          "requires lhs_dilation attribute");
490     }
491     auto lhs_dilation = conv_op.lhs_dilation().getValue();
492     if (lhs_dilation.isSplat() && lhs_dilation.getSplatValue<int64_t>() == 1)
493       return rewriter.notifyMatchFailure(conv_op,
494                                          "requires non-trivial lhs_dilation");
495 
496     if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
497                                                         .getValue()
498                                                         .getType()
499                                                         .cast<ShapedType>()
500                                                         .getRank() != 1)
501       return rewriter.notifyMatchFailure(
502           conv_op, "requires window_strides to equal to one");
503 
504     int num_spatial_dims =
505         conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
506     // TODO(chhe): Currently we don't support 3D Convolution.
507     if (num_spatial_dims != 2)
508       return rewriter.notifyMatchFailure(conv_op,
509                                          "doesn't support more than 2D");
510 
511     // TODO(chhe): To support more data formats other than "NHWC".
512     // Checks format [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f].
513     if (conv_op.dimension_numbers().input_batch_dimension().getInt() != 0 ||
514         conv_op.dimension_numbers().input_feature_dimension().getInt() !=
515             num_spatial_dims + 1)
516       return rewriter.notifyMatchFailure(conv_op,
517                                          "requires input format [b, 0, 1, f]");
518     DenseIntElementsAttr input_spatial_dimensions =
519         conv_op.dimension_numbers().input_spatial_dimensions();
520     for (auto p :
521          llvm::enumerate(input_spatial_dimensions.getValues<int64_t>())) {
522       if (p.value() != p.index() + 1)
523         return rewriter.notifyMatchFailure(
524             conv_op, "requires input format [b, 0, 1, f]");
525     }
526 
527     // Checks output dimensions.
528     if (conv_op.dimension_numbers().output_batch_dimension().getInt() != 0 ||
529         conv_op.dimension_numbers().output_feature_dimension().getInt() !=
530             num_spatial_dims + 1)
531       return rewriter.notifyMatchFailure(conv_op,
532                                          "requires output format [b, 0, 1, f]");
533     DenseIntElementsAttr output_spatial_dimensions =
534         conv_op.dimension_numbers().output_spatial_dimensions();
535     for (auto p :
536          llvm::enumerate(output_spatial_dimensions.getValues<int64_t>())) {
537       if (p.value() != p.index() + 1)
538         return rewriter.notifyMatchFailure(
539             conv_op, "requires output format [0, 1, o, i]");
540     }
541 
542     // Checks kernel dimensions.
543     if (conv_op.dimension_numbers().kernel_input_feature_dimension().getInt() !=
544             num_spatial_dims + 1 ||
545         conv_op.dimension_numbers()
546                 .kernel_output_feature_dimension()
547                 .getInt() != num_spatial_dims)
548       return rewriter.notifyMatchFailure(conv_op,
549                                          "requires kernel format [b, 0, 1, f]");
550     DenseIntElementsAttr kernal_spatial_dimensions =
551         conv_op.dimension_numbers().kernel_spatial_dimensions();
552     for (auto p :
553          llvm::enumerate(kernal_spatial_dimensions.getValues<int64_t>())) {
554       if (p.value() != p.index())
555         return rewriter.notifyMatchFailure(
556             conv_op, "requires kernel format [0, 1, o, i]");
557     }
558 
559     return success();
560   }
561 };
562 
563 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
564  public:
565   using OpConversionPattern::OpConversionPattern;
566 
matchAndRewrite(mhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const567   LogicalResult matchAndRewrite(
568       mhlo::SliceOp slice_op, ArrayRef<Value> args,
569       ConversionPatternRewriter &rewriter) const final {
570     DenseIntElementsAttr strides = slice_op.strides();
571     // Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
572     if (!strides.isSplat() ||
573         strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
574       return failure();
575 
576     rewriter.setInsertionPointAfter(slice_op.getOperation());
577     auto start_indices = slice_op.start_indices();
578     auto limit_indices = slice_op.limit_indices();
579     std::vector<int64_t> size_values;
580     for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
581                                limit_indices.getValues<APInt>())) {
582       size_values.emplace_back(std::get<1>(pair).getSExtValue() -
583                                std::get<0>(pair).getSExtValue());
584     }
585 
586     RankedTensorType ty =
587         RankedTensorType::get({static_cast<int64_t>(size_values.size())},
588                               rewriter.getIntegerType(64));
589     auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
590     auto size = rewriter.create<ConstOp>(
591         slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
592     rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
593                                          slice_op.operand(), start, size);
594     return success();
595   };
596 };
597 
598 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
599  public:
600   using OpConversionPattern::OpConversionPattern;
601 
matchAndRewrite(mhlo::DynamicSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const602   LogicalResult matchAndRewrite(
603       mhlo::DynamicSliceOp op, ArrayRef<Value> args,
604       ConversionPatternRewriter &rewriter) const final {
605     ShapedType input_type = op.operand().getType().cast<ShapedType>();
606     if (!input_type.hasStaticShape()) return failure();
607     Type start_indices_element_type = op.start_indices()
608                                           .front()
609                                           .getType()
610                                           .cast<ShapedType>()
611                                           .getElementType();
612 
613     // Clamp indices to [0, input_size - output_size]
614     llvm::SmallVector<Value, 4> start_indices_vector;
615     start_indices_vector.reserve(op.start_indices().size());
616     Value clamp_min = rewriter.create<ConstOp>(
617         op.getLoc(), rewriter.getIntegerAttr(start_indices_element_type, 0));
618     for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
619       Value clamp_max = rewriter.create<ConstOp>(
620           op.getLoc(),
621           rewriter.getIntegerAttr(start_indices_element_type,
622                                   input_type.getShape()[i] -
623                                       op.slice_sizes().getValue<int64_t>({i})));
624       Value clamped_index = rewriter.create<mhlo::ClampOp>(
625           op.getLoc(), op.start_indices()[i].getType(), clamp_min,
626           op.start_indices()[i], clamp_max);
627       start_indices_vector.push_back(clamped_index);
628     }
629 
630     // Pack individual start indices to start indices tensor.
631     Type start_indices_type = RankedTensorType::get(
632         {static_cast<int64_t>(start_indices_vector.size())},
633         start_indices_element_type);
634     Value start_indices_op = rewriter.create<PackOp>(
635         op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
636 
637     Value slice_sices_op =
638         rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
639     rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
640                                          start_indices_op, slice_sices_op);
641     return success();
642   };
643 };
644 
645 // Appends all elements in `range` to `values`.
646 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)647 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
648   values.insert(values.end(), range.begin(), range.end());
649 }
650 
651 // Appends all elements in `range` to `values`.
652 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)653 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
654             RangeTs &&...ranges) {
655   values.insert(values.end(), range.begin(), range.end());
656   Append(values, ranges...);
657 }
658 
659 // Returns the number of elements in `range`.
660 template <typename Range>
Size(Range && range)661 size_t Size(Range &&range) {
662   return range.size();
663 }
664 
665 // Returns the total number of elements in a variadic number of `ranges`.
666 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)667 size_t Size(Range &&range, RangeTs &&...ranges) {
668   return range.size() + Size(std::forward<RangeTs>(ranges)...);
669 }
670 
671 // Concats all elements in `ranges` and returns a small vector as a result.
672 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)673 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
674   llvm::SmallVector<int64_t, 4> results;
675   results.reserve(Size(std::forward<RangeTs>(ranges)...));
676   Append(results, std::forward<RangeTs>(ranges)...);
677   return results;
678 }
679 
680 // A struct to hold axes and sizes for a set of dimensions.
681 struct DimensionVector {
AxesArraymlir::TF::__anon964666f10111::DimensionVector682   llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anon964666f10111::DimensionVector683   llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
684 
685   llvm::SmallVector<int64_t, 4> axes;
686   llvm::SmallVector<int64_t, 4> sizes;
687 };
688 
689 // A struct to hold information about dimensions of dot_general operands.
690 class DotDimensionsInfo {
691  public:
DotDimensionsInfo(ShapedType type,DenseIntElementsAttr batch_dimensions,DenseIntElementsAttr contracting_dimensions)692   DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
693                     DenseIntElementsAttr contracting_dimensions) {
694     const int rank = type.getRank();
695     for (const int dim : batch_dimensions.getValues<int64_t>()) {
696       batch_dimensions_.axes.push_back(dim);
697       batch_dimensions_.sizes.push_back(type.getDimSize(dim));
698     }
699 
700     for (const int dim : contracting_dimensions.getValues<int64_t>()) {
701       contracting_dimensions_.axes.push_back(dim);
702       contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
703     }
704 
705     for (int dim = 0; dim < rank; ++dim) {
706       if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
707           llvm::count(batch_dimensions_.axes, dim) > 0) {
708         continue;
709       }
710       out_dimensions_.axes.push_back(dim);
711       out_dimensions_.sizes.push_back(type.getDimSize(dim));
712     }
713   }
714 
batch_dimensions() const715   const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const716   const DimensionVector &contracting_dimensions() const {
717     return contracting_dimensions_;
718   }
719   // Out dimensions are any dimensions that are neither batch nor contracting
720   // dimensions, hence will be propagated to output shape.
out_dimensions() const721   const DimensionVector &out_dimensions() const { return out_dimensions_; }
722 
723   // Returns the total dimension size after flattening all contracting
724   // dimensions.
FlattenedContractingDimensionSize() const725   int FlattenedContractingDimensionSize() const {
726     return std::accumulate(contracting_dimensions_.sizes.begin(),
727                            contracting_dimensions_.sizes.end(), 1,
728                            std::multiplies<int64_t>());
729   }
730 
731   // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const732   int FlattenedOutDimensionSize() const {
733     return std::accumulate(out_dimensions_.sizes.begin(),
734                            out_dimensions_.sizes.end(), 1,
735                            std::multiplies<int64_t>());
736   }
737 
738  private:
739   DimensionVector batch_dimensions_;
740   DimensionVector contracting_dimensions_;
741   // Out dimensions are any dimensions that are neither batch nor contracting
742   // dimensions, hence will be propagated to output shape.
743   DimensionVector out_dimensions_;
744 };
745 
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbers dot_dimension_numbers,ShapedType result_type,mlir::Location loc)746 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
747                  DotDimensionNumbers dot_dimension_numbers,
748                  ShapedType result_type, mlir::Location loc) {
749   auto lhs_type = lhs.getType().cast<ShapedType>();
750   auto rhs_type = rhs.getType().cast<ShapedType>();
751   const int lhs_rank = lhs_type.getRank();
752   const int rhs_rank = rhs_type.getRank();
753 
754   // Collects lhs and rhs dimensions information.
755   DotDimensionsInfo lhs_dot_dimensions_info(
756       lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
757       dot_dimension_numbers.lhs_contracting_dimensions());
758   DotDimensionsInfo rhs_dot_dimensions_info(
759       rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
760       dot_dimension_numbers.rhs_contracting_dimensions());
761 
762   // Transposes lhs shape to be in the order of {batch_dimensions,
763   // out_dimensions, contracting dimensions}.
764   llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
765       lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
766       lhs_dot_dimensions_info.out_dimensions().AxesArray(),
767       lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
768   llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
769       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
770       lhs_dot_dimensions_info.out_dimensions().SizesArray(),
771       lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
772   auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
773       loc,
774       RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
775       lhs,
776       DenseIntElementsAttr::get(
777           RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
778           lhs_permutation));
779 
780   // Transposes rhs shape to be in the order of {batch_dimensions, contracting
781   // dimensions, out_dimensions}.
782   llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
783       rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
784       rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
785       rhs_dot_dimensions_info.out_dimensions().AxesArray());
786   llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
787       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
788       rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
789       rhs_dot_dimensions_info.out_dimensions().SizesArray());
790   auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
791       loc,
792       RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
793       rhs,
794       DenseIntElementsAttr::get(
795           RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
796           rhs_permutation));
797 
798   // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
799   llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
800       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
801       llvm::ArrayRef<int64_t>{
802           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
803       llvm::ArrayRef<int64_t>{
804           lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
805   auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
806       loc,
807       RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
808       lhs_transposed.getResult());
809 
810   // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
811   llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
812       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
813       llvm::ArrayRef<int64_t>{
814           rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
815       llvm::ArrayRef<int64_t>{
816           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
817   auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
818       loc,
819       RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
820       rhs_transposed.getResult());
821 
822   // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
823   llvm::SmallVector<int64_t, 4> matmul_shape =
824       Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
825                       llvm::ArrayRef<int64_t>{
826                           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
827                       llvm::ArrayRef<int64_t>{
828                           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
829   auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
830       loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
831       lhs_flattend.getResult(), rhs_flattend.getResult());
832   auto reshaped =
833       rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
834   return reshaped.getResult();
835 }
836 
837 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
838 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)839 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
840   auto dot_op = cast<mhlo::DotOp>(old_op);
841   auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
842   auto dot_dimension_numbers = DotDimensionNumbers::get(
843       /*lhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
844       /*rhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
845       /*lhs_contracting_dimensions=*/
846       rewriter.getI64TensorAttr({lhs_rank == 1 ? 0 : 1}),
847       /*rhs_contracting_dimensions=*/rewriter.getI64TensorAttr({0}),
848       rewriter.getContext());
849   return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
850                     dot_op.getResult().getType().cast<ShapedType>(),
851                     dot_op.getLoc());
852 }
853 
854 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
855 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)856 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
857   auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
858   return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
859                     dot_general_op.dot_dimension_numbers(),
860                     dot_general_op.getResult().getType().cast<ShapedType>(),
861                     dot_general_op.getLoc());
862 }
863 
864 // Checks if the specified region is a binary reduction function what takes 2
865 // inputs, passes it to an instance of the specifiied reduction op and then
866 // returns the result.
867 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)868 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
869   Block &body = function.front();
870   if (body.getNumArguments() != 2) return failure();
871 
872   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
873   if (!return_op) return failure();
874   if (return_op.getNumOperands() != 1) return failure();
875 
876   ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
877       return_op.getOperands().front().getDefiningOp());
878   if (!reduce_op) return failure();
879   if (reduce_op.lhs() != body.getArgument(0) ||
880       reduce_op.rhs() != body.getArgument(1))
881     return failure();
882 
883   return success();
884 }
885 
886 // Check if the specified region is a binary reduction function what takes 2
887 // inputs and returns the second input. Functions like this are used by update
888 // scatter like ops.
889 template <>
MatchBinaryReduceFunction(mlir::Region & function)890 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
891   Block &body = function.front();
892   if (body.getNumArguments() != 2) return failure();
893 
894   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
895   if (!return_op) return failure();
896   if (return_op.getNumOperands() != 1) return failure();
897   if (return_op.getOperands().front() != body.getArgument(1)) return failure();
898   return success();
899 }
900 
901 // Converts an mhlo.reduce op with the specified BinaryOp as the reduction
902 // operation into the specified TfOp.
903 template <typename BinaryOp, typename TfOp>
904 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
905  public:
906   using OpConversionPattern::OpConversionPattern;
907 
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const908   LogicalResult matchAndRewrite(
909       mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
910       ConversionPatternRewriter &rewriter) const final {
911     if (failed(MatchReduceOpInput(reduce_op))) return failure();
912 
913     if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
914       return failure();
915 
916     // In `MatchReduceOpInput` function, we already match that the
917     // "mhlo::ReduceOp" only has one input, one init_value and one result.
918     if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
919 
920     auto input = reduce_op.inputs()[0];
921 
922     // Get reduction dimension.
923     DenseIntElementsAttr dimension = reduce_op.dimensions();
924     SmallVector<int64_t, 4> reduce_dims;
925     for (const int64_t &dim : dimension.getValues<int64_t>()) {
926       reduce_dims.emplace_back(dim);
927     }
928     auto dim_type = RankedTensorType::get(
929         {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
930     auto reduction_indices = rewriter.create<ConstOp>(
931         reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
932 
933     rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
934                                       reduction_indices,
935                                       /*keep_dim=*/rewriter.getBoolAttr(false));
936     return success();
937   }
938 
939  private:
940   // Checks that the init value matches with the init value expected for the
941   // target TfOp.
942   virtual LogicalResult MatchInitValue(Value init_value) const = 0;
943 
944   // This function tries to match that the "mhlo::ReduceOp" only has one
945   // input, one init_value and one result.
MatchReduceOpInput(mhlo::ReduceOp reduce_op) const946   LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
947     if (reduce_op.inputs().size() != 1 || reduce_op.init_values().size() != 1 ||
948         reduce_op.getResults().size() != 1)
949       return failure();
950 
951     if (!reduce_op.inputs()[0].getType().isa<RankedTensorType>())
952       return failure();
953     if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
954     return success();
955   }
956 };
957 
958 class ConvertReduceOpToTfSum
959     : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
960  public:
961   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
962 
MatchInitValue(Value init_value) const963   LogicalResult MatchInitValue(Value init_value) const override {
964     DenseFPElementsAttr init_attr;
965     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
966         !init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
967       return failure();
968     return success();
969   }
970 };
971 
972 class ConvertReduceOpToTfMax
973     : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
974  public:
975   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
976 
MatchInitValue(Value init_value) const977   LogicalResult MatchInitValue(Value init_value) const override {
978     DenseFPElementsAttr init_attr;
979     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
980         !init_attr.isSplat() ||
981         !init_attr.getSplatValue<APFloat>().isInfinity() ||
982         !init_attr.getSplatValue<APFloat>().isNegative())
983       return failure();
984     return success();
985   }
986 };
987 
988 class ConvertReduceOpToTfMin
989     : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
990  public:
991   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
992 
MatchInitValue(Value init_value) const993   LogicalResult MatchInitValue(Value init_value) const override {
994     DenseFPElementsAttr init_attr;
995     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
996         !init_attr.isSplat() ||
997         !init_attr.getSplatValue<APFloat>().isInfinity() ||
998         init_attr.getSplatValue<APFloat>().isNegative())
999       return failure();
1000     return success();
1001   }
1002 };
1003 
1004 template <typename TfReduce, typename TfArgReduce>
1005 class ConvertReduceOpToTfArgMinMax
1006     : public OpConversionPattern<mhlo::ReduceOp> {
1007  public:
1008   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1009   LogicalResult matchAndRewrite(
1010       mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
1011       ConversionPatternRewriter &rewriter) const final {
1012     if (reduce_op.inputs().size() != 2) return failure();
1013     if (reduce_op.dimensions().getNumElements() != 1) return failure();
1014 
1015     // Check that the input init is the expected value.
1016     DenseElementsAttr input_init;
1017     if (!matchPattern(reduce_op.init_values().front(), m_Constant(&input_init)))
1018       return failure();
1019     if (!IsValueInitValue(input_init)) return failure();
1020 
1021     // Check that the iota init is zero.
1022     DenseElementsAttr iota_init;
1023     if (!matchPattern(reduce_op.init_values().back(), m_Constant(&iota_init)))
1024       return failure();
1025     if (*iota_init.getIntValues().begin() != 0) return failure();
1026 
1027     // Verify that the second argument is an Iota op along the same dimenion as
1028     // the reduction.
1029     Value iota = reduce_op.inputs().back();
1030     mhlo::BroadcastInDimOp iota_broadcast =
1031         llvm::dyn_cast_or_null<mhlo::BroadcastInDimOp>(iota.getDefiningOp());
1032     if (!iota_broadcast ||
1033         iota_broadcast.broadcast_dimensions() != reduce_op.dimensions())
1034       return failure();
1035     if (!llvm::isa<mhlo::IotaOp>(iota_broadcast.operand().getDefiningOp()))
1036       return failure();
1037 
1038     // Match the reduction computation.
1039     if (failed(matchReduceComputation(reduce_op.body()))) return failure();
1040 
1041     Value input = reduce_op.inputs().front();
1042     int64_t axis = reduce_op.dimensions().getValue<int64_t>({0});
1043 
1044     auto dim_type = RankedTensorType::get({1}, rewriter.getI64Type());
1045     auto reduction_indices = rewriter.create<ConstOp>(
1046         reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr({axis}));
1047 
1048     // Generate a Max and an ArgMax of as the mhlo op returns both while in TF
1049     // we have separate ops for them. If only one of them is used then the other
1050     // one will be garbage collected later.
1051     auto result_type = reduce_op.getType(0).cast<TupleType>();
1052     auto tf_reduce_op = rewriter.create<TfReduce>(
1053         reduce_op.getLoc(), result_type.getType(0), input, reduction_indices,
1054         /*keep_dim=*/rewriter.getBoolAttr(false));
1055     auto tf_argreduce_op = rewriter.create<TfArgReduce>(
1056         reduce_op.getLoc(), result_type.getType(1), input, reduction_indices);
1057 
1058     // Pack the result into a TupleOp to match return type. The Tuple will be
1059     // optimised out by a subsequent pass.
1060     SmallVector<Value, 2> result{tf_reduce_op, tf_argreduce_op};
1061     rewriter.replaceOpWithNewOp<mhlo::TupleOp>(reduce_op, result);
1062     return success();
1063   }
1064 
1065   // Pattern matches the following reduction function for ArgMax/ArgMin:
1066   // %0 = compare{GT}(%lhs_value, %rhs_value)
1067   // %1 = compare{NE}(%lhs_value, %lhs_value)
1068   // %2 = or(%0, %1)
1069   // %3 = select(%2, %lhs_value, %rhs_value)
1070   // %4 = compare{EQ}(%lhs_value, %rhs_value)
1071   // %5 = compare{LT}(%lhs_index, %rhs_index)
1072   // %6 = and(%4, %5)
1073   // %7 = or(%2, %6)
1074   // %8 = select(%7, %lhs_index, %rhs_index)
1075   // %9 = tuple(%3, %8)
1076   // return %9
matchReduceComputation(Region & computation) const1077   LogicalResult matchReduceComputation(Region &computation) const {
1078     Block &body = computation.front();
1079     if (body.getNumArguments() != 4) return failure();
1080 
1081     mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
1082     if (!return_op) return failure();
1083     if (return_op.getNumOperands() != 1) return failure();
1084 
1085     mhlo::TupleOp return_tuple = llvm::dyn_cast_or_null<mhlo::TupleOp>(
1086         return_op.getOperand(0).getDefiningOp());
1087     if (!return_tuple ||
1088         return_tuple.getType().cast<TupleType>().getTypes().size() != 2)
1089       return failure();
1090 
1091     mhlo::SelectOp value_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
1092         return_tuple.getOperand(0).getDefiningOp());
1093     if (!value_select || value_select.on_true() != body.getArgument(0) ||
1094         value_select.on_false() != body.getArgument(2))
1095       return failure();
1096 
1097     mhlo::OrOp value_or = llvm::dyn_cast_or_null<mhlo::OrOp>(
1098         value_select.getOperand(0).getDefiningOp());
1099     if (!value_or) return failure();
1100 
1101     mhlo::SelectOp index_select = llvm::dyn_cast_or_null<mhlo::SelectOp>(
1102         return_tuple.getOperand(1).getDefiningOp());
1103     if (!index_select || index_select.on_true() != body.getArgument(1) ||
1104         index_select.on_false() != body.getArgument(3))
1105       return failure();
1106 
1107     mhlo::CompareOp value_gt =
1108         llvm::dyn_cast_or_null<mhlo::CompareOp>(value_or.lhs().getDefiningOp());
1109     if (!value_gt || value_gt.comparison_direction() != CompareDirection() ||
1110         value_gt.lhs() != body.getArgument(0) ||
1111         value_gt.rhs() != body.getArgument(2))
1112       return failure();
1113 
1114     mhlo::CompareOp value_ne =
1115         llvm::dyn_cast_or_null<mhlo::CompareOp>(value_or.rhs().getDefiningOp());
1116     if (!value_ne || value_ne.comparison_direction() != "NE" ||
1117         value_ne.lhs() != body.getArgument(0) ||
1118         value_ne.rhs() != body.getArgument(0))
1119       return failure();
1120 
1121     mhlo::OrOp index_or =
1122         llvm::dyn_cast_or_null<mhlo::OrOp>(index_select.pred().getDefiningOp());
1123 
1124     if (!index_or || index_or.lhs() != value_or) return failure();
1125 
1126     mhlo::AndOp index_and =
1127         llvm::dyn_cast_or_null<mhlo::AndOp>(index_or.rhs().getDefiningOp());
1128     if (!index_and) return failure();
1129 
1130     mhlo::CompareOp value_eq = llvm::dyn_cast_or_null<mhlo::CompareOp>(
1131         index_and.lhs().getDefiningOp());
1132     if (!value_eq || value_eq.comparison_direction() != "EQ" ||
1133         value_eq.lhs() != body.getArgument(0) ||
1134         value_eq.rhs() != body.getArgument(2))
1135       return failure();
1136 
1137     mhlo::CompareOp index_lt = llvm::dyn_cast_or_null<mhlo::CompareOp>(
1138         index_and.rhs().getDefiningOp());
1139     if (!index_lt || index_lt.comparison_direction() != "LT" ||
1140         index_lt.lhs() != body.getArgument(1) ||
1141         index_lt.rhs() != body.getArgument(3))
1142       return failure();
1143 
1144     return success();
1145   }
1146 
1147   virtual const char *CompareDirection() const = 0;
1148 
1149   virtual bool IsValueInitValue(const DenseElementsAttr &attr) const = 0;
1150 };
1151 
1152 class ConvertReduceOpToTfArgmax
1153     : public ConvertReduceOpToTfArgMinMax<TF::MaxOp, TF::ArgMaxOp> {
1154  public:
1155   using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
1156 
CompareDirection() const1157   const char *CompareDirection() const override { return "GT"; }
IsValueInitValue(const DenseElementsAttr & attr) const1158   bool IsValueInitValue(const DenseElementsAttr &attr) const override {
1159     if (attr.getNumElements() != 1 ||
1160         !attr.getType().getElementType().isa<FloatType>())
1161       return false;
1162     auto value = *attr.getFloatValues().begin();
1163     return value.isNegative() && value.isInfinity();
1164   }
1165 };
1166 
1167 class ConvertReduceOpToTfArgmin
1168     : public ConvertReduceOpToTfArgMinMax<TF::MinOp, TF::ArgMinOp> {
1169  public:
1170   using ConvertReduceOpToTfArgMinMax::ConvertReduceOpToTfArgMinMax;
1171 
CompareDirection() const1172   const char *CompareDirection() const override { return "LT"; }
IsValueInitValue(const DenseElementsAttr & attr) const1173   bool IsValueInitValue(const DenseElementsAttr &attr) const override {
1174     if (attr.getNumElements() != 1 ||
1175         !attr.getType().getElementType().isa<FloatType>())
1176       return false;
1177     auto value = *attr.getFloatValues().begin();
1178     return !value.isNegative() && value.isInfinity();
1179   }
1180 };
1181 
1182 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
1183  public:
1184   using OpConversionPattern::OpConversionPattern;
1185 
matchAndRewrite(mhlo::IotaOp iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1186   LogicalResult matchAndRewrite(
1187       mhlo::IotaOp iota_op, ArrayRef<Value> args,
1188       ConversionPatternRewriter &rewriter) const final {
1189     RankedTensorType type =
1190         iota_op.getType().dyn_cast_or_null<RankedTensorType>();
1191     if (!type) return failure();
1192 
1193     const uint64_t dimension = iota_op.iota_dimension();
1194     Type element_type = type.getElementType();
1195     Attribute start, limit, delta;
1196     if (element_type.isa<FloatType>()) {
1197       start = rewriter.getFloatAttr(element_type, 0.0);
1198       limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
1199       delta = rewriter.getFloatAttr(element_type, 1.0);
1200     } else if (element_type.isa<IntegerType>()) {
1201       start = rewriter.getIntegerAttr(element_type, 0);
1202       limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
1203       delta = rewriter.getIntegerAttr(element_type, 1);
1204     } else {
1205       return failure();
1206     }
1207 
1208     auto range_type =
1209         RankedTensorType::get({type.getShape()[dimension]}, element_type);
1210     Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
1211     Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
1212     Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
1213     Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
1214                                                 start_op, limit_op, delta_op);
1215 
1216     if (type.getRank() > 1) {
1217       std::vector<int64_t> reshape_shape(type.getRank(), 1);
1218       reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
1219       auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
1220       Value reshape_shape_op = rewriter.create<TF::ConstOp>(
1221           iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
1222       result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
1223                                               result, reshape_shape_op);
1224 
1225       Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
1226           iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
1227       result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
1228                                                   result, broadcast_shape_op);
1229     }
1230 
1231     rewriter.replaceOp(iota_op, result);
1232     return success();
1233   }
1234 };
1235 
1236 // A helper function for ConvertMaxPoolOp and ConvertAvgMaxPoolOp. Returns true
1237 // if the given ReduceWindowOp is a spatial pooling without dilation. If returns
1238 // true, also outputs the window strides and the TF padding mode ("VALID" or
1239 // "SAME").
IsSpatialPoolingWithoutDilation(mhlo::ReduceWindowOp rw,llvm::SmallVectorImpl<int64_t> * window_strides,std::string * padding_mode)1240 bool IsSpatialPoolingWithoutDilation(
1241     mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl<int64_t> *window_strides,
1242     std::string *padding_mode) {
1243   // tf.max_pool or tf.avg_pool need at least 3 dimensions (batch, spatial,
1244   // channel).
1245   const uint64_t rank = rw.window_dimensions().size();
1246   if (rank <= 2) return false;
1247 
1248   if (rw.window_strides().hasValue()) {
1249     window_strides->insert(window_strides->end(),
1250                            rw.window_strides()->getValues<int64_t>().begin(),
1251                            rw.window_strides()->getValues<int64_t>().end());
1252   } else {
1253     window_strides->resize(rank, 1);
1254   }
1255 
1256   llvm::SmallVector<int64_t, 10> padding;
1257   if (rw.padding().hasValue()) {
1258     padding.insert(padding.begin(), rw.padding()->getValues<int64_t>().begin(),
1259                    rw.padding()->getValues<int64_t>().end());
1260   } else {
1261     padding.resize(2 * rank, 0);
1262   }
1263 
1264   // Check that we don't do any reduction along the batch (first) and channel
1265   // (last) dimensions.
1266   const uint64_t batch_dim = 0;
1267   const uint64_t channel_dim = rank - 1;
1268   if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
1269       rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
1270       (*window_strides)[batch_dim] != 1 ||
1271       (*window_strides)[channel_dim] != 1 || padding[2 * batch_dim] != 0 ||
1272       padding[2 * batch_dim + 1] != 0 || padding[2 * channel_dim] != 0 ||
1273       padding[2 * channel_dim + 1] != 0)
1274     return false;
1275 
1276   if (rw.window_dilations().hasValue() &&
1277       !(rw.window_dilations()->isSplat() &&
1278         rw.window_dilations()->getSplatValue<APInt>() == 1))
1279     return false;
1280 
1281   if (rw.base_dilations().hasValue() &&
1282       !(rw.base_dilations()->isSplat() &&
1283         rw.base_dilations()->getSplatValue<APInt>() == 1))
1284     return false;
1285 
1286   if (llvm::all_of(padding, [](int64_t i) { return i == 0; })) {
1287     *padding_mode = "VALID";
1288     return true;
1289   }
1290 
1291   // Check that the individual padding values are corresponding to SAME
1292   // padding from TensorFlow.
1293   RankedTensorType input_type =
1294       rw.inputs()[0].getType().dyn_cast<RankedTensorType>();
1295   RankedTensorType output_type =
1296       rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1297   if (!input_type || !output_type) return false;
1298 
1299   for (uint64_t i = 1; i < rank - 1; ++i) {
1300     int64_t padding_size =
1301         (output_type.getShape()[i] - 1) * (*window_strides)[i] +
1302         rw.window_dimensions().getValue<int64_t>({i}) -
1303         input_type.getShape()[i];
1304     if (padding[2 * i] != tensorflow::MathUtil::FloorOfRatio(
1305                               padding_size, static_cast<int64_t>(2)) ||
1306         padding[2 * i + 1] != tensorflow::MathUtil::CeilOfRatio(
1307                                   padding_size, static_cast<int64_t>(2)))
1308       return false;
1309   }
1310 
1311   *padding_mode = "SAME";
1312   return true;
1313 }
1314 
1315 // Maps the following representations of AvgPool in MHLO into a tf.AvgPool{3D}
1316 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
1317 // padding:
1318 // * div(reduce_sum_window(x), constant(sizeof(window)))
1319 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
1320 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
1321  public:
1322   using OpConversionPattern::OpConversionPattern;
1323 
matchAndRewrite(mhlo::DivOp div_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1324   LogicalResult matchAndRewrite(
1325       mhlo::DivOp div_op, ArrayRef<Value> args,
1326       ConversionPatternRewriter &rewriter) const final {
1327     auto rw =
1328         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
1329     if (!rw || rw->getNumResults() != 1) return failure();
1330 
1331     // Check that the reduce-window is a sum-reduce-window.
1332     if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
1333       return failure();
1334 
1335     // Check that this is a floating point reduce window with a rank of 4 or 5.
1336     const RankedTensorType rw_type =
1337         rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1338     if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
1339         rw_type.getRank() <= 3 || rw_type.getRank() > 5)
1340       return failure();
1341 
1342     // Check that the Div op doesn't do broadcasting on the output of the reduce
1343     // window.
1344     if (div_op.getType() != rw_type) return failure();
1345 
1346     // If the init value isn't zero then it can't be an average pool.
1347     if (!isFloatZero(rw.init_values()[0])) return failure();
1348 
1349     llvm::SmallVector<int64_t, 5> window_strides;
1350     std::string padding_mode;
1351     if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
1352       return rewriter.notifyMatchFailure(
1353           div_op, "not the root of spatial pooling without dilation");
1354     }
1355 
1356     DenseFPElementsAttr divisor;
1357     if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
1358       // If the divisor is a constant then check that it matches with the number
1359       // of elements inside the window what is required for a VALID AvgPool.
1360       if (!divisor.isSplat()) return failure();
1361       int64_t window_size = 1;
1362       for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
1363         window_size *= w;
1364       }
1365       if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
1366         return failure();
1367 
1368       if (padding_mode != "VALID") {
1369         return failure();
1370       }
1371 
1372       return replaceWithAvgPool(
1373           div_op, rw.inputs()[0],
1374           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1375           window_strides, "VALID", rewriter);
1376     }
1377 
1378     auto rw_rhs =
1379         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
1380     if (rw_rhs && rw_rhs.getNumResults() == 1) {
1381       // Check that RHS is a sum-reduce-window.
1382       if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
1383         return failure();
1384 
1385       // Check that the RHS is a reduce_window over a constant 1 input with 0 as
1386       // the init value.
1387       DenseFPElementsAttr rhs_input;
1388       if (!isFloatZero(rw_rhs.init_values()[0]) ||
1389           !matchPattern(rw_rhs.inputs()[0], m_Constant(&rhs_input)) ||
1390           !rhs_input.isSplat() ||
1391           !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
1392         return failure();
1393 
1394       // Check that the two reduce window have the same window configuration.
1395       if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
1396           rw.window_strides() != rw_rhs.window_strides() ||
1397           rw.window_dilations() != rw_rhs.window_dilations() ||
1398           rw.base_dilations() != rw_rhs.base_dilations() ||
1399           rw.padding() != rw_rhs.padding())
1400         return failure();
1401 
1402       return replaceWithAvgPool(
1403           div_op, rw.inputs()[0],
1404           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1405           window_strides, padding_mode, rewriter);
1406     }
1407 
1408     return failure();
1409   }
1410 
1411  private:
isFloatZero(Value value) const1412   bool isFloatZero(Value value) const {
1413     DenseFPElementsAttr initial_value;
1414     return matchPattern(value, m_Constant(&initial_value)) &&
1415            initial_value.getNumElements() == 1 &&
1416            initial_value.getValue<APFloat>({}).isZero();
1417   }
1418 
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const1419   LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
1420                                    llvm::ArrayRef<int64_t> ksizes,
1421                                    llvm::ArrayRef<int64_t> kstrides,
1422                                    llvm::StringRef padding,
1423                                    ConversionPatternRewriter &rewriter) const {
1424     if (ksizes.size() == 4) {
1425       rewriter.replaceOpWithNewOp<AvgPoolOp>(
1426           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
1427           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1428           rewriter.getStringAttr("NHWC"));
1429       return success();
1430     } else if (ksizes.size() == 5) {
1431       rewriter.replaceOpWithNewOp<AvgPool3DOp>(
1432           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
1433           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1434           rewriter.getStringAttr("NDHWC"));
1435       return success();
1436     }
1437     return failure();
1438   }
1439 };
1440 
1441 class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
1442  public:
1443   using OpConversionPattern::OpConversionPattern;
1444 
matchAndRewrite(mhlo::ReduceWindowOp rw,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1445   LogicalResult matchAndRewrite(
1446       mhlo::ReduceWindowOp rw, ArrayRef<Value> args,
1447       ConversionPatternRewriter &rewriter) const final {
1448     // Check that the reduce-window is a max-reduce-window.
1449     if (failed(MatchBinaryReduceFunction<mhlo::MaxOp>(rw.body())))
1450       return failure();
1451 
1452     // Check that this is a floating point reduce window with a rank of 4 or 5.
1453     const RankedTensorType rw_type =
1454         rw.getResult(0).getType().dyn_cast<RankedTensorType>();
1455     if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
1456         rw_type.getRank() <= 3 || rw_type.getRank() > 5)
1457       return failure();
1458 
1459     if (!isFloatMinusInfinity(rw.init_values()[0])) {
1460       return failure();
1461     }
1462 
1463     llvm::SmallVector<int64_t, 5> window_strides;
1464     std::string padding_mode;
1465     if (!IsSpatialPoolingWithoutDilation(rw, &window_strides, &padding_mode)) {
1466       return rewriter.notifyMatchFailure(
1467           rw, "not the root of spatial pooling without dilation");
1468     }
1469 
1470     return replaceWithMaxPool(
1471         rw, rw.inputs()[0],
1472         llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
1473         window_strides, padding_mode, rewriter);
1474   }
1475 
1476  private:
isFloatMinusInfinity(Value value) const1477   bool isFloatMinusInfinity(Value value) const {
1478     DenseFPElementsAttr float_value;
1479     if (!matchPattern(value, m_Constant(&float_value))) {
1480       return false;
1481     }
1482 
1483     if (float_value.getNumElements() != 1) {
1484       return false;
1485     }
1486 
1487     APFloat element = float_value.getValue<APFloat>({});
1488     if (!element.isInfinity()) {
1489       return false;
1490     }
1491     if (!element.isNegative()) {
1492       return false;
1493     }
1494 
1495     return true;
1496   }
1497 
replaceWithMaxPool(mhlo::ReduceWindowOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const1498   LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
1499                                    llvm::ArrayRef<int64_t> ksizes,
1500                                    llvm::ArrayRef<int64_t> kstrides,
1501                                    llvm::StringRef padding,
1502                                    ConversionPatternRewriter &rewriter) const {
1503     if (ksizes.size() == 4) {
1504       rewriter.replaceOpWithNewOp<MaxPoolOp>(
1505           op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
1506           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1507           /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
1508           rewriter.getStringAttr("NHWC"));
1509       return success();
1510     } else if (ksizes.size() == 5) {
1511       rewriter.replaceOpWithNewOp<MaxPool3DOp>(
1512           op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes),
1513           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
1514           rewriter.getStringAttr("NDHWC"));
1515       return success();
1516     }
1517     return failure();
1518   }
1519 };
1520 
1521 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const1522   void getDependentDialects(DialectRegistry &registry) const override {
1523     registry.insert<TF::TensorFlowDialect>();
1524   }
1525 
1526  public:
1527   LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &)1528   LegalizeHloToTf(const LegalizeHloToTf &) {}
1529 
getArgument() const1530   StringRef getArgument() const final { return "tf-legalize-hlo"; }
1531 
getDescription() const1532   StringRef getDescription() const final {
1533     return "Legalize from HLO to the TF dialect";
1534   }
1535 
1536   /// Performs the legalization to the TF dialect.
1537   void runOnFunction() override;
1538 };
1539 
1540 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)1541 ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
1542   ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
1543   auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1544                                          rewriter.getIntegerType(64));
1545   auto attr = DenseElementsAttr::get(attr_type, shape);
1546   return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
1547 }
1548 
IsSign(APFloat a,APFloat sign)1549 bool IsSign(APFloat a, APFloat sign) {
1550   if (a.isNaN() || a.isZero()) return a == sign;
1551   if (a.isNegative()) return sign.isExactlyValue(-1.0);
1552   return sign.isExactlyValue(1.0);
1553 }
1554 
1555 // Returns whether the splat constant is the sign of the FloatTensor
FloatTensorIsSign(PatternRewriter & rewriter,ElementsAttr floatv,ElementsAttr sgn_cst)1556 bool FloatTensorIsSign(PatternRewriter &rewriter, ElementsAttr floatv,
1557                        ElementsAttr sgn_cst) {
1558   if (!sgn_cst.isa<SplatElementsAttr>()) return false;
1559   auto sgn_cst_spl = sgn_cst.cast<SplatElementsAttr>().getSplatValue<APFloat>();
1560   if (floatv.isa<SplatElementsAttr>()) {
1561     auto floatv_spl = floatv.cast<SplatElementsAttr>().getSplatValue<APFloat>();
1562     return IsSign(floatv_spl, sgn_cst_spl);
1563   } else if (floatv.isa<DenseElementsAttr>()) {
1564     auto floatv_dns = floatv.cast<DenseFPElementsAttr>();
1565     return llvm::all_of(floatv_dns.getAttributeValues(), [&](Attribute value) {
1566       FloatAttr value_f = value.cast<FloatAttr>();
1567       return IsSign(value_f.getValue(), sgn_cst_spl);
1568     });
1569   }
1570   return false;
1571 }
1572 
1573 // If index_vector_dim == indices.rank() then insert the implicit extra
1574 // dimension into indices to normalize everything to index_vector_dim ==
1575 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)1576 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
1577                                    ShapedType &indices_type,
1578                                    int64_t index_vector_dim,
1579                                    ConversionPatternRewriter &rewriter) {
1580   if (index_vector_dim == indices_type.getRank()) {
1581     llvm::SmallVector<int64_t, 4> new_start_indices_shape(
1582         indices_type.getShape().begin(), indices_type.getShape().end());
1583     new_start_indices_shape.push_back(1);
1584     indices_type = RankedTensorType::get(new_start_indices_shape,
1585                                          indices_type.getElementType());
1586     indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
1587                                                indices_type, indices);
1588   } else if (index_vector_dim != indices_type.getRank() - 1) {
1589     // If index_vector_dim isn't the last dimension in indices then it isn't
1590     // supported yet.
1591     // TODO(tberghammer): Transpose indices to support this usecase.
1592     return rewriter.notifyMatchFailure(
1593         parent_op,
1594         "index vector dim isn't the last dimension in start indices");
1595   }
1596   return success();
1597 }
1598 
1599 // Check that `attr` is an R1 iota with integer element type starting from `0`
1600 // with `size` number of values.
IsIotaAttr(const DenseIntElementsAttr & attr,int64_t size)1601 bool IsIotaAttr(const DenseIntElementsAttr &attr, int64_t size) {
1602   if (!attr.getType().getElementType().isa<IntegerType>()) return false;
1603   if (attr.getType().getRank() != 1) return false;
1604   if (attr.getNumElements() != size) return false;
1605   int64_t iota = 0;
1606   for (auto s : attr.getIntValues()) {
1607     if (s != iota) return false;
1608     ++iota;
1609   }
1610   return true;
1611 }
1612 
1613 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
1614  public:
1615   using OpConversionPattern::OpConversionPattern;
1616 
matchAndRewrite(mhlo::GatherOp gather_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1617   LogicalResult matchAndRewrite(
1618       mhlo::GatherOp gather_op, ArrayRef<Value> args,
1619       ConversionPatternRewriter &rewriter) const final {
1620     Value operand = gather_op.operand();
1621     Value start_indices = gather_op.start_indices();
1622 
1623     // Can only convert with static shaped gather.
1624     ShapedType operand_type = operand.getType().cast<ShapedType>();
1625     ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
1626     ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
1627     if (!operand_type.hasStaticShape() ||
1628         !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
1629       return failure();
1630     }
1631 
1632     // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1633     int64_t index_vector_dim =
1634         gather_op.dimension_numbers().index_vector_dim().getInt();
1635     if (failed(NormalizeIndexVector(gather_op, start_indices,
1636                                     start_indices_type, index_vector_dim,
1637                                     rewriter))) {
1638       return failure();
1639     }
1640 
1641     // Verify that start_index_map and collapsed_slice_dims contains the same
1642     // values.
1643     auto start_index_map = gather_op.dimension_numbers().start_index_map();
1644     auto collapsed_slice_dims =
1645         gather_op.dimension_numbers().collapsed_slice_dims();
1646     if (start_index_map.getNumElements() !=
1647         collapsed_slice_dims.getNumElements()) {
1648       return rewriter.notifyMatchFailure(
1649           gather_op,
1650           "different size for start index map and collapsed slice dims");
1651     }
1652     for (auto c : collapsed_slice_dims) {
1653       if (llvm::count(start_index_map, c) == 0) {
1654         return rewriter.notifyMatchFailure(
1655             gather_op, "collapsed slice dim isn't present in start index map");
1656       }
1657     }
1658 
1659     // Verify that slice_sizes is 1 for the indexed dimensions and the full
1660     // shape for the rest of the dimensions.
1661     auto slice_sizes = gather_op.slice_sizes();
1662     int64_t index = 0;
1663     for (int64_t s : slice_sizes.getValues<int64_t>()) {
1664       if (llvm::count(start_index_map, index)) {
1665         if (s != 1) {
1666           return rewriter.notifyMatchFailure(gather_op,
1667                                              "unsupported slice sizes");
1668         }
1669       } else {
1670         if (s != operand_type.getShape()[index]) {
1671           return rewriter.notifyMatchFailure(gather_op,
1672                                              "unsupported slice sizes");
1673         }
1674       }
1675       ++index;
1676     }
1677 
1678     // Verify that offset_dims are the tailing dimensions in the output tensor.
1679     auto offset_dims = gather_op.dimension_numbers().offset_dims();
1680     int64_t offset = start_indices_type.getRank() - 1;
1681     for (int64_t o : offset_dims.getValues<int64_t>()) {
1682       if (o != offset) {
1683         return rewriter.notifyMatchFailure(gather_op,
1684                                            "unsupported offset dims");
1685       }
1686       ++offset;
1687     }
1688 
1689     // Transpose the operand to handle non-iota start index map.
1690     llvm::SmallVector<int64_t, 4> transpose_dimensions;
1691     llvm::SmallVector<int64_t, 4> transpose_shape;
1692     for (auto s : start_index_map) {
1693       transpose_dimensions.push_back(s.getZExtValue());
1694       transpose_shape.push_back(operand_type.getShape()[s.getZExtValue()]);
1695     }
1696     for (int64_t i = 0, e = operand_type.getRank(); i < e; ++i) {
1697       if (llvm::count(start_index_map, i) == 0) {
1698         transpose_dimensions.push_back(i);
1699         transpose_shape.push_back(operand_type.getShape()[i]);
1700       }
1701     }
1702     operand_type =
1703         RankedTensorType::get(transpose_shape, operand_type.getElementType());
1704     operand = rewriter.create<mhlo::TransposeOp>(
1705         gather_op.getLoc(), operand_type, operand,
1706         rewriter.getI64TensorAttr(transpose_dimensions));
1707 
1708     rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
1709                                                 start_indices);
1710     return success();
1711   }
1712 };
1713 
1714 class ConvertWhileOp : public OpConversionPattern<mhlo::WhileOp> {
1715  public:
1716   using OpConversionPattern::OpConversionPattern;
1717 
matchAndRewrite(mhlo::WhileOp while_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1718   LogicalResult matchAndRewrite(
1719       mhlo::WhileOp while_op, ArrayRef<Value> args,
1720       ConversionPatternRewriter &rewriter) const final {
1721     // HLO WhileOp should have two regions: cond and body.
1722     if (while_op->getNumRegions() != 2) return failure();
1723 
1724     // This rule doesn't support mhlo::WhileOp with tuple inputs.
1725     for (auto type : while_op->getOperandTypes()) {
1726       if (type.isa<TupleType>()) return failure();
1727     }
1728 
1729     // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp
1730     // currently doesn't support stateless and shape invariant, so these
1731     // parameters are set to the default values.
1732     OpBuilder builder(while_op);
1733     auto new_while = builder.create<TF::WhileRegionOp>(
1734         while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(),
1735         /*parallel_iterations=*/10,
1736         /*is_stateless=*/false, /*shape_invariant=*/false);
1737     new_while.cond().takeBody(while_op.getRegion(0));
1738     new_while.body().takeBody(while_op.getRegion(1));
1739     ReplaceReturnOp(new_while.cond(), rewriter);
1740     ReplaceReturnOp(new_while.body(), rewriter);
1741     rewriter.replaceOp(while_op, new_while.getResults());
1742     return success();
1743   }
1744 
1745  private:
1746   // Replaces mhlo::ReturnOp to TF::Yield.
ReplaceReturnOp(Region & region,ConversionPatternRewriter & rewriter)1747   static void ReplaceReturnOp(Region &region,
1748                               ConversionPatternRewriter &rewriter) {
1749     for (auto &block : region.getBlocks()) {
1750       Operation *terminator = block.getTerminator();
1751       auto return_op = llvm::dyn_cast_or_null<mhlo::ReturnOp>(terminator);
1752       if (return_op == nullptr) continue;
1753 
1754       OpBuilder builder(return_op);
1755       builder.create<TF::YieldOp>(return_op.getLoc(), return_op->getOperands());
1756       rewriter.eraseOp(return_op);
1757     }
1758   }
1759 };
1760 
1761 template <typename BinaryOp, typename TfOp>
1762 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
1763  public:
1764   using OpConversionPattern::OpConversionPattern;
1765 
matchAndRewrite(mhlo::ScatterOp scatter_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1766   LogicalResult matchAndRewrite(
1767       mhlo::ScatterOp scatter_op, ArrayRef<Value> args,
1768       ConversionPatternRewriter &rewriter) const final {
1769     Value operand = scatter_op.operand();
1770     Value indices = scatter_op.scatter_indices();
1771     Value updates = scatter_op.updates();
1772     ShapedType operand_type = operand.getType().cast<ShapedType>();
1773     ShapedType indices_type = indices.getType().cast<ShapedType>();
1774     ShapedType updates_type = updates.getType().cast<ShapedType>();
1775 
1776     // Can only convert with static shaped scatter.
1777     if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
1778         !updates_type.hasStaticShape()) {
1779       return failure();
1780     }
1781 
1782     // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1783     int64_t index_vector_dim =
1784         scatter_op.scatter_dimension_numbers().index_vector_dim().getInt();
1785     if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
1786                                     index_vector_dim, rewriter))) {
1787       return failure();
1788     }
1789 
1790     // Verify that inserted_window_dims and scatter_dims_to_operand_dims are
1791     // both an iota with the same number of elements as the last dimension of
1792     // start_indices.
1793     auto inserted_window_dims =
1794         scatter_op.scatter_dimension_numbers().inserted_window_dims();
1795     auto scatter_dims_to_operand_dims =
1796         scatter_op.scatter_dimension_numbers().scatter_dims_to_operand_dims();
1797     if (!IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) ||
1798         !IsIotaAttr(scatter_dims_to_operand_dims,
1799                     indices_type.getShape().back())) {
1800       // TODO(tberghammer): Transform indices to support non-standard
1801       // scatter_dims_to_operand_dims.
1802       return rewriter.notifyMatchFailure(
1803           scatter_op,
1804           "unsupported inserted window dims and/or scatter dims to operand "
1805           "dims");
1806     }
1807 
1808     // Verify that update window dims are the tailing dimensions in the update
1809     // tensor.
1810     auto update_window_dims =
1811         scatter_op.scatter_dimension_numbers().update_window_dims();
1812     int64_t offset = indices_type.getRank() - 1;
1813     for (int64_t o : update_window_dims.getValues<int64_t>()) {
1814       if (o != offset) {
1815         return rewriter.notifyMatchFailure(scatter_op,
1816                                            "unsupported update window dims");
1817       }
1818       ++offset;
1819     }
1820 
1821     // Match the scatter computation against computations supported by TF.
1822     if (failed(MatchBinaryReduceFunction<BinaryOp>(
1823             scatter_op.update_computation()))) {
1824       return failure();
1825     }
1826 
1827     rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
1828                                       scatter_op.getResult().getType(), operand,
1829                                       indices, updates);
1830     return success();
1831   }
1832 };
1833 using ConvertScatterAddOp =
1834     ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
1835 using ConvertScatterMaxOp =
1836     ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
1837 using ConvertScatterMinOp =
1838     ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
1839 using ConvertScatterSubOp =
1840     ConvertScatterOp<mhlo::SubOp, TF::TensorScatterSubOp>;
1841 using ConvertScatterUpdateOp =
1842     ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
1843 
1844 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)1845 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
1846   auto pad_op = cast<mhlo::PadOp>(old_op);
1847   mlir::Location loc = pad_op.getLoc();
1848 
1849   llvm::SmallVector<APInt, 8> padding;
1850   for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
1851                           pad_op.edge_padding_high().getValues<APInt>())) {
1852     padding.push_back(std::get<0>(p));
1853     padding.push_back(std::get<1>(p));
1854   }
1855   auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
1856                                          rewriter.getI64Type());
1857   auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
1858   auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
1859   return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
1860                                   padding_op, pad_op.padding_value());
1861 }
1862 
1863 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
1864 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)1865 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
1866                         Value output) {
1867   // broadcast_dimensions is an increasing list by definition, thus it suffices
1868   // to check the first element.
1869   int64_t input_rank = broadcast_dimensions.getNumElements();
1870   int64_t output_rank = output.getType().cast<ShapedType>().getRank();
1871   return input_rank == 0 ||
1872          (broadcast_dimensions.getValue({0}).cast<IntegerAttr>().getInt() ==
1873           output_rank - input_rank);
1874 }
1875 
1876 // Returns the intermediate shape that input tensor should be reshaped to during
1877 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)1878 ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
1879                          DenseIntElementsAttr broadcast_dimensions,
1880                          Value output) {
1881   // Initialize expanded shape with output rank and dimensions of 1.
1882   SmallVector<Attribute, 4> expanded_shape(
1883       output.getType().cast<ShapedType>().getRank(),
1884       /*Value=*/rewriter.getI64IntegerAttr(1));
1885 
1886   // Set dimension sizes specified by broadcast_dimensions.
1887   ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
1888   for (auto x : llvm::enumerate(broadcast_dimensions)) {
1889     expanded_shape[x.value().getSExtValue()] =
1890         rewriter.getI64IntegerAttr(input_shape[x.index()]);
1891   }
1892 
1893   // Create the expanded type wrapped in a ConstantOp.
1894   auto attr_type =
1895       RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
1896                             rewriter.getIntegerType(64));
1897   auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
1898   return rewriter.create<ConstantOp>(output.getLoc(), attr_type, attr);
1899 }
1900 
1901 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
1902 
1903 /// Performs the lowering to XLA dialect.
runOnFunction()1904 void LegalizeHloToTf::runOnFunction() {
1905   MLIRContext &context = getContext();
1906 
1907   // Add legalization patterns to the list.
1908   OwningRewritePatternList patterns(&getContext());
1909   PopulateLegalizeHloToTfPatterns(&patterns, &context);
1910 
1911   ConversionTarget target(context);
1912   target.addLegalDialect<TensorFlowDialect>();
1913   target.addLegalOp<CallOp, ConstantOp>();
1914   target.addLegalOp<mhlo::TupleOp>();
1915   if (failed(
1916           applyPartialConversion(getFunction(), target, std::move(patterns)))) {
1917     getFunction().emitError("mhlo to TF legalization failed.");
1918     signalPassFailure();
1919   }
1920 }
1921 
1922 static PassRegistration<LegalizeHloToTf> pass;
1923 
1924 }  // end namespace
1925 
PopulateLegalizeHloToTfPatterns(OwningRewritePatternList * patterns,MLIRContext * context)1926 void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
1927                                      MLIRContext *context) {
1928   patterns->insert<ConvertWhileOp, ConvertAvgPoolOp, ConvertConvOp,
1929                    ConvertConvBackpropInputOp, ConvertDynamicSliceOp,
1930                    ConvertGatherOp, ConvertMaxPoolOp, ConvertScatterAddOp,
1931                    ConvertScatterMaxOp, ConvertScatterMinOp,
1932                    ConvertScatterSubOp, ConvertScatterUpdateOp, ConvertSliceOp,
1933                    ConvertReduceOpToTfArgmax, ConvertReduceOpToTfArgmin,
1934                    ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
1935                    ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
1936   populateWithGenerated(*patterns);
1937 }
1938 
CreateLegalizeHloToTfPass()1939 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
1940   return std::make_unique<LegalizeHloToTf>();
1941 }
1942 
1943 }  // end namespace TF
1944 }  // end namespace mlir
1945