1// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s 2// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON 3 4// Tests for types, ops with custom constraints, verifiers, printer or parser 5// methods. 6 7// CHECK-LABEL: tensor_type_noconstraint 8func private @tensor_type_noconstraint() -> !tfr.tensor 9 10// ----- 11 12// CHECK-LABEL: tensor_type 13func private @tensor_type() -> !tfr.tensor<T> 14 15// ----- 16 17// CHECK-LABEL: tensor_list_type_noconstraint 18func private @tensor_list_type_noconstraint() -> !tfr.tensor_list 19 20// ----- 21 22// CHECK-LABEL: tensor_list_type_array_like 23func private @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]> 24 25// ----- 26 27// CHECK-LABEL: tensor_list_type_tuple_like 28func private @tensor_list_type_tuple_like() -> !tfr.tensor_list<input_T> 29 30// ----- 31 32// expected-error@+1 {{unbalanced '>' character in pretty dialect name}} 33func private @tensor_invalid_1() -> !tfr.tensor<[N, T> 34 35// ----- 36 37// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}} 38func @tensor_invalid_2() -> !tfr.tensor<[N, T] 39 40// ----- 41 42// CHECK-LABEL: call_op 43func @call_op(%arg0: !tfr.tensor<T>, %arg1: !tfr.tensor_list<TL>, %arg2: i32) -> !tfr.tensor<K> { 44 %0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor<T>, !tfr.tensor_list<TL>, i32) -> !tfr.tensor<K> 45 return %0 : !tfr.tensor<K> 46} 47 48// ----- 49 50// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K> 51func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K> { 52 %0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor<K> 53 return %0 : !tfr.tensor<K> 54} 55 56// ----- 57 58func @call_op_invalid_1(%arg0: tensor<?xf32>) -> !tfr.tensor<K> { 59 // expected-error@+1 {{got 'tensor<?xf32>'}} 60 %0 = tfr.call @Huu(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K> 61 return %0 : !tfr.tensor<K> 62} 63 64// ----- 65 66// CHECK-LABEL: get_shape 67func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) { 68 %0 = tfr.get_shape %arg0 -> !shape.shape 69 %1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape 70 return %0, %1 : !shape.shape, !shape.shape 71} 72 73// ----- 74 75// CHECK-LABEL: get_real_shape 76// CANON-LABEL: get_real_shape 77func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> { 78 %0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor 79 %1 = tfr.get_shape %0 -> !shape.shape 80 %2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex> 81 return %2 : tensor<1xindex> 82 83// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor<?xindex> 84// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor<?xindex> -> tensor<1xindex> 85// CANON-NEXT: return %[[e]] : tensor<1xindex> 86} 87 88// ----- 89 90func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) { 91 %0 = tfr.get_element_type %arg0 -> !tfr.attr 92 %1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr 93 return %0, %1 : !tfr.attr, !tfr.attr 94} 95 96// ----- 97 98// CHECK-LABEL: from_tf_tensor 99func @from_tf_tensor(%arg0: tensor<?xf32>) -> !tfr.tensor<K> { 100 %0 = "tfr.cast"(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K> 101 return %0 : !tfr.tensor<K> 102} 103 104// ----- 105 106// CHECK-LABEL: to_tf_tensor 107func @to_tf_tensor(%arg0: !tfr.tensor<T>) -> tensor<?xi32> { 108 %0 = "tfr.cast"(%arg0) : (!tfr.tensor<T>) -> tensor<?xi32> 109 return %0 : tensor<?xi32> 110} 111 112// ----- 113 114// CHECK-LABEL: constant 115func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) { 116 %0 = tfr.constant f32 -> !tfr.attr 117 %1 = tfr.constant [f32, i32] -> !tfr.attr 118 %2 = "tfr.constant"() {value = f32} : () -> !tfr.attr 119 %3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr 120 return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr 121} 122 123// ----- 124 125// CHECK-LABEL: equal 126// CANON-LABEL: equal 127func @equal() -> (i1, i1, i1, i1) { 128 %0 = tfr.constant f32 -> !tfr.attr 129 %1 = tfr.constant f32 -> !tfr.attr 130 %2 = tfr.constant i32 -> !tfr.attr 131 %same_type = tfr.equal %0,%1 -> i1 132 %diff_type = tfr.equal %0,%2 -> i1 133 134 %3 = tfr.constant "hello" -> !tfr.attr 135 %4 = tfr.constant "hello" -> !tfr.attr 136 %5 = tfr.constant "how are you" -> !tfr.attr 137 %same_str = tfr.equal %3,%4 -> i1 138 %diff_str = tfr.equal %3,%5 -> i1 139 return %same_type, %diff_type, %same_str, %diff_str : i1, i1, i1, i1 140 141// CANON-NEXT: %true = constant true 142// CANON-NEXT: %false = constant false 143// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1 144} 145 146// ----- 147 148// CHECK-LABEL: constant_tensor_scalar 149func @constant_tensor_scalar(%arg0: i32) -> tensor<i32> { 150 %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<i32> 151 return %0 : tensor<i32> 152} 153 154// ----- 155 156// CHECK-LABEL: constant_tensor_vector 157func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> { 158 %0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32> 159 return %0 : tensor<1x2xi32> 160} 161 162// ----- 163 164// CHECK-LABEL: constant_tensor_array 165// CANON-LABEL: constant_tensor_array 166func @constant_tensor_array() -> !tfr.tensor { 167 %0 = tfr.constant [1, -1, 3] -> !tfr.attr 168 %1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor 169 return %1 : !tfr.tensor 170 171// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64> 172// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor 173// CANON-NEXT: return 174} 175 176// ----- 177 178// CHECK-LABEL: constant_tensor_scalar 179// CANON-LABEL: constant_tensor_scalar 180func @constant_tensor_scalar() -> !tfr.tensor { 181 %0 = "std.constant"() {value = 42 : i32} : () -> i32 182 %1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor 183 return %1 : !tfr.tensor 184 185// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32> 186// CANON-NEXT: "tfr.cast"(%0) : (tensor<i32>) -> !tfr.tensor 187// CANON-NEXT: return 188} 189 190// ----- 191 192func @constant_tensor_invalid_0(%arg0: i32) -> tensor<f32> { 193 // expected-error@+1 {{input and output should have the same scalar types.}} 194 %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<f32> 195 return %0 : tensor<f32> 196} 197 198// ----- 199 200func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor<?xi32> { 201 // expected-error@+1 {{output type should be static and ranked}} 202 %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<?xi32> 203 return %0 : tensor<?xi32> 204} 205 206// ----- 207 208func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> { 209 // expected-error@+1 {{input and output should have same shape and element type}} 210 %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32> 211 return %0 : tensor<1xf32> 212} 213 214// ----- 215 216func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> { 217 // expected-error@+1 {{input and output should have same shape and element type}} 218 %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32> 219 return %0 : tensor<1x1xi32> 220} 221 222// ----- 223 224func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> { 225 // expected-error@+1 {{input can not be converted to an output tensor}} 226 %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32> 227 return %0 : tensor<1x1xi32> 228} 229 230// ----- 231 232// CHECK-LABEL: get_element 233func @get_element(%arg0: !tfr.tensor_list<T>) -> !tfr.tensor { 234 %cst = "std.constant"() {value = 1 : index} : () -> index 235 %0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list<T>, index) -> !tfr.tensor 236 return %0 : !tfr.tensor 237} 238 239// ----- 240 241// CHECK-LABEL: build_list 242func @build_list(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> !tfr.tensor_list { 243 %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list 244 return %0 : !tfr.tensor_list 245} 246 247// ----- 248 249// CHECK-LABEL: build_const_list 250// CANON-LABEL: build_const_list 251func @build_const_list() -> !tfr.attr { 252 %0 = "std.constant"() {value = 42 : i32} : () -> i32 253 %1 = "std.constant"() {value = 41 : i32} : () -> i32 254 %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr 255 return %2 : !tfr.attr 256 257// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr 258// CANON-NEXT: return %[[c]] : !tfr.attr 259} 260 261// ----- 262 263// CHECK-LABEL: build_high_dim_const_list 264// CANON-LABEL: build_high_dim_const_list 265func @build_high_dim_const_list() -> !tfr.attr { 266 %0 = "std.constant"() {value = 42 : i32} : () -> i32 267 %1 = "std.constant"() {value = 41 : i32} : () -> i32 268 %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr 269 %3 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr 270 %4 = "tfr.build_list"(%2, %3) : (!tfr.attr, !tfr.attr) -> !tfr.attr 271 return %4 : !tfr.attr 272 273// CANON-NEXT: %[[c:.*]] = tfr.constant {{\[}}[42 : i32, 41 : i32], [42 : i32, 41 : i32]] -> !tfr.attr 274// CANON-NEXT: return %[[c]] : !tfr.attr 275} 276 277// ----- 278 279// CHECK-LABEL: get_length 280// CANON-LABEL: get_length 281func @get_length(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> index { 282 %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list 283 %1 = "tfr.get_length"(%0) : (!tfr.tensor_list) -> index 284 return %1 : index 285 286// CANON-NEXT: %[[c:.*]] = constant 2 : index 287// CANON-NEXT: return %[[c]] : index 288} 289 290// ----- 291 292// CHECK-LABEL: tfr.func 293tfr.func @External(%arg0: !tfr.tensor<A>, 294 %arg1: !tfr.tensor_list<C>, 295 %arg2: i32 {tfr.name = "A"}, 296 %arg3: !tfr.attr {tfr.name = "T"}) 297 -> (!tfr.tensor<A>, !tfr.tensor_list<C>) 298 attributes {A, C} 299 300// ----- 301 302// CHECK-LABEL: tfr.func 303tfr.func @Foo(%arg0: !tfr.tensor<A>, 304 %arg1: !tfr.tensor_list<C>, 305 %arg2: i32 {tfr.name = "A"}, 306 %arg3: vector<1xi32> {tfr.name = "C"}) 307 -> (!tfr.tensor<A>, !tfr.tensor_list<C>) 308 attributes {A, C} { 309 tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<C> 310} 311 312// ----- 313 314// CHECK-LABEL: tfr.func 315tfr.func @Bar(%arg0: !tfr.tensor<A>, 316 %arg2: i32 {tfr.name = "B"}, 317 %arg3: vector<1xi32> {tfr.name = "C"}) 318 -> (!tfr.tensor<A>, !tfr.tensor<A>) 319 attributes {A} { 320 tfr.return %arg0, %arg0 : !tfr.tensor<A>, !tfr.tensor<A> 321} 322 323// ----- 324 325// expected-error@+1 {{Undefined attributes are used: A}} 326tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor<A>, 327 %arg1: !tfr.tensor_list<A>, 328 %arg2: i32 {tfr.name = "A"}, 329 %arg3: vector<1xi32> {tfr.name = "C"}) -> 330 (!tfr.tensor<A>, !tfr.tensor_list<A>) { 331 tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A> 332} 333 334// ----- 335 336// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}} 337tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor<A>, 338 %arg1: !tfr.tensor_list<A>, 339 %arg2: i32 {tfr.name = "A"}, 340 %arg3: vector<1xi32>) -> 341 (!tfr.tensor<A>, !tfr.tensor_list<A>) { 342 tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A> 343} 344 345// ----- 346 347// expected-error@+1 {{tfr.tensor/tfr.tensor_list argument should be before non tensor arguments}} 348tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor<A>, 349 %arg2: i32 {tfr.name = "A"}, 350 %arg1: !tfr.tensor_list<A>, 351 %arg3: vector<1xi32> {tfr.name = "C"}) -> 352 (!tfr.tensor<A>, !tfr.tensor_list<A>) { 353 tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A> 354} 355 356// ----- 357 358tfr.func @Foo_valid_arg_order0( 359 %arg1: !tfr.tensor_list, 360 %arg0: !tfr.tensor<T>, 361 %arg2: i32 {tfr.name = "A"}, 362 %arg3: vector<1xi32> {tfr.name = "C"}) -> 363 (!tfr.tensor, !tfr.tensor_list) attributes {T}{ 364 tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list 365} 366 367// ----- 368 369// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}} 370tfr.func @Foo_invalid_arg_order0( 371 %arg1: !tfr.tensor_list, 372 %arg0: !tfr.tensor<T>, 373 %arg2: i32 {tfr.name = "A"}, 374 %arg3: vector<1xi32> {tfr.name = "C"}) -> 375 (!tfr.tensor, !tfr.tensor_list) { 376 tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list 377} 378 379// ----- 380 381// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}} 382tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor<A>, 383 %arg1: !tfr.tensor_list<A>, 384 %arg2: i32 {tfr.name = "A"}, 385 %arg3: vector<1xi32> {tfr.name = "C"}) -> 386 (!tfr.tensor_list<A>, !tfr.tensor<A>) { 387 tfr.return %arg1, %arg0 : !tfr.tensor_list<A>, !tfr.tensor<A> 388} 389 390// ----- 391 392// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}} 393tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor<A>, 394 %arg1: !tfr.tensor_list<A>, 395 %arg2: !tfr.tensor_list<A>, 396 %arg3: i32 {tfr.name = "A"}, 397 %arg4: vector<1xi32> {tfr.name = "C"}) -> 398 (!tfr.tensor<A>, !tfr.tensor_list<A>) { 399 tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A> 400} 401 402// ----- 403 404// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}} 405tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor<C>, 406 %arg1: !tfr.tensor_list<A>, 407 %arg2: i32 {tfr.name = "A"}, 408 %arg3: vector<1xi32> {tfr.name = "C"}) -> 409 (!tfr.tensor_list<A>, !tfr.tensor_list<A>) { 410 tfr.return %arg1, %arg1 : !tfr.tensor_list<A>, !tfr.tensor_list<A> 411} 412 413// ----- 414 415// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}} 416tfr.func @Foo_return_attr(%arg0: !tfr.tensor<C>, 417 %arg1: !tfr.tensor_list<A>, 418 %arg2: i32 {tfr.name = "A"}, 419 %arg3: vector<1xi32> {tfr.name = "C"}) -> i32 { 420 tfr.return %arg2 : i32 421} 422