• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: tfr-opt %s -verify-diagnostics -split-input-file | tfr-opt | FileCheck %s
2// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s -check-prefix=CANON
3
4// Tests for types, ops with custom constraints, verifiers, printer or parser
5// methods.
6
7// CHECK-LABEL: tensor_type_noconstraint
8func private @tensor_type_noconstraint() -> !tfr.tensor
9
10// -----
11
12// CHECK-LABEL: tensor_type
13func private @tensor_type() -> !tfr.tensor<T>
14
15// -----
16
17// CHECK-LABEL: tensor_list_type_noconstraint
18func private @tensor_list_type_noconstraint() -> !tfr.tensor_list
19
20// -----
21
22// CHECK-LABEL: tensor_list_type_array_like
23func private @tensor_list_type_array_like() -> !tfr.tensor_list<[N, T]>
24
25// -----
26
27// CHECK-LABEL: tensor_list_type_tuple_like
28func private @tensor_list_type_tuple_like() -> !tfr.tensor_list<input_T>
29
30// -----
31
32// expected-error@+1 {{unbalanced '>' character in pretty dialect name}}
33func private @tensor_invalid_1() -> !tfr.tensor<[N, T>
34
35// -----
36
37// expected-error@+1 {{unexpected nul or EOF in pretty dialect name}}
38func @tensor_invalid_2() -> !tfr.tensor<[N, T]
39
40// -----
41
42// CHECK-LABEL: call_op
43func @call_op(%arg0: !tfr.tensor<T>, %arg1: !tfr.tensor_list<TL>, %arg2: i32) -> !tfr.tensor<K> {
44  %0 = tfr.call @Foo(%arg0, %arg1, %arg2) : (!tfr.tensor<T>, !tfr.tensor_list<TL>, i32) -> !tfr.tensor<K>
45  return %0 : !tfr.tensor<K>
46}
47
48// -----
49
50// CHECK-LABEL: call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K>
51func @call_op_arg_attr(%arg0: i32) -> !tfr.tensor<K> {
52  %0 = tfr.call @Bar(%arg0) : (i32) -> !tfr.tensor<K>
53  return %0 : !tfr.tensor<K>
54}
55
56// -----
57
58func @call_op_invalid_1(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
59  // expected-error@+1 {{got 'tensor<?xf32>'}}
60  %0 = tfr.call @Huu(%arg0)  : (tensor<?xf32>) -> !tfr.tensor<K>
61  return %0 : !tfr.tensor<K>
62}
63
64// -----
65
66// CHECK-LABEL: get_shape
67func @get_shape(%arg0: !tfr.tensor) -> (!shape.shape, !shape.shape) {
68  %0 = tfr.get_shape %arg0 -> !shape.shape
69  %1 = "tfr.get_shape"(%arg0) : (!tfr.tensor) -> !shape.shape
70  return %0, %1 : !shape.shape, !shape.shape
71}
72
73// -----
74
75// CHECK-LABEL: get_real_shape
76// CANON-LABEL: get_real_shape
77func @get_real_shape(%arg0: tensor<1x2xf32>) -> tensor<1xindex> {
78  %0 = "tfr.cast"(%arg0) : (tensor<1x2xf32>) -> !tfr.tensor
79  %1 = tfr.get_shape %0 -> !shape.shape
80  %2 = shape.to_extent_tensor %1 : !shape.shape -> tensor<1xindex>
81  return %2 : tensor<1xindex>
82
83// CANON-NEXT: %[[s:.*]] = shape.const_shape [1, 2] : tensor<?xindex>
84// CANON-NEXT: %[[e:.*]] = shape.to_extent_tensor %[[s]] : tensor<?xindex> -> tensor<1xindex>
85// CANON-NEXT: return %[[e]] : tensor<1xindex>
86}
87
88// -----
89
90func @get_element_type(%arg0: !tfr.tensor) -> (!tfr.attr, !tfr.attr) {
91  %0 = tfr.get_element_type %arg0 -> !tfr.attr
92  %1 = "tfr.get_element_type"(%arg0) : (!tfr.tensor) -> !tfr.attr
93  return %0, %1 : !tfr.attr, !tfr.attr
94}
95
96// -----
97
98// CHECK-LABEL: from_tf_tensor
99func @from_tf_tensor(%arg0: tensor<?xf32>) -> !tfr.tensor<K> {
100  %0 = "tfr.cast"(%arg0) : (tensor<?xf32>) -> !tfr.tensor<K>
101  return %0 : !tfr.tensor<K>
102}
103
104// -----
105
106// CHECK-LABEL: to_tf_tensor
107func @to_tf_tensor(%arg0: !tfr.tensor<T>) -> tensor<?xi32> {
108  %0 = "tfr.cast"(%arg0) : (!tfr.tensor<T>) -> tensor<?xi32>
109  return %0 : tensor<?xi32>
110}
111
112// -----
113
114// CHECK-LABEL: constant
115func @constant() -> (!tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr) {
116  %0 = tfr.constant f32 -> !tfr.attr
117  %1 = tfr.constant [f32, i32] -> !tfr.attr
118  %2 = "tfr.constant"() {value = f32} : () -> !tfr.attr
119  %3 = "tfr.constant"() {value = [f32, i32]} : () -> !tfr.attr
120  return %0, %1, %2, %3 : !tfr.attr, !tfr.attr, !tfr.attr, !tfr.attr
121}
122
123// -----
124
125// CHECK-LABEL: equal
126// CANON-LABEL: equal
127func @equal() -> (i1, i1, i1, i1) {
128  %0 = tfr.constant f32 -> !tfr.attr
129  %1 = tfr.constant f32 -> !tfr.attr
130  %2 = tfr.constant i32 -> !tfr.attr
131  %same_type = tfr.equal %0,%1 -> i1
132  %diff_type = tfr.equal %0,%2 -> i1
133
134  %3 = tfr.constant "hello" -> !tfr.attr
135  %4 = tfr.constant "hello" -> !tfr.attr
136  %5 = tfr.constant "how are you" -> !tfr.attr
137  %same_str = tfr.equal %3,%4 -> i1
138  %diff_str = tfr.equal %3,%5 -> i1
139  return %same_type, %diff_type, %same_str, %diff_str  : i1, i1, i1, i1
140
141// CANON-NEXT: %true = constant true
142// CANON-NEXT: %false = constant false
143// CANON-NEXT: return %true, %false, %true, %false : i1, i1, i1, i1
144}
145
146// -----
147
148// CHECK-LABEL: constant_tensor_scalar
149func @constant_tensor_scalar(%arg0: i32) -> tensor<i32> {
150  %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<i32>
151  return %0 : tensor<i32>
152}
153
154// -----
155
156// CHECK-LABEL: constant_tensor_vector
157func @constant_tensor_vector(%arg0: vector<1x2xi32>) -> tensor<1x2xi32> {
158  %0 = "tfr.constant_tensor"(%arg0) : (vector<1x2xi32>) -> tensor<1x2xi32>
159  return %0 : tensor<1x2xi32>
160}
161
162// -----
163
164// CHECK-LABEL: constant_tensor_array
165// CANON-LABEL: constant_tensor_array
166func @constant_tensor_array() -> !tfr.tensor {
167  %0 = tfr.constant [1, -1, 3] -> !tfr.attr
168  %1 = "tfr.constant_tensor"(%0) : (!tfr.attr) -> !tfr.tensor
169  return %1 : !tfr.tensor
170
171// CANON-NEXT: "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi64>} : () -> tensor<3xi64>
172// CANON-NEXT: "tfr.cast"(%0) : (tensor<3xi64>) -> !tfr.tensor
173// CANON-NEXT: return
174}
175
176// -----
177
178// CHECK-LABEL: constant_tensor_scalar
179// CANON-LABEL: constant_tensor_scalar
180func @constant_tensor_scalar() -> !tfr.tensor {
181  %0 = "std.constant"() {value = 42 : i32} : () -> i32
182  %1 = "tfr.constant_tensor"(%0) : (i32) -> !tfr.tensor
183  return %1 : !tfr.tensor
184
185// CANON-NEXT: "tf.Const"() {value = dense<42> : tensor<i32>} : () -> tensor<i32>
186// CANON-NEXT: "tfr.cast"(%0) : (tensor<i32>) -> !tfr.tensor
187// CANON-NEXT: return
188}
189
190// -----
191
192func @constant_tensor_invalid_0(%arg0: i32) -> tensor<f32> {
193    // expected-error@+1 {{input and output should have the same scalar types.}}
194  %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<f32>
195  return %0 : tensor<f32>
196}
197
198// -----
199
200func @constant_tensor_invalid_1(%arg0: vector<1xi32>) -> tensor<?xi32> {
201    // expected-error@+1 {{output type should be static and ranked}}
202  %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<?xi32>
203  return %0 : tensor<?xi32>
204}
205
206// -----
207
208func @constant_tensor_invalid_2(%arg0: vector<1xi32>) -> tensor<1xf32> {
209    // expected-error@+1 {{input and output should have same shape and element type}}
210  %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1xf32>
211  return %0 : tensor<1xf32>
212}
213
214// -----
215
216func @constant_tensor_invalid_3(%arg0: vector<1xi32>) -> tensor<1x1xi32> {
217    // expected-error@+1 {{input and output should have same shape and element type}}
218  %0 = "tfr.constant_tensor"(%arg0) : (vector<1xi32>) -> tensor<1x1xi32>
219  return %0 : tensor<1x1xi32>
220}
221
222// -----
223
224func @constant_tensor_invalid_4(%arg0: i32) -> tensor<1x1xi32> {
225    // expected-error@+1 {{input can not be converted to an output tensor}}
226  %0 = "tfr.constant_tensor"(%arg0) : (i32) -> tensor<1x1xi32>
227  return %0 : tensor<1x1xi32>
228}
229
230// -----
231
232// CHECK-LABEL: get_element
233func @get_element(%arg0: !tfr.tensor_list<T>) -> !tfr.tensor {
234  %cst = "std.constant"() {value = 1 : index} : () -> index
235  %0 = tfr.get_element %arg0[%cst] : (!tfr.tensor_list<T>, index) -> !tfr.tensor
236  return %0 : !tfr.tensor
237}
238
239// -----
240
241// CHECK-LABEL: build_list
242func @build_list(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> !tfr.tensor_list {
243  %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
244  return %0 : !tfr.tensor_list
245}
246
247// -----
248
249// CHECK-LABEL: build_const_list
250// CANON-LABEL: build_const_list
251func @build_const_list() -> !tfr.attr {
252  %0 = "std.constant"() {value = 42 : i32} : () -> i32
253  %1 = "std.constant"() {value = 41 : i32} : () -> i32
254  %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
255  return %2 : !tfr.attr
256
257// CANON-NEXT: %[[c:.*]] = tfr.constant [42 : i32, 41 : i32] -> !tfr.attr
258// CANON-NEXT: return %[[c]] : !tfr.attr
259}
260
261// -----
262
263// CHECK-LABEL: build_high_dim_const_list
264// CANON-LABEL: build_high_dim_const_list
265func @build_high_dim_const_list() -> !tfr.attr {
266  %0 = "std.constant"() {value = 42 : i32} : () -> i32
267  %1 = "std.constant"() {value = 41 : i32} : () -> i32
268  %2 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
269  %3 = "tfr.build_list"(%0, %1) : (i32, i32) -> !tfr.attr
270  %4 = "tfr.build_list"(%2, %3) : (!tfr.attr, !tfr.attr) -> !tfr.attr
271  return %4 : !tfr.attr
272
273// CANON-NEXT: %[[c:.*]] = tfr.constant {{\[}}[42 : i32, 41 : i32], [42 : i32, 41 : i32]] -> !tfr.attr
274// CANON-NEXT: return %[[c]] : !tfr.attr
275}
276
277// -----
278
279// CHECK-LABEL: get_length
280// CANON-LABEL: get_length
281func @get_length(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> index {
282  %0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
283  %1 = "tfr.get_length"(%0) : (!tfr.tensor_list) -> index
284  return %1 : index
285
286// CANON-NEXT: %[[c:.*]] = constant 2 : index
287// CANON-NEXT: return %[[c]] : index
288}
289
290// -----
291
292// CHECK-LABEL: tfr.func
293tfr.func @External(%arg0: !tfr.tensor<A>,
294              %arg1: !tfr.tensor_list<C>,
295              %arg2: i32 {tfr.name = "A"},
296              %arg3: !tfr.attr {tfr.name = "T"})
297  -> (!tfr.tensor<A>, !tfr.tensor_list<C>)
298  attributes {A, C}
299
300// -----
301
302// CHECK-LABEL: tfr.func
303tfr.func @Foo(%arg0: !tfr.tensor<A>,
304              %arg1: !tfr.tensor_list<C>,
305              %arg2: i32 {tfr.name = "A"},
306              %arg3: vector<1xi32> {tfr.name = "C"})
307  -> (!tfr.tensor<A>, !tfr.tensor_list<C>)
308  attributes {A, C} {
309  tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<C>
310}
311
312// -----
313
314// CHECK-LABEL: tfr.func
315tfr.func @Bar(%arg0: !tfr.tensor<A>,
316              %arg2: i32 {tfr.name = "B"},
317              %arg3: vector<1xi32> {tfr.name = "C"})
318  -> (!tfr.tensor<A>, !tfr.tensor<A>)
319  attributes {A} {
320  tfr.return %arg0, %arg0 : !tfr.tensor<A>, !tfr.tensor<A>
321}
322
323// -----
324
325// expected-error@+1 {{Undefined attributes are used: A}}
326tfr.func @Foo_undefined_attr(%arg0: !tfr.tensor<A>,
327              %arg1: !tfr.tensor_list<A>,
328              %arg2: i32 {tfr.name = "A"},
329              %arg3: vector<1xi32> {tfr.name = "C"}) ->
330    (!tfr.tensor<A>, !tfr.tensor_list<A>) {
331  tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
332}
333
334// -----
335
336// expected-error@+1 {{3 attribute argument doesn't have a tfr.name attribute}}
337tfr.func @Foo_unnamed_attr(%arg0: !tfr.tensor<A>,
338              %arg1: !tfr.tensor_list<A>,
339              %arg2: i32 {tfr.name = "A"},
340              %arg3: vector<1xi32>) ->
341    (!tfr.tensor<A>, !tfr.tensor_list<A>) {
342  tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
343}
344
345// -----
346
347// expected-error@+1 {{tfr.tensor/tfr.tensor_list argument should be before non tensor arguments}}
348tfr.func @Foo_invalid_arg_order(%arg0: !tfr.tensor<A>,
349              %arg2: i32 {tfr.name = "A"},
350              %arg1: !tfr.tensor_list<A>,
351              %arg3: vector<1xi32> {tfr.name = "C"}) ->
352    (!tfr.tensor<A>, !tfr.tensor_list<A>) {
353  tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
354}
355
356// -----
357
358tfr.func @Foo_valid_arg_order0(
359              %arg1: !tfr.tensor_list,
360              %arg0: !tfr.tensor<T>,
361              %arg2: i32 {tfr.name = "A"},
362              %arg3: vector<1xi32> {tfr.name = "C"}) ->
363    (!tfr.tensor, !tfr.tensor_list) attributes {T}{
364  tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list
365}
366
367// -----
368
369// expected-error@+1 {{tfr.tensor argument should be before tfr.tensor_list argument.}}
370tfr.func @Foo_invalid_arg_order0(
371              %arg1: !tfr.tensor_list,
372              %arg0: !tfr.tensor<T>,
373              %arg2: i32 {tfr.name = "A"},
374              %arg3: vector<1xi32> {tfr.name = "C"}) ->
375    (!tfr.tensor, !tfr.tensor_list) {
376  tfr.return %arg0, %arg1 : !tfr.tensor<T>, !tfr.tensor_list
377}
378
379// -----
380
381// expected-error@+1 {{tfr.tensor result should be before tfr.tensor_list result}}
382tfr.func @Foo_invalid_result_order(%arg0: !tfr.tensor<A>,
383              %arg1: !tfr.tensor_list<A>,
384              %arg2: i32 {tfr.name = "A"},
385              %arg3: vector<1xi32> {tfr.name = "C"}) ->
386    (!tfr.tensor_list<A>, !tfr.tensor<A>) {
387  tfr.return %arg1, %arg0 : !tfr.tensor_list<A>, !tfr.tensor<A>
388}
389
390// -----
391
392// expected-error@+1 {{More than one tfr.tensor_list argument isn't allowed}}
393tfr.func @Foo_multiple_tensor_list_args(%arg0: !tfr.tensor<A>,
394              %arg1: !tfr.tensor_list<A>,
395              %arg2: !tfr.tensor_list<A>,
396              %arg3: i32 {tfr.name = "A"},
397              %arg4: vector<1xi32> {tfr.name = "C"}) ->
398    (!tfr.tensor<A>, !tfr.tensor_list<A>) {
399  tfr.return %arg0, %arg1 : !tfr.tensor<A>, !tfr.tensor_list<A>
400}
401
402// -----
403
404// expected-error@+1 {{More than one tfr.tensor_list result isn't allowed}}
405tfr.func @Foo_multiple_tensor_list_results(%arg0: !tfr.tensor<C>,
406              %arg1: !tfr.tensor_list<A>,
407              %arg2: i32 {tfr.name = "A"},
408              %arg3: vector<1xi32> {tfr.name = "C"}) ->
409    (!tfr.tensor_list<A>, !tfr.tensor_list<A>) {
410  tfr.return %arg1, %arg1 : !tfr.tensor_list<A>, !tfr.tensor_list<A>
411}
412
413// -----
414
415// expected-error@+1 {{None tfr.tensor/tfr.tensor_list results aren't allowed as a result}}
416tfr.func @Foo_return_attr(%arg0: !tfr.tensor<C>,
417              %arg1: !tfr.tensor_list<A>,
418              %arg2: i32 {tfr.name = "A"},
419              %arg3: vector<1xi32> {tfr.name = "C"}) -> i32 {
420  tfr.return %arg2 : i32
421}
422