• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tensorflow.compiler.mlir.tfr.examples.pad.ops_defs."""
15
16import os
17from absl.testing import parameterized
18import tensorflow as tf
19
20from tensorflow.compiler.mlir.tfr.examples.pad import gen_pad_ops
21from tensorflow.compiler.mlir.tfr.examples.pad import ops_defs
22from tensorflow.compiler.mlir.tfr.python import test_utils
23from tensorflow.python.framework import load_library
24from tensorflow.python.platform import test
25
26_lib_dir = os.path.dirname(gen_pad_ops.__file__)
27_lib_name = os.path.basename(gen_pad_ops.__file__)[4:].replace('.py', '.so')
28load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
29
30
31class PadOpsDefsTest(test_utils.OpsDefsTest, parameterized.TestCase):
32
33  @parameterized.named_parameters(('ReflectMode', 'REFLECT'),
34                                  ('SymmetricMode', 'SYMMETRIC'))
35  def test_mirror_pad(self, mode):
36    input_ = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
37    paddings = tf.constant([[
38        1,
39        1,
40    ], [2, 2]])
41    kwargs = {
42        'input': input_,
43        'paddings': paddings,
44        'mode': mode,
45    }
46    kwargs_ = {
47        'input_': input_,
48        'paddings': paddings,
49        'mode': mode,
50    }
51    # Make sure the composition python function is correct
52    self._assertOpAndComposite([input_], tf.raw_ops.MirrorPad,
53                               ops_defs._composite_mirror_pad, kwargs_, kwargs)
54    # Make sure the translation and decomposition is correct
55    self._assertOpAndComposite([input_],
56                               tf.function(gen_pad_ops.new_mirror_pad),
57                               ops_defs._composite_mirror_pad, kwargs_)
58
59  @parameterized.named_parameters(('ReflectMode', 'REFLECT'),
60                                  ('SymmetricMode', 'SYMMETRIC'))
61  def test_mirror_pad_grad(self, mode):
62    input_ = tf.constant([[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2],
63                          [5, 4, 4, 5, 6, 6, 5], [5, 4, 4, 5, 6, 6, 5]],
64                         dtype=tf.float32)
65    paddings = tf.constant([[
66        1,
67        1,
68    ], [2, 2]])
69    kwargs = {
70        'input': input_,
71        'paddings': paddings,
72        'mode': mode,
73    }
74    kwargs_ = {
75        'input_': input_,
76        'paddings': paddings,
77        'mode': mode,
78    }
79    # Make sure the composition python function is correct
80    self._assertOpAndComposite([input_], tf.raw_ops.MirrorPadGrad,
81                               ops_defs._composite_mirror_pad_grad, kwargs_,
82                               kwargs)
83    # Make sure the translation and decomposition is correct
84    self._assertOpAndComposite([input_],
85                               tf.function(gen_pad_ops.new_mirror_pad_grad),
86                               ops_defs._composite_mirror_pad_grad, kwargs_)
87
88
89if __name__ == '__main__':
90  os.environ[
91      'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/pad'
92  test.main()
93