1// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s 2 3// Shape computations shall be reified. 4// CHECK-LABEL: @shape_of_unary 5// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>) 6func @shape_of_unary(%arg : tensor<?x32xi16>) { 7 // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<?xindex> 8 // CHECK: "use"(%[[SHAPE]]) 9 %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16> 10 %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> 11 "use"(%1) : (tensor<?xindex>) -> () 12 return 13} 14 15// ----- 16 17// Shape computations shall be reified. 18// CHECK-LABEL: @shape_of_nary 19// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>) 20func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) { 21 // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<?xindex> 22 // CHECK: "use"(%[[SHAPE]]) 23 %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16> 24 %1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16> 25 %2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex> 26 "use"(%2) : (tensor<?xindex>) -> () 27 return 28} 29 30// ----- 31 32// Broadcasts can be moved up over unary shape-preserving operations. 33// CHECK-LABEL: @bcast_unary 34// CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>, %[[OUT_DIMS:.*]]: tensor<3xindex>) 35func @bcast_unary(%arg : tensor<?x32xi16>, %out_dims : tensor<3xindex>) 36 -> tensor<?x?x32xf16> { 37 // CHECK: %[[BCASTED_OPERAND:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[OUT_DIMS]]) 38 // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xi16>, tensor<3xindex>) -> tensor<?x?x32xi16> 39 // CHECK: "mhlo.convert"(%[[BCASTED_OPERAND]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16> 40 %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16> 41 %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) { 42 broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : 43 (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> 44 return %1 : tensor<?x?x32xf16> 45} 46 47// ----- 48 49// Broadcasts can be moved up over n-ary shape-preserving operations. 50// CHECK-LABEL: @bcast_nary 51// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf32>, %[[ARG1:.*]]: tensor<?x32xf32>, %[[OUT_DIMS:.*]]: tensor<3xindex>) 52func @bcast_nary(%arg0 : tensor<?x32xf32>, %arg1 : tensor<?x32xf32>, 53 %out_dims : tensor<3xindex>) -> tensor<?x?x32xf32> { 54 // CHECK-NOT: subtract 55 // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[OUT_DIMS]]) 56 // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[OUT_DIMS]]) 57 // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] : tensor<?x?x32xf32> 58 %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf32> 59 %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %out_dims) { 60 broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : 61 (tensor<?x32xf32>, tensor<3xindex>) -> tensor<?x?x32xf32> 62 return %1 : tensor<?x?x32xf32> 63} 64 65// ----- 66 67// Exemplary IR as it appears in the lowering with `tf.Sub` and `tf.Cast`. 68// CHECK-LABEL: @cast_sub 69// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xi16>, %[[ARG1:.*]]: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> 70func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>) 71 -> tensor<?x?x32xf16> { 72 // CHECK-NOT: convert 73 // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %{{.*}}) 74 // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %{{.*}}) 75 // CHECK: %[[CONVERTED_BCASTED_ARG0:.*]] = "mhlo.convert"(%[[BCASTED_ARG0]]) : (tensor<?x?x32xi16>) -> tensor<?x?x32xf16> 76 // CHECK: %{{.*}} = mhlo.subtract %[[BCASTED_ARG1]], %[[CONVERTED_BCASTED_ARG0]] : tensor<?x?x32xf16> 77 %0 = "mhlo.convert"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16> 78 %1 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex> 79 %2 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> 80 %3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex> 81 %4 = shape.assuming %3 -> (tensor<?x?x32xf16>) { 82 %5 = shape.shape_of %arg1 : tensor<?x?x32xf16> -> tensor<?xindex> 83 %6 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> 84 %7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex> 85 -> tensor<?xindex> 86 %8 = tensor.cast %7 : tensor<?xindex> to tensor<3xindex> 87 %9 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) { 88 broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : 89 (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> 90 %10 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) { 91 broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : 92 (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> 93 %11 = mhlo.subtract %9, %10 : tensor<?x?x32xf16> 94 shape.assuming_yield %11 : tensor<?x?x32xf16> 95 } 96 return %4 : tensor<?x?x32xf16> 97} 98 99// ----- 100 101// CHECK-LABEL: @inline_bcasted_shape_operands 102// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) 103func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>, 104 %c : tensor<?xindex>) -> !shape.witness { 105 // CHECK-NOT: shape.broadcast 106 // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]] 107 // CHECK: return %[[WITNESS]] : !shape.witness 108 %0 = shape.broadcast %a, %b : tensor<?xindex>, tensor<?xindex> 109 -> tensor<?xindex> 110 %1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex> 111 return %1 : !shape.witness 112} 113 114// ----- 115 116// CHECK-LABEL: @move_shape_of_into_assuming 117// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?x32xf32>) 118func @move_shape_of_into_assuming(%arg0 : !shape.witness, 119 %arg1 : tensor<?x32xf32>) -> tensor<3xindex> { 120 // CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<?x32xf32>, tensor<?x32xf32>, tensor<3xindex>) { 121 // CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<?x32xf32> 122 // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[DUMMY_TENSOR]] 123 // CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[SHAPE]] 124 // CHECK: } 125 // CHECK-NOT: shape_of 126 // CHECK: return %[[ASSUMING_RESULTS]]#2 127 %0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) { 128 %1 = "dummy.tensor"() : () -> tensor<?x32xf32> 129 shape.assuming_yield %arg1, %1 : tensor<?x32xf32>, tensor<?x32xf32> 130 } 131 %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex> 132 "use"(%0#0, %0#1) : (tensor<?x32xf32>, tensor<?x32xf32>) -> () 133 return %2 : tensor<3xindex> 134} 135 136// ----- 137 138// CHECK-LABEL: @move_cstr_broadcastable_into_assuming 139// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>) 140func @move_cstr_broadcastable_into_assuming(%arg0 : !shape.witness, 141 %arg1 : tensor<2xindex>) -> !shape.witness { 142 // CHECK: %[[ASSUMING_RESULTS:.*]]:3 = shape.assuming %[[ARG0]] -> (tensor<2xindex>, tensor<3xindex>, !shape.witness) { 143 // CHECK: %[[DUMMY_TENSOR:.*]] = "dummy.tensor"() : () -> tensor<3xindex> 144 // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[DUMMY_TENSOR]] 145 // CHECK: shape.assuming_yield %[[ARG1]], %[[DUMMY_TENSOR]], %[[WITNESS]] 146 // CHECK: } 147 // CHECK-NOT: cstr_broadcastable 148 // CHECK: return %[[ASSUMING_RESULTS]]#2 149 %0:2 = shape.assuming %arg0 -> (tensor<2xindex>, tensor<3xindex>) { 150 %1 = "dummy.tensor"() : () -> tensor<3xindex> 151 shape.assuming_yield %arg1, %1 : tensor<2xindex>, tensor<3xindex> 152 } 153 %1 = shape.cstr_broadcastable %arg1, %0#1 : tensor<2xindex>, tensor<3xindex> 154 "use"(%0#0, %0#1) : (tensor<2xindex>, tensor<3xindex>) -> () 155 return %1 : !shape.witness 156} 157 158// ----- 159 160// CHECK-LABEL: @not_move_shape_of_into_assuming 161func @not_move_shape_of_into_assuming(%arg0 : !shape.witness, 162 %arg1 : tensor<?x32xf32>, %arg2 : tensor<?x32xf32>) -> tensor<3xindex> { 163 // CHECK: shape.assuming 164 // CHECK-SAME: { 165 // CHECK-NOT: shape_of 166 // CHECK: } 167 // CHECK: "some.other.op" 168 // CHECK: shape_of 169 %0:2 = shape.assuming %arg0 -> (tensor<?x32xf32>, tensor<?x32xf32>) { 170 shape.assuming_yield %arg1, %arg2 : tensor<?x32xf32>, tensor<?x32xf32> 171 } 172 "some.other.op"() : () -> () 173 %2 = shape.shape_of %0#1 : tensor<?x32xf32> -> tensor<3xindex> 174 return %2 : tensor<3xindex> 175} 176 177// ----- 178 179// CHECK-LABEL: @move_cstr_broadcastable_out_of_assuming 180// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2xindex>, %[[ARG2:.*]]: tensor<3xindex>) 181func @move_cstr_broadcastable_out_of_assuming(%arg0 : !shape.witness, 182 %arg1 : tensor<2xindex>, %arg2 : tensor<3xindex>) -> !shape.witness { 183 // CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]] 184 // CHECK-NOT: assuming 185 // CHECK-NOT: cstr_broadcastable 186 // CHECK: return %[[WITNESS]] 187 %0 = shape.assuming %arg0 -> (!shape.witness) { 188 %1 = shape.cstr_broadcastable %arg1, %arg2 : tensor<2xindex>, tensor<3xindex> 189 shape.assuming_yield %1 : !shape.witness 190 } 191 return %0 : !shape.witness 192} 193 194// ----- 195 196// CHECK-LABEL: @move_elementwise_into_assuming 197// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<?xf32>) 198func @move_elementwise_into_assuming(%arg0 : !shape.witness, 199 %arg1 : tensor<?xf32>) -> tensor<?xf32> { 200 // CHECK: %[[RES:.*]] = shape.assuming %[[ARG0]] 201 // CHECK: %[[SOME:.*]] = "some.op" 202 // CHECK: %[[TANH:.*]] = "mhlo.tanh"(%[[ARG1]]) 203 // CHECK: %[[BCAST_ADD:.*]] = chlo.broadcast_add %[[TANH]], %[[SOME]] 204 // CHECK: shape.assuming_yield %[[BCAST_ADD]] 205 // CHECK-NOT: tanh 206 // CHECK-NOT: broadcast_add 207 // CHECK: return %[[RES]] 208 %0:2 = shape.assuming %arg0 -> (tensor<?xf32>, tensor<?xf32>) { 209 %1 = "some.op"() : () -> tensor<?xf32> 210 shape.assuming_yield %arg1, %1 : tensor<?xf32>, tensor<?xf32> 211 } 212 %1 = "mhlo.tanh"(%arg1) : (tensor<?xf32>) -> tensor<?xf32> 213 %2 = chlo.broadcast_add %1, %0#1 214 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 215 return %2 : tensor<?xf32> 216} 217 218// ----- 219 220// CHECK-LABEL: @move_shape_of_out_of_assuming 221// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) 222func @move_shape_of_out_of_assuming(%arg0 : !shape.witness, 223 %arg1 : tensor<2x?xf32>) -> tensor<2xindex> { 224 // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]] 225 // CHECK-NOT: assuming 226 // CHECK-NOT: cstr_broadcastable 227 // CHECK: return %[[SHAPE]] 228 %0 = shape.assuming %arg0 -> (tensor<2xindex>) { 229 %1 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex> 230 shape.assuming_yield %1 : tensor<2xindex> 231 } 232 return %0 : tensor<2xindex> 233} 234 235// ----- 236 237// CHECK-LABEL: @move_shape_of_out_of_assuming 238// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) 239func @move_shape_of_out_of_assuming(%arg0 : !shape.witness, 240 %arg1 : tensor<2x?xf32>) -> tensor<2xindex> { 241 // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG1]] 242 // CHECK: %{{.*}} = shape.assuming %[[ARG0]] -> (tensor<2x?xf32>) { 243 // CHECK: %[[SOME_VAL:.*]] = "some.op"() : () -> tensor<2x?xf32> 244 // CHECK: shape.assuming_yield %[[SOME_VAL]] : tensor<2x?xf32> 245 // CHECK: } 246 // CHECK: return %[[SHAPE]] 247 %0:2 = shape.assuming %arg0 -> (tensor<2x?xf32>, tensor<2xindex>) { 248 %1 = "some.op"() : () -> (tensor<2x?xf32>) 249 %2 = shape.shape_of %arg1 : tensor<2x?xf32> -> tensor<2xindex> 250 shape.assuming_yield %1, %2 : tensor<2x?xf32>, tensor<2xindex> 251 } 252 "use"(%0#0, %0#1) : (tensor<2x?xf32>, tensor<2xindex>) -> () 253 return %0#1 : tensor<2xindex> 254} 255 256// ----- 257 258// CHECK-LABEL: @not_move_shape_of_out_of_assuming 259// CHECK-SAME: (%[[ARG0:.*]]: !shape.witness, %[[ARG1:.*]]: tensor<2x?xf32>) 260func @not_move_shape_of_out_of_assuming(%arg0 : !shape.witness, 261 %arg1 : tensor<2x?xf32>) -> tensor<2xindex> { 262 // CHECK-NOT: shape_of 263 // CHECK: shape.assuming 264 // CHECK-SAME: { 265 // CHECK: "some.tensor" 266 // CHECK: shape_of 267 // CHECK: } 268 %0 = shape.assuming %arg0 -> (tensor<2xindex>) { 269 %1 = "some.tensor"() : () -> tensor<2x?xf32> 270 %2 = shape.shape_of %1 : tensor<2x?xf32> -> tensor<2xindex> 271 shape.assuming_yield %2 : tensor<2xindex> 272 } 273 return %0 : tensor<2xindex> 274} 275 276// ----- 277 278// CHECK: @merge_assuming_ops 279// CHECK: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) 280func @merge_assuming_ops(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, 281 %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> { 282 // CHECK: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] 283 // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] 284 // CHECK: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] 285 // CHECK: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]] 286 // CHECK: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]], %[[SHAPE2]] 287 // CHECK: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] 288 // CHECK: %[[MERGED:.*]]:2 = shape.assuming %[[COMBINED_WITNESS]] 289 // CHECK-SAME: { 290 // CHECK: "some.op" 291 // CHECK: %[[RESULT0:.*]] = "some.producer" 292 // CHECK: "another.op" 293 // CHECK: %[[RESULT1:.*]] = "another.producer" 294 // CHECK: shape.assuming_yield %[[RESULT0]], %[[RESULT1]] 295 // CHECK: } 296 // CHECK: return %[[MERGED]]#1 297 %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> 298 %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> 299 %2 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> 300 %3 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex> 301 %4 = shape.cstr_broadcastable %0, %1, %2 : tensor<2xindex>, tensor<2xindex>, 302 tensor<3xindex> 303 %5 = shape.assuming %3 -> (tensor<?x32xf16>) { 304 "some.op"() : () -> () 305 %6 = "some.producer"() : () -> tensor<?x32xf16> 306 shape.assuming_yield %6 : tensor<?x32xf16> 307 } 308 %7 = shape.assuming %4 -> (tensor<?x?x32xf16>) { 309 "another.op"() : () -> () 310 %8 = "another.producer"() : () -> tensor<?x?x32xf16> 311 shape.assuming_yield %8 : tensor<?x?x32xf16> 312 } 313 "use"(%5, %7) : (tensor<?x32xf16>, tensor<?x?x32xf16>) -> () 314 return %7 : tensor<?x?x32xf16> 315} 316 317// ----- 318 319// Do not merge assuming ops if witness will not dominate use. 320// CHECK: @do_not_merge_assuming_ops 321func @do_not_merge_assuming_ops() { 322 // CHECK: shape.assuming 323 // CHECK: shape.assuming 324 %0 = "some.witness"() : () -> !shape.witness 325 %1 = shape.assuming %0 -> (!shape.witness) { 326 %2 = "some.witness"() : () -> !shape.witness 327 shape.assuming_yield %2 : !shape.witness 328 } 329 shape.assuming %1 { 330 "some.op"() : () -> () 331 shape.assuming_yield 332 } 333 return 334} 335 336// ----- 337 338// CHECK: @eliminate_extent_tensor_cast 339// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) 340func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { 341 // CHECK-NOT: shape_of 342 // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> 343 // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () 344 %0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor<?xindex> 345 %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex> 346 "use"(%1) : (tensor<3xindex>) -> () 347 return 348} 349 350// ----- 351 352// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. 353// CHECK-LABEL: @sub_sub 354// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) 355func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, 356 %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> { 357 // CHECK-DAG: %[[SHAPE0:.*]] = shape.shape_of %[[ARG0]] 358 // CHECK-DAG: %[[SHAPE1:.*]] = shape.shape_of %[[ARG1]] 359 // CHECK-DAG: %[[SHAPE2:.*]] = shape.shape_of %[[ARG2]] 360 // CHECK-DAG: %[[WITNESS0:.*]] = shape.cstr_broadcastable %[[SHAPE0]], %[[SHAPE1]] 361 // CHECK-DAG: %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] 362 // CHECK-DAG: %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] 363 // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] 364 // CHECK: %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] 365 // CHECK: %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] 366 // CHECK: %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> 367 // CHECK: %[[BCASTED_ARG0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} 368 // CHECK: %[[BCASTED_ARG1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} 369 // CHECK: %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] 370 // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] 371 // CHECK: shape.assuming_yield %[[RESULT]] 372 // CHECK: return %[[ASSUMING_RESULT]] 373 %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> 374 %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> 375 %2 = shape.cstr_broadcastable %0, %1 : tensor<2xindex>, tensor<2xindex> 376 %3 = shape.assuming %2 -> (tensor<?x32xf16>) { 377 %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> 378 %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> 379 %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex> 380 %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex> 381 %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> 382 %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> 383 %14 = mhlo.subtract %12, %13 : tensor<?x32xf16> 384 shape.assuming_yield %14 : tensor<?x32xf16> 385 } 386 %4 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> 387 %5 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex> 388 %6 = shape.cstr_broadcastable %4, %5 : tensor<3xindex>, tensor<2xindex> 389 %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) { 390 %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> 391 %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex> 392 %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex> 393 %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex> 394 %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> 395 %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> 396 %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16> 397 shape.assuming_yield %14 : tensor<?x?x32xf16> 398 } 399 return %7 : tensor<?x?x32xf16> 400} 401 402// ----- 403 404// CHECK-LABEL: @redundant_cstr_broadcastable 405// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>) 406func @redundant_cstr_broadcastable(%arg0: tensor<?xindex>, 407 %arg1 : tensor<?xindex>) { 408 // CHECK-DAG: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]] 409 // CHECK: shape.assuming %[[WITNESS]] 410 %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex> 411 %1 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex> 412 %2 = shape.assuming_all %0, %1 413 shape.assuming %2 -> () { 414 "some.op"() : () -> () 415 shape.assuming_yield 416 } 417 return 418} 419