• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs."""
15
16import os
17import tensorflow as tf
18
19from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops
20from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs
21from tensorflow.compiler.mlir.tfr.python import test_utils
22from tensorflow.python.framework import load_library
23from tensorflow.python.platform import test
24
25_lib_dir = os.path.dirname(gen_mnist_ops.__file__)
26_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so')
27load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
28
29
30class MnistOpsDefsTest(test_utils.OpsDefsTest):
31
32  def test_new_conv2d_relu(self):
33    input_ = tf.random.uniform([1, 4, 4, 1])
34    filter_ = tf.random.uniform([2, 2, 1, 8])
35    bias = tf.zeros([8])
36    kwargs = {
37        'input_': input_,
38        'filter_': filter_,
39        'bias': bias,
40        'stride_w': 2,
41        'stride_h': 2,
42        'dilation_w': 1,
43        'dilation_h': 1,
44        'padding': 'SAME',
45        'act': 'RELU'
46    }
47
48    self._assertOpAndComposite([input_, filter_, bias],
49                               tf.function(gen_mnist_ops.new_conv2d),
50                               ops_defs._composite_conv_add_relu, kwargs)
51
52  def test_new_conv2d_relu6(self):
53    input_ = tf.random.uniform([1, 4, 4, 1])
54    filter_ = tf.random.uniform([2, 2, 1, 8])
55    bias = tf.zeros([8])
56    kwargs = {
57        'input_': input_,
58        'filter_': filter_,
59        'bias': bias,
60        'stride_w': 2,
61        'stride_h': 2,
62        'dilation_w': 1,
63        'dilation_h': 1,
64        'padding': 'SAME',
65        'act': 'RELU6'
66    }
67
68    self._assertOpAndComposite([input_, filter_, bias],
69                               tf.function(gen_mnist_ops.new_conv2d),
70                               ops_defs._composite_conv_add_relu, kwargs)
71
72  def test_new_conv2d_tanh(self):
73    self.skipTest('Fix tanh gradients')
74    input_ = tf.random.uniform([1, 4, 4, 1])
75    filter_ = tf.random.uniform([2, 2, 1, 8])
76    bias = tf.zeros([8])
77    kwargs = {
78        'input_': input_,
79        'filter_': filter_,
80        'bias': bias,
81        'stride_w': 2,
82        'stride_h': 2,
83        'dilation_w': 1,
84        'dilation_h': 1,
85        'padding': 'SAME',
86        'act': 'TANH'
87    }
88
89    self._assertOpAndComposite([input_, filter_, bias],
90                               tf.function(gen_mnist_ops.new_conv2d),
91                               ops_defs._composite_conv_add_relu, kwargs)
92
93  def test_new_fully_connected(self):
94    input_ = tf.random.uniform([2, 4])
95    filter_ = tf.random.uniform([3, 4])
96    bias = tf.zeros([3])
97    kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'}
98
99    self._assertOpAndComposite([input_, filter_, bias],
100                               tf.function(gen_mnist_ops.new_fully_connected),
101                               ops_defs._composite_fully_connected, kwargs)
102
103  def test_new_max_pool(self):
104    input_ = tf.random.uniform([8, 4, 4, 1])
105    kwargs = {
106        'input_': input_,
107        'stride_w': 2,
108        'stride_h': 2,
109        'filter_width': 1,
110        'filter_height': 1,
111        'padding': 'SAME',
112    }
113
114    self._assertOpAndComposite([input_],
115                               tf.function(gen_mnist_ops.new_max_pool),
116                               ops_defs._composite_max_pool, kwargs)
117
118
119if __name__ == '__main__':
120  os.environ[
121      'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist'
122  test.main()
123