1// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s 2 3// Lower binary ops. 4// CHECK-LABEL: @binary_ops 5// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) 6func @binary_ops(%lhs : index, %rhs : index) { 7 // CHECK: addi %[[LHS]], %[[RHS]] : index 8 %sum = shape.add %lhs, %rhs : index, index -> index 9 // CHECK: muli %[[LHS]], %[[RHS]] : index 10 %product = shape.mul %lhs, %rhs : index, index -> index 11 return 12} 13 14// ----- 15 16// Don't lower binary ops when they operate on `shape.size`. 17// CHECK-LABEL: @binary_ops_on_size 18// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size) 19func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) { 20 // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size 21 // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size 22 %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size 23 %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size 24 return 25} 26 27// ----- 28 29// Convert `rank` to `dim` of the first dimension. 30// CHECK-LABEL: @rank 31// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index 32func @rank(%shape : tensor<?xindex>) -> index { 33 // CHECK: %[[C0:.*]] = constant 0 : index 34 // CHECK: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] 35 // CHECK: return %[[RESULT]] : index 36 %rank = shape.rank %shape : tensor<?xindex> -> index 37 return %rank : index 38} 39 40// ----- 41 42// Don't lower `get_extent` if it is of type `shape.size`. 43// CHECK-LABEL: @get_extent 44func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size { 45 // CHECK: shape.get_extent 46 %result = shape.get_extent %shape, %idx 47 : tensor<?xindex>, !shape.size -> !shape.size 48 return %result : !shape.size 49} 50 51// ----- 52 53// Don't lower `rank` if type is not error-free. 54// CHECK-LABEL: @rank 55func @rank(%shape : !shape.shape) { 56 // CHECK: shape.rank 57 %rank = shape.rank %shape : !shape.shape -> !shape.size 58 return 59} 60 61// ----- 62 63// Express `get_extent` as `std.dim` when it relies directly on the outcome of a 64// `shape_of` operation. 65// CHECK-LABEL: @get_extent_shape_of 66// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index 67func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index { 68 // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> 69 // CHECK: return %[[RESULT]] : index 70 %shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex> 71 %result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index 72 return %result : index 73} 74 75// ----- 76 77// Express `get_extent` as `std.extract_element`. 78// CHECK-LABEL: @get_extent_from_extent_tensor 79// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index 80func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index) 81 -> index { 82 // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex> 83 // CHECK: return %[[RESULT]] : index 84 %result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index 85 return %result : index 86} 87 88// ----- 89 90// Lower `const_shape` to `tensor_from_elements`. 91// CHECK-LABEL: @const_shape 92// CHECK-SAME: () -> tensor<?xindex> 93func @const_shape() -> tensor<?xindex> { 94 // CHECK: %[[C1:.*]] = constant 1 : index 95 // CHECK: %[[C2:.*]] = constant 2 : index 96 // CHECK: %[[C3:.*]] = constant 3 : index 97 // CHECK: %[[TENSOR3:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] 98 // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex> 99 // CHECK: return %[[RESULT]] : tensor<?xindex> 100 %shape = shape.const_shape [1, 2, 3] : tensor<?xindex> 101 return %shape : tensor<?xindex> 102} 103 104// ----- 105 106// Lower `const_shape` in the case of rank 0. 107// CHECK-LABEL: func @const_shape_zero_elements 108// CHECK-SAME: () -> tensor<?xindex> 109func @const_shape_zero_elements() -> tensor<?xindex> { 110 // CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex> 111 // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex> 112 // CHECK: return %[[RESULT]] : tensor<?xindex> 113 %shape = shape.const_shape [] : tensor<?xindex> 114 return %shape : tensor<?xindex> 115} 116 117// ----- 118 119// Lower `any` to its first operand. 120// CHECK-LABEL: @any_of_three 121// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex> 122func @any_of_three(%a : tensor<?xindex>, 123 %b : tensor<?xindex>, 124 %c : tensor<?xindex>) -> tensor<?xindex> { 125 // CHECK: return %[[A]] : tensor<?xindex> 126 %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex> 127 return %result : tensor<?xindex> 128} 129 130// ----- 131 132// Lower `any` to its first operand. 133// CHECK-LABEL: @any_of_one 134// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex> 135func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> { 136 // CHECK: return %[[A]] : tensor<?xindex> 137 %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex> 138 return %result : tensor<?xindex> 139} 140 141// ----- 142 143// Lower 'const_size` to `std.constant` 144// CHECK-LABEL: @const_size 145func @const_size() -> index { 146 // CHECK: %[[RES:.*]] = constant 42 : index 147 %size = shape.const_size 42 148 %result = shape.size_to_index %size : !shape.size 149 // CHECK: return %[[RES]] 150 return %result : index 151} 152 153// ----- 154 155// Lower `to_extent_tensor` to `std.tensor_cast` 156// Fold to_extent_tensor when already on tensor. 157// CHECK-LABEL: @to_extent_tensor 158// CHECK-SAME: (%[[ARG:.*]]: tensor<?xindex> 159func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> { 160 // CHECK-NOT: to_extent_tensor 161 // CHECK: %[[RES:.*]] = tensor_cast %[[ARG]] : tensor<?xindex> to tensor<3xindex 162 %casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<3xindex> 163 // CHECK: return %[[RES]] 164 return %casted : tensor<3xindex> 165} 166 167// CHECK-LABEL: @shape_reduce 168// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index 169func @shape_reduce(%shape : tensor<?xindex>) -> index { 170 %init = constant 1 : index 171 %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index { 172 ^bb0(%index : index, %extent : index, %acc: index): 173 %new_acc = muli %acc, %extent : index 174 shape.yield %new_acc : index 175 } 176 return %num_elements : index 177} 178// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index 179// CHECK-NEXT: %[[C0:.*]] = constant 0 : index 180// CHECK-NEXT: %[[C1:.*]] = constant 1 : index 181// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex> 182// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index) 183// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]] 184// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index 185// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index 186// CHECK-NEXT: } 187// CHECK-NEXT: return %[[RESULT]] : index 188 189// ----- 190 191// Don't lower `shape_of` for result type of `shape.shape`. 192// CHECK-LABEL: @shape_of 193// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) 194func @shape_of(%arg : tensor<*xf32>) { 195 // CHECK: shape.shape 196 %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape 197 return 198} 199 200// ----- 201 202// Lower `shape_of` for unranked tensors. 203// CHECK-LABEL: @shape_of_unranked 204// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) 205func @shape_of_unranked(%arg : tensor<*xf32>) { 206 // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> 207 // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] { 208 // CHECK: ^bb0(%[[I:.*]]: index): 209 // CHECK: %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> 210 // CHECK: yield %[[EXTENT]] : index 211 // CHECK: } : tensor<?xindex> 212 %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex> 213 return 214} 215 216// ----- 217 218// Don't lower `shape_of` with `shape.shape` type. 219// CHECK-LABEL: @shape_of 220// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) 221func @shape_of_stat(%arg : tensor<1x2x3xf32>) { 222 // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape 223 %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape 224 return 225} 226 227// ----- 228 229// Lower `shape_of` for statically shaped tensor. 230// CHECK-LABEL: @shape_of_stat 231// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) 232func @shape_of_stat(%arg : tensor<1x2x3xf32>) { 233 // CHECK-DAG: %[[C1:.*]] = constant 1 : index 234 // CHECK-DAG: %[[C2:.*]] = constant 2 : index 235 // CHECK-DAG: %[[C3:.*]] = constant 3 : index 236 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C2]], %[[C3]] : tensor<3xindex> 237 %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex> 238 return 239} 240 241// ----- 242 243// Lower `shape_of` for 0-D tensor. 244// CHECK-LABEL: @shape_of_zero_d 245// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>) 246func @shape_of_zero_d(%arg : tensor<f32>) { 247 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex> 248 %shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex> 249 return 250} 251 252// ----- 253 254// Lower `shape_of` for dynamically shaped tensor. 255// CHECK-LABEL: @shape_of_dyn 256// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) 257func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { 258 // CHECK-DAG: %[[C1:.*]] = constant 1 : index 259 // CHECK-DAG: %[[C5:.*]] = constant 5 : index 260 // CHECK-DAG: %[[C2:.*]] = constant 2 : index 261 // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> 262 // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements %[[C1]], %[[C5]], %[[DYN_DIM]] : tensor<3xindex> 263 %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex> 264 return 265} 266 267// ----- 268 269// CHECK-LABEL: @shape_eq 270// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1 271func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 { 272 // CHECK: %[[C0:.*]] = constant 0 : index 273 // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex> 274 // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex> 275 // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]] 276 // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) { 277 // CHECK: %[[C1:.*]] = constant 1 : index 278 // CHECK: %[[INIT:.*]] = constant true 279 // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) { 280 // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex> 281 // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex> 282 // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]] 283 // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]] 284 // CHECK: scf.yield %[[CONJ_NEXT]] : i1 285 // CHECK: } 286 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 287 // CHECK: } else { 288 // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false 289 // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1 290 // CHECK: } 291 // CHECK: return %[[SHAPE_EQ]] : i1 292 %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 293 return %result : i1 294} 295 296// ----- 297 298// Don't lower `shape.broadcast` if a `shape.shape` type is involved. 299// CHECK-LABEL: @broadcast 300func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape { 301 // CHECK: shape.broadcast 302 %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape 303 return %c : !shape.shape 304} 305 306// ----- 307 308// CHECK-LABEL: func @broadcast_unknown_extents( 309// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>, 310// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) { 311func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) { 312 // CHECK: %[[C0:.*]] = constant 0 : index 313 // CHECK: %[[C1:.*]] = constant 1 : index 314 // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex> 315 // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex> 316 // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index 317 // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index 318 // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index 319 // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex> 320 // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex> 321 // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex> 322 // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex> 323 // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index 324 // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { 325 // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): 326 // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index 327 // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex> 328 // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { 329 // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index 330 // CHECK: } else { 331 // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index 332 // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex> 333 // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index 334 // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index 335 // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index 336 // CHECK: } 337 // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index 338 // CHECK: } : tensor<?xindex> 339 // CHECK: return 340 // CHECK: } 341 %0 = shape.broadcast %a, %b 342 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> 343 return 344} 345 346// ----- 347 348// CHECK-LABEL: func @broadcast_known_different_extents( 349// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>, 350// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) { 351func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) { 352 // CHECK: %[[C0:.*]] = constant 0 : index 353 // CHECK: %[[C1:.*]] = constant 1 : index 354 // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> 355 // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> 356 // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index 357 // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index 358 // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index 359 // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex> 360 // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex> 361 // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex> 362 // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex> 363 // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index 364 // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { 365 // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): 366 // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index 367 // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex> 368 // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { 369 // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index 370 // CHECK: } else { 371 // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index 372 // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex> 373 // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index 374 // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index 375 // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index 376 // CHECK: } 377 // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index 378 // CHECK: } : tensor<?xindex> 379 // CHECK: return 380 // CHECK: } 381 %0 = shape.broadcast %a, %b 382 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex> 383 return 384} 385 386// ----- 387 388func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 { 389 %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex> 390 return %0 : i1 391} 392 393// CHECK-LABEL: func @try_is_broadcastable( 394// CHECK-SAME: %[[LHS:.*]]: tensor<3xindex>, 395// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> i1 { 396// CHECK: %[[C0:.*]] = constant 0 : index 397// CHECK: %[[C1:.*]] = constant 1 : index 398// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex> 399// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex> 400// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index 401// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index 402// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index 403// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex> 404// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex> 405// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex> 406// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex> 407// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index 408// CHECK: %[[TRUE:.*]] = constant true 409// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { 410// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex> 411// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index 412// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index 413// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex> 414// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index 415// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index 416// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 417// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 418// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 419// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 420// CHECK: } 421// CHECK: return %[[ALL_RESULT]] : i1 422// CHECK: } 423 424// ----- 425 426func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness { 427 %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex> 428 return %0 : !shape.witness 429} 430 431// CHECK-LABEL: func @broadcast( 432// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>, 433// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness { 434// CHECK: %[[C0:.*]] = constant 0 : index 435// CHECK: %[[C1:.*]] = constant 1 : index 436// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex> 437// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex> 438// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index 439// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index 440// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index 441// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex> 442// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex> 443// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex> 444// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex> 445// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index 446// CHECK: %[[TRUE:.*]] = constant true 447// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { 448// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex> 449// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index 450// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index 451// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex> 452// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index 453// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index 454// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 455// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 456// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 457// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 458// CHECK: } 459// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" 460// CHECK: return %[[RESULT]] : !shape.witness 461// CHECK: } 462