1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs.""" 15 16import os 17import tensorflow as tf 18 19from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops 20from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs 21from tensorflow.compiler.mlir.tfr.python import test_utils 22from tensorflow.python.framework import load_library 23from tensorflow.python.platform import test 24 25_lib_dir = os.path.dirname(gen_mnist_ops.__file__) 26_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') 27load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) 28 29 30class MnistOpsDefsTest(test_utils.OpsDefsTest): 31 32 def test_new_conv2d_relu(self): 33 input_ = tf.random.uniform([1, 4, 4, 1]) 34 filter_ = tf.random.uniform([2, 2, 1, 8]) 35 bias = tf.zeros([8]) 36 kwargs = { 37 'input_': input_, 38 'filter_': filter_, 39 'bias': bias, 40 'stride_w': 2, 41 'stride_h': 2, 42 'dilation_w': 1, 43 'dilation_h': 1, 44 'padding': 'SAME', 45 'act': 'RELU' 46 } 47 48 self._assertOpAndComposite([input_, filter_, bias], 49 tf.function(gen_mnist_ops.new_conv2d), 50 ops_defs._composite_conv_add_relu, kwargs) 51 52 def test_new_conv2d_relu6(self): 53 input_ = tf.random.uniform([1, 4, 4, 1]) 54 filter_ = tf.random.uniform([2, 2, 1, 8]) 55 bias = tf.zeros([8]) 56 kwargs = { 57 'input_': input_, 58 'filter_': filter_, 59 'bias': bias, 60 'stride_w': 2, 61 'stride_h': 2, 62 'dilation_w': 1, 63 'dilation_h': 1, 64 'padding': 'SAME', 65 'act': 'RELU6' 66 } 67 68 self._assertOpAndComposite([input_, filter_, bias], 69 tf.function(gen_mnist_ops.new_conv2d), 70 ops_defs._composite_conv_add_relu, kwargs) 71 72 def test_new_conv2d_tanh(self): 73 self.skipTest('Fix tanh gradients') 74 input_ = tf.random.uniform([1, 4, 4, 1]) 75 filter_ = tf.random.uniform([2, 2, 1, 8]) 76 bias = tf.zeros([8]) 77 kwargs = { 78 'input_': input_, 79 'filter_': filter_, 80 'bias': bias, 81 'stride_w': 2, 82 'stride_h': 2, 83 'dilation_w': 1, 84 'dilation_h': 1, 85 'padding': 'SAME', 86 'act': 'TANH' 87 } 88 89 self._assertOpAndComposite([input_, filter_, bias], 90 tf.function(gen_mnist_ops.new_conv2d), 91 ops_defs._composite_conv_add_relu, kwargs) 92 93 def test_new_fully_connected(self): 94 input_ = tf.random.uniform([2, 4]) 95 filter_ = tf.random.uniform([3, 4]) 96 bias = tf.zeros([3]) 97 kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'} 98 99 self._assertOpAndComposite([input_, filter_, bias], 100 tf.function(gen_mnist_ops.new_fully_connected), 101 ops_defs._composite_fully_connected, kwargs) 102 103 def test_new_max_pool(self): 104 input_ = tf.random.uniform([8, 4, 4, 1]) 105 kwargs = { 106 'input_': input_, 107 'stride_w': 2, 108 'stride_h': 2, 109 'filter_width': 1, 110 'filter_height': 1, 111 'padding': 'SAME', 112 } 113 114 self._assertOpAndComposite([input_], 115 tf.function(gen_mnist_ops.new_max_pool), 116 ops_defs._composite_max_pool, kwargs) 117 118 119if __name__ == '__main__': 120 os.environ[ 121 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' 122 test.main() 123