• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tfr_gen` module."""
16
17# pylint: disable=missing-function-docstring
18
19import sys
20
21from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw
22from tensorflow.compiler.mlir.tfr.python import composite
23from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module as tfr_gen
24from tensorflow.compiler.mlir.tfr.resources import gen_test_ops as test_ops
25from tensorflow.python.framework import dtypes
26from tensorflow.python.ops import gen_array_ops as array_ops
27from tensorflow.python.ops import gen_math_ops as math_ops
28from tensorflow.python.platform import test
29
30
31Composite = composite.Composite
32
33#--- test fn for mlir location ---
34
35
36@Composite('TestInputNOp')
37def _tfr_loc_test(x):
38  n = 10
39  x_sum = x[0]
40  for i in range(1, n):
41    x_sum = math_ops.Add(x_sum, x[i])
42  return x_sum
43
44
45#--- test fn for tfr tensors ---
46
47
48@composite.Composite('TestNoOp')
49def _tfr_tensor_empty_arg():
50  pass
51
52
53@composite.Composite('TestIdentityOp')
54def _tfr_tensor_tensor(x):
55  return x
56
57
58@composite.Composite('TestIdentityNOp')
59def _tfr_tensor_tensor_list(x):
60  return x
61
62
63@composite.Composite('TestInputNOp')
64def _tfr_tensor_tensor_list_get_elt(x):
65  return x[1]
66
67
68@composite.Composite('TestOutputNOp')
69def _tfr_tensor_tensor_list_output(x):
70  return [x, x]
71
72
73@composite.Composite('TestTwoInputsOp')
74def _tfr_tensor_tensor_list_split(x, y, pred):
75  z, _ = array_ops.Split(axis=0, value=x, num_split=2)
76  (y, pred)  # pylint: disable=pointless-statement
77  return z
78
79
80@composite.Composite('TestTwoOutputsOp')
81def _tfr_tensor_two_output(x):
82  z = array_ops.Split(axis=0, value=x, num_split=2)
83  return z[0], z[1]
84
85
86@composite.Composite('TestNumAttrsOp')
87def _tfr_tensor_tensor_with_cst(x1, y1, x2, y2):
88  x = array_ops.OneHot(
89      indices=[0, 2, -1, x1], depth=y1, on_value=True, off_value=False)
90  (x, x2, y2)  # pylint: disable=pointless-statement
91  return
92
93#--- test fn for scf control flow ---
94
95
96@composite.Composite('TestTwoInputsOp')
97def _tfr_control_flow_if(x, y, pred):
98  if pred:
99    return x
100  else:
101    return y
102
103
104@composite.Composite('TestThreeInputsOp')
105def _tfr_control_flow_nested_if(x, y, z, select):
106  if select == 'x':
107    return x
108  elif select == 'y':
109    return y
110  else:
111    return z
112
113
114@composite.Composite('TestInputNOp')
115def _tfr_control_flow_range_for(x):
116  # TODO(fengliuai): use len(x) instead
117  n = 10
118  x_sum = x[0]
119  for i in range(1, n):
120    x_sum = math_ops.Add(x_sum, x[i])
121  return x_sum
122
123
124@composite.Composite('TestInputNOp')
125def _tfr_control_flow_tensor_list_size(ins):
126  n = len(ins)
127  if n == 0:
128    return array_ops.Const(value=[[0, 1], [2, 3]], dtype=dtypes.int64)
129  else:
130    return math_ops.AddN(ins)
131
132
133#--- test fn for tf ops ---
134
135
136@composite.Composite('TestComplexTFOp')
137def _tfr_tf_ops_complex(lhs, rhs):
138  left_padding, _ = array_ops.SplitV(
139      value=lhs, size_splits=[rhs, -1], axis=0, num_split=2)
140  _, right_padding = array_ops.SplitV(
141      value=lhs, size_splits=[rhs, rhs], axis=1, num_split=2)
142  return [left_padding, right_padding]
143
144
145@composite.Composite('TestIdentityOp')
146def _tfr_tf_ops_tensor(x):
147  return array_ops.Identity(x)
148
149
150@composite.Composite('TestTwoInputsOp')
151def _tfr_tf_ops_tensors(x, y, pred):
152  if pred:
153    return math_ops.Add(x, y)
154  else:
155    return array_ops.Concat(0, [x, y])
156
157
158@composite.Composite('TestInputNOp')
159def _tfr_tf_ops_with_defaults(ins):
160  return test_ops.TestTwoInputsOp(ins[0], ins[1])
161
162
163#--- test fn for tfr attributes ---
164
165
166@composite.Composite('TestNumAttrsOp')
167def _tfr_attrs_num_type(x, y, x1, y1):
168  # int
169  z0 = [x, y]
170  z1 = x == y
171  z2 = x < y
172  z3 = x <= y
173  z4 = x > y
174  z5 = x >= y
175  z6 = x != y
176  z7 = x + y
177  z8 = x - y
178  z8 += x
179  z8 += 1
180  (z0, z1, z2, z3, z4, z5, z6, z7, z8)  # pylint: disable=pointless-statement
181
182  # float
183  z9 = x1 > y1
184  z10 = x1 + y1
185  z11 = [x1, y1]
186  (z9, z10, z11)  # pylint: disable=pointless-statement
187  return
188
189
190@composite.Composite('TestNonNumAttrsOp')
191def _tfr_attrs_tfr_type(x, y, z):
192  z1 = x == y
193  z2 = x == 'test'
194  z3 = y == z
195  (z1, z2, z3)  # pylint: disable=pointless-statement
196  return
197
198
199#--- test fn for shapes ---
200
201
202@composite.Composite('TestIdentityOp')
203def _tfr_shapes(x):
204  s1 = x.shape
205  s3 = x.shape.as_list()
206
207  for i in range(len(s3)):
208    s3[i]  # pylint: disable=pointless-statement
209
210  for i in range(1, len(s3), 2):
211    s3[i]  # pylint: disable=pointless-statement
212
213  s5 = array_ops.Shape(x)
214  (s1, s3, s5)  # pylint: disable=pointless-statement
215  return x
216
217
218#--- test fn for nested functions ---
219
220
221@composite.Composite('TestIdentityNOp')
222def _tfr_temp_op(x):
223  return x
224
225
226@composite.Composite('TestIdentityOp')
227def _tfr_temp_use_op(x):
228  y = _tfr_temp_op([x])
229  return y[0]
230
231#--- test fn for quant built-ins ---
232
233
234# pylint: disable=undefined-variable
235@composite.Composite('TestIdentityOp')
236def _tfr_quant_test(x):
237  y = _tfr_quant_raw_data(x)
238  s, z = _tfr_quant_qparam(x)
239  s = _tfr_quant_scale_factor(1.0, [s, s])
240  s = _tfr_quant_scale_factor(1.0, [s])
241  y = math_ops.Sub(y, z)
242  qmin, qmax = _tfr_quant_act_range('RELU', 1.0, 0)
243  (qmin, qmax)  # pylint: disable=pointless-statement
244  d = _tfr_quant_rescale(y, s, 0)
245  e = math_ops.Cast(x=d, DstT=dtypes.int16)
246  f = math_ops.Cast(x=e, DstT=dtypes.int8)
247  return f
248
249
250@composite.Composite('TestIdentityNOp')
251def _tfr_quant_test_n(x):
252  y = _tfr_quant_raw_data(x)
253  return y
254
255
256class TFRGenTestBase(test.TestCase):
257
258  def _check_code(self, tfr_code, exp_tfr_code):
259    return self.assertTrue(fw.check(str(tfr_code), exp_tfr_code), str(tfr_code))
260
261
262class TFRGenTensorTest(TFRGenTestBase):
263  """MLIR Generation Tests for MLIR TFR Program."""
264
265  def test_tfr_loc(self):
266    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_loc', [test_ops])
267    mlir_code_exp = r"""
268      CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) {
269      CHECK-NEXT:   %[[n:.*]] = arith.constant 10 : i64
270      CHECK-SAME        loc("tfr_gen_test.py":%{{.*}}:6)
271      CHECK-NEXT:   %[[cst:.*]] = arith.constant 0 : index
272      CHECK-SAME        loc("tfr_gen_test.py":%[[sum_line:.*]]:10)
273      CHECK-NEXT:   %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor
274      CHECK-SAME        loc("tfr_gen_test.py":%[[sum_line]]:10)
275      CHECK-NEXT:   %[[cst_1:.*]] = arith.constant 1 : i64
276      CHECK-SAME        loc("tfr_gen_test.py":%[[for_line:.*]]:2)
277      CHECK-NEXT:   %[[begin:.*]] = arith.index_cast %[[cst_1]] : i64 to index
278      CHECK-SAME        loc("tfr_gen_test.py":%[[for_line]]:2)
279      CHECK-NEXT:   %[[end:.*]] = arith.index_cast %[[n]] : i64 to index
280      CHECK-SAME        loc("tfr_gen_test.py":%[[for_line]]:2)
281      CHECK-NEXT:   %[[step:.*]] = arith.constant 1 : index
282      CHECK-SAME        loc("tfr_gen_test.py":%[[for_line]]:2)
283      CHECK-NEXT:   %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]]
284      CHECK-SAME:       iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) {
285      CHECK-NEXT:     %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor
286      CHECK-SAME        loc("tfr_gen_test.py":%[[add_line:.*]]:34)
287      CHECK-NEXT:     %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor)
288      CHECK-SAME        loc("tfr_gen_test.py":%[[add_line]]:12)
289      CHECK-NEXT:     scf.yield %[[Add]] : !tfr.tensor
290      CHECK-SAME        loc(unknown)
291      CHECK-NEXT:   }
292      CHECK-SAME        loc("tfr_gen_test.py":%[[for_line]]:2)
293      CHECK-NEXT:   %{{.*}} = arith.constant true
294      CHECK-SAME        loc(unknown)
295      CHECK-NEXT:   tfr.return %[[for_stmt]] : !tfr.tensor
296      CHECK-SAME        loc(unknown)
297      CHECK-NEXT: }
298      CHECK-SAME        loc("tfr_gen_test.py":%{{def_line:.*}}:0)
299    """
300    self._check_code(mlir_code, mlir_code_exp)
301
302  def test_tfr_tensors(self):
303    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tensor', [test_ops])
304    mlir_code_exp = r"""
305      CHECK-LABEL: tfr.func @tf__test_no_op() -> () {
306      CHECK-NEXT:    tfr.return
307      CHECK-NEXT: }
308
309      CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
310      CHECK-NEXT: constant true
311      CHECK-NEXT:    tfr.return %x : !tfr.tensor
312      CHECK-NEXT: }
313
314      CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) {
315      CHECK-NEXT: constant true
316      CHECK-NEXT:    tfr.return %x : !tfr.tensor_list
317      CHECK-NEXT: }
318
319      CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) {
320      CHECK-NEXT: constant true
321      CHECK-NEXT: %[[index:.*]] = arith.constant 1 : index
322      CHECK-NEXT: %[[sub:.*]] = tfr.get_element %x[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor
323      CHECK-NEXT: tfr.return %[[sub]] : !tfr.tensor
324      CHECK-NEXT: }
325
326      CHECK-LABEL: tfr.func @tf__test_output_n_op(%x: !tfr.tensor) -> (!tfr.tensor_list) {
327      CHECK-NEXT: constant true
328      CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %x) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
329      CHECK-NEXT: tfr.return %[[list]] : !tfr.tensor_list
330      CHECK-NEXT: }
331
332      CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) {
333      CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64
334      CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64
335      CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor
336      CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list)
337      CHECK-NEXT: %[[cst_4:.*]] = arith.constant 0 : index
338      CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%idx] : (!tfr.tensor_list, index) -> !tfr.tensor
339      CHECK-NEXT: %[[cst_5:.*]] = arith.constant 1 : index
340      CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor
341      CHECK-NEXT: constant true
342      CHECK-NEXT: tfr.return %[[elt]] : !tfr.tensor
343      CHECK-NEXT: }
344
345      CHECK-LABEL: tfr.func @tf__test_two_outputs_op(%x: !tfr.tensor) -> (!tfr.tensor, !tfr.tensor) {
346      CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64
347      CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64
348      CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor
349      CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list)
350      CHECK-NEXT: constant true
351      CHECK-NEXT: %[[cst_4:.*]] = arith.constant 0 : index
352      CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%cst_4] : (!tfr.tensor_list, index) -> !tfr.tensor
353      CHECK-NEXT: %[[cst_5:.*]] = arith.constant 1 : index
354      CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%cst_5] : (!tfr.tensor_list, index) -> !tfr.tensor
355      CHECK-NEXT: tfr.return %[[elt]], %[[elt_1]] : !tfr.tensor, !tfr.tensor
356      CHECK-NEXT: }
357
358      CHECK-LABEL: tfr.func @tf__test_num_attrs_op(%x1: i64{tfr.name="x1",tfr.default=-10}, %y1: i64{tfr.name="y1",tfr.default=1}, %x2: f32{tfr.name="x2",tfr.default=0.0}, %y2: f32{tfr.name="y2",tfr.default=-3.0}) -> () {
359      CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64
360      CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64
361      CHECK-NEXT: %[[cst_2:.*]] = arith.constant 1 : i64
362      CHECK-NEXT: %[[zero:.*]] = arith.constant 0 : i64
363      CHECK-NEXT: %[[cst_3:.*]] = arith.subi %zero, %cst_2 : i64
364      CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%[[cst]], %[[cst_1]], %[[cst_3]], %x1) : (i64, i64, i64, i64) -> !tfr.attr
365      CHECK-NEXT: %[[cst_4:.*]] = arith.constant true
366      CHECK-NEXT: %[[cst_5:.*]] = arith.constant false
367      CHECK-NEXT: %[[cst_6:.*]] = "tfr.constant_tensor"(%[[list]]) : (!tfr.attr) -> !tfr.tensor
368      CHECK-NEXT: %[[cst_7:.*]] = "tfr.constant_tensor"(%y1) : (i64) -> !tfr.tensor
369      CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_4]]) : (i1) -> !tfr.tensor
370      CHECK-NEXT: %[[cst_9:.*]] = "tfr.constant_tensor"(%[[cst_5]]) : (i1) -> !tfr.tensor
371      CHECK-NEXT: %[[cst_10:.*]] = arith.constant -1 : i64
372      CHECK-NEXT: %[[OneHot:.*]] = tfr.call @tf__one_hot(%[[cst_6]], %[[cst_7]], %[[cst_8]], %[[cst_9]], %[[cst_10]])
373      CHECK-SAME:   (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor)
374      CHECK-NEXT: constant true
375      CHECK-NEXT: tfr.return
376      CHECK-NEXT: }
377    """
378    self._check_code(mlir_code, mlir_code_exp)
379
380  def test_tfr_control_flow(self):
381    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_control_flow', [test_ops])
382    mlir_code_exp = r"""
383      CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor,
384      CHECK-SAME:     %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) {
385      CHECK-NEXT: %[[if:.*]] = scf.if %pred -> (!tfr.tensor) {
386      CHECK-NEXT:   arith.constant true
387      CHECK-NEXT:   scf.yield %x : !tfr.tensor
388      CHECK-NEXT: } else {
389      CHECK-NEXT:   arith.constant true
390      CHECK-NEXT:   scf.yield %y : !tfr.tensor
391      CHECK-NEXT:   }
392      CHECK-NEXT:   tfr.return %if_stmt : !tfr.tensor
393      CHECK-NEXT: }
394
395      CHECK-LABEL: tfr.func @tf__test_three_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %z: !tfr.tensor,
396      CHECK-SAME:     %select: !tfr.attr{tfr.name="act",tfr.default="z"}) -> (!tfr.tensor) {
397      CHECK-NEXT:   %[[cst:.*]] = tfr.constant "x" -> !tfr.attr
398      CHECK-NEXT:   %[[eq:.*]] = tfr.equal %select, %[[cst]] -> i1
399      CHECK-NEXT:   %[[if_stmt:.*]] = scf.if %[[eq]] -> (!tfr.tensor) {
400      CHECK-NEXT:     %[[cst_1:.*]] = arith.constant true
401      CHECK-NEXT:     scf.yield %x : !tfr.tensor
402      CHECK-NEXT:   } else {
403      CHECK-NEXT:     %[[cst_2:.*]] = tfr.constant "y" -> !tfr.attr
404      CHECK-NEXT:     %[[eq_1:.*]] = tfr.equal %select, %[[cst_2]] -> i1
405      CHECK-NEXT:     %[[if_stmt1:.*]] = scf.if %[[eq_1]] -> (!tfr.tensor) {
406      CHECK-NEXT:       %[[cst_3:.*]] = arith.constant true
407      CHECK-NEXT:       scf.yield %y : !tfr.tensor
408      CHECK-NEXT:     } else {
409      CHECK-NEXT:       %[[cst_4:.*]] = arith.constant true
410      CHECK-NEXT:       scf.yield %z : !tfr.tensor
411      CHECK-NEXT:     }
412      CHECK-NEXT:     scf.yield %[[if_stmt1]] : !tfr.tensor
413      CHECK-NEXT:   }
414      CHECK-NEXT:   tfr.return %[[if_stmt]] : !tfr.tensor
415      CHECK-NEXT: }
416
417      CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) {
418      CHECK-NEXT:   %[[n:.*]] = arith.constant 10 : i64
419      CHECK-NEXT:   %[[cst:.*]] = arith.constant 0 : index
420      CHECK-NEXT:   %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor
421      CHECK-NEXT:   %[[cst_1:.*]] = arith.constant 1 : i64
422      CHECK-NEXT:   %[[begin:.*]] = arith.index_cast %[[cst_1]] : i64 to index
423      CHECK-NEXT:   %[[end:.*]] = arith.index_cast %[[n]] : i64 to index
424      CHECK-NEXT:   %[[step:.*]] = arith.constant 1 : index
425      CHECK-NEXT:   %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]]
426      CHECK-SAME:       iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) {
427      CHECK-NEXT:     %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor
428      CHECK-NEXT:     %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor)
429      CHECK-NEXT:     scf.yield %[[Add]] : !tfr.tensor
430      CHECK-NEXT:   }
431      CHECK-NEXT:   %{{.*}} = arith.constant true
432      CHECK-NEXT:   tfr.return %[[for_stmt]] : !tfr.tensor
433      CHECK-NEXT: }
434
435      CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) {
436      CHECK: %[[attr:.*]] = tfr.constant i64 -> !tfr.attr
437      CHECK: %Const = tfr.call @tf__const(%{{.*}}, %[[attr]]) : (!tfr.attr, !tfr.attr) -> (!tfr.tensor)
438    """
439    self._check_code(mlir_code, mlir_code_exp)
440
441  def test_tfr_tf_ops(self):
442    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tf_ops', [test_ops])
443    mlir_code_exp = r"""
444      CHECK-LABEL: tfr.func @tf__test_complex_tf_op(%lhs: !tfr.tensor, %rhs: !tfr.tensor) -> (!tfr.tensor_list) {
445      CHECK-NEXT:   %[[cst:.*]] = arith.constant 1 : i64
446      CHECK-NEXT:   %[[zero:.*]] = arith.constant 0 : i64
447      CHECK-NEXT:   %[[cst_1:.*]] = arith.subi %[[zero]], %cst : i64
448      CHECK-NEXT:   %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst_1]]) : (i64) -> !tfr.tensor
449      CHECK-NEXT:   %[[list:.*]] = "tfr.build_list"(%rhs, %[[cst_2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
450      CHECK-NEXT:   %[[cst_3:.*]] = arith.constant 0 : i64
451      CHECK-NEXT:   %[[cst_4:.*]] = arith.constant 2 : i64
452      CHECK-NEXT:   %[[zero_1:.*]] = arith.constant 0 : i64
453      CHECK-NEXT:   %[[pack:.*]] = tfr.call @tf__pack(%[[list]], %[[zero_1]]) : (!tfr.tensor_list, i64) -> !tfr.tensor
454      CHECK-NEXT:   %[[cst_5:.*]] = "tfr.constant_tensor"(%[[cst_3]]) : (i64) -> !tfr.tensor
455      CHECK-NEXT:   %[[SplitV:.*]] = tfr.call @tf__split_v(%lhs, %[[pack]], %[[cst_5]], %[[cst_4]])
456      CHECK-NEXT:   %[[idx:.*]] = arith.constant 0 : index
457      CHECK-NEXT:   %[[elt:.*]] = tfr.get_element %SplitV[%idx] : (!tfr.tensor_list, index) -> !tfr.tensor
458      CHECK-NEXT:   %[[idx_1:.*]] = arith.constant 1 : index
459      CHECK-NEXT:   %[[elt_1:.*]] = tfr.get_element %SplitV[%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor
460      CHECK-NEXT:   %[[list_1:.*]] = "tfr.build_list"(%rhs, %rhs) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
461      CHECK-NEXT:   %[[cst_6:.*]] = arith.constant 1 : i64
462      CHECK-NEXT:   %[[cst_7:.*]] = arith.constant 2 : i64
463      CHECK-NEXT:   %[[zero_2:.*]] = arith.constant 0 : i64
464      CHECK-NEXT:   %[[pack_1:.*]] = tfr.call @tf__pack(%[[list_1]], %[[zero_2]]) : (!tfr.tensor_list, i64) -> !tfr.tensor
465      CHECK-NEXT:   %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_6]]) : (i64) -> !tfr.tensor
466      CHECK-NEXT:   %[[SplitV_1:.*]] = tfr.call @tf__split_v(%lhs, %[[pack_1]], %[[cst_8]], %[[cst_7]])
467      CHECK-NEXT:   %[[idx_2:.*]] = arith.constant 0 : index
468      CHECK-NEXT:   %[[elt_2:.*]] = tfr.get_element %SplitV_1[%idx_2] : (!tfr.tensor_list, index) -> !tfr.tensor
469      CHECK-NEXT:   %[[idx_3:.*]] = arith.constant 1 : index
470      CHECK-NEXT:   %[[elt_3:.*]] = tfr.get_element %SplitV_1[%idx_3] : (!tfr.tensor_list, index) -> !tfr.tensor
471      CHECK-NEXT:   %[[cst_9:.*]] = arith.constant true
472      CHECK-NEXT:   %[[list_2:.*]] = "tfr.build_list"(%[[elt]], %[[elt_3]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
473      CHECK-NEXT:   tfr.return %[[list_2]] : !tfr.tensor_list
474      CHECK-NEXT:   }
475
476      CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
477      CHECK-NEXT:    %cst = arith.constant true
478      CHECK-NEXT:    %[[Id:.*]] = tfr.call @tf__identity(%x) : (!tfr.tensor) -> (!tfr.tensor)
479      CHECK-NEXT:    tfr.return %[[Id]] : !tfr.tensor
480      CHECK-NEXT: }
481
482      CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor,
483      CHECK-SAME:     %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) {
484      CHECK-NEXT:   %[[if_stmt:.*]] = scf.if %pred -> (!tfr.tensor) {
485      CHECK-NEXT:     %cst = arith.constant true
486      CHECK-NEXT:     %[[Add:.*]] = tfr.call @tf__add(%x, %y) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor)
487      CHECK-NEXT:     scf.yield %[[Add]] : !tfr.tensor
488      CHECK-NEXT:   } else {
489      CHECK-NEXT:     %cst_1 = arith.constant true
490      CHECK-NEXT:     %[[cst_2:.*]] = arith.constant 0 : i64
491      CHECK-NEXT:     %[[list:.*]] = "tfr.build_list"(%x, %y) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
492      CHECK-NEXT:     %[[Concat:.*]] = tfr.call @tf__concat(%[[cst_2]], %[[list]]) : (i64, !tfr.tensor_list) -> (!tfr.tensor)
493      CHECK-NEXT:     scf.yield %[[Concat]] : !tfr.tensor
494      CHECK-NEXT:   }
495      CHECK-NEXT:   tfr.return %[[if_stmt]] : !tfr.tensor
496      CHECK-NEXT: }
497
498      CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) {
499      CHECK-NEXT:   %cst = arith.constant true
500      CHECK-NEXT:   %[[cst_1:.*]] = arith.constant 0 : index
501      CHECK-NEXT:   %[[elt:.*]] = tfr.get_element %ins[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor
502      CHECK-NEXT:   %[[cst_2:.*]] = arith.constant 1 : index
503      CHECK-NEXT:   %[[elt_1:.*]] = tfr.get_element %ins[%cst_2] : (!tfr.tensor_list, index) -> !tfr.tensor
504      CHECK-NEXT:   %[[cst_3:.*]] = arith.constant false
505      CHECK-NEXT:   %[[call:.*]] = tfr.call @tf__test_two_inputs_op(
506      CHECK-SAME:     %[[elt]], %[[elt_1]], %[[cst_3]]) : (!tfr.tensor, !tfr.tensor, i1) -> (!tfr.tensor)
507      CHECK-NEXT:   tfr.return %[[call]] : !tfr.tensor
508      CHECK-NEXT: }
509
510      CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor<T>,!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
511
512      CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor<i32_>,!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_}
513
514      CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
515
516      CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list<N,T>,i64{tfr.name="axis",tfr.type="int"}) -> (!tfr.tensor<T>) attributes {N,T,axis,f32_,i1_,i32_,i64_}
517
518      CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor<T>,!tfr.tensor<Tlen>,!tfr.tensor<i32_>,i64{tfr.name="num_split",tfr.type="int"}) -> (!tfr.tensor_list<num_split,T>) attributes {T,Tlen,f32_,i1_,i32_,i64_,num_split}
519
520      CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor<T>,!tfr.tensor<Tlen>,i64{tfr.name="N",tfr.type="int"}) -> (!tfr.tensor_list<N,T>) attributes {N,T,Tlen,f32_,i1_,i32_,i64_}
521
522      CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
523
524      CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_}
525
526      CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor<T>,!tfr.tensor<T>,i1{tfr.name="pred",tfr.type="bool"}) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_,pred}
527
528      CHECK-LABEL: tfr.func @tf__test_two_outputs_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>,!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
529    """
530    self._check_code(mlir_code, mlir_code_exp)
531
532  def test_tfr_attrs(self):
533    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_attrs', [test_ops])
534    mlir_code_exp = r"""
535      CHECK-LABEL: tfr.func @tf__test_num_attrs_op(
536      CHECK-SAME:     %x: i64{tfr.name="x1",tfr.default=-10},
537      CHECK-SAME:     %y: i64{tfr.name="y1",tfr.default=1},
538      CHECK-SAME:     %x1: f32{tfr.name="x2",tfr.default=0.0},
539      CHECK-SAME:     %y1: f32{tfr.name="y2",tfr.default=-3.0}) -> () {
540      CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x, %y) : (i64, i64) -> !tfr.attr
541      CHECK-NEXT: %{{.*}} = arith.cmpi "eq", %x, %y : i64
542      CHECK-NEXT: %{{.*}} = arith.cmpi "ult", %x, %y : i64
543      CHECK-NEXT: %{{.*}} = arith.cmpi "ule", %x, %y : i64
544      CHECK-NEXT: %{{.*}} = arith.cmpi "ugt", %x, %y : i64
545      CHECK-NEXT: %{{.*}} = arith.cmpi "uge", %x, %y : i64
546      CHECK-NEXT: %{{.*}} = arith.cmpi "ne", %x, %y : i64
547      CHECK-NEXT: %{{.*}} = arith.addi %x, %y : i64
548      CHECK-NEXT: %[[sub_1:.*]] = arith.subi %x, %y : i64
549      CHECK-NEXT: %[[add_1:.*]] = arith.addi %[[sub_1]], %x : i64
550      CHECK-NEXT: %[[cst:.*]] = arith.constant 1 : i64
551      CHECK-NEXT: %{{.*}} = arith.addi %[[add_1]], %[[cst]] : i64
552      CHECK-NEXT: %{{.*}} = arith.cmpf "ugt", %x1, %y1 : f32
553      CHECK-NEXT: %{{.*}} = arith.addf %x1, %y1 : f32
554      CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x1, %y1) : (f32, f32) -> !tfr.attr
555      CHECK-NEXT: %{{.*}} = arith.constant true
556      CHECK-NEXT: tfr.return
557      CHECK-NEXT: }
558
559      CHECK-LABEL: tfr.func @tf__test_non_num_attrs_op(
560      CHECK-SAME:     %x: !tfr.attr{tfr.name="z"},
561      CHECK-SAME:     %y: !tfr.attr{tfr.name="x",tfr.default="hello"},
562      CHECK-SAME:     %z: !tfr.attr{tfr.name="y",tfr.default=f32}) -> () {
563      CHECK-NEXT: %{{.*}} = tfr.equal %x, %y -> i1
564      CHECK-NEXT: %[[cst:.*]] = tfr.constant "test" -> !tfr.attr
565      CHECK-NEXT: %{{.*}} = tfr.equal %x, %[[cst]] -> i1
566      CHECK-NEXT: %{{.*}} = tfr.equal %y, %z -> i1
567      CHECK-NEXT: %{{.*}} = arith.constant true
568      CHECK-NEXT: tfr.return
569      CHECK-NEXT: }
570    """
571    self._check_code(mlir_code, mlir_code_exp)
572
573  def test_tf_tensor_shape(self):
574    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_shapes', [test_ops])
575    mlir_code_exp = r"""
576      CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
577      CHECK-NEXT:   %[[shape:.*]] = tfr.get_shape %x -> !shape.shape
578
579      CHECK-NEXT:   %[[shape_1:.*]] = tfr.get_shape %x -> !shape.shape
580      CHECK-NEXT:   %[[len:.*]] = shape.rank %[[shape_1]] : !shape.shape -> !shape.size
581      CHECK-NEXT:   %[[index:.*]] = shape.size_to_index %[[len]] : !shape.size
582      CHECK-NEXT:   %[[begin:.*]] = arith.constant 0 : index
583      CHECK-NEXT:   %[[step:.*]] = arith.constant 1 : index
584      CHECK-NEXT:   scf.for %[[itr_1:.*]] = %[[begin]] to %[[index]] step %[[step]]  {
585      CHECK-NEXT:     %[[size:.*]] = shape.get_extent %[[shape_1]], %[[itr_1]]: !shape.shape, index -> !shape.size
586      CHECK-NEXT:     %[[elt:.*]] = shape.size_to_index %[[size]] : !shape.size
587      CHECK-NEXT:     scf.yield
588      CHECK-NEXT:   }
589
590      CHECK-NEXT:   %[[cst:.*]] = arith.constant 1 : i64
591      CHECK-NEXT:   %[[len_1:.*]] = shape.rank %shape_1 : !shape.shape -> !shape.size
592      CHECK-NEXT:   %[[len_size_1:.*]] = shape.size_to_index %[[len_1]] : !shape.size
593      CHECK-NEXT:   %[[cst_1:.*]] = arith.constant 2 : i64
594      CHECK-NEXT:   %[[begin_1:.*]] = arith.index_cast %[[cst]] : i64 to index
595      CHECK-NEXT:   %[[step_1:.*]] = arith.index_cast %[[cst_1]] : i64 to index
596      CHECK-NEXT:   scf.for %[[itr_3:.*]] = %[[begin_1]] to %[[len_size_1]] step %[[step_1]]
597
598      CHECK:        %[[cst:.*]] = tfr.constant i32 -> !tfr.attr
599      CHECK-NEXT:   %[[Shape:.*]] = tfr.call @tf__shape(%x, %[[cst]]) : (!tfr.tensor, !tfr.attr) -> (!tfr.tensor)
600      CHECK-NEXT:   %{{.*}} = arith.constant true
601      CHECK-NEXT:   tfr.return %x : !tfr.tensor
602      CHECK-NEXT: }
603    """
604    self._check_code(mlir_code, mlir_code_exp)
605
606  def test_temp_function(self):
607    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_temp', [test_ops])
608    mlir_code_exp = r"""
609      CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list)
610
611      CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
612      CHECK-NEXT:   %[[list:.*]] = "tfr.build_list"(%x) : (!tfr.tensor) -> !tfr.tensor_list
613      CHECK-NEXT:   %[[call:.*]] = tfr.call @tf__test_identity_n_op(%[[list]]) : (!tfr.tensor_list)
614    """
615    self._check_code(mlir_code, mlir_code_exp)
616
617  def test_quant_builtins(self):
618    mlir_code = tfr_gen(sys.modules[__name__], '_tfr_quant', [test_ops])
619    mlir_code_exp = r"""
620      CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) {
621      CHECK-NEXT:   %[[raw_data:.*]] = tfr.quant_raw_data(%x) : (!tfr.tensor) -> (!tfr.tensor)
622      CHECK-NEXT:   %[[qparam:.*]]:2 = tfr.quant_qparam(%x) : (!tfr.tensor) -> (!tfr.tensor, !tfr.tensor)
623      CHECK:        %[[list:.*]] = "tfr.build_list"(%[[qparam]]#0, %[[qparam]]#0) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list
624      CHECK:        %[[factor:.*]] = tfr.quant_scale_factor(%{{.*}}, %[[list]]) : (f32, !tfr.tensor_list) -> (!tfr.tensor)
625      CHECK:        %[[list1:.*]] = "tfr.build_list"(%[[factor]]) : (!tfr.tensor) -> !tfr.tensor_list
626      CHECK:        %[[factor1:.*]] = tfr.quant_scale_factor(%{{.*}}, %[[list1]]) : (f32, !tfr.tensor_list) -> (!tfr.tensor)
627      CHECK-NEXT:   %[[Sub:.*]] = tfr.call @tf__sub(%[[raw_data]], %[[qparam]]#1) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor)
628      CHECK:        %[[act_range:.*]]:2 = tfr.quant_act_range(%{{.*}}, %{{.*}}, %{{.*}}) : (!tfr.attr, f32, i64) -> (!tfr.tensor, !tfr.tensor)
629      CHECK:        %[[rescale:.*]] = tfr.quant_rescale(%[[Sub]], %[[factor1]], %{{.*}}) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor)
630      CHECK:        %[[attr:.*]] = tfr.constant i16 -> !tfr.attr
631      CHECK:        %[[Cast:.*]] = tfr.call @tf__cast(%[[rescale]], %[[attr]], %{{.*}}) : (!tfr.tensor, !tfr.attr, i1) -> (!tfr.tensor)
632      CHECK:        %[[attr_1:.*]] = tfr.constant i8 -> !tfr.attr
633      CHECK:        tfr.call @tf__cast(%[[Cast]], %[[attr_1]], %{{.*}}) : (!tfr.tensor, !tfr.attr, i1) -> (!tfr.tensor)
634      CHECK:       }
635
636      CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) {
637      CHECK-NEXT:   %[[raw_data:.*]] = tfr.quant_raw_data(%x) : (!tfr.tensor_list) -> (!tfr.tensor_list)
638      CHECK:        tfr.return %[[raw_data:.*]] : !tfr.tensor_list
639      CHECK:       }
640    """
641    self._check_code(mlir_code, mlir_code_exp)
642
643
644if __name__ == '__main__':
645  test.main()
646