1// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s 2// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all 3// This test runs twice: 4// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying 5// that the chlo ops emit produces more useful tests. 6// 2. With chlo legalization enabled, verifying diagnostics to pick up any 7// issues with the full lowering (can catch some broadcasting corner 8// cases which emit with a warning). 9 10//===----------------------------------------------------------------------===// 11// BatchNorm op legalizations. 12//===----------------------------------------------------------------------===// 13 14// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same 15// code), so only do a couple of basic checks. 16 17// CHECK-LABEL: fusedBatchNormV2_noTraining 18func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 19 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 20 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 21 return %0#0 : tensor<8x8x8x8xf32> 22} 23 24// CHECK-LABEL: fusedBatchNormV2_training 25func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 26 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 27 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 28 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 29 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 30 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 31 // CHECK: mhlo.constant 32 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 33 return %0#0 : tensor<8x8x8x8xf32> 34} 35 36// CHECK-LABEL: fusedBatchNormV3_noTraining 37func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 38 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 39 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 40 return %0#0 : tensor<8x8x8x8xf32> 41} 42 43// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision 44// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) 45func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { 46 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 47 // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} 48 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) 49 // CHECK: [[Y_CONVERT:%.*]] = "mhlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 50 // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32> 51 // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> 52 // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] 53 return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> 54} 55 56// CHECK-LABEL: fusedBatchNormV3_training 57func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 58 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 59 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 60 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 61 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 62 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 63 // CHECK: mhlo.constant 64 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 65 return %0#0 : tensor<8x8x8x8xf32> 66} 67 68// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance 69func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { 70 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 71 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 72 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 73 // CHECK: return %[[VAR]] 74 return %0#4 : tensor<8xf32> 75} 76 77// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor 78func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 79 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 80 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 81 // CHECK-DAG: %[[BATCH_MEAN:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} 82 // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} 83 84 // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> 85 // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] 86 87 // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> 88 // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> 89 90 // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 91 // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] 92 // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] 93 94 // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 95 // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] 96 // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] 97 98 // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] 99 return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> 100} 101 102// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision 103func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 104 // CHECK: "mhlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 105 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 106 // CHECK: "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 107 return %0#0 : tensor<8x8x8x8xbf16> 108} 109 110// CHECK-LABEL: fusedBatchNormV3_NCHW 111func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 112 // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 113 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 114 return %0#0 : tensor<8x8x8x8xf32> 115} 116 117// CHECK-LABEL: fusedBatchNormV3_NDHWC 118func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) { 119 // CHECK: feature_index = 4 : i64 120 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 121 return %0#0 : tensor<8x8x8x8x8xf32> 122} 123 124// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported 125func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 126 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32> 127 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 128 return %0#0 : tensor<?x?x?x?xf32> 129} 130 131// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 132func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 133 // CHECK: tf.FusedBatchNormV3 134 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 135 return %0#0 : tensor<?x?x?x?xf32> 136} 137 138// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 139func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) { 140 // CHECK: tf.FusedBatchNormV3 141 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) 142 return %0#0 : tensor<?x6x?x?xf32> 143} 144 145// CHECK-LABEL: fusedBatchNormGrad_noTraining 146func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 147 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 148 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 149 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 150 151 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 152 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 153 154 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 155 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 156 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 157 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 158 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 159 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 160 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 161 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 162 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 163 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 164 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 165 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 166 167 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 168 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 169 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 170 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 171 172 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 173 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 174 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 175 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 176 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 177 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 178 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 179 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 180 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 181 182 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 183 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 184 185 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 186 return %0#0 : tensor<8x8x8x8xf32> 187} 188 189// CHECK-LABEL: fusedBatchNormGrad_Training 190func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 191 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 192 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 193 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 194 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 195 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 196 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 197 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 198 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 199 200 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 201 return %0#0 : tensor<8x8x8x8xf32> 202} 203 204// CHECK-LABEL: fusedBatchNormGradV2_noTraining 205func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 206 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 207 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 208 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 209 210 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 211 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 212 213 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 214 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 215 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 216 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 217 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 218 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 219 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 220 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 221 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 222 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 223 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 224 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 225 226 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 227 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 228 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 229 230 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 231 232 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 233 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 234 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 235 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 236 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 237 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 238 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 239 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 240 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 241 242 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 243 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 244 245 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 246 return %0#0 : tensor<8x8x8x8xf32> 247} 248 249// CHECK-LABEL: fusedBatchNormGradV2_Training 250func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 251 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 252 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 253 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 254 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 255 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 256 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 257 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 258 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 259 260 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 261 return %0#0 : tensor<8x8x8x8xf32> 262} 263 264// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision 265func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 266 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 267 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 268 269 // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 270 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 271 272 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 273 return %0#0 : tensor<8x8x8x8xbf16> 274} 275 276// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision 277func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 278 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 279 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 280 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 281 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 282 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 283 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 284 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 285 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 286 287 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 288 return %0#0 : tensor<8x8x8x8xbf16> 289} 290 291// CHECK-LABEL: fusedBatchNormGradV3_noTraining 292func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 293 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 294 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 295 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 296 297 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 298 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 299 300 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 301 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 302 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 303 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 304 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 305 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 306 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 307 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 308 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 309 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 310 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 311 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 312 313 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 314 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 315 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 316 317 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 318 319 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 320 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 321 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 322 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 323 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 324 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 325 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 326 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 327 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 328 329 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 330 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 331 332 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 333 return %0#0 : tensor<8x8x8x8xf32> 334} 335 336// CHECK-LABEL: fusedBatchNormGradV3_Training 337func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { 338 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 339 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 340 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 341 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 342 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 343 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 344 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 345 // CHECK: return %[[x_backprop]] 346 // CHECK-SAME: tensor<8x8x8x8xf32> 347 348 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) 349 return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> 350} 351 352// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision 353func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 354 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 355 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 356 357 // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 358 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 359 360 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 361 return %0#0 : tensor<8x8x8x8xbf16> 362} 363 364// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision 365func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 366 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 367 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 368 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 369 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 370 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 371 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 372 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 373 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 374 375 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 376 return %0#0 : tensor<8x8x8x8xbf16> 377} 378 379// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW 380func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 381 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 382 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 383 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 384 385 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 386 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 387 388 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 389 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 390 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 391 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 392 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 393 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 394 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 395 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 396 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 397 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 398 // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 399 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 400 401 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 402 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 403 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 404 405 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 406 407 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 408 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 409 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 410 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 411 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 412 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 413 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 414 // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 415 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 416 417 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 418 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 419 420 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 421 return %0#0 : tensor<8x8x8x8xf32> 422} 423 424// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW 425func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 426 // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 427 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 428 return %0#0 : tensor<8x8x8x8xf32> 429} 430 431//===----------------------------------------------------------------------===// 432// Bias op legalizations. 433//===----------------------------------------------------------------------===// 434 435// CHECK-LABEL: func @biasAdd_default 436func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 437 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 438 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 439 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 440 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 441 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 442 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 443 return %0 : tensor<1x32x10x32xi32> 444} 445 446// CHECK-LABEL: func @biasAdd_NHWC 447func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 448 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 449 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 450 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 451 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 452 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 453 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 454 return %0 : tensor<1x32x10x32xi32> 455} 456 457// CHECK-LABEL: func @biasAdd_NCHW 458func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 459 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 460 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 461 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 462 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 463 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 464 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 465 return %0 : tensor<1x32x10x32xi32> 466} 467 468// CHECK-LABEL: func @biasAdd_dynamic 469func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { 470 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 471 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 472 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 473 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 474 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 475 %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> 476 return %0 : tensor<?x?x?x?xi32> 477} 478 479// CHECK-LABEL: func @biasAdd_partial_dynamic 480func @biasAdd_partial_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<512xi32>) -> tensor<?x?x?x512xi32> { 481 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 482 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 483 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 484 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 485 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 486 // CHECK: %[[CAST:.+]] = tensor.cast %[[RESULT]] : tensor<?x?x?x?xi32> to tensor<?x?x?x512xi32> 487 // CHECK: return %[[CAST]] : tensor<?x?x?x512xi32> 488 %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<?x?x?x?xi32>, tensor<512xi32>) -> tensor<?x?x?x512xi32> 489 return %0 : tensor<?x?x?x512xi32> 490} 491 492 493//===----------------------------------------------------------------------===// 494// ClipByValue 495//===----------------------------------------------------------------------===// 496 497// CHECK-LABEL: @clip 498func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> { 499 // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) 500 501 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> 502 // CHECK: return [[VAL]] 503 return %0 : tensor<f32> 504} 505 506// CHECK-LABEL: @clip_dynamic 507func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> { 508 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) 509 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 510 511 // CHECK: return [[CLAMP]] 512 return %0 : tensor<?xf32> 513} 514 515// CHECK-LABEL: @clip_static_broadcast 516func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> { 517 // CHECK-DAG: [[SHPIDX:%.+]] = mhlo.constant dense<5> 518 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 519 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 520 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) 521 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 522 523 // CHECK: return [[CLAMP]] 524 return %0 : tensor<5xf32> 525} 526 527 528// CHECK-LABEL: @clip_dynamic_broadcast 529func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> { 530 // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 531 // CHECK: [[SHPIDX:%.+]] = index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> 532 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 533 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 534 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) 535 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 536 537 // CHECK: return [[CLAMP]] 538 return %0 : tensor<?xf32> 539} 540 541//===----------------------------------------------------------------------===// 542// DiagPart 543//===----------------------------------------------------------------------===// 544 545// CHECK-LABEL: func @diag_part 546// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> 547func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { 548 // CHECK: %[[RS:.*]] = "mhlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> 549 // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> 550 // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> 551 // CHECK-DAG: %[[COMP:.*]] = "mhlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> 552 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 553 // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32> 554 // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32> 555 // CHECK-DAG: %[[RED:.*]] = "mhlo.reduce"(%[[SEL]], %[[ZERO]]) 556 // CHECK-DAG: mhlo.add 557 // CHECK-DAG: {dimensions = dense<0> : tensor<1xi64>} : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32> 558 // CHECK-DAG: %[[RES:.*]] = "mhlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32> 559 // CHECK-DAG: return %[[RES]] : tensor<4x3xf32> 560 %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> 561 return %0: tensor<4x3xf32> 562} 563 564//===----------------------------------------------------------------------===// 565// MatrixDiagPart 566//===----------------------------------------------------------------------===// 567 568// CHECK-LABEL: func @matrix_diag_part 569// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> 570func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 571 // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32> 572 // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32> 573 // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32> 574 // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32> 575 // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32> 576 // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 577 // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1> 578 // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 579 // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1> 580 // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 581 // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32> 582 // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 583 // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32> 584 // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 585 // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32> 586 // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 587 // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32> 588 // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 589 // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32> 590 // CHECK-DAG: %[[V19:.*]] = "mhlo.negate"(%[[V18]]) : (tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 591 // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 592 // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32> 593 // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 594 // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32> 595 // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32> 596 // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 597 // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32> 598 // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 599 // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 600 // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32> 601 // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32> 602 // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32> 603 // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32> 604 // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32> 605 // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 606 // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 607 // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1> 608 // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 609 // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 610 // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1> 611 // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> 612 // CHECK-DAG: %[[V41:.*]] = "mhlo.reshape"(%[[V40]]) : (tensor<1x22x128xi1>) -> tensor<22x128xi1> 613 // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> 614 // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = {collapsed_slice_dims = dense<[1, 2]> : tensor<2xi64>, index_vector_dim = 0 : i64, offset_dims = dense<0> : tensor<1xi64>, start_index_map = dense<[1, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> 615 // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1> 616 // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32> 617 // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32> 618 // CHECK: return %[[V46]] : tensor<7x22x128xi32> 619 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 620 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 621 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 622 T = i32, align = "RIGHT_LEFT" 623 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 624 return %2: tensor<7x22x128xi32> 625} 626 627// CHECK-LABEL: func @matrix_diag_part_single_diagonal 628func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { 629 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 630 %1 = mhlo.constant dense<0> : tensor<2xi32> // k 631 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 632 T = i32, align = "RIGHT_LEFT" 633 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32> 634 // CHECK: %[[result:.*]] = "mhlo.reshape"({{.*}}) : (tensor<7x1x128xi32>) -> tensor<7x128xi32> 635 // CHECK: return %[[result]] : tensor<7x128xi32> 636 return %2: tensor<7x128xi32> 637} 638 639// CHECK-LABEL: func @matrix_diag_part_align_ll 640func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 641 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 642 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 643 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 644 T = i32, align = "LEFT_LEFT" 645 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 646 // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1> 647 // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 648 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 649 return %2: tensor<7x22x128xi32> 650} 651 652// CHECK-LABEL: func @matrix_diag_part_align_lr 653func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 654 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 655 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 656 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 657 T = i32, align = "LEFT_RIGHT" 658 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 659 // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "LE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 660 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 661 return %2: tensor<7x22x128xi32> 662} 663 664// CHECK-LABEL: func @matrix_diag_part_align_rl 665func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 666 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 667 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 668 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 669 T = i32, align = "RIGHT_LEFT" 670 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 671 // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 672 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 673 return %2: tensor<7x22x128xi32> 674} 675 676// CHECK-LABEL: func @matrix_diag_part_align_rr 677func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 678 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 679 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 680 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 681 T = i32, align = "RIGHT_RIGHT" 682 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 683 // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1> 684 // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 685 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 686 return %2: tensor<7x22x128xi32> 687} 688 689// CHECK-LABEL: func @matrix_diag_part_align_7d 690// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> 691func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { 692 %0 = mhlo.constant dense<-1.> : tensor<f32> // padding value 693 %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k 694 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 695 T = f32, align = "LEFT_RIGHT" 696 } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32> 697 return %2: tensor<3x5x7x9x11x4x10xf32> 698} 699 700//===----------------------------------------------------------------------===// 701// Erf 702//===----------------------------------------------------------------------===// 703 704// CHECK-LABEL: func @erf 705func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 706 // CHECK: chlo.erf %arg0 : tensor<2x3xf32> 707 %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 708 return %0 : tensor<2x3xf32> 709} 710 711//===----------------------------------------------------------------------===// 712// Erfc 713//===----------------------------------------------------------------------===// 714 715// CHECK-LABEL: func @erfc 716func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 717 // CHECK: chlo.erfc %arg0 : tensor<2x3xf32> 718 %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 719 return %0 : tensor<2x3xf32> 720} 721 722//===----------------------------------------------------------------------===// 723// Einsum. 724//===----------------------------------------------------------------------===// 725 726// CHECK-LABEL: func @einsum 727func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { 728 // CHECK: mhlo.einsum 729 %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> 730 return %0: tensor<2x4xf32> 731} 732 733// CHECK-LABEL: func @unary_einsum 734func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { 735 // CHECK: mhlo.unary_einsum 736 %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> 737 return %0: tensor<2x2xf32> 738} 739 740//===----------------------------------------------------------------------===// 741// FloorDiv and FloorMod. 742//===----------------------------------------------------------------------===// 743 744// CHECK-LABEL: func @floordiv_broadcast_i32 745func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 746 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 747 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 748 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"} 749 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 750 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 751 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 752 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 753 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 754 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 755 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 756 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 757 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 758 // CHECK: return [[SELECT]] 759 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 760 return %0: tensor<2x3xi32> 761} 762 763// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 764func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 765 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 766 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]] 767 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 768 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 769 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 770 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 771 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 772 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 773 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 774 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 775 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 776 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 777 // CHECK: return [[SELECT]] 778 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 779 return %0: tensor<2x3xi32> 780} 781 782// CHECK-LABEL: func @floordiv_f32 783func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { 784 // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 785 // CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]]) 786 // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> 787 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 788 return %0: tensor<2xf32> 789} 790 791// CHECK-LABEL: func @floordiv_bf16 792func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { 793 // CHECK-NEXT: mhlo.convert 794 // CHECK-NEXT: mhlo.convert 795 // CHECK-NEXT: chlo.broadcast_divide 796 // CHECK-NEXT: mhlo.floor 797 // CHECK-NEXT: mhlo.convert 798 // CHECK-NEXT: return 799 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> 800 return %0: tensor<2xbf16> 801} 802 803// CHECK-LABEL: func @floordiv_f16_broadcast 804func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { 805 // CHECK-NEXT: chlo.broadcast_divide 806 // CHECK-NEXT: mhlo.floor 807 // CHECK-NEXT: return 808 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> 809 return %0: tensor<2x3xf16> 810} 811 812// CHECK-LABEL: func @floordiv_dynamic 813func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 814 // CHECK-DAG: [[DIV:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 815 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[DIV]], %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 816 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[MUL]], %arg0 {comparison_direction = "NE"} 817 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 818 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 819 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 820 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 821 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 822 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 823 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 824 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[DIV]], [[ONES]] 825 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[SUB]], [[DIV]]) 826 // CHECK: return [[SELECT]] 827 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 828 return %0: tensor<?x?xi32> 829} 830 831// CHECK-LABEL: func @floordiv_unranked 832func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 833 // CHECK-NOT: tf.FloorDiv 834 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 835 return %0: tensor<*xf32> 836} 837 838// CHECK-LABEL: func @floordiv_int 839func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 840 // CHECK: tf.FloorDiv 841 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 842 return %0: tensor<*xi32> 843} 844 845// CHECK-LABEL: func @floormod_broadcast_numerator 846func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 847 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 848 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 849 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 850 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 851 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 852 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 853 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} 854 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 855 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] 856 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 857 // CHECK-NEXT: return [[SELECT]] 858 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 859 return %0: tensor<2x3xi32> 860} 861 862// CHECK-LABEL: func @floormod_broadcast_denominator 863func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 864 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 865 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 866 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} 867 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 868 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 869 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 870 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 871 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 872 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 873 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 874 // CHECK-NEXT: return [[SELECT]] 875 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 876 return %0: tensor<2x3xi32> 877} 878 879// CHECK-LABEL: func @floormod_dynamic 880func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 881 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 882 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 883 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} 884 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 885 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 886 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 887 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 888 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 889 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 890 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 891 // CHECK-NEXT: return [[SELECT]] 892 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 893 return %0: tensor<?x?xi32> 894} 895 896// CHECK-LABEL: func @floormod_unranked 897func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 898 // CHECK-NOT: tf.FloorMod 899 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 900 return %0: tensor<*xi32> 901} 902 903//===----------------------------------------------------------------------===// 904// OnesLike 905//===----------------------------------------------------------------------===// 906 907// CHECK-LABEL: @ones_like 908// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) 909func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { 910 // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 1.0{{.*}}} 911 // CHECK: return %[[RES]] 912 %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> 913 return %0 : tensor<2x?xf32> 914} 915 916//===----------------------------------------------------------------------===// 917// ZerosLike 918//===----------------------------------------------------------------------===// 919 920// CHECK-LABEL: @zeros_like 921// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) 922func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { 923 // CHECK: %[[RES:.*]] = "chlo.constant_like"(%[[ARG]]) {value = 0.0{{.*}}} 924 // CHECK: return %[[RES]] 925 %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> 926 return %0 : tensor<2x?xf32> 927} 928 929//===----------------------------------------------------------------------===// 930// BroadcastTo. 931//===----------------------------------------------------------------------===// 932 933// CHECK-LABEL: func @broadcast_to 934func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { 935 %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> 936 937 // CHECK: [[CST:%.+]] = mhlo.constant 938 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CST]]) 939 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 940 %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> 941 return %0 : tensor<16x16x16x16xf32> 942} 943 944// CHECK-LABEL: func @broadcast_scalar_to_unranked 945// CHECK: (%[[ARG0:.*]]: tensor<f32>, %[[SHAPE:.*]]: tensor<?xi32>) 946func @broadcast_scalar_to_unranked(%arg0: tensor<f32>, %shape: tensor<?xi32>) -> tensor<*xf32> { 947 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) 948 // CHECK-SAME: {broadcast_dimensions = dense<> : tensor<0xi64>} 949 %0 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<f32>, tensor<?xi32>) -> tensor<*xf32> 950 return %0 : tensor<*xf32> 951} 952 953//===----------------------------------------------------------------------===// 954// Complex op legalizations. 955//===----------------------------------------------------------------------===// 956 957// CHECK-LABEL: func @complex 958func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> { 959 // CHECK: chlo.broadcast_complex 960 %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>> 961 return %1 : tensor<3xcomplex<f32>> 962} 963 964// CHECK-LABEL: func @imag 965func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 966 // CHECK: "mhlo.imag" 967 %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 968 return %1 : tensor<3xf32> 969} 970 971// CHECK-LABEL: func @real 972func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 973 // CHECK: "mhlo.real" 974 %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 975 return %1 : tensor<3xf32> 976} 977 978//===----------------------------------------------------------------------===// 979// Concat op legalizations. 980//===----------------------------------------------------------------------===// 981 982// CHECK-LABEL: func @concat_v2 983func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 984 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 985 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 986 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 987 return %1 : tensor<6x3xf32> 988} 989 990// CHECK-LABEL: func @concat_v2_neg_axis 991func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 992 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 993 994 %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64> 995 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 996 return %1 : tensor<6x3xf32> 997} 998 999// CHECK-LABEL: func @concat_v2_1d_axis 1000func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { 1001 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> 1002 1003 %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> 1004 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> 1005 return %1 : tensor<3x6xf32> 1006} 1007 1008// CHECK-LABEL: func @concat_v2_non_const_axis 1009func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> { 1010 // CHECK: "tf.ConcatV2" 1011 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32> 1012 return %1 : tensor<3x6xf32> 1013} 1014 1015// CHECK-LABEL: func @concat_v2_unranked 1016func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 1017 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 1018 // CHECK: "tf.ConcatV2" 1019 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32> 1020 return %1 : tensor<*xf32> 1021} 1022 1023//===----------------------------------------------------------------------===// 1024// Pad op legalizations. 1025//===----------------------------------------------------------------------===// 1026 1027// CHECK-LABEL: func @padv2_1D 1028func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> { 1029 %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> 1030 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1031 // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, 1032 // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, 1033 // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> 1034 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32> 1035 return %1 : tensor<6xf32> 1036} 1037 1038// CHECK-LABEL: func @padv2_2D 1039func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1040 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> 1041 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1042 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1043 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1044 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1045 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32> 1046 return %1 : tensor<6x9xf32> 1047} 1048 1049// CHECK-LABEL: func @padv2_i32_paddings 1050func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1051 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> 1052 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1053 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1054 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1055 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1056 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32> 1057 return %1 : tensor<6x9xf32> 1058} 1059 1060// CHECK-LABEL: func @padv2_dynamic 1061func @padv2_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<1x2xi64>) -> tensor<?xf32> { 1062 // CHECK: "mhlo.transpose"({{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x2xi64>) -> tensor<2x1xi64> 1063 // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2x1xi64>) -> tensor<2xi64> 1064 // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 1065 // CHECK: "mhlo.slice"({{.*}}) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 1066 // CHECK: "mhlo.dynamic_pad"({{.*}}) : (tensor<?xf32>, tensor<f32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<?xf32> 1067 %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor<?xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<?xf32> 1068 return %1 : tensor<?xf32> 1069} 1070 1071//===----------------------------------------------------------------------===// 1072// Identity op legalizations. 1073//===----------------------------------------------------------------------===// 1074 1075// CHECK-LABEL: func @identity 1076func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1077 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1078 %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1079 return %0: tensor<1xi32> 1080} 1081 1082// CHECK-LABEL: func @identityN 1083func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { 1084 // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> 1085 %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) 1086 return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> 1087} 1088 1089// CHECK-LABEL: func @stopgradient 1090func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1091 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1092 %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1093 return %0: tensor<1xi32> 1094} 1095 1096// CHECK-LABEL: func @preventgradient 1097func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1098 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1099 %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> 1100 return %0: tensor<1xi32> 1101} 1102 1103// CHECK-LABEL: func @checkNumerics 1104func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1105 // CHECK-NEXT: return %arg0 : tensor<1xf32> 1106 %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> 1107 return %0: tensor<1xf32> 1108} 1109 1110//===----------------------------------------------------------------------===// 1111// InfeedDequeueTuple legalization 1112//===----------------------------------------------------------------------===// 1113 1114// CHECK-LABEL: func @infeed_dequeue_tuple 1115func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { 1116// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token 1117// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = "", layout = [{{\[\[1, 3, 2, 0], \[1, 2, 0]]}}, unit]} : (!mhlo.token) -> tuple<tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>, !mhlo.token> 1118// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>, !mhlo.token>) -> tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>> 1119// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>) -> tensor<1x8x4x4xi32> 1120// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<1x8x4x4xi32>, tensor<1x100x1xf32>>) -> tensor<1x100x1xf32> 1121// CHECK: return [[RES_1]], [[RES_2]] 1122 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) 1123 return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> 1124} 1125 1126// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error 1127func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { 1128 // We expect legalization to fail for dynamic shapes: 1129 // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}} 1130 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) 1131 return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> 1132} 1133 1134// The following op sharding is used: 1135// Proto debug string: 1136// type: TUPLE 1137// tuple_shardings { 1138// type: MAXIMAL 1139// tile_assignment_dimensions: 1 1140// tile_assignment_devices: 0 1141// } 1142// Serialized string: 1143// "\08\02*\08\08\01\1A\01\01\22\01\00" 1144 1145// CHECK-LABEL: infeed_dequeue_tuple_sharding 1146func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { 1147 // CHECK: "mhlo.infeed" 1148 // An additional sharding is added at the end to account for token result. 1149 // Proto debug string: 1150 // type: TUPLE 1151 // tuple_shardings { 1152 // type: MAXIMAL 1153 // tile_assignment_dimensions: 1 1154 // tile_assignment_devices: 0 1155 // } 1156 // tuple_shardings { 1157 // type: MAXIMAL 1158 // tile_assignment_dimensions: 1 1159 // tile_assignment_devices: 0 1160 // } 1161 // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" 1162 %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> 1163 return %0 : tensor<8xi32> 1164} 1165 1166//===----------------------------------------------------------------------===// 1167// Nullary op legalizations. 1168//===----------------------------------------------------------------------===// 1169 1170// CHECK-LABEL: @const 1171func @const() -> tensor<2xi32> { 1172 // CHECK: mhlo.constant dense<0> : tensor<2xi32> 1173 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) 1174 return %0: tensor<2xi32> 1175} 1176 1177// CHECK-LABEL: @const_dynamic_output 1178func @const_dynamic_output() -> tensor<*xi32> { 1179 // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32> 1180 // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> 1181 %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) 1182 // CHECK: return [[CAST]] 1183 return %0: tensor<*xi32> 1184} 1185 1186// CHECK-LABEL: @opaque_const 1187func @opaque_const() -> tensor<!tf_type.variant<tensor<2xi32>>> { 1188 // CHECK-NOT: mhlo.constant 1189 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf_type.variant>} : () -> tensor<!tf_type.variant<tensor<2xi32>>> 1190 return %0 : tensor<!tf_type.variant<tensor<2xi32>>> 1191} 1192 1193//===----------------------------------------------------------------------===// 1194// Matmul op legalizations. 1195//===----------------------------------------------------------------------===// 1196 1197// CHECK-LABEL: matmul_notranspose 1198// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) 1199func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { 1200 // CHECK: "mhlo.dot"(%[[A]], %[[B]]) 1201 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> 1202 1203 return %0 : tensor<5x11xf32> 1204} 1205 1206// CHECK-LABEL: matmul_transpose_b 1207// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) 1208func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1209 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1210 // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]]) 1211 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1212 1213 return %0 : tensor<5x11xf32> 1214} 1215 1216// CHECK-LABEL: matmul_transpose_both 1217// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) 1218func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1219 // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1220 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1221 // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) 1222 %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1223 1224 return %0 : tensor<5x11xf32> 1225} 1226 1227// Verify that MatMul with ranked inputs are lowered to HLO. 1228// CHECK-LABEL: matmul_ranked 1229func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> { 1230 // CHECK: "mhlo.dot" 1231 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32> 1232 1233 return %0 : tensor<?x?xf32> 1234} 1235 1236// Verify that MatMul with unranked inputs are lowered to HLO. 1237// CHECK-LABEL: matmul_unranked 1238func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> { 1239 // CHECK: "mhlo.dot" 1240 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 1241 1242 return %0 : tensor<*xf32> 1243} 1244 1245// Verify SparseMatMul is legalized to dot. 1246// CHECK-LABEL: test_sparse_mat_mul 1247func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { 1248 // CHECK: "mhlo.dot" 1249 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> 1250 return %0: tensor<3x5xf32> 1251} 1252 1253// SparseMatMul where one operand needs to be transposed and the other one not. 1254// 1255// CHECK-LABEL: @test_sparse_mat_mul_with_transpose 1256// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1257// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> 1258// CHECK-SAME: -> tensor<3x5xf32> 1259// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]]) 1260// CHECK-SAME: permutation = dense<[1, 0]> 1261// CHECK-SAME: -> tensor<4x5xf32> 1262// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) 1263// CHECK-SAME: -> tensor<3x5xf32> 1264// CHECK: return %[[RESULT]] 1265func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { 1266 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> 1267 return %0: tensor<3x5xf32> 1268} 1269 1270// SparseMatMul where one operand needs to be casted and the other one not. 1271// 1272// CHECK-LABEL: @test_sparse_mat_mul_with_cast 1273// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1274// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> 1275// CHECK-SAME: -> tensor<3x5xf32> 1276// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]]) 1277// CHECK-SAME: -> tensor<4x5xf32> 1278// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) 1279// CHECK-SAME: -> tensor<3x5xf32> 1280// CHECK: return %[[RESULT]] 1281func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { 1282 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> 1283 return %0: tensor<3x5xf32> 1284} 1285 1286//===----------------------------------------------------------------------===// 1287// MatrixBandPart op legalizations. 1288//===----------------------------------------------------------------------===// 1289 1290// CHECK-LABEL: matrix_band_part 1291// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1292func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { 1293 // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64> 1294 // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64> 1295 1296 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64> 1297 // CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 1298 // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1299 1300 // CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 1301 // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1302 // CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64> 1303 1304 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64> 1305 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64> 1306 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64> 1307 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1> 1308 1309 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1> 1310 1311 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> 1312 1313 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> 1314 1315 // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]] 1316 // CHECK-DAG: return %[[R]] 1317 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16> 1318 return %0 : tensor<64x64xbf16> 1319} 1320 1321// CHECK-LABEL: matrix_band_part_2 1322// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1323func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> { 1324 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64> 1325 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64> 1326 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64> 1327 1328 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1> 1329 1330 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1> 1331 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> 1332 1333 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> 1334 1335 // CHECK-DAG: %[[R:.*]] = chlo.broadcast_select %[[J]], %[[INPUT]], %[[ZERO2]] 1336 // CHECK-DAG: return %[[R]] 1337 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16> 1338 return %0 : tensor<12x24x48xbf16> 1339} 1340 1341// CHECK-LABEL: matrix_band_part_3 1342// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1343func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> { 1344 // CHECK: "tf.MatrixBandPart" 1345 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16> 1346 return %0 : tensor<*xbf16> 1347} 1348 1349// CHECK-LABEL: matrix_band_part_4 1350// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1351func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> { 1352 // This one should lower. 1353 // CHECK-NOT: "tf.MatrixBandPart" 1354 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16> 1355 return %0 : tensor<24x48xbf16> 1356} 1357 1358//===----------------------------------------------------------------------===// 1359// MaxPool op legalizations. 1360//===----------------------------------------------------------------------===// 1361 1362// CHECK-LABEL: maxpool_valid_padding 1363// CHECK-SAME: %[[ARG:.*]]: tensor 1364func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1365 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 1366 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1367 // CHECK: mhlo.maximum 1368 // CHECK: mhlo.return 1369 // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} 1370 1371 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1372 return %0 : tensor<2x3x5x7xi32> 1373} 1374 1375// CHECK-LABEL: maxpool_same_padding 1376// CHECK-SAME: %[[ARG:.*]]: tensor 1377func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { 1378 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1379 1380 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> 1381 return %0 : tensor<2x4x7x7xi32> 1382} 1383 1384// CHECK-LABEL: maxpool_3d_valid_padding 1385// CHECK-SAME: %[[ARG:.*]]: tensor 1386func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { 1387 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 1388 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1389 // CHECK: mhlo.maximum 1390 // CHECK: mhlo.return 1391 // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} 1392 1393 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> 1394 return %0 : tensor<2x8x3x5x7xf32> 1395} 1396 1397// CHECK-LABEL: maxpool_3d_same_padding 1398// CHECK-SAME: %[[ARG:.*]]: tensor 1399func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { 1400 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1401 1402 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> 1403 return %0 : tensor<2x8x4x7x7xf32> 1404} 1405 1406// CHECK-LABEL: maxpool_explicit_padding 1407func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1408 // CHECK: tf.MaxPool 1409 // TODO(b/165938852): need to support explicit padding in max_pool. 1410 1411 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1412 return %0 : tensor<2x3x5x7xi32> 1413} 1414 1415//===----------------------------------------------------------------------===// 1416// MaxPoolGrad op legalizations. 1417//===----------------------------------------------------------------------===// 1418 1419// CHECK-LABEL: @max_pool_grad_valid 1420// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> 1421func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { 1422 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1423 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { 1424 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1425 // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 1426 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> () 1427 // CHECK: }, { 1428 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1429 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1430 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> () 1431 // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32> 1432 // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> 1433 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1434 data_format = "NHWC", 1435 ksize = [1, 2, 2, 1], 1436 padding = "VALID", 1437 strides = [1, 2, 2, 1] 1438 } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> 1439 return %result : tensor<10x24x24x64xf32> 1440} 1441 1442// CHECK-LABEL: @max_pool_3d_grad_valid 1443// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> 1444func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { 1445 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1446 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { 1447 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1448 // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 1449 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> () 1450 // CHECK: }, { 1451 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1452 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1453 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> () 1454 // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32> 1455 // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> 1456 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> 1457 return %result : tensor<10x8x24x24x64xf32> 1458} 1459 1460// CHECK-LABEL: @max_pool_grad_same 1461func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { 1462 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1463 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1464 data_format = "NHWC", 1465 ksize = [1, 2, 3, 1], 1466 padding = "SAME", 1467 strides = [1, 4, 4, 1] 1468 } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> 1469 return %result : tensor<2x13x25x7xf32> 1470} 1471 1472// CHECK-LABEL: @max_pool_3d_grad_same 1473func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { 1474 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1475 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> 1476 return %result : tensor<2x8x13x25x7xf32> 1477} 1478 1479//===----------------------------------------------------------------------===// 1480// OneHot op legalizations. 1481//===----------------------------------------------------------------------===// 1482 1483// CHECK-LABEL:one_hot 1484func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> { 1485 // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> 1486 // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> 1487 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> 1488 // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1489 // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1490 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> 1491 // CHECK: return %[[RESULT]] : tensor<3x5xf32> 1492 %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32> 1493 %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32> 1494 return %result : tensor<3x5xf32> 1495} 1496 1497//===----------------------------------------------------------------------===// 1498// tf.OutfeedEnqueueTuple legalization 1499//===----------------------------------------------------------------------===// 1500 1501// CHECK-LABEL: func @outfeed_enqueue_tuple 1502// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) 1503func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { 1504// CHECK: [[TUPLE:%.*]] = "mhlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>> 1505// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token 1506// CHECK: "mhlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token 1507 "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () 1508 return 1509} 1510 1511//===----------------------------------------------------------------------===// 1512// Pack op legalizations. 1513//===----------------------------------------------------------------------===// 1514 1515// CHECK-LABEL: func @pack 1516func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { 1517 // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> 1518 // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> 1519 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> 1520 1521 %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> 1522 return %0 : tensor<2x2xi32> 1523} 1524 1525//===----------------------------------------------------------------------===// 1526// PartitionedCall op legalization. 1527//===----------------------------------------------------------------------===// 1528 1529// CHECK-LABEL: func @partitioned_call 1530func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1531 // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32> 1532 %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>) 1533 return %0 : tensor<i32> 1534} 1535 1536func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1537 return %arg0 : tensor<i32> 1538} 1539 1540// CHECK-LABEL: func @partitioned_call_multi_input 1541func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1542 // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> 1543 %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) 1544 return %0 : tensor<i32> 1545} 1546 1547func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1548 return %arg0 : tensor<i32> 1549} 1550 1551// CHECK-LABEL: func @partitioned_call_multi_in_out 1552func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1553 // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1554 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1555 return %0, %1 : tensor<i32>, tensor<i32> 1556} 1557 1558func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1559 return %arg1, %arg0 : tensor<i32>, tensor<i32> 1560} 1561 1562// CHECK-LABEL: func @unhandled_partitioned_call 1563func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1564 // The argument types don't match the parameter types for the 1565 // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not 1566 // for a standard CallOp, so this op can't be lowered. 1567 // CHECK: "tf.PartitionedCall" 1568 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1569 return %0, %1 : tensor<i32>, tensor<i32> 1570} 1571 1572// CHECK-LABEL: func @unhandled_partitioned_call_2 1573func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1574 // CHECK: "tf.PartitionedCall" 1575 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1576 return %0, %1 : tensor<i32>, tensor<i32> 1577} 1578 1579 1580//===----------------------------------------------------------------------===// 1581// ReverseV2 op legalization. 1582//===----------------------------------------------------------------------===// 1583 1584// CHECK-LABEL: @reverse_func_32 1585func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1586 %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) 1587 1588 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1589 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> 1590 1591 // CHECK: return [[VAL]] : tensor<5xi32> 1592 return %reversed : tensor<5xi32> 1593} 1594 1595// CHECK-LABEL: @reverse_func_64 1596func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1597 %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) 1598 1599 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1600 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> 1601 1602 // CHECK: return [[VAL]] : tensor<5xi32> 1603 return %reversed : tensor<5xi32> 1604} 1605 1606// CHECK-LABEL: @reverse_func_neg 1607func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { 1608 %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1609 1610 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} 1611 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> 1612 1613 // CHECK: return [[VAL]] : tensor<5x5xi32> 1614 return %reversed : tensor<5x5xi32> 1615} 1616 1617//===----------------------------------------------------------------------===// 1618// StatefulPartitionedCall op legalization. 1619//===----------------------------------------------------------------------===// 1620 1621// CHECK-LABEL: func @stateful_partitioned_call 1622// CHECK-SAME: [[ARG:%.+]]: tensor<i32> 1623func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1624 // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32> 1625 %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>) 1626 return %0 : tensor<i32> 1627} 1628 1629func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1630 return %arg0 : tensor<i32> 1631} 1632 1633// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out 1634// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>) 1635func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1636 // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1637 %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1638 return %0, %1 : tensor<i32>, tensor<i32> 1639} 1640 1641func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1642 return %arg1, %arg0 : tensor<i32>, tensor<i32> 1643} 1644 1645//===----------------------------------------------------------------------===// 1646// Elu op legalizations. 1647//===----------------------------------------------------------------------===// 1648 1649// CHECK-LABEL: func @elu 1650func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1651 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<1xf32>) -> tensor<1xf32> 1652 // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg0, %[[ZERO]]) {comparison_direction = "GT"} 1653 // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0) 1654 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) 1655 // CHECK: return %[[RESULT]] 1656 %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> 1657 return %0: tensor<1xf32> 1658} 1659 1660// CHECK-LABEL: func @elu_unranked 1661func @elu_unranked(%arg0: tensor<?xf32>) -> tensor<?xf32> { 1662 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<?xf32>) -> tensor<?xf32> 1663 // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg0, %[[ZERO]]) {comparison_direction = "GT"} 1664 // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0) 1665 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) 1666 // CHECK: return %[[RESULT]] 1667 %0 = "tf.Elu"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 1668 return %0: tensor<?xf32> 1669} 1670 1671// CHECK-LABEL: func @elu_grad 1672// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1673func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> { 1674 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1675 // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 1676 // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} 1677 // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} 1678 // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]]) 1679 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]]) 1680 // CHECK: return %[[RESULT]] 1681 %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> 1682 return %2 : tensor<4x8xf32> 1683} 1684 1685//===----------------------------------------------------------------------===// 1686// Relu op legalizations. 1687//===----------------------------------------------------------------------===// 1688 1689// CHECK-LABEL: func @relu 1690func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1691 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1692 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> 1693 %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1694 return %0: tensor<1xi32> 1695} 1696 1697// CHECK-LABEL: func @relu_unranked 1698func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1699 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1700 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> 1701 %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1702 return %0: tensor<?xi32> 1703} 1704 1705// CHECK-LABEL: func @relu6 1706func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1707 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1708 // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1709 // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32> 1710 %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1711 return %0: tensor<1xi32> 1712} 1713 1714// CHECK-LABEL: func @relu6_unranked 1715func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1716 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1717 // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1718 // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32> 1719 %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1720 return %0: tensor<?xi32> 1721} 1722 1723// CHECK-LABEL: func @relu_grad_unranked 1724// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<?x?xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1725func @relu_grad_unranked(%gradients: tensor<?x?xf32>, %features: tensor<?x?xf32>) -> tensor<?x?xf32> { 1726 // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) {value = 0.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xf32> 1727 // CHECK-DAG: %[[PRED:.*]] = "mhlo.compare"(%arg1, %0) {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> 1728 // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 1729 // CHECK-DAG: return %[[RESULT]] : tensor<?x?xf32> 1730 %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 1731 return %2 : tensor<?x?xf32> 1732} 1733 1734// CHECK-LABEL: func @leaky_relu 1735func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { 1736 // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32> 1737 // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%[[ALPHA]]) {broadcast_sizes = dense<[1, 4, 4, 3]> : tensor<4xi64>} : (tensor<f32>) -> tensor<1x4x4x3xf32> 1738 // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4x3xf32> 1739 // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[BCASTALPHA]] : tensor<1x4x4x3xf32> 1740 // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> 1741 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[INP]], %[[LEAKY]]) : (tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1742 // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> 1743 %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> 1744 return %0 : tensor<1x4x4x3xf32> 1745} 1746 1747// CHECK-LABEL: func @leaky_relu_grad 1748func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { 1749 // CHECK-NEXT: %[[ALPHA:.*]] = mhlo.constant dense<2.000000e-01> : tensor<f32> 1750 // CHECK-NEXT: %[[BCASTALPHA:.*]] = "mhlo.broadcast"(%0) {broadcast_sizes = dense<[1, 4, 4]> : tensor<3xi64>} : (tensor<f32>) -> tensor<1x4x4xf32> 1751 // CHECK-NEXT: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<1x4x4xf32> 1752 // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[BCASTALPHA]] : tensor<1x4x4xf32> 1753 // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[INP:.*]], %[[ZERO]]) {comparison_direction = "GT"} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> 1754 // CHECK-NEXT: %[[RES:.*]] = "mhlo.select"(%[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]]) : (tensor<1x4x4xi1>, tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 1755 // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> 1756 %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> 1757 return %0 : tensor<1x4x4xf32> 1758} 1759 1760// CHECK-LABEL: func @softsign 1761func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { 1762 // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<4x10xf32>) -> tensor<4x10xf32> 1763 // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<4x10xf32>) -> tensor<4x10xf32> 1764 // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> 1765 // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> 1766 // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> 1767 %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> 1768 return %0 : tensor<4x10xf32> 1769} 1770 1771// CHECK-LABEL: func @softsign_unranked 1772func @softsign_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 1773 // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> 1774 // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<*xf32>) -> tensor<*xf32> 1775 // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<*xf32> 1776 // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<*xf32> 1777 // CHECK-NEXT: return %[[DIV]] : tensor<*xf32> 1778 %0 = "tf.Softsign"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 1779 return %0 : tensor<*xf32> 1780} 1781 1782// CHECK-LABEL: func @softsign_grad 1783func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { 1784 1785 // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 1786 // CHECK-NEXT: %[[ABS:.*]] = "mhlo.abs"(%{{.*}}) : (tensor<4x10xf32>) -> tensor<4x10xf32> 1787 // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = chlo.broadcast_add %[[ONE]], %[[ABS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4x10xf32>) -> tensor<4x10xf32> 1788 // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] : tensor<4x10xf32> 1789 // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = chlo.broadcast_divide %{{.*}}, %[[MUL]] : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> 1790 // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> 1791 %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> 1792 return %0 : tensor<4x10xf32> 1793} 1794 1795//===----------------------------------------------------------------------===// 1796// Roll op legalizations. 1797//===----------------------------------------------------------------------===// 1798 1799// CHECK-LABEL: func @Roll_0D 1800func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor<i32>) -> tensor<512xi32> { 1801 %axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>) 1802 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1803 // CHECK: %[[AXIS_SIZE:.*]] = mhlo.constant dense<512> : tensor<i32> 1804 // CHECK: %[[T1:.+]] = mhlo.remainder %arg1, %[[AXIS_SIZE]] : tensor<i32> 1805 // CHECK: %[[T2:.+]] = mhlo.add %[[T1]], %[[AXIS_SIZE]] : tensor<i32> 1806 // CHECK: %[[T3:.+]] = mhlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor<i32> 1807 // CHECK: %[[CONCAT:.+]] = "mhlo.concatenate"(%arg0, %arg0) {dimension = 0 : i64} 1808 // CHECK: %[[OFFSET:.+]] = mhlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor<i32> 1809 // CHECK: "mhlo.dynamic-slice"(%[[CONCAT]], %[[OFFSET]]) 1810 // CHECK-SAME: {slice_sizes = dense<512> : tensor<1xi64>} 1811 // CHECK-SAME: (tensor<1024xi32>, tensor<i32>) -> tensor<512xi32> 1812 %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor<i32>, tensor<i32>) -> tensor<512xi32> 1813 return %0 : tensor<512xi32> 1814} 1815 1816//===----------------------------------------------------------------------===// 1817// Select op legalizations. 1818//===----------------------------------------------------------------------===// 1819 1820// CHECK-LABEL: func @select_batch_static 1821func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 1822 // CHECK: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<2xi1>, tensor<3xindex>) -> tensor<2x6x8xi1> 1823 // CHECK: "mhlo.select"(%[[BCAST]], %arg1, %arg2) 1824 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 1825 return %0: tensor<2x6x8xi32> 1826} 1827 1828// CHECK-LABEL: func @select_batch_static_r1 1829func @select_batch_static_r1(%arg0: tensor<i1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 1830 // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) 1831 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 1832 return %0: tensor<2x6x8xi32> 1833} 1834 1835// CHECK-LABEL: func @select_batch_static_all_same 1836func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { 1837 // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) 1838 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> 1839 return %0: tensor<2x6x8xi32> 1840} 1841 1842// CHECK-LABEL: func @select_batch_dynamic_r1 1843func @select_batch_dynamic_r1(%arg0: tensor<?xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> { 1844 // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?xi1> -> tensor<1xindex> 1845 // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex> 1846 // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex> 1847 // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> 1848 // CHECK-NEXT: %[[C1:.*]] = constant 1 : index 1849 // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<?xindex>, tensor<?xindex>) 1850 // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<?xindex> 1851 // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] 1852 // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) { 1853 // CHECK-NEXT: %[[SHAPE1E:.*]] = shape.to_extent_tensor %[[SHAPE1]] : tensor<3xindex> -> tensor<3xindex> 1854 // CHECK-NEXT: %[[BCAST:.*]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE1E]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi1>, tensor<3xindex>) -> tensor<?x?x8xi1> 1855 // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%[[BCAST]], %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 1856 // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32> 1857 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 1858 return %0: tensor<?x?x8xi32> 1859} 1860 1861// CHECK-LABEL: func @select_batch_dynamic 1862func @select_batch_dynamic(%arg0: tensor<?x?x8xi1>, %arg1: tensor<?x?x8xi32>, %arg2: tensor<?x?x8xi32>) -> tensor<?x?x8xi32> { 1863 // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?x8xi1> -> tensor<3xindex> 1864 // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor<?x?x8xi32> -> tensor<3xindex> 1865 // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor<?x?x8xi32> -> tensor<3xindex> 1866 // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> 1867 // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex> 1868 // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] 1869 // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor<?x?x8xi32>) { 1870 // CHECK-NEXT: %[[SELECT:.*]] = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 1871 // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor<?x?x8xi32> 1872 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?x?x8xi1>, tensor<?x?x8xi32>, tensor<?x?x8xi32>) -> tensor<?x?x8xi32> 1873 return %0: tensor<?x?x8xi32> 1874} 1875 1876// CHECK-LABEL: testSelectInvalidUnranked 1877func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { 1878 // CHECK-NEXT: tf.Select 1879 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> 1880 return %0: tensor<*xf16> 1881} 1882 1883// CHECK-LABEL: testSelectThenUnranked 1884func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { 1885 // CHECK-NEXT: tf.Select 1886 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> 1887 return %0: tensor<*xf16> 1888} 1889 1890// CHECK-LABEL: testSelectElseUnranked 1891func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { 1892 // CHECK-NEXT: tf.Select 1893 %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> 1894 return %0: tensor<*xf16> 1895} 1896 1897// CHECK-LABEL: func @selectv2_dynamic_ranked 1898func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { 1899 // CHECK: chlo.broadcast_select 1900 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> 1901 return %0: tensor<2x?x8xi32> 1902} 1903 1904// CHECK-LABEL: func @selectv2_unranked 1905func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { 1906 // CHECK: chlo.broadcast_select 1907 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> 1908 return %0: tensor<*xi32> 1909} 1910 1911//===----------------------------------------------------------------------===// 1912// Fast Fourier Transform op legalization. 1913//===----------------------------------------------------------------------===// 1914 1915// CHECK-LABEL: func @fft_1D 1916func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1917 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>> 1918 %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1919 return %0 : tensor<8xcomplex<f32>> 1920} 1921 1922// CHECK-LABEL: func @ifft_1D 1923func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1924 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>> 1925 %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1926 return %0 : tensor<8xcomplex<f32>> 1927} 1928 1929// CHECK-LABEL: func @rfft_1D 1930func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> { 1931 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1932 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> 1933 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1934 return %0 : tensor<8xcomplex<f32>> 1935} 1936 1937// CHECK-LABEL: func @rfft_1D_padded 1938func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex<f32>> { 1939 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1940 // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %{{.*}}) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32> 1941 // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> 1942 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1943 return %0 : tensor<8xcomplex<f32>> 1944} 1945 1946// CHECK-LABEL: func @rfft_1D_sliced 1947func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex<f32>> { 1948 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1949 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> 1950 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32> 1951 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex<f32>> 1952 return %0 : tensor<2x8xcomplex<f32>> 1953} 1954 1955// CHECK-LABEL: func @irfft_1D 1956func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<5xf32> { 1957 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1958 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>> 1959 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>> 1960 %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<5xf32> 1961 return %0 : tensor<5xf32> 1962} 1963 1964// CHECK-LABEL: fft_1D_dynamic 1965func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1966 // CHECK: "tf.FFT" 1967 %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1968 return %0 : tensor<8xcomplex<f32>> 1969} 1970 1971// CHECK-LABEL: rfft_1D_dynamic 1972func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> { 1973 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1974 // CHECK: "tf.RFFT" 1975 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1976 return %0 : tensor<8xcomplex<f32>> 1977} 1978 1979//===----------------------------------------------------------------------===// 1980// Shape op legalization. 1981//===----------------------------------------------------------------------===// 1982 1983// CHECK-LABEL: func @shape_1D 1984func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> { 1985 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 1986 // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32> 1987 %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> 1988 1989 // CHECK: return [[TENSOR]] 1990 return %0 : tensor<1xi32> 1991} 1992 1993// CHECK-LABEL: func @shape_2D 1994func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> { 1995 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 1996 // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32> 1997 %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32> 1998 1999 // CHECK: return [[TENSOR]] 2000 return %0 : tensor<2xi32> 2001} 2002 2003// CHECK-LABEL: func @shape_rankless 2004func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> { 2005 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 2006 // CHECK: [[TENSOR:%.+]] = index_cast [[SHAPE]] : tensor<?xindex> to tensor<?xi32> 2007 %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32> 2008 2009 // CHECK: return [[TENSOR]] 2010 return %0 : tensor<?xi32> 2011} 2012 2013//===----------------------------------------------------------------------===// 2014// Transpose op legalization. 2015//===----------------------------------------------------------------------===// 2016 2017// CHECK-LABEL: @transpose_noop 2018func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 2019 %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2020 // CHECK: return %arg0 2021 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> 2022 return %0 : tensor<2x3xf32> 2023} 2024 2025// CHECK-LABEL: @transpose_2d 2026func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { 2027 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2028 // CHECK: "mhlo.transpose" 2029 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> 2030 return %0 : tensor<3x2xf32> 2031} 2032 2033// CHECK-LABEL: @transpose_3d_int32 2034func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 2035 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2036 // CHECK: "mhlo.transpose" 2037 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> 2038 return %0 : tensor<3x2x1xf32> 2039} 2040 2041// CHECK-LABEL: @transpose_3d 2042func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 2043 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) 2044 // CHECK: "mhlo.transpose" 2045 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> 2046 return %0 : tensor<3x2x1xf32> 2047} 2048 2049// CHECK-LABEL: @transpose_dynamic_2d 2050func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> { 2051 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2052 // CHECK: "mhlo.transpose" 2053 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32> 2054 return %0 : tensor<4x?xf32> 2055} 2056 2057// CHECK-LABEL: @transpose_unranked_2d 2058func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2059 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2060 // CHECK: "mhlo.transpose" 2061 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> 2062 return %0 : tensor<*xf32> 2063} 2064 2065 2066//===----------------------------------------------------------------------===// 2067// Unary op legalizations. 2068//===----------------------------------------------------------------------===// 2069 2070// CHECK-LABEL: @abs 2071func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2072 // CHECK: "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2073 %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2074 return %0 : tensor<2xf32> 2075} 2076 2077// CHECK-LABEL: func @abs_dynamic 2078func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2079 // CHECK: "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2080 %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2081 return %0 : tensor<?xf32> 2082} 2083 2084// CHECK-LABEL: func @abs_unranked 2085func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2086 // CHECK: "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2087 %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2088 return %0 : tensor<*xf32> 2089} 2090 2091// CHECK-LABEL: @acos 2092// CHLO-LABEL: @acos 2093func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2094 // CHECK: chlo.acos %arg0 : tensor<2xf32> 2095// CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} 2096// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> 2097// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> 2098// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 2099// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] 2100// CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) 2101// CHLO: %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00> 2102// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 2103// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] 2104// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] 2105// CHLO: %[[VAL_12:.*]] = mhlo.constant dense<3.14159274> 2106// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) 2107// CHLO: return %[[VAL_13]] : tensor<2xf32> 2108 %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2109 return %0 : tensor<2xf32> 2110} 2111 2112// CHECK-LABEL: @acos_complex 2113// CHLO-LABEL: @acos_complex 2114func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2115 // CHLO: tf.Acos 2116 %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2117 return %0 : tensor<2xcomplex<f32>> 2118} 2119 2120// CHECK-LABEL: @acos_dynamic 2121// CHLO-LABEL: @acos_dynamic 2122func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2123 // CHECK: chlo.acos %arg0 : tensor<*xf32> 2124 // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be 2125 // lowered further on ranked tensors. Unranked CHLO must be transformed to 2126 // ranked code before further lowering. 2127 // CHLO: "tf.Acos" 2128 %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2129 return %0 : tensor<*xf32> 2130} 2131 2132// CHECK-LABEL: @tan 2133// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2134// CHLO-LABEL: @tan 2135// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2136func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { 2137 // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> 2138 // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) 2139 // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) 2140 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2141 %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> 2142 return %result : tensor<2xf32> 2143} 2144 2145// CHECK-LABEL: @tan_unranked 2146// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2147// CHLO-LABEL: @tan_unranked 2148// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2149func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { 2150 // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> 2151 // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) 2152 // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) 2153 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2154 %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> 2155 return %result : tensor<*xf32> 2156} 2157 2158// CHECK-LABEL: func @cast_dynamic_i2f 2159func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> { 2160 // CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2161 %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2162 return %0 : tensor<?xf32> 2163} 2164 2165// CHECK-LABEL: func @cast_i2f 2166func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { 2167 // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2168 %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2169 return %0 : tensor<2xf32> 2170} 2171 2172// CHECK-LABEL: func @cast_c2f 2173func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2174 // CHECK: tf.Cast 2175 %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2176 return %0 : tensor<2xf32> 2177} 2178 2179// CHECK-LABEL: @ceil 2180func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2181 // CHECK: "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2182 %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2183 return %0 : tensor<2xf32> 2184} 2185 2186// CHECK-LABEL: func @ceil_dynamic 2187func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2188 // CHECK: "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2189 %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2190 return %0 : tensor<?xf32> 2191} 2192 2193// CHECK-LABEL: func @ceil_unranked 2194func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2195 // CHECK: "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2196 %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2197 return %0 : tensor<*xf32> 2198} 2199 2200// CHECK-LABEL: @complex_abs 2201func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2202 // CHECK: "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2203 %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2204 return %0 : tensor<2xf32> 2205} 2206 2207// CHECK-LABEL: @cos 2208func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2209 // CHECK: "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2210 %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2211 return %0 : tensor<2xf32> 2212} 2213 2214// CHECK-LABEL: func @cos_dynamic 2215func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2216 // CHECK: "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2217 %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2218 return %0 : tensor<?xf32> 2219} 2220 2221// CHECK-LABEL: func @cos_unranked 2222func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2223 // CHECK: "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2224 %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2225 return %0 : tensor<*xf32> 2226} 2227 2228// CHECK-LABEL: @exp 2229func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2230 // CHECK: "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2231 %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2232 return %0 : tensor<2xf32> 2233} 2234 2235// CHECK-LABEL: @expm1 2236func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2237 // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2238 %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2239 return %0 : tensor<2xf32> 2240} 2241 2242// CHECK-LABEL: func @exp_dynamic 2243func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2244 // CHECK: "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2245 %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2246 return %0 : tensor<?xf32> 2247} 2248 2249// CHECK-LABEL: func @exp_unranked 2250func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2251 // CHECK: "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2252 %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2253 return %0 : tensor<*xf32> 2254} 2255 2256// CHECK-LABEL: @floor 2257func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2258 // CHECK: "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2259 %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2260 return %0 : tensor<2xf32> 2261} 2262 2263// CHECK-LABEL: func @floor_dynamic 2264func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2265 // CHECK: "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2266 %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2267 return %0 : tensor<?xf32> 2268} 2269 2270// CHECK-LABEL: func @floor_unranked 2271func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2272 // CHECK: "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2273 %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2274 return %0 : tensor<*xf32> 2275} 2276 2277// CHECK-LABEL: func @invert_op_unranked 2278func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { 2279 // CHECK: "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> 2280 %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> 2281 return %0 : tensor<*xi32> 2282} 2283 2284// CHECK-LABEL: @is_finite 2285func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { 2286 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2287 %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2288 return %0 : tensor<2xi1> 2289} 2290 2291// CHECK-LABEL: func @is_finite_dynamic 2292func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> { 2293 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2294 %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2295 return %0 : tensor<?xi1> 2296} 2297 2298// CHECK-LABEL: func @is_finite_unranked 2299func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { 2300 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2301 %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2302 return %0 : tensor<*xi1> 2303} 2304 2305// CHECK-LABEL: @log 2306func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2307 // CHECK: "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2308 %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2309 return %0 : tensor<2xf32> 2310} 2311 2312// CHECK-LABEL: func @log_dynamic 2313func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2314 // CHECK: "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2315 %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2316 return %0 : tensor<?xf32> 2317} 2318 2319// CHECK-LABEL: func @log_unranked 2320func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2321 // CHECK: "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2322 %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2323 return %0 : tensor<*xf32> 2324} 2325 2326// CHECK-LABEL: @log1p 2327func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2328 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2329 %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2330 return %0 : tensor<2xf32> 2331} 2332 2333// CHECK-LABEL: func @log1p_dynamic 2334func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2335 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2336 %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2337 return %0 : tensor<?xf32> 2338} 2339 2340// CHECK-LABEL: func @log1p_unranked 2341func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2342 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2343 %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2344 return %0 : tensor<*xf32> 2345} 2346 2347// CHECK-LABEL: func @not_op_unranked 2348func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { 2349 // CHECK: "mhlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> 2350 %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> 2351 return %0 : tensor<*xi1> 2352} 2353 2354// CHECK-LABEL: @neg 2355func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2356 // CHECK: "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2357 %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2358 return %0 : tensor<2xf32> 2359} 2360 2361// CHECK-LABEL: func @neg_dynamic 2362func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2363 // CHECK: "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2364 %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2365 return %0 : tensor<?xf32> 2366} 2367 2368// CHECK-LABEL: func @neg_unranked 2369func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2370 // CHECK: "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2371 %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2372 return %0 : tensor<*xf32> 2373} 2374 2375// CHECK-LABEL: @sigmoid 2376func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2377 // CHECK: mhlo.logistic 2378 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2379 return %0 : tensor<2xf32> 2380} 2381 2382// CHECK-LABEL: @sigmoid_complex 2383func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2384 // CHECK: mhlo.logistic 2385 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2386 return %0 : tensor<2xcomplex<f32>> 2387} 2388 2389// CHECK-LABEL: @sigmoid_unranked 2390func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2391 // CHECK: mhlo.logistic 2392 %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2393 return %0 : tensor<*xf32> 2394} 2395 2396 2397// CHECK-LABEL: @sigmoid_grad 2398func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { 2399 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> 2400 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> 2401 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> 2402 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> 2403 // CHECK: return [[MUL1]] 2404 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 2405 return %0 : tensor<2xf32> 2406} 2407 2408// CHECK-LABEL: @sigmoid_grad_complex 2409func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2410 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>> 2411 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>> 2412 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>> 2413 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>> 2414 // CHECK: return [[MUL1]] 2415 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2416 return %0 : tensor<2xcomplex<f32>> 2417} 2418 2419// CHECK-LABEL: @sigmoid_grad_dynamic 2420func @sigmoid_grad_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 2421 // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2422 // CHECK: chlo.broadcast_subtract {{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> 2423 // CHECK: chlo.broadcast_multiply {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2424 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 2425 return %0 : tensor<?xf32> 2426} 2427 2428// CHECK-LABEL: @sin 2429func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2430 // CHECK: "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2431 %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2432 return %0 : tensor<2xf32> 2433} 2434 2435// CHECK-LABEL: func @sin_dynamic 2436func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2437 // CHECK: "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2438 %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2439 return %0 : tensor<?xf32> 2440} 2441 2442// CHECK-LABEL: func @sin_unranked 2443func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2444 // CHECK: "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2445 %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2446 return %0 : tensor<*xf32> 2447} 2448 2449// CHECK-LABEL: func @rsqrt 2450func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2451 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2452 %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2453 return %0 : tensor<2xf32> 2454} 2455 2456// CHECK-LABEL: func @rsqrt_dynamic 2457func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2458 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2459 %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2460 return %0 : tensor<?xf32> 2461} 2462 2463// CHECK-LABEL: func @rsqrt_unranked 2464func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2465 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2466 %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2467 return %0 : tensor<*xf32> 2468} 2469 2470// CHECK-LABEL: func @sqrt 2471func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2472 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2473 %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2474 return %0 : tensor<2xf32> 2475} 2476 2477// CHECK-LABEL: func @sqrt_dynamic 2478func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2479 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2480 %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2481 return %0 : tensor<?xf32> 2482} 2483 2484// CHECK-LABEL: func @sqrt_unranked 2485func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2486 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2487 %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2488 return %0 : tensor<*xf32> 2489} 2490 2491// CHECK-LABEL: func @tanh 2492func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2493 // CHECK: "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2494 %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2495 return %0 : tensor<2xf32> 2496} 2497 2498// CHECK-LABEL: func @tanh_dynamic 2499func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2500 // CHECK: "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2501 %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2502 return %0 : tensor<?xf32> 2503} 2504 2505// CHECK-LABEL: func @tanh_unranked 2506func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2507 // CHECK: "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2508 %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2509 return %0 : tensor<*xf32> 2510} 2511 2512// CHECK-LABEL: func @bitcast 2513func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2514 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2515 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2516 return %0 : tensor<2xf32> 2517} 2518 2519// CHECK-LABEL: func @bitcast_dynamic 2520func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2521 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2522 %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2523 return %0 : tensor<?xf32> 2524} 2525 2526// CHECK-LABEL: func @bitcast_unranked 2527func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2528 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2529 %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2530 return %0 : tensor<*xf32> 2531} 2532 2533// CHECK-LABEL: func @bitcast_same_widths 2534func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { 2535 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> 2536 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> 2537 return %0 : tensor<2xi32> 2538} 2539 2540// CHECK-LABEL: func @bitcast_smaller_input_width 2541func @bitcast_smaller_input_width(%arg0: tensor<2xi8>) -> tensor<2xi64> { 2542 // CHECK: "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64> 2543 %0 = "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64> 2544 return %0 : tensor<2xi64> 2545} 2546 2547// CHECK-LABEL: func @bitcast_smaller_output_width 2548func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2xf16> { 2549 // CHECK: "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16> 2550 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16> 2551 return %0 : tensor<2xf16> 2552} 2553 2554// CHECK-LABEL: reshape 2555func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { 2556 // CHECK: "mhlo.reshape" 2557 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> 2558 return %0 : tensor<2x1xf32> 2559} 2560 2561// CHECK-LABEL: reshape_dynamic 2562func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 2563 // CHECK: "chlo.dynamic_reshape" 2564 // CHLO: mhlo.compute_reshape_shape 2565 // CHLO: mhlo.dynamic_reshape 2566 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 2567 return %0 : tensor<?x?xf32> 2568} 2569 2570// CHECK-LABEL: reshape_unranked 2571// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> 2572// CHECK-SAME: %[[TARGET_SHAPE:.*]]: tensor<2xi32> 2573func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 2574 // CHECK: "chlo.dynamic_reshape" 2575 // CHLO: shape.shape_of 2576 // CHLO: shape.num_elements 2577 // CHLO: mhlo.cstr_reshapable 2578 // CHLO: assuming{{.*}}{ 2579 // CHLO: mhlo.compute_reshape_shape 2580 // CHLO: mhlo.dynamic_reshape 2581 // CHLO: } 2582 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32> 2583 return %0 : tensor<?x?xf32> 2584} 2585 2586// CHECK-LABEL: squeeze 2587func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { 2588 // CHECK: "mhlo.reshape" 2589 %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> 2590 return %0 : tensor<1x10xf32> 2591} 2592 2593// CHECK-LABEL: squeeze_dynamic 2594func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> { 2595 // CHECK: "tf.Squeeze" 2596 %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> 2597 return %0 : tensor<*xf32> 2598} 2599 2600// CHECK-LABEL: expand_dims 2601func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> { 2602 // CHECK: "mhlo.reshape" 2603 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32> 2604 return %0 : tensor<1x2xf32> 2605} 2606 2607// CHECK-LABEL: expand_dims_dynamic 2608func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> { 2609 %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>) 2610 2611 // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 2612 // CHECK-DAG: %[[CST0:.+]] = constant 0 2613 // CHECK-DAG: %[[CST1:.+]] = constant 1 2614 // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] 2615 // CHECK-DAG: %[[CST1_0:.+]] = constant 1 2616 // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] 2617 // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]] 2618 // CHECK-DAG: %[[RESHAPE:.+]] = "mhlo.dynamic_reshape"(%arg0, %[[TOEXTENTS]]) 2619 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32> 2620 2621 // CHECK: return %[[RESHAPE]] 2622 return %0 : tensor<?x1x?xf32> 2623} 2624 2625// CHECK-LABEL: expand_dynamic_dims_rank1_axis 2626func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> { 2627 %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 2628 2629 // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 2630 // CHECK-DAG: %[[CST0:.+]] = constant 0 2631 // CHECK-DAG: %[[CST1:.+]] = constant 1 2632 // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] 2633 // CHECK-DAG: %[[CST1_0:.+]] = constant 1 2634 // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] 2635 // CHECK-DAG: %[[CST2:.+]] = constant 2 2636 // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]] 2637 // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]] 2638 // CHECK-DAG: %[[RESHAPE:.+]] = "mhlo.dynamic_reshape"(%arg0, %[[TOEXTENTS]]) 2639 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32> 2640 2641 // CHECK: return %[[RESHAPE]] 2642 return %0 : tensor<?x1x?x4xf32> 2643} 2644 2645// CHECK-LABEL: func @sign 2646// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> 2647func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { 2648 // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]]) 2649 // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> 2650 %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) 2651 return %0 : tensor<1x2x3x4xf32> 2652} 2653 2654// CHECK-LABEL: func @sign_dynamic 2655func @sign_dynamic(%arg0: tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> { 2656 // CHECK: "mhlo.sign"(%arg0) : (tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> 2657 // CHECK: "mhlo.compare"({{.*}}) {comparison_direction = "NE"} : (tensor<?x2x3x?xf32>, tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xi1> 2658 // CHECK: shape.shape_of %arg0 : tensor<?x2x3x?xf32> -> tensor<4xindex> 2659 // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<4xindex>) -> tensor<?x2x3x?xf32> 2660 // CHECK: "mhlo.select"({{.*}}) : (tensor<?x2x3x?xi1>, tensor<?x2x3x?xf32>, tensor<?x2x3x?xf32>) -> tensor<?x2x3x?xf32> 2661 // CHECK: return {{.*}} : tensor<?x2x3x?xf32> 2662 %0 = "tf.Sign"(%arg0) : (tensor<?x2x3x?xf32>) -> (tensor<?x2x3x?xf32>) 2663 return %0 : tensor<?x2x3x?xf32> 2664} 2665 2666// CHECK-LABEL: slice_constant_start 2667func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { 2668 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> 2669 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> 2670 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2671 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2672 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2673 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2674 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> 2675 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START:.*]]) : 2676 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2677 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) 2678 // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : 2679 // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 2680 // CHECK: return %[[RESULT]] : tensor<2xi32> 2681 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2682 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2683 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> 2684 return %0 : tensor<2xi32> 2685} 2686 2687// CHECK-LABEL: slice_i32_consts 2688func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { 2689 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi32> 2690 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi32>) -> tensor<1xi64> 2691 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2692 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2693 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2694 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> 2695 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64> 2696 // CHECK: "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 2697 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2698 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2699 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> 2700 return %0 : tensor<2xi32> 2701} 2702 2703// CHECK-LABEL: slice_constant_start_negative_one_size 2704func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { 2705 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> 2706 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<1xi64>) -> tensor<1xi64> 2707 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2708 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2709 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2710 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> 2711 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64> 2712 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32> 2713 // CHECK: return %[[RESULT]] : tensor<3xi32> 2714 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2715 %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2716 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> 2717 return %0 : tensor<3xi32> 2718} 2719 2720// CHECK-LABEL: slice_constant_start_dynamic_shape 2721func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2722 // CHECK: %[[START:.*]] = mhlo.constant dense<[1, 0]> : tensor<2xi64> 2723 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START]]) : (tensor<2xi64>) -> tensor<2xi64> 2724 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) 2725 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2726 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2727 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2728 // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> 2729 // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : 2730 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2731 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) 2732 // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, 2733 // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, 2734 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2735 // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> 2736 // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : 2737 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2738 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice" 2739 // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) 2740 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : 2741 // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 2742 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2743 %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2744 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2745 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2746 return %0 : tensor<1x4xi32> 2747} 2748 2749// CHECK-LABEL: slice_variable_start 2750func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2751 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> 2752 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) 2753 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2754 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2755 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 2756 // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64> 2757 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 2758 // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64> 2759 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 2760 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2761 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2762 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2763 return %0 : tensor<1x4xi32> 2764} 2765 2766// CHECK-LABEL: slice_mhlo_sizes 2767func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { 2768 // CHECK-NOT: "tf.Slice" 2769 %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> 2770 %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> 2771 return %1 : tensor<1x512x4xf32> 2772} 2773 2774// CHECK-LABEL: slice_variable_start_negative_one_size 2775func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2776 // CHECK: %[[RESULT:.*]] = "tf.Slice" 2777 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2778 %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2779 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2780 return %0 : tensor<1x4xi32> 2781} 2782 2783// CHECK-LABEL: slice_real_dynamic_slice 2784func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<*xi32> { 2785 // CHECK: tensor.extract {{.*}} : tensor<1xi64> 2786 // CHECK: tensor.extract {{.*}} : tensor<1xi64> 2787 // CHECK: index_cast {{.*}} : index to i64 2788 // CHECK: cmpi eq, {{.*}} : i64 2789 // CHECK: addi {{.*}} : i64 2790 // CHECK: tensor.dim {{.*}} : tensor<4xi32> 2791 // CHECK: index_cast {{.*}} : index to i64 2792 // CHECK: select {{.*}} : i64 2793 // CHECK: index_cast {{.*}} : i64 to index 2794 // CHECK: index_cast {{.*}} : i64 to index 2795 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 2796 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 2797 // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> 2798 %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xi32> 2799 return %0 : tensor<*xi32> 2800} 2801 2802//===----------------------------------------------------------------------===// 2803// StridedSlice op legalizations. 2804//===----------------------------------------------------------------------===// 2805 2806// CHECK-LABEL: simple_strided_slice 2807func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 2808 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2809 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2810 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2811 2812 // CHECK: mhlo.slice 2813 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 2814 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 2815 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2816 // CHECK-SAME: -> tensor<3x2xf32> 2817 2818 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2819 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 2820 return %output : tensor<3x2xf32> 2821} 2822 2823// CHECK-LABEL: dynamic_strided_slice 2824func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 2825 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2826 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2827 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2828 2829 // CHECK: "tf.StridedSlice" 2830 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2831 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 2832 return %output : tensor<?x2xf32> 2833} 2834 2835// CHECK-LABEL: strided_slice_negative_indices 2836func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 2837 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2838 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2839 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2840 2841 // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 2842 2843 // CHECK: mhlo.slice 2844 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 2845 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 2846 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2847 // CHECK-SAME: -> tensor<3x2xf32> 2848 2849 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2850 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 2851 return %output : tensor<3x2xf32> 2852} 2853 2854// CHECK-LABEL: dynamic_strided_slice_negative_indices 2855func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 2856 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2857 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2858 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2859 2860 // CHECK: tf.StridedSlice 2861 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2862 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 2863 return %output : tensor<?x2xf32> 2864} 2865 2866// CHECK-LABEL: strided_slice_range_clamping 2867func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { 2868 %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2869 %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2870 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2871 2872 // CHECK: mhlo.slice 2873 // CHECK-DAG-SAME: start_indices = dense<[0, 0]> 2874 // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> 2875 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2876 // CHECK-SAME: -> tensor<1x3xf32> 2877 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2878 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> 2879 return %output : tensor<1x3xf32> 2880} 2881 2882// CHECK-LABEL: strided_slice_empty 2883func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { 2884 %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2885 %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2886 %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2887 2888 // CHECK: mhlo.constant dense<> : tensor<0xf32> 2889 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2890 : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> 2891 return %output : tensor<0xf32> 2892} 2893 2894// CHECK-LABEL: strided_slice_begin_end_mask 2895// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> 2896func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { 2897 2898 // For StridedSlice 2899 // Dim #: 0, 1, 2 2900 // Input shape: [4, 128, 1024] 2901 // Begin: 1, 4, -3 2902 // End: 8, 65, 42 2903 // Stride: 1, 4, -1 2904 // Begin mask: 0, 0, 1 (= 1) 2905 // End mask: 1, 0, 0 (= 4) 2906 2907 // So result shape: 2908 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 2909 // Dim #1: 4 to 65 stride 4: so 16 2910 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 2911 // result shape: [4, 16, 1022] 2912 2913 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2914 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2915 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2916 2917 // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) 2918 2919 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) 2920 // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> 2921 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> 2922 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 2923 // CHECK-SAME: -> tensor<4x16x1022xf32> 2924 2925 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> 2926 2927 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2928 // CHECK-SAME: -> tensor<4x16x1022xf32> 2929 2930 return 2931} 2932 2933// CHECK-LABEL: strided_slice_shrink_axis_mask 2934// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> 2935func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { 2936 2937 // For StridedSlice 2938 // Dim #: 0, 1, 2 2939 // Input shape: [4, 128, 1024] 2940 // Begin: 1, 4, -3 2941 // End: 8, 65, 42 2942 // Stride: 1, 4, -1 2943 // Begin mask: 1, 0, 0 (= 1) 2944 // End mask: 0, 0, 1 (= 4) 2945 // Shrink axis mask: 1, 0, 1 (= 5) 2946 2947 // So result shape: 2948 // Dim #0: shrink axis, take value at [1] 2949 // Dim #1: 4 to 65 stride 4: so 16 2950 // Dim #2: shrink axis, take value at [-3] 2951 // result shape: [16] 2952 2953 // As output shape of StridedSlice differs, a reshape will follow. 2954 2955 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2956 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2957 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2958 2959 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 2960 // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> 2961 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> 2962 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 2963 // CHECK-SAME: -> tensor<1x16x1xf32> 2964 2965 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> 2966 2967 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2968 // CHECK-SAME: -> tensor<16xf32> 2969 2970 return 2971} 2972 2973// CHECK-LABEL: strided_slice_ellipsis_mask 2974// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 2975func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 2976 // For StridedSlice input[1, ..., 8:, :10, 2:6:2] 2977 // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized 2978 // slice input[1, :, :, 8:, :10, 2:6:2] 2979 2980 // The start, limit indices and strides attributes of mhlo.slice would 2981 // reflect the canonicalized slice. 2982 // As output shape of StridedSlice differs, a reshape will follow. 2983 2984 %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2985 %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2986 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2987 2988 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 2989 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 2990 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 2991 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 2992 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 2993 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> 2994 2995 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2996 // CHECK-SAME: -> tensor<4x8x8x10x2xf32> 2997 2998 return 2999} 3000 3001// CHECK-LABEL: strided_slice_new_axis_mask 3002// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 3003func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 3004 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 3005 // New axis mask is at index 1 and 6 of sparse spec, so 3006 // new_axis_mask = 2^1 + 2^6 = 66 3007 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 3008 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 3009 // This is then reshaped to add the new axes. 3010 3011 // The start, limit indices and strides attributes of mhlo.slice would 3012 // reflect the canonicalized slice. 3013 // As output shape of StridedSlice differs, a reshape will follow to reflect 3014 // new axes added. 3015 3016 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3017 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3018 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 3019 3020 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 3021 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 3022 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 3023 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 3024 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 3025 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> 3026 3027 // CHECK: "mhlo.reshape"(%[[SLICE]]) 3028 // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> 3029 3030 return 3031} 3032 3033// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( 3034// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> 3035func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { 3036 // StridedSlice gets input[8:10], which is same as input[8:10, ...] 3037 // The start_indices, limit_indices, and strides attribute of mhlo.slice 3038 // reflect the canonicalized slice. 3039 %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> 3040 %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> 3041 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3042 // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) 3043 // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> 3044 // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> 3045 // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> 3046 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> 3047 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> 3048 // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> 3049 return %0 : tensor<2x16x2xf32> 3050} 3051 3052// CHECK-LABEL: strided_slice_nonconstant_begin_end 3053func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { 3054 // In this case, the `begin` and `end` inputs are unknown at compile time -- 3055 // so the StridedSlice needs to slice these vectors and use that as input to 3056 // an HLO dynamic slice. 3057 %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 3058 %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3059 %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3060 %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 3061 %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 3062 // CHECK: %[[A:.*]] = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32> 3063 // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) 3064 // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> 3065 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 3066 // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) 3067 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 3068 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 3069 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 3070 // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32> 3071 // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] 3072 // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 3073 // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32> 3074 // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32> 3075 // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : 3076 // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3077 // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice" 3078 // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) 3079 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : 3080 // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32> 3081 // CHECK-NEXT: %[[FINAL:.*]] = "mhlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> 3082 %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3083 // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> 3084 return %result : tensor<1x97xi32> 3085} 3086 3087// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 3088func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { 3089 // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, 3090 // `strides` must be known. 3091 // CHECK: tf.StridedSlice 3092 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3093 return %result : tensor<1x97xi32> 3094} 3095 3096// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 3097func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3098 // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown 3099 // at compile time, `strides` must be known to have all 1 values. 3100 %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> 3101 // CHECK: tf.StridedSlice 3102 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3103 return %result : tensor<1x97xi32> 3104} 3105 3106// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count 3107func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { 3108 %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> 3109 // When begin/end are dynamic, the number of output elements must be equal to 3110 // the number of input elements sliced. 3111 // CHECK: tf.StridedSlice 3112 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> 3113 return %0 : tensor<6x10xf32> 3114} 3115 3116// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask 3117func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3118 // Begin mask: When `begin` and `end` inputs are unknown at compile time, we 3119 // can't support a begin mask. 3120 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3121 // CHECK: tf.StridedSlice 3122 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3123 return %result : tensor<1x97xi32> 3124} 3125 3126// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask 3127func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3128 // End mask: When `begin` and `end` inputs are unknown at compile time, we 3129 // can't support an end mask. 3130 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3131 // CHECK: tf.StridedSlice 3132 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3133 return %result : tensor<1x97xi32> 3134} 3135 3136// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask 3137func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3138 // New axis mask: When `begin` and `end` inputs are unknown at compile time, 3139 // we can't support a new_axis mask. 3140 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3141 // CHECK: tf.StridedSlice 3142 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3143 return %result : tensor<1x97xi32> 3144} 3145 3146// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask 3147func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3148 // This ellipsis mask is not supported because it does not refer to the last 3149 // dimension. 3150 // [0, 1, 0] = 2 3151 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3152 // CHECK: tf.StridedSlice 3153 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3154 return %result : tensor<1x97xi32> 3155} 3156 3157// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask 3158func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3159 // This ellipsis mask is supported because it refers to the last dimension. 3160 // [1, 0, 0] = 4 3161 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3162 // CHECK: mhlo.dynamic-slice 3163 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3164 return %result : tensor<1x97xi32> 3165} 3166 3167// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask 3168func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3169 // This shrink_axis mask is supported because it refers to a major dimension. 3170 // [1, 1, 1] = 7 3171 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3172 // CHECK: mhlo.dynamic-slice 3173 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3174 return %result : tensor<1x97xi32> 3175} 3176 3177// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask 3178func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3179 // This shrink_axis mask is unsupported because it does not refer to a major 3180 // dimension. 3181 // [0, 1, 0] = 2 3182 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3183 // CHECK: tf.StridedSlice 3184 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3185 return %result : tensor<1x97xi32> 3186} 3187 3188 3189//===----------------------------------------------------------------------===// 3190// Reduction op legalizations. 3191//===----------------------------------------------------------------------===// 3192 3193// CHECK-LABEL: func @mean 3194func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3195 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3196 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3197 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3198 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3199 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3200 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3201 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3202 // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 3203 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> 3204 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3205 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3206 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3207 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3208 return %0 : tensor<4x1xf16> 3209} 3210 3211// CHECK-LABEL: func @mean_scalar_dim 3212func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3213 // Verify that tf.Mean op with scalar attributes are lowered successfully. 3214 3215 // CHECK-NOT: tf.Mean 3216 %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64> 3217 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16> 3218 return %0 : tensor<4x1xf16> 3219} 3220 3221// CHECK-LABEL: func @mean_dynamic 3222func @mean_dynamic(%arg0: tensor<?x?xf16>) -> tensor<?x1xf16> { 3223 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<?x?xf16>) -> tensor<?x?xf32> 3224 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3225 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3226 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3227 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3228 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3229 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> 3230 // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?xf16> -> tensor<2xindex> 3231 // CHECK: %[[C1_1:.*]] = constant 1 : index 3232 // CHECK: %[[C1_2:.*]] = constant 1 : index 3233 // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex> 3234 // CHECK: %[[MUL:.*]] = muli %[[C1_1]], %[[REDUCED_DIM]] : index 3235 // CHECK: %[[INDEX_CAST:.*]] = index_cast %[[MUL]] : index to i64 3236 // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor<1xi64> 3237 // CHECK: %[[SCALAR_TENSOR:.*]] = "mhlo.reshape"(%[[TENSOR]]) : (tensor<1xi64>) -> tensor<i64> 3238 // CHECK: %[[CONVERT:.*]] = "mhlo.convert"(%[[SCALAR_TENSOR]]) : (tensor<i64>) -> tensor<f32> 3239 // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 3240 // CHECK: %[[MEAN_CONVERTED:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<?xf32>) -> tensor<?xf16> 3241 // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor<?xf16> -> tensor<1xindex> 3242 // CHECK: %[[C1:.*]] = constant 1 : index 3243 // CHECK: %[[C0:.*]] = constant 0 : index 3244 // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex> 3245 // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex> 3246 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[MEAN_CONVERTED]], %[[RESULT_SHAPE]]) : (tensor<?xf16>, tensor<2xindex>) -> tensor<?x1xf16> 3247 // CHECK: return %[[RESULT]] : tensor<?x1xf16> 3248 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3249 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<?x?xf16>, tensor<1xi64>) -> tensor<?x1xf16> 3250 return %0 : tensor<?x1xf16> 3251} 3252 3253// CHECK-LABEL: func @sum 3254func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3255 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3256 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3257 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3258 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3259 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3260 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3261 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3262 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3263 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3264 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3265 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3266 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3267 return %0 : tensor<4x1xf16> 3268} 3269 3270// CHECK-LABEL: func @sum_dynamic 3271func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3272 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32> 3273 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3274 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3275 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3276 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3277 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3278 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32> 3279 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3280 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3281 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3282 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3283 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3284 return %0 : tensor<4x1xf16> 3285} 3286 3287// CHECK-LABEL: func @max 3288func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3289 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> 3290 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3291 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3292 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3293 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16> 3294 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3295 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3296 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3297 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3298 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3299 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3300 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3301 return %0 : tensor<4x1xf16> 3302} 3303 3304// CHECK-LABEL: func @max_qint 3305// Regression test to ensure we don't crash getting the initial value for 3306// tf.Max when using quantized integer types. 3307func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3308 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3309 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3310 return %0 : tensor<4x1x!tf_type.qint8> 3311} 3312 3313// CHECK-LABEL: func @max_dynamic 3314func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3315 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16> 3316 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3317 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3318 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3319 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16> 3320 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3321 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16> 3322 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3323 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3324 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3325 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3326 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3327 return %0 : tensor<4x1xf16> 3328} 3329 3330// CHECK-LABEL: func @min 3331func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3332 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> 3333 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16> 3334 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3335 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3336 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.minimum %[[ARGA]], %[[ARGB]] : tensor<f16> 3337 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3338 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3339 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3340 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3341 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3342 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3343 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3344 return %0 : tensor<4x1xf16> 3345} 3346 3347// CHECK-LABEL: func @min_qint 3348// Regression test to ensure we don't crash getting the initial value for 3349// tf.Min when using quantized integer types. 3350func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3351 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3352 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3353 return %0 : tensor<4x1x!tf_type.qint8> 3354} 3355 3356// CHECK-LABEL: func @prod 3357func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3358 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3359 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3360 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3361 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3362 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.multiply %[[ARGA]], %[[ARGB]] : tensor<f32> 3363 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3364 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3365 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3366 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3367 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3368 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3369 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3370 return %0 : tensor<4x1xf16> 3371} 3372 3373// CHECK-LABEL: func @prod_qint 3374// Regression test to ensure we don't crash getting the initial value for 3375// tf.Prod when using quantized integer types. 3376func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { 3377 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3378 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> 3379 return %0 : tensor<4x1x!tf_type.qint8> 3380} 3381 3382// CHECK-LABEL: @all 3383func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3384 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3385 // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1> 3386 // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { 3387 // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>): 3388 // CHECK: %[[AND:.*]] = mhlo.and %[[ARGA]], %[[ARGB]] : tensor<i1> 3389 // CHECK: "mhlo.return"(%[[AND]]) : (tensor<i1>) -> () 3390 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3391 %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3392 return %0 : tensor<4xi1> 3393} 3394 3395// CHECK-LABEL: @all_keep_dim 3396func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3397 // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> 3398 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3399 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3400 return %0 : tensor<4x1xi1> 3401} 3402 3403// CHECk-LABEL: @all_dynamic 3404func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3405 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3406 // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> 3407 // CHECK: "mhlo.reduce"(%[[ARG]] 3408 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3409 return %0 : tensor<4x1xi1> 3410} 3411 3412// CHECK-LABEL: @any 3413func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3414 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3415 // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1> 3416 // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { 3417 // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>): 3418 // CHECK: %[[AND:.*]] = mhlo.or %[[ARGA]], %[[ARGB]] : tensor<i1> 3419 // CHECK: "mhlo.return"(%[[AND]]) : (tensor<i1>) -> () 3420 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3421 %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3422 return %0 : tensor<4xi1> 3423} 3424 3425// CHECK-LABEL: @any_keep_dim 3426func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3427 // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> 3428 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3429 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3430 return %0 : tensor<4x1xi1> 3431} 3432 3433// CHECk-LABEL: @any_dynamic 3434func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3435 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3436 // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> 3437 // CHECK: "mhlo.reduce"(%[[ARG]] 3438 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3439 return %0 : tensor<4x1xi1> 3440} 3441 3442//===----------------------------------------------------------------------===// 3443// Tile op legalizations. 3444//===----------------------------------------------------------------------===// 3445 3446// CHECK-LABEL: func @tile_by_reshape 3447func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { 3448 // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> 3449 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> 3450 // CHECK: return %[[RESULT]] : tensor<28x24xf32> 3451 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 3452 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> 3453 return %0 : tensor<28x24xf32> 3454} 3455 3456// CHECK-LABEL: func @tile_just_broadcast 3457func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { 3458 // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> 3459 // CHECK: return %[[RESULT]] : tensor<7x3xf32> 3460 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 3461 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> 3462 return %0 : tensor<7x3xf32> 3463} 3464 3465// CHECK-LABEL: func @tile_dynamic_shape 3466func @tile_dynamic_shape(%arg0: tensor<?x8xf32>) -> tensor<?x24xf32> { 3467 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> 3468 // CHECK: tensor.dim {{.*}} : tensor<?x8xf32> 3469 // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> 3470 // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<?x8xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> 3471 // CHECK: muli {{.*}} : index 3472 // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> 3473 // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x?x?xf32>, tensor<2xindex>) -> tensor<?x24xf32> 3474 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<?x8xf32>, tensor<2xi32>) -> tensor<?x24xf32> 3475 return %0 : tensor<?x24xf32> 3476} 3477 3478//===----------------------------------------------------------------------===// 3479// ArgMax/ArgMin op legalizations. 3480//===----------------------------------------------------------------------===// 3481 3482// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 3483func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { 3484 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64> 3485 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3486 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32> 3487 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3488 // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>): 3489 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GE"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 3490 // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 3491 // CHECK: %[[COMPARE_EQ:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 3492 // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] 3493 // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3494 // CHECK: %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3495 // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT3]]) : (tensor<i64>, tensor<i32>) -> () 3496 // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> 3497 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 3498 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32> 3499 return %0 : tensor<7xi32> 3500} 3501 3502// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 3503func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { 3504 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 3505 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i64> 3506 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64> 3507 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3508 // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> 3509 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 3510 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64> 3511 return %0 : tensor<3xi64> 3512} 3513 3514// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1 3515func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> { 3516 // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1> 3517 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i64> 3518 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64> 3519 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3520 // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> 3521 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 3522 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor<i32>) -> tensor<3xi64> 3523 return %0 : tensor<3xi64> 3524} 3525 3526// CHECK-LABEL: func @argmax_dynamic_shape_input_output 3527func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> { 3528 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 3529 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3530 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32> 3531 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3532 // CHECK: return %[[REDUCE]]#1 : tensor<?xi32> 3533 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 3534 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32> 3535 return %0 : tensor<?xi32> 3536} 3537 3538// CHECK-LABEL: func @argmax_dynamic_shape_input 3539func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { 3540 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 3541 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3542 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32> 3543 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3544 // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> 3545 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 3546 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32> 3547 return %0 : tensor<3xi32> 3548} 3549 3550// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0 3551func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { 3552 // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor<i64> 3553 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3554 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32> 3555 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3556 // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>): 3557 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "LE"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 3558 // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 3559 // CHECK: %[[COMPARE_EQ:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 3560 // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] 3561 // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3562 // CHECK: %[[RESULT3:.*]] = "mhlo.select"(%[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3563 // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT3]]) : (tensor<i64>, tensor<i32>) -> () 3564 // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> 3565 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 3566 %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32> 3567 return %0 : tensor<7xi32> 3568} 3569 3570//===----------------------------------------------------------------------===// 3571// Random op legalizations. 3572//===----------------------------------------------------------------------===// 3573 3574// CHECK-LABEL: func @rng_uniform 3575func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 3576 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3577 // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3578 // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 3579 // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> 3580 %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 3581 // CHECK: return %[[F32]] 3582 return %0 : tensor<12x?x64xf32> 3583} 3584 3585// CHECK-LABEL: func @rng_std_normal 3586func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 3587 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3588 // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3589 // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 3590 // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> 3591 %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 3592 // CHECK: return %[[F32]] 3593 return %0 : tensor<12x?x64xf32> 3594} 3595 3596//===----------------------------------------------------------------------===// 3597// Range op legalizations. 3598//===----------------------------------------------------------------------===// 3599 3600// CHECK-LABEL: func @range 3601// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 3602func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> { 3603 %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32> 3604 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" 3605 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3606 // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3607 %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 3608 return %3 : tensor<5xf32> 3609} 3610 3611// CHECK-LABEL: func @range_dynamic 3612// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 3613func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> { 3614 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 3615 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) 3616 // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) 3617 // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) 3618 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 3619 // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) 3620 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) 3621 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) 3622 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 3623 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) 3624 // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) 3625 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3626 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3627 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 3628 3629 // CHECK: return [[ADD]] 3630 return %2 : tensor<?xf32> 3631} 3632 3633// CHECK-LABEL: func @range_int_dynamic 3634// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32> 3635func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> { 3636 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 3637 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) 3638 // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) 3639 // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) 3640 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 3641 // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) 3642 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) 3643 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) 3644 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 3645 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) 3646 // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) 3647 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3648 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3649 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32> 3650 3651 // CHECK: return [[ADD]] 3652 return %2 : tensor<?xi32> 3653} 3654 3655// CHECK-LABEL: func @linspace_static 3656// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32> 3657func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> { 3658 // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> 3659 // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM]]) 3660 // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> 3661 // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] 3662 // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] 3663 // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] 3664 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} 3665 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3666 // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3667 // CHECK: return [[LINSPACE]] 3668 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32> 3669 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32> 3670 return %1 : tensor<4xf32> 3671} 3672 3673// CHECK-LABEL: func @linspace_dynamic 3674func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> { 3675 // CHECK: "tf.LinSpace" 3676 %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32> 3677 return %0 : tensor<?xf32> 3678} 3679 3680// CHECK-LABEL: func @linspace_invalid_num 3681func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> { 3682 // CHECK: mhlo.constant dense<> : tensor<0xi32> 3683 // CHECK: "tf.LinSpace" 3684 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> 3685 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32> 3686 return %1 : tensor<?xf32> 3687} 3688 3689//===----------------------------------------------------------------------===// 3690// LegacyCall op legalizations. 3691//===----------------------------------------------------------------------===// 3692 3693func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 3694 return %arg0: tensor<10x2xf32> 3695} 3696 3697// CHECK-LABEL: testSimpleLegacyCallOp 3698func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 3699 // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> 3700 %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> 3701 // CHECK: return %[[RESULT]] 3702 return %0: tensor<10x2xf32> 3703} 3704 3705func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 3706 return %arg0: tensor<10x2xf32> 3707} 3708 3709// CHECK-LABEL: testMultiInputLegacyCallOp 3710func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 3711 // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 3712 %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 3713 // CHECK: return %[[RESULT]] 3714 return %0: tensor<10x2xf32> 3715} 3716 3717//===----------------------------------------------------------------------===// 3718// Conv op legalizations. 3719//===----------------------------------------------------------------------===// 3720 3721// CHECK-LABEL: conv_simple 3722func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { 3723 3724 // CHECK: mhlo.convolution(%arg0, %arg1) 3725 // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] 3726 // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]} 3727 // CHECK-SAME: batch_group_count = 1 3728 // CHECK-SAME: feature_group_count = 2 3729 3730 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> 3731 return %0 : tensor<256x8x7x16xf32> 3732} 3733 3734// CHECK-LABEL: conv3d_simple 3735func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { 3736 3737 // CHECK: mhlo.convolution(%arg0, %arg1) 3738 // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] 3739 // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]} 3740 // CHECK-SAME: batch_group_count = 1 3741 // CHECK-SAME: feature_group_count = 2 3742 3743 %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> 3744 return %0 : tensor<256x7x6x5x16xf32> 3745} 3746 3747// CHECK-LABEL: depthwiseconv_simple 3748func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { 3749 // CHECK: %[[RESHAPED_FILTER:.*]] = "mhlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> 3750 // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]]) 3751 // CHECK-SAME: feature_group_count = 3 3752 %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { 3753 data_format = "NHWC", 3754 device = "", 3755 dilations = [1, 1, 1, 1], 3756 explicit_paddings = [], 3757 padding = "VALID", 3758 strides = [1, 1, 1, 1] 3759 } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> 3760 return %0 : tensor<2x3x4x9xf32> 3761} 3762 3763// CHECK-LABEL: conv_valid_padding 3764func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { 3765 // CHECK: mhlo.convolution(%arg0, %arg1) 3766 3767 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> 3768 return %0 : tensor<1x2x3x1xf32> 3769} 3770 3771// CHECK-LABEL: conv_explicit_paddings 3772func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { 3773 3774 // CHECK: mhlo.convolution(%arg0, %arg1) 3775 // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]] 3776 3777 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> 3778 return %0 : tensor<256x9x7x16xf32> 3779} 3780 3781// CHECK-LABEL: @conv2d_backprop_input 3782func @conv2d_backprop_input( 3783 %filter: tensor<3x3x1x32xf32>, 3784 %out_backprop: tensor<100x26x26x32xf32> 3785 ) -> tensor<100x28x28x1xf32> { 3786 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 3787 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) 3788 // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] 3789 // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} 3790 // CHECK-SAME: batch_group_count = 1 : i64 3791 // CHECK-SAME: feature_group_count = 1 : i64 3792 // CHECK: return %[[RESULT]] 3793 %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> 3794 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 3795 data_format = "NHWC", 3796 dilations = [1, 1, 1, 1], 3797 explicit_paddings = [], 3798 padding = "VALID", 3799 strides = [1, 1, 1, 1], 3800 use_cudnn_on_gpu = true 3801 } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> 3802 return %result : tensor<100x28x28x1xf32> 3803} 3804 3805// CHECK-LABEL: @conv2d_backprop_input_grouped 3806func @conv2d_backprop_input_grouped( 3807 %filter: tensor<2x2x5x21xf32>, 3808 %out_backprop: tensor<5x2x2x21xf32> 3809 ) -> tensor<5x3x3x15xf32> { 3810 %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> 3811 3812 // Verify filter transformation for grouped convolution. 3813 3814 // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%arg0) : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> 3815 // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) 3816 // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> 3817 // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> 3818 // CHECK: "mhlo.reshape"(%[[TRANSPOSE]]) : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> 3819 3820 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 3821 data_format = "NHWC", 3822 dilations = [1, 1, 1, 1], 3823 explicit_paddings = [], 3824 padding = "VALID", 3825 strides = [1, 1, 1, 1], 3826 use_cudnn_on_gpu = true 3827 } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> 3828 return %result : tensor<5x3x3x15xf32> 3829} 3830 3831 3832// CHECK-LABEL: @conv3d_backprop_input 3833func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { 3834 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} 3835 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) 3836 // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] 3837 // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} 3838 // CHECK-SAME: batch_group_count = 1 : i64, 3839 // CHECK-SAME: feature_group_count = 1 : i64 3840 3841 // CHECK: return %[[RESULT]] 3842 %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> 3843 %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> 3844 return %result : tensor<2x8x8x8x1xf32> 3845} 3846 3847// CHECK-LABEL: @conv2d_backprop_filter 3848func @conv2d_backprop_filter( 3849 %input: tensor<100x28x28x1xf32>, 3850 %out_backprop: tensor<100x26x26x32xf32> 3851 ) -> tensor<100x28x28x1xf32> { 3852 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) 3853 // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] 3854 // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} 3855 // CHECK-SAME: batch_group_count = 1 : i64 3856 // CHECK-SAME: feature_group_count = 1 : i64 3857 // CHECK: return %[[RESULT]] 3858 %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> 3859 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 3860 data_format = "NHWC", 3861 dilations = [1, 1, 1, 1], 3862 explicit_paddings = [], 3863 padding = "VALID", 3864 strides = [1, 1, 1, 1], 3865 use_cudnn_on_gpu = true 3866 } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> 3867 return %result : tensor<100x28x28x1xf32> 3868} 3869 3870// CHECK-LABEL: @conv2d_backprop_filter_grouped 3871func @conv2d_backprop_filter_grouped( 3872 %input: tensor<1x2x2x2xf32>, 3873 %out_backprop: tensor<1x1x1x2xf32> 3874 ) -> tensor<2x2x1x2xf32> { 3875 3876 // CHECK: mhlo.convolution(%arg0, %arg1) 3877 // CHECK-SAME: batch_group_count = 2 : i64 3878 // CHECK-SAME: feature_group_count = 1 : i64 3879 3880 %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> 3881 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 3882 data_format = "NHWC", 3883 dilations = [1, 1, 1, 1], 3884 explicit_paddings = [], 3885 padding = "VALID", 3886 strides = [1, 1, 1, 1], 3887 use_cudnn_on_gpu = true 3888 } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> 3889 return %result : tensor<2x2x1x2xf32> 3890} 3891 3892 3893// CHECK-LABEL: @conv3d_backprop_filter 3894func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { 3895 // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) 3896 // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] 3897 // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} 3898 // CHECK-SAME: batch_group_count = 1 : i64 3899 // CHECK-SAME: feature_group_count = 1 : i64 3900 // CHECK: return %[[RESULT]] 3901 %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> 3902 %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> 3903 return %result : tensor<2x8x8x8x1xf32> 3904} 3905 3906// CHECK-LABEL: @collective_permute 3907func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { 3908 %source_target_pairs = "tf.Const" () { 3909 value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> 3910 } : () -> tensor<3x2xi32> 3911 3912 // CHECK: "mhlo.collective_permute" 3913 // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> 3914 %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { 3915 } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> 3916 3917 return %0 : tensor<128x32xf32> 3918} 3919 3920// CHECK-LABEL: @cross_replica_sum 3921func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { 3922 %replica_groups = "tf.Const" () { 3923 value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> 3924 } : () -> tensor<2x4xi32> 3925 3926 // CHECK: mhlo.cross-replica-sum 3927 // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> 3928 %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> 3929 return %result : tensor<10xf32> 3930} 3931 3932// CHECK-LABEL: conv_dynamic 3933func @conv_dynamic(%arg0: tensor<?x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> { 3934 // CHECK: "mhlo.dynamic_conv"({{.*}}) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 2 : i64, rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<?x8x7x16xf32> 3935 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<?x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<?x8x7x16xf32> 3936 return %0 : tensor<?x8x7x16xf32> 3937} 3938 3939//===----------------------------------------------------------------------===// 3940// tf.Split legalization 3941//===----------------------------------------------------------------------===// 3942 3943// CHECK-LABEL: @split_not_match_non_const_split_dim 3944func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) { 3945 // CHECK: tf.Split 3946 %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) 3947 return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> 3948} 3949 3950// CHECK-LABEL: @split_not_match_unknown_input_dim 3951func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { 3952 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3953 // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> 3954 // CHECK: divi_signed {{.*}} : index 3955 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 3956 // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> 3957 // CHECK: muli {{.*}} : index 3958 // CHECK: muli {{.*}} : index 3959 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 3960 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 3961 // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> 3962 // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> 3963 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) 3964 return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> 3965} 3966 3967// CHECK-LABEL: @split_match_and_split_into_two 3968func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { 3969 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 3970 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 3971 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 3972 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) 3973 // CHECK: return %[[ONE]], %[[TWO]] 3974 return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> 3975} 3976 3977// CHECK-LABEL: @split_match_and_split_into_two_dynamic 3978func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { 3979 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 3980 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 3981 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 3982 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) 3983 // CHECK: return %[[ONE]], %[[TWO]] 3984 return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> 3985} 3986 3987// CHECK-LABEL: @split_match_and_split_into_three 3988// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 3989func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { 3990 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3991 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3992 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3993 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3994 %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) 3995 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 3996 return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> 3997} 3998 3999//===----------------------------------------------------------------------===// 4000// tf.TopKV2 legalization 4001//===----------------------------------------------------------------------===// 4002 4003// CHECK-LABEL: topk_v2_non_const_k 4004func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) { 4005 // CHECK: tf.TopKV2 4006 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) 4007 return %0#0, %0#1: tensor<?xf32>, tensor<?xi32> 4008} 4009 4010// CHECK-LABEL: topk_v2_unknown_input_last_dim 4011func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { 4012 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 4013 // CHECK: tf.TopKV2 4014 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) 4015 return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> 4016} 4017 4018// CHECK-LABEL: topk_v2 4019// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> 4020func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { 4021 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 4022 4023 // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} 4024 // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { 4025 // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>): 4026 // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {compare_type = "TOTALORDER", comparison_direction = "GT"} 4027 // CHECK-NEXT: "mhlo.return"(%[[CMP]]) 4028 // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) 4029 // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 4030 // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 4031 // CHECK-NEXT: return %[[VAL]], %[[IDX]] 4032 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) 4033 return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> 4034} 4035 4036//===----------------------------------------------------------------------===// 4037// tf.SplitV legalization 4038//===----------------------------------------------------------------------===// 4039 4040// CHECK-LABEL: @splitv_match_and_split_into_three 4041// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 4042func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 4043 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4044 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4045 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> 4046 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 4047 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> 4048 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 4049 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 4050 return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 4051} 4052 4053// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic 4054func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) { 4055 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4056 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4057 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32> 4058 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32> 4059 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32> 4060 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) 4061 return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32> 4062} 4063 4064// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes 4065func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 4066 %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 4067 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 4068 // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> 4069 // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> 4070 // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> 4071 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 4072 return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 4073} 4074 4075//===----------------------------------------------------------------------===// 4076// tf.Assert legalization 4077//===----------------------------------------------------------------------===// 4078 4079// CHECK-LABEL: @assert 4080func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) { 4081 // CHECK-NOT: tf.Assert 4082 "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> () 4083 return 4084} 4085 4086//===----------------------------------------------------------------------===// 4087// tf.Unpack legalization 4088//===----------------------------------------------------------------------===// 4089 4090// CHECK-LABEL: @unpack 4091func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { 4092 // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4093 // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4094 // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4095 // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4096 // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4097 // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4098 4099 %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) 4100 // return %[[RES1]], %[[RES2]], %[[RES3]] 4101 return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> 4102} 4103 4104// CHECK-LABEL: func @unpack_dynamic 4105func @unpack_dynamic(%arg0: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4106 // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32> 4107 // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> 4108 // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32> 4109 // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> 4110 // CHECK: "mhlo.real_dynamic_slice"({{.*}}) : (tensor<?x?x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x?x1xf32> 4111 // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> 4112 // CHECK: "mhlo.dynamic_reshape"({{.*}}) : (tensor<?x?x1xf32>, tensor<2xi32>) -> tensor<?x?xf32> 4113 // CHECK: return {{.*}} : tensor<?x?xf32>, tensor<?x?xf32> 4114 %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4115 return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4116} 4117 4118// CHECK-LABEL: @unpack_unranked 4119func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4120 4121 // CHECK: tf.Unpack 4122 %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4123 return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4124} 4125 4126//===----------------------------------------------------------------------===// 4127// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization 4128//===----------------------------------------------------------------------===// 4129 4130// CHECK-LABEL: @unsorted_segment_sum 4131// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> 4132// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> 4133func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { 4134 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4135 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4136 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4137 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { 4138 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4139 // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32> 4140 // CHECK: "mhlo.return"([[ADD]]) 4141 // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> 4142 // CHECK: return [[SCATTER]] 4143 %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>) 4144 return %0: tensor<4x64xf32> 4145} 4146 4147// CHECK-LABEL: @unsorted_segment_prod 4148// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> 4149// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32> 4150func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4151 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4152 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4153 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4154 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { 4155 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4156 // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32> 4157 // CHECK: "mhlo.return"([[MUL]]) 4158 // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32> 4159 // CHECK: return [[SCATTER]] 4160 %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4161 return %0: tensor<4x?xf32> 4162} 4163 4164// CHECK-LABEL: @unsorted_segment_min 4165func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4166 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4167 // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32> 4168 // CHECK: mhlo.scatter 4169 // CHECK: mhlo.minimum 4170 %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4171 return %0: tensor<4x?xf32> 4172} 4173 4174// CHECK-LABEL: @unsorted_segment_max 4175func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4176 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4177 // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32> 4178 // CHECK: mhlo.scatter 4179 // CHECK: mhlo.maximum 4180 %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4181 return %0: tensor<4x?xf32> 4182} 4183 4184//===----------------------------------------------------------------------===// 4185// tf.GatherNd legalization 4186//===----------------------------------------------------------------------===// 4187// CHECK-LABEL: func @gatherNd_dynamic 4188func @gatherNd_dynamic(%arg0: tensor<?x?x?xi32>, %arg1: tensor<?x6x2xi32>) -> tensor<?x6x?xi32> { 4189 // CHECK: tensor.dim 4190 // CHECK: index_cast 4191 // CHECK: tensor.from_elements 4192 // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = false} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>, tensor<3xi32>) -> tensor<?x6x?xi32> 4193 %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<?x?x?xi32>, tensor<?x6x2xi32>) -> tensor<?x6x?xi32> 4194 return %0 : tensor<?x6x?xi32> 4195} 4196 4197// CHECK-LABEL: func @gatherNd_static 4198func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { 4199 // CHECK: "mhlo.gather"({{.*}}) { 4200 // CHECK-SAME: dimension_numbers = { 4201 // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> 4202 // CHECK-SAME: index_vector_dim = 1 : i64 4203 // CHECK-SAME: offset_dims = dense<[1, 2]> : tensor<2xi64> 4204 // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64>} 4205 // CHECK-SAME: indices_are_sorted = false 4206 // CHECK-SAME: slice_sizes = dense<[1, 4, 128]> : tensor<3xi64> 4207 // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> 4208 %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> 4209 return %0 : tensor<2x4x128xf32> 4210} 4211 4212//===----------------------------------------------------------------------===// 4213// tf.GatherV2 legalization 4214//===----------------------------------------------------------------------===// 4215 4216// CHECK-LABEL: @gather_v2 4217func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { 4218 // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> 4219 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4220 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> 4221 return %1 : tensor<16x2x5xf32> 4222} 4223 4224// CHECK-LABEL: @gather_v2_dynamic 4225func @gather_v2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> { 4226 // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32> 4227 // CHECK: index_cast {{.*}} : index to i32 4228 // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32> 4229 // CHECK: index_cast {{.*}} : index to i32 4230 // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> 4231 // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<3xi32>) -> tensor<*xf32> 4232 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4233 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32> 4234 return %1 : tensor<*xf32> 4235} 4236 4237// CHECK-LABEL: @gather_v2_dynamic_index_i64 4238func @gather_v2_dynamic_index_i64(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>) -> tensor<*xf32> { 4239 // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32> 4240 // CHECK: index_cast {{.*}} : index to i64 4241 // CHECK: tensor.dim {{.*}} : tensor<?x?x?xf32> 4242 // CHECK: index_cast {{.*}} : index to i64 4243 // CHECK: tensor.from_elements {{.*}} : tensor<3xi64> 4244 // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<3xi64>) -> tensor<*xf32> 4245 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4246 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<1xi32>) -> tensor<*xf32> 4247 return %1 : tensor<*xf32> 4248} 4249 4250// CHECK-LABEL: @gather_v2_unranked 4251func @gather_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> { 4252 // CHECK: tf.GatherV2 4253 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4254 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32> 4255 return %1 : tensor<*xf32> 4256} 4257 4258// CHECK-LABEL: @gather_v2_dynamic_shape 4259func @gather_v2_dynamic_shape(%arg0: tensor<?x2x3xf32>, %arg1: tensor<?x5xi32>) -> tensor<?x2x5xf32> { 4260 // CHECK: constant 0 4261 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4262 // CHECK: tensor.dim {{.*}} : tensor<?x2x3xf32> 4263 // CHECK: index_cast {{.*}} : index to i32 4264 // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> 4265 // CHECK: "mhlo.dynamic_gather"({{.*}}) {dimension_numbers = {collapsed_slice_dims = dense<2> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<2> : tensor<1xi64>}, indices_are_sorted = false} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<3xi32>) -> tensor<?x2x5xf32> 4266 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x2x3xf32>, tensor<?x5xi32>, tensor<1xi32>) -> tensor<?x2x5xf32> 4267 return %1 : tensor<?x2x5xf32> 4268} 4269 4270//===----------------------------------------------------------------------===// 4271// tf.StridedSliceGrad legalization 4272//===----------------------------------------------------------------------===// 4273 4274// CHECK-LABEL: strided_slice_grad 4275// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> 4276func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { 4277 4278 // For StridedSlice 4279 // Dim #: 0, 1, 2 4280 // Input shape: [4, 128, 1024] 4281 // Begin: 1, 4, -3 4282 // End: 8, 65, 42 4283 // Stride: 1, 4, -1 4284 // Begin mask: 1, 0, 0 (= 1) 4285 // End mask: 0, 0, 1 (= 4) 4286 4287 // So result shape: 4288 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 4289 // Dim #1: 4 to 65 stride 4: so 16 4290 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 4291 // result shape: [4, 16, 1022] 4292 4293 // To pad back: 4294 // Dim #: 0, 1, 2 4295 // Pad low: 0, 4, 0 4296 // Pad interm: 0, 3, 0 4297 // Pad high: 0, 63, 2 4298 4299 %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4300 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4301 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4302 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4303 4304 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 4305 // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 4306 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4307 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32> 4308 4309 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> 4310 // CHECK: return [[PAD]] 4311 return %0: tensor<4x128x1024xf32> 4312} 4313 4314// CHECK-LABEL: strided_slice_grad_shrink_axis_mask 4315// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> 4316func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { 4317 // Input to StridedSlice was of shape 4x8xf32 4318 // Strided slice gets input[2:3, 0:8] 4319 // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 4320 // which is the shape of gradient. 4321 // StridedSliceGrad would reshape the gradient to 1x8xf32 and 4322 // then pad to match the shape of input 4x8xf32. 4323 4324 %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4325 %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4326 %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4327 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4328 4329 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32> 4330 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4331 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4332 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> 4333 // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> 4334 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> 4335 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> 4336 4337 // CHECK: return [[PAD]] : tensor<4x8xf32> 4338 return %0 : tensor<4x8xf32> 4339} 4340 4341// CHECK-LABEL: strided_slice_grad_new_axis_mask 4342// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> 4343func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { 4344 // Input to StridedSlice was of shape 8xf32 4345 // Strided slice gets input[tf.new_axis, 2:4] 4346 // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is 4347 // 1x2xf32 which is the shape of gradient. 4348 // StridedSliceGrad would reshape the gradient to 2xf32 and 4349 // then pad to match the shape of input 4x8xf32. 4350 4351 %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 4352 %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4353 %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4354 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4355 4356 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32> 4357 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4358 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4359 // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> 4360 // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> 4361 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> 4362 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> 4363 4364 // CHECK: return [[PAD]] : tensor<8xf32> 4365 return %0 : tensor<8xf32> 4366} 4367 4368// CHECK-LABEL: strided_slice_grad_ellipsis_mask 4369// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> 4370func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { 4371 // Input to StridedSlice was of shape 4x4x8xf32 4372 // Strided slice gets input[2:4, ...] 4373 // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and 4374 // dim#2, ignoring begin and end indices for these dimensions. So the output 4375 // is 2x4x8xf32 which is the shape of gradient. 4376 // StridedSliceGrad would pad the gradient to match the shape of 4377 // input 4x4x8xf32. 4378 4379 %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4380 %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4381 %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4382 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4383 4384 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> 4385 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4386 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4387 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> 4388 // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> 4389 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> 4390 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> 4391 4392 // CHECK: return [[PAD]] : tensor<4x4x8xf32> 4393 return %0 : tensor<4x4x8xf32> 4394} 4395 4396 4397// CHECK-LABEL: strided_slice_grad_all_masks 4398// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> 4399func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { 4400 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 4401 // New axis mask is at index 1 and 6 of sparse spec, so 4402 // new_axis_mask = 2^1 + 2^6 = 66 4403 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 4404 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 4405 // The StridedSliceGrad op would propogate the gradient for the sliced tensor 4406 // to the original input tensor by padding with zeroes. 4407 4408 %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) 4409 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4410 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4411 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4412 4413 // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) 4414 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> 4415 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4416 // The edge_padding_low, edge_padding_high and interior_padding attributes of 4417 // mhlo.pad would reflect the padding required to get the shape of the 4418 // input of StridedSlice op. 4419 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) 4420 // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 4421 // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> 4422 // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> 4423 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> 4424 4425 // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> 4426 return %0 : tensor<2x4x8x16x32x64xf32> 4427} 4428 4429// CHECK-LABEL: @tensor_scatter_update 4430func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4431 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4432 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4433 // CHECK: "mhlo.return"(%arg4) : (tensor<f32>) -> () 4434 // CHECK: }) 4435 // CHECK-SAME: indices_are_sorted = false 4436 // CHECK-SAME: scatter_dimension_numbers 4437 // CHECK-SAME: index_vector_dim = 1 : i64 4438 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4439 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4440 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4441 // CHECK-SAME: unique_indices = false 4442 %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4443 return %0 : tensor<?x?x?xf32> 4444} 4445 4446// CHECK-LABEL: @tensor_scatter_add 4447func @tensor_scatter_add(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4448 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4449 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4450 // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor<f32> 4451 // CHECK: "mhlo.return"(%1) : (tensor<f32>) -> () 4452 // CHECK: }) 4453 // CHECK-SAME: indices_are_sorted = false 4454 // CHECK-SAME: scatter_dimension_numbers 4455 // CHECK-SAME: index_vector_dim = 1 : i64 4456 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4457 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4458 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4459 // CHECK-SAME: unique_indices = false 4460 %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4461 return %0 : tensor<?x?x?xf32> 4462} 4463 4464// CHECK-LABEL: @tensor_scatter_sub 4465func @tensor_scatter_sub(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4466 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4467 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4468 // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor<f32> 4469 // CHECK: "mhlo.return"(%1) : (tensor<f32>) -> () 4470 // CHECK: }) 4471 // CHECK-SAME: indices_are_sorted = false 4472 // CHECK-SAME: scatter_dimension_numbers 4473 // CHECK-SAME: index_vector_dim = 1 : i64 4474 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4475 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4476 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4477 // CHECK-SAME: unique_indices = false 4478 %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4479 return %0 : tensor<?x?x?xf32> 4480} 4481 4482// CHECK-LABEL: @tensor_scatter_min 4483func @tensor_scatter_min(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4484 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4485 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4486 // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor<f32> 4487 // CHECK: "mhlo.return"(%1) : (tensor<f32>) -> () 4488 // CHECK: }) 4489 // CHECK-SAME: indices_are_sorted = false 4490 // CHECK-SAME: scatter_dimension_numbers 4491 // CHECK-SAME: index_vector_dim = 1 : i64 4492 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4493 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4494 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4495 // CHECK-SAME: unique_indices = false 4496 %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4497 return %0 : tensor<?x?x?xf32> 4498} 4499 4500// CHECK-LABEL: @tensor_scatter_max 4501func @tensor_scatter_max(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4502 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4503 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4504 // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor<f32> 4505 // CHECK: "mhlo.return"(%1) : (tensor<f32>) -> () 4506 // CHECK: }) 4507 // CHECK-SAME: indices_are_sorted = false 4508 // CHECK-SAME: scatter_dimension_numbers 4509 // CHECK-SAME: index_vector_dim = 1 : i64 4510 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4511 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4512 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4513 // CHECK-SAME: unique_indices = false 4514 %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4515 return %0 : tensor<?x?x?xf32> 4516} 4517 4518//===----------------------------------------------------------------------===// 4519// tf.RandomShuffle legalization 4520//===----------------------------------------------------------------------===// 4521 4522// CHECK-LABEL: @random_shuffle_first_dim_1 4523// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> 4524func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { 4525 %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) 4526 // CHECK-NEXT: return [[INPUT]] 4527 return %0: tensor<1x?xf32> 4528} 4529 4530// CHECK-LABEL: @random_shuffle_1D_16 4531// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> 4532func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { 4533 // CHECK: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> 4534 // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 4535 // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32> 4536 // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) 4537 // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( { 4538 // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>): 4539 // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} 4540 // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) 4541 // CHECK: return [[SORT]]#1 4542 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) 4543 return %0: tensor<16xf32> 4544} 4545 4546// CHECK-LABEL: @random_shuffle_1D_10240 4547func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { 4548 // CHECK: mhlo.rng_uniform 4549 // CHECK: mhlo.sort 4550 // CHECK: mhlo.rng_uniform 4551 // CHECK: mhlo.sort 4552 %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) 4553 return %0: tensor<10240xf32> 4554} 4555 4556// CHECK-LABEL: @random_shuffle_3D 4557// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> 4558func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { 4559 // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> 4560 4561 // CHECK: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> 4562 // CHECK: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 4563 // CHECK: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32> 4564 // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) 4565 4566 // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32> 4567 // CHECK: [[WHILE_INIT:%.*]] = "mhlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]]) 4568 4569 // CHECK: [[WHILE_OUT:%.*]] = "mhlo.while"([[WHILE_INIT]]) ( { 4570 // CHECK: ^{{.*}}([[COND_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>): 4571 // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 4572 // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32> 4573 // CHECK: [[CMP:%.*]] = "mhlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"} 4574 // CHECK: "mhlo.return"([[CMP]]) 4575 // CHECK: }, { 4576 // CHECK: ^{{.*}}([[BODY_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>): 4577 // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} 4578 // CHECK: [[SWAPS:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} 4579 // CHECK: [[INDICES:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} 4580 // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 4581 // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 4582 // CHECK: [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor<i32> 4583 // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>} 4584 // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 4585 // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 4586 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32> 4587 // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]] 4588 // CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) 4589 // CHECK: "mhlo.return"([[NEW_TUPLE]]) 4590 // CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>> 4591 4592 // CHECK: [[SWAPED_INDICES:%.*]] = "mhlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32> 4593 // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[SWAPED_INDICES]]) 4594 // CHECK-SAME: dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<[1, 2]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>} 4595 // CHECK-SAME: indices_are_sorted = false 4596 // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> : tensor<3xi64> 4597 // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32> 4598 4599 // CHECK: return [[GATHER]] 4600 4601 %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) 4602 return %0: tensor<4x?x16xf32> 4603} 4604 4605//===----------------------------------------------------------------------===// 4606// tf.AvgPool legalization 4607//===----------------------------------------------------------------------===// 4608 4609// CHECK-LABEL: @avgpool_valid_padding 4610// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> 4611// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> 4612// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4613// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4614// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4615// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4616// CHECK: "mhlo.return"([[ADD]]) 4617// CHECK: }) 4618// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 4619// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 4620// CHECK-SAME: -> tensor<2x3x5x7xf32> 4621// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4622// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4623// CHECK-SAME: broadcast_dimensions = dense<> 4624// CHECK-SAME: -> tensor<2x3x5x7xf32> 4625// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4626// CHECK-SAME: -> tensor<2x3x5x7xf16> 4627// CHECK: return [[CONV16]] 4628func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { 4629 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> 4630 return %0 : tensor<2x3x5x7xf16> 4631} 4632 4633// CHECK-LABEL: @avgpool_3d_valid_padding 4634// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> 4635// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> 4636// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4637// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4638// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4639// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4640// CHECK: "mhlo.return"([[ADD]]) 4641// CHECK: }) 4642// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 4643// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 4644// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 4645// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4646// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4647// CHECK-SAME: broadcast_dimensions = dense<> 4648// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 4649// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4650// CHECK-SAME: -> tensor<2x4x3x5x7xf16> 4651// CHECK: return [[CONV16]] 4652func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { 4653 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> 4654 return %0 : tensor<2x4x3x5x7xf16> 4655} 4656 4657// CHECK-LABEL: @avgpool_nchw_format 4658// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> 4659// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> 4660// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4661// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4662// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4663// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4664// CHECK: "mhlo.return"([[ADD]]) 4665// CHECK: }) 4666// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> 4667// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 4668// CHECK-SAME: -> tensor<2x7x3x5xf32> 4669// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4670// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4671// CHECK-SAME: broadcast_dimensions = dense<> 4672// CHECK-SAME: -> tensor<2x7x3x5xf32> 4673// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4674// CHECK-SAME: -> tensor<2x7x3x5xf16> 4675// CHECK: return [[CONV16]] 4676func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { 4677 %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> 4678 return %0 : tensor<2x7x3x5xf16> 4679} 4680 4681// CHECK-LABEL: @avgpool_3d_ncdhw_format 4682// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> 4683// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> 4684// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4685// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4686// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4687// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4688// CHECK: "mhlo.return"([[ADD]]) 4689// CHECK: }) 4690// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> 4691// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 4692// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 4693// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4694// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4695// CHECK-SAME: broadcast_dimensions = dense<> 4696// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 4697// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4698// CHECK-SAME: -> tensor<2x7x4x3x5xf16> 4699// CHECK: return [[CONV16]] 4700func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { 4701 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> 4702 return %0 : tensor<2x7x4x3x5xf16> 4703} 4704 4705// CHECK-LABEL: @avgpool_same_padding( 4706// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 4707// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4708// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { 4709// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4710// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4711// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4712// CHECK: }) 4713// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 4714// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 4715// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 4716// CHECK-SAME: -> tensor<2x4x6x7xf32> 4717// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> 4718// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { 4719// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4720// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4721// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4722// CHECK: }) 4723// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 4724// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 4725// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 4726// CHECK-SAME: -> tensor<2x4x6x7xf32> 4727// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> 4728// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> 4729// CHECK: } 4730func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { 4731 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 4732 return %0 : tensor<2x4x6x7xf32> 4733} 4734 4735// CHECK-LABEL: @avgpool_3d_same_padding( 4736// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 4737// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4738// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { 4739// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4740// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4741// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4742// CHECK: }) 4743// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 4744// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 4745// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 4746// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 4747// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> 4748// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { 4749// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4750// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4751// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4752// CHECK: }) 4753// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 4754// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 4755// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 4756// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 4757// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] 4758// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> 4759// CHECK: } 4760func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { 4761 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 4762 return %0 : tensor<2x4x4x6x7xf32> 4763} 4764 4765//===----------------------------------------------------------------------===// 4766// AvgPoolGrad op legalizations. 4767//===----------------------------------------------------------------------===// 4768 4769// CHECK-LABEL: @avgpool_grad_valid_padding( 4770// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 4771// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4772// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4773// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 4774// CHECK_SAME: broadcast_dimensions = dense<> 4775// CHECK_SAME: -> tensor<10x12x16x64xf32> 4776// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4777// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 4778// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4779// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 4780// CHECK-SAME: -> tensor<10x25x33x64xf32> 4781// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4782// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4783// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4784// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 4785// CHECK: }) 4786// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 4787// CHECK-SAME: window_strides = dense<1> 4788// CHECK-SAME: -> tensor<10x24x32x64xf32> 4789// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> 4790func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 4791 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4792 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4793 data_format = "NHWC", 4794 ksize = [1, 2, 2, 1], 4795 padding = "VALID", 4796 strides = [1, 2, 2, 1] 4797 } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> 4798 return %result : tensor<10x24x32x64xf32> 4799} 4800 4801// CHECK-LABEL: @avgpool_3d_grad_valid_padding( 4802// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 4803// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4804// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4805// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32> 4806// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4807// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> 4808// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 4809// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> 4810// CHECK-SAME: -> tensor<10x8x25x33x64xf32> 4811// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4812// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4813// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4814// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 4815// CHECK: }) 4816// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 4817// CHECK-SAME: window_strides = dense<1> 4818// CHECK-SAME: -> tensor<10x8x24x32x64xf32> 4819// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> 4820func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 4821 %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4822 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4823 data_format = "NDHWC", 4824 ksize = [1, 1, 2, 2, 1], 4825 padding = "VALID", 4826 strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> 4827 return %result : tensor<10x8x24x32x64xf32> 4828} 4829 4830// CHECK-LABEL: @avgpool_grad_same_padding( 4831// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 4832// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4833// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> 4834// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4835// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4836// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4837// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4838// CHECK: }) 4839// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> 4840// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 4841// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 4842// CHECK-SAME: -> tensor<2x4x7x9xf32> 4843// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> 4844// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4845// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> 4846// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4847// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> 4848// CHECK-SAME: -> tensor<2x14x27x9xf32> 4849// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4850// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4851// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4852// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4853// CHECK: }) 4854// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 4855// CHECK-SAME: window_strides = dense<1> 4856// CHECK-SAME: -> tensor<2x13x25x9xf32> 4857// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> 4858func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 4859 %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4860 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4861 data_format = "NHWC", 4862 ksize = [1, 2, 3, 1], 4863 padding = "SAME", 4864 strides = [1, 4, 4, 1] 4865 } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> 4866 return %result : tensor<2x13x25x9xf32> 4867} 4868 4869// CHECK-LABEL: @avgpool_3d_grad_same_padding( 4870// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 4871// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4872// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> 4873// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4874// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4875// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4876// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4877// CHECK: }) 4878// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> 4879// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 4880// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 4881// CHECK-SAME: -> tensor<2x8x4x7x9xf32> 4882// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> 4883// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4884// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> 4885// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 4886// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> 4887// CHECK-SAME: -> tensor<2x8x14x27x9xf32> 4888// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4889// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4890// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4891// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4892// CHECK: }) 4893// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 4894// CHECK-SAME: window_strides = dense<1> 4895// CHECK-SAME: -> tensor<2x8x13x25x9xf32> 4896// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> 4897func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 4898 %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4899 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4900 data_format = "NDHWC", 4901 ksize = [1, 1, 2, 3, 1], 4902 padding = "SAME", 4903 strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> 4904 return %result : tensor<2x8x13x25x9xf32> 4905} 4906 4907// CHECK-LABEL: @avgpool_grad_nchw_format( 4908// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 4909// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4910// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> 4911// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4912// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4913// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4914// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4915// CHECK: }) 4916// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> 4917// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 4918// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 4919// CHECK-SAME: -> tensor<2x9x4x7xf32> 4920// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> 4921// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4922// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> 4923// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> 4924// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> 4925// CHECK-SAME: -> tensor<2x9x14x27xf32> 4926// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4927// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4928// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4929// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4930// CHECK: }) 4931// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 4932// CHECK-SAME: window_strides = dense<1> 4933// CHECK-SAME: -> tensor<2x9x13x25xf32> 4934// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> 4935func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 4936 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4937 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4938 data_format = "NCHW", 4939 ksize = [1, 1, 2, 3], 4940 padding = "SAME", 4941 strides = [1, 1, 4, 4] 4942 } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> 4943 return %result : tensor<2x9x13x25xf32> 4944} 4945 4946// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( 4947// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 4948// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4949// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> 4950// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4951// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4952// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4953// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4954// CHECK: }) 4955// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> 4956// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 4957// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 4958// CHECK-SAME: -> tensor<2x9x8x4x7xf32> 4959// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> 4960// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4961// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> 4962// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> 4963// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> 4964// CHECK-SAME: -> tensor<2x9x8x14x27xf32> 4965// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4966// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4967// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4968// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4969// CHECK: }) 4970// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 4971// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> 4972// CHECK-SAME: -> tensor<2x9x8x13x25xf32> 4973// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> 4974func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 4975 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4976 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4977 data_format = "NCDHW", 4978 ksize = [1, 1, 1, 2, 3], 4979 padding = "SAME", 4980 strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> 4981 return %result : tensor<2x9x8x13x25xf32> 4982} 4983 4984// CHECK-LABEL: @avgpool_grad_bf16( 4985// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 4986// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16> 4987// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16> 4988// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 4989// CHECK-SAME: broadcast_dimensions = dense<> 4990// CHECK-SAME: -> tensor<10x12x16x64xbf16> 4991// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4992// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 4993// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4994// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 4995// CHECK-SAME: -> tensor<10x25x33x64xbf16> 4996// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "mhlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> 4997// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4998// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( { 4999// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 5000// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 5001// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 5002// CHECK: }) 5003// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 5004// CHECK-SAME: window_strides = dense<1> 5005// CHECK-SAME: -> tensor<10x24x32x64xf32> 5006// CHECK: %[[RESULT_CONVERTED:.*]] = "mhlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> 5007// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> 5008func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 5009 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 5010 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 5011 data_format = "NHWC", 5012 ksize = [1, 2, 2, 1], 5013 padding = "VALID", 5014 strides = [1, 2, 2, 1] 5015 } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> 5016 return %result : tensor<10x24x32x64xbf16> 5017} 5018 5019// CHECK-LABEL: xla_sharding 5020func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { 5021 // CHECK-NEXT: "mhlo.custom_call"(%arg0) {api_version = 1 : i32, backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""} 5022 %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> 5023 return %0 : tensor<4x16xf32> 5024} 5025 5026// CHECK-LABEL: inplace_update_one 5027func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { 5028 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 5029 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5030 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 5031 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) 5032 // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) 5033 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> 5034 5035 // CHECK: return [[UPDATE]] 5036 return %0 : tensor<8x4xf32> 5037} 5038 5039// CHECK-LABEL: inplace_update_three 5040func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { 5041 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 5042 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5043 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5044 // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 5045 // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5046 // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5047 // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 5048 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) 5049 // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) 5050 // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]]) 5051 // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) 5052 // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) 5053 // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) 5054 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> 5055 5056 // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> 5057 return %0 : tensor<8x8x4xf32> 5058} 5059 5060 5061// CHECK-LABEL: xla_dynamic_update_slice 5062func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { 5063 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 5064 // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32> 5065 // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 5066 // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor<i32> 5067 // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32> 5068 // CHECK: return [[DUS]] 5069 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> 5070 return %0 : tensor<4x16xf32> 5071} 5072 5073// CHECK-LABEL: xla_dynamic_update_slice2 5074func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { 5075 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 5076 // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32> 5077 // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32> 5078 // CHECK: return [[DUS]] 5079 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> 5080 return %0 : tensor<4xf32> 5081} 5082 5083//===----------------------------------------------------------------------===// 5084// AllToAll op legalizations. 5085//===----------------------------------------------------------------------===// 5086 5087// CHECK-LABEL: func @alltoall_basic 5088func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { 5089 %group_assignment = "tf.Const" () { 5090 value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> 5091 } : () -> tensor<3x4xi32> 5092 %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> 5093 // CHECK: mhlo.all_to_all 5094 // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> 5095 return %result : tensor<10xf32> 5096} 5097 5098//===----------------------------------------------------------------------===// 5099// Cumsum op legalizations. 5100//===----------------------------------------------------------------------===// 5101 5102// CHECK-LABEL: func @cumsum_static 5103// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5104func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5105 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5106 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> 5107 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5108 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 5109 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5110 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5111 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 5112 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5113 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> 5114 // CHECK: return [[CONVERT_REDUCE]] 5115 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5116 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5117 return %1 : tensor<4xf32> 5118} 5119 5120// CHECK-LABEL: func @cumsum_exclusive 5121// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5122func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5123 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5124 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> 5125 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5126 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 5127 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5128 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5129 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 5130 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5131 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5132 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> 5133 // CHECK: return [[CONVERT_REDUCE]] 5134 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5135 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5136 return %1 : tensor<4xf32> 5137} 5138 5139// CHECK-LABEL: func @cumsum_reverse 5140// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5141func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5142 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5143 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5144 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> 5145 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5146 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 5147 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5148 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5149 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 5150 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5151 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> 5152 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5153 // CHECK: return [[REVERSE_BACK]] 5154 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5155 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5156 return %1 : tensor<4xf32> 5157} 5158 5159// CHECK-LABEL: func @cumsum_exclusive_reverse 5160// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 5161func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5162 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 5163 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5164 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> 5165 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 5166 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 5167 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 5168 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 5169 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 5170 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5171 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 5172 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> 5173 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 5174 // CHECK: return [[REVERSE_BACK]] 5175 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5176 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5177 return %1 : tensor<4xf32> 5178} 5179 5180// CHECK-LABEL: func @cumsum_empty 5181func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { 5182 %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 5183 5184 // CHECK: mhlo.constant dense<> : tensor<0xf32> 5185 %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32> 5186 return %1 : tensor<0xf32> 5187} 5188 5189// CHECK-LABEL: func @cumsum_dynamic 5190func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> { 5191 // CHECK: "tf.Cumsum" 5192 %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32> 5193 return %0 : tensor<?xf32> 5194} 5195 5196//===----------------------------------------------------------------------===// 5197// Cumprod op legalizations. 5198//===----------------------------------------------------------------------===// 5199 5200// CHECK-LABEL: func @cumprod 5201func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { 5202 // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 5203 // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { 5204 // CHECK: mhlo.mul 5205 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 5206 %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 5207 return %1 : tensor<4xf32> 5208} 5209 5210//===----------------------------------------------------------------------===// 5211// Qr op legalization 5212//===----------------------------------------------------------------------===// 5213 5214// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 5215func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { 5216 // The tf.Qr lowering is a full algorithm that is not effective to verify with 5217 // FileCheck. Just verify that it converted. 5218 // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is 5219 // really only applicable to certain legacy uses. 5220 // CHECK-NOT: "tf.Qr" 5221 %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 5222 return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> 5223} 5224 5225//===----------------------------------------------------------------------===// 5226// tf.Softplus legalization 5227//===----------------------------------------------------------------------===// 5228 5229// CHECK-LABEL: func @softplus_f16 5230// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) 5231func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { 5232 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5233 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16> 5234 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5235 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16> 5236 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5237 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5238 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5239 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5240 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5241 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5242 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5243 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> 5244 5245 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> 5246 return %0 : tensor<8x16xf16> 5247} 5248 5249// CHECK-LABEL: func @softplus_bf16 5250// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) 5251func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { 5252 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5253 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16> 5254 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5255 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16> 5256 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5257 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5258 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5259 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5260 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5261 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5262 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5263 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> 5264 5265 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> 5266 return %0 : tensor<8x16xbf16> 5267} 5268 5269// CHECK-LABEL: func @softplus_f32 5270// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) 5271func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 5272 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5273 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32> 5274 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5275 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 5276 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5277 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5278 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5279 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5280 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5281 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5282 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5283 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> 5284 5285 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> 5286 return %0 : tensor<8x16xf32> 5287} 5288 5289// CHECK-LABEL: func @softplus_f64 5290// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) 5291func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { 5292 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5293 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64> 5294 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5295 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64> 5296 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5297 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5298 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5299 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5300 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5301 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5302 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5303 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> 5304 5305 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> 5306 return %0 : tensor<8x16xf64> 5307} 5308 5309// CHECK-LABEL: @xla_gather 5310func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { 5311 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> 5312 5313 // CHECK: "mhlo.gather" 5314 // CHECK-SAME: dimension_numbers = 5315 // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> 5316 // CHECK-SAME: index_vector_dim = 1 : i64 5317 // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> 5318 // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> 5319 // CHECK-SAME: indices_are_sorted = true 5320 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 5321 5322 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32> 5323 return %0 : tensor<10x1x300xf32> 5324} 5325 5326// CHECK-LABEL: @xla_gather_i32 5327func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { 5328 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> 5329 5330 // CHECK: "mhlo.gather" 5331 // CHECK-SAME: dimension_numbers = 5332 // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> 5333 // CHECK-SAME: index_vector_dim = 1 : i64 5334 // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> 5335 // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> 5336 // CHECK-SAME: indices_are_sorted = true 5337 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 5338 5339 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32> 5340 return %0 : tensor<10x1x300xf32> 5341} 5342 5343 5344// CHECK: func @stridedslice_with_i32 5345func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { 5346// CHECK-NOT: tf.StridedSlice 5347// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic-slice 5348// CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[DYNSLICE]]) 5349// CHECK: return [[RESHAPE]] 5350 %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> 5351 %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 5352 %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 5353 %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 5354 %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 5355 %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 5356 %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> 5357 return %6 : tensor<4xf32> 5358} 5359 5360func @replica_id() -> tensor<i32> { 5361 // CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32> 5362 // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32> 5363 %0 = "tf.XlaReplicaId"() : () -> tensor<i32> 5364 return %0 : tensor<i32> 5365} 5366 5367// CHECK: func @angle_c64 5368// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>) 5369func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> { 5370// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]]) 5371// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]]) 5372// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] 5373 %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32> 5374 return %0 : tensor<f32> 5375} 5376 5377//===----------------------------------------------------------------------===// 5378// tf.XlaDotV2 legalization 5379//===----------------------------------------------------------------------===// 5380 5381// CHECK-LABEL: @xladotv2_matmul( 5382// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> 5383func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { 5384 // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) { 5385 // CHECK-SAME: dot_dimension_numbers = { 5386 // CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, 5387 // CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>, 5388 // CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, 5389 // CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64> 5390 // CHECK-SAME: }, precision_config = []} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> 5391 %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> 5392 return %res : tensor<64x16xi32> 5393} 5394 5395//===----------------------------------------------------------------------===// 5396// tf.Print legalization 5397//===----------------------------------------------------------------------===// 5398// CHECK-LABEL: @simple_print 5399func @simple_print() -> (tensor<*xi32>) { 5400 // CHECK: mhlo.constant dense<1> : tensor<i32> 5401 // CHECK: tensor.cast {{.*}} : tensor<i32> to tensor<*xi32> 5402 // CHECK: "mhlo.print"({{.*}}) : (tensor<*xi32>) -> tensor<*xi32> 5403 %const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32> 5404 %print = "tf.Print"(%const) { message = "bla" } : (tensor<*xi32>) -> (tensor<*xi32>) 5405 return %print: tensor<*xi32> 5406} 5407