• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `op_reg_gen` module."""
16
17# pylint: disable=missing-function-docstring
18# pylint: disable=invalid-name
19# pylint: disable=g-direct-tensorflow-import
20
21import sys
22
23from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw
24from tensorflow.compiler.mlir.tfr.python import composite
25from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
26from tensorflow.python.platform import test
27
28
29Composite = composite.Composite
30
31
32@composite.Composite(
33    'TestNoOp', derived_attrs=['T: numbertype'], outputs=['o1: T'])
34def _composite_no_op():
35  pass
36
37
38@Composite(
39    'TestCompositeOp',
40    inputs=['x: T', 'y: T'],
41    attrs=['act: {"", "relu"}', 'trans: bool = true'],
42    derived_attrs=['T: numbertype'],
43    outputs=['o1: T', 'o2: T'])
44def _composite_op(x, y, act, trans):
45  return x + act, y + trans
46
47
48class TFRGenTensorTest(test.TestCase):
49  """MLIR Generation Tests for MLIR TFR Program."""
50
51  def test_op_reg_gen(self):
52    cxx_code = gen_register_op(sys.modules[__name__])
53    cxx_code_exp = r"""
54      CHECK: #include "tensorflow/core/framework/op.h"
55      CHECK-EMPTY
56      CHECK: namespace tensorflow {
57      CHECK-EMPTY
58      CHECK-LABEL: REGISTER_OP("TestNoOp")
59      CHECK-NEXT:      .Attr("T: numbertype")
60      CHECK-NEXT:      .Output("o1: T");
61      CHECK-EMPTY
62      CHECK-LABEL: REGISTER_OP("TestCompositeOp")
63      CHECK-NEXT:      .Input("x: T")
64      CHECK-NEXT:      .Input("y: T")
65      CHECK-NEXT:      .Attr("act: {'', 'relu'}")
66      CHECK-NEXT:      .Attr("trans: bool = true")
67      CHECK-NEXT:      .Attr("T: numbertype")
68      CHECK-NEXT:      .Output("o1: T")
69      CHECK-NEXT:      .Output("o2: T");
70      CHECK-EMPTY
71      CHECK:  }  // namespace tensorflow
72    """
73    self.assertTrue(fw.check(str(cxx_code), cxx_code_exp), str(cxx_code))
74
75
76if __name__ == '__main__':
77  test.main()
78