• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: xla-opt -split-input-file -verify-diagnostics -xla-legalize-tf-collective -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s
2
3
4// CHECK: module attributes
5// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0
6// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2
7// CHECK-LABEL: func @all_reduce_cross_replica
8func.func @all_reduce_cross_replica(%input: tensor<f32>) -> tensor<f32> {
9  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
10  // CHECK: "mhlo.all_reduce"
11  // CHECK: mhlo.add
12  // CHECK{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>
13  // CHECK-NOT: channel_handle
14  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
15  func.return %0 : tensor<f32>
16}
17
18// -----
19
20// CHECK: module attributes
21// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0
22// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2
23// CHECK-LABEL: func @all_reduce_cross_replica_and_partition
24func.func @all_reduce_cross_replica_and_partition(%input: tensor<f32>) -> tensor<f32> {
25  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
26  // CHECK: "mhlo.all_reduce"
27  // CHECK: mhlo.add
28  // CHECK: mhlo.return
29  // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 2, type = 1>
30  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>
31  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
32  // CHECK: "mhlo.all_reduce"
33  // CHECK: mhlo.add
34  // CHECK: mhlo.return
35  // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 1, type = 1>
36  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>
37  %1 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
38  %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
39  func.return %2 : tensor<f32>
40}
41
42// -----
43
44// CHECK-LABEL: func @xla_all_reduce_add
45func.func @xla_all_reduce_add(%input: tensor<f32>) -> tensor<f32> {
46  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
47  // CHECK: "mhlo.all_reduce"
48  // CHECK: mhlo.add
49  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
50  func.return %0 : tensor<f32>
51}
52
53// CHECK-LABEL: func @xla_all_reduce_max
54func.func @xla_all_reduce_max(%input: tensor<f32>) -> tensor<f32> {
55  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
56  // CHECK: "mhlo.all_reduce"
57  // CHECK: mhlo.maximum
58  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Max", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
59  func.return %0 : tensor<f32>
60}
61
62// CHECK-LABEL: func @xla_all_reduce_mean
63func.func @xla_all_reduce_mean(%input: tensor<f32>) -> tensor<f32> {
64  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
65  // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<1.000000e+00>
66  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
67  // CHECK: mhlo.add
68  // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]]
69  // CHECK-NEXT: return %[[RESULT]]
70  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Mean", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
71  func.return %0 : tensor<f32>
72}
73
74// CHECK-LABEL: func @xla_all_reduce_min
75func.func @xla_all_reduce_min(%input: tensor<f32>) -> tensor<f32> {
76  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
77  // CHECK: "mhlo.all_reduce"
78  // CHECK: mhlo.minimum
79  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Min", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
80  func.return %0 : tensor<f32>
81}
82
83// CHECK-LABEL: func @xla_all_reduce_mul
84func.func @xla_all_reduce_mul(%input: tensor<f32>) -> tensor<f32> {
85  %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32>
86  // CHECK: "mhlo.all_reduce"
87  // CHECK: mhlo.mul
88  %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Mul", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32>
89  func.return %0 : tensor<f32>
90}
91
92
93// -----
94
95// CHECK: module attributes
96// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 1
97// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2
98// CHECK-LABEL: func @collective_reduce_v2
99func.func @collective_reduce_v2(%input: tensor<f32>) -> tensor<f32> {
100  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
101  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
102  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
103  // CHECK: "mhlo.all_reduce"
104  // CHECK: mhlo.add
105  // CHECK: mhlo.return
106  // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 2, type = 1>
107  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
108  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
109  // CHECK: "mhlo.all_reduce"
110  // CHECK: mhlo.add
111  // CHECK: mhlo.return
112  // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 1, type = 1>
113  // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
114  %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
115  %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
116  func.return %2 : tensor<f32>
117}
118
119// -----
120
121// CHECK-LABEL: func @collective_reduce_v2_add_id
122func.func @collective_reduce_v2_add_id(%input: tensor<f32>) -> tensor<f32> {
123  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
124  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
125  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
126  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
127  // CHECK: mhlo.add
128  // CHECK: mhlo.return
129  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
130  // CHECK-NEXT: return %[[REDUCE]]
131  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
132  func.return %0 : tensor<f32>
133}
134
135// CHECK-LABEL: func @collective_reduce_v2_max_id
136func.func @collective_reduce_v2_max_id(%input: tensor<f32>) -> tensor<f32> {
137  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
138  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
139  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
140  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
141  // CHECK: mhlo.maximum
142  // CHECK: mhlo.return
143  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
144  // CHECK-NEXT: return %[[REDUCE]]
145  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
146  func.return %0 : tensor<f32>
147}
148
149// CHECK-LABEL: func @collective_reduce_v2_min_id
150func.func @collective_reduce_v2_min_id(%input: tensor<f32>) -> tensor<f32> {
151  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
152  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
153  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
154  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
155  // CHECK: mhlo.minimum
156  // CHECK: mhlo.return
157  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
158  // CHECK-NEXT: return %[[REDUCE]]
159  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
160  func.return %0 : tensor<f32>
161}
162
163// CHECK-LABEL: func @collective_reduce_v2_mul_id
164func.func @collective_reduce_v2_mul_id(%input: tensor<f32>) -> tensor<f32> {
165  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
166  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
167  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
168  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
169  // CHECK: mhlo.mul
170  // CHECK: mhlo.return
171  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
172  // CHECK-NEXT: return %[[REDUCE]]
173  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
174  func.return %0 : tensor<f32>
175}
176
177// CHECK-LABEL: func @collective_reduce_v2_add_div
178func.func @collective_reduce_v2_add_div(%input: tensor<f32>) -> tensor<f32> {
179  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
180  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
181  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
182  // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00>
183  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
184  // CHECK: mhlo.add
185  // CHECK: mhlo.return
186  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
187  // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]]
188  // CHECK-NEXT: return %[[RESULT]]
189  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
190  func.return %0 : tensor<f32>
191}
192
193// CHECK-LABEL: func @collective_reduce_v2_max_div
194func.func @collective_reduce_v2_max_div(%input: tensor<f32>) -> tensor<f32> {
195  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
196  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
197  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
198  // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00>
199  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
200  // CHECK: mhlo.maximum
201  // CHECK: mhlo.return
202  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
203  // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]]
204  // CHECK-NEXT: return %[[RESULT]]
205  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
206  func.return %0 : tensor<f32>
207}
208
209// CHECK-LABEL: func @collective_reduce_v2_min_div
210func.func @collective_reduce_v2_min_div(%input: tensor<f32>) -> tensor<f32> {
211  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
212  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
213  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
214  // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00>
215  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
216  // CHECK: mhlo.minimum
217  // CHECK: mhlo.return
218  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
219  // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]]
220  // CHECK-NEXT: return %[[RESULT]]
221  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
222  func.return %0 : tensor<f32>
223}
224
225// CHECK-LABEL: func @collective_reduce_v2_mul_div
226func.func @collective_reduce_v2_mul_div(%input: tensor<f32>) -> tensor<f32> {
227  %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
228  %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
229  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
230  // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00>
231  // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce"
232  // CHECK: mhlo.mul
233  // CHECK: mhlo.return
234  // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
235  // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]]
236  // CHECK-NEXT: return %[[RESULT]]
237  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
238  func.return %0 : tensor<f32>
239}
240
241
242// -----
243
244// CHECK: module attributes
245// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0
246// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2
247// CHECK-LABEL: func @collective_assign_group_v2
248func.func @collective_assign_group_v2(%input: tensor<f32>) -> tensor<f32> {
249  %rank = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
250  %key_base = "tf.Const"() { value = dense<10> : tensor<i32> } : () -> tensor<i32>
251  %group_assignment = "tf.Const"() { value = dense<[[0, 1]]> : tensor<1x2xi32> } : () -> tensor<1x2xi32>
252  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
253  %group_size, %group_key = "tf.CollectiveAssignGroupV2"(%group_assignment, %rank, %key_base) {} : (tensor<1x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
254  // CHECK-NOT: "tf.CollectiveAssignGroupV2"
255  // CHECK: "mhlo.all_reduce"
256  // CHECK: mhlo.add
257  // CHECK{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
258  // CHECK-NOT: "tf.CollectiveAssignGroupV2"
259  %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
260  func.return %0 : tensor<f32>
261}
262
263// -----
264
265func.func @inconsistent_collective_info(%input: tensor<f32>) -> tensor<f32> {
266  %group_key = "tf.Const"() { value = dense<11> : tensor<i32> } : () -> tensor<i32>
267  %group_size1 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
268  %group_size2 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32>
269  %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
270  // expected-error@below {{op module already contains an attribute tf2xla.collective_info.group_size=2, overwritting to a new value 1 is not allowed.}}
271  %0 = "tf.CollectiveReduceV2"(%input, %group_size1, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
272  %1 = "tf.CollectiveReduceV2"(%input, %group_size2, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32>
273  %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
274  func.return %2 : tensor<f32>
275}
276
277