1// RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s 2 3// CHECK-LABEL: @if 4// CHECK-SAME: ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>) 5func @if(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) { 6 // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 7 %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 8 // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) 9 // CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { 10 // CHECK: ^bb0([[THEN_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>): 11 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 0 : i32} 12 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 1 : i32} 13 // CHECK: [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]]) 14 // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) 15 // CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> () 16 // CHECK: }, { 17 // CHECK: ^bb0([[ELSE_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>) 18 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 0 : i32} 19 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 1 : i32} 20 // CHECK: [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]]) 21 // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) 22 // CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> () 23 // CHECK: }) 24 %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> 25 26 // CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32} 27 // CHECK: return [[VAL3]] 28 return %1 : tensor<f32> 29} 30 31func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> 32attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { 33 %0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> 34 return %0 : tensor<f32> 35} 36 37func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> 38attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { 39 %0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32> 40 return %0 : tensor<f32> 41} 42 43 44// CHECK-LABEL: @ifRegion 45// CHECK-SAME: ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>) 46func @ifRegion(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) { 47 // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"} 48 %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 49 // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]]) 50 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG1]]) 51 // CHECK: [[VAL3:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL2]]) ( { 52 %1 = "tf.IfRegion"(%0) ( { 53 // CHECK: ^{{[a-z0-9]+}}([[TRUE_ARG:%.+]]: tuple<tensor<f32>>): 54 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[TRUE_ARG]]) {index = 0 : i32} 55 // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL5]]) 56 %2 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32> 57 // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) 58 // CHECK: "mhlo.return"([[VAL7]]) 59 "tf.Yield"(%2) : (tensor<f32>) -> () 60 }, { 61 // CHECK: ^{{[a-z0-9]+}}([[FALSE_ARG:%.+]]: tuple<tensor<f32>>): 62 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[FALSE_ARG]]) {index = 0 : i32} 63 // CHECK: [[VAL6:%.+]] = "mhlo.exponential"([[VAL5]]) 64 %2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> 65 // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]]) 66 // CHECK: "mhlo.return"([[VAL7]]) 67 "tf.Yield"(%2) : (tensor<f32>) -> () 68 // CHECK: }) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>> 69 }) {is_stateless = true} : (tensor<i1>) -> tensor<f32> 70 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} 71 // CHECK: return [[VAL4]] 72 return %1 : tensor<f32> 73} 74 75 76// CHECK-LABEL: func @case 77// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>) 78func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { 79 %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) 80 // CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>> 81 // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { 82 // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): 83 // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 84 // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 85 // CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) 86 // CHECK: "mhlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> () 87 // CHECK: }, { 88 // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): 89 // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 90 // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 91 // CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) 92 // CHECK: "mhlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> () 93 // CHECK: }, { 94 // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): 95 // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 96 // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> 97 // CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) 98 // CHECK: "mhlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> () 99 // CHECK: }) : (tensor<i32>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>) 100 return %0#0, %0#1 : tensor<f32>, tensor<f32> 101// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor<f32>, tensor<f32> 102} 103 104func @exponential(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { 105 %0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> 106 return %0, %arg1 : tensor<f32>, tensor<f32> 107} 108 109func @log(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { 110 %0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32> 111 return %0, %arg1 : tensor<f32>, tensor<f32> 112} 113 114func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { 115 %0 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> 116 return %0, %arg1 : tensor<f32>, tensor<f32> 117} 118 119 120// CHECK-LABEL: func @caseRegion 121// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor<i32>, [[ARG0:.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>) 122func @caseRegion(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { 123 // CHECK: [[VAL0:%.+]] = "mhlo.tuple"([[ARG1]]) 124 // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) 125 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]]) 126 // CHECK: [[VAL3:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]], [[VAL0]], [[VAL1]], [[VAL2]]) ( { 127 %0:2 = "tf.CaseRegion"(%index) ( { 128 // CHECK: ^{{[a-z0-9]+}}([[BRANCH0_ARG:%.+]]: tuple<tensor<f32>>): 129 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH0_ARG]]) {index = 0 : i32} 130 // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) 131 %1 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> 132 // CHECK: "mhlo.return"([[VAL5]], [[VAL4]]) 133 "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> () 134 }, { 135 // CHECK: ^{{[a-z0-9]+}}([[BRANCH1_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>): 136 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 0 : i32} 137 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 1 : i32} 138 // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL4]]) 139 %1 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32> 140 // CHECK: "mhlo.return"([[VAL6]], [[VAL5]]) 141 "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> () 142 }, { 143 // CHECK: ^{{[a-z0-9]+}}([[BRANCH2_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>): 144 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 0 : i32} 145 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 1 : i32} 146 // CHECK: [[VAL6:%.+]] = "mhlo.floor"([[VAL4]]) 147 %1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> 148 // CHECK: "mhlo.return"([[VAL6]], [[VAL5]]) 149 "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> () 150 // CHECK: }) : (tensor<i32>, tuple<tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>) 151 }) {is_stateless = true} : (tensor<i32>) -> (tensor<f32>, tensor<f32>) 152 // CHECK: return [[VAL3]]#0, [[VAL3]]#1 : tensor<f32>, tensor<f32> 153 return %0#0, %0#1 : tensor<f32>, tensor<f32> 154} 155 156 157// CHECK-LABEL: func @while 158func @while() -> tensor<i32> { 159 // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> 160 // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> 161 %0 = mhlo.constant dense<0> : tensor<i32> 162 %1 = mhlo.constant dense<-1> : tensor<i32> 163 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) 164 // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { 165 // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 166 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 167 // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} 168 // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} 169 // CHECK: [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]]) 170 // CHECK: "mhlo.return"([[VAL10]]) 171 // CHECK: }, { 172 // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 173 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} 174 // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} 175 // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} 176 // CHECK: [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]]) 177 // CHECK: [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2) 178 // CHECK: "mhlo.return"([[VAL11]]) 179 // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>> 180 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} 181 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} 182 // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} 183 // CHECK: return [[VAL6]] 184 %2:3 = "tf.While"(%0, %1, %0) {body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) 185 return %2#2 : tensor<i32> 186} 187func @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i1> { 188 %0 = mhlo.constant dense<10> : tensor<i32> 189 %1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 190 return %1 : tensor<i1> 191} 192func @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) { 193 %0 = mhlo.constant dense<1> : tensor<i32> 194 %1 = mhlo.add %arg2, %0 : tensor<i32> 195 %2 = mhlo.add %arg0, %0 : tensor<i32> 196 return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32> 197} 198 199 200// CHECK-LABEL: func @whileRegion 201func @whileRegion() -> tensor<i32> { 202 // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> 203 %0 = mhlo.constant dense<0> : tensor<i32> 204 // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> 205 %1 = mhlo.constant dense<-1> : tensor<i32> 206 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]]) 207 // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { 208 %2:3 = "tf.WhileRegion"(%0, %1, %0) ( { 209 // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 210 ^cond(%carg0: tensor<i32>, %carg1: tensor<i32>, %carg2: tensor<i32>): 211 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 212 // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} 213 // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} 214 // CHECK: [[VAL10:%.+]] = mhlo.constant dense<10> 215 %3 = mhlo.constant dense<10> : tensor<i32> 216 // CHECK: [[VAL11:%.+]] = "mhlo.compare"([[VAL9]], [[VAL10]]) {comparison_direction = "LT"} 217 %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 218 // CHECK: "mhlo.return"([[VAL11]]) 219 "tf.Yield"(%4) : (tensor<i1>) -> () 220 }, { 221 // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 222 ^body(%barg0: tensor<i32>, %barg1: tensor<i32>, %barg2: tensor<i32>): 223 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} 224 // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} 225 // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} 226 // CHECK: [[VAL10:%.+]] = mhlo.constant dense<1> 227 %5 = mhlo.constant dense<1> : tensor<i32> 228 // CHECK: [[VAL11:%.+]] = mhlo.add [[VAL9]], [[VAL10]] 229 %6 = mhlo.add %barg2, %5 : tensor<i32> 230 // CHECK: [[VAL12:%.+]] = mhlo.add [[VAL7]], [[VAL10]] 231 %7 = mhlo.add %barg0, %5 : tensor<i32> 232 // CHECK: [[VAL13:%.+]] = "mhlo.tuple"([[VAL12]], [[VAL8]], [[VAL11]]) 233 // CHECK: "mhlo.return"([[VAL13]]) 234 "tf.Yield"(%7, %barg1, %6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> () 235 }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) 236 // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>> 237 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} 238 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32} 239 // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32} 240 // CHECK: return [[VAL6]] 241 return %2#2 : tensor<i32> 242} 243 244 245// CHECK-LABEL: func @whileRegionImplicitInputs 246// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>) 247func @whileRegionImplicitInputs(%arg0: tensor<i32>) -> tensor<i32> { 248 // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> 249 %0 = mhlo.constant dense<0> : tensor<i32> 250 // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> 251 %1 = mhlo.constant dense<-1> : tensor<i32> 252 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[VAL0]], [[VAL1]]) 253 // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { 254 %2 = "tf.WhileRegion"(%arg0) ( { 255 // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 256 ^cond(%carg0: tensor<i32>): 257 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 258 // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} 259 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32} 260 // CHECK: [[VAL8:%.+]] = "mhlo.compare"([[VAL5]], [[VAL6]]) {comparison_direction = "LT"} 261 %3 = "mhlo.compare"(%carg0, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 262 // CHECK: "mhlo.return"([[VAL8]]) 263 "tf.Yield"(%3) : (tensor<i1>) -> () 264 }, { 265 // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): 266 ^body(%barg0: tensor<i32>): 267 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} 268 // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} 269 // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} 270 // CHECK: [[VAL8:%.+]] = mhlo.add [[VAL5]], [[VAL7]] 271 %3 = mhlo.add %barg0, %1 : tensor<i32> 272 // CHECK: [[VAL9:%.+]] = mhlo.add [[VAL5]], [[VAL8]] 273 %4 = mhlo.add %barg0, %3 : tensor<i32> 274 // CHECK: [[VAL10:%.+]] = "mhlo.tuple"([[VAL9]], [[VAL6]], [[VAL7]]) 275 // CHECK: "mhlo.return"([[VAL10]]) 276 "tf.Yield"(%4) : (tensor<i32>) -> () 277 }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32> 278 // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>> 279 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32} 280 // CHECK: return [[VAL4]] 281 return %2 : tensor<i32> 282} 283 284 285// CHECK-LABEL: func @whileRegionMultipleImplicitInputs 286func @whileRegionMultipleImplicitInputs() { 287 // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> 288 %0 = mhlo.constant dense<0> : tensor<i32> 289 // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> 290 %1 = mhlo.constant dense<-1> : tensor<i32> 291 // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]]) 292 // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( { 293 "tf.WhileRegion"() ( { 294 // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>): 295 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 296 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} 297 // CHECK: [[VAL6:%.+]] = "mhlo.compare"([[VAL4]], [[VAL5]]) {comparison_direction = "LT"} 298 %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 299 // CHECK: "mhlo.return"([[VAL6]]) 300 "tf.Yield"(%2) : (tensor<i1>) -> () 301 }, { 302 // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>): 303 // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 304 // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32} 305 // CHECK: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]] 306 %2 = mhlo.add %0, %1 : tensor<i32> 307 // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL4]], [[VAL5]]) 308 // CHECK: "mhlo.return"([[VAL7]]) 309 "tf.Yield"() : () -> () 310 }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () 311 // CHECK: }) : (tuple<tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>> 312 // CHECK: return 313 return 314} 315