• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s
2// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all
3// This test runs twice:
4//   1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying
5//      that the chlo ops emit produces more useful tests.
6//   2. With chlo legalization enabled, verifying diagnostics to pick up any
7//      issues with the full lowering (can catch some broadcasting corner
8//      cases which emit with a warning).
9
10//===----------------------------------------------------------------------===//
11// BatchNorm op legalizations.
12//===----------------------------------------------------------------------===//
13
14// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same
15// code), so only do a couple of basic checks.
16
17// CHECK-LABEL: fusedBatchNormV2_noTraining
18func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
19  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
20  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
21  return %0#0 : tensor<8x8x8x8xf32>
22}
23
24// CHECK-LABEL: fusedBatchNormV2_training
25func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
26  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
27  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
28  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
29  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
30  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
31  // CHECK: mhlo.constant
32  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
33  return %0#0 : tensor<8x8x8x8xf32>
34}
35
36// CHECK-LABEL: fusedBatchNormV3_noTraining
37func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
38  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
39  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
40  return %0#0 : tensor<8x8x8x8xf32>
41}
42
43// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
44// CHECK-SAME:  ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>)
45func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) {
46  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
47  // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
48  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
49  // CHECK: [[Y_CONVERT:%.*]] = "mhlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
50  // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32>
51  // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32>
52  // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]]
53  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
54}
55
56// CHECK-LABEL: fusedBatchNormV3_training
57func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
58  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
59  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
60  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
61  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
62  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
63  // CHECK: mhlo.constant
64  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
65  return %0#0 : tensor<8x8x8x8xf32>
66}
67
68// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance
69func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> {
70  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
71  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
72  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
73  // CHECK: return %[[VAR]]
74  return %0#4 : tensor<8xf32>
75}
76
77// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor
78func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
79  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
80  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
81  // CHECK-DAG: %[[BATCH_MEAN:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32}
82  // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32}
83
84  // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694>
85  // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
86
87  // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988>
88  // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01>
89
90  // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3
91  // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
92  // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
93
94  // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4
95  // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
96  // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
97
98  // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]]
99  return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
100}
101
102// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision
103func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
104  // CHECK: "mhlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
105  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
106  // CHECK: "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
107  return %0#0 : tensor<8x8x8x8xbf16>
108}
109
110// CHECK-LABEL: fusedBatchNormV3_NCHW
111func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
112  // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
113  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
114  return %0#0 : tensor<8x8x8x8xf32>
115}
116
117// CHECK-LABEL: fusedBatchNormV3_NDHWC
118func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) {
119  // CHECK: feature_index = 4 : i64
120  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
121  return %0#0 : tensor<8x8x8x8x8xf32>
122}
123
124// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported
125func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
126  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32>
127  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
128  return %0#0 : tensor<?x?x?x?xf32>
129}
130
131// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1
132func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
133  // CHECK: tf.FusedBatchNormV3
134  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
135  return %0#0 : tensor<?x?x?x?xf32>
136}
137
138// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2
139func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) {
140  // CHECK: tf.FusedBatchNormV3
141  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>)
142  return %0#0 : tensor<?x6x?x?xf32>
143}
144
145// CHECK-LABEL: fusedBatchNormGrad_noTraining
146func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
147  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
148  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
149  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
150
151  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
152  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
153
154  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
155  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
156  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
157  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
158  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
159  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
160  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
161  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
162  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
163  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
164  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
165  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
166
167  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
168  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
169  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
170  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
171
172  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
173  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
174  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
175  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
176  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
177  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
178  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
179  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
180  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
181
182  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
183  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
184
185  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
186  return %0#0 : tensor<8x8x8x8xf32>
187}
188
189// CHECK-LABEL: fusedBatchNormGrad_Training
190func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
191  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
192  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
193  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
194  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
195  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
196  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
197  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
198  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
199
200  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
201  return %0#0 : tensor<8x8x8x8xf32>
202}
203
204// CHECK-LABEL: fusedBatchNormGradV2_noTraining
205func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
206  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
207  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
208  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
209
210  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
211  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
212
213  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
214  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
215  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
216  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
217  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
218  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
219  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
220  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
221  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
222  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
223  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
224  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
225
226  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
227  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
228  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
229
230  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
231
232  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
233  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
234  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
235  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
236  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
237  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
238  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
239  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
240  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
241
242  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
243  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
244
245  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
246  return %0#0 : tensor<8x8x8x8xf32>
247}
248
249// CHECK-LABEL: fusedBatchNormGradV2_Training
250func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
251  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
252  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
253  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
254  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
255  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
256  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
257  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
258  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
259
260  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
261  return %0#0 : tensor<8x8x8x8xf32>
262}
263
264// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision
265func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
266  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
267  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
268
269  // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
270  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
271
272  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
273  return %0#0 : tensor<8x8x8x8xbf16>
274}
275
276// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision
277func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
278  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
279  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
280  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
281  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
282  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
283  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
284  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
285  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
286
287  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
288  return %0#0 : tensor<8x8x8x8xbf16>
289}
290
291// CHECK-LABEL: fusedBatchNormGradV3_noTraining
292func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
293  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
294  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
295  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
296
297  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
298  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
299
300  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
301  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
302  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
303  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
304  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
305  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
306  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
307  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
308  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
309  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
310  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
311  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
312
313  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
314  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
315  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
316
317  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
318
319  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
320  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
321  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
322  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
323  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
324  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
325  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
326  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
327  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
328
329  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
330  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
331
332  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
333  return %0#0 : tensor<8x8x8x8xf32>
334}
335
336// CHECK-LABEL: fusedBatchNormGradV3_Training
337func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) {
338  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
339  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
340  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
341  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
342  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
343  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
344  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
345  // CHECK: return %[[x_backprop]]
346  // CHECK-SAME: tensor<8x8x8x8xf32>
347
348  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>)
349  return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>
350}
351
352// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision
353func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
354  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
355  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
356
357  // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
358  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
359
360  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
361  return %0#0 : tensor<8x8x8x8xbf16>
362}
363
364// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision
365func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
366  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
367  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
368  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
369  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
370  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
371  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
372  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
373  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
374
375  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
376  return %0#0 : tensor<8x8x8x8xbf16>
377}
378
379// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW
380func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
381  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
382  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
383  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
384
385  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
386  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
387
388  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
389  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
390  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
391  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
392  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
393  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
394  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
395  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
396  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
397  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
398  // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
399  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
400
401  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
402  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
403  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
404
405  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
406
407  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
408  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
409  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
410  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
411  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
412  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
413  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
414  // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
415  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
416
417  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
418  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
419
420  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
421  return %0#0 : tensor<8x8x8x8xf32>
422}
423
424// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW
425func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
426  // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
427  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
428  return %0#0 : tensor<8x8x8x8xf32>
429}
430
431//===----------------------------------------------------------------------===//
432// Bias op legalizations.
433//===----------------------------------------------------------------------===//
434
435// CHECK-LABEL: func @biasAdd_default
436func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
437  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
438  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
439  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
440  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
441  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
442  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
443  return %0 : tensor<1x32x10x32xi32>
444}
445
446// CHECK-LABEL: func @biasAdd_NHWC
447func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
448  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
449  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
450  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
451  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
452  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
453  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
454  return %0 : tensor<1x32x10x32xi32>
455}
456
457// CHECK-LABEL: func @biasAdd_NCHW
458func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
459  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
460  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
461  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
462  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
463  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
464  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
465  return %0 : tensor<1x32x10x32xi32>
466}
467
468// CHECK-LABEL: func @biasAdd_dynamic
469func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
470  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
471  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
472  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
473  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
474  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
475  %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
476  return %0 : tensor<?x?x?x?xi32>
477}
478
479// CHECK-LABEL: func @biasAdd_partial_dynamic
480func @biasAdd_partial_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<512xi32>) -> tensor<?x?x?x512xi32> {
481  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
482  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
483  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
484  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
485  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
486  // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor<?x?x?x?xi32> to tensor<?x?x?x512xi32>
487  // CHECK: return %[[CAST]] : tensor<?x?x?x512xi32>
488  %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<?x?x?x?xi32>, tensor<512xi32>) -> tensor<?x?x?x512xi32>
489  return %0 : tensor<?x?x?x512xi32>
490}
491
492
493//===----------------------------------------------------------------------===//
494// ClipByValue
495//===----------------------------------------------------------------------===//
496
497// CHECK-LABEL: @clip
498func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
499  // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2)
500
501  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
502  // CHECK: return [[VAL]]
503  return %0 : tensor<f32>
504}
505
506// CHECK-LABEL: @clip_dynamic
507func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
508  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2)
509  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
510
511  // CHECK: return [[CLAMP]]
512  return %0 : tensor<?xf32>
513}
514
515// CHECK-LABEL: @clip_static_broadcast
516func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> {
517  // CHECK-DAG: [[SHPIDX:%.+]] = mhlo.constant dense<5>
518  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
519  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
520  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]])
521  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
522
523  // CHECK: return [[CLAMP]]
524  return %0 : tensor<5xf32>
525}
526
527
528// CHECK-LABEL: @clip_dynamic_broadcast
529func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> {
530  // CHECK: [[SHP:%.+]] = shape.shape_of %arg0
531  // CHECK: [[SHPIDX:%.+]] = index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32>
532  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
533  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
534  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]])
535  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
536
537  // CHECK: return [[CLAMP]]
538  return %0 : tensor<?xf32>
539}
540
541//===----------------------------------------------------------------------===//
542// DiagPart
543//===----------------------------------------------------------------------===//
544
545// CHECK-LABEL: func @diag_part
546// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32>
547func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> {
548  // CHECK: %[[RS:.*]] = "mhlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32>
549  // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32>
550  // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32>
551  // CHECK-DAG: %[[COMP:.*]] = "mhlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1>
552  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
553  // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32>
554  // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32>
555  // CHECK-DAG: %[[RED:.*]] = "mhlo.reduce"(%[[SEL]], %[[ZERO]])
556  // CHECK-DAG:  mhlo.add
557  // CHECK-DAG: {dimensions = dense<0> : tensor<1xi64>} : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32>
558  // CHECK-DAG:  %[[RES:.*]] = "mhlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32>
559  // CHECK-DAG:  return %[[RES]] : tensor<4x3xf32>
560  %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32>
561  return %0: tensor<4x3xf32>
562}
563
564//===----------------------------------------------------------------------===//
565// MatrixDiagPart
566//===----------------------------------------------------------------------===//
567
568// CHECK-LABEL: func @matrix_diag_part
569// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32>
570func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
571  // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32>
572  // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32>
573  // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32>
574  // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32>
575  // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32>
576  // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
577  // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1>
578  // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
579  // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1>
580  // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
581  // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32>
582  // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
583  // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32>
584  // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
585  // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32>
586  // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
587  // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32>
588  // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
589  // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32>
590  // CHECK-DAG: %[[V19:.*]] = "mhlo.negate"(%[[V18]]) : (tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
591  // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
592  // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32>
593  // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
594  // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32>
595  // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32>
596  // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
597  // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32>
598  // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
599  // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
600  // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32>
601  // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32>
602  // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32>
603  // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32>
604  // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32>
605  // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
606  // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
607  // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1>
608  // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
609  // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
610  // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1>
611  // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1>
612  // CHECK-DAG: %[[V41:.*]] = "mhlo.reshape"(%[[V40]]) : (tensor<1x22x128xi1>) -> tensor<22x128xi1>
613  // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32>
614  // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = {collapsed_slice_dims = dense<[1, 2]> : tensor<2xi64>, index_vector_dim = 0 : i64, offset_dims = dense<0> : tensor<1xi64>, start_index_map = dense<[1, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32>
615  // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1>
616  // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32>
617  // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32>
618  // CHECK: return %[[V46]] : tensor<7x22x128xi32>
619  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
620  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
621  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
622      T = i32, align = "RIGHT_LEFT"
623  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
624  return %2: tensor<7x22x128xi32>
625}
626
627// CHECK-LABEL: func @matrix_diag_part_single_diagonal
628func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> {
629  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
630  %1 = mhlo.constant dense<0> : tensor<2xi32>  // k
631  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
632      T = i32, align = "RIGHT_LEFT"
633  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32>
634  // CHECK: %[[result:.*]] = "mhlo.reshape"({{.*}}) : (tensor<7x1x128xi32>) -> tensor<7x128xi32>
635  // CHECK: return %[[result]] : tensor<7x128xi32>
636  return %2: tensor<7x128xi32>
637}
638
639// CHECK-LABEL: func @matrix_diag_part_align_ll
640func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
641  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
642  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
643  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
644      T = i32, align = "LEFT_LEFT"
645  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
646  // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1>
647  // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
648  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
649  return %2: tensor<7x22x128xi32>
650}
651
652// CHECK-LABEL: func @matrix_diag_part_align_lr
653func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
654  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
655  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
656  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
657      T = i32, align = "LEFT_RIGHT"
658  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
659  // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "LE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
660  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
661  return %2: tensor<7x22x128xi32>
662}
663
664// CHECK-LABEL: func @matrix_diag_part_align_rl
665func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
666  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
667  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
668  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
669      T = i32, align = "RIGHT_LEFT"
670  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
671  // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
672  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
673  return %2: tensor<7x22x128xi32>
674}
675
676// CHECK-LABEL: func @matrix_diag_part_align_rr
677func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
678  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
679  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
680  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
681      T = i32, align = "RIGHT_RIGHT"
682  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
683  // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1>
684  // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
685  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
686  return %2: tensor<7x22x128xi32>
687}
688
689// CHECK-LABEL: func @matrix_diag_part_align_7d
690// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32>
691func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> {
692  %0 = mhlo.constant dense<-1.> : tensor<f32>  // padding value
693  %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32>  // k
694  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
695      T = f32, align = "LEFT_RIGHT"
696  } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32>
697  return %2: tensor<3x5x7x9x11x4x10xf32>
698}
699
700//===----------------------------------------------------------------------===//
701// Erf
702//===----------------------------------------------------------------------===//
703
704// CHECK-LABEL: func @erf
705func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
706  // CHECK: chlo.erf %arg0 : tensor<2x3xf32>
707  %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
708  return %0 : tensor<2x3xf32>
709}
710
711//===----------------------------------------------------------------------===//
712// Erfc
713//===----------------------------------------------------------------------===//
714
715// CHECK-LABEL: func @erfc
716func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
717  // CHECK: chlo.erfc %arg0 : tensor<2x3xf32>
718  %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
719  return %0 : tensor<2x3xf32>
720}
721
722//===----------------------------------------------------------------------===//
723// Einsum.
724//===----------------------------------------------------------------------===//
725
726// CHECK-LABEL: func @einsum
727func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
728  // CHECK:  mhlo.einsum
729  %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
730  return %0: tensor<2x4xf32>
731}
732
733// CHECK-LABEL: func @unary_einsum
734func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
735  // CHECK:  mhlo.unary_einsum
736  %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
737  return %0: tensor<2x2xf32>
738}
739
740//===----------------------------------------------------------------------===//
741// FloorDiv and FloorMod.
742//===----------------------------------------------------------------------===//
743
744// CHECK-LABEL: func @floordiv_broadcast_i32
745func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
746  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
747  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
748  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
749  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
750  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
751  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
752  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
753  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
754  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
755  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
756  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
757  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
758  // CHECK: return [[SELECT]]
759  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
760  return %0: tensor<2x3xi32>
761}
762
763// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
764func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
765  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
766  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]]
767  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
768  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
769  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
770  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
771  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
772  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
773  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
774  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
775  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
776  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
777  // CHECK: return [[SELECT]]
778  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
779  return %0: tensor<2x3xi32>
780}
781
782// CHECK-LABEL: func @floordiv_f32
783func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
784  // CHECK-NEXT:  %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0
785  // CHECK-NEXT:  %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]])
786  // CHECK-NEXT:  return %[[FLOOR]] : tensor<2xf32>
787  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
788  return %0: tensor<2xf32>
789}
790
791// CHECK-LABEL: func @floordiv_bf16
792func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
793  // CHECK-NEXT:  mhlo.convert
794  // CHECK-NEXT:  mhlo.convert
795  // CHECK-NEXT:  chlo.broadcast_divide
796  // CHECK-NEXT:  mhlo.floor
797  // CHECK-NEXT:  mhlo.convert
798  // CHECK-NEXT:  return
799  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16>
800  return %0: tensor<2xbf16>
801}
802
803// CHECK-LABEL: func @floordiv_f16_broadcast
804func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
805  // CHECK-NEXT:  chlo.broadcast_divide
806  // CHECK-NEXT:  mhlo.floor
807  // CHECK-NEXT:  return
808  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
809  return %0: tensor<2x3xf16>
810}
811
812// CHECK-LABEL: func @floordiv_dynamic
813func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
814  // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
815  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
816  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"}
817  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
818  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
819  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
820  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
821  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
822  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
823  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
824  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]]
825  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]])
826  // CHECK: return [[SELECT]]
827  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
828  return %0: tensor<?x?xi32>
829}
830
831// CHECK-LABEL: func @floordiv_unranked
832func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
833  // CHECK-NOT: tf.FloorDiv
834  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
835  return %0: tensor<*xf32>
836}
837
838// CHECK-LABEL: func @floordiv_int
839func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
840  // CHECK: tf.FloorDiv
841  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
842  return %0: tensor<*xi32>
843}
844
845// CHECK-LABEL: func @floormod_broadcast_numerator
846func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
847  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
848  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
849  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
850  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
851  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
852  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
853  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
854  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
855  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]]
856  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
857  // CHECK-NEXT: return [[SELECT]]
858  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
859  return %0: tensor<2x3xi32>
860}
861
862// CHECK-LABEL: func @floormod_broadcast_denominator
863func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
864  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
865  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
866  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
867  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
868  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
869  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
870  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
871  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
872  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
873  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
874  // CHECK-NEXT: return [[SELECT]]
875  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
876  return %0: tensor<2x3xi32>
877}
878
879// CHECK-LABEL: func @floormod_dynamic
880func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
881  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
882  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
883  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
884  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
885  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
886  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
887  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
888  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
889  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
890  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
891  // CHECK-NEXT: return [[SELECT]]
892  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
893  return %0: tensor<?x?xi32>
894}
895
896// CHECK-LABEL: func @floormod_unranked
897func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
898  // CHECK-NOT: tf.FloorMod
899  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
900  return %0: tensor<*xi32>
901}
902
903//===----------------------------------------------------------------------===//
904// OnesLike
905//===----------------------------------------------------------------------===//
906
907// CHECK-LABEL: @ones_like
908// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x?xf32>)
909func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
910  // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 1.0{{.*}}}
911  // CHECK: return %[[RES]]
912  %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
913  return %0 : tensor<2x?xf32>
914}
915
916//===----------------------------------------------------------------------===//
917// ZerosLike
918//===----------------------------------------------------------------------===//
919
920// CHECK-LABEL: @zeros_like
921// CHECK-SAME:  (%[[ARG:.*]]: tensor<2x?xf32>)
922func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
923  // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 0.0{{.*}}}
924  // CHECK: return %[[RES]]
925  %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
926  return %0 : tensor<2x?xf32>
927}
928
929//===----------------------------------------------------------------------===//
930// BroadcastTo.
931//===----------------------------------------------------------------------===//
932
933// CHECK-LABEL: func @broadcast_to
934func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
935  %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
936
937  // CHECK: [[CST:%.+]] = mhlo.constant
938  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CST]])
939  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
940  %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32>
941  return %0 : tensor<16x16x16x16xf32>
942}
943
944// CHECK-LABEL: func @broadcast_scalar_to_unranked
945// CHECK: (%[[ARG0:.*]]: tensor<f32>, %[[SHAPE:.*]]: tensor<?xi32>)
946func @broadcast_scalar_to_unranked(%arg0: tensor<f32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
947  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]])
948  // CHECK-SAME: {broadcast_dimensions = dense<> : tensor<0xi64>}
949  %0 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<f32>, tensor<?xi32>) -> tensor<*xf32>
950  return %0 : tensor<*xf32>
951}
952
953//===----------------------------------------------------------------------===//
954// Complex op legalizations.
955//===----------------------------------------------------------------------===//
956
957// CHECK-LABEL: func @complex
958func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
959  // CHECK: chlo.broadcast_complex
960  %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
961  return %1 : tensor<3xcomplex<f32>>
962}
963
964// CHECK-LABEL: func @imag
965func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
966  // CHECK: "mhlo.imag"
967  %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
968  return %1 : tensor<3xf32>
969}
970
971// CHECK-LABEL: func @real
972func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
973  // CHECK: "mhlo.real"
974  %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
975  return %1 : tensor<3xf32>
976}
977
978//===----------------------------------------------------------------------===//
979// Concat op legalizations.
980//===----------------------------------------------------------------------===//
981
982// CHECK-LABEL: func @concat_v2
983func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
984  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
985  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
986  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
987  return %1 : tensor<6x3xf32>
988}
989
990// CHECK-LABEL: func @concat_v2_neg_axis
991func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
992  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
993
994  %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64>
995  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
996  return %1 : tensor<6x3xf32>
997}
998
999// CHECK-LABEL: func @concat_v2_1d_axis
1000func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
1001  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
1002
1003  %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64>
1004  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32>
1005  return %1 : tensor<3x6xf32>
1006}
1007
1008// CHECK-LABEL: func @concat_v2_non_const_axis
1009func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> {
1010  // CHECK: "tf.ConcatV2"
1011  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32>
1012  return %1 : tensor<3x6xf32>
1013}
1014
1015// CHECK-LABEL: func @concat_v2_unranked
1016func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
1017  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
1018  // CHECK: "tf.ConcatV2"
1019  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
1020  return %1 : tensor<*xf32>
1021}
1022
1023//===----------------------------------------------------------------------===//
1024// Pad op legalizations.
1025//===----------------------------------------------------------------------===//
1026
1027// CHECK-LABEL: func @padv2_1D
1028func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> {
1029  %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
1030  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1031  // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>,
1032  // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>,
1033  // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64>
1034  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32>
1035  return %1 : tensor<6xf32>
1036}
1037
1038// CHECK-LABEL: func @padv2_2D
1039func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1040  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64>
1041  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1042  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1043  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1044  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1045  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32>
1046  return %1 : tensor<6x9xf32>
1047}
1048
1049// CHECK-LABEL: func @padv2_i32_paddings
1050func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1051  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
1052  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1053  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1054  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1055  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1056  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
1057  return %1 : tensor<6x9xf32>
1058}
1059
1060// CHECK-LABEL: func @padv2_dynamic
1061func @padv2_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<1x2xi64>) -> tensor<?xf32> {
1062  // CHECK: "mhlo.transpose"({{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x2xi64>) -> tensor<2x1xi64>
1063  // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2x1xi64>) -> tensor<2xi64>
1064  // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
1065  // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
1066  // CHECK: "mhlo.dynamic_pad"({{.*}}) : (tensor<?xf32>, tensor<f32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<?xf32>
1067  %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor<?xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<?xf32>
1068  return %1 : tensor<?xf32>
1069}
1070
1071//===----------------------------------------------------------------------===//
1072// Identity op legalizations.
1073//===----------------------------------------------------------------------===//
1074
1075// CHECK-LABEL: func @identity
1076func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1077  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1078  %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1079  return %0: tensor<1xi32>
1080}
1081
1082// CHECK-LABEL: func @identityN
1083func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
1084  // CHECK-NEXT:  return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
1085  %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
1086  return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
1087}
1088
1089// CHECK-LABEL: func @stopgradient
1090func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1091  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1092  %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1093  return %0: tensor<1xi32>
1094}
1095
1096// CHECK-LABEL: func @preventgradient
1097func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1098  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1099  %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32>
1100  return %0: tensor<1xi32>
1101}
1102
1103// CHECK-LABEL: func @checkNumerics
1104func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1105  // CHECK-NEXT:  return %arg0 : tensor<1xf32>
1106  %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32>
1107  return %0: tensor<1xf32>
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// InfeedDequeueTuple legalization
1112//===----------------------------------------------------------------------===//
1113
1114// CHECK-LABEL: func @infeed_dequeue_tuple
1115func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) {
1116// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
1117// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = "", layout = [{{\[\[1, 3, 2, 0], \[1, 2, 0]]}}, unit]} : (!mhlo.token) -> tuple<tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>, !mhlo.token>
1118// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>, !mhlo.token>) -> tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>
1119// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>) -> tensor<1x8x4x4xi32>
1120// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>) -> tensor<1x100x1xf32>
1121// CHECK: return [[RES_1]], [[RES_2]]
1122  %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>)
1123  return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32>
1124}
1125
1126// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error
1127func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) {
1128  // We expect legalization to fail for dynamic shapes:
1129  // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}}
1130  %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>)
1131  return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32>
1132}
1133
1134// The following op sharding is used:
1135// Proto debug string:
1136//   type: TUPLE
1137//   tuple_shardings {
1138//     type: MAXIMAL
1139//     tile_assignment_dimensions: 1
1140//     tile_assignment_devices: 0
1141//   }
1142// Serialized string:
1143//   "\08\02*\08\08\01\1A\01\01\22\01\00"
1144
1145// CHECK-LABEL: infeed_dequeue_tuple_sharding
1146func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
1147  // CHECK: "mhlo.infeed"
1148  // An additional sharding is added at the end to account for token result.
1149  // Proto debug string:
1150  //   type: TUPLE
1151  //   tuple_shardings {
1152  //     type: MAXIMAL
1153  //     tile_assignment_dimensions: 1
1154  //     tile_assignment_devices: 0
1155  //   }
1156  //   tuple_shardings {
1157  //     type: MAXIMAL
1158  //     tile_assignment_dimensions: 1
1159  //     tile_assignment_devices: 0
1160  //   }
1161  // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00"
1162  %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
1163  return %0 : tensor<8xi32>
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// Nullary op legalizations.
1168//===----------------------------------------------------------------------===//
1169
1170// CHECK-LABEL: @const
1171func @const() -> tensor<2xi32> {
1172  // CHECK: mhlo.constant dense<0> : tensor<2xi32>
1173  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
1174  return %0: tensor<2xi32>
1175}
1176
1177// CHECK-LABEL: @const_dynamic_output
1178func @const_dynamic_output() -> tensor<*xi32> {
1179  // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32>
1180  // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32>
1181  %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
1182  // CHECK: return [[CAST]]
1183  return %0: tensor<*xi32>
1184}
1185
1186// CHECK-LABEL: @opaque_const
1187func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> {
1188  // CHECK-NOT: mhlo.constant
1189  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>>
1190  return %0 : tensor<!tf_type.variant<tensor<2xi32>>>
1191}
1192
1193//===----------------------------------------------------------------------===//
1194// Matmul op legalizations.
1195//===----------------------------------------------------------------------===//
1196
1197// CHECK-LABEL: matmul_notranspose
1198// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>)
1199func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> {
1200  // CHECK: "mhlo.dot"(%[[A]], %[[B]])
1201  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
1202
1203  return %0 : tensor<5x11xf32>
1204}
1205
1206// CHECK-LABEL: matmul_transpose_b
1207// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>)
1208func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1209  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1210  // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]])
1211  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1212
1213  return %0 : tensor<5x11xf32>
1214}
1215
1216// CHECK-LABEL: matmul_transpose_both
1217// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>)
1218func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1219  // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1220  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1221  // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]])
1222  %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1223
1224  return %0 : tensor<5x11xf32>
1225}
1226
1227// Verify that MatMul with ranked inputs are lowered to HLO.
1228// CHECK-LABEL: matmul_ranked
1229func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> {
1230  // CHECK: "mhlo.dot"
1231  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32>
1232
1233  return %0 : tensor<?x?xf32>
1234}
1235
1236// Verify that MatMul with unranked inputs are lowered to HLO.
1237// CHECK-LABEL: matmul_unranked
1238func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> {
1239  // CHECK: "mhlo.dot"
1240  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
1241
1242  return %0 : tensor<*xf32>
1243}
1244
1245// Verify SparseMatMul is legalized to dot.
1246// CHECK-LABEL: test_sparse_mat_mul
1247func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
1248  // CHECK: "mhlo.dot"
1249  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
1250  return %0: tensor<3x5xf32>
1251}
1252
1253// SparseMatMul where one operand needs to be transposed and the other one not.
1254//
1255// CHECK-LABEL:   @test_sparse_mat_mul_with_transpose
1256// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1257// CHECK-SAME:      %[[ARG1:.*]]: tensor<5x4xf32>
1258// CHECK-SAME:      -> tensor<3x5xf32>
1259// CHECK:           %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
1260// CHECK-SAME:        permutation = dense<[1, 0]>
1261// CHECK-SAME:        -> tensor<4x5xf32>
1262// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
1263// CHECK-SAME:        -> tensor<3x5xf32>
1264// CHECK:           return %[[RESULT]]
1265func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
1266  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
1267  return %0: tensor<3x5xf32>
1268}
1269
1270// SparseMatMul where one operand needs to be casted and the other one not.
1271//
1272// CHECK-LABEL:   @test_sparse_mat_mul_with_cast
1273// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1274// CHECK-SAME:      %[[ARG1:.*]]: tensor<4x5xbf16>
1275// CHECK-SAME:      -> tensor<3x5xf32>
1276// CHECK:           %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
1277// CHECK-SAME:        -> tensor<4x5xf32>
1278// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
1279// CHECK-SAME:        -> tensor<3x5xf32>
1280// CHECK:           return %[[RESULT]]
1281func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
1282  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
1283  return %0: tensor<3x5xf32>
1284}
1285
1286//===----------------------------------------------------------------------===//
1287// MatrixBandPart op legalizations.
1288//===----------------------------------------------------------------------===//
1289
1290// CHECK-LABEL: matrix_band_part
1291// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1292func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
1293  // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
1294  // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
1295
1296  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
1297  // CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
1298  // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1299
1300  // CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
1301  // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1302  // CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64>
1303
1304  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64>
1305  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64>
1306  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64>
1307  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1>
1308
1309  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1>
1310
1311  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
1312
1313  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
1314
1315  // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]]
1316  // CHECK-DAG: return %[[R]]
1317  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
1318  return %0 : tensor<64x64xbf16>
1319}
1320
1321// CHECK-LABEL: matrix_band_part_2
1322// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1323func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> {
1324  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64>
1325  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64>
1326  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64>
1327
1328  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1>
1329
1330  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1>
1331  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
1332
1333  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
1334
1335  // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]]
1336  // CHECK-DAG: return %[[R]]
1337  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16>
1338  return %0 : tensor<12x24x48xbf16>
1339}
1340
1341// CHECK-LABEL: matrix_band_part_3
1342// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1343func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
1344  // CHECK: "tf.MatrixBandPart"
1345  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
1346  return %0 : tensor<*xbf16>
1347}
1348
1349// CHECK-LABEL: matrix_band_part_4
1350// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1351func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> {
1352  // This one should lower.
1353  // CHECK-NOT: "tf.MatrixBandPart"
1354  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16>
1355  return %0 : tensor<24x48xbf16>
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// MaxPool op legalizations.
1360//===----------------------------------------------------------------------===//
1361
1362// CHECK-LABEL: maxpool_valid_padding
1363// CHECK-SAME: %[[ARG:.*]]: tensor
1364func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1365  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
1366  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1367  // CHECK: mhlo.maximum
1368  // CHECK: mhlo.return
1369  // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}
1370
1371  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1372  return %0 : tensor<2x3x5x7xi32>
1373}
1374
1375// CHECK-LABEL: maxpool_same_padding
1376// CHECK-SAME: %[[ARG:.*]]: tensor
1377func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> {
1378  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1379
1380  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32>
1381  return %0 : tensor<2x4x7x7xi32>
1382}
1383
1384// CHECK-LABEL: maxpool_3d_valid_padding
1385// CHECK-SAME: %[[ARG:.*]]: tensor
1386func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> {
1387  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
1388  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1389  // CHECK: mhlo.maximum
1390  // CHECK: mhlo.return
1391  // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}
1392
1393  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32>
1394  return %0 : tensor<2x8x3x5x7xf32>
1395}
1396
1397// CHECK-LABEL: maxpool_3d_same_padding
1398// CHECK-SAME: %[[ARG:.*]]: tensor
1399func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> {
1400  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1401
1402  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32>
1403  return %0 : tensor<2x8x4x7x7xf32>
1404}
1405
1406// CHECK-LABEL: maxpool_explicit_padding
1407func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1408  // CHECK: tf.MaxPool
1409  // TODO(b/165938852): need to support explicit padding in max_pool.
1410
1411  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1412  return %0 : tensor<2x3x5x7xi32>
1413}
1414
1415//===----------------------------------------------------------------------===//
1416// MaxPoolGrad op legalizations.
1417//===----------------------------------------------------------------------===//
1418
1419// CHECK-LABEL: @max_pool_grad_valid
1420// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32>
1421func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
1422  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1423  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
1424  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1425  // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
1426  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
1427  // CHECK: },  {
1428  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1429  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1430  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
1431  // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
1432  // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32>
1433  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1434     data_format = "NHWC",
1435     ksize = [1, 2, 2, 1],
1436     padding = "VALID",
1437     strides = [1, 2, 2, 1]
1438  } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32>
1439  return %result : tensor<10x24x24x64xf32>
1440}
1441
1442// CHECK-LABEL: @max_pool_3d_grad_valid
1443// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32>
1444func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> {
1445  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1446  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
1447  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1448  // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
1449  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
1450  // CHECK: },  {
1451  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1452  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1453  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
1454  // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
1455  // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
1456  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
1457  return %result : tensor<10x8x24x24x64xf32>
1458}
1459
1460// CHECK-LABEL: @max_pool_grad_same
1461func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
1462  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1463  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1464     data_format = "NHWC",
1465     ksize = [1, 2, 3, 1],
1466     padding = "SAME",
1467     strides = [1, 4, 4, 1]
1468  } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32>
1469  return %result : tensor<2x13x25x7xf32>
1470}
1471
1472// CHECK-LABEL: @max_pool_3d_grad_same
1473func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> {
1474  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1475  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32>
1476  return %result : tensor<2x8x13x25x7xf32>
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// OneHot op legalizations.
1481//===----------------------------------------------------------------------===//
1482
1483// CHECK-LABEL:one_hot
1484func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> {
1485  // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32>
1486  // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32>
1487  // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1>
1488  // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1489  // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1490  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
1491  // CHECK: return %[[RESULT]] : tensor<3x5xf32>
1492  %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
1493  %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32>
1494  return %result : tensor<3x5xf32>
1495}
1496
1497//===----------------------------------------------------------------------===//
1498// tf.OutfeedEnqueueTuple legalization
1499//===----------------------------------------------------------------------===//
1500
1501// CHECK-LABEL: func @outfeed_enqueue_tuple
1502// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>)
1503func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () {
1504// CHECK: [[TUPLE:%.*]] = "mhlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>>
1505// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
1506// CHECK: "mhlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
1507  "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> ()
1508  return
1509}
1510
1511//===----------------------------------------------------------------------===//
1512// Pack op legalizations.
1513//===----------------------------------------------------------------------===//
1514
1515// CHECK-LABEL: func @pack
1516func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
1517  // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32>
1518  // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32>
1519  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
1520
1521  %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
1522  return %0 : tensor<2x2xi32>
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// PartitionedCall op legalization.
1527//===----------------------------------------------------------------------===//
1528
1529// CHECK-LABEL: func @partitioned_call
1530func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1531  // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32>
1532  %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>)
1533  return %0 : tensor<i32>
1534}
1535
1536func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1537  return %arg0 : tensor<i32>
1538}
1539
1540// CHECK-LABEL: func @partitioned_call_multi_input
1541func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1542  // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
1543  %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
1544  return %0 : tensor<i32>
1545}
1546
1547func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1548  return %arg0 : tensor<i32>
1549}
1550
1551// CHECK-LABEL: func @partitioned_call_multi_in_out
1552func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1553  // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1554  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1555  return %0, %1 : tensor<i32>, tensor<i32>
1556}
1557
1558func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1559  return %arg1, %arg0 : tensor<i32>, tensor<i32>
1560}
1561
1562// CHECK-LABEL: func @unhandled_partitioned_call
1563func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1564  // The argument types don't match the parameter types for the
1565  // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not
1566  // for a standard CallOp, so this op can't be lowered.
1567  // CHECK: "tf.PartitionedCall"
1568  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1569  return %0, %1 : tensor<i32>, tensor<i32>
1570}
1571
1572// CHECK-LABEL: func @unhandled_partitioned_call_2
1573func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1574  // CHECK: "tf.PartitionedCall"
1575  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1576  return %0, %1 : tensor<i32>, tensor<i32>
1577}
1578
1579
1580//===----------------------------------------------------------------------===//
1581// ReverseV2 op legalization.
1582//===----------------------------------------------------------------------===//
1583
1584// CHECK-LABEL: @reverse_func_32
1585func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1586  %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>)
1587
1588  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1589  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32>
1590
1591  // CHECK: return [[VAL]] : tensor<5xi32>
1592  return %reversed : tensor<5xi32>
1593}
1594
1595// CHECK-LABEL: @reverse_func_64
1596func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1597  %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>)
1598
1599  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1600  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32>
1601
1602  // CHECK: return [[VAL]] : tensor<5xi32>
1603  return %reversed : tensor<5xi32>
1604}
1605
1606// CHECK-LABEL: @reverse_func_neg
1607func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> {
1608  %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1609
1610  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>}
1611  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32>
1612
1613  // CHECK: return [[VAL]] : tensor<5x5xi32>
1614  return %reversed : tensor<5x5xi32>
1615}
1616
1617//===----------------------------------------------------------------------===//
1618// StatefulPartitionedCall op legalization.
1619//===----------------------------------------------------------------------===//
1620
1621// CHECK-LABEL: func @stateful_partitioned_call
1622// CHECK-SAME: [[ARG:%.+]]: tensor<i32>
1623func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1624  // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32>
1625  %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>)
1626  return %0 : tensor<i32>
1627}
1628
1629func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1630  return %arg0 : tensor<i32>
1631}
1632
1633// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out
1634// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>)
1635func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1636  // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1637  %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1638  return %0, %1 : tensor<i32>, tensor<i32>
1639}
1640
1641func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1642  return %arg1, %arg0 : tensor<i32>, tensor<i32>
1643}
1644
1645//===----------------------------------------------------------------------===//
1646// Elu op legalizations.
1647//===----------------------------------------------------------------------===//
1648
1649// CHECK-LABEL: func @elu
1650func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1651  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1xf32>) -> tensor<1xf32>
1652  // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg0, %[[ZERO]]) {comparison_direction = "GT"}
1653  // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0)
1654  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
1655  // CHECK: return %[[RESULT]]
1656  %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
1657  return %0: tensor<1xf32>
1658}
1659
1660// CHECK-LABEL: func @elu_unranked
1661func @elu_unranked(%arg0: tensor<?xf32>) -> tensor<?xf32> {
1662  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<?xf32>) -> tensor<?xf32>
1663  // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg0, %[[ZERO]]) {comparison_direction = "GT"}
1664  // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0)
1665  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
1666  // CHECK: return %[[RESULT]]
1667  %0 = "tf.Elu"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
1668  return %0: tensor<?xf32>
1669}
1670
1671// CHECK-LABEL: func @elu_grad
1672// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1673func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
1674  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1675  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
1676  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"}
1677  // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>}
1678  // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]])
1679  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]])
1680  // CHECK: return %[[RESULT]]
1681  %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
1682  return %2 : tensor<4x8xf32>
1683}
1684
1685//===----------------------------------------------------------------------===//
1686// Relu op legalizations.
1687//===----------------------------------------------------------------------===//
1688
1689// CHECK-LABEL: func @relu
1690func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1691  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1692  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
1693  %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1694  return %0: tensor<1xi32>
1695}
1696
1697// CHECK-LABEL: func @relu_unranked
1698func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1699  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1700  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
1701  %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1702  return %0: tensor<?xi32>
1703}
1704
1705// CHECK-LABEL: func @relu6
1706func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1707  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1708  // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1709  // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
1710  %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1711  return %0: tensor<1xi32>
1712}
1713
1714// CHECK-LABEL: func @relu6_unranked
1715func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1716  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1717  // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1718  // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
1719  %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1720  return %0: tensor<?xi32>
1721}
1722
1723// CHECK-LABEL: func @relu_grad_unranked
1724// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<?x?xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1725func @relu_grad_unranked(%gradients: tensor<?x?xf32>, %features: tensor<?x?xf32>) -> tensor<?x?xf32> {
1726  // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32>
1727  // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg1, %0) {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
1728  // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1729  // CHECK-DAG: return %[[RESULT]] : tensor<?x?xf32>
1730  %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1731  return %2 : tensor<?x?xf32>
1732}
1733
1734// CHECK-LABEL: func @leaky_relu
1735func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} {
1736   // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
1737   // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%[[ALPHA]]) {broadcast_sizes = dense<[1, 4, 4, 3]> : tensor<4xi64>} : (tensor<f32>) -> tensor<1x4x4x3xf32>
1738   // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4x3xf32>
1739   // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[BCASTALPHA]] : tensor<1x4x4x3xf32>
1740   // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1>
1741   // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1742   // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32>
1743    %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32>
1744    return %0 : tensor<1x4x4x3xf32>
1745}
1746
1747// CHECK-LABEL: func @leaky_relu_grad
1748func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} {
1749   // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32>
1750   // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%0) {broadcast_sizes = dense<[1, 4, 4]> : tensor<3xi64>} : (tensor<f32>) -> tensor<1x4x4xf32>
1751   // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4xf32>
1752   // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[BCASTALPHA]] : tensor<1x4x4xf32>
1753   // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1>
1754   // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
1755   // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32>
1756    %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32>
1757    return %0 : tensor<1x4x4xf32>
1758}
1759
1760// CHECK-LABEL: func @softsign
1761func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> {
1762    // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32>
1763    // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<4x10xf32>) -> tensor<4x10xf32>
1764    // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32>
1765    // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32>
1766    // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32>
1767    %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32>
1768    return %0 : tensor<4x10xf32>
1769}
1770
1771// CHECK-LABEL: func @softsign_unranked
1772func @softsign_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
1773    // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
1774    // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<*xf32>) -> tensor<*xf32>
1775    // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<*xf32>
1776    // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<*xf32>
1777    // CHECK-NEXT: return %[[DIV]] : tensor<*xf32>
1778    %0 = "tf.Softsign"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
1779    return %0 : tensor<*xf32>
1780}
1781
1782// CHECK-LABEL: func @softsign_grad
1783func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> {
1784
1785    // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
1786    // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<4x10xf32>) -> tensor<4x10xf32>
1787    // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4x10xf32>) -> tensor<4x10xf32>
1788    // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32>
1789    // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
1790    // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32>
1791    %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
1792    return %0 : tensor<4x10xf32>
1793}
1794
1795//===----------------------------------------------------------------------===//
1796// Roll op legalizations.
1797//===----------------------------------------------------------------------===//
1798
1799// CHECK-LABEL: func @Roll_0D
1800func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor<i32>) -> tensor<512xi32> {
1801  %axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
1802  //      CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1803  //      CHECK: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor<i32>
1804  //      CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor<i32>
1805  //      CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor<i32>
1806  //      CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor<i32>
1807  //      CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) {dimension = 0 : i64}
1808  //      CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor<i32>
1809  //      CHECK: "mhlo.dynamic-slice"(%[[CONCAT]], %[[OFFSET]])
1810  // CHECK-SAME:    {slice_sizes = dense<512> : tensor<1xi64>}
1811  // CHECK-SAME:    (tensor<1024xi32>, tensor<i32>) -> tensor<512xi32>
1812  %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor<i32>, tensor<i32>) -> tensor<512xi32>
1813  return %0 : tensor<512xi32>
1814}
1815
1816//===----------------------------------------------------------------------===//
1817// Select op legalizations.
1818//===----------------------------------------------------------------------===//
1819
1820// CHECK-LABEL: func @select_batch_static
1821func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
1822  // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1>
1823  // CHECK: "mhlo.select"(%[[BCAST]], %arg1, %arg2)
1824  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
1825  return %0: tensor<2x6x8xi32>
1826}
1827
1828// CHECK-LABEL: func @select_batch_static_r1
1829func @select_batch_static_r1(%arg0: tensor<i1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
1830  // CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
1831  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
1832  return %0: tensor<2x6x8xi32>
1833}
1834
1835// CHECK-LABEL: func @select_batch_static_all_same
1836func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> {
1837  // CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
1838  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32>
1839  return %0: tensor<2x6x8xi32>
1840}
1841
1842// CHECK-LABEL: func @select_batch_dynamic_r1
1843func @select_batch_dynamic_r1(%arg0: tensor<?xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> {
1844  // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?xi1> -> tensor<1xindex>
1845  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex>
1846  // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex>
1847  // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex>
1848  // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
1849  // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
1850  // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<?xindex>
1851  // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]]
1852  // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) {
1853  // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex>
1854  // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi1>, tensor<3xindex>) -> tensor<?x?x8xi1>
1855  // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCAST]], %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
1856  // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32>
1857  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
1858  return %0: tensor<?x?x8xi32>
1859}
1860
1861// CHECK-LABEL: func @select_batch_dynamic
1862func @select_batch_dynamic(%arg0: tensor<?x?x8xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> {
1863  // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?x8xi1> -> tensor<3xindex>
1864  // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex>
1865  // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex>
1866  // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex>
1867  // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex>
1868  // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]]
1869  // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) {
1870  // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
1871  // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32>
1872  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32>
1873  return %0: tensor<?x?x8xi32>
1874}
1875
1876// CHECK-LABEL: testSelectInvalidUnranked
1877func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
1878  // CHECK-NEXT: tf.Select
1879  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16>
1880  return %0: tensor<*xf16>
1881}
1882
1883// CHECK-LABEL: testSelectThenUnranked
1884func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> {
1885  // CHECK-NEXT: tf.Select
1886  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16>
1887  return %0: tensor<*xf16>
1888}
1889
1890// CHECK-LABEL: testSelectElseUnranked
1891func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
1892  // CHECK-NEXT: tf.Select
1893  %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16>
1894  return %0: tensor<*xf16>
1895}
1896
1897// CHECK-LABEL: func @selectv2_dynamic_ranked
1898func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
1899  // CHECK: chlo.broadcast_select
1900  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
1901  return %0: tensor<2x?x8xi32>
1902}
1903
1904// CHECK-LABEL: func @selectv2_unranked
1905func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
1906  // CHECK: chlo.broadcast_select
1907  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
1908  return %0: tensor<*xi32>
1909}
1910
1911//===----------------------------------------------------------------------===//
1912// Fast Fourier Transform op legalization.
1913//===----------------------------------------------------------------------===//
1914
1915// CHECK-LABEL: func @fft_1D
1916func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1917  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>>
1918  %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1919  return %0 : tensor<8xcomplex<f32>>
1920}
1921
1922// CHECK-LABEL: func @ifft_1D
1923func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1924  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>>
1925  %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1926  return %0 : tensor<8xcomplex<f32>>
1927}
1928
1929// CHECK-LABEL: func @rfft_1D
1930func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
1931  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1932  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>
1933  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1934  return %0 : tensor<8xcomplex<f32>>
1935}
1936
1937// CHECK-LABEL: func @rfft_1D_padded
1938func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex<f32>> {
1939  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1940  // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32>
1941  // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>
1942  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1943  return %0 : tensor<8xcomplex<f32>>
1944}
1945
1946// CHECK-LABEL: func @rfft_1D_sliced
1947func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex<f32>> {
1948  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1949  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32>
1950  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32>
1951  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex<f32>>
1952  return %0 : tensor<2x8xcomplex<f32>>
1953}
1954
1955// CHECK-LABEL: func @irfft_1D
1956func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<5xf32> {
1957  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1958  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>>
1959  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>>
1960  %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<5xf32>
1961  return %0 : tensor<5xf32>
1962}
1963
1964// CHECK-LABEL: fft_1D_dynamic
1965func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1966  // CHECK: "tf.FFT"
1967  %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1968  return %0 : tensor<8xcomplex<f32>>
1969}
1970
1971// CHECK-LABEL: rfft_1D_dynamic
1972func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> {
1973  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1974  // CHECK: "tf.RFFT"
1975  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1976  return %0 : tensor<8xcomplex<f32>>
1977}
1978
1979//===----------------------------------------------------------------------===//
1980// Shape op legalization.
1981//===----------------------------------------------------------------------===//
1982
1983// CHECK-LABEL: func @shape_1D
1984func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
1985  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
1986  // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32>
1987  %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
1988
1989  // CHECK: return [[TENSOR]]
1990  return %0 : tensor<1xi32>
1991}
1992
1993// CHECK-LABEL: func @shape_2D
1994func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
1995  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
1996  // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32>
1997  %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
1998
1999  // CHECK: return [[TENSOR]]
2000  return %0 : tensor<2xi32>
2001}
2002
2003// CHECK-LABEL: func @shape_rankless
2004func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
2005  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
2006  // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<?xindex> to tensor<?xi32>
2007  %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
2008
2009  // CHECK: return [[TENSOR]]
2010  return %0 : tensor<?xi32>
2011}
2012
2013//===----------------------------------------------------------------------===//
2014// Transpose op legalization.
2015//===----------------------------------------------------------------------===//
2016
2017// CHECK-LABEL: @transpose_noop
2018func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
2019  %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2020  // CHECK: return %arg0
2021  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32>
2022  return %0 : tensor<2x3xf32>
2023}
2024
2025// CHECK-LABEL: @transpose_2d
2026func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
2027  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2028  // CHECK: "mhlo.transpose"
2029  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
2030  return %0 : tensor<3x2xf32>
2031}
2032
2033// CHECK-LABEL: @transpose_3d_int32
2034func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
2035  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2036  // CHECK: "mhlo.transpose"
2037  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32>
2038  return %0 : tensor<3x2x1xf32>
2039}
2040
2041// CHECK-LABEL: @transpose_3d
2042func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
2043  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>)
2044  // CHECK: "mhlo.transpose"
2045  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
2046  return %0 : tensor<3x2x1xf32>
2047}
2048
2049// CHECK-LABEL: @transpose_dynamic_2d
2050func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
2051  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2052  // CHECK: "mhlo.transpose"
2053  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32>
2054  return %0 : tensor<4x?xf32>
2055}
2056
2057// CHECK-LABEL: @transpose_unranked_2d
2058func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2059  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2060  // CHECK: "mhlo.transpose"
2061  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32>
2062  return %0 : tensor<*xf32>
2063}
2064
2065
2066//===----------------------------------------------------------------------===//
2067// Unary op legalizations.
2068//===----------------------------------------------------------------------===//
2069
2070// CHECK-LABEL: @abs
2071func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2072  // CHECK:  "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2073  %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2074  return %0 : tensor<2xf32>
2075}
2076
2077// CHECK-LABEL: func @abs_dynamic
2078func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2079  // CHECK:  "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2080  %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2081  return %0 : tensor<?xf32>
2082}
2083
2084// CHECK-LABEL: func @abs_unranked
2085func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2086  // CHECK:  "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2087  %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2088  return %0 : tensor<*xf32>
2089}
2090
2091// CHECK-LABEL: @acos
2092// CHLO-LABEL: @acos
2093func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2094  // CHECK:  chlo.acos %arg0 : tensor<2xf32>
2095// CHLO:   %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"}
2096// CHLO:   %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00>
2097// CHLO:   %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00>
2098// CHLO:   %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0
2099// CHLO:   %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]]
2100// CHLO:   %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]])
2101// CHLO:   %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00>
2102// CHLO:   %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0
2103// CHLO:   %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]]
2104// CHLO:   %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]]
2105// CHLO:   %[[VAL_12:.*]] = mhlo.constant dense<3.14159274>
2106// CHLO:   %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]])
2107// CHLO:       return %[[VAL_13]] : tensor<2xf32>
2108  %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2109  return %0 : tensor<2xf32>
2110}
2111
2112// CHECK-LABEL: @acos_complex
2113// CHLO-LABEL: @acos_complex
2114func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2115  // CHLO: tf.Acos
2116  %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2117  return %0 : tensor<2xcomplex<f32>>
2118}
2119
2120// CHECK-LABEL: @acos_dynamic
2121// CHLO-LABEL: @acos_dynamic
2122func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2123  // CHECK:  chlo.acos %arg0 : tensor<*xf32>
2124  // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be
2125  // lowered further on ranked tensors.  Unranked CHLO must be transformed to
2126  // ranked code before further lowering.
2127  // CHLO: "tf.Acos"
2128  %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2129  return %0 : tensor<*xf32>
2130}
2131
2132// CHECK-LABEL: @tan
2133// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2134// CHLO-LABEL: @tan
2135// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2136func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> {
2137  // CHECK: chlo.tan %[[ARG]] : tensor<2xf32>
2138  // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]])
2139  // CHLO  %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]])
2140  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2141  %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32>
2142  return %result : tensor<2xf32>
2143}
2144
2145// CHECK-LABEL: @tan_unranked
2146// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2147// CHLO-LABEL: @tan_unranked
2148// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2149func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> {
2150  // CHECK: chlo.tan %[[ARG]] : tensor<*xf32>
2151  // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]])
2152  // CHLO  %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]])
2153  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2154  %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
2155  return %result : tensor<*xf32>
2156}
2157
2158// CHECK-LABEL: func @cast_dynamic_i2f
2159func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
2160  // CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2161  %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2162  return %0 : tensor<?xf32>
2163}
2164
2165// CHECK-LABEL: func @cast_i2f
2166func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
2167  // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2168  %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2169  return %0 : tensor<2xf32>
2170}
2171
2172// CHECK-LABEL: func @cast_c2f
2173func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2174  // CHECK: tf.Cast
2175  %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2176  return %0 : tensor<2xf32>
2177}
2178
2179// CHECK-LABEL: @ceil
2180func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2181  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2182  %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2183  return %0 : tensor<2xf32>
2184}
2185
2186// CHECK-LABEL: func @ceil_dynamic
2187func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2188  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2189  %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2190  return %0 : tensor<?xf32>
2191}
2192
2193// CHECK-LABEL: func @ceil_unranked
2194func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2195  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2196  %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2197  return %0 : tensor<*xf32>
2198}
2199
2200// CHECK-LABEL: @complex_abs
2201func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2202  // CHECK:  "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2203  %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2204  return %0 : tensor<2xf32>
2205}
2206
2207// CHECK-LABEL: @cos
2208func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2209  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2210  %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2211  return %0 : tensor<2xf32>
2212}
2213
2214// CHECK-LABEL: func @cos_dynamic
2215func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2216  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2217  %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2218  return %0 : tensor<?xf32>
2219}
2220
2221// CHECK-LABEL: func @cos_unranked
2222func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2223  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2224  %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2225  return %0 : tensor<*xf32>
2226}
2227
2228// CHECK-LABEL: @exp
2229func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2230  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2231  %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2232  return %0 : tensor<2xf32>
2233}
2234
2235// CHECK-LABEL: @expm1
2236func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2237  // CHECK:  "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2238  %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2239  return %0 : tensor<2xf32>
2240}
2241
2242// CHECK-LABEL: func @exp_dynamic
2243func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2244  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2245  %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2246  return %0 : tensor<?xf32>
2247}
2248
2249// CHECK-LABEL: func @exp_unranked
2250func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2251  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2252  %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2253  return %0 : tensor<*xf32>
2254}
2255
2256// CHECK-LABEL: @floor
2257func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2258  // CHECK:  "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2259  %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2260  return %0 : tensor<2xf32>
2261}
2262
2263// CHECK-LABEL: func @floor_dynamic
2264func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2265  // CHECK:  "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2266  %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2267  return %0 : tensor<?xf32>
2268}
2269
2270// CHECK-LABEL: func @floor_unranked
2271func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2272  // CHECK:  "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2273  %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2274  return %0 : tensor<*xf32>
2275}
2276
2277// CHECK-LABEL: func @invert_op_unranked
2278func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
2279  // CHECK:  "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
2280  %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
2281  return %0 : tensor<*xi32>
2282}
2283
2284// CHECK-LABEL: @is_finite
2285func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
2286  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2287  %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2288  return %0 : tensor<2xi1>
2289}
2290
2291// CHECK-LABEL: func @is_finite_dynamic
2292func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
2293  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2294  %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2295  return %0 : tensor<?xi1>
2296}
2297
2298// CHECK-LABEL: func @is_finite_unranked
2299func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
2300  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2301  %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2302  return %0 : tensor<*xi1>
2303}
2304
2305// CHECK-LABEL: @log
2306func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2307  // CHECK:  "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2308  %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2309  return %0 : tensor<2xf32>
2310}
2311
2312// CHECK-LABEL: func @log_dynamic
2313func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2314  // CHECK:  "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2315  %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2316  return %0 : tensor<?xf32>
2317}
2318
2319// CHECK-LABEL: func @log_unranked
2320func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2321  // CHECK:  "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2322  %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2323  return %0 : tensor<*xf32>
2324}
2325
2326// CHECK-LABEL: @log1p
2327func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2328  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2329  %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2330  return %0 : tensor<2xf32>
2331}
2332
2333// CHECK-LABEL: func @log1p_dynamic
2334func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2335  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2336  %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2337  return %0 : tensor<?xf32>
2338}
2339
2340// CHECK-LABEL: func @log1p_unranked
2341func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2342  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2343  %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2344  return %0 : tensor<*xf32>
2345}
2346
2347// CHECK-LABEL: func @not_op_unranked
2348func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> {
2349  // CHECK:  "mhlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
2350  %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
2351  return %0 : tensor<*xi1>
2352}
2353
2354// CHECK-LABEL: @neg
2355func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2356  // CHECK:  "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2357  %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2358  return %0 : tensor<2xf32>
2359}
2360
2361// CHECK-LABEL: func @neg_dynamic
2362func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2363  // CHECK:  "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2364  %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2365  return %0 : tensor<?xf32>
2366}
2367
2368// CHECK-LABEL: func @neg_unranked
2369func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2370  // CHECK:  "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2371  %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2372  return %0 : tensor<*xf32>
2373}
2374
2375// CHECK-LABEL: @sigmoid
2376func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2377  // CHECK: mhlo.logistic
2378  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2379  return %0 : tensor<2xf32>
2380}
2381
2382// CHECK-LABEL: @sigmoid_complex
2383func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2384  // CHECK: mhlo.logistic
2385  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2386  return %0 : tensor<2xcomplex<f32>>
2387}
2388
2389// CHECK-LABEL: @sigmoid_unranked
2390func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2391  // CHECK: mhlo.logistic
2392  %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2393  return %0 : tensor<*xf32>
2394}
2395
2396
2397// CHECK-LABEL: @sigmoid_grad
2398func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
2399  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xf32>
2400  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32>
2401  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xf32>
2402  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
2403  // CHECK: return [[MUL1]]
2404  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
2405  return %0 : tensor<2xf32>
2406}
2407
2408// CHECK-LABEL: @sigmoid_grad_complex
2409func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2410  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>>
2411  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>
2412  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>>
2413  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>>
2414  // CHECK: return [[MUL1]]
2415  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2416  return %0 : tensor<2xcomplex<f32>>
2417}
2418
2419// CHECK-LABEL: @sigmoid_grad_dynamic
2420func @sigmoid_grad_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
2421  // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2422  // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
2423  // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2424  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
2425  return %0 : tensor<?xf32>
2426}
2427
2428// CHECK-LABEL: @sin
2429func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2430  // CHECK:  "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2431  %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2432  return %0 : tensor<2xf32>
2433}
2434
2435// CHECK-LABEL: func @sin_dynamic
2436func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2437  // CHECK:  "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2438  %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2439  return %0 : tensor<?xf32>
2440}
2441
2442// CHECK-LABEL: func @sin_unranked
2443func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2444  // CHECK:  "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2445  %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2446  return %0 : tensor<*xf32>
2447}
2448
2449// CHECK-LABEL: func @rsqrt
2450func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2451  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2452  %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2453  return %0 : tensor<2xf32>
2454}
2455
2456// CHECK-LABEL: func @rsqrt_dynamic
2457func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2458  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2459  %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2460  return %0 : tensor<?xf32>
2461}
2462
2463// CHECK-LABEL: func @rsqrt_unranked
2464func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2465  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2466  %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2467  return %0 : tensor<*xf32>
2468}
2469
2470// CHECK-LABEL: func @sqrt
2471func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2472  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2473  %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2474  return %0 : tensor<2xf32>
2475}
2476
2477// CHECK-LABEL: func @sqrt_dynamic
2478func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2479  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2480  %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2481  return %0 : tensor<?xf32>
2482}
2483
2484// CHECK-LABEL: func @sqrt_unranked
2485func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2486  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2487  %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2488  return %0 : tensor<*xf32>
2489}
2490
2491// CHECK-LABEL: func @tanh
2492func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2493  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2494  %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2495  return %0 : tensor<2xf32>
2496}
2497
2498// CHECK-LABEL: func @tanh_dynamic
2499func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2500  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2501  %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2502  return %0 : tensor<?xf32>
2503}
2504
2505// CHECK-LABEL: func @tanh_unranked
2506func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2507  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2508  %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2509  return %0 : tensor<*xf32>
2510}
2511
2512// CHECK-LABEL: func @bitcast
2513func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2514  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2515  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2516  return %0 : tensor<2xf32>
2517}
2518
2519// CHECK-LABEL: func @bitcast_dynamic
2520func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2521  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2522  %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2523  return %0 : tensor<?xf32>
2524}
2525
2526// CHECK-LABEL: func @bitcast_unranked
2527func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2528  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2529  %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2530  return %0 : tensor<*xf32>
2531}
2532
2533// CHECK-LABEL: func @bitcast_same_widths
2534func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
2535  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
2536  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
2537  return %0 : tensor<2xi32>
2538}
2539
2540// CHECK-LABEL: func @bitcast_smaller_input_width
2541func @bitcast_smaller_input_width(%arg0: tensor<2xi8>) -> tensor<2xi64> {
2542  // CHECK:  "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64>
2543  %0 = "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64>
2544  return %0 : tensor<2xi64>
2545}
2546
2547// CHECK-LABEL: func @bitcast_smaller_output_width
2548func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2xf16> {
2549  // CHECK:  "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16>
2550  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16>
2551  return %0 : tensor<2xf16>
2552}
2553
2554// CHECK-LABEL: reshape
2555func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> {
2556  // CHECK:  "mhlo.reshape"
2557  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
2558  return %0 : tensor<2x1xf32>
2559}
2560
2561// CHECK-LABEL: reshape_dynamic
2562func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
2563  // CHECK:  "chlo.dynamic_reshape"
2564  // CHLO:  mhlo.compute_reshape_shape
2565  // CHLO:  mhlo.dynamic_reshape
2566  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
2567  return %0 : tensor<?x?xf32>
2568}
2569
2570// CHECK-LABEL: reshape_unranked
2571// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
2572// CHECK-SAME: %[[TARGET_SHAPE:.*]]: tensor<2xi32>
2573func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
2574  // CHECK:  "chlo.dynamic_reshape"
2575  // CHLO:  shape.shape_of
2576  // CHLO:  shape.num_elements
2577  // CHLO:  mhlo.cstr_reshapable
2578  // CHLO:  assuming{{.*}}{
2579  // CHLO:   mhlo.compute_reshape_shape
2580  // CHLO:   mhlo.dynamic_reshape
2581  // CHLO:  }
2582  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
2583  return %0 : tensor<?x?xf32>
2584}
2585
2586// CHECK-LABEL: squeeze
2587func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
2588  // CHECK: "mhlo.reshape"
2589  %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
2590  return %0 : tensor<1x10xf32>
2591}
2592
2593// CHECK-LABEL: squeeze_dynamic
2594func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
2595  // CHECK: "tf.Squeeze"
2596  %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
2597  return %0 : tensor<*xf32>
2598}
2599
2600// CHECK-LABEL: expand_dims
2601func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
2602  // CHECK: "mhlo.reshape"
2603  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
2604  return %0 : tensor<1x2xf32>
2605}
2606
2607// CHECK-LABEL: expand_dims_dynamic
2608func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> {
2609  %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>)
2610
2611  // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0
2612  // CHECK-DAG: %[[CST0:.+]] = constant 0
2613  // CHECK-DAG: %[[CST1:.+]] = constant 1
2614  // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]]
2615  // CHECK-DAG: %[[CST1_0:.+]] = constant 1
2616  // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]]
2617  // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]]
2618  // CHECK-DAG: %[[RESHAPE:.+]] = "mhlo.dynamic_reshape"(%arg0, %[[TOEXTENTS]])
2619  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32>
2620
2621  // CHECK: return %[[RESHAPE]]
2622  return %0 : tensor<?x1x?xf32>
2623}
2624
2625// CHECK-LABEL: expand_dynamic_dims_rank1_axis
2626func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> {
2627  %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
2628
2629  // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0
2630  // CHECK-DAG: %[[CST0:.+]] = constant 0
2631  // CHECK-DAG: %[[CST1:.+]] = constant 1
2632  // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]]
2633  // CHECK-DAG: %[[CST1_0:.+]] = constant 1
2634  // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]]
2635  // CHECK-DAG: %[[CST2:.+]] = constant 2
2636  // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]]
2637  // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]]
2638  // CHECK-DAG: %[[RESHAPE:.+]] = "mhlo.dynamic_reshape"(%arg0, %[[TOEXTENTS]])
2639  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32>
2640
2641  // CHECK: return %[[RESHAPE]]
2642  return %0 : tensor<?x1x?x4xf32>
2643}
2644
2645// CHECK-LABEL: func @sign
2646// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
2647func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
2648  // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]])
2649  // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32>
2650  %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>)
2651  return %0 : tensor<1x2x3x4xf32>
2652}
2653
2654// CHECK-LABEL: func @sign_dynamic
2655func @sign_dynamic(%arg0: tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> {
2656  // CHECK: "mhlo.sign"(%arg0) : (tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
2657  // CHECK: "mhlo.compare"({{.*}}) {comparison_direction = "NE"} : (tensor<?x2x3x?xf32>, tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xi1>
2658  // CHECK: shape.shape_of %arg0 : tensor<?x2x3x?xf32> -> tensor<4xindex>
2659  // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xindex>) -> tensor<?x2x3x?xf32>
2660  // CHECK: "mhlo.select"({{.*}}) : (tensor<?x2x3x?xi1>, tensor<?x2x3x?xf32>, tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32>
2661  // CHECK: return {{.*}} : tensor<?x2x3x?xf32>
2662  %0 = "tf.Sign"(%arg0) : (tensor<?x2x3x?xf32>) -> (tensor<?x2x3x?xf32>)
2663  return %0 : tensor<?x2x3x?xf32>
2664}
2665
2666// CHECK-LABEL: slice_constant_start
2667func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
2668  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64>
2669  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
2670  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2671  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2672  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2673  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2674  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
2675  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START:.*]]) :
2676  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2677  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]])
2678  // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
2679  // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
2680  // CHECK: return %[[RESULT]] : tensor<2xi32>
2681  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2682  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2683  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
2684  return %0 : tensor<2xi32>
2685}
2686
2687// CHECK-LABEL: slice_i32_consts
2688func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
2689  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi32>
2690  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64>
2691  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2692  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2693  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2694  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
2695  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
2696  // CHECK: "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
2697  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2698  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2699  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
2700  return %0 : tensor<2xi32>
2701}
2702
2703// CHECK-LABEL: slice_constant_start_negative_one_size
2704func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
2705  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64>
2706  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64>
2707  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2708  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2709  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2710  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
2711  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
2712  // CHECK: %[[RESULT:.*]] =  "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32>
2713  // CHECK: return %[[RESULT]] : tensor<3xi32>
2714  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2715  %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2716  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
2717  return %0 : tensor<3xi32>
2718}
2719
2720// CHECK-LABEL: slice_constant_start_dynamic_shape
2721func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2722  // CHECK: %[[START:.*]] = mhlo.constant dense<[1, 0]> : tensor<2xi64>
2723  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64>
2724  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]])
2725  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2726  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2727  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2728  // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
2729  // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) :
2730  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2731  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]])
2732  // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
2733  // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
2734  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2735  // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
2736  // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) :
2737  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2738  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"
2739  // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]])
2740  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
2741  // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
2742  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2743  %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2744  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2745  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2746  return %0 : tensor<1x4xi32>
2747}
2748
2749// CHECK-LABEL: slice_variable_start
2750func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2751  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
2752  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]])
2753  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2754  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2755  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
2756  // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64>
2757  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
2758  // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64>
2759  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
2760  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2761  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2762  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2763  return %0 : tensor<1x4xi32>
2764}
2765
2766// CHECK-LABEL: slice_mhlo_sizes
2767func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> {
2768  // CHECK-NOT: "tf.Slice"
2769  %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
2770  %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32>
2771  return %1 : tensor<1x512x4xf32>
2772}
2773
2774// CHECK-LABEL: slice_variable_start_negative_one_size
2775func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2776  // CHECK: %[[RESULT:.*]] = "tf.Slice"
2777  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2778  %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2779  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2780  return %0 : tensor<1x4xi32>
2781}
2782
2783// CHECK-LABEL: slice_real_dynamic_slice
2784func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<*xi32> {
2785  // CHECK: tensor.extract {{.*}} : tensor<1xi64>
2786  // CHECK: tensor.extract {{.*}} : tensor<1xi64>
2787  // CHECK: index_cast {{.*}} : index to i64
2788  // CHECK: cmpi eq, {{.*}} : i64
2789  // CHECK: addi {{.*}} : i64
2790  // CHECK: tensor.dim {{.*}} : tensor<4xi32>
2791  // CHECK: index_cast {{.*}} : index to i64
2792  // CHECK: select {{.*}} : i64
2793  // CHECK: index_cast {{.*}} : i64 to index
2794  // CHECK: index_cast {{.*}} : i64 to index
2795  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
2796  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
2797  // CHECK: tensor.from_elements {{.*}} : tensor<1xindex>
2798  %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xi32>
2799  return %0 : tensor<*xi32>
2800}
2801
2802//===----------------------------------------------------------------------===//
2803// StridedSlice op legalizations.
2804//===----------------------------------------------------------------------===//
2805
2806// CHECK-LABEL: simple_strided_slice
2807func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
2808  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2809  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2810  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2811
2812  // CHECK: mhlo.slice
2813  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
2814  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
2815  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2816  // CHECK-SAME: -> tensor<3x2xf32>
2817
2818  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2819      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
2820  return %output : tensor<3x2xf32>
2821}
2822
2823// CHECK-LABEL: dynamic_strided_slice
2824func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
2825  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2826  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2827  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2828
2829  // CHECK: "tf.StridedSlice"
2830  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2831      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
2832  return %output : tensor<?x2xf32>
2833}
2834
2835// CHECK-LABEL: strided_slice_negative_indices
2836func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
2837  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2838  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2839  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2840
2841  // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
2842
2843  // CHECK: mhlo.slice
2844  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
2845  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
2846  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2847  // CHECK-SAME: -> tensor<3x2xf32>
2848
2849  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2850      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
2851  return %output : tensor<3x2xf32>
2852}
2853
2854// CHECK-LABEL: dynamic_strided_slice_negative_indices
2855func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
2856  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2857  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2858  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2859
2860  // CHECK: tf.StridedSlice
2861  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2862      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
2863  return %output : tensor<?x2xf32>
2864}
2865
2866// CHECK-LABEL: strided_slice_range_clamping
2867func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> {
2868  %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2869  %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2870  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2871
2872  // CHECK: mhlo.slice
2873  // CHECK-DAG-SAME: start_indices = dense<[0, 0]>
2874  // CHECK-DAG-SAME: limit_indices = dense<[1, 8]>
2875  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2876  // CHECK-SAME: -> tensor<1x3xf32>
2877  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2878      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
2879  return %output : tensor<1x3xf32>
2880}
2881
2882// CHECK-LABEL: strided_slice_empty
2883func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> {
2884  %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2885  %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2886  %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2887
2888  // CHECK: mhlo.constant dense<> : tensor<0xf32>
2889  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2890      : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32>
2891  return %output : tensor<0xf32>
2892}
2893
2894// CHECK-LABEL: strided_slice_begin_end_mask
2895// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32>
2896func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) {
2897
2898  // For StridedSlice
2899  // Dim #:        0,   1,    2
2900  // Input shape: [4, 128, 1024]
2901  // Begin:        1,   4,   -3
2902  // End:          8,  65,   42
2903  // Stride:       1,   4,   -1
2904  // Begin mask:   0,   0,    1  (= 1)
2905  // End mask:     1,   0,    0  (= 4)
2906
2907  // So result shape:
2908  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
2909  // Dim #1: 4 to 65 stride 4: so 16
2910  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
2911  // result shape: [4, 16, 1022]
2912
2913  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2914  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2915  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2916
2917  // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]])
2918
2919  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]])
2920  // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]>
2921  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]>
2922  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
2923  // CHECK-SAME: -> tensor<4x16x1022xf32>
2924
2925  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32>
2926
2927  // CHECK: "mhlo.reshape"(%[[SLICE]])
2928  // CHECK-SAME: -> tensor<4x16x1022xf32>
2929
2930  return
2931}
2932
2933// CHECK-LABEL: strided_slice_shrink_axis_mask
2934// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32>
2935func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) {
2936
2937  // For StridedSlice
2938  // Dim #:            0,   1,    2
2939  // Input shape:     [4, 128, 1024]
2940  // Begin:            1,   4,   -3
2941  // End:              8,  65,   42
2942  // Stride:           1,   4,   -1
2943  // Begin mask:       1,   0,    0  (= 1)
2944  // End mask:         0,   0,    1  (= 4)
2945  // Shrink axis mask: 1,   0,    1  (= 5)
2946
2947  // So result shape:
2948  // Dim #0: shrink axis, take value at [1]
2949  // Dim #1: 4 to 65 stride 4: so 16
2950  // Dim #2: shrink axis, take value at [-3]
2951  // result shape: [16]
2952
2953  // As output shape of StridedSlice differs, a reshape will follow.
2954
2955  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2956  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2957  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2958
2959  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
2960  // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]>
2961  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]>
2962  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
2963  // CHECK-SAME: -> tensor<1x16x1xf32>
2964
2965  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32>
2966
2967  // CHECK: "mhlo.reshape"(%[[SLICE]])
2968  // CHECK-SAME: -> tensor<16xf32>
2969
2970  return
2971}
2972
2973// CHECK-LABEL: strided_slice_ellipsis_mask
2974// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
2975func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
2976  // For StridedSlice input[1, ..., 8:, :10, 2:6:2]
2977  // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized
2978  // slice input[1, :, :, 8:, :10, 2:6:2]
2979
2980  // The start, limit indices and strides attributes of mhlo.slice would
2981  // reflect the canonicalized slice.
2982  // As output shape of StridedSlice differs, a reshape will follow.
2983
2984  %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2985  %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2986  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2987
2988  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
2989  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
2990  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
2991  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
2992  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
2993  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32>
2994
2995  // CHECK: "mhlo.reshape"(%[[SLICE]])
2996  // CHECK-SAME: -> tensor<4x8x8x10x2xf32>
2997
2998  return
2999}
3000
3001// CHECK-LABEL: strided_slice_new_axis_mask
3002// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
3003func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
3004  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
3005  // New axis mask is at index 1 and 6 of sparse spec, so
3006  // new_axis_mask = 2^1 + 2^6 = 66
3007  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
3008  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
3009  // This is then reshaped to add the new axes.
3010
3011  // The start, limit indices and strides attributes of mhlo.slice would
3012  // reflect the canonicalized slice.
3013  // As output shape of StridedSlice differs, a reshape will follow to reflect
3014  // new axes added.
3015
3016  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3017  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3018  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
3019
3020  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
3021  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
3022  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
3023  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
3024  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
3025  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32>
3026
3027  // CHECK: "mhlo.reshape"(%[[SLICE]])
3028  // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32>
3029
3030  return
3031}
3032
3033// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
3034// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
3035func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
3036  // StridedSlice gets input[8:10], which is same as input[8:10, ...]
3037  // The start_indices, limit_indices, and strides attribute of mhlo.slice
3038  // reflect the canonicalized slice.
3039  %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
3040  %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
3041  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3042  // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]])
3043  // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
3044  // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
3045  // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
3046  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
3047  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
3048  // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
3049  return %0 : tensor<2x16x2xf32>
3050}
3051
3052// CHECK-LABEL: strided_slice_nonconstant_begin_end
3053func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) {
3054  // In this case, the `begin` and `end` inputs are unknown at compile time --
3055  // so the StridedSlice needs to slice these vectors and use that as input to
3056  // an HLO dynamic slice.
3057  %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
3058  %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3059  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3060  %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
3061  %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
3062  // CHECK: %[[A:.*]] = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
3063  // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]])
3064  // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
3065  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
3066  // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]])
3067  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
3068  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
3069  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
3070  // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
3071  // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
3072  // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
3073  // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
3074  // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
3075  // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
3076  // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3077  // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice"
3078  // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
3079  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
3080  // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32>
3081  // CHECK-NEXT: %[[FINAL:.*]] = "mhlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32>
3082  %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3083  // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32>
3084  return %result : tensor<1x97xi32>
3085}
3086
3087// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1
3088func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) {
3089  // Dynamic stride: when `begin` and `end` inputs are unknown at compile time,
3090  // `strides` must be known.
3091  // CHECK: tf.StridedSlice
3092  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3093  return %result : tensor<1x97xi32>
3094}
3095
3096// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2
3097func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3098  // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown
3099  // at compile time, `strides` must be known to have all 1 values.
3100  %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
3101  // CHECK: tf.StridedSlice
3102  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3103  return %result : tensor<1x97xi32>
3104}
3105
3106// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count
3107func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> {
3108  %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
3109  // When begin/end are dynamic, the number of output elements must be equal to
3110  // the number of input elements sliced.
3111  // CHECK: tf.StridedSlice
3112  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32>
3113  return %0 : tensor<6x10xf32>
3114}
3115
3116// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask
3117func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3118  // Begin mask: When `begin` and `end` inputs are unknown at compile time, we
3119  // can't support a begin mask.
3120  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3121  // CHECK: tf.StridedSlice
3122  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3123  return %result : tensor<1x97xi32>
3124}
3125
3126// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask
3127func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3128  // End mask: When `begin` and `end` inputs are unknown at compile time, we
3129  // can't support an end mask.
3130  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3131  // CHECK: tf.StridedSlice
3132  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3133  return %result : tensor<1x97xi32>
3134}
3135
3136// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask
3137func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3138  // New axis mask: When `begin` and `end` inputs are unknown at compile time,
3139  // we can't support a new_axis mask.
3140  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3141  // CHECK: tf.StridedSlice
3142  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3143  return %result : tensor<1x97xi32>
3144}
3145
3146// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask
3147func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3148  // This ellipsis mask is not supported because it does not refer to the last
3149  // dimension.
3150  // [0, 1, 0] = 2
3151  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3152  // CHECK: tf.StridedSlice
3153  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3154  return %result : tensor<1x97xi32>
3155}
3156
3157// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask
3158func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3159  // This ellipsis mask is supported because it refers to the last dimension.
3160  // [1, 0, 0] = 4
3161  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3162  // CHECK: mhlo.dynamic-slice
3163  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3164  return %result : tensor<1x97xi32>
3165}
3166
3167// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask
3168func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3169  // This shrink_axis mask is supported because it refers to a major dimension.
3170  // [1, 1, 1] = 7
3171  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3172  // CHECK: mhlo.dynamic-slice
3173  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3174  return %result : tensor<1x97xi32>
3175}
3176
3177// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask
3178func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3179  // This shrink_axis mask is unsupported because it does not refer to a major
3180  // dimension.
3181  // [0, 1, 0] = 2
3182  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3183  // CHECK: tf.StridedSlice
3184  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3185  return %result : tensor<1x97xi32>
3186}
3187
3188
3189//===----------------------------------------------------------------------===//
3190// Reduction op legalizations.
3191//===----------------------------------------------------------------------===//
3192
3193// CHECK-LABEL: func @mean
3194func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3195  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3196  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3197  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3198  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3199  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3200  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3201  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3202  // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
3203  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16>
3204  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3205  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3206  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3207  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3208  return %0 : tensor<4x1xf16>
3209}
3210
3211// CHECK-LABEL: func @mean_scalar_dim
3212func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3213  // Verify that tf.Mean op with scalar attributes are lowered successfully.
3214
3215  // CHECK-NOT: tf.Mean
3216  %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
3217  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16>
3218  return %0 : tensor<4x1xf16>
3219}
3220
3221// CHECK-LABEL: func @mean_dynamic
3222func @mean_dynamic(%arg0: tensor<?x?xf16>) -> tensor<?x1xf16> {
3223  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<?x?xf16>) -> tensor<?x?xf32>
3224  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3225  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3226  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3227  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3228  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3229  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
3230  // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?xf16> -> tensor<2xindex>
3231  // CHECK: %[[C1_1:.*]] = constant 1 : index
3232  // CHECK: %[[C1_2:.*]] = constant 1 : index
3233  // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex>
3234  // CHECK: %[[MUL:.*]] = muli %[[C1_1]], %[[REDUCED_DIM]] : index
3235  // CHECK: %[[INDEX_CAST:.*]] = index_cast %[[MUL]] : index to i64
3236  // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor<1xi64>
3237  // CHECK: %[[SCALAR_TENSOR:.*]] = "mhlo.reshape"(%[[TENSOR]]) : (tensor<1xi64>) -> tensor<i64>
3238  // CHECK: %[[CONVERT:.*]] = "mhlo.convert"(%[[SCALAR_TENSOR]]) : (tensor<i64>) -> tensor<f32>
3239  // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
3240  // CHECK: %[[MEAN_CONVERTED:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<?xf32>) -> tensor<?xf16>
3241  // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor<?xf16> -> tensor<1xindex>
3242  // CHECK: %[[C1:.*]] = constant 1 : index
3243  // CHECK: %[[C0:.*]] = constant 0 : index
3244  // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex>
3245  // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex>
3246  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[MEAN_CONVERTED]], %[[RESULT_SHAPE]]) : (tensor<?xf16>, tensor<2xindex>) -> tensor<?x1xf16>
3247  // CHECK: return %[[RESULT]] : tensor<?x1xf16>
3248  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3249  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<?x?xf16>, tensor<1xi64>) -> tensor<?x1xf16>
3250  return %0 : tensor<?x1xf16>
3251}
3252
3253// CHECK-LABEL: func @sum
3254func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3255  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3256  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3257  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3258  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3259  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3260  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3261  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3262  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3263  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3264  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3265  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3266  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3267  return %0 : tensor<4x1xf16>
3268}
3269
3270// CHECK-LABEL: func @sum_dynamic
3271func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3272    // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32>
3273    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3274    // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3275    // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3276    // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3277    // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3278    // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32>
3279    // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3280    // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3281    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3282  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3283  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3284  return %0 : tensor<4x1xf16>
3285}
3286
3287// CHECK-LABEL: func @max
3288func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3289  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16>
3290  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3291  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3292  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3293  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16>
3294  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3295  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3296  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3297  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3298  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3299  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3300  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3301  return %0 : tensor<4x1xf16>
3302}
3303
3304// CHECK-LABEL: func @max_qint
3305// Regression test to ensure we don't crash getting the initial value for
3306// tf.Max when using quantized integer types.
3307func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3308  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3309  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3310  return %0 : tensor<4x1x!tf_type.qint8>
3311}
3312
3313// CHECK-LABEL: func @max_dynamic
3314func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3315    // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16>
3316    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3317    // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3318    // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3319    // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16>
3320    // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3321    // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16>
3322    // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3323    // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3324    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3325  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3326  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3327  return %0 : tensor<4x1xf16>
3328}
3329
3330// CHECK-LABEL: func @min
3331func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3332  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16>
3333  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16>
3334  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3335  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3336  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.minimum %[[ARGA]], %[[ARGB]] : tensor<f16>
3337  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3338  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3339  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3340  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3341  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3342  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3343  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3344  return %0 : tensor<4x1xf16>
3345}
3346
3347// CHECK-LABEL: func @min_qint
3348// Regression test to ensure we don't crash getting the initial value for
3349// tf.Min when using quantized integer types.
3350func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3351  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3352  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3353  return %0 : tensor<4x1x!tf_type.qint8>
3354}
3355
3356// CHECK-LABEL: func @prod
3357func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3358  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3359  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3360  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3361  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3362  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.multiply %[[ARGA]], %[[ARGB]] : tensor<f32>
3363  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3364  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3365  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3366  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3367  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3368  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3369  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3370  return %0 : tensor<4x1xf16>
3371}
3372
3373// CHECK-LABEL: func @prod_qint
3374// Regression test to ensure we don't crash getting the initial value for
3375// tf.Prod when using quantized integer types.
3376func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> {
3377  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3378  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8>
3379  return %0 : tensor<4x1x!tf_type.qint8>
3380}
3381
3382// CHECK-LABEL: @all
3383func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3384  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3385  // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1>
3386  // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( {
3387  // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
3388  // CHECK:  %[[AND:.*]] = mhlo.and %[[ARGA]], %[[ARGB]] : tensor<i1>
3389  // CHECK:  "mhlo.return"(%[[AND]]) : (tensor<i1>) -> ()
3390  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3391  %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3392  return %0 : tensor<4xi1>
3393}
3394
3395// CHECK-LABEL: @all_keep_dim
3396func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3397  // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
3398  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3399  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3400  return %0 : tensor<4x1xi1>
3401}
3402
3403// CHECk-LABEL: @all_dynamic
3404func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3405  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3406  // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
3407  // CHECK: "mhlo.reduce"(%[[ARG]]
3408  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3409  return %0 : tensor<4x1xi1>
3410}
3411
3412// CHECK-LABEL: @any
3413func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3414  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3415  // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1>
3416  // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( {
3417  // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
3418  // CHECK:  %[[AND:.*]] = mhlo.or %[[ARGA]], %[[ARGB]] : tensor<i1>
3419  // CHECK:  "mhlo.return"(%[[AND]]) : (tensor<i1>) -> ()
3420  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3421  %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3422  return %0 : tensor<4xi1>
3423}
3424
3425// CHECK-LABEL: @any_keep_dim
3426func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3427  // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
3428  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3429  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3430  return %0 : tensor<4x1xi1>
3431}
3432
3433// CHECk-LABEL: @any_dynamic
3434func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3435  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3436  // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
3437  // CHECK: "mhlo.reduce"(%[[ARG]]
3438  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3439  return %0 : tensor<4x1xi1>
3440}
3441
3442//===----------------------------------------------------------------------===//
3443// Tile op legalizations.
3444//===----------------------------------------------------------------------===//
3445
3446// CHECK-LABEL: func @tile_by_reshape
3447func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> {
3448  // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32>
3449  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32>
3450  // CHECK: return %[[RESULT]] : tensor<28x24xf32>
3451  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
3452  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32>
3453  return %0 : tensor<28x24xf32>
3454}
3455
3456// CHECK-LABEL: func @tile_just_broadcast
3457func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> {
3458  // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32>
3459  // CHECK: return %[[RESULT]] : tensor<7x3xf32>
3460  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
3461  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32>
3462  return %0 : tensor<7x3xf32>
3463}
3464
3465// CHECK-LABEL: func @tile_dynamic_shape
3466func @tile_dynamic_shape(%arg0: tensor<?x8xf32>) -> tensor<?x24xf32> {
3467  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32>
3468  // CHECK: tensor.dim {{.*}} : tensor<?x8xf32>
3469  // CHECK: tensor.from_elements  {{.*}} : tensor<4xindex>
3470  // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<?x8xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
3471  // CHECK: muli {{.*}} : index
3472  // CHECK: tensor.from_elements {{.*}} : tensor<2xindex>
3473  // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x?x?xf32>, tensor<2xindex>) -> tensor<?x24xf32>
3474  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<?x8xf32>, tensor<2xi32>) -> tensor<?x24xf32>
3475  return %0 : tensor<?x24xf32>
3476}
3477
3478//===----------------------------------------------------------------------===//
3479// ArgMax/ArgMin op legalizations.
3480//===----------------------------------------------------------------------===//
3481
3482// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0
3483func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> {
3484  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64>
3485  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3486  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32>
3487  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3488  // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>):
3489  // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
3490  // CHECK:  %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
3491  // CHECK: %[[COMPARE_EQ:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
3492  // CHECK:  %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]]
3493  // CHECK:  %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3494  // CHECK:  %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3495  // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT3]]) : (tensor<i64>, tensor<i32>) -> ()
3496  // CHECK: return %[[REDUCE]]#1 : tensor<7xi32>
3497  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
3498  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32>
3499  return %0 : tensor<7xi32>
3500}
3501
3502// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1
3503func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> {
3504  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
3505  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant  dense<0> : tensor<i64>
3506  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64>
3507  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3508  // CHECK: return %[[REDUCE]]#1 : tensor<3xi64>
3509  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
3510  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64>
3511  return %0 : tensor<3xi64>
3512}
3513
3514// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1
3515func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> {
3516  // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1>
3517  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant  dense<0> : tensor<i64>
3518  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64>
3519  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3520  // CHECK: return %[[REDUCE]]#1 : tensor<3xi64>
3521  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
3522  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor<i32>) -> tensor<3xi64>
3523  return %0 : tensor<3xi64>
3524}
3525
3526// CHECK-LABEL: func @argmax_dynamic_shape_input_output
3527func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> {
3528  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
3529  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3530  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32>
3531  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3532  // CHECK: return %[[REDUCE]]#1 : tensor<?xi32>
3533  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
3534  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32>
3535  return %0 : tensor<?xi32>
3536}
3537
3538// CHECK-LABEL: func @argmax_dynamic_shape_input
3539func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> {
3540  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
3541  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3542  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32>
3543  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3544  // CHECK: return %[[REDUCE]]#1 : tensor<3xi32>
3545  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
3546  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32>
3547  return %0 : tensor<3xi32>
3548}
3549
3550// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0
3551func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> {
3552  // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor<i64>
3553  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3554  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32>
3555  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3556  // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>):
3557  // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "LE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
3558  // CHECK:  %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
3559  // CHECK: %[[COMPARE_EQ:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
3560  // CHECK:  %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]]
3561  // CHECK:  %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3562  // CHECK:  %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3563  // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT3]]) : (tensor<i64>, tensor<i32>) -> ()
3564  // CHECK: return %[[REDUCE]]#1 : tensor<7xi32>
3565  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
3566  %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32>
3567  return %0 : tensor<7xi32>
3568}
3569
3570//===----------------------------------------------------------------------===//
3571// Random op legalizations.
3572//===----------------------------------------------------------------------===//
3573
3574// CHECK-LABEL: func @rng_uniform
3575func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
3576  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3577  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3578  // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
3579  // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32>
3580  %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
3581  // CHECK: return %[[F32]]
3582  return %0 : tensor<12x?x64xf32>
3583}
3584
3585// CHECK-LABEL: func @rng_std_normal
3586func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
3587  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3588  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3589  // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
3590  // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32>
3591  %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
3592  // CHECK: return %[[F32]]
3593  return %0 : tensor<12x?x64xf32>
3594}
3595
3596//===----------------------------------------------------------------------===//
3597// Range op legalizations.
3598//===----------------------------------------------------------------------===//
3599
3600// CHECK-LABEL: func @range
3601// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
3602func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
3603  %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
3604  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"
3605  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3606  // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3607  %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3608  return %3 : tensor<5xf32>
3609}
3610
3611// CHECK-LABEL: func @range_dynamic
3612// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
3613func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> {
3614  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
3615  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]])
3616  // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]])
3617  // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2)
3618  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
3619  // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]])
3620  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]])
3621  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]])
3622  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
3623  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
3624  // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
3625  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3626  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3627  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
3628
3629  // CHECK: return [[ADD]]
3630  return %2 : tensor<?xf32>
3631}
3632
3633// CHECK-LABEL: func @range_int_dynamic
3634// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32>
3635func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> {
3636  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
3637  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]])
3638  // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]])
3639  // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2)
3640  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
3641  // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]])
3642  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]])
3643  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]])
3644  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
3645  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
3646  // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
3647  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3648  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3649  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
3650
3651  // CHECK: return [[ADD]]
3652  return %2 : tensor<?xi32>
3653}
3654
3655// CHECK-LABEL: func @linspace_static
3656// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
3657func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
3658  // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4>
3659  // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM]])
3660  // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00>
3661  // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
3662  // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]]
3663  // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
3664  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64}
3665  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3666  // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3667  // CHECK: return [[LINSPACE]]
3668  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32>
3669  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
3670  return %1 : tensor<4xf32>
3671}
3672
3673// CHECK-LABEL: func @linspace_dynamic
3674func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> {
3675  // CHECK: "tf.LinSpace"
3676  %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32>
3677  return %0 : tensor<?xf32>
3678}
3679
3680// CHECK-LABEL: func @linspace_invalid_num
3681func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
3682  // CHECK: mhlo.constant dense<> : tensor<0xi32>
3683  // CHECK: "tf.LinSpace"
3684  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
3685  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32>
3686  return %1 : tensor<?xf32>
3687}
3688
3689//===----------------------------------------------------------------------===//
3690// LegacyCall op legalizations.
3691//===----------------------------------------------------------------------===//
3692
3693func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
3694  return %arg0: tensor<10x2xf32>
3695}
3696
3697// CHECK-LABEL: testSimpleLegacyCallOp
3698func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
3699  // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32>
3700  %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32>
3701  // CHECK: return %[[RESULT]]
3702  return %0: tensor<10x2xf32>
3703}
3704
3705func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
3706  return %arg0: tensor<10x2xf32>
3707}
3708
3709// CHECK-LABEL: testMultiInputLegacyCallOp
3710func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
3711  // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
3712  %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
3713  // CHECK: return %[[RESULT]]
3714  return %0: tensor<10x2xf32>
3715}
3716
3717//===----------------------------------------------------------------------===//
3718// Conv op legalizations.
3719//===----------------------------------------------------------------------===//
3720
3721// CHECK-LABEL: conv_simple
3722func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> {
3723
3724  // CHECK: mhlo.convolution(%arg0, %arg1)
3725  // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
3726  // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]}
3727  // CHECK-SAME: batch_group_count = 1
3728  // CHECK-SAME: feature_group_count = 2
3729
3730  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
3731  return %0 : tensor<256x8x7x16xf32>
3732}
3733
3734// CHECK-LABEL: conv3d_simple
3735func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> {
3736
3737  // CHECK: mhlo.convolution(%arg0, %arg1)
3738  // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f]
3739  // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]}
3740  // CHECK-SAME: batch_group_count = 1
3741  // CHECK-SAME: feature_group_count = 2
3742
3743  %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32>
3744  return %0 : tensor<256x7x6x5x16xf32>
3745}
3746
3747// CHECK-LABEL: depthwiseconv_simple
3748func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
3749  // CHECK: %[[RESHAPED_FILTER:.*]] = "mhlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
3750  // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]])
3751  // CHECK-SAME: feature_group_count = 3
3752  %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
3753    data_format = "NHWC",
3754    device = "",
3755    dilations = [1, 1, 1, 1],
3756    explicit_paddings = [],
3757    padding = "VALID",
3758    strides = [1, 1, 1, 1]
3759  } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
3760  return %0 : tensor<2x3x4x9xf32>
3761}
3762
3763// CHECK-LABEL: conv_valid_padding
3764func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
3765  // CHECK: mhlo.convolution(%arg0, %arg1)
3766
3767  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32>
3768  return %0 : tensor<1x2x3x1xf32>
3769}
3770
3771// CHECK-LABEL: conv_explicit_paddings
3772func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> {
3773
3774  // CHECK: mhlo.convolution(%arg0, %arg1)
3775  // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]]
3776
3777  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32>
3778  return %0 : tensor<256x9x7x16xf32>
3779}
3780
3781// CHECK-LABEL: @conv2d_backprop_input
3782func @conv2d_backprop_input(
3783    %filter: tensor<3x3x1x32xf32>,
3784    %out_backprop: tensor<100x26x26x32xf32>
3785  ) -> tensor<100x28x28x1xf32> {
3786    // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
3787    // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]])
3788    // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]
3789    // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
3790    // CHECK-SAME: batch_group_count = 1 : i64
3791    // CHECK-SAME: feature_group_count = 1 : i64
3792    // CHECK: return %[[RESULT]]
3793  %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32>
3794  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
3795    data_format = "NHWC",
3796    dilations = [1, 1, 1, 1],
3797    explicit_paddings = [],
3798    padding = "VALID",
3799    strides = [1, 1, 1, 1],
3800    use_cudnn_on_gpu = true
3801  } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
3802  return %result : tensor<100x28x28x1xf32>
3803}
3804
3805// CHECK-LABEL: @conv2d_backprop_input_grouped
3806func @conv2d_backprop_input_grouped(
3807    %filter: tensor<2x2x5x21xf32>,
3808    %out_backprop: tensor<5x2x2x21xf32>
3809  ) -> tensor<5x3x3x15xf32> {
3810  %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32>
3811
3812  // Verify filter transformation for grouped convolution.
3813
3814  // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%arg0) : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32>
3815  // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]])
3816  // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]>
3817  // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32>
3818  // CHECK: "mhlo.reshape"(%[[TRANSPOSE]]) : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32>
3819
3820  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
3821    data_format = "NHWC",
3822    dilations = [1, 1, 1, 1],
3823    explicit_paddings = [],
3824    padding = "VALID",
3825    strides = [1, 1, 1, 1],
3826    use_cudnn_on_gpu = true
3827  } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32>
3828  return %result : tensor<5x3x3x15xf32>
3829}
3830
3831
3832// CHECK-LABEL: @conv3d_backprop_input
3833func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
3834  // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
3835  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]])
3836  // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f]
3837  // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
3838  // CHECK-SAME: batch_group_count = 1 : i64,
3839  // CHECK-SAME: feature_group_count = 1 : i64
3840
3841  // CHECK: return %[[RESULT]]
3842  %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32>
3843  %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
3844  return %result : tensor<2x8x8x8x1xf32>
3845}
3846
3847// CHECK-LABEL: @conv2d_backprop_filter
3848func @conv2d_backprop_filter(
3849    %input: tensor<100x28x28x1xf32>,
3850    %out_backprop: tensor<100x26x26x32xf32>
3851  ) -> tensor<100x28x28x1xf32> {
3852  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
3853  // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f]
3854  // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
3855  // CHECK-SAME:  batch_group_count = 1 : i64
3856  // CHECK-SAME:  feature_group_count = 1 : i64
3857  // CHECK: return %[[RESULT]]
3858  %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32>
3859  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
3860    data_format = "NHWC",
3861    dilations = [1, 1, 1, 1],
3862    explicit_paddings = [],
3863    padding = "VALID",
3864    strides = [1, 1, 1, 1],
3865    use_cudnn_on_gpu = true
3866  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
3867  return %result : tensor<100x28x28x1xf32>
3868}
3869
3870// CHECK-LABEL: @conv2d_backprop_filter_grouped
3871func @conv2d_backprop_filter_grouped(
3872    %input: tensor<1x2x2x2xf32>,
3873    %out_backprop: tensor<1x1x1x2xf32>
3874  ) -> tensor<2x2x1x2xf32> {
3875
3876  // CHECK: mhlo.convolution(%arg0, %arg1)
3877  // CHECK-SAME:  batch_group_count = 2 : i64
3878  // CHECK-SAME:  feature_group_count = 1 : i64
3879
3880  %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32>
3881  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
3882    data_format = "NHWC",
3883    dilations = [1, 1, 1, 1],
3884    explicit_paddings = [],
3885    padding = "VALID",
3886    strides = [1, 1, 1, 1],
3887    use_cudnn_on_gpu = true
3888  } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32>
3889  return %result : tensor<2x2x1x2xf32>
3890}
3891
3892
3893// CHECK-LABEL: @conv3d_backprop_filter
3894func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
3895  // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1)
3896  // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f]
3897  // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]}
3898  // CHECK-SAME: batch_group_count = 1 : i64
3899  // CHECK-SAME: feature_group_count = 1 : i64
3900  // CHECK: return %[[RESULT]]
3901  %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
3902  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
3903  return %result : tensor<2x8x8x8x1xf32>
3904}
3905
3906// CHECK-LABEL: @collective_permute
3907func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
3908  %source_target_pairs = "tf.Const" () {
3909    value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32>
3910  } : () -> tensor<3x2xi32>
3911
3912  // CHECK: "mhlo.collective_permute"
3913  // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
3914  %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) {
3915  } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32>
3916
3917  return %0 : tensor<128x32xf32>
3918}
3919
3920// CHECK-LABEL: @cross_replica_sum
3921func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
3922  %replica_groups = "tf.Const" () {
3923    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
3924  } : () -> tensor<2x4xi32>
3925
3926  // CHECK: mhlo.cross-replica-sum
3927  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
3928  %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32>
3929  return %result : tensor<10xf32>
3930}
3931
3932// CHECK-LABEL: conv_dynamic
3933func @conv_dynamic(%arg0: tensor<?x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> {
3934  // CHECK: "mhlo.dynamic_conv"({{.*}}) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<?x8x7x16xf32>
3935  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32>
3936  return %0 : tensor<?x8x7x16xf32>
3937}
3938
3939//===----------------------------------------------------------------------===//
3940// tf.Split legalization
3941//===----------------------------------------------------------------------===//
3942
3943// CHECK-LABEL: @split_not_match_non_const_split_dim
3944func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
3945  // CHECK: tf.Split
3946  %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
3947  return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32>
3948}
3949
3950// CHECK-LABEL: @split_not_match_unknown_input_dim
3951func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) {
3952  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3953  // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32>
3954  // CHECK: divi_signed {{.*}} : index
3955  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
3956  // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32>
3957  // CHECK: muli {{.*}} : index
3958  // CHECK: muli {{.*}} : index
3959  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
3960  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
3961  // CHECK: tensor.from_elements {{.*}} : tensor<3xindex>
3962  // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32>
3963  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>)
3964  return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32>
3965}
3966
3967// CHECK-LABEL: @split_match_and_split_into_two
3968func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) {
3969  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
3970  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
3971  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
3972  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>)
3973  // CHECK: return %[[ONE]], %[[TWO]]
3974  return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32>
3975}
3976
3977// CHECK-LABEL: @split_match_and_split_into_two_dynamic
3978func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) {
3979  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
3980  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
3981  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
3982  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>)
3983  // CHECK: return %[[ONE]], %[[TWO]]
3984  return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32>
3985}
3986
3987// CHECK-LABEL: @split_match_and_split_into_three
3988// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
3989func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) {
3990  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3991  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3992  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3993  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3994  %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3995  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
3996  return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
3997}
3998
3999//===----------------------------------------------------------------------===//
4000// tf.TopKV2 legalization
4001//===----------------------------------------------------------------------===//
4002
4003// CHECK-LABEL: topk_v2_non_const_k
4004func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
4005  // CHECK: tf.TopKV2
4006  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
4007  return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
4008}
4009
4010// CHECK-LABEL: topk_v2_unknown_input_last_dim
4011func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
4012  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4013  // CHECK: tf.TopKV2
4014  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
4015  return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
4016}
4017
4018// CHECK-LABEL: topk_v2
4019// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
4020func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
4021  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
4022
4023  // CHECK:      %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64}
4024  // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
4025  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
4026  // CHECK-NEXT:   %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {compare_type = "TOTALORDER", comparison_direction = "GT"}
4027  // CHECK-NEXT:   "mhlo.return"(%[[CMP]])
4028  // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
4029  // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
4030  // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
4031  // CHECK-NEXT: return %[[VAL]], %[[IDX]]
4032  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
4033  return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
4034}
4035
4036//===----------------------------------------------------------------------===//
4037// tf.SplitV legalization
4038//===----------------------------------------------------------------------===//
4039
4040// CHECK-LABEL: @splitv_match_and_split_into_three
4041// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
4042func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
4043  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4044  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4045  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32>
4046  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
4047  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32>
4048  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
4049  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
4050  return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
4051}
4052
4053// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic
4054func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) {
4055  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4056  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4057  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32>
4058  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32>
4059  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32>
4060  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>)
4061  return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>
4062}
4063
4064// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes
4065func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
4066  %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
4067  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
4068  // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>
4069  // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>
4070  // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>
4071  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
4072  return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
4073}
4074
4075//===----------------------------------------------------------------------===//
4076// tf.Assert legalization
4077//===----------------------------------------------------------------------===//
4078
4079// CHECK-LABEL: @assert
4080func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
4081  // CHECK-NOT: tf.Assert
4082  "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> ()
4083  return
4084}
4085
4086//===----------------------------------------------------------------------===//
4087// tf.Unpack legalization
4088//===----------------------------------------------------------------------===//
4089
4090// CHECK-LABEL: @unpack
4091func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
4092  // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4093  // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4094  // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4095  // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4096  // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4097  // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4098
4099  %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
4100  // return %[[RES1]], %[[RES2]], %[[RES3]]
4101  return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
4102}
4103
4104// CHECK-LABEL: func @unpack_dynamic
4105func @unpack_dynamic(%arg0: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4106  // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32>
4107  // CHECK: tensor.from_elements {{.*}} : tensor<2xi32>
4108  // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32>
4109  // CHECK: tensor.from_elements {{.*}} : tensor<3xi32>
4110  // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32>
4111  // CHECK: tensor.from_elements {{.*}} : tensor<2xi32>
4112  // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32>
4113  // CHECK: return {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>
4114  %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4115  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4116}
4117
4118// CHECK-LABEL: @unpack_unranked
4119func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4120
4121  // CHECK: tf.Unpack
4122  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4123  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4124}
4125
4126//===----------------------------------------------------------------------===//
4127// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
4128//===----------------------------------------------------------------------===//
4129
4130// CHECK-LABEL: @unsorted_segment_sum
4131// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32>
4132// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32>
4133func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) {
4134  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4135  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4136  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4137  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( {
4138  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4139  // CHECK:   [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32>
4140  // CHECK:   "mhlo.return"([[ADD]])
4141  // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32>
4142  // CHECK: return [[SCATTER]]
4143  %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>)
4144  return %0: tensor<4x64xf32>
4145}
4146
4147// CHECK-LABEL: @unsorted_segment_prod
4148// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32>
4149// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32>
4150func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4151  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4152  // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4153  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4154  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( {
4155  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4156  // CHECK:   [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32>
4157  // CHECK:   "mhlo.return"([[MUL]])
4158  // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32>
4159  // CHECK: return [[SCATTER]]
4160  %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4161  return %0: tensor<4x?xf32>
4162}
4163
4164// CHECK-LABEL: @unsorted_segment_min
4165func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4166  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4167  // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32>
4168  // CHECK: mhlo.scatter
4169  // CHECK: mhlo.minimum
4170  %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4171  return %0: tensor<4x?xf32>
4172}
4173
4174// CHECK-LABEL: @unsorted_segment_max
4175func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4176  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4177  // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32>
4178  // CHECK: mhlo.scatter
4179  // CHECK: mhlo.maximum
4180  %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4181  return %0: tensor<4x?xf32>
4182}
4183
4184//===----------------------------------------------------------------------===//
4185// tf.GatherNd legalization
4186//===----------------------------------------------------------------------===//
4187// CHECK-LABEL: func @gatherNd_dynamic
4188func @gatherNd_dynamic(%arg0: tensor<?x?x?xi32>, %arg1: tensor<?x6x2xi32>) -> tensor<?x6x?xi32> {
4189  // CHECK: tensor.dim
4190  // CHECK: index_cast
4191  // CHECK: tensor.from_elements
4192  // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = false} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>, tensor<3xi32>) -> tensor<?x6x?xi32>
4193  %0 =  "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>) -> tensor<?x6x?xi32>
4194  return %0 : tensor<?x6x?xi32>
4195}
4196
4197// CHECK-LABEL: func @gatherNd_static
4198func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> {
4199  // CHECK:      "mhlo.gather"({{.*}}) {
4200  // CHECK-SAME:   dimension_numbers = {
4201  // CHECK-SAME:     collapsed_slice_dims = dense<0> : tensor<1xi64>
4202  // CHECK-SAME:     index_vector_dim = 1 : i64
4203  // CHECK-SAME:     offset_dims = dense<[1, 2]> : tensor<2xi64>
4204  // CHECK-SAME:     start_index_map = dense<0> : tensor<1xi64>}
4205  // CHECK-SAME:   indices_are_sorted = false
4206  // CHECK-SAME:   slice_sizes = dense<[1, 4, 128]> : tensor<3xi64>
4207  // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32>
4208  %0 =  "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32>
4209  return %0 : tensor<2x4x128xf32>
4210}
4211
4212//===----------------------------------------------------------------------===//
4213// tf.GatherV2 legalization
4214//===----------------------------------------------------------------------===//
4215
4216// CHECK-LABEL: @gather_v2
4217func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> {
4218  // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32>
4219  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4220  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32>
4221  return %1 : tensor<16x2x5xf32>
4222}
4223
4224// CHECK-LABEL: @gather_v2_dynamic
4225func @gather_v2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
4226  // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32>
4227  // CHECK: index_cast {{.*}} : index to i32
4228  // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32>
4229  // CHECK: index_cast {{.*}} : index to i32
4230  // CHECK: tensor.from_elements {{.*}} : tensor<3xi32>
4231  // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<3xi32>) -> tensor<*xf32>
4232  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4233  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32>
4234  return %1 : tensor<*xf32>
4235}
4236
4237// CHECK-LABEL: @gather_v2_dynamic_index_i64
4238func @gather_v2_dynamic_index_i64(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>) -> tensor<*xf32> {
4239  // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32>
4240  // CHECK: index_cast {{.*}} : index to i64
4241  // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32>
4242  // CHECK: index_cast {{.*}} : index to i64
4243  // CHECK: tensor.from_elements {{.*}} : tensor<3xi64>
4244  // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<3xi64>) -> tensor<*xf32>
4245  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4246  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<1xi32>) -> tensor<*xf32>
4247  return %1 : tensor<*xf32>
4248}
4249
4250// CHECK-LABEL: @gather_v2_unranked
4251func @gather_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
4252  // CHECK: tf.GatherV2
4253  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4254  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32>
4255  return %1 : tensor<*xf32>
4256}
4257
4258// CHECK-LABEL: @gather_v2_dynamic_shape
4259func @gather_v2_dynamic_shape(%arg0: tensor<?x2x3xf32>, %arg1: tensor<?x5xi32>) -> tensor<?x2x5xf32> {
4260  // CHECK: constant 0
4261  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4262  // CHECK: tensor.dim {{.*}} : tensor<?x2x3xf32>
4263  // CHECK: index_cast {{.*}} : index to i32
4264  // CHECK: tensor.from_elements {{.*}} : tensor<3xi32>
4265  // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<3xi32>) -> tensor<?x2x5xf32>
4266  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<1xi32>) -> tensor<?x2x5xf32>
4267  return %1 : tensor<?x2x5xf32>
4268}
4269
4270//===----------------------------------------------------------------------===//
4271// tf.StridedSliceGrad legalization
4272//===----------------------------------------------------------------------===//
4273
4274// CHECK-LABEL: strided_slice_grad
4275// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32>
4276func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> {
4277
4278  // For StridedSlice
4279  // Dim #:        0,   1,    2
4280  // Input shape: [4, 128, 1024]
4281  // Begin:        1,   4,   -3
4282  // End:          8,  65,   42
4283  // Stride:       1,   4,   -1
4284  // Begin mask:   1,   0,    0  (= 1)
4285  // End mask:     0,   0,    1  (= 4)
4286
4287  // So result shape:
4288  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
4289  // Dim #1: 4 to 65 stride 4: so 16
4290  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
4291  // result shape: [4, 16, 1022]
4292
4293  // To pad back:
4294  // Dim #:        0,   1,   2
4295  // Pad low:      0,   4,   0
4296  // Pad interm:   0,   3,   0
4297  // Pad high:     0,  63,   2
4298
4299  %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4300  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4301  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4302  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4303
4304  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
4305  // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
4306  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4307  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32>
4308
4309  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32>
4310  // CHECK: return [[PAD]]
4311  return %0: tensor<4x128x1024xf32>
4312}
4313
4314// CHECK-LABEL: strided_slice_grad_shrink_axis_mask
4315// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32>
4316func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> {
4317  // Input to StridedSlice was of shape 4x8xf32
4318  // Strided slice gets input[2:3, 0:8]
4319  // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32
4320  // which is the shape of gradient.
4321  // StridedSliceGrad would reshape the gradient to 1x8xf32 and
4322  // then pad to match the shape of input 4x8xf32.
4323
4324  %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4325  %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4326  %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4327  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4328
4329  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32>
4330  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4331  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4332  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64>
4333  // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64>
4334  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64>
4335  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32>
4336
4337  // CHECK: return [[PAD]] : tensor<4x8xf32>
4338  return %0 : tensor<4x8xf32>
4339}
4340
4341// CHECK-LABEL: strided_slice_grad_new_axis_mask
4342// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32>
4343func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> {
4344  // Input to StridedSlice was of shape 8xf32
4345  // Strided slice gets input[tf.new_axis, 2:4]
4346  // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is
4347  // 1x2xf32 which is the shape of gradient.
4348  // StridedSliceGrad would reshape the gradient to 2xf32 and
4349  // then pad to match the shape of input 4x8xf32.
4350
4351  %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
4352  %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4353  %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4354  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4355
4356  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32>
4357  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4358  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4359  // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64>
4360  // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64>
4361  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64>
4362  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32>
4363
4364  // CHECK: return [[PAD]] : tensor<8xf32>
4365  return %0 : tensor<8xf32>
4366}
4367
4368// CHECK-LABEL: strided_slice_grad_ellipsis_mask
4369// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32>
4370func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> {
4371  // Input to StridedSlice was of shape 4x4x8xf32
4372  // Strided slice gets input[2:4, ...]
4373  // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and
4374  // dim#2, ignoring begin and end indices for these dimensions. So the output
4375  // is 2x4x8xf32 which is the shape of gradient.
4376  // StridedSliceGrad would pad the gradient to match the shape of
4377  // input 4x4x8xf32.
4378
4379  %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4380  %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4381  %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4382  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4383
4384  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
4385  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4386  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4387  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64>
4388  // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64>
4389  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64>
4390  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32>
4391
4392  // CHECK: return [[PAD]] : tensor<4x4x8xf32>
4393  return %0 : tensor<4x4x8xf32>
4394}
4395
4396
4397// CHECK-LABEL: strided_slice_grad_all_masks
4398// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32>
4399func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> {
4400  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
4401  // New axis mask is at index 1 and 6 of sparse spec, so
4402  // new_axis_mask = 2^1 + 2^6 = 66
4403  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
4404  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
4405  // The StridedSliceGrad op would propogate the gradient for the sliced tensor
4406  // to the original input tensor by padding with zeroes.
4407
4408  %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>)
4409  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4410  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4411  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4412
4413  // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0)
4414  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32>
4415  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4416  // The edge_padding_low, edge_padding_high and interior_padding attributes of
4417  // mhlo.pad would reflect the padding required to get the shape of the
4418  // input of StridedSlice op.
4419  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]])
4420  // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
4421  // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64>
4422  // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64>
4423  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32>
4424
4425  // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32>
4426  return %0 : tensor<2x4x8x16x32x64xf32>
4427}
4428
4429// CHECK-LABEL: @tensor_scatter_update
4430func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4431  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4432  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4433  // CHECK:    "mhlo.return"(%arg4) : (tensor<f32>) -> ()
4434  // CHECK:  })
4435  // CHECK-SAME: indices_are_sorted = false
4436  // CHECK-SAME: scatter_dimension_numbers
4437  // CHECK-SAME:   index_vector_dim = 1 : i64
4438  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4439  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4440  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4441  // CHECK-SAME: unique_indices = false
4442  %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4443  return %0 : tensor<?x?x?xf32>
4444}
4445
4446// CHECK-LABEL: @tensor_scatter_add
4447func @tensor_scatter_add(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4448  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4449  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4450  // CHECK:    %1 = mhlo.add %arg3, %arg4 : tensor<f32>
4451  // CHECK:    "mhlo.return"(%1) : (tensor<f32>) -> ()
4452  // CHECK:  })
4453  // CHECK-SAME: indices_are_sorted = false
4454  // CHECK-SAME: scatter_dimension_numbers
4455  // CHECK-SAME:   index_vector_dim = 1 : i64
4456  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4457  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4458  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4459  // CHECK-SAME: unique_indices = false
4460  %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4461  return %0 : tensor<?x?x?xf32>
4462}
4463
4464// CHECK-LABEL: @tensor_scatter_sub
4465func @tensor_scatter_sub(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4466  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4467  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4468  // CHECK:    %1 = mhlo.subtract %arg3, %arg4 : tensor<f32>
4469  // CHECK:    "mhlo.return"(%1) : (tensor<f32>) -> ()
4470  // CHECK:  })
4471  // CHECK-SAME: indices_are_sorted = false
4472  // CHECK-SAME: scatter_dimension_numbers
4473  // CHECK-SAME:   index_vector_dim = 1 : i64
4474  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4475  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4476  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4477  // CHECK-SAME: unique_indices = false
4478  %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4479  return %0 : tensor<?x?x?xf32>
4480}
4481
4482// CHECK-LABEL: @tensor_scatter_min
4483func @tensor_scatter_min(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4484  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4485  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4486  // CHECK:    %1 = mhlo.minimum %arg3, %arg4 : tensor<f32>
4487  // CHECK:    "mhlo.return"(%1) : (tensor<f32>) -> ()
4488  // CHECK:  })
4489  // CHECK-SAME: indices_are_sorted = false
4490  // CHECK-SAME: scatter_dimension_numbers
4491  // CHECK-SAME:   index_vector_dim = 1 : i64
4492  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4493  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4494  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4495  // CHECK-SAME: unique_indices = false
4496  %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4497  return %0 : tensor<?x?x?xf32>
4498}
4499
4500// CHECK-LABEL: @tensor_scatter_max
4501func @tensor_scatter_max(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4502  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4503  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4504  // CHECK:    %1 = mhlo.maximum %arg3, %arg4 : tensor<f32>
4505  // CHECK:    "mhlo.return"(%1) : (tensor<f32>) -> ()
4506  // CHECK:  })
4507  // CHECK-SAME: indices_are_sorted = false
4508  // CHECK-SAME: scatter_dimension_numbers
4509  // CHECK-SAME:   index_vector_dim = 1 : i64
4510  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4511  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4512  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4513  // CHECK-SAME: unique_indices = false
4514  %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4515  return %0 : tensor<?x?x?xf32>
4516}
4517
4518//===----------------------------------------------------------------------===//
4519// tf.RandomShuffle legalization
4520//===----------------------------------------------------------------------===//
4521
4522// CHECK-LABEL: @random_shuffle_first_dim_1
4523// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32>
4524func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> {
4525  %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>)
4526  // CHECK-NEXT: return [[INPUT]]
4527  return %0: tensor<1x?xf32>
4528}
4529
4530// CHECK-LABEL: @random_shuffle_1D_16
4531// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32>
4532func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> {
4533  // CHECK: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64>
4534  // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
4535  // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32>
4536  // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]])
4537  // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( {
4538  // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>):
4539  // CHECK:   "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"}
4540  // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>)
4541  // CHECK: return [[SORT]]#1
4542  %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>)
4543  return %0: tensor<16xf32>
4544}
4545
4546// CHECK-LABEL: @random_shuffle_1D_10240
4547func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> {
4548  // CHECK: mhlo.rng_uniform
4549  // CHECK: mhlo.sort
4550  // CHECK: mhlo.rng_uniform
4551  // CHECK: mhlo.sort
4552  %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>)
4553  return %0: tensor<10240xf32>
4554}
4555
4556// CHECK-LABEL: @random_shuffle_3D
4557// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32>
4558func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
4559  // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
4560
4561  // CHECK: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64>
4562  // CHECK: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
4563  // CHECK: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32>
4564  // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]])
4565
4566  // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32>
4567  // CHECK: [[WHILE_INIT:%.*]] = "mhlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]])
4568
4569  // CHECK: [[WHILE_OUT:%.*]] = "mhlo.while"([[WHILE_INIT]]) ( {
4570  // CHECK: ^{{.*}}([[COND_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>):
4571  // CHECK:   [[IV:%.*]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
4572  // CHECK:   [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32>
4573  // CHECK:   [[CMP:%.*]] = "mhlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"}
4574  // CHECK:   "mhlo.return"([[CMP]])
4575  // CHECK: },  {
4576  // CHECK: ^{{.*}}([[BODY_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>):
4577  // CHECK:   [[IV:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
4578  // CHECK:   [[SWAPS:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
4579  // CHECK:   [[INDICES:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
4580  // CHECK:   [[SRC_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
4581  // CHECK:   [[SWP_IDX:%.*]] = "mhlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
4582  // CHECK:   [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor<i32>
4583  // CHECK:   [[TGT_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>}
4584  // CHECK:   [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
4585  // CHECK:   [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
4586  // CHECK:   [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32>
4587  // CHECK:   [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]]
4588  // CHECK:   [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]])
4589  // CHECK:   "mhlo.return"([[NEW_TUPLE]])
4590  // CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>
4591
4592  // CHECK: [[SWAPED_INDICES:%.*]] = "mhlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32>
4593  // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[SWAPED_INDICES]])
4594  // CHECK-SAME: dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<[1, 2]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}
4595  // CHECK-SAME: indices_are_sorted = false
4596  // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> : tensor<3xi64>
4597  // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32>
4598
4599  // CHECK: return [[GATHER]]
4600
4601  %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>)
4602  return %0: tensor<4x?x16xf32>
4603}
4604
4605//===----------------------------------------------------------------------===//
4606// tf.AvgPool legalization
4607//===----------------------------------------------------------------------===//
4608
4609// CHECK-LABEL:   @avgpool_valid_padding
4610// CHECK-SAME:      [[ARG:%.+]]: tensor<2x12x21x7xf16>
4611// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32>
4612// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4613// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4614// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4615// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4616// CHECK:             "mhlo.return"([[ADD]])
4617// CHECK:           })
4618// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
4619// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
4620// CHECK-SAME:        -> tensor<2x3x5x7xf32>
4621// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4622// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4623// CHECK-SAME:        broadcast_dimensions = dense<>
4624// CHECK-SAME:        -> tensor<2x3x5x7xf32>
4625// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4626// CHECK-SAME:        -> tensor<2x3x5x7xf16>
4627// CHECK:           return [[CONV16]]
4628func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> {
4629  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16>
4630  return %0 : tensor<2x3x5x7xf16>
4631}
4632
4633// CHECK-LABEL:   @avgpool_3d_valid_padding
4634// CHECK-SAME:      [[ARG:%.+]]: tensor<2x4x12x21x7xf16>
4635// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32>
4636// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4637// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4638// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4639// CHECK:           [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4640// CHECK:             "mhlo.return"([[ADD]])
4641// CHECK:           })
4642// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
4643// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
4644// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
4645// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4646// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4647// CHECK-SAME:        broadcast_dimensions = dense<>
4648// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
4649// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4650// CHECK-SAME:        -> tensor<2x4x3x5x7xf16>
4651// CHECK:           return [[CONV16]]
4652func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> {
4653  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16>
4654  return %0 : tensor<2x4x3x5x7xf16>
4655}
4656
4657// CHECK-LABEL:   @avgpool_nchw_format
4658// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x12x21xf16>
4659// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32>
4660// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4661// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4662// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4663// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4664// CHECK:             "mhlo.return"([[ADD]])
4665// CHECK:           })
4666// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2]>
4667// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
4668// CHECK-SAME:        -> tensor<2x7x3x5xf32>
4669// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4670// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4671// CHECK-SAME:        broadcast_dimensions = dense<>
4672// CHECK-SAME:        -> tensor<2x7x3x5xf32>
4673// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4674// CHECK-SAME:        -> tensor<2x7x3x5xf16>
4675// CHECK:           return [[CONV16]]
4676func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> {
4677  %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16>
4678  return %0 : tensor<2x7x3x5xf16>
4679}
4680
4681// CHECK-LABEL:   @avgpool_3d_ncdhw_format
4682// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x4x12x21xf16>
4683// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32>
4684// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4685// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4686// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4687// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4688// CHECK:             "mhlo.return"([[ADD]])
4689// CHECK:           })
4690// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 2]>
4691// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
4692// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
4693// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4694// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4695// CHECK-SAME:        broadcast_dimensions = dense<>
4696// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
4697// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4698// CHECK-SAME:        -> tensor<2x7x4x3x5xf16>
4699// CHECK:           return [[CONV16]]
4700func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> {
4701  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16>
4702  return %0 : tensor<2x7x4x3x5xf16>
4703}
4704
4705// CHECK-LABEL:   @avgpool_same_padding(
4706// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
4707// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4708// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
4709// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4710// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4711// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4712// CHECK:           })
4713// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
4714// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
4715// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
4716// CHECK-SAME:        -> tensor<2x4x6x7xf32>
4717// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32>
4718// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
4719// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4720// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4721// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4722// CHECK:           })
4723// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
4724// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
4725// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
4726// CHECK-SAME:        -> tensor<2x4x6x7xf32>
4727// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32>
4728// CHECK:           return %[[RESULT]] : tensor<2x4x6x7xf32>
4729// CHECK:         }
4730func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> {
4731  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
4732  return %0 : tensor<2x4x6x7xf32>
4733}
4734
4735// CHECK-LABEL:   @avgpool_3d_same_padding(
4736// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
4737// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4738// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
4739// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4740// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4741// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4742// CHECK:           })
4743// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
4744// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
4745// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
4746// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
4747// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32>
4748// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
4749// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4750// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4751// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4752// CHECK:           })
4753// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
4754// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
4755// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
4756// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
4757// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]]
4758// CHECK:           return %[[RESULT]] : tensor<2x4x4x6x7xf32>
4759// CHECK:         }
4760func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> {
4761  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
4762  return %0 : tensor<2x4x4x6x7xf32>
4763}
4764
4765//===----------------------------------------------------------------------===//
4766// AvgPoolGrad op legalizations.
4767//===----------------------------------------------------------------------===//
4768
4769// CHECK-LABEL:   @avgpool_grad_valid_padding(
4770// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
4771// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4772// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4773// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
4774// CHECK_SAME:        broadcast_dimensions = dense<>
4775// CHECK_SAME:        -> tensor<10x12x16x64xf32>
4776// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4777// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
4778// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4779// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
4780// CHECK-SAME:        -> tensor<10x25x33x64xf32>
4781// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4782// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4783// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4784// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
4785// CHECK:           })
4786// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
4787// CHECK-SAME:        window_strides = dense<1>
4788// CHECK-SAME:        -> tensor<10x24x32x64xf32>
4789// CHECK:           return %[[RESULT]] : tensor<10x24x32x64xf32>
4790func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
4791  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4792  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4793     data_format = "NHWC",
4794     ksize = [1, 2, 2, 1],
4795     padding = "VALID",
4796     strides = [1, 2, 2, 1]
4797  } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32>
4798  return %result : tensor<10x24x32x64xf32>
4799}
4800
4801// CHECK-LABEL:   @avgpool_3d_grad_valid_padding(
4802// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
4803// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4804// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4805// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
4806// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4807// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 1, 0]>
4808// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
4809// CHECK-SAME:        interior_padding = dense<[0, 0, 1, 1, 0]>
4810// CHECK-SAME:        -> tensor<10x8x25x33x64xf32>
4811// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4812// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4813// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4814// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
4815// CHECK:           })
4816// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
4817// CHECK-SAME:        window_strides = dense<1>
4818// CHECK-SAME:        -> tensor<10x8x24x32x64xf32>
4819// CHECK:           return %[[RESULT]] : tensor<10x8x24x32x64xf32>
4820func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
4821  %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4822  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4823    data_format = "NDHWC",
4824    ksize = [1, 1, 2, 2, 1],
4825    padding = "VALID",
4826    strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32>
4827  return %result : tensor<10x8x24x32x64xf32>
4828}
4829
4830// CHECK-LABEL:   @avgpool_grad_same_padding(
4831// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
4832// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4833// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
4834// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4835// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4836// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4837// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4838// CHECK:           })
4839// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]>
4840// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
4841// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
4842// CHECK-SAME:        -> tensor<2x4x7x9xf32>
4843// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32>
4844// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4845// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 0]>
4846// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4847// CHECK-SAME:        interior_padding = dense<[0, 3, 3, 0]>
4848// CHECK-SAME:        -> tensor<2x14x27x9xf32>
4849// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4850// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4851// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4852// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4853// CHECK:           })
4854// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
4855// CHECK-SAME:        window_strides = dense<1>
4856// CHECK-SAME:        -> tensor<2x13x25x9xf32>
4857// CHECK:           return %[[RESULT]] : tensor<2x13x25x9xf32>
4858func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
4859  %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4860  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4861     data_format = "NHWC",
4862     ksize = [1, 2, 3, 1],
4863     padding = "SAME",
4864     strides = [1, 4, 4, 1]
4865  } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32>
4866  return %result : tensor<2x13x25x9xf32>
4867}
4868
4869// CHECK-LABEL:   @avgpool_3d_grad_same_padding(
4870// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
4871// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4872// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
4873// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4874// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4875// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4876// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4877// CHECK:           })
4878// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]>
4879// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
4880// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
4881// CHECK-SAME:        -> tensor<2x8x4x7x9xf32>
4882// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32>
4883// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4884// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1, 0]>
4885// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
4886// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3, 0]>
4887// CHECK-SAME:        -> tensor<2x8x14x27x9xf32>
4888// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4889// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4890// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4891// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4892// CHECK:           })
4893// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
4894// CHECK-SAME:        window_strides = dense<1>
4895// CHECK-SAME:        -> tensor<2x8x13x25x9xf32>
4896// CHECK:           return %[[RESULT]] : tensor<2x8x13x25x9xf32>
4897func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
4898  %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4899  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4900    data_format = "NDHWC",
4901    ksize = [1, 1, 2, 3, 1],
4902    padding = "SAME",
4903    strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32>
4904  return %result : tensor<2x8x13x25x9xf32>
4905}
4906
4907// CHECK-LABEL:   @avgpool_grad_nchw_format(
4908// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
4909// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4910// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
4911// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4912// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4913// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4914// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4915// CHECK:           })
4916// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]>
4917// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
4918// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
4919// CHECK-SAME:        -> tensor<2x9x4x7xf32>
4920// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32>
4921// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4922// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1]>
4923// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1]>
4924// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3]>
4925// CHECK-SAME:        -> tensor<2x9x14x27xf32>
4926// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4927// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4928// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4929// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4930// CHECK:           })
4931// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
4932// CHECK-SAME:        window_strides = dense<1>
4933// CHECK-SAME:        -> tensor<2x9x13x25xf32>
4934// CHECK:           return %[[RESULT]] : tensor<2x9x13x25xf32>
4935func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
4936  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4937  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4938     data_format = "NCHW",
4939     ksize = [1, 1, 2, 3],
4940     padding = "SAME",
4941     strides = [1, 1, 4, 4]
4942  } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32>
4943  return %result : tensor<2x9x13x25xf32>
4944}
4945
4946// CHECK-LABEL:   @avgpool_3d_grad_ncdwh_format(
4947// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
4948// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4949// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
4950// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4951// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4952// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4953// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4954// CHECK:           })
4955// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]>
4956// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
4957// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
4958// CHECK-SAME:        -> tensor<2x9x8x4x7xf32>
4959// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32>
4960// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4961// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 0, 1]>
4962// CHECK-SAME:        edge_padding_low = dense<[0, 0, 0, 1, 1]>
4963// CHECK-SAME:        interior_padding = dense<[0, 0, 0, 3, 3]>
4964// CHECK-SAME:        -> tensor<2x9x8x14x27xf32>
4965// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4966// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4967// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4968// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4969// CHECK:           })
4970// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
4971// CHECK-SAME:        window_strides = dense<1> : tensor<5xi64>
4972// CHECK-SAME:        -> tensor<2x9x8x13x25xf32>
4973// CHECK:           return %[[RESULT]] : tensor<2x9x8x13x25xf32>
4974func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
4975  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4976  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4977    data_format = "NCDHW",
4978    ksize = [1, 1, 1, 2, 3],
4979    padding = "SAME",
4980    strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32>
4981  return %result : tensor<2x9x8x13x25xf32>
4982}
4983
4984// CHECK-LABEL:   @avgpool_grad_bf16(
4985// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
4986// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
4987// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
4988// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
4989// CHECK-SAME:        broadcast_dimensions = dense<>
4990// CHECK-SAME:        -> tensor<10x12x16x64xbf16>
4991// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4992// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
4993// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4994// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
4995// CHECK-SAME:        -> tensor<10x25x33x64xbf16>
4996// CHECK:           %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "mhlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32>
4997// CHECK:           %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4998// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( {
4999// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
5000// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
5001// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
5002// CHECK:           })
5003// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
5004// CHECK-SAME:        window_strides = dense<1>
5005// CHECK-SAME:        -> tensor<10x24x32x64xf32>
5006// CHECK:           %[[RESULT_CONVERTED:.*]] = "mhlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16>
5007// CHECK:           return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16>
5008func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
5009  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
5010  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
5011     data_format = "NHWC",
5012     ksize = [1, 2, 2, 1],
5013     padding = "VALID",
5014     strides = [1, 2, 2, 1]
5015  } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16>
5016  return %result : tensor<10x24x32x64xbf16>
5017}
5018
5019// CHECK-LABEL: xla_sharding
5020func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
5021  // CHECK-NEXT: "mhlo.custom_call"(%arg0) {api_version = 1 : i32, backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""}
5022  %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
5023  return %0 : tensor<4x16xf32>
5024}
5025
5026// CHECK-LABEL: inplace_update_one
5027func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> {
5028  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
5029  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5030  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
5031  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
5032  // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]])
5033  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32>
5034
5035  // CHECK: return [[UPDATE]]
5036  return %0 : tensor<8x4xf32>
5037}
5038
5039// CHECK-LABEL: inplace_update_three
5040func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> {
5041  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
5042  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5043  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5044  // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
5045  // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5046  // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5047  // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
5048  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
5049  // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
5050  // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]])
5051  // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]])
5052  // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]])
5053  // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]])
5054  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32>
5055
5056  // CHECK:  return [[UPDATE3]] : tensor<8x8x4xf32>
5057  return %0 : tensor<8x8x4xf32>
5058}
5059
5060
5061// CHECK-LABEL: xla_dynamic_update_slice
5062func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> {
5063  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
5064  // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
5065  // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
5066  // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
5067  // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32>
5068  // CHECK: return [[DUS]]
5069  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32>
5070  return %0 : tensor<4x16xf32>
5071}
5072
5073// CHECK-LABEL: xla_dynamic_update_slice2
5074func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> {
5075  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
5076  // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
5077  // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
5078  // CHECK: return [[DUS]]
5079  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32>
5080  return %0 : tensor<4xf32>
5081}
5082
5083//===----------------------------------------------------------------------===//
5084// AllToAll op legalizations.
5085//===----------------------------------------------------------------------===//
5086
5087// CHECK-LABEL: func @alltoall_basic
5088func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
5089  %group_assignment = "tf.Const" () {
5090    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
5091  } : () -> tensor<3x4xi32>
5092  %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} :  (tensor<10xf32>, tensor<3x4xi32>)  -> tensor<10xf32>
5093  // CHECK: mhlo.all_to_all
5094  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
5095  return %result : tensor<10xf32>
5096}
5097
5098//===----------------------------------------------------------------------===//
5099// Cumsum op legalizations.
5100//===----------------------------------------------------------------------===//
5101
5102// CHECK-LABEL: func @cumsum_static
5103// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5104func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5105  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5106  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32>
5107  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5108  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
5109  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5110  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5111  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
5112  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5113  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32>
5114  // CHECK: return [[CONVERT_REDUCE]]
5115  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5116  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5117  return %1 : tensor<4xf32>
5118}
5119
5120// CHECK-LABEL: func @cumsum_exclusive
5121// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5122func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5123  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5124  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32>
5125  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5126  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
5127  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5128  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5129  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
5130  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5131  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5132  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32>
5133  // CHECK: return [[CONVERT_REDUCE]]
5134  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5135  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5136  return %1 : tensor<4xf32>
5137}
5138
5139// CHECK-LABEL: func @cumsum_reverse
5140// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5141func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5142  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5143  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5144  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32>
5145  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5146  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
5147  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5148  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5149  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
5150  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5151  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32>
5152  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5153  // CHECK: return [[REVERSE_BACK]]
5154  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5155  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5156  return %1 : tensor<4xf32>
5157}
5158
5159// CHECK-LABEL: func @cumsum_exclusive_reverse
5160// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
5161func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5162  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
5163  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5164  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32>
5165  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
5166  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
5167  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
5168  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
5169  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
5170  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5171  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
5172  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32>
5173  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
5174  // CHECK: return [[REVERSE_BACK]]
5175  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5176  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5177  return %1 : tensor<4xf32>
5178}
5179
5180// CHECK-LABEL: func @cumsum_empty
5181func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> {
5182  %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
5183
5184  // CHECK: mhlo.constant dense<> : tensor<0xf32>
5185  %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32>
5186  return %1 : tensor<0xf32>
5187}
5188
5189// CHECK-LABEL: func @cumsum_dynamic
5190func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
5191  // CHECK: "tf.Cumsum"
5192  %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32>
5193  return %0 : tensor<?xf32>
5194}
5195
5196//===----------------------------------------------------------------------===//
5197// Cumprod op legalizations.
5198//===----------------------------------------------------------------------===//
5199
5200// CHECK-LABEL: func @cumprod
5201func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5202  // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
5203  // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
5204  // CHECK:   mhlo.mul
5205  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
5206  %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
5207  return %1 : tensor<4xf32>
5208}
5209
5210//===----------------------------------------------------------------------===//
5211// Qr op legalization
5212//===----------------------------------------------------------------------===//
5213
5214// CHECK:  func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
5215func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) {
5216  // The tf.Qr lowering is a full algorithm that is not effective to verify with
5217  // FileCheck. Just verify that it converted.
5218  // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is
5219  // really only applicable to certain legacy uses.
5220  // CHECK-NOT: "tf.Qr"
5221  %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
5222  return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
5223}
5224
5225//===----------------------------------------------------------------------===//
5226// tf.Softplus legalization
5227//===----------------------------------------------------------------------===//
5228
5229// CHECK-LABEL: func @softplus_f16
5230// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
5231func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
5232  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5233  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16>
5234  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5235  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16>
5236  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5237  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5238  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5239  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5240  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5241  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5242  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5243  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
5244
5245  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf16>
5246  return %0 : tensor<8x16xf16>
5247}
5248
5249// CHECK-LABEL: func @softplus_bf16
5250// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
5251func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
5252  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5253  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16>
5254  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5255  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16>
5256  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5257  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5258  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5259  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5260  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5261  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5262  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5263  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
5264
5265  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xbf16>
5266  return %0 : tensor<8x16xbf16>
5267}
5268
5269// CHECK-LABEL: func @softplus_f32
5270// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
5271func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
5272  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5273  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32>
5274  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5275  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
5276  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5277  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5278  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5279  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5280  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5281  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5282  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5283  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
5284
5285  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf32>
5286  return %0 : tensor<8x16xf32>
5287}
5288
5289// CHECK-LABEL: func @softplus_f64
5290// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
5291func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
5292  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5293  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64>
5294  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5295  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64>
5296  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5297  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5298  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5299  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5300  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5301  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5302  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5303  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
5304
5305  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf64>
5306  return %0 : tensor<8x16xf64>
5307}
5308
5309// CHECK-LABEL: @xla_gather
5310func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
5311  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64>
5312
5313  // CHECK: "mhlo.gather"
5314  // CHECK-SAME: dimension_numbers =
5315  // CHECK-SAME:   collapsed_slice_dims = dense<0> : tensor<1xi64>
5316  // CHECK-SAME:   index_vector_dim = 1 : i64
5317  // CHECK-SAME:   offset_dims = dense<1> : tensor<1xi64>
5318  // CHECK-SAME:   start_index_map = dense<0> : tensor<1xi64>
5319  // CHECK-SAME: indices_are_sorted = true
5320  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
5321
5322  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32>
5323  return %0 : tensor<10x1x300xf32>
5324}
5325
5326// CHECK-LABEL: @xla_gather_i32
5327func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
5328  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32>
5329
5330  // CHECK: "mhlo.gather"
5331  // CHECK-SAME: dimension_numbers =
5332  // CHECK-SAME:   collapsed_slice_dims = dense<0> : tensor<1xi64>
5333  // CHECK-SAME:   index_vector_dim = 1 : i64
5334  // CHECK-SAME:   offset_dims = dense<1> : tensor<1xi64>
5335  // CHECK-SAME:   start_index_map = dense<0> : tensor<1xi64>
5336  // CHECK-SAME: indices_are_sorted = true
5337  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
5338
5339  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32>
5340  return %0 : tensor<10x1x300xf32>
5341}
5342
5343
5344// CHECK: func @stridedslice_with_i32
5345func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} {
5346// CHECK-NOT: tf.StridedSlice
5347// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic-slice
5348// CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[DYNSLICE]])
5349// CHECK: return [[RESHAPE]]
5350  %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
5351  %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
5352  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
5353  %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
5354  %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
5355  %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
5356  %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32>
5357  return %6 : tensor<4xf32>
5358}
5359
5360func @replica_id() -> tensor<i32> {
5361  // CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32>
5362  // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32>
5363  %0 = "tf.XlaReplicaId"() : () -> tensor<i32>
5364  return %0 : tensor<i32>
5365}
5366
5367// CHECK: func @angle_c64
5368// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>)
5369func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> {
5370// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]])
5371// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]])
5372// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]]
5373  %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32>
5374  return %0 : tensor<f32>
5375}
5376
5377//===----------------------------------------------------------------------===//
5378// tf.XlaDotV2 legalization
5379//===----------------------------------------------------------------------===//
5380
5381// CHECK-LABEL: @xladotv2_matmul(
5382// CHECK-SAME:    %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32>
5383func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> {
5384  // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) {
5385  // CHECK-SAME:  dot_dimension_numbers = {
5386  // CHECK-SAME:      lhs_batching_dimensions = dense<> : tensor<0xi64>,
5387  // CHECK-SAME:      lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
5388  // CHECK-SAME:      rhs_batching_dimensions = dense<> : tensor<0xi64>,
5389  // CHECK-SAME:      rhs_contracting_dimensions = dense<0> : tensor<1xi64>
5390  // CHECK-SAME:  }, precision_config = []} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32>
5391  %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32>
5392  return %res : tensor<64x16xi32>
5393}
5394
5395//===----------------------------------------------------------------------===//
5396// tf.Print legalization
5397//===----------------------------------------------------------------------===//
5398// CHECK-LABEL: @simple_print
5399func @simple_print() -> (tensor<*xi32>) {
5400  // CHECK: mhlo.constant dense<1> : tensor<i32>
5401  // CHECK: tensor.cast {{.*}} : tensor<i32> to tensor<*xi32>
5402  // CHECK: "mhlo.print"({{.*}}) : (tensor<*xi32>) -> tensor<*xi32>
5403  %const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
5404  %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>)
5405  return %print: tensor<*xi32>
5406}
5407