• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for transpose."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import tensorflow.compat.v1 as tf
22from tensorflow.lite.testing.zip_test_utils import create_tensor_data
23from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
24from tensorflow.lite.testing.zip_test_utils import register_make_test_function
25
26
27@register_make_test_function()
28def make_transpose_tests(options):
29  """Make a set of tests to do transpose."""
30
31  # TODO(nupurgarg): Add test for uint8.
32  test_parameters = [{
33      "dtype": [tf.int32, tf.int64, tf.float32],
34      "input_shape": [[2, 2, 3]],
35      "perm": [[0, 1, 2], [0, 2, 1]],
36      "constant_perm": [True, False],
37      "fully_quantize": [False],
38  }, {
39      "dtype": [tf.float32],
40      "input_shape": [[1, 2, 3, 4]],
41      "perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
42      "constant_perm": [True, False],
43      "fully_quantize": [False],
44  }, {
45      "dtype": [tf.float32],
46      "input_shape": [[1, 2, 3, 4, 5]],
47      "perm": [[4, 3, 2, 1, 0]],
48      "constant_perm": [True, False],
49      "fully_quantize": [False],
50  }, {
51      "dtype": [tf.float32],
52      "input_shape": [[2, 2, 3]],
53      "perm": [[0, 1, 2], [0, 2, 1]],
54      "constant_perm": [True],
55      "fully_quantize": [True],
56  }, {
57      "dtype": [tf.float32],
58      "input_shape": [[1, 2, 3, 4]],
59      "perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
60      "constant_perm": [True],
61      "fully_quantize": [True],
62  }, {
63      "dtype": [tf.float32],
64      "input_shape": [[1, 2, 3, 4, 5]],
65      "perm": [[0, 1, 2, 3, 4], [3, 4, 0, 1, 2]],
66      "constant_perm": [True],
67      "fully_quantize": [True, False],
68  }]
69
70  def build_graph(parameters):
71    """Build a transpose graph given `parameters`."""
72    input_tensor = tf.compat.v1.placeholder(
73        dtype=parameters["dtype"],
74        name="input",
75        shape=parameters["input_shape"])
76
77    if parameters["constant_perm"]:
78      perm = parameters["perm"]
79      input_tensors = [input_tensor]
80    else:
81      shape = [len(parameters["perm"]), 2]
82      perm = tf.compat.v1.placeholder(dtype=tf.int32, name="perm", shape=shape)
83      input_tensors = [input_tensor, perm]
84
85    out = tf.transpose(input_tensor, perm=perm)
86    return input_tensors, [out]
87
88  def build_inputs(parameters, sess, inputs, outputs):
89    values = [
90        create_tensor_data(parameters["dtype"], parameters["input_shape"],
91                           min_value=-1, max_value=1)
92    ]
93    if not parameters["constant_perm"]:
94      values.append(np.array(parameters["perm"]))
95    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
96
97  make_zip_of_tests(
98      options,
99      test_parameters,
100      build_graph,
101      build_inputs,
102      expected_tf_failures=9)
103