1// RUN: xla-opt -split-input-file -verify-diagnostics -xla-legalize-tf-collective -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s 2 3 4// CHECK: module attributes 5// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0 6// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2 7// CHECK-LABEL: func @all_reduce_cross_replica 8func.func @all_reduce_cross_replica(%input: tensor<f32>) -> tensor<f32> { 9 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 10 // CHECK: "mhlo.all_reduce" 11 // CHECK: mhlo.add 12 // CHECK{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> 13 // CHECK-NOT: channel_handle 14 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 15 func.return %0 : tensor<f32> 16} 17 18// ----- 19 20// CHECK: module attributes 21// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0 22// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2 23// CHECK-LABEL: func @all_reduce_cross_replica_and_partition 24func.func @all_reduce_cross_replica_and_partition(%input: tensor<f32>) -> tensor<f32> { 25 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 26 // CHECK: "mhlo.all_reduce" 27 // CHECK: mhlo.add 28 // CHECK: mhlo.return 29 // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 2, type = 1> 30 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> 31 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 32 // CHECK: "mhlo.all_reduce" 33 // CHECK: mhlo.add 34 // CHECK: mhlo.return 35 // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 1, type = 1> 36 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> 37 %1 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 38 %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32> 39 func.return %2 : tensor<f32> 40} 41 42// ----- 43 44// CHECK-LABEL: func @xla_all_reduce_add 45func.func @xla_all_reduce_add(%input: tensor<f32>) -> tensor<f32> { 46 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 47 // CHECK: "mhlo.all_reduce" 48 // CHECK: mhlo.add 49 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 50 func.return %0 : tensor<f32> 51} 52 53// CHECK-LABEL: func @xla_all_reduce_max 54func.func @xla_all_reduce_max(%input: tensor<f32>) -> tensor<f32> { 55 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 56 // CHECK: "mhlo.all_reduce" 57 // CHECK: mhlo.maximum 58 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Max", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 59 func.return %0 : tensor<f32> 60} 61 62// CHECK-LABEL: func @xla_all_reduce_mean 63func.func @xla_all_reduce_mean(%input: tensor<f32>) -> tensor<f32> { 64 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 65 // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<1.000000e+00> 66 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 67 // CHECK: mhlo.add 68 // CHECK: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] 69 // CHECK-NEXT: return %[[RESULT]] 70 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Mean", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 71 func.return %0 : tensor<f32> 72} 73 74// CHECK-LABEL: func @xla_all_reduce_min 75func.func @xla_all_reduce_min(%input: tensor<f32>) -> tensor<f32> { 76 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 77 // CHECK: "mhlo.all_reduce" 78 // CHECK: mhlo.minimum 79 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Min", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 80 func.return %0 : tensor<f32> 81} 82 83// CHECK-LABEL: func @xla_all_reduce_mul 84func.func @xla_all_reduce_mul(%input: tensor<f32>) -> tensor<f32> { 85 %group_assignment = "tf.Const"() { value = dense<[[0],[1]]> : tensor<2x1xi32> } : () -> tensor<2x1xi32> 86 // CHECK: "mhlo.all_reduce" 87 // CHECK: mhlo.mul 88 %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Mul", mode = "CrossReplica"} : (tensor<f32>, tensor<2x1xi32>) -> tensor<f32> 89 func.return %0 : tensor<f32> 90} 91 92 93// ----- 94 95// CHECK: module attributes 96// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 1 97// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2 98// CHECK-LABEL: func @collective_reduce_v2 99func.func @collective_reduce_v2(%input: tensor<f32>) -> tensor<f32> { 100 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 101 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 102 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 103 // CHECK: "mhlo.all_reduce" 104 // CHECK: mhlo.add 105 // CHECK: mhlo.return 106 // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 2, type = 1> 107 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 108 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 109 // CHECK: "mhlo.all_reduce" 110 // CHECK: mhlo.add 111 // CHECK: mhlo.return 112 // CHECK-NEXT: channel_handle = #mhlo.channel_handle<handle = 1, type = 1> 113 // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 114 %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 115 %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32> 116 func.return %2 : tensor<f32> 117} 118 119// ----- 120 121// CHECK-LABEL: func @collective_reduce_v2_add_id 122func.func @collective_reduce_v2_add_id(%input: tensor<f32>) -> tensor<f32> { 123 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 124 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 125 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 126 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 127 // CHECK: mhlo.add 128 // CHECK: mhlo.return 129 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 130 // CHECK-NEXT: return %[[REDUCE]] 131 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 132 func.return %0 : tensor<f32> 133} 134 135// CHECK-LABEL: func @collective_reduce_v2_max_id 136func.func @collective_reduce_v2_max_id(%input: tensor<f32>) -> tensor<f32> { 137 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 138 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 139 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 140 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 141 // CHECK: mhlo.maximum 142 // CHECK: mhlo.return 143 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 144 // CHECK-NEXT: return %[[REDUCE]] 145 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 146 func.return %0 : tensor<f32> 147} 148 149// CHECK-LABEL: func @collective_reduce_v2_min_id 150func.func @collective_reduce_v2_min_id(%input: tensor<f32>) -> tensor<f32> { 151 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 152 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 153 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 154 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 155 // CHECK: mhlo.minimum 156 // CHECK: mhlo.return 157 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 158 // CHECK-NEXT: return %[[REDUCE]] 159 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 160 func.return %0 : tensor<f32> 161} 162 163// CHECK-LABEL: func @collective_reduce_v2_mul_id 164func.func @collective_reduce_v2_mul_id(%input: tensor<f32>) -> tensor<f32> { 165 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 166 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 167 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 168 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 169 // CHECK: mhlo.mul 170 // CHECK: mhlo.return 171 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 172 // CHECK-NEXT: return %[[REDUCE]] 173 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 174 func.return %0 : tensor<f32> 175} 176 177// CHECK-LABEL: func @collective_reduce_v2_add_div 178func.func @collective_reduce_v2_add_div(%input: tensor<f32>) -> tensor<f32> { 179 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 180 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 181 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 182 // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> 183 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 184 // CHECK: mhlo.add 185 // CHECK: mhlo.return 186 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 187 // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] 188 // CHECK-NEXT: return %[[RESULT]] 189 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 190 func.return %0 : tensor<f32> 191} 192 193// CHECK-LABEL: func @collective_reduce_v2_max_div 194func.func @collective_reduce_v2_max_div(%input: tensor<f32>) -> tensor<f32> { 195 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 196 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 197 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 198 // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> 199 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 200 // CHECK: mhlo.maximum 201 // CHECK: mhlo.return 202 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 203 // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] 204 // CHECK-NEXT: return %[[RESULT]] 205 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Max", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 206 func.return %0 : tensor<f32> 207} 208 209// CHECK-LABEL: func @collective_reduce_v2_min_div 210func.func @collective_reduce_v2_min_div(%input: tensor<f32>) -> tensor<f32> { 211 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 212 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 213 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 214 // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> 215 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 216 // CHECK: mhlo.minimum 217 // CHECK: mhlo.return 218 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 219 // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] 220 // CHECK-NEXT: return %[[RESULT]] 221 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Min", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 222 func.return %0 : tensor<f32> 223} 224 225// CHECK-LABEL: func @collective_reduce_v2_mul_div 226func.func @collective_reduce_v2_mul_div(%input: tensor<f32>) -> tensor<f32> { 227 %group_key = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 228 %group_size = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 229 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 230 // CHECK: %[[GROUP_SIZE:.*]] = mhlo.constant dense<2.000000e+00> 231 // CHECK: %[[REDUCE:.*]] = "mhlo.all_reduce" 232 // CHECK: mhlo.mul 233 // CHECK: mhlo.return 234 // CHECK-NEXT{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 235 // CHECK-NEXT: %[[RESULT:.*]] = mhlo.divide %[[REDUCE]], %[[GROUP_SIZE]] 236 // CHECK-NEXT: return %[[RESULT]] 237 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Div"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 238 func.return %0 : tensor<f32> 239} 240 241 242// ----- 243 244// CHECK: module attributes 245// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_key = 0 246// CHECK-SAME{LITERAL}: tf2xla.collective_info.group_size = 2 247// CHECK-LABEL: func @collective_assign_group_v2 248func.func @collective_assign_group_v2(%input: tensor<f32>) -> tensor<f32> { 249 %rank = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 250 %key_base = "tf.Const"() { value = dense<10> : tensor<i32> } : () -> tensor<i32> 251 %group_assignment = "tf.Const"() { value = dense<[[0, 1]]> : tensor<1x2xi32> } : () -> tensor<1x2xi32> 252 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 253 %group_size, %group_key = "tf.CollectiveAssignGroupV2"(%group_assignment, %rank, %key_base) {} : (tensor<1x2xi32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 254 // CHECK-NOT: "tf.CollectiveAssignGroupV2" 255 // CHECK: "mhlo.all_reduce" 256 // CHECK: mhlo.add 257 // CHECK{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> 258 // CHECK-NOT: "tf.CollectiveAssignGroupV2" 259 %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 260 func.return %0 : tensor<f32> 261} 262 263// ----- 264 265func.func @inconsistent_collective_info(%input: tensor<f32>) -> tensor<f32> { 266 %group_key = "tf.Const"() { value = dense<11> : tensor<i32> } : () -> tensor<i32> 267 %group_size1 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 268 %group_size2 = "tf.Const"() { value = dense<2> : tensor<i32> } : () -> tensor<i32> 269 %instance_key = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32> 270 // expected-error@below {{op module already contains an attribute tf2xla.collective_info.group_size=2, overwritting to a new value 1 is not allowed.}} 271 %0 = "tf.CollectiveReduceV2"(%input, %group_size1, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 272 %1 = "tf.CollectiveReduceV2"(%input, %group_size2, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor<f32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<f32> 273 %2 = "tf.Add"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32> 274 func.return %2 : tensor<f32> 275} 276 277