• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s
2
3// CHECK-LABEL: @if
4// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>)
5func @if(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
6  // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
7  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
8  // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
9  // CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( {
10  // CHECK: ^bb0([[THEN_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
11  // CHECK:   [[VAL4:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 0 : i32}
12  // CHECK:   [[VAL5:%.+]] = "mhlo.get_tuple_element"([[THEN_ARG]]) {index = 1 : i32}
13  // CHECK:   [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]])
14  // CHECK:   [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
15  // CHECK:   "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
16  // CHECK: },  {
17  // CHECK: ^bb0([[ELSE_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>)
18  // CHECK:   [[VAL4:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 0 : i32}
19  // CHECK:   [[VAL5:%.+]] = "mhlo.get_tuple_element"([[ELSE_ARG]]) {index = 1 : i32}
20  // CHECK:   [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]])
21  // CHECK:   [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
22  // CHECK:   "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
23  // CHECK: })
24  %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
25
26  // CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
27  // CHECK: return [[VAL3]]
28  return %1 : tensor<f32>
29}
30
31func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
32attributes  {tf._input_shapes = ["tfshape$", "tfshape$"]} {
33  %0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
34  return %0 : tensor<f32>
35}
36
37func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
38attributes  {tf._input_shapes = ["tfshape$", "tfshape$"]} {
39  %0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
40  return %0 : tensor<f32>
41}
42
43
44// CHECK-LABEL: @ifRegion
45// CHECK-SAME:  ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>)
46func @ifRegion(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
47  // CHECK: [[VAL0:%.+]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "GT"}
48  %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
49  // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]])
50  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG1]])
51  // CHECK: [[VAL3:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL2]]) ( {
52  %1 = "tf.IfRegion"(%0) ( {
53  // CHECK: ^{{[a-z0-9]+}}([[TRUE_ARG:%.+]]: tuple<tensor<f32>>):
54    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[TRUE_ARG]]) {index = 0 : i32}
55    // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL5]])
56    %2 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
57    // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
58    // CHECK: "mhlo.return"([[VAL7]])
59    "tf.Yield"(%2) : (tensor<f32>) -> ()
60  }, {
61  // CHECK: ^{{[a-z0-9]+}}([[FALSE_ARG:%.+]]: tuple<tensor<f32>>):
62    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[FALSE_ARG]]) {index = 0 : i32}
63    // CHECK: [[VAL6:%.+]] = "mhlo.exponential"([[VAL5]])
64    %2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
65    // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
66    // CHECK: "mhlo.return"([[VAL7]])
67    "tf.Yield"(%2) : (tensor<f32>) -> ()
68  // CHECK: }) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>>
69  }) {is_stateless = true} : (tensor<i1>) -> tensor<f32>
70  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
71  // CHECK: return [[VAL4]]
72  return %1 : tensor<f32>
73}
74
75
76// CHECK-LABEL: func @case
77// CHECK-SAME:  %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
78func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
79  %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
80  // CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
81  // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
82  // CHECK:   ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
83  // CHECK:     %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
84  // CHECK:     %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
85  // CHECK:     %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
86  // CHECK:     "mhlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> ()
87  // CHECK:   },  {
88  // CHECK:   ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
89  // CHECK:     %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
90  // CHECK:     %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
91  // CHECK:     %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
92  // CHECK:     "mhlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> ()
93  // CHECK:   },  {
94  // CHECK:   ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
95  // CHECK:     %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
96  // CHECK:     %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
97  // CHECK:     %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
98  // CHECK:     "mhlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> ()
99  // CHECK:   }) : (tensor<i32>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
100  return %0#0, %0#1 : tensor<f32>, tensor<f32>
101// CHECK:   return %[[CASE]]#0, %[[CASE]]#1 : tensor<f32>, tensor<f32>
102}
103
104func @exponential(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
105  %0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
106  return %0, %arg1 : tensor<f32>, tensor<f32>
107}
108
109func @log(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
110  %0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
111  return %0, %arg1 : tensor<f32>, tensor<f32>
112}
113
114func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
115  %0 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
116  return %0, %arg1 : tensor<f32>, tensor<f32>
117}
118
119
120// CHECK-LABEL: func @caseRegion
121// CHECK-SAME:  ([[BRANCH_INDEX:%.+]]: tensor<i32>, [[ARG0:.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>)
122func @caseRegion(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
123  // CHECK: [[VAL0:%.+]] = "mhlo.tuple"([[ARG1]])
124  // CHECK: [[VAL1:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
125  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[ARG1]])
126  // CHECK: [[VAL3:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]], [[VAL0]], [[VAL1]], [[VAL2]]) ( {
127  %0:2 = "tf.CaseRegion"(%index) ( {
128  // CHECK: ^{{[a-z0-9]+}}([[BRANCH0_ARG:%.+]]: tuple<tensor<f32>>):
129    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH0_ARG]]) {index = 0 : i32}
130    // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]])
131    %1 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
132    // CHECK: "mhlo.return"([[VAL5]], [[VAL4]])
133    "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
134  }, {
135  // CHECK: ^{{[a-z0-9]+}}([[BRANCH1_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
136    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 0 : i32}
137    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH1_ARG]]) {index = 1 : i32}
138    // CHECK: [[VAL6:%.+]] = "mhlo.log"([[VAL4]])
139    %1 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
140    // CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
141    "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
142  }, {
143  // CHECK: ^{{[a-z0-9]+}}([[BRANCH2_ARG:%.+]]: tuple<tensor<f32>, tensor<f32>>):
144    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 0 : i32}
145    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BRANCH2_ARG]]) {index = 1 : i32}
146    // CHECK: [[VAL6:%.+]] = "mhlo.floor"([[VAL4]])
147    %1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
148    // CHECK: "mhlo.return"([[VAL6]], [[VAL5]])
149    "tf.Yield"(%1, %arg1) : (tensor<f32>, tensor<f32>) -> ()
150  // CHECK: }) : (tensor<i32>, tuple<tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
151  }) {is_stateless = true} : (tensor<i32>) -> (tensor<f32>, tensor<f32>)
152  // CHECK: return [[VAL3]]#0, [[VAL3]]#1 : tensor<f32>, tensor<f32>
153  return %0#0, %0#1 : tensor<f32>, tensor<f32>
154}
155
156
157// CHECK-LABEL: func @while
158func @while() -> tensor<i32> {
159  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
160  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
161  %0 = mhlo.constant dense<0> : tensor<i32>
162  %1 = mhlo.constant dense<-1> : tensor<i32>
163  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
164  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
165  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
166  // CHECK:   [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
167  // CHECK:   [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
168  // CHECK:   [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
169  // CHECK:   [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]])
170  // CHECK:   "mhlo.return"([[VAL10]])
171  // CHECK: },  {
172  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
173  // CHECK:   [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
174  // CHECK:   [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
175  // CHECK:   [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
176  // CHECK:   [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]])
177  // CHECK:   [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2)
178  // CHECK:   "mhlo.return"([[VAL11]])
179  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
180  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
181  // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
182  // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
183  // CHECK: return [[VAL6]]
184  %2:3 = "tf.While"(%0, %1, %0) {body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
185  return %2#2 : tensor<i32>
186}
187func @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i1> {
188  %0 = mhlo.constant dense<10> : tensor<i32>
189  %1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
190  return %1 : tensor<i1>
191}
192func @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>) {
193  %0 = mhlo.constant dense<1> : tensor<i32>
194  %1 = mhlo.add %arg2, %0 : tensor<i32>
195  %2 = mhlo.add %arg0, %0 : tensor<i32>
196  return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32>
197}
198
199
200// CHECK-LABEL: func @whileRegion
201func @whileRegion() -> tensor<i32> {
202  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
203  %0 = mhlo.constant dense<0> : tensor<i32>
204  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
205  %1 = mhlo.constant dense<-1> : tensor<i32>
206  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
207  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
208  %2:3 = "tf.WhileRegion"(%0, %1, %0) ( {
209  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
210  ^cond(%carg0: tensor<i32>, %carg1: tensor<i32>, %carg2: tensor<i32>):
211    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
212    // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
213    // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
214    // CHECK: [[VAL10:%.+]] = mhlo.constant dense<10>
215    %3 = mhlo.constant dense<10> : tensor<i32>
216    // CHECK: [[VAL11:%.+]] = "mhlo.compare"([[VAL9]], [[VAL10]]) {comparison_direction = "LT"}
217    %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
218    // CHECK: "mhlo.return"([[VAL11]])
219    "tf.Yield"(%4) : (tensor<i1>) -> ()
220  }, {
221  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
222  ^body(%barg0: tensor<i32>, %barg1: tensor<i32>, %barg2: tensor<i32>):
223    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
224    // CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
225    // CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
226    // CHECK: [[VAL10:%.+]] = mhlo.constant dense<1>
227    %5 = mhlo.constant dense<1> : tensor<i32>
228    // CHECK: [[VAL11:%.+]] = mhlo.add [[VAL9]], [[VAL10]]
229    %6 = mhlo.add %barg2, %5 : tensor<i32>
230    // CHECK: [[VAL12:%.+]] = mhlo.add [[VAL7]], [[VAL10]]
231    %7 = mhlo.add %barg0, %5 : tensor<i32>
232    // CHECK: [[VAL13:%.+]] = "mhlo.tuple"([[VAL12]], [[VAL8]], [[VAL11]])
233    // CHECK: "mhlo.return"([[VAL13]])
234    "tf.Yield"(%7, %barg1, %6) : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
235  }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
236  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
237  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
238  // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
239  // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
240  // CHECK: return [[VAL6]]
241  return %2#2 : tensor<i32>
242}
243
244
245// CHECK-LABEL: func @whileRegionImplicitInputs
246// CHECK-SAME:  ([[ARG0:%.+]]: tensor<i32>)
247func @whileRegionImplicitInputs(%arg0: tensor<i32>) -> tensor<i32> {
248  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
249  %0 = mhlo.constant dense<0> : tensor<i32>
250  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
251  %1 = mhlo.constant dense<-1> : tensor<i32>
252  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[ARG0]], [[VAL0]], [[VAL1]])
253  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
254  %2 = "tf.WhileRegion"(%arg0) ( {
255  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
256  ^cond(%carg0: tensor<i32>):
257    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
258    // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
259    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 2 : i32}
260    // CHECK: [[VAL8:%.+]] = "mhlo.compare"([[VAL5]], [[VAL6]]) {comparison_direction = "LT"}
261    %3 = "mhlo.compare"(%carg0, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
262    // CHECK: "mhlo.return"([[VAL8]])
263    "tf.Yield"(%3) : (tensor<i1>) -> ()
264  }, {
265  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
266  ^body(%barg0: tensor<i32>):
267    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
268    // CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
269    // CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
270    // CHECK: [[VAL8:%.+]] = mhlo.add [[VAL5]], [[VAL7]]
271    %3 = mhlo.add %barg0, %1 : tensor<i32>
272    // CHECK: [[VAL9:%.+]] = mhlo.add [[VAL5]], [[VAL8]]
273    %4 = mhlo.add %barg0, %3 : tensor<i32>
274    // CHECK: [[VAL10:%.+]] = "mhlo.tuple"([[VAL9]], [[VAL6]], [[VAL7]])
275    // CHECK: "mhlo.return"([[VAL10]])
276    "tf.Yield"(%4) : (tensor<i32>) -> ()
277  }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>) -> tensor<i32>
278  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
279  // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
280  // CHECK: return [[VAL4]]
281  return %2 : tensor<i32>
282}
283
284
285// CHECK-LABEL: func @whileRegionMultipleImplicitInputs
286func @whileRegionMultipleImplicitInputs() {
287  // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
288  %0 = mhlo.constant dense<0> : tensor<i32>
289  // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
290  %1 = mhlo.constant dense<-1> : tensor<i32>
291  // CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]])
292  // CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
293  "tf.WhileRegion"() ( {
294  // CHECK: ^bb0([[COND_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
295    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
296    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
297    // CHECK: [[VAL6:%.+]] = "mhlo.compare"([[VAL4]], [[VAL5]]) {comparison_direction = "LT"}
298    %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
299    // CHECK: "mhlo.return"([[VAL6]])
300    "tf.Yield"(%2) : (tensor<i1>) -> ()
301  }, {
302  // CHECK: ^bb0([[BODY_ARG:%.+]]: tuple<tensor<i32>, tensor<i32>>):
303    // CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
304    // CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 1 : i32}
305    // CHECK: [[VAL6:%.+]] = mhlo.add [[VAL4]], [[VAL5]]
306    %2 = mhlo.add %0, %1 : tensor<i32>
307    // CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL4]], [[VAL5]])
308    // CHECK: "mhlo.return"([[VAL7]])
309    "tf.Yield"() : () -> ()
310  }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> ()
311  // CHECK: }) : (tuple<tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>>
312  // CHECK: return
313  return
314}
315