1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for legalizing HLO to TensorFlow.
17
18 #include <cstdint>
19 #include <functional>
20 #include <memory>
21 #include <numeric>
22 #include <vector>
23
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Location.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/Matchers.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/PatternMatch.h" // from @llvm-project
40 #include "mlir/IR/Value.h" // from @llvm-project
41 #include "mlir/Pass/Pass.h" // from @llvm-project
42 #include "mlir/Support/LLVM.h" // from @llvm-project
43 #include "mlir/Support/LogicalResult.h" // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
51 #include "tensorflow/core/framework/kernel_shape_util.h"
52 #include "tensorflow/core/lib/math/math_util.h"
53
54 namespace mlir {
55 namespace TF {
56 namespace {
57
58 using mhlo::DotDimensionNumbers;
59
60 class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
61 public:
62 using OpConversionPattern::OpConversionPattern;
63
matchAndRewrite(mhlo::ConvOp conv_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const64 LogicalResult matchAndRewrite(
65 mhlo::ConvOp conv_op, ArrayRef<Value> args,
66 ConversionPatternRewriter &rewriter) const final {
67 if (!IsSupportedConvOp(conv_op)) {
68 return failure();
69 }
70
71 // Constructs strides array.
72 // For example, [2, 3] -> [1, 2, 3, 1].
73 SmallVector<int64_t, 4> strides({1});
74 for (const auto v :
75 conv_op.window_strides().getValue().getValues<int64_t>()) {
76 strides.emplace_back(v);
77 }
78 strides.emplace_back(1);
79
80 // Constructs dilation array.
81 SmallVector<int64_t, 4> dilation;
82 if (auto rhs_dilation = conv_op.rhs_dilation()) {
83 // For example, [2, 3] -> [1, 2, 3, 1].
84 dilation.emplace_back(1);
85 dilation.append(rhs_dilation.getValue().getValues<int64_t>().begin(),
86 rhs_dilation.getValue().getValues<int64_t>().end());
87 dilation.emplace_back(1);
88 } else {
89 // Default value
90 dilation = {1, 1, 1, 1};
91 }
92
93 const int input_feature_dimension =
94 conv_op.dimension_numbers().input_feature_dimension().getInt();
95 const int input_channels =
96 conv_op.lhs().getType().cast<ShapedType>().getDimSize(
97 input_feature_dimension);
98 int feature_group_count = conv_op.feature_group_count();
99
100 if (feature_group_count != 1 && feature_group_count != input_channels) {
101 // Group convolution is not supported yet.
102 return failure();
103 }
104
105 const bool is_depthwise_conv = input_channels == feature_group_count;
106 std::string padding;
107
108 if (!conv_op.padding().hasValue() ||
109 (conv_op.padding().getValue().isSplat() &&
110 conv_op.padding()->getSplatValue<int64_t>() == 0)) {
111 padding = "VALID";
112 } else {
113 // Check if padding is "SAME".
114 // TODO(chhe): To support "EXPLICIT" padding.
115 SmallVector<int64_t, 8> padding_array;
116 for (const auto v : conv_op.padding().getValue().getValues<int64_t>()) {
117 padding_array.emplace_back(v);
118 }
119
120 const int num_spatial_dims = conv_op.dimension_numbers()
121 .input_spatial_dimensions()
122 .getNumElements();
123 if (!IsSamePadding(conv_op, num_spatial_dims, strides, dilation,
124 padding_array))
125 return failure();
126
127 padding = "SAME";
128 }
129
130 CreateConvOp(conv_op, strides, padding, dilation, is_depthwise_conv,
131 input_channels, rewriter);
132 return success();
133 };
134
135 private:
IsSamePadding(mhlo::ConvOp conv_op,int num_spatial_dims,ArrayRef<int64_t> strides,ArrayRef<int64_t> dilation,ArrayRef<int64_t> padding_array) const136 bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
137 ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
138 ArrayRef<int64_t> padding_array) const {
139 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
140 int dim = i + 1;
141 tensorflow::int64 output_size;
142 tensorflow::int64 pad_low_int64;
143 tensorflow::int64 pad_high_int64;
144 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
145 conv_op.lhs().getType().cast<ShapedType>().getDimSize(dim),
146 conv_op.rhs().getType().cast<ShapedType>().getDimSize(i),
147 dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size,
148 &pad_low_int64, &pad_high_int64);
149 if (!status.ok()) return false;
150 if (padding_array[2 * i] != pad_low_int64 ||
151 padding_array[2 * i + 1] != pad_high_int64)
152 return false;
153 }
154
155 return true;
156 }
157
CreateConvOp(mhlo::ConvOp conv_op,ArrayRef<int64_t> strides,StringRef padding,ArrayRef<int64_t> dilation,bool is_depthwise_conv,int input_channels,ConversionPatternRewriter & rewriter) const158 void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
159 StringRef padding, ArrayRef<int64_t> dilation,
160 bool is_depthwise_conv, int input_channels,
161 ConversionPatternRewriter &rewriter) const {
162 // TODO(chhe): To support more data formats other than "NHWC".
163 if (is_depthwise_conv) {
164 // Reshapes filter format to [filter_height, filter_width, in_channels,
165 // channel_multiplier] from HLO's [filter_height, filter_width, 1,
166 // in_channels * channel_multiplier] format.
167 auto filter_type = conv_op.rhs().getType().cast<ShapedType>();
168 llvm::ArrayRef<int64_t> hlo_filter_shape = filter_type.getShape();
169 llvm::SmallVector<int64_t, 4> tf_filter_shape(hlo_filter_shape.begin(),
170 hlo_filter_shape.end());
171 tf_filter_shape[2] = input_channels;
172 tf_filter_shape[3] = hlo_filter_shape.back() / input_channels;
173 auto reshaped_filter = rewriter.create<mhlo::ReshapeOp>(
174 conv_op.rhs().getLoc(),
175 RankedTensorType::get(tf_filter_shape, filter_type.getElementType()),
176 conv_op.rhs());
177
178 rewriter.replaceOpWithNewOp<DepthwiseConv2dNativeOp>(
179 conv_op, conv_op.getType(), conv_op.lhs(), reshaped_filter,
180 rewriter.getI64ArrayAttr(strides),
181 /*padding=*/rewriter.getStringAttr(padding),
182 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
183 /*data_format=*/rewriter.getStringAttr("NHWC"),
184 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
185 } else {
186 rewriter.replaceOpWithNewOp<Conv2DOp>(
187 conv_op, conv_op.getType(), conv_op.lhs(), conv_op.rhs(),
188 rewriter.getI64ArrayAttr(strides),
189 /*use_cudnn_on_gpu=*/rewriter.getBoolAttr(true),
190 /*padding=*/rewriter.getStringAttr(padding),
191 /*explicit_paddings=*/rewriter.getI64ArrayAttr({}),
192 /*data_format=*/rewriter.getStringAttr("NHWC"),
193 /*dilations=*/rewriter.getI64ArrayAttr(dilation));
194 }
195 }
196
IsSupportedConvOp(mhlo::ConvOp conv_op) const197 bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
198 if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
199 !conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
200 !conv_op.getType().cast<ShapedType>().hasStaticShape())
201 return false;
202
203 // All ones in "lhs_dilation" means this "mhlo.conv" op should be
204 // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
205 if (conv_op.lhs_dilation().hasValue()) {
206 auto lhs_dilation = conv_op.lhs_dilation().getValue();
207 if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue<int64_t>() != 1)
208 return false;
209 }
210
211 if (!conv_op.window_strides().hasValue() || conv_op.window_strides()
212 .getValue()
213 .getType()
214 .cast<ShapedType>()
215 .getRank() != 1)
216 return false;
217
218 int num_spatial_dims =
219 conv_op.dimension_numbers().input_spatial_dimensions().getNumElements();
220 // TODO(b/158636600): Currently we don't support 3D Convolution.
221 if (num_spatial_dims != 2) return false;
222
223 // TODO(chhe): To support more data formats other than "NHWC".
224 // Checks input dimensions.
225 if (conv_op.dimension_numbers().input_batch_dimension().getInt() != 0 ||
226 conv_op.dimension_numbers().input_feature_dimension().getInt() !=
227 num_spatial_dims + 1)
228 return false;
229 DenseIntElementsAttr input_spatial_dimensions =
230 conv_op.dimension_numbers().input_spatial_dimensions();
231 for (auto p :
232 llvm::enumerate(input_spatial_dimensions.getValues<int64_t>())) {
233 if (p.value() != p.index() + 1) return false;
234 }
235
236 // Checks output dimensions.
237 if (conv_op.dimension_numbers().output_batch_dimension().getInt() != 0 ||
238 conv_op.dimension_numbers().output_feature_dimension().getInt() !=
239 num_spatial_dims + 1)
240 return false;
241 DenseIntElementsAttr output_spatial_dimensions =
242 conv_op.dimension_numbers().output_spatial_dimensions();
243 for (auto p :
244 llvm::enumerate(output_spatial_dimensions.getValues<int64_t>())) {
245 if (p.value() != p.index() + 1) return false;
246 }
247
248 // Checks kernel dimensions.
249 if (conv_op.dimension_numbers().kernel_input_feature_dimension().getInt() !=
250 num_spatial_dims ||
251 conv_op.dimension_numbers()
252 .kernel_output_feature_dimension()
253 .getInt() != num_spatial_dims + 1)
254 return false;
255 DenseIntElementsAttr kernal_spatial_dimensions =
256 conv_op.dimension_numbers().kernel_spatial_dimensions();
257 for (auto p :
258 llvm::enumerate(kernal_spatial_dimensions.getValues<int64_t>())) {
259 if (p.value() != p.index()) return false;
260 }
261
262 return true;
263 }
264 };
265
266 class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
267 public:
268 using OpConversionPattern::OpConversionPattern;
269
matchAndRewrite(mhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const270 LogicalResult matchAndRewrite(
271 mhlo::SliceOp slice_op, ArrayRef<Value> args,
272 ConversionPatternRewriter &rewriter) const final {
273 DenseIntElementsAttr strides = slice_op.strides();
274 // Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
275 if (!strides.isSplat() ||
276 strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
277 return failure();
278
279 rewriter.setInsertionPointAfter(slice_op.getOperation());
280 auto start_indices = slice_op.start_indices();
281 auto limit_indices = slice_op.limit_indices();
282 std::vector<int64_t> size_values;
283 for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
284 limit_indices.getValues<APInt>())) {
285 size_values.emplace_back(std::get<1>(pair).getSExtValue() -
286 std::get<0>(pair).getSExtValue());
287 }
288
289 RankedTensorType ty =
290 RankedTensorType::get({static_cast<int64_t>(size_values.size())},
291 rewriter.getIntegerType(64));
292 auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
293 auto size = rewriter.create<ConstOp>(
294 slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
295 rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
296 slice_op.operand(), start, size);
297 return success();
298 };
299 };
300
301 class ConvertDynamicSliceOp : public OpConversionPattern<mhlo::DynamicSliceOp> {
302 public:
303 using OpConversionPattern::OpConversionPattern;
304
matchAndRewrite(mhlo::DynamicSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const305 LogicalResult matchAndRewrite(
306 mhlo::DynamicSliceOp op, ArrayRef<Value> args,
307 ConversionPatternRewriter &rewriter) const final {
308 ShapedType input_type = op.operand().getType().cast<ShapedType>();
309 if (!input_type.hasStaticShape()) return failure();
310 Type start_indices_element_type = op.start_indices()
311 .front()
312 .getType()
313 .cast<ShapedType>()
314 .getElementType();
315
316 // Clamp indices to [0, input_size - output_size]
317 llvm::SmallVector<Value, 4> start_indices_vector;
318 start_indices_vector.reserve(op.start_indices().size());
319 Value clamp_min = rewriter.create<ConstOp>(
320 op.getLoc(), rewriter.getIntegerAttr(start_indices_element_type, 0));
321 for (uint64_t i = 0, e = op.start_indices().size(); i < e; ++i) {
322 Value clamp_max = rewriter.create<ConstOp>(
323 op.getLoc(),
324 rewriter.getIntegerAttr(start_indices_element_type,
325 input_type.getShape()[i] -
326 op.slice_sizes().getValue<int64_t>({i})));
327 Value clamped_index = rewriter.create<mhlo::ClampOp>(
328 op.getLoc(), op.start_indices()[i].getType(), op.start_indices()[i],
329 clamp_min, clamp_max);
330 start_indices_vector.push_back(clamped_index);
331 }
332
333 // Pack individual start indices to start indices tensor.
334 Type start_indices_type = RankedTensorType::get(
335 {static_cast<int64_t>(start_indices_vector.size())},
336 start_indices_element_type);
337 Value start_indices_op = rewriter.create<PackOp>(
338 op.getLoc(), start_indices_type, ValueRange(start_indices_vector));
339
340 Value slice_sices_op =
341 rewriter.create<ConstOp>(op.getLoc(), op.slice_sizes());
342 rewriter.replaceOpWithNewOp<SliceOp>(op, op.getType(), op.operand(),
343 start_indices_op, slice_sices_op);
344 return success();
345 };
346 };
347
348 // Appends all elements in `range` to `values`.
349 template <typename ValueT, typename Range>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range)350 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range) {
351 values.insert(values.end(), range.begin(), range.end());
352 }
353
354 // Appends all elements in `range` to `values`.
355 template <typename ValueT, typename Range, typename... RangeTs>
Append(llvm::SmallVectorImpl<ValueT> & values,Range && range,RangeTs &&...ranges)356 void Append(llvm::SmallVectorImpl<ValueT> &values, Range &&range,
357 RangeTs &&...ranges) {
358 values.insert(values.end(), range.begin(), range.end());
359 Append(values, ranges...);
360 }
361
362 // Returns the number of elements in `range`.
363 template <typename Range>
Size(Range && range)364 size_t Size(Range &&range) {
365 return range.size();
366 }
367
368 // Returns the total number of elements in a variadic number of `ranges`.
369 template <typename Range, typename... RangeTs>
Size(Range && range,RangeTs &&...ranges)370 size_t Size(Range &&range, RangeTs &&...ranges) {
371 return range.size() + Size(std::forward<RangeTs>(ranges)...);
372 }
373
374 // Concats all elements in `ranges` and returns a small vector as a result.
375 template <typename ValueT, typename... RangeTs>
Concat(RangeTs &&...ranges)376 llvm::SmallVector<ValueT, 4> Concat(RangeTs &&...ranges) {
377 llvm::SmallVector<int64_t, 4> results;
378 results.reserve(Size(std::forward<RangeTs>(ranges)...));
379 Append(results, std::forward<RangeTs>(ranges)...);
380 return results;
381 }
382
383 // A struct to hold axes and sizes for a set of dimensions.
384 struct DimensionVector {
AxesArraymlir::TF::__anon0527096c0111::DimensionVector385 llvm::ArrayRef<int64_t> AxesArray() const { return axes; }
SizesArraymlir::TF::__anon0527096c0111::DimensionVector386 llvm::ArrayRef<int64_t> SizesArray() const { return sizes; }
387
388 llvm::SmallVector<int64_t, 4> axes;
389 llvm::SmallVector<int64_t, 4> sizes;
390 };
391
392 // A struct to hold information about dimensions of dot_general operands.
393 class DotDimensionsInfo {
394 public:
DotDimensionsInfo(ShapedType type,DenseIntElementsAttr batch_dimensions,DenseIntElementsAttr contracting_dimensions)395 DotDimensionsInfo(ShapedType type, DenseIntElementsAttr batch_dimensions,
396 DenseIntElementsAttr contracting_dimensions) {
397 const int rank = type.getRank();
398 for (const int dim : batch_dimensions.getValues<int64_t>()) {
399 batch_dimensions_.axes.push_back(dim);
400 batch_dimensions_.sizes.push_back(type.getDimSize(dim));
401 }
402
403 for (const int dim : contracting_dimensions.getValues<int64_t>()) {
404 contracting_dimensions_.axes.push_back(dim);
405 contracting_dimensions_.sizes.push_back(type.getDimSize(dim));
406 }
407
408 for (int dim = 0; dim < rank; ++dim) {
409 if (llvm::count(contracting_dimensions_.axes, dim) > 0 ||
410 llvm::count(batch_dimensions_.axes, dim) > 0) {
411 continue;
412 }
413 out_dimensions_.axes.push_back(dim);
414 out_dimensions_.sizes.push_back(type.getDimSize(dim));
415 }
416 }
417
batch_dimensions() const418 const DimensionVector &batch_dimensions() const { return batch_dimensions_; }
contracting_dimensions() const419 const DimensionVector &contracting_dimensions() const {
420 return contracting_dimensions_;
421 }
422 // Out dimensions are any dimensions that are neither batch nor contracting
423 // dimensions, hence will be propagated to output shape.
out_dimensions() const424 const DimensionVector &out_dimensions() const { return out_dimensions_; }
425
426 // Returns the total dimension size after flattening all contracting
427 // dimensions.
FlattenedContractingDimensionSize() const428 int FlattenedContractingDimensionSize() const {
429 return std::accumulate(contracting_dimensions_.sizes.begin(),
430 contracting_dimensions_.sizes.end(), 1,
431 std::multiplies<int64_t>());
432 }
433
434 // Returns the total dimension size after flattening all out dimensions.
FlattenedOutDimensionSize() const435 int FlattenedOutDimensionSize() const {
436 return std::accumulate(out_dimensions_.sizes.begin(),
437 out_dimensions_.sizes.end(), 1,
438 std::multiplies<int64_t>());
439 }
440
441 private:
442 DimensionVector batch_dimensions_;
443 DimensionVector contracting_dimensions_;
444 // Out dimensions are any dimensions that are neither batch nor contracting
445 // dimensions, hence will be propagated to output shape.
446 DimensionVector out_dimensions_;
447 };
448
ConvertDot(PatternRewriter & rewriter,Value lhs,Value rhs,DotDimensionNumbers dot_dimension_numbers,ShapedType result_type,mlir::Location loc)449 Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs,
450 DotDimensionNumbers dot_dimension_numbers,
451 ShapedType result_type, mlir::Location loc) {
452 auto lhs_type = lhs.getType().cast<ShapedType>();
453 auto rhs_type = rhs.getType().cast<ShapedType>();
454 const int lhs_rank = lhs_type.getRank();
455 const int rhs_rank = rhs_type.getRank();
456
457 // Collects lhs and rhs dimensions information.
458 DotDimensionsInfo lhs_dot_dimensions_info(
459 lhs_type, dot_dimension_numbers.lhs_batching_dimensions(),
460 dot_dimension_numbers.lhs_contracting_dimensions());
461 DotDimensionsInfo rhs_dot_dimensions_info(
462 rhs_type, dot_dimension_numbers.rhs_batching_dimensions(),
463 dot_dimension_numbers.rhs_contracting_dimensions());
464
465 // Transposes lhs shape to be in the order of {batch_dimensions,
466 // out_dimensions, contracting dimensions}.
467 llvm::SmallVector<int64_t, 4> lhs_permutation = Concat<int64_t>(
468 lhs_dot_dimensions_info.batch_dimensions().AxesArray(),
469 lhs_dot_dimensions_info.out_dimensions().AxesArray(),
470 lhs_dot_dimensions_info.contracting_dimensions().AxesArray());
471 llvm::SmallVector<int64_t, 4> lhs_transposed_shape = Concat<int64_t>(
472 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
473 lhs_dot_dimensions_info.out_dimensions().SizesArray(),
474 lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
475 auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
476 loc,
477 RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
478 lhs,
479 DenseIntElementsAttr::get(
480 RankedTensorType::get({lhs_rank}, rewriter.getI64Type()),
481 lhs_permutation));
482
483 // Transposes rhs shape to be in the order of {batch_dimensions, contracting
484 // dimensions, out_dimensions}.
485 llvm::SmallVector<int64_t, 4> rhs_permutation = Concat<int64_t>(
486 rhs_dot_dimensions_info.batch_dimensions().AxesArray(),
487 rhs_dot_dimensions_info.contracting_dimensions().AxesArray(),
488 rhs_dot_dimensions_info.out_dimensions().AxesArray());
489 llvm::SmallVector<int64_t, 4> rhs_transposed_shape = Concat<int64_t>(
490 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
491 rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
492 rhs_dot_dimensions_info.out_dimensions().SizesArray());
493 auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
494 loc,
495 RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
496 rhs,
497 DenseIntElementsAttr::get(
498 RankedTensorType::get({rhs_rank}, rewriter.getI64Type()),
499 rhs_permutation));
500
501 // Reshapes lhs to flatten out_dimensions and contracting_dimensions.
502 llvm::SmallVector<int64_t, 4> lhs_flattened_shape = Concat<int64_t>(
503 lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
504 llvm::ArrayRef<int64_t>{
505 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
506 llvm::ArrayRef<int64_t>{
507 lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
508 auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
509 loc,
510 RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
511 lhs_transposed.getResult());
512
513 // Reshapes rhs to flatten out_dimensions and contracting_dimensions.
514 llvm::SmallVector<int64_t, 4> rhs_flattened_shape = Concat<int64_t>(
515 rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
516 llvm::ArrayRef<int64_t>{
517 rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
518 llvm::ArrayRef<int64_t>{
519 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
520 auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
521 loc,
522 RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
523 rhs_transposed.getResult());
524
525 // Creates matmul op of `lhs_flattend` and `rhs_flattend`.
526 llvm::SmallVector<int64_t, 4> matmul_shape =
527 Concat<int64_t>(lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
528 llvm::ArrayRef<int64_t>{
529 lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
530 llvm::ArrayRef<int64_t>{
531 rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
532 auto matmul = rewriter.create<TF::BatchMatMulV2Op>(
533 loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
534 lhs_flattend.getResult(), rhs_flattend.getResult());
535 auto reshaped =
536 rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
537 return reshaped.getResult();
538 }
539
540 // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
541 // necessary.
ConvertDotOp(PatternRewriter & rewriter,Operation * old_op)542 Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
543 auto dot_op = cast<mhlo::DotOp>(old_op);
544 auto lhs_rank = dot_op.lhs().getType().cast<ShapedType>().getRank();
545 auto dot_dimension_numbers = DotDimensionNumbers::get(
546 /*lhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
547 /*rhs_batching_dimensions=*/rewriter.getI64TensorAttr({}),
548 /*lhs_contracting_dimensions=*/
549 rewriter.getI64TensorAttr({lhs_rank == 1 ? 0 : 1}),
550 /*rhs_contracting_dimensions=*/rewriter.getI64TensorAttr({0}),
551 rewriter.getContext());
552 return ConvertDot(rewriter, dot_op.lhs(), dot_op.rhs(), dot_dimension_numbers,
553 dot_op.getResult().getType().cast<ShapedType>(),
554 dot_op.getLoc());
555 }
556
557 // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
558 // inserted to convert to well-formed matrix multiply.
ConvertDotGeneralOp(PatternRewriter & rewriter,Operation * old_op)559 Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
560 auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
561 return ConvertDot(rewriter, dot_general_op.lhs(), dot_general_op.rhs(),
562 dot_general_op.dot_dimension_numbers(),
563 dot_general_op.getResult().getType().cast<ShapedType>(),
564 dot_general_op.getLoc());
565 }
566
567 // Checks if the specified region is a binary reduction function what takes 2
568 // inputs, passes it to an instance of the specifiied reduction op and then
569 // returns the result.
570 template <typename ReductionOp>
MatchBinaryReduceFunction(mlir::Region & function)571 LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
572 Block &body = function.front();
573 if (body.getNumArguments() != 2) return failure();
574
575 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
576 if (!return_op) return failure();
577 if (return_op.getNumOperands() != 1) return failure();
578
579 ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
580 return_op.getOperands().front().getDefiningOp());
581 if (!reduce_op) return failure();
582 if (reduce_op.lhs() != body.getArgument(0) ||
583 reduce_op.rhs() != body.getArgument(1))
584 return failure();
585
586 return success();
587 }
588
589 // Check if the specified region is a binary reduction function what takes 2
590 // inputs and returns the second input. Functions like this are used by update
591 // scatter like ops.
592 template <>
MatchBinaryReduceFunction(mlir::Region & function)593 LogicalResult MatchBinaryReduceFunction<void>(mlir::Region &function) {
594 Block &body = function.front();
595 if (body.getNumArguments() != 2) return failure();
596
597 mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
598 if (!return_op) return failure();
599 if (return_op.getNumOperands() != 1) return failure();
600 if (return_op.getOperands().front() != body.getArgument(1)) return failure();
601 return success();
602 }
603
604 // Converts an mhlo.reduce op with the specified BinaryOp as the reduction
605 // operation into the specified TfOp.
606 template <typename BinaryOp, typename TfOp>
607 class ConvertReduceOpToTfOp : public OpConversionPattern<mhlo::ReduceOp> {
608 public:
609 using OpConversionPattern::OpConversionPattern;
610
matchAndRewrite(mhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const611 LogicalResult matchAndRewrite(
612 mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
613 ConversionPatternRewriter &rewriter) const final {
614 if (failed(MatchReduceOpInput(reduce_op))) return failure();
615
616 if (failed(MatchBinaryReduceFunction<BinaryOp>(reduce_op.body())))
617 return failure();
618
619 // In `MatchReduceOpInput` function, we already match that the
620 // "mhlo::ReduceOp" only has one input, one init_value and one result.
621 if (failed(MatchInitValue(reduce_op.init_values()[0]))) return failure();
622
623 auto input = reduce_op.operands()[0];
624
625 // Get reduction dimension.
626 DenseIntElementsAttr dimension = reduce_op.dimensions();
627 SmallVector<int64_t, 4> reduce_dims;
628 for (const int64_t &dim : dimension.getValues<int64_t>()) {
629 reduce_dims.emplace_back(dim);
630 }
631 auto dim_type = RankedTensorType::get(
632 {static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
633 auto reduction_indices = rewriter.create<ConstOp>(
634 reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
635
636 rewriter.replaceOpWithNewOp<TfOp>(reduce_op, reduce_op.getType(0), input,
637 reduction_indices,
638 /*keep_dim=*/rewriter.getBoolAttr(false));
639 return success();
640 }
641
642 private:
643 // Checks that the init value matches with the init value expected for the
644 // target TfOp.
645 virtual LogicalResult MatchInitValue(Value init_value) const = 0;
646
647 // This function tries to match that the "mhlo::ReduceOp" only has one
648 // input, one init_value and one result.
MatchReduceOpInput(mhlo::ReduceOp reduce_op) const649 LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) const {
650 if (reduce_op.operands().size() != 1 ||
651 reduce_op.init_values().size() != 1 ||
652 reduce_op.getResults().size() != 1)
653 return failure();
654
655 if (!reduce_op.operands()[0].getType().isa<RankedTensorType>())
656 return failure();
657 if (!reduce_op.getType(0).isa<RankedTensorType>()) return failure();
658 return success();
659 }
660 };
661
662 class ConvertReduceOpToTfSum
663 : public ConvertReduceOpToTfOp<mhlo::AddOp, TF::SumOp> {
664 public:
665 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
666
MatchInitValue(Value init_value) const667 LogicalResult MatchInitValue(Value init_value) const override {
668 DenseFPElementsAttr init_attr;
669 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
670 !init_attr.isSplat() || !init_attr.getSplatValue<APFloat>().isZero())
671 return failure();
672 return success();
673 }
674 };
675
676 class ConvertReduceOpToTfMax
677 : public ConvertReduceOpToTfOp<mhlo::MaxOp, TF::MaxOp> {
678 public:
679 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
680
MatchInitValue(Value init_value) const681 LogicalResult MatchInitValue(Value init_value) const override {
682 DenseFPElementsAttr init_attr;
683 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
684 !init_attr.isSplat() ||
685 !init_attr.getSplatValue<APFloat>().isInfinity() ||
686 !init_attr.getSplatValue<APFloat>().isNegative())
687 return failure();
688 return success();
689 }
690 };
691
692 class ConvertReduceOpToTfMin
693 : public ConvertReduceOpToTfOp<mhlo::MinOp, TF::MinOp> {
694 public:
695 using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp;
696
MatchInitValue(Value init_value) const697 LogicalResult MatchInitValue(Value init_value) const override {
698 DenseFPElementsAttr init_attr;
699 if (!matchPattern(init_value, m_Constant(&init_attr)) ||
700 !init_attr.isSplat() ||
701 !init_attr.getSplatValue<APFloat>().isInfinity() ||
702 init_attr.getSplatValue<APFloat>().isNegative())
703 return failure();
704 return success();
705 }
706 };
707
708 class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
709 public:
710 using OpConversionPattern::OpConversionPattern;
711
matchAndRewrite(mhlo::IotaOp iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const712 LogicalResult matchAndRewrite(
713 mhlo::IotaOp iota_op, ArrayRef<Value> args,
714 ConversionPatternRewriter &rewriter) const final {
715 RankedTensorType type =
716 iota_op.getType().dyn_cast_or_null<RankedTensorType>();
717 if (!type) return failure();
718
719 const uint64_t dimension = iota_op.iota_dimension();
720 Type element_type = type.getElementType();
721 Attribute start, limit, delta;
722 if (element_type.isa<FloatType>()) {
723 start = rewriter.getFloatAttr(element_type, 0.0);
724 limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]);
725 delta = rewriter.getFloatAttr(element_type, 1.0);
726 } else if (element_type.isa<IntegerType>()) {
727 start = rewriter.getIntegerAttr(element_type, 0);
728 limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]);
729 delta = rewriter.getIntegerAttr(element_type, 1);
730 } else {
731 return failure();
732 }
733
734 auto range_type =
735 RankedTensorType::get({type.getShape()[dimension]}, element_type);
736 Value start_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), start);
737 Value limit_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), limit);
738 Value delta_op = rewriter.create<TF::ConstOp>(iota_op.getLoc(), delta);
739 Value result = rewriter.create<TF::RangeOp>(iota_op.getLoc(), range_type,
740 start_op, limit_op, delta_op);
741
742 if (type.getRank() > 1) {
743 std::vector<int64_t> reshape_shape(type.getRank(), 1);
744 reshape_shape[iota_op.iota_dimension()] = type.getShape()[dimension];
745 auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
746 Value reshape_shape_op = rewriter.create<TF::ConstOp>(
747 iota_op.getLoc(), rewriter.getI64TensorAttr(reshape_shape));
748 result = rewriter.create<TF::ReshapeOp>(iota_op.getLoc(), reshape_type,
749 result, reshape_shape_op);
750
751 Value broadcast_shape_op = rewriter.create<TF::ConstOp>(
752 iota_op.getLoc(), rewriter.getI64TensorAttr(type.getShape()));
753 result = rewriter.create<TF::BroadcastToOp>(iota_op.getLoc(), type,
754 result, broadcast_shape_op);
755 }
756
757 rewriter.replaceOp(iota_op, result);
758 return success();
759 }
760 };
761
762 // Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
763 // operation when they cleanly map to 2D or 3D average pool with VALID or SAME
764 // padding:
765 // * div(reduce_sum_window(x), constant(sizeof(window)))
766 // * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
767 class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
768 public:
769 using OpConversionPattern::OpConversionPattern;
770
matchAndRewrite(mhlo::DivOp div_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const771 LogicalResult matchAndRewrite(
772 mhlo::DivOp div_op, ArrayRef<Value> args,
773 ConversionPatternRewriter &rewriter) const final {
774 auto rw =
775 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
776 if (!rw) return failure();
777
778 // Check that the reduce-window is a sum-reduce-window.
779 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
780 return failure();
781
782 // Check that this is a floating point reduce window with a rank of 4 or 5.
783 RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
784 if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
785 rw_type.getRank() <= 3 || rw_type.getRank() > 5)
786 return failure();
787
788 // Check that the Div op doesn't do broadcasting on the output of the reduce
789 // window.
790 if (div_op.getType() != rw.getType()) return failure();
791
792 // tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
793 const uint64_t rank = rw.window_dimensions().size();
794 if (rank <= 2) return failure();
795
796 // If the init value isn't zero then it can't be an average pool.
797 if (!isFloatZero(rw.init_value())) return failure();
798
799 llvm::SmallVector<int64_t, 5> window_strides;
800 if (rw.window_strides().hasValue()) {
801 window_strides.insert(window_strides.end(),
802 rw.window_strides()->getValues<int64_t>().begin(),
803 rw.window_strides()->getValues<int64_t>().end());
804 } else {
805 window_strides.resize(rank, 1);
806 }
807
808 llvm::SmallVector<int64_t, 10> padding;
809 if (rw.padding().hasValue()) {
810 padding.insert(padding.begin(),
811 rw.padding()->getValues<int64_t>().begin(),
812 rw.padding()->getValues<int64_t>().end());
813 } else {
814 padding.resize(2 * rank, 0);
815 }
816
817 // Check that we don't do any reduction along the batch (first) and channel
818 // (last) dimensions.
819 const uint64_t batch_dim = 0;
820 const uint64_t channel_dim = rank - 1;
821 if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
822 rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
823 window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
824 padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
825 padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
826 return failure();
827
828 if (rw.window_dilations().hasValue() &&
829 !(rw.window_dilations()->isSplat() &&
830 rw.window_dilations()->getSplatValue<APInt>() == 1))
831 return failure();
832
833 if (rw.base_dilations().hasValue() &&
834 !(rw.base_dilations()->isSplat() &&
835 rw.base_dilations()->getSplatValue<APInt>() == 1))
836 return failure();
837
838 DenseFPElementsAttr divisor;
839 if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
840 // If the divisor is a constant then check that it matches with the number
841 // of elements inside the window what is required for a VALID AvgPool.
842 if (!divisor.isSplat()) return failure();
843 int64_t window_size = 1;
844 for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
845 window_size *= w;
846 }
847 if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
848 return failure();
849
850 // Check that we have no padding.
851 if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
852 return failure();
853
854 return replaceWithAvgPool(
855 div_op, rw.operand(),
856 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
857 window_strides, "VALID", rewriter);
858 }
859
860 auto rw_rhs =
861 dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
862 if (rw_rhs) {
863 // Check that RHS is a sum-reduce-window.
864 if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
865 return failure();
866
867 // Check that the RHS is a reduce_window over a constant 1 input with 0 as
868 // the init value.
869 DenseFPElementsAttr rhs_input;
870 if (!isFloatZero(rw_rhs.init_value()) ||
871 !matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
872 !rhs_input.isSplat() ||
873 !rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
874 return failure();
875
876 // Check that the two reduce window have the same window configuration.
877 if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
878 rw.window_strides() != rw_rhs.window_strides() ||
879 rw.window_dilations() != rw_rhs.window_dilations() ||
880 rw.base_dilations() != rw_rhs.base_dilations() ||
881 rw.padding() != rw_rhs.padding())
882 return failure();
883
884 if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
885 return replaceWithAvgPool(
886 div_op, rw.operand(),
887 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
888 window_strides, "VALID", rewriter);
889
890 RankedTensorType input_type =
891 rw.operand().getType().dyn_cast<RankedTensorType>();
892 RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
893 if (!input_type || !output_type) return failure();
894
895 // Check that the individual padding values are corresponding to SAME
896 // padding from TensorFlow.
897 for (uint64_t i = 1; i < rank - 1; ++i) {
898 int64_t padding_size =
899 (output_type.getShape()[i] - 1) * window_strides[i] +
900 rw.window_dimensions().getValue<int64_t>({i}) -
901 input_type.getShape()[i];
902 if (padding[2 * i] !=
903 tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
904 padding[2 * i + 1] !=
905 tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
906 return failure();
907 }
908 return replaceWithAvgPool(
909 div_op, rw.operand(),
910 llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
911 window_strides, "SAME", rewriter);
912 }
913 return failure();
914 }
915
916 private:
isFloatZero(Value value) const917 bool isFloatZero(Value value) const {
918 DenseFPElementsAttr initial_value;
919 return matchPattern(value, m_Constant(&initial_value)) &&
920 initial_value.getNumElements() == 1 &&
921 initial_value.getValue<APFloat>({}).isZero();
922 }
923
replaceWithAvgPool(mhlo::DivOp op,Value input,llvm::ArrayRef<int64_t> ksizes,llvm::ArrayRef<int64_t> kstrides,llvm::StringRef padding,ConversionPatternRewriter & rewriter) const924 LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
925 llvm::ArrayRef<int64_t> ksizes,
926 llvm::ArrayRef<int64_t> kstrides,
927 llvm::StringRef padding,
928 ConversionPatternRewriter &rewriter) const {
929 if (ksizes.size() == 4) {
930 rewriter.replaceOpWithNewOp<AvgPoolOp>(
931 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
932 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
933 rewriter.getStringAttr("NHWC"));
934 return success();
935 } else if (ksizes.size() == 5) {
936 rewriter.replaceOpWithNewOp<AvgPool3DOp>(
937 op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
938 rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
939 rewriter.getStringAttr("NDHWC"));
940 return success();
941 }
942 return failure();
943 }
944 };
945
946 class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
getDependentDialects(DialectRegistry & registry) const947 void getDependentDialects(DialectRegistry ®istry) const override {
948 registry.insert<TF::TensorFlowDialect>();
949 }
950
951 public:
952 LegalizeHloToTf() = default;
LegalizeHloToTf(const LegalizeHloToTf &)953 LegalizeHloToTf(const LegalizeHloToTf &) {}
954
955 /// Performs the legalization to the TF dialect.
956 void runOnFunction() override;
957 };
958
959 // Returns the shape of the given value in a Constant Op.
ShapeToConst(PatternRewriter & rewriter,Value value)960 ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
961 ArrayRef<int64_t> shape = value.getType().cast<ShapedType>().getShape();
962 auto attr_type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
963 rewriter.getIntegerType(64));
964 auto attr = DenseElementsAttr::get(attr_type, shape);
965 return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
966 }
967
968 // If index_vector_dim == indices.rank() then insert the implicit extra
969 // dimension into indices to normalize everything to index_vector_dim ==
970 // indices.rank() - 1.
NormalizeIndexVector(Operation * parent_op,Value & indices,ShapedType & indices_type,int64_t index_vector_dim,ConversionPatternRewriter & rewriter)971 LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices,
972 ShapedType &indices_type,
973 int64_t index_vector_dim,
974 ConversionPatternRewriter &rewriter) {
975 if (index_vector_dim == indices_type.getRank()) {
976 llvm::SmallVector<int64_t, 4> new_start_indices_shape(
977 indices_type.getShape().begin(), indices_type.getShape().end());
978 new_start_indices_shape.push_back(1);
979 indices_type = RankedTensorType::get(new_start_indices_shape,
980 indices_type.getElementType());
981 indices = rewriter.create<mhlo::ReshapeOp>(parent_op->getLoc(),
982 indices_type, indices);
983 } else if (index_vector_dim != indices_type.getRank() - 1) {
984 // If index_vector_dim isn't the last dimension in indices then it isn't
985 // supported yet.
986 // TODO(tberghammer): Transpose indices to support this usecase.
987 return rewriter.notifyMatchFailure(
988 parent_op,
989 "index vector dim isn't the last dimension in start indices");
990 }
991 return success();
992 }
993
994 // Check that `attr` is an R1 iota with integer element type starting from `0`
995 // with `size` number of values.
IsIotaAttr(const DenseIntElementsAttr & attr,int64_t size)996 bool IsIotaAttr(const DenseIntElementsAttr &attr, int64_t size) {
997 if (!attr.getType().getElementType().isa<IntegerType>()) return false;
998 if (attr.getType().getRank() != 1) return false;
999 if (attr.getNumElements() != size) return false;
1000 int64_t iota = 0;
1001 for (auto s : attr.getIntValues()) {
1002 if (s != iota) return false;
1003 ++iota;
1004 }
1005 return true;
1006 }
1007
1008 class ConvertGatherOp : public OpConversionPattern<mhlo::GatherOp> {
1009 public:
1010 using OpConversionPattern::OpConversionPattern;
1011
matchAndRewrite(mhlo::GatherOp gather_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1012 LogicalResult matchAndRewrite(
1013 mhlo::GatherOp gather_op, ArrayRef<Value> args,
1014 ConversionPatternRewriter &rewriter) const final {
1015 Value operand = gather_op.operand();
1016 Value start_indices = gather_op.start_indices();
1017
1018 // Can only convert with static shaped gather.
1019 ShapedType operand_type = operand.getType().cast<ShapedType>();
1020 ShapedType start_indices_type = start_indices.getType().cast<ShapedType>();
1021 ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>();
1022 if (!operand_type.hasStaticShape() ||
1023 !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) {
1024 return failure();
1025 }
1026
1027 // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1028 int64_t index_vector_dim =
1029 gather_op.dimension_numbers().index_vector_dim().getInt();
1030 if (failed(NormalizeIndexVector(gather_op, start_indices,
1031 start_indices_type, index_vector_dim,
1032 rewriter))) {
1033 return failure();
1034 }
1035
1036 // Verify that start_index_map and collapsed_slice_dims are both an iota
1037 // with the same number of elements as the last dimension of start_indices.
1038 auto start_index_map = gather_op.dimension_numbers().start_index_map();
1039 auto collapsed_slice_dims =
1040 gather_op.dimension_numbers().collapsed_slice_dims();
1041 if (!IsIotaAttr(start_index_map, start_indices_type.getShape().back()) ||
1042 !IsIotaAttr(collapsed_slice_dims,
1043 start_indices_type.getShape().back())) {
1044 // TODO(tberghammer): Transform start_indices to support non-standard
1045 // start_index_maps.
1046 return rewriter.notifyMatchFailure(
1047 gather_op, "unsupported start index map and/or collapsed slice dims");
1048 }
1049
1050 // Verify that slice_sizes is 1 for the indexed dimensions and the full
1051 // shape for the rest of the dimensions.
1052 auto slice_sizes = gather_op.slice_sizes();
1053 int64_t index = 0;
1054 for (int64_t s : slice_sizes.getValues<int64_t>()) {
1055 if (index < start_indices_type.getShape().back()) {
1056 if (s != 1) {
1057 return rewriter.notifyMatchFailure(gather_op,
1058 "unsupported slice sizes");
1059 }
1060 } else {
1061 if (s != operand_type.getShape()[index]) {
1062 return rewriter.notifyMatchFailure(gather_op,
1063 "unsupported slice sizes");
1064 }
1065 }
1066 ++index;
1067 }
1068
1069 // Verify that offset_dims are the tailing dimensions in the output tensor.
1070 auto offset_dims = gather_op.dimension_numbers().offset_dims();
1071 int64_t offset = start_indices_type.getRank() - 1;
1072 for (int64_t o : offset_dims.getValues<int64_t>()) {
1073 if (o != offset) {
1074 return rewriter.notifyMatchFailure(gather_op,
1075 "unsupported offset dims");
1076 }
1077 ++offset;
1078 }
1079
1080 rewriter.replaceOpWithNewOp<TF::GatherNdOp>(gather_op, result_type, operand,
1081 start_indices);
1082 return success();
1083 }
1084 };
1085
1086 template <typename BinaryOp, typename TfOp>
1087 class ConvertScatterOp : public OpConversionPattern<mhlo::ScatterOp> {
1088 public:
1089 using OpConversionPattern::OpConversionPattern;
1090
matchAndRewrite(mhlo::ScatterOp scatter_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1091 LogicalResult matchAndRewrite(
1092 mhlo::ScatterOp scatter_op, ArrayRef<Value> args,
1093 ConversionPatternRewriter &rewriter) const final {
1094 Value operand = scatter_op.operand();
1095 Value indices = scatter_op.scatter_indices();
1096 Value updates = scatter_op.updates();
1097 ShapedType operand_type = operand.getType().cast<ShapedType>();
1098 ShapedType indices_type = indices.getType().cast<ShapedType>();
1099 ShapedType updates_type = updates.getType().cast<ShapedType>();
1100
1101 // Can only convert with static shaped scatter.
1102 if (!operand_type.hasStaticShape() || !indices_type.hasStaticShape() ||
1103 !updates_type.hasStaticShape()) {
1104 return failure();
1105 }
1106
1107 // Normalize start_indices so index_vector_dim == start_indices.rank() - 1.
1108 int64_t index_vector_dim =
1109 scatter_op.scatter_dimension_numbers().index_vector_dim().getInt();
1110 if (failed(NormalizeIndexVector(scatter_op, indices, indices_type,
1111 index_vector_dim, rewriter))) {
1112 return failure();
1113 }
1114
1115 // Verify that inserted_window_dims and scatter_dims_to_operand_dims are
1116 // both an iota with the same number of elements as the last dimension of
1117 // start_indices.
1118 auto inserted_window_dims =
1119 scatter_op.scatter_dimension_numbers().inserted_window_dims();
1120 auto scatter_dims_to_operand_dims =
1121 scatter_op.scatter_dimension_numbers().scatter_dims_to_operand_dims();
1122 if (!IsIotaAttr(inserted_window_dims, indices_type.getShape().back()) ||
1123 !IsIotaAttr(scatter_dims_to_operand_dims,
1124 indices_type.getShape().back())) {
1125 // TODO(tberghammer): Transform indices to support non-standard
1126 // scatter_dims_to_operand_dims.
1127 return rewriter.notifyMatchFailure(
1128 scatter_op,
1129 "unsupported inserted window dims and/or scatter dims to operand "
1130 "dims");
1131 }
1132
1133 // Verify that update window dims are the tailing dimensions in the update
1134 // tensor.
1135 auto update_window_dims =
1136 scatter_op.scatter_dimension_numbers().update_window_dims();
1137 int64_t offset = indices_type.getRank() - 1;
1138 for (int64_t o : update_window_dims.getValues<int64_t>()) {
1139 if (o != offset) {
1140 return rewriter.notifyMatchFailure(scatter_op,
1141 "unsupported update window dims");
1142 }
1143 ++offset;
1144 }
1145
1146 // Match the scatter computation against computations supported by TF.
1147 if (failed(MatchBinaryReduceFunction<BinaryOp>(
1148 scatter_op.update_computation()))) {
1149 return failure();
1150 }
1151
1152 rewriter.replaceOpWithNewOp<TfOp>(scatter_op,
1153 scatter_op.getResult().getType(), operand,
1154 indices, updates);
1155 return success();
1156 }
1157 };
1158 using ConvertScatterAddOp =
1159 ConvertScatterOp<mhlo::AddOp, TF::TensorScatterAddOp>;
1160 using ConvertScatterMaxOp =
1161 ConvertScatterOp<mhlo::MaxOp, TF::TensorScatterMaxOp>;
1162 using ConvertScatterMinOp =
1163 ConvertScatterOp<mhlo::MinOp, TF::TensorScatterMinOp>;
1164 using ConvertScatterSubOp =
1165 ConvertScatterOp<mhlo::SubOp, TF::TensorScatterSubOp>;
1166 using ConvertScatterUpdateOp =
1167 ConvertScatterOp<void, TF::TensorScatterUpdateOp>;
1168
1169 // Converts mhlo.pad to tf.PadV2
ConvertPadOp(PatternRewriter & rewriter,Operation * old_op)1170 Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
1171 auto pad_op = cast<mhlo::PadOp>(old_op);
1172 mlir::Location loc = pad_op.getLoc();
1173
1174 llvm::SmallVector<APInt, 8> padding;
1175 for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
1176 pad_op.edge_padding_high().getValues<APInt>())) {
1177 padding.push_back(std::get<0>(p));
1178 padding.push_back(std::get<1>(p));
1179 }
1180 auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
1181 rewriter.getI64Type());
1182 auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
1183 auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
1184 return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
1185 padding_op, pad_op.padding_value());
1186 }
1187
1188 // Returns true if broadcast_dimensions obey Tensorflow convention, as in new
1189 // dimensions are added as prefix.
IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,Value output)1190 bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
1191 Value output) {
1192 // broadcast_dimensions is an increasing list by definition, thus it suffices
1193 // to check the first element.
1194 int64_t input_rank = broadcast_dimensions.getNumElements();
1195 int64_t output_rank = output.getType().cast<ShapedType>().getRank();
1196 return input_rank == 0 ||
1197 (broadcast_dimensions.getValue({0}).cast<IntegerAttr>().getInt() ==
1198 output_rank - input_rank);
1199 }
1200
1201 // Returns the intermediate shape that input tensor should be reshaped to during
1202 // legalization of BroadcastInDimOp.
ExpandedShape(PatternRewriter & rewriter,Value input,DenseIntElementsAttr broadcast_dimensions,Value output)1203 ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input,
1204 DenseIntElementsAttr broadcast_dimensions,
1205 Value output) {
1206 // Initialize expanded shape with output rank and dimensions of 1.
1207 SmallVector<Attribute, 4> expanded_shape(
1208 output.getType().cast<ShapedType>().getRank(),
1209 /*Value=*/rewriter.getI64IntegerAttr(1));
1210
1211 // Set dimension sizes specified by broadcast_dimensions.
1212 ArrayRef<int64_t> input_shape = input.getType().cast<ShapedType>().getShape();
1213 for (auto x : llvm::enumerate(broadcast_dimensions)) {
1214 expanded_shape[x.value().getSExtValue()] =
1215 rewriter.getI64IntegerAttr(input_shape[x.index()]);
1216 }
1217
1218 // Create the expanded type wrapped in a ConstantOp.
1219 auto attr_type =
1220 RankedTensorType::get({static_cast<int64_t>(expanded_shape.size())},
1221 rewriter.getIntegerType(64));
1222 auto attr = DenseElementsAttr::get(attr_type, expanded_shape);
1223 return rewriter.create<ConstantOp>(output.getLoc(), attr_type, attr);
1224 }
1225
1226 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
1227
1228 /// Performs the lowering to XLA dialect.
runOnFunction()1229 void LegalizeHloToTf::runOnFunction() {
1230 MLIRContext &context = getContext();
1231
1232 // Add legalization patterns to the list.
1233 OwningRewritePatternList patterns;
1234 PopulateLegalizeHloToTfPatterns(&patterns, &context);
1235
1236 ConversionTarget target(context);
1237 target.addLegalDialect<TensorFlowDialect>();
1238 target.addLegalOp<CallOp, ConstantOp>();
1239 if (failed(
1240 applyPartialConversion(getFunction(), target, std::move(patterns)))) {
1241 getFunction().emitError("mhlo to TF legalization failed.");
1242 signalPassFailure();
1243 }
1244 }
1245
1246 static PassRegistration<LegalizeHloToTf> pass(
1247 "tf-legalize-hlo", "Legalize from HLO to the TF dialect");
1248
1249 } // end namespace
1250
PopulateLegalizeHloToTfPatterns(OwningRewritePatternList * patterns,MLIRContext * context)1251 void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
1252 MLIRContext *context) {
1253 patterns
1254 ->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertDynamicSliceOp,
1255 ConvertGatherOp, ConvertScatterAddOp, ConvertScatterMaxOp,
1256 ConvertScatterMinOp, ConvertScatterSubOp, ConvertScatterUpdateOp,
1257 ConvertSliceOp, ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
1258 ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
1259 populateWithGenerated(context, *patterns);
1260 }
1261
CreateLegalizeHloToTfPass()1262 std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
1263 return std::make_unique<LegalizeHloToTf>();
1264 }
1265
1266 } // end namespace TF
1267 } // end namespace mlir
1268