1// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s 2 3func.func @main() -> tensor<f32> { 4 %cst = arith.constant dense<1> : tensor<i32> 5 %cst_0 = arith.constant dense<5.600000e+01> : tensor<f32> 6 %cst_1 = arith.constant dense<1.200000e+01> : tensor<f32> 7 %cst_2 = arith.constant dense<1.300000e+01> : tensor<f32> 8 %0 = "mhlo.case"(%cst) ({ 9 %1 = "mhlo.negate"(%cst_0) : (tensor<f32>) -> tensor<f32> 10 "mhlo.return"(%1) : (tensor<f32>) -> () 11 }, { 12 %1 = "mhlo.copy"(%cst_1) : (tensor<f32>) -> tensor<f32> 13 "mhlo.return"(%1) : (tensor<f32>) -> () 14 }, { 15 %1 = "mhlo.floor"(%cst_2) : (tensor<f32>) -> tensor<f32> 16 "mhlo.return"(%1) : (tensor<f32>) -> () 17 }) : (tensor<i32>) -> tensor<f32> 18 func.return %0 : tensor<f32> 19} 20 21// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { 22// CHECK: %[[ARG:.*]] = f32[] parameter(0) 23// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]]) 24// CHECK: } 25 26// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { 27// CHECK: %[[ARG:.*]] = f32[] parameter(0) 28// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]]) 29// CHECK: } 30 31// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { 32// CHECK: %[[ARG:.*]] = f32[] parameter(0) 33// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]]) 34// CHECK: } 35 36// CHECK-LABEL: ENTRY 37// CHECK-SAME: () -> f32[] 38 39// CHECK-DAG: %[[INDEX:.*]] = s32[] constant(1) 40// CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) 41// CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) 42// CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) 43// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} 44 45// ----- 46 47func.func @main() -> (tensor<f32>, tensor<f32>) { 48 %cst = arith.constant dense<1> : tensor<i32> 49 %cst_0 = arith.constant dense<5.600000e+01> : tensor<f32> 50 %cst_1 = arith.constant dense<1.200000e+01> : tensor<f32> 51 %cst_2 = arith.constant dense<1.300000e+01> : tensor<f32> 52 %0:2 = "mhlo.case"(%cst) ({ 53 %1 = "mhlo.negate"(%cst_0) : (tensor<f32>) -> tensor<f32> 54 "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () 55 }, { 56 %1 = "mhlo.copy"(%cst_1) : (tensor<f32>) -> tensor<f32> 57 "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () 58 }, { 59 %1 = "mhlo.floor"(%cst_2) : (tensor<f32>) -> tensor<f32> 60 "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () 61 }) : (tensor<i32>) -> (tensor<f32>, tensor<f32>) 62 func.return %0#0, %0#1 : tensor<f32>, tensor<f32> 63} 64 65// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { 66// CHECK: %[[ARG:.*]] = f32[] parameter(0) 67// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) 68// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) 69// CHECK: } 70 71// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { 72// CHECK: %[[ARG:.*]] = f32[] parameter(0) 73// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]]) 74// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]]) 75// CHECK: } 76 77// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { 78// CHECK: %[[ARG:.*]] = f32[] parameter(0) 79// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]]) 80// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) 81// CHECK: } 82 83// CHECK-LABEL: ENTRY 84// CHECK-SAME: () -> (f32[], f32[]) 85 86// CHECK-DAG: %[[INDEX:.*]] = s32[] constant(1) 87// CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) 88// CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) 89// CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) 90// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} 91// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0 92// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1 93// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) 94 95// ----- 96// Test export mhlo::CaseOp with diffrent number of block-arguments (even 0). 97 98func.func @main() -> (tensor<f32>, tensor<f32>) { 99 %cst = arith.constant dense<1> : tensor<i32> 100 %cst_0 = arith.constant dense<5.600000e+01> : tensor<f32> 101 %cst_1 = arith.constant dense<1.200000e+01> : tensor<f32> 102 %cst_2 = arith.constant dense<1.300000e+01> : tensor<f32> 103 %0:2 = "mhlo.case"(%cst) ({ 104 %1 = "mhlo.negate"(%cst_0) : (tensor<f32>) -> tensor<f32> 105 "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () 106 }, { 107 %1 = "mhlo.copy"(%cst_1) : (tensor<f32>) -> tensor<f32> 108 %2 = "mhlo.copy"(%cst_2) : (tensor<f32>) -> tensor<f32> 109 "mhlo.return"(%1, %2) : (tensor<f32>, tensor<f32>) -> () 110 }, { 111 %cst_3 = arith.constant dense<1.300000e+01> : tensor<f32> 112 %1 = "mhlo.floor"(%cst_3) : (tensor<f32>) -> tensor<f32> 113 "mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () 114 }) : (tensor<i32>) -> (tensor<f32>, tensor<f32>) 115 func.return %0#0, %0#1 : tensor<f32>, tensor<f32> 116} 117 118// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { 119// CHECK: %[[ARG:.*]] = f32[] parameter(0) 120// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) 121// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) 122// CHECK: } 123 124// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: (f32[], f32[])) -> (f32[], f32[]) { 125// CHECK: %[[ARG:.*]] = (f32[], f32[]) parameter(0) 126// CHECK-DAG: %[[GTE1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG]]), index=0 127// CHECK-DAG: %[[COPY1:.*]] = f32[] copy(f32[] %[[GTE1]]) 128// CHECK-DAG: %[[GTE2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG]]), index=1 129// CHECK-DAG: %[[COPY2:.*]] = f32[] copy(f32[] %[[GTE2]]) 130// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY1]], f32[] %[[COPY2]]) 131// CHECK: } 132 133// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: ()) -> (f32[], f32[]) { 134// CHECK: %[[ARG:.*]] = () parameter(0) 135// CHECK: %[[CST:.*]] = f32[] constant 136// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[CST]]) 137// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) 138// CHECK: } 139 140// CHECK-LABEL: ENTRY 141// CHECK-SAME: () -> (f32[], f32[]) 142 143// CHECK-DAG: %[[INDEX:.*]] = s32[] constant(1) 144// CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) 145// CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) 146// CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) 147// CHECK-DAG: %[[TUPLE1:.*]] = (f32[], f32[]) tuple(f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]) 148// CHECK-DAG: %[[TUPLE2:.*]] = () tuple() 149 150// CHECK: %[[COND:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], (f32[], f32[]) %[[TUPLE1]], () %[[TUPLE2]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} 151 152// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[COND]]), index=0 153// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[COND]]), index=1 154// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) 155