• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for concat."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26@register_make_test_function()
27def make_concat_tests(options):
28  """Make a set of tests to do concatenation."""
29
30  test_parameters = [{
31      "base_shape": [[1, 3, 4, 3], [3, 4]],
32      "num_tensors": [1, 2, 3, 4, 5, 6],
33      "axis": [0, 1, 2, 3, -3, -2, -1],
34      "type": [tf.float32, tf.uint8, tf.int32, tf.int64],
35      "fully_quantize": [False],
36      "quant_16x8": [False],
37      "dynamic_range_quantize": [False],
38  }, {
39      "base_shape": [[1, 3, 4, 3], [3, 4], [2, 3, 4, 3]],
40      "num_tensors": [1, 2, 3, 4, 5, 6],
41      "axis": [1, 2, 3, -3, -2, -1],
42      "type": [tf.float32],
43      "fully_quantize": [True],
44      "quant_16x8": [False],
45      "dynamic_range_quantize": [False],
46  }, {
47      "base_shape": [[1, 3, 4, 3]],
48      "num_tensors": [6],
49      "axis": [-1],
50      "type": [tf.float32],
51      "fully_quantize": [True],
52      "quant_16x8": [True],
53      "dynamic_range_quantize": [False],
54  }, {
55      "base_shape": [[1, 3, 4, 3]],
56      "num_tensors": [6],
57      "axis": [1],
58      "type": [tf.float32],
59      "fully_quantize": [False],
60      "quant_16x8": [False],
61      "dynamic_range_quantize": [True],
62  }, {
63      "base_shape": [[1, 3, 4, 3]],
64      "num_tensors": [6],
65      "axis": [1],
66      "type": [tf.bool],
67      "fully_quantize": [False],
68      "quant_16x8": [False],
69      "dynamic_range_quantize": [True],
70  }]
71
72  def get_shape(parameters, delta):
73    """Return a tweaked version of 'base_shape'."""
74    axis = parameters["axis"]
75    shape = parameters["base_shape"][:]
76    if axis < 0:
77      axis += len(shape)
78    if axis < len(shape):
79      shape[axis] += delta
80    return shape
81
82  def build_graph(parameters):
83    all_tensors = []
84    for n in range(0, parameters["num_tensors"]):
85      input_tensor = tf.compat.v1.placeholder(
86          dtype=parameters["type"],
87          name=("input%d" % n),
88          shape=get_shape(parameters, n))
89      all_tensors.append(input_tensor)
90    out = tf.concat(all_tensors, parameters["axis"])
91    return all_tensors, [out]
92
93  def build_inputs(parameters, sess, inputs, outputs):
94    all_values = []
95    for n in range(0, parameters["num_tensors"]):
96      input_values = create_tensor_data(
97          parameters["type"],
98          get_shape(parameters, n),
99          min_value=-1,
100          max_value=1)
101      all_values.append(input_values)
102    return all_values, sess.run(
103        outputs, feed_dict=dict(zip(inputs, all_values)))
104
105  make_zip_of_tests(
106      options,
107      test_parameters,
108      build_graph,
109      build_inputs,
110      expected_tf_failures=75)
111