1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `op_reg_gen` module.""" 16 17# pylint: disable=missing-function-docstring 18# pylint: disable=invalid-name 19# pylint: disable=g-direct-tensorflow-import 20 21import sys 22 23from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw 24from tensorflow.compiler.mlir.tfr.python import composite 25from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op 26from tensorflow.python.platform import test 27 28 29Composite = composite.Composite 30 31 32@composite.Composite( 33 'TestNoOp', derived_attrs=['T: numbertype'], outputs=['o1: T']) 34def _composite_no_op(): 35 pass 36 37 38@Composite( 39 'TestCompositeOp', 40 inputs=['x: T', 'y: T'], 41 attrs=['act: {"", "relu"}', 'trans: bool = true'], 42 derived_attrs=['T: numbertype'], 43 outputs=['o1: T', 'o2: T']) 44def _composite_op(x, y, act, trans): 45 return x + act, y + trans 46 47 48class TFRGenTensorTest(test.TestCase): 49 """MLIR Generation Tests for MLIR TFR Program.""" 50 51 def test_op_reg_gen(self): 52 cxx_code = gen_register_op(sys.modules[__name__]) 53 cxx_code_exp = r""" 54 CHECK: #include "tensorflow/core/framework/op.h" 55 CHECK-EMPTY 56 CHECK: namespace tensorflow { 57 CHECK-EMPTY 58 CHECK-LABEL: REGISTER_OP("TestNoOp") 59 CHECK-NEXT: .Attr("T: numbertype") 60 CHECK-NEXT: .Output("o1: T"); 61 CHECK-EMPTY 62 CHECK-LABEL: REGISTER_OP("TestCompositeOp") 63 CHECK-NEXT: .Input("x: T") 64 CHECK-NEXT: .Input("y: T") 65 CHECK-NEXT: .Attr("act: {'', 'relu'}") 66 CHECK-NEXT: .Attr("trans: bool = true") 67 CHECK-NEXT: .Attr("T: numbertype") 68 CHECK-NEXT: .Output("o1: T") 69 CHECK-NEXT: .Output("o2: T"); 70 CHECK-EMPTY 71 CHECK: } // namespace tensorflow 72 """ 73 self.assertTrue(fw.check(str(cxx_code), cxx_code_exp), str(cxx_code)) 74 75 76if __name__ == '__main__': 77 test.main() 78