• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for legalizing HLO to TensorFlow.
17 
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21 #include <numeric>
22 #include <vector>
23 
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Location.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/Matchers.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
40 #include "mlir/IR/Value.h"  // from @llvm-project
41 #include "mlir/Pass/Pass.h"  // from @llvm-project
42 #include "mlir/Support/LLVM.h"  // from @llvm-project
43 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/core/framework/kernel_shape_util.h"
52 #include "tensorflow/core/lib/math/math_util.h"
53 
54 namespace mlir {
55 namespace TF {
56 namespace {
57 
58 using mhlo::DotDimensionNumbers;
59 
60 class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
61  public:
62   using OpConversionPattern::OpConversionPattern;
63 
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const64   LogicalResult matchAndRewrite(
65       mhlo::ConvOp conv_op, ArrayRef<Value> args,
66       ConversionPatternRewriter &rewriter) const final {
67     if (!IsSupportedConvOp(conv_op)) {
68       return failure();
69     }
70 
71     // Constructs strides array.
72     // For example, [2, 3] -> [1, 2, 3, 1].
73     SmallVector<int64_t, 4> strides({1});
74     for (const auto v :
75          conv_op.window_strides().getValue().getValues<int64_t>()) {
76       strides.emplace_back(v);
77     }
78     strides.emplace_back(1);
79 
80     // Constructs dilation array.
81     SmallVector<int64_t, 4> dilation;
82     if (auto rhs_dilation = conv_op.rhs_dilation()) {
83       // For example, [2, 3] -> [1, 2, 3, 1].
84       dilation.emplace_back(1);
85       dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
86                       rhs_dilation.getValue().getValues<int64_t>().end());
87       dilation.emplace_back(1);
88     } else {
89       // Default value
90       dilation = {1, 1, 1, 1};
91     }
92 
93     const int input_feature_dimension =
94         conv_op.dimension_numbers().input_feature_dimension().getInt();
95     const int input_channels =
96         conv_op.lhs().getType().cast<ShapedType>().getDimSize(
97             input_feature_dimension);
98     int feature_group_count = conv_op.feature_group_count();
99 
100     if (feature_group_count != 1 && feature_group_count != input_channels) {
101       // Group convolution is not supported yet.
102       return failure();
103     }
104 
105     const bool is_depthwise_conv = input_channels == feature_group_count;
106     std::string padding;
107 
108     if (!conv_op.padding().hasValue() ||
109         (conv_op.padding().getValue().isSplat() &&
110          conv_op.padding()->getSplatValue<int64_t>() == 0)) {
111       padding = "VALID";
112     } else {
113       // Check if padding is "SAME".
114       // TODO(chhe): To support "EXPLICIT" padding.
115       SmallVector<int64_t, 8> padding_array;
116       for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
117         padding_array.emplace_back(v);
118       }
119 
120       const int num_spatial_dims = conv_op.dimension_numbers()
121                                        .input_spatial_dimensions()
122                                        .getNumElements();
123       if (!IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
124                          padding_array))
125         return failure();
126 
127       padding = "SAME";
128     }
129 
130     CreateConvOp(conv_op, strides, padding, dilation, is_depthwise_conv,
131                  input_channels, rewriter);
132     return success();
133   };
134 
135  private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const136   bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
137                      ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
138                      ArrayRef<int64_t> padding_array) const {
139     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
140       int dim = i + 1;
141       tensorflow::int64 output_size;
142       tensorflow::int64 pad_low_int64;
143       tensorflow::int64 pad_high_int64;
144       tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
145           conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim),
146           conv_op.rhs().getType().cast<ShapedType>().getDimSize(i),
147           dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
148           &pad_low_int64, &pad_high_int64);
149       if (!status.ok()) return false;
150       if (padding_array[2 * i] != pad_low_int64 ||
151           padding_array[2 * i + 1] != pad_high_int64)
152         return false;
153     }
154 
155     return true;
156   }
157 
CreateConvOp(mhlo::ConvOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,ConversionPatternRewriter & rewriter) const158   void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
159                     StringRef padding, ArrayRef<int64_t> dilation,
160                     bool is_depthwise_conv, int input_channels,
161                     ConversionPatternRewriter &rewriter) const {
162     // TODO(chhe): To support more data formats other than "NHWC".
163     if (is_depthwise_conv) {
164       // Reshapes filter format to [filter_height, filter_width, in_channels,
165       // channel_multiplier] from HLO's [filter_height, filter_width, 1,
166       // in_channels * channel_multiplier] format.
167       auto filter_type = conv_op.rhs().getType().cast<ShapedType>();
168       llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
169       llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
170                                                     hlo_filter_shape.end());
171       tf_filter_shape[2] = input_channels;
172       tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
173       auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
174           conv_op.rhs().getLoc(),
175           RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
176           conv_op.rhs());
177 
178       rewriter.replaceOpWithNewOp<DepthwiseConv2dNativeOp>(
179           conv_op, conv_op.getType(), conv_op.lhs(), reshaped_filter,
180           rewriter.getI64ArrayAttr(strides),
181           /*padding=*/rewriter.getStringAttr(padding),
182           /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
183           /*data_format=*/rewriter.getStringAttr("NHWC"),
184           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
185     } else {
186       rewriter.replaceOpWithNewOp<Conv2DOp>(
187           conv_op, conv_op.getType(), conv_op.lhs(), conv_op.rhs(),
188           rewriter.getI64ArrayAttr(strides),
189           /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
190           /*padding=*/rewriter.getStringAttr(padding),
191           /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
192           /*data_format=*/rewriter.getStringAttr("NHWC"),
193           /*dilations=*/rewriter.getI64ArrayAttr(dilation));
194     }
195   }
196 
IsSupportedConvOp(mhlo::ConvOp conv_op) const197   bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
198     if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
199         !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
200         !conv_op.getType().cast<ShapedType>().hasStaticShape())
201       return false;
202 
203     // All ones in "lhs_dilation" means this "mhlo.conv" op should be
204     // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
205     if (conv_op.lhs_dilation().hasValue()) {
206       auto lhs_dilation = conv_op.lhs_dilation().getValue();
207       if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
208         return false;
209     }
210 
211     if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
212                                                         .getValue()
213                                                         .getType()
214                                                         .cast<ShapedType>()
215                                                         .getRank() != 1)
216       return false;
217 
218     int num_spatial_dims =
219         conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
220     // TODO(b/158636600): Currently we don't support 3D Convolution.
221     if (num_spatial_dims != 2) return false;
222 
223     // TODO(chhe): To support more data formats other than "NHWC".
224     // Checks input dimensions.
225     if (conv_op.dimension_numbers().input_batch_dimension().getInt() != 0 ||
226         conv_op.dimension_numbers().input_feature_dimension().getInt() !=
227             num_spatial_dims + 1)
228       return false;
229     DenseIntElementsAttr input_spatial_dimensions =
230         conv_op.dimension_numbers().input_spatial_dimensions();
231     for (auto p :
232          llvm::enumerate(input_spatial_dimensions.getValues<int64_t>())) {
233       if (p.value() != p.index() + 1) return false;
234     }
235 
236     // Checks output dimensions.
237     if (conv_op.dimension_numbers().output_batch_dimension().getInt() != 0 ||
238         conv_op.dimension_numbers().output_feature_dimension().getInt() !=
239             num_spatial_dims + 1)
240       return false;
241     DenseIntElementsAttr output_spatial_dimensions =
242         conv_op.dimension_numbers().output_spatial_dimensions();
243     for (auto p :
244          llvm::enumerate(output_spatial_dimensions.getValues<int64_t>())) {
245       if (p.value() != p.index() + 1) return false;
246     }
247 
248     // Checks kernel dimensions.
249     if (conv_op.dimension_numbers().kernel_input_feature_dimension().getInt() !=
250             num_spatial_dims ||
251         conv_op.dimension_numbers()
252                 .kernel_output_feature_dimension()
253                 .getInt() != num_spatial_dims + 1)
254       return false;
255     DenseIntElementsAttr kernal_spatial_dimensions =
256         conv_op.dimension_numbers().kernel_spatial_dimensions();
257     for (auto p :
258          llvm::enumerate(kernal_spatial_dimensions.getValues<int64_t>())) {
259       if (p.value() != p.index()) return false;
260     }
261 
262     return true;
263   }
264 };
265 
266 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
267  public:
268   using OpConversionPattern::OpConversionPattern;
269 
matchAndRewrite(mhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const270   LogicalResult matchAndRewrite(
271       mhlo::SliceOp slice_op, ArrayRef<Value> args,
272       ConversionPatternRewriter &rewriter) const final {
273     DenseIntElementsAttr strides = slice_op.strides();
274     // Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
275     if (!strides.isSplat() ||
276         strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
277       return failure();
278 
279     rewriter.setInsertionPointAfter(slice_op.getOperation());
280     auto start_indices = slice_op.start_indices();
281     auto limit_indices = slice_op.limit_indices();
282     std::vector<int64_t> size_values;
283     for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
284                                limit_indices.getValues<APInt>())) {
285       size_values.emplace_back(std::get<1>(pair).getSExtValue() -
286                                std::get<0>(pair).getSExtValue());
287     }
288 
289     RankedTensorType ty =
290         RankedTensorType::get({static_cast<int64_t>(size_values.size())},
291                               rewriter.getIntegerType(64));
292     auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
293     auto size = rewriter.create<ConstOp>(
294         slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
295     rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
296                                          slice_op.operand(), start, size);
297     return success();
298   };
299 };
300 
301 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
302  public:
303   using OpConversionPattern::OpConversionPattern;
304 
matchAndRewrite(mhlo::DynamicSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const305   LogicalResult matchAndRewrite(
306       mhlo::DynamicSliceOp op, ArrayRef<Value> args,
307       ConversionPatternRewriter &rewriter) const final {
308     ShapedType input_type = op.operand().getType().cast<ShapedType>();
309     if (!input_type.hasStaticShape()) return failure();
310     Type start_indices_element_type = op.start_indices()
311                                           .front()
312                                           .getType()
313                                           .cast<ShapedType>()
314                                           .getElementType();
315 
316     // Clamp indices to [0, input_size - output_size]
317     llvm::SmallVector<Value, 4> start_indices_vector;
318     start_indices_vector.reserve(op.start_indices().size());
319     Value clamp_min = rewriter.create<ConstOp>(
320         op.getLoc(), rewriter.getIntegerAttr(start_indices_element_type, 0));
321     for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
322       Value clamp_max = rewriter.create<ConstOp>(
323           op.getLoc(),
324           rewriter.getIntegerAttr(start_indices_element_type,
325                                   input_type.getShape()[i] -
326                                       op.slice_sizes().getValue<int64_t>({i})));
327       Value clamped_index = rewriter.create<mhlo::ClampOp>(
328           op.getLoc(), op.start_indices()[i].getType(), op.start_indices()[i],
329           clamp_min, clamp_max);
330       start_indices_vector.push_back(clamped_index);
331     }
332 
333     // Pack individual start indices to start indices tensor.
334     Type start_indices_type = RankedTensorType::get(
335         {static_cast<int64_t>(start_indices_vector.size())},
336         start_indices_element_type);
337     Value start_indices_op = rewriter.create<PackOp>(
338         op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
339 
340     Value slice_sices_op =
341         rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
342     rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
343                                          start_indices_op, slice_sices_op);
344     return success();
345   };
346 };
347 
348 // Appends all elements in `range` to `values`.
349 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)350 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
351   values.insert(values.end(), range.begin(), range.end());
352 }
353 
354 // Appends all elements in `range` to `values`.
355 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)356 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
357             RangeTs &&...ranges) {
358   values.insert(values.end(), range.begin(), range.end());
359   Append(values, ranges...);
360 }
361 
362 // Returns the number of elements in `range`.
363 template <typename Range>
Size(Range && range)364 size_t Size(Range &&range) {
365   return range.size();
366 }
367 
368 // Returns the total number of elements in a variadic number of `ranges`.
369 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)370 size_t Size(Range &&range, RangeTs &&...ranges) {
371   return range.size() + Size(std::forward<RangeTs>(ranges)...);
372 }
373 
374 // Concats all elements in `ranges` and returns a small vector as a result.
375 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)376 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
377   llvm::SmallVector<int64_t, 4> results;
378   results.reserve(Size(std::forward<RangeTs>(ranges)...));
379   Append(results, std::forward<RangeTs>(ranges)...);
380   return results;
381 }
382 
383 // A struct to hold axes and sizes for a set of dimensions.
384 struct DimensionVector {
AxesArraymlir::TF::__anon0527096c0111::DimensionVector385   llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anon0527096c0111::DimensionVector386   llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
387 
388   llvm::SmallVector<int64_t, 4> axes;
389   llvm::SmallVector<int64_t, 4> sizes;
390 };
391 
392 // A struct to hold information about dimensions of dot_general operands.
393 class DotDimensionsInfo {
394  public:
DotDimensionsInfo(ShapedType type,DenseIntElementsAttr batch_dimensions,DenseIntElementsAttr contracting_dimensions)395   DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
396                     DenseIntElementsAttr contracting_dimensions) {
397     const int rank = type.getRank();
398     for (const int dim : batch_dimensions.getValues<int64_t>()) {
399       batch_dimensions_.axes.push_back(dim);
400       batch_dimensions_.sizes.push_back(type.getDimSize(dim));
401     }
402 
403     for (const int dim : contracting_dimensions.getValues<int64_t>()) {
404       contracting_dimensions_.axes.push_back(dim);
405       contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
406     }
407 
408     for (int dim = 0; dim < rank; ++dim) {
409       if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
410           llvm::count(batch_dimensions_.axes, dim) > 0) {
411         continue;
412       }
413       out_dimensions_.axes.push_back(dim);
414       out_dimensions_.sizes.push_back(type.getDimSize(dim));
415     }
416   }
417 
batch_dimensions() const418   const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const419   const DimensionVector &contracting_dimensions() const {
420     return contracting_dimensions_;
421   }
422   // Out dimensions are any dimensions that are neither batch nor contracting
423   // dimensions, hence will be propagated to output shape.
out_dimensions() const424   const DimensionVector &out_dimensions() const { return out_dimensions_; }
425 
426   // Returns the total dimension size after flattening all contracting
427   // dimensions.
FlattenedContractingDimensionSize() const428   int FlattenedContractingDimensionSize() const {
429     return std::accumulate(contracting_dimensions_.sizes.begin(),
430                            contracting_dimensions_.sizes.end(), 1,
431                            std::multiplies<int64_t>());
432   }
433 
434   // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const435   int FlattenedOutDimensionSize() const {
436     return std::accumulate(out_dimensions_.sizes.begin(),
437                            out_dimensions_.sizes.end(), 1,
438                            std::multiplies<int64_t>());
439   }
440 
441  private:
442   DimensionVector batch_dimensions_;
443   DimensionVector contracting_dimensions_;
444   // Out dimensions are any dimensions that are neither batch nor contracting
445   // dimensions, hence will be propagated to output shape.
446   DimensionVector out_dimensions_;
447 };
448 
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbers dot_dimension_numbers,ShapedType result_type,mlir::Location loc)449 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
450                  DotDimensionNumbers dot_dimension_numbers,
451                  ShapedType result_type, mlir::Location loc) {
452   auto lhs_type = lhs.getType().cast<ShapedType>();
453   auto rhs_type = rhs.getType().cast<ShapedType>();
454   const int lhs_rank = lhs_type.getRank();
455   const int rhs_rank = rhs_type.getRank();
456 
457   // Collects lhs and rhs dimensions information.
458   DotDimensionsInfo lhs_dot_dimensions_info(
459       lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
460       dot_dimension_numbers.lhs_contracting_dimensions());
461   DotDimensionsInfo rhs_dot_dimensions_info(
462       rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
463       dot_dimension_numbers.rhs_contracting_dimensions());
464 
465   // Transposes lhs shape to be in the order of {batch_dimensions,
466   // out_dimensions, contracting dimensions}.
467   llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
468       lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
469       lhs_dot_dimensions_info.out_dimensions().AxesArray(),
470       lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
471   llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
472       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
473       lhs_dot_dimensions_info.out_dimensions().SizesArray(),
474       lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
475   auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
476       loc,
477       RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
478       lhs,
479       DenseIntElementsAttr::get(
480           RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
481           lhs_permutation));
482 
483   // Transposes rhs shape to be in the order of {batch_dimensions, contracting
484   // dimensions, out_dimensions}.
485   llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
486       rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
487       rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
488       rhs_dot_dimensions_info.out_dimensions().AxesArray());
489   llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
490       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
491       rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
492       rhs_dot_dimensions_info.out_dimensions().SizesArray());
493   auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
494       loc,
495       RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
496       rhs,
497       DenseIntElementsAttr::get(
498           RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
499           rhs_permutation));
500 
501   // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
502   llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
503       lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
504       llvm::ArrayRef<int64_t>{
505           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
506       llvm::ArrayRef<int64_t>{
507           lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
508   auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
509       loc,
510       RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
511       lhs_transposed.getResult());
512 
513   // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
514   llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
515       rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
516       llvm::ArrayRef<int64_t>{
517           rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
518       llvm::ArrayRef<int64_t>{
519           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
520   auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
521       loc,
522       RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
523       rhs_transposed.getResult());
524 
525   // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
526   llvm::SmallVector<int64_t, 4> matmul_shape =
527       Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
528                       llvm::ArrayRef<int64_t>{
529                           lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
530                       llvm::ArrayRef<int64_t>{
531                           rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
532   auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
533       loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
534       lhs_flattend.getResult(), rhs_flattend.getResult());
535   auto reshaped =
536       rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
537   return reshaped.getResult();
538 }
539 
540 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
541 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)542 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
543   auto dot_op = cast<mhlo::DotOp>(old_op);
544   auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
545   auto dot_dimension_numbers = DotDimensionNumbers::get(
546       /*lhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
547       /*rhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
548       /*lhs_contracting_dimensions=*/
549       rewriter.getI64TensorAttr({lhs_rank == 1 ? 0 : 1}),
550       /*rhs_contracting_dimensions=*/rewriter.getI64TensorAttr({0}),
551       rewriter.getContext());
552   return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
553                     dot_op.getResult().getType().cast<ShapedType>(),
554                     dot_op.getLoc());
555 }
556 
557 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
558 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)559 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
560   auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
561   return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
562                     dot_general_op.dot_dimension_numbers(),
563                     dot_general_op.getResult().getType().cast<ShapedType>(),
564                     dot_general_op.getLoc());
565 }
566 
567 // Checks if the specified region is a binary reduction function what takes 2
568 // inputs, passes it to an instance of the specifiied reduction op and then
569 // returns the result.
570 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)571 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
572   Block &body = function.front();
573   if (body.getNumArguments() != 2) return failure();
574 
575   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
576   if (!return_op) return failure();
577   if (return_op.getNumOperands() != 1) return failure();
578 
579   ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
580       return_op.getOperands().front().getDefiningOp());
581   if (!reduce_op) return failure();
582   if (reduce_op.lhs() != body.getArgument(0) ||
583       reduce_op.rhs() != body.getArgument(1))
584     return failure();
585 
586   return success();
587 }
588 
589 // Check if the specified region is a binary reduction function what takes 2
590 // inputs and returns the second input. Functions like this are used by update
591 // scatter like ops.
592 template <>
MatchBinaryReduceFunction(mlir::Region & function)593 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
594   Block &body = function.front();
595   if (body.getNumArguments() != 2) return failure();
596 
597   mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
598   if (!return_op) return failure();
599   if (return_op.getNumOperands() != 1) return failure();
600   if (return_op.getOperands().front() != body.getArgument(1)) return failure();
601   return success();
602 }
603 
604 // Converts an mhlo.reduce op with the specified BinaryOp as the reduction
605 // operation into the specified TfOp.
606 template <typename BinaryOp, typename TfOp>
607 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
608  public:
609   using OpConversionPattern::OpConversionPattern;
610 
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const611   LogicalResult matchAndRewrite(
612       mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
613       ConversionPatternRewriter &rewriter) const final {
614     if (failed(MatchReduceOpInput(reduce_op))) return failure();
615 
616     if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
617       return failure();
618 
619     // In `MatchReduceOpInput` function, we already match that the
620     // "mhlo::ReduceOp" only has one input, one init_value and one result.
621     if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
622 
623     auto input = reduce_op.operands()[0];
624 
625     // Get reduction dimension.
626     DenseIntElementsAttr dimension = reduce_op.dimensions();
627     SmallVector<int64_t, 4> reduce_dims;
628     for (const int64_t &dim : dimension.getValues<int64_t>()) {
629       reduce_dims.emplace_back(dim);
630     }
631     auto dim_type = RankedTensorType::get(
632         {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
633     auto reduction_indices = rewriter.create<ConstOp>(
634         reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
635 
636     rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
637                                       reduction_indices,
638                                       /*keep_dim=*/rewriter.getBoolAttr(false));
639     return success();
640   }
641 
642  private:
643   // Checks that the init value matches with the init value expected for the
644   // target TfOp.
645   virtual LogicalResult MatchInitValue(Value init_value) const = 0;
646 
647   // This function tries to match that the "mhlo::ReduceOp" only has one
648   // input, one init_value and one result.
MatchReduceOpInput(mhlo::ReduceOp reduce_op) const649   LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
650     if (reduce_op.operands().size() != 1 ||
651         reduce_op.init_values().size() != 1 ||
652         reduce_op.getResults().size() != 1)
653       return failure();
654 
655     if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
656       return failure();
657     if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
658     return success();
659   }
660 };
661 
662 class ConvertReduceOpToTfSum
663     : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
664  public:
665   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
666 
MatchInitValue(Value init_value) const667   LogicalResult MatchInitValue(Value init_value) const override {
668     DenseFPElementsAttr init_attr;
669     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
670         !init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
671       return failure();
672     return success();
673   }
674 };
675 
676 class ConvertReduceOpToTfMax
677     : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
678  public:
679   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
680 
MatchInitValue(Value init_value) const681   LogicalResult MatchInitValue(Value init_value) const override {
682     DenseFPElementsAttr init_attr;
683     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
684         !init_attr.isSplat() ||
685         !init_attr.getSplatValue<APFloat>().isInfinity() ||
686         !init_attr.getSplatValue<APFloat>().isNegative())
687       return failure();
688     return success();
689   }
690 };
691 
692 class ConvertReduceOpToTfMin
693     : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
694  public:
695   using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
696 
MatchInitValue(Value init_value) const697   LogicalResult MatchInitValue(Value init_value) const override {
698     DenseFPElementsAttr init_attr;
699     if (!matchPattern(init_value, m_Constant(&init_attr)) ||
700         !init_attr.isSplat() ||
701         !init_attr.getSplatValue<APFloat>().isInfinity() ||
702         init_attr.getSplatValue<APFloat>().isNegative())
703       return failure();
704     return success();
705   }
706 };
707 
708 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
709  public:
710   using OpConversionPattern::OpConversionPattern;
711 
matchAndRewrite(mhlo::IotaOp iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const712   LogicalResult matchAndRewrite(
713       mhlo::IotaOp iota_op, ArrayRef<Value> args,
714       ConversionPatternRewriter &rewriter) const final {
715     RankedTensorType type =
716         iota_op.getType().dyn_cast_or_null<RankedTensorType>();
717     if (!type) return failure();
718 
719     const uint64_t dimension = iota_op.iota_dimension();
720     Type element_type = type.getElementType();
721     Attribute start, limit, delta;
722     if (element_type.isa<FloatType>()) {
723       start = rewriter.getFloatAttr(element_type, 0.0);
724       limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
725       delta = rewriter.getFloatAttr(element_type, 1.0);
726     } else if (element_type.isa<IntegerType>()) {
727       start = rewriter.getIntegerAttr(element_type, 0);
728       limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
729       delta = rewriter.getIntegerAttr(element_type, 1);
730     } else {
731       return failure();
732     }
733 
734     auto range_type =
735         RankedTensorType::get({type.getShape()[dimension]}, element_type);
736     Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
737     Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
738     Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
739     Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
740                                                 start_op, limit_op, delta_op);
741 
742     if (type.getRank() > 1) {
743       std::vector<int64_t> reshape_shape(type.getRank(), 1);
744       reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
745       auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
746       Value reshape_shape_op = rewriter.create<TF::ConstOp>(
747           iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
748       result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
749                                               result, reshape_shape_op);
750 
751       Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
752           iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
753       result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
754                                                   result, broadcast_shape_op);
755     }
756 
757     rewriter.replaceOp(iota_op, result);
758     return success();
759   }
760 };
761 
762 // Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
763 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
764 // padding:
765 // * div(reduce_sum_window(x), constant(sizeof(window)))
766 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
767 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
768  public:
769   using OpConversionPattern::OpConversionPattern;
770 
matchAndRewrite(mhlo::DivOp div_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const771   LogicalResult matchAndRewrite(
772       mhlo::DivOp div_op, ArrayRef<Value> args,
773       ConversionPatternRewriter &rewriter) const final {
774     auto rw =
775         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
776     if (!rw) return failure();
777 
778     // Check that the reduce-window is a sum-reduce-window.
779     if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
780       return failure();
781 
782     // Check that this is a floating point reduce window with a rank of 4 or 5.
783     RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
784     if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
785         rw_type.getRank() <= 3 || rw_type.getRank() > 5)
786       return failure();
787 
788     // Check that the Div op doesn't do broadcasting on the output of the reduce
789     // window.
790     if (div_op.getType() != rw.getType()) return failure();
791 
792     // tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
793     const uint64_t rank = rw.window_dimensions().size();
794     if (rank <= 2) return failure();
795 
796     // If the init value isn't zero then it can't be an average pool.
797     if (!isFloatZero(rw.init_value())) return failure();
798 
799     llvm::SmallVector<int64_t, 5> window_strides;
800     if (rw.window_strides().hasValue()) {
801       window_strides.insert(window_strides.end(),
802                             rw.window_strides()->getValues<int64_t>().begin(),
803                             rw.window_strides()->getValues<int64_t>().end());
804     } else {
805       window_strides.resize(rank, 1);
806     }
807 
808     llvm::SmallVector<int64_t, 10> padding;
809     if (rw.padding().hasValue()) {
810       padding.insert(padding.begin(),
811                      rw.padding()->getValues<int64_t>().begin(),
812                      rw.padding()->getValues<int64_t>().end());
813     } else {
814       padding.resize(2 * rank, 0);
815     }
816 
817     // Check that we don't do any reduction along the batch (first) and channel
818     // (last) dimensions.
819     const uint64_t batch_dim = 0;
820     const uint64_t channel_dim = rank - 1;
821     if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
822         rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
823         window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
824         padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
825         padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
826       return failure();
827 
828     if (rw.window_dilations().hasValue() &&
829         !(rw.window_dilations()->isSplat() &&
830           rw.window_dilations()->getSplatValue<APInt>() == 1))
831       return failure();
832 
833     if (rw.base_dilations().hasValue() &&
834         !(rw.base_dilations()->isSplat() &&
835           rw.base_dilations()->getSplatValue<APInt>() == 1))
836       return failure();
837 
838     DenseFPElementsAttr divisor;
839     if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
840       // If the divisor is a constant then check that it matches with the number
841       // of elements inside the window what is required for a VALID AvgPool.
842       if (!divisor.isSplat()) return failure();
843       int64_t window_size = 1;
844       for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
845         window_size *= w;
846       }
847       if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
848         return failure();
849 
850       // Check that we have no padding.
851       if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
852         return failure();
853 
854       return replaceWithAvgPool(
855           div_op, rw.operand(),
856           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
857           window_strides, "VALID", rewriter);
858     }
859 
860     auto rw_rhs =
861         dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
862     if (rw_rhs) {
863       // Check that RHS is a sum-reduce-window.
864       if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
865         return failure();
866 
867       // Check that the RHS is a reduce_window over a constant 1 input with 0 as
868       // the init value.
869       DenseFPElementsAttr rhs_input;
870       if (!isFloatZero(rw_rhs.init_value()) ||
871           !matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
872           !rhs_input.isSplat() ||
873           !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
874         return failure();
875 
876       // Check that the two reduce window have the same window configuration.
877       if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
878           rw.window_strides() != rw_rhs.window_strides() ||
879           rw.window_dilations() != rw_rhs.window_dilations() ||
880           rw.base_dilations() != rw_rhs.base_dilations() ||
881           rw.padding() != rw_rhs.padding())
882         return failure();
883 
884       if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
885         return replaceWithAvgPool(
886             div_op, rw.operand(),
887             llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
888             window_strides, "VALID", rewriter);
889 
890       RankedTensorType input_type =
891           rw.operand().getType().dyn_cast<RankedTensorType>();
892       RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
893       if (!input_type || !output_type) return failure();
894 
895       // Check that the individual padding values are corresponding to SAME
896       // padding from TensorFlow.
897       for (uint64_t i = 1; i < rank - 1; ++i) {
898         int64_t padding_size =
899             (output_type.getShape()[i] - 1) * window_strides[i] +
900             rw.window_dimensions().getValue<int64_t>({i}) -
901             input_type.getShape()[i];
902         if (padding[2 * i] !=
903                 tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
904             padding[2 * i + 1] !=
905                 tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
906           return failure();
907       }
908       return replaceWithAvgPool(
909           div_op, rw.operand(),
910           llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
911           window_strides, "SAME", rewriter);
912     }
913     return failure();
914   }
915 
916  private:
isFloatZero(Value value) const917   bool isFloatZero(Value value) const {
918     DenseFPElementsAttr initial_value;
919     return matchPattern(value, m_Constant(&initial_value)) &&
920            initial_value.getNumElements() == 1 &&
921            initial_value.getValue<APFloat>({}).isZero();
922   }
923 
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const924   LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
925                                    llvm::ArrayRef<int64_t> ksizes,
926                                    llvm::ArrayRef<int64_t> kstrides,
927                                    llvm::StringRef padding,
928                                    ConversionPatternRewriter &rewriter) const {
929     if (ksizes.size() == 4) {
930       rewriter.replaceOpWithNewOp<AvgPoolOp>(
931           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
932           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
933           rewriter.getStringAttr("NHWC"));
934       return success();
935     } else if (ksizes.size() == 5) {
936       rewriter.replaceOpWithNewOp<AvgPool3DOp>(
937           op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
938           rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
939           rewriter.getStringAttr("NDHWC"));
940       return success();
941     }
942     return failure();
943   }
944 };
945 
946 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const947   void getDependentDialects(DialectRegistry &registry) const override {
948     registry.insert<TF::TensorFlowDialect>();
949   }
950 
951  public:
952   LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &)953   LegalizeHloToTf(const LegalizeHloToTf &) {}
954 
955   /// Performs the legalization to the TF dialect.
956   void runOnFunction() override;
957 };
958 
959 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)960 ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
961   ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
962   auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
963                                          rewriter.getIntegerType(64));
964   auto attr = DenseElementsAttr::get(attr_type, shape);
965   return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
966 }
967 
968 // If index_vector_dim == indices.rank() then insert the implicit extra
969 // dimension into indices to normalize everything to index_vector_dim ==
970 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)971 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
972                                    ShapedType &indices_type,
973                                    int64_t index_vector_dim,
974                                    ConversionPatternRewriter &rewriter) {
975   if (index_vector_dim == indices_type.getRank()) {
976     llvm::SmallVector<int64_t, 4> new_start_indices_shape(
977         indices_type.getShape().begin(), indices_type.getShape().end());
978     new_start_indices_shape.push_back(1);
979     indices_type = RankedTensorType::get(new_start_indices_shape,
980                                          indices_type.getElementType());
981     indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
982                                                indices_type, indices);
983   } else if (index_vector_dim != indices_type.getRank() - 1) {
984     // If index_vector_dim isn't the last dimension in indices then it isn't
985     // supported yet.
986     // TODO(tberghammer): Transpose indices to support this usecase.
987     return rewriter.notifyMatchFailure(
988         parent_op,
989         "index vector dim isn't the last dimension in start indices");
990   }
991   return success();
992 }
993 
994 // Check that `attr` is an R1 iota with integer element type starting from `0`
995 // with `size` number of values.
IsIotaAttr(const DenseIntElementsAttr & attr,int64_t size)996 bool IsIotaAttr(const DenseIntElementsAttr &attr, int64_t size) {
997   if (!attr.getType().getElementType().isa<IntegerType>()) return false;
998   if (attr.getType().getRank() != 1) return false;
999   if (attr.getNumElements() != size) return false;
1000   int64_t iota = 0;
1001   for (auto s : attr.getIntValues()) {
1002     if (s != iota) return false;
1003     ++iota;
1004   }
1005   return true;
1006 }
1007 
1008 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
1009  public:
1010   using OpConversionPattern::OpConversionPattern;
1011 
matchAndRewrite(mhlo::GatherOp gather_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1012   LogicalResult matchAndRewrite(
1013       mhlo::GatherOp gather_op, ArrayRef<Value> args,
1014       ConversionPatternRewriter &rewriter) const final {
1015     Value operand = gather_op.operand();
1016     Value start_indices = gather_op.start_indices();
1017 
1018     // Can only convert with static shaped gather.
1019     ShapedType operand_type = operand.getType().cast<ShapedType>();
1020     ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
1021     ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
1022     if (!operand_type.hasStaticShape() ||
1023         !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
1024       return failure();
1025     }
1026 
1027     // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1028     int64_t index_vector_dim =
1029         gather_op.dimension_numbers().index_vector_dim().getInt();
1030     if (failed(NormalizeIndexVector(gather_op, start_indices,
1031                                     start_indices_type, index_vector_dim,
1032                                     rewriter))) {
1033       return failure();
1034     }
1035 
1036     // Verify that start_index_map and collapsed_slice_dims are both an iota
1037     // with the same number of elements as the last dimension of start_indices.
1038     auto start_index_map = gather_op.dimension_numbers().start_index_map();
1039     auto collapsed_slice_dims =
1040         gather_op.dimension_numbers().collapsed_slice_dims();
1041     if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) ||
1042         !IsIotaAttr(collapsed_slice_dims,
1043                     start_indices_type.getShape().back())) {
1044       // TODO(tberghammer): Transform start_indices to support non-standard
1045       // start_index_maps.
1046       return rewriter.notifyMatchFailure(
1047           gather_op, "unsupported start index map and/or collapsed slice dims");
1048     }
1049 
1050     // Verify that slice_sizes is 1 for the indexed dimensions and the full
1051     // shape for the rest of the dimensions.
1052     auto slice_sizes = gather_op.slice_sizes();
1053     int64_t index = 0;
1054     for (int64_t s : slice_sizes.getValues<int64_t>()) {
1055       if (index < start_indices_type.getShape().back()) {
1056         if (s != 1) {
1057           return rewriter.notifyMatchFailure(gather_op,
1058                                              "unsupported slice sizes");
1059         }
1060       } else {
1061         if (s != operand_type.getShape()[index]) {
1062           return rewriter.notifyMatchFailure(gather_op,
1063                                              "unsupported slice sizes");
1064         }
1065       }
1066       ++index;
1067     }
1068 
1069     // Verify that offset_dims are the tailing dimensions in the output tensor.
1070     auto offset_dims = gather_op.dimension_numbers().offset_dims();
1071     int64_t offset = start_indices_type.getRank() - 1;
1072     for (int64_t o : offset_dims.getValues<int64_t>()) {
1073       if (o != offset) {
1074         return rewriter.notifyMatchFailure(gather_op,
1075                                            "unsupported offset dims");
1076       }
1077       ++offset;
1078     }
1079 
1080     rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
1081                                                 start_indices);
1082     return success();
1083   }
1084 };
1085 
1086 template <typename BinaryOp, typename TfOp>
1087 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
1088  public:
1089   using OpConversionPattern::OpConversionPattern;
1090 
matchAndRewrite(mhlo::ScatterOp scatter_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1091   LogicalResult matchAndRewrite(
1092       mhlo::ScatterOp scatter_op, ArrayRef<Value> args,
1093       ConversionPatternRewriter &rewriter) const final {
1094     Value operand = scatter_op.operand();
1095     Value indices = scatter_op.scatter_indices();
1096     Value updates = scatter_op.updates();
1097     ShapedType operand_type = operand.getType().cast<ShapedType>();
1098     ShapedType indices_type = indices.getType().cast<ShapedType>();
1099     ShapedType updates_type = updates.getType().cast<ShapedType>();
1100 
1101     // Can only convert with static shaped scatter.
1102     if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
1103         !updates_type.hasStaticShape()) {
1104       return failure();
1105     }
1106 
1107     // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1108     int64_t index_vector_dim =
1109         scatter_op.scatter_dimension_numbers().index_vector_dim().getInt();
1110     if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
1111                                     index_vector_dim, rewriter))) {
1112       return failure();
1113     }
1114 
1115     // Verify that inserted_window_dims and scatter_dims_to_operand_dims are
1116     // both an iota with the same number of elements as the last dimension of
1117     // start_indices.
1118     auto inserted_window_dims =
1119         scatter_op.scatter_dimension_numbers().inserted_window_dims();
1120     auto scatter_dims_to_operand_dims =
1121         scatter_op.scatter_dimension_numbers().scatter_dims_to_operand_dims();
1122     if (!IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) ||
1123         !IsIotaAttr(scatter_dims_to_operand_dims,
1124                     indices_type.getShape().back())) {
1125       // TODO(tberghammer): Transform indices to support non-standard
1126       // scatter_dims_to_operand_dims.
1127       return rewriter.notifyMatchFailure(
1128           scatter_op,
1129           "unsupported inserted window dims and/or scatter dims to operand "
1130           "dims");
1131     }
1132 
1133     // Verify that update window dims are the tailing dimensions in the update
1134     // tensor.
1135     auto update_window_dims =
1136         scatter_op.scatter_dimension_numbers().update_window_dims();
1137     int64_t offset = indices_type.getRank() - 1;
1138     for (int64_t o : update_window_dims.getValues<int64_t>()) {
1139       if (o != offset) {
1140         return rewriter.notifyMatchFailure(scatter_op,
1141                                            "unsupported update window dims");
1142       }
1143       ++offset;
1144     }
1145 
1146     // Match the scatter computation against computations supported by TF.
1147     if (failed(MatchBinaryReduceFunction<BinaryOp>(
1148             scatter_op.update_computation()))) {
1149       return failure();
1150     }
1151 
1152     rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
1153                                       scatter_op.getResult().getType(), operand,
1154                                       indices, updates);
1155     return success();
1156   }
1157 };
1158 using ConvertScatterAddOp =
1159     ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
1160 using ConvertScatterMaxOp =
1161     ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
1162 using ConvertScatterMinOp =
1163     ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
1164 using ConvertScatterSubOp =
1165     ConvertScatterOp<mhlo::SubOp, TF::TensorScatterSubOp>;
1166 using ConvertScatterUpdateOp =
1167     ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
1168 
1169 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)1170 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
1171   auto pad_op = cast<mhlo::PadOp>(old_op);
1172   mlir::Location loc = pad_op.getLoc();
1173 
1174   llvm::SmallVector<APInt, 8> padding;
1175   for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
1176                           pad_op.edge_padding_high().getValues<APInt>())) {
1177     padding.push_back(std::get<0>(p));
1178     padding.push_back(std::get<1>(p));
1179   }
1180   auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
1181                                          rewriter.getI64Type());
1182   auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
1183   auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
1184   return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
1185                                   padding_op, pad_op.padding_value());
1186 }
1187 
1188 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
1189 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)1190 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
1191                         Value output) {
1192   // broadcast_dimensions is an increasing list by definition, thus it suffices
1193   // to check the first element.
1194   int64_t input_rank = broadcast_dimensions.getNumElements();
1195   int64_t output_rank = output.getType().cast<ShapedType>().getRank();
1196   return input_rank == 0 ||
1197          (broadcast_dimensions.getValue({0}).cast<IntegerAttr>().getInt() ==
1198           output_rank - input_rank);
1199 }
1200 
1201 // Returns the intermediate shape that input tensor should be reshaped to during
1202 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)1203 ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
1204                          DenseIntElementsAttr broadcast_dimensions,
1205                          Value output) {
1206   // Initialize expanded shape with output rank and dimensions of 1.
1207   SmallVector<Attribute, 4> expanded_shape(
1208       output.getType().cast<ShapedType>().getRank(),
1209       /*Value=*/rewriter.getI64IntegerAttr(1));
1210 
1211   // Set dimension sizes specified by broadcast_dimensions.
1212   ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
1213   for (auto x : llvm::enumerate(broadcast_dimensions)) {
1214     expanded_shape[x.value().getSExtValue()] =
1215         rewriter.getI64IntegerAttr(input_shape[x.index()]);
1216   }
1217 
1218   // Create the expanded type wrapped in a ConstantOp.
1219   auto attr_type =
1220       RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
1221                             rewriter.getIntegerType(64));
1222   auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
1223   return rewriter.create<ConstantOp>(output.getLoc(), attr_type, attr);
1224 }
1225 
1226 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
1227 
1228 /// Performs the lowering to XLA dialect.
runOnFunction()1229 void LegalizeHloToTf::runOnFunction() {
1230   MLIRContext &context = getContext();
1231 
1232   // Add legalization patterns to the list.
1233   OwningRewritePatternList patterns;
1234   PopulateLegalizeHloToTfPatterns(&patterns, &context);
1235 
1236   ConversionTarget target(context);
1237   target.addLegalDialect<TensorFlowDialect>();
1238   target.addLegalOp<CallOp, ConstantOp>();
1239   if (failed(
1240           applyPartialConversion(getFunction(), target, std::move(patterns)))) {
1241     getFunction().emitError("mhlo to TF legalization failed.");
1242     signalPassFailure();
1243   }
1244 }
1245 
1246 static PassRegistration<LegalizeHloToTf> pass(
1247     "tf-legalize-hlo", "Legalize from HLO to the TF dialect");
1248 
1249 }  // end namespace
1250 
PopulateLegalizeHloToTfPatterns(OwningRewritePatternList * patterns,MLIRContext * context)1251 void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
1252                                      MLIRContext *context) {
1253   patterns
1254       ->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertDynamicSliceOp,
1255                ConvertGatherOp, ConvertScatterAddOp, ConvertScatterMaxOp,
1256                ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp,
1257                ConvertSliceOp, ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
1258                ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
1259   populateWithGenerated(context, *patterns);
1260 }
1261 
CreateLegalizeHloToTfPass()1262 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
1263   return std::make_unique<LegalizeHloToTf>();
1264 }
1265 
1266 }  // end namespace TF
1267 }  // end namespace mlir
1268