1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tfr_gen` module.""" 16 17# pylint: disable=missing-function-docstring 18 19import sys 20 21from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw 22from tensorflow.compiler.mlir.tfr.python import composite 23from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module as tfr_gen 24from tensorflow.compiler.mlir.tfr.resources import gen_test_ops as test_ops 25from tensorflow.python.framework import dtypes 26from tensorflow.python.ops import gen_array_ops as array_ops 27from tensorflow.python.ops import gen_math_ops as math_ops 28from tensorflow.python.platform import test 29 30 31Composite = composite.Composite 32 33#--- test fn for mlir location --- 34 35 36@Composite('TestInputNOp') 37def _tfr_loc_test(x): 38 n = 10 39 x_sum = x[0] 40 for i in range(1, n): 41 x_sum = math_ops.Add(x_sum, x[i]) 42 return x_sum 43 44 45#--- test fn for tfr tensors --- 46 47 48@composite.Composite('TestNoOp') 49def _tfr_tensor_empty_arg(): 50 pass 51 52 53@composite.Composite('TestIdentityOp') 54def _tfr_tensor_tensor(x): 55 return x 56 57 58@composite.Composite('TestIdentityNOp') 59def _tfr_tensor_tensor_list(x): 60 return x 61 62 63@composite.Composite('TestInputNOp') 64def _tfr_tensor_tensor_list_get_elt(x): 65 return x[1] 66 67 68@composite.Composite('TestOutputNOp') 69def _tfr_tensor_tensor_list_output(x): 70 return [x, x] 71 72 73@composite.Composite('TestTwoInputsOp') 74def _tfr_tensor_tensor_list_split(x, y, pred): 75 z, _ = array_ops.Split(axis=0, value=x, num_split=2) 76 (y, pred) # pylint: disable=pointless-statement 77 return z 78 79 80@composite.Composite('TestTwoOutputsOp') 81def _tfr_tensor_two_output(x): 82 z = array_ops.Split(axis=0, value=x, num_split=2) 83 return z[0], z[1] 84 85 86@composite.Composite('TestNumAttrsOp') 87def _tfr_tensor_tensor_with_cst(x1, y1, x2, y2): 88 x = array_ops.OneHot( 89 indices=[0, 2, -1, x1], depth=y1, on_value=True, off_value=False) 90 (x, x2, y2) # pylint: disable=pointless-statement 91 return 92 93#--- test fn for scf control flow --- 94 95 96@composite.Composite('TestTwoInputsOp') 97def _tfr_control_flow_if(x, y, pred): 98 if pred: 99 return x 100 else: 101 return y 102 103 104@composite.Composite('TestThreeInputsOp') 105def _tfr_control_flow_nested_if(x, y, z, select): 106 if select == 'x': 107 return x 108 elif select == 'y': 109 return y 110 else: 111 return z 112 113 114@composite.Composite('TestInputNOp') 115def _tfr_control_flow_range_for(x): 116 # TODO(fengliuai): use len(x) instead 117 n = 10 118 x_sum = x[0] 119 for i in range(1, n): 120 x_sum = math_ops.Add(x_sum, x[i]) 121 return x_sum 122 123 124@composite.Composite('TestInputNOp') 125def _tfr_control_flow_tensor_list_size(ins): 126 n = len(ins) 127 if n == 0: 128 return array_ops.Const(value=[[0, 1], [2, 3]], dtype=dtypes.int64) 129 else: 130 return math_ops.AddN(ins) 131 132 133#--- test fn for tf ops --- 134 135 136@composite.Composite('TestComplexTFOp') 137def _tfr_tf_ops_complex(lhs, rhs): 138 left_padding, _ = array_ops.SplitV( 139 value=lhs, size_splits=[rhs, -1], axis=0, num_split=2) 140 _, right_padding = array_ops.SplitV( 141 value=lhs, size_splits=[rhs, rhs], axis=1, num_split=2) 142 return [left_padding, right_padding] 143 144 145@composite.Composite('TestIdentityOp') 146def _tfr_tf_ops_tensor(x): 147 return array_ops.Identity(x) 148 149 150@composite.Composite('TestTwoInputsOp') 151def _tfr_tf_ops_tensors(x, y, pred): 152 if pred: 153 return math_ops.Add(x, y) 154 else: 155 return array_ops.Concat(0, [x, y]) 156 157 158@composite.Composite('TestInputNOp') 159def _tfr_tf_ops_with_defaults(ins): 160 return test_ops.TestTwoInputsOp(ins[0], ins[1]) 161 162 163#--- test fn for tfr attributes --- 164 165 166@composite.Composite('TestNumAttrsOp') 167def _tfr_attrs_num_type(x, y, x1, y1): 168 # int 169 z0 = [x, y] 170 z1 = x == y 171 z2 = x < y 172 z3 = x <= y 173 z4 = x > y 174 z5 = x >= y 175 z6 = x != y 176 z7 = x + y 177 z8 = x - y 178 z8 += x 179 z8 += 1 180 (z0, z1, z2, z3, z4, z5, z6, z7, z8) # pylint: disable=pointless-statement 181 182 # float 183 z9 = x1 > y1 184 z10 = x1 + y1 185 z11 = [x1, y1] 186 (z9, z10, z11) # pylint: disable=pointless-statement 187 return 188 189 190@composite.Composite('TestNonNumAttrsOp') 191def _tfr_attrs_tfr_type(x, y, z): 192 z1 = x == y 193 z2 = x == 'test' 194 z3 = y == z 195 (z1, z2, z3) # pylint: disable=pointless-statement 196 return 197 198 199#--- test fn for shapes --- 200 201 202@composite.Composite('TestIdentityOp') 203def _tfr_shapes(x): 204 s1 = x.shape 205 s3 = x.shape.as_list() 206 207 for i in range(len(s3)): 208 s3[i] # pylint: disable=pointless-statement 209 210 for i in range(1, len(s3), 2): 211 s3[i] # pylint: disable=pointless-statement 212 213 s5 = array_ops.Shape(x) 214 (s1, s3, s5) # pylint: disable=pointless-statement 215 return x 216 217 218#--- test fn for nested functions --- 219 220 221@composite.Composite('TestIdentityNOp') 222def _tfr_temp_op(x): 223 return x 224 225 226@composite.Composite('TestIdentityOp') 227def _tfr_temp_use_op(x): 228 y = _tfr_temp_op([x]) 229 return y[0] 230 231#--- test fn for quant built-ins --- 232 233 234# pylint: disable=undefined-variable 235@composite.Composite('TestIdentityOp') 236def _tfr_quant_test(x): 237 y = _tfr_quant_raw_data(x) 238 s, z = _tfr_quant_qparam(x) 239 s = _tfr_quant_scale_factor(1.0, [s, s]) 240 s = _tfr_quant_scale_factor(1.0, [s]) 241 y = math_ops.Sub(y, z) 242 qmin, qmax = _tfr_quant_act_range('RELU', 1.0, 0) 243 (qmin, qmax) # pylint: disable=pointless-statement 244 d = _tfr_quant_rescale(y, s, 0) 245 e = math_ops.Cast(x=d, DstT=dtypes.int16) 246 f = math_ops.Cast(x=e, DstT=dtypes.int8) 247 return f 248 249 250@composite.Composite('TestIdentityNOp') 251def _tfr_quant_test_n(x): 252 y = _tfr_quant_raw_data(x) 253 return y 254 255 256class TFRGenTestBase(test.TestCase): 257 258 def _check_code(self, tfr_code, exp_tfr_code): 259 return self.assertTrue(fw.check(str(tfr_code), exp_tfr_code), str(tfr_code)) 260 261 262class TFRGenTensorTest(TFRGenTestBase): 263 """MLIR Generation Tests for MLIR TFR Program.""" 264 265 def test_tfr_loc(self): 266 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_loc', [test_ops]) 267 mlir_code_exp = r""" 268 CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { 269 CHECK-NEXT: %[[n:.*]] = arith.constant 10 : i64 270 CHECK-SAME loc("tfr_gen_test.py":%{{.*}}:6) 271 CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : index 272 CHECK-SAME loc("tfr_gen_test.py":%[[sum_line:.*]]:10) 273 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor 274 CHECK-SAME loc("tfr_gen_test.py":%[[sum_line]]:10) 275 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 1 : i64 276 CHECK-SAME loc("tfr_gen_test.py":%[[for_line:.*]]:2) 277 CHECK-NEXT: %[[begin:.*]] = arith.index_cast %[[cst_1]] : i64 to index 278 CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) 279 CHECK-NEXT: %[[end:.*]] = arith.index_cast %[[n]] : i64 to index 280 CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) 281 CHECK-NEXT: %[[step:.*]] = arith.constant 1 : index 282 CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) 283 CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] 284 CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { 285 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor 286 CHECK-SAME loc("tfr_gen_test.py":%[[add_line:.*]]:34) 287 CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) 288 CHECK-SAME loc("tfr_gen_test.py":%[[add_line]]:12) 289 CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor 290 CHECK-SAME loc(unknown) 291 CHECK-NEXT: } 292 CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) 293 CHECK-NEXT: %{{.*}} = arith.constant true 294 CHECK-SAME loc(unknown) 295 CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor 296 CHECK-SAME loc(unknown) 297 CHECK-NEXT: } 298 CHECK-SAME loc("tfr_gen_test.py":%{{def_line:.*}}:0) 299 """ 300 self._check_code(mlir_code, mlir_code_exp) 301 302 def test_tfr_tensors(self): 303 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tensor', [test_ops]) 304 mlir_code_exp = r""" 305 CHECK-LABEL: tfr.func @tf__test_no_op() -> () { 306 CHECK-NEXT: tfr.return 307 CHECK-NEXT: } 308 309 CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { 310 CHECK-NEXT: constant true 311 CHECK-NEXT: tfr.return %x : !tfr.tensor 312 CHECK-NEXT: } 313 314 CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) { 315 CHECK-NEXT: constant true 316 CHECK-NEXT: tfr.return %x : !tfr.tensor_list 317 CHECK-NEXT: } 318 319 CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { 320 CHECK-NEXT: constant true 321 CHECK-NEXT: %[[index:.*]] = arith.constant 1 : index 322 CHECK-NEXT: %[[sub:.*]] = tfr.get_element %x[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor 323 CHECK-NEXT: tfr.return %[[sub]] : !tfr.tensor 324 CHECK-NEXT: } 325 326 CHECK-LABEL: tfr.func @tf__test_output_n_op(%x: !tfr.tensor) -> (!tfr.tensor_list) { 327 CHECK-NEXT: constant true 328 CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %x) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 329 CHECK-NEXT: tfr.return %[[list]] : !tfr.tensor_list 330 CHECK-NEXT: } 331 332 CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { 333 CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64 334 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64 335 CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor 336 CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) 337 CHECK-NEXT: %[[cst_4:.*]] = arith.constant 0 : index 338 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%idx] : (!tfr.tensor_list, index) -> !tfr.tensor 339 CHECK-NEXT: %[[cst_5:.*]] = arith.constant 1 : index 340 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor 341 CHECK-NEXT: constant true 342 CHECK-NEXT: tfr.return %[[elt]] : !tfr.tensor 343 CHECK-NEXT: } 344 345 CHECK-LABEL: tfr.func @tf__test_two_outputs_op(%x: !tfr.tensor) -> (!tfr.tensor, !tfr.tensor) { 346 CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64 347 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64 348 CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor 349 CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) 350 CHECK-NEXT: constant true 351 CHECK-NEXT: %[[cst_4:.*]] = arith.constant 0 : index 352 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%cst_4] : (!tfr.tensor_list, index) -> !tfr.tensor 353 CHECK-NEXT: %[[cst_5:.*]] = arith.constant 1 : index 354 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%cst_5] : (!tfr.tensor_list, index) -> !tfr.tensor 355 CHECK-NEXT: tfr.return %[[elt]], %[[elt_1]] : !tfr.tensor, !tfr.tensor 356 CHECK-NEXT: } 357 358 CHECK-LABEL: tfr.func @tf__test_num_attrs_op(%x1: i64{tfr.name="x1",tfr.default=-10}, %y1: i64{tfr.name="y1",tfr.default=1}, %x2: f32{tfr.name="x2",tfr.default=0.0}, %y2: f32{tfr.name="y2",tfr.default=-3.0}) -> () { 359 CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : i64 360 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64 361 CHECK-NEXT: %[[cst_2:.*]] = arith.constant 1 : i64 362 CHECK-NEXT: %[[zero:.*]] = arith.constant 0 : i64 363 CHECK-NEXT: %[[cst_3:.*]] = arith.subi %zero, %cst_2 : i64 364 CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%[[cst]], %[[cst_1]], %[[cst_3]], %x1) : (i64, i64, i64, i64) -> !tfr.attr 365 CHECK-NEXT: %[[cst_4:.*]] = arith.constant true 366 CHECK-NEXT: %[[cst_5:.*]] = arith.constant false 367 CHECK-NEXT: %[[cst_6:.*]] = "tfr.constant_tensor"(%[[list]]) : (!tfr.attr) -> !tfr.tensor 368 CHECK-NEXT: %[[cst_7:.*]] = "tfr.constant_tensor"(%y1) : (i64) -> !tfr.tensor 369 CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_4]]) : (i1) -> !tfr.tensor 370 CHECK-NEXT: %[[cst_9:.*]] = "tfr.constant_tensor"(%[[cst_5]]) : (i1) -> !tfr.tensor 371 CHECK-NEXT: %[[cst_10:.*]] = arith.constant -1 : i64 372 CHECK-NEXT: %[[OneHot:.*]] = tfr.call @tf__one_hot(%[[cst_6]], %[[cst_7]], %[[cst_8]], %[[cst_9]], %[[cst_10]]) 373 CHECK-SAME: (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor) 374 CHECK-NEXT: constant true 375 CHECK-NEXT: tfr.return 376 CHECK-NEXT: } 377 """ 378 self._check_code(mlir_code, mlir_code_exp) 379 380 def test_tfr_control_flow(self): 381 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_control_flow', [test_ops]) 382 mlir_code_exp = r""" 383 CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, 384 CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { 385 CHECK-NEXT: %[[if:.*]] = scf.if %pred -> (!tfr.tensor) { 386 CHECK-NEXT: arith.constant true 387 CHECK-NEXT: scf.yield %x : !tfr.tensor 388 CHECK-NEXT: } else { 389 CHECK-NEXT: arith.constant true 390 CHECK-NEXT: scf.yield %y : !tfr.tensor 391 CHECK-NEXT: } 392 CHECK-NEXT: tfr.return %if_stmt : !tfr.tensor 393 CHECK-NEXT: } 394 395 CHECK-LABEL: tfr.func @tf__test_three_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %z: !tfr.tensor, 396 CHECK-SAME: %select: !tfr.attr{tfr.name="act",tfr.default="z"}) -> (!tfr.tensor) { 397 CHECK-NEXT: %[[cst:.*]] = tfr.constant "x" -> !tfr.attr 398 CHECK-NEXT: %[[eq:.*]] = tfr.equal %select, %[[cst]] -> i1 399 CHECK-NEXT: %[[if_stmt:.*]] = scf.if %[[eq]] -> (!tfr.tensor) { 400 CHECK-NEXT: %[[cst_1:.*]] = arith.constant true 401 CHECK-NEXT: scf.yield %x : !tfr.tensor 402 CHECK-NEXT: } else { 403 CHECK-NEXT: %[[cst_2:.*]] = tfr.constant "y" -> !tfr.attr 404 CHECK-NEXT: %[[eq_1:.*]] = tfr.equal %select, %[[cst_2]] -> i1 405 CHECK-NEXT: %[[if_stmt1:.*]] = scf.if %[[eq_1]] -> (!tfr.tensor) { 406 CHECK-NEXT: %[[cst_3:.*]] = arith.constant true 407 CHECK-NEXT: scf.yield %y : !tfr.tensor 408 CHECK-NEXT: } else { 409 CHECK-NEXT: %[[cst_4:.*]] = arith.constant true 410 CHECK-NEXT: scf.yield %z : !tfr.tensor 411 CHECK-NEXT: } 412 CHECK-NEXT: scf.yield %[[if_stmt1]] : !tfr.tensor 413 CHECK-NEXT: } 414 CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor 415 CHECK-NEXT: } 416 417 CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { 418 CHECK-NEXT: %[[n:.*]] = arith.constant 10 : i64 419 CHECK-NEXT: %[[cst:.*]] = arith.constant 0 : index 420 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor 421 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 1 : i64 422 CHECK-NEXT: %[[begin:.*]] = arith.index_cast %[[cst_1]] : i64 to index 423 CHECK-NEXT: %[[end:.*]] = arith.index_cast %[[n]] : i64 to index 424 CHECK-NEXT: %[[step:.*]] = arith.constant 1 : index 425 CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] 426 CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { 427 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor 428 CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) 429 CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor 430 CHECK-NEXT: } 431 CHECK-NEXT: %{{.*}} = arith.constant true 432 CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor 433 CHECK-NEXT: } 434 435 CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) { 436 CHECK: %[[attr:.*]] = tfr.constant i64 -> !tfr.attr 437 CHECK: %Const = tfr.call @tf__const(%{{.*}}, %[[attr]]) : (!tfr.attr, !tfr.attr) -> (!tfr.tensor) 438 """ 439 self._check_code(mlir_code, mlir_code_exp) 440 441 def test_tfr_tf_ops(self): 442 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tf_ops', [test_ops]) 443 mlir_code_exp = r""" 444 CHECK-LABEL: tfr.func @tf__test_complex_tf_op(%lhs: !tfr.tensor, %rhs: !tfr.tensor) -> (!tfr.tensor_list) { 445 CHECK-NEXT: %[[cst:.*]] = arith.constant 1 : i64 446 CHECK-NEXT: %[[zero:.*]] = arith.constant 0 : i64 447 CHECK-NEXT: %[[cst_1:.*]] = arith.subi %[[zero]], %cst : i64 448 CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst_1]]) : (i64) -> !tfr.tensor 449 CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%rhs, %[[cst_2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 450 CHECK-NEXT: %[[cst_3:.*]] = arith.constant 0 : i64 451 CHECK-NEXT: %[[cst_4:.*]] = arith.constant 2 : i64 452 CHECK-NEXT: %[[zero_1:.*]] = arith.constant 0 : i64 453 CHECK-NEXT: %[[pack:.*]] = tfr.call @tf__pack(%[[list]], %[[zero_1]]) : (!tfr.tensor_list, i64) -> !tfr.tensor 454 CHECK-NEXT: %[[cst_5:.*]] = "tfr.constant_tensor"(%[[cst_3]]) : (i64) -> !tfr.tensor 455 CHECK-NEXT: %[[SplitV:.*]] = tfr.call @tf__split_v(%lhs, %[[pack]], %[[cst_5]], %[[cst_4]]) 456 CHECK-NEXT: %[[idx:.*]] = arith.constant 0 : index 457 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %SplitV[%idx] : (!tfr.tensor_list, index) -> !tfr.tensor 458 CHECK-NEXT: %[[idx_1:.*]] = arith.constant 1 : index 459 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %SplitV[%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor 460 CHECK-NEXT: %[[list_1:.*]] = "tfr.build_list"(%rhs, %rhs) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 461 CHECK-NEXT: %[[cst_6:.*]] = arith.constant 1 : i64 462 CHECK-NEXT: %[[cst_7:.*]] = arith.constant 2 : i64 463 CHECK-NEXT: %[[zero_2:.*]] = arith.constant 0 : i64 464 CHECK-NEXT: %[[pack_1:.*]] = tfr.call @tf__pack(%[[list_1]], %[[zero_2]]) : (!tfr.tensor_list, i64) -> !tfr.tensor 465 CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_6]]) : (i64) -> !tfr.tensor 466 CHECK-NEXT: %[[SplitV_1:.*]] = tfr.call @tf__split_v(%lhs, %[[pack_1]], %[[cst_8]], %[[cst_7]]) 467 CHECK-NEXT: %[[idx_2:.*]] = arith.constant 0 : index 468 CHECK-NEXT: %[[elt_2:.*]] = tfr.get_element %SplitV_1[%idx_2] : (!tfr.tensor_list, index) -> !tfr.tensor 469 CHECK-NEXT: %[[idx_3:.*]] = arith.constant 1 : index 470 CHECK-NEXT: %[[elt_3:.*]] = tfr.get_element %SplitV_1[%idx_3] : (!tfr.tensor_list, index) -> !tfr.tensor 471 CHECK-NEXT: %[[cst_9:.*]] = arith.constant true 472 CHECK-NEXT: %[[list_2:.*]] = "tfr.build_list"(%[[elt]], %[[elt_3]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 473 CHECK-NEXT: tfr.return %[[list_2]] : !tfr.tensor_list 474 CHECK-NEXT: } 475 476 CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { 477 CHECK-NEXT: %cst = arith.constant true 478 CHECK-NEXT: %[[Id:.*]] = tfr.call @tf__identity(%x) : (!tfr.tensor) -> (!tfr.tensor) 479 CHECK-NEXT: tfr.return %[[Id]] : !tfr.tensor 480 CHECK-NEXT: } 481 482 CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, 483 CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { 484 CHECK-NEXT: %[[if_stmt:.*]] = scf.if %pred -> (!tfr.tensor) { 485 CHECK-NEXT: %cst = arith.constant true 486 CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%x, %y) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) 487 CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor 488 CHECK-NEXT: } else { 489 CHECK-NEXT: %cst_1 = arith.constant true 490 CHECK-NEXT: %[[cst_2:.*]] = arith.constant 0 : i64 491 CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %y) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 492 CHECK-NEXT: %[[Concat:.*]] = tfr.call @tf__concat(%[[cst_2]], %[[list]]) : (i64, !tfr.tensor_list) -> (!tfr.tensor) 493 CHECK-NEXT: scf.yield %[[Concat]] : !tfr.tensor 494 CHECK-NEXT: } 495 CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor 496 CHECK-NEXT: } 497 498 CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) { 499 CHECK-NEXT: %cst = arith.constant true 500 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 0 : index 501 CHECK-NEXT: %[[elt:.*]] = tfr.get_element %ins[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor 502 CHECK-NEXT: %[[cst_2:.*]] = arith.constant 1 : index 503 CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %ins[%cst_2] : (!tfr.tensor_list, index) -> !tfr.tensor 504 CHECK-NEXT: %[[cst_3:.*]] = arith.constant false 505 CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_two_inputs_op( 506 CHECK-SAME: %[[elt]], %[[elt_1]], %[[cst_3]]) : (!tfr.tensor, !tfr.tensor, i1) -> (!tfr.tensor) 507 CHECK-NEXT: tfr.return %[[call]] : !tfr.tensor 508 CHECK-NEXT: } 509 510 CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor<T>,!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_} 511 512 CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor<i32_>,!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_} 513 514 CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_} 515 516 CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list<N,T>,i64{tfr.name="axis",tfr.type="int"}) -> (!tfr.tensor<T>) attributes {N,T,axis,f32_,i1_,i32_,i64_} 517 518 CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor<T>,!tfr.tensor<Tlen>,!tfr.tensor<i32_>,i64{tfr.name="num_split",tfr.type="int"}) -> (!tfr.tensor_list<num_split,T>) attributes {T,Tlen,f32_,i1_,i32_,i64_,num_split} 519 520 CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor<T>,!tfr.tensor<Tlen>,i64{tfr.name="N",tfr.type="int"}) -> (!tfr.tensor_list<N,T>) attributes {N,T,Tlen,f32_,i1_,i32_,i64_} 521 522 CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_} 523 524 CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_} 525 526 CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor<T>,!tfr.tensor<T>,i1{tfr.name="pred",tfr.type="bool"}) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_,pred} 527 528 CHECK-LABEL: tfr.func @tf__test_two_outputs_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>,!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_} 529 """ 530 self._check_code(mlir_code, mlir_code_exp) 531 532 def test_tfr_attrs(self): 533 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_attrs', [test_ops]) 534 mlir_code_exp = r""" 535 CHECK-LABEL: tfr.func @tf__test_num_attrs_op( 536 CHECK-SAME: %x: i64{tfr.name="x1",tfr.default=-10}, 537 CHECK-SAME: %y: i64{tfr.name="y1",tfr.default=1}, 538 CHECK-SAME: %x1: f32{tfr.name="x2",tfr.default=0.0}, 539 CHECK-SAME: %y1: f32{tfr.name="y2",tfr.default=-3.0}) -> () { 540 CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x, %y) : (i64, i64) -> !tfr.attr 541 CHECK-NEXT: %{{.*}} = arith.cmpi "eq", %x, %y : i64 542 CHECK-NEXT: %{{.*}} = arith.cmpi "ult", %x, %y : i64 543 CHECK-NEXT: %{{.*}} = arith.cmpi "ule", %x, %y : i64 544 CHECK-NEXT: %{{.*}} = arith.cmpi "ugt", %x, %y : i64 545 CHECK-NEXT: %{{.*}} = arith.cmpi "uge", %x, %y : i64 546 CHECK-NEXT: %{{.*}} = arith.cmpi "ne", %x, %y : i64 547 CHECK-NEXT: %{{.*}} = arith.addi %x, %y : i64 548 CHECK-NEXT: %[[sub_1:.*]] = arith.subi %x, %y : i64 549 CHECK-NEXT: %[[add_1:.*]] = arith.addi %[[sub_1]], %x : i64 550 CHECK-NEXT: %[[cst:.*]] = arith.constant 1 : i64 551 CHECK-NEXT: %{{.*}} = arith.addi %[[add_1]], %[[cst]] : i64 552 CHECK-NEXT: %{{.*}} = arith.cmpf "ugt", %x1, %y1 : f32 553 CHECK-NEXT: %{{.*}} = arith.addf %x1, %y1 : f32 554 CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x1, %y1) : (f32, f32) -> !tfr.attr 555 CHECK-NEXT: %{{.*}} = arith.constant true 556 CHECK-NEXT: tfr.return 557 CHECK-NEXT: } 558 559 CHECK-LABEL: tfr.func @tf__test_non_num_attrs_op( 560 CHECK-SAME: %x: !tfr.attr{tfr.name="z"}, 561 CHECK-SAME: %y: !tfr.attr{tfr.name="x",tfr.default="hello"}, 562 CHECK-SAME: %z: !tfr.attr{tfr.name="y",tfr.default=f32}) -> () { 563 CHECK-NEXT: %{{.*}} = tfr.equal %x, %y -> i1 564 CHECK-NEXT: %[[cst:.*]] = tfr.constant "test" -> !tfr.attr 565 CHECK-NEXT: %{{.*}} = tfr.equal %x, %[[cst]] -> i1 566 CHECK-NEXT: %{{.*}} = tfr.equal %y, %z -> i1 567 CHECK-NEXT: %{{.*}} = arith.constant true 568 CHECK-NEXT: tfr.return 569 CHECK-NEXT: } 570 """ 571 self._check_code(mlir_code, mlir_code_exp) 572 573 def test_tf_tensor_shape(self): 574 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_shapes', [test_ops]) 575 mlir_code_exp = r""" 576 CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { 577 CHECK-NEXT: %[[shape:.*]] = tfr.get_shape %x -> !shape.shape 578 579 CHECK-NEXT: %[[shape_1:.*]] = tfr.get_shape %x -> !shape.shape 580 CHECK-NEXT: %[[len:.*]] = shape.rank %[[shape_1]] : !shape.shape -> !shape.size 581 CHECK-NEXT: %[[index:.*]] = shape.size_to_index %[[len]] : !shape.size 582 CHECK-NEXT: %[[begin:.*]] = arith.constant 0 : index 583 CHECK-NEXT: %[[step:.*]] = arith.constant 1 : index 584 CHECK-NEXT: scf.for %[[itr_1:.*]] = %[[begin]] to %[[index]] step %[[step]] { 585 CHECK-NEXT: %[[size:.*]] = shape.get_extent %[[shape_1]], %[[itr_1]]: !shape.shape, index -> !shape.size 586 CHECK-NEXT: %[[elt:.*]] = shape.size_to_index %[[size]] : !shape.size 587 CHECK-NEXT: scf.yield 588 CHECK-NEXT: } 589 590 CHECK-NEXT: %[[cst:.*]] = arith.constant 1 : i64 591 CHECK-NEXT: %[[len_1:.*]] = shape.rank %shape_1 : !shape.shape -> !shape.size 592 CHECK-NEXT: %[[len_size_1:.*]] = shape.size_to_index %[[len_1]] : !shape.size 593 CHECK-NEXT: %[[cst_1:.*]] = arith.constant 2 : i64 594 CHECK-NEXT: %[[begin_1:.*]] = arith.index_cast %[[cst]] : i64 to index 595 CHECK-NEXT: %[[step_1:.*]] = arith.index_cast %[[cst_1]] : i64 to index 596 CHECK-NEXT: scf.for %[[itr_3:.*]] = %[[begin_1]] to %[[len_size_1]] step %[[step_1]] 597 598 CHECK: %[[cst:.*]] = tfr.constant i32 -> !tfr.attr 599 CHECK-NEXT: %[[Shape:.*]] = tfr.call @tf__shape(%x, %[[cst]]) : (!tfr.tensor, !tfr.attr) -> (!tfr.tensor) 600 CHECK-NEXT: %{{.*}} = arith.constant true 601 CHECK-NEXT: tfr.return %x : !tfr.tensor 602 CHECK-NEXT: } 603 """ 604 self._check_code(mlir_code, mlir_code_exp) 605 606 def test_temp_function(self): 607 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_temp', [test_ops]) 608 mlir_code_exp = r""" 609 CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) 610 611 CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { 612 CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x) : (!tfr.tensor) -> !tfr.tensor_list 613 CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_identity_n_op(%[[list]]) : (!tfr.tensor_list) 614 """ 615 self._check_code(mlir_code, mlir_code_exp) 616 617 def test_quant_builtins(self): 618 mlir_code = tfr_gen(sys.modules[__name__], '_tfr_quant', [test_ops]) 619 mlir_code_exp = r""" 620 CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { 621 CHECK-NEXT: %[[raw_data:.*]] = tfr.quant_raw_data(%x) : (!tfr.tensor) -> (!tfr.tensor) 622 CHECK-NEXT: %[[qparam:.*]]:2 = tfr.quant_qparam(%x) : (!tfr.tensor) -> (!tfr.tensor, !tfr.tensor) 623 CHECK: %[[list:.*]] = "tfr.build_list"(%[[qparam]]#0, %[[qparam]]#0) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list 624 CHECK: %[[factor:.*]] = tfr.quant_scale_factor(%{{.*}}, %[[list]]) : (f32, !tfr.tensor_list) -> (!tfr.tensor) 625 CHECK: %[[list1:.*]] = "tfr.build_list"(%[[factor]]) : (!tfr.tensor) -> !tfr.tensor_list 626 CHECK: %[[factor1:.*]] = tfr.quant_scale_factor(%{{.*}}, %[[list1]]) : (f32, !tfr.tensor_list) -> (!tfr.tensor) 627 CHECK-NEXT: %[[Sub:.*]] = tfr.call @tf__sub(%[[raw_data]], %[[qparam]]#1) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) 628 CHECK: %[[act_range:.*]]:2 = tfr.quant_act_range(%{{.*}}, %{{.*}}, %{{.*}}) : (!tfr.attr, f32, i64) -> (!tfr.tensor, !tfr.tensor) 629 CHECK: %[[rescale:.*]] = tfr.quant_rescale(%[[Sub]], %[[factor1]], %{{.*}}) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor) 630 CHECK: %[[attr:.*]] = tfr.constant i16 -> !tfr.attr 631 CHECK: %[[Cast:.*]] = tfr.call @tf__cast(%[[rescale]], %[[attr]], %{{.*}}) : (!tfr.tensor, !tfr.attr, i1) -> (!tfr.tensor) 632 CHECK: %[[attr_1:.*]] = tfr.constant i8 -> !tfr.attr 633 CHECK: tfr.call @tf__cast(%[[Cast]], %[[attr_1]], %{{.*}}) : (!tfr.tensor, !tfr.attr, i1) -> (!tfr.tensor) 634 CHECK: } 635 636 CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) { 637 CHECK-NEXT: %[[raw_data:.*]] = tfr.quant_raw_data(%x) : (!tfr.tensor_list) -> (!tfr.tensor_list) 638 CHECK: tfr.return %[[raw_data:.*]] : !tfr.tensor_list 639 CHECK: } 640 """ 641 self._check_code(mlir_code, mlir_code_exp) 642 643 644if __name__ == '__main__': 645 test.main() 646