• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for concat."""
16import tensorflow.compat.v1 as tf
17from tensorflow.lite.testing.zip_test_utils import create_tensor_data
18from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
19from tensorflow.lite.testing.zip_test_utils import register_make_test_function
20
21
22@register_make_test_function()
23def make_concat_tests(options):
24  """Make a set of tests to do concatenation."""
25
26  test_parameters = [{
27      "base_shape": [[1, 3, 4, 3], [3, 4]],
28      "num_tensors": [1, 2, 3, 4, 5, 6],
29      "axis": [0, 1, 2, 3, -3, -2, -1],
30      "type": [tf.float32, tf.uint8, tf.int32, tf.int64],
31      "fully_quantize": [False],
32      "quant_16x8": [False],
33      "dynamic_range_quantize": [False],
34  }, {
35      "base_shape": [[1, 3, 4, 3], [3, 4], [2, 3, 4, 3]],
36      "num_tensors": [1, 2, 3, 4, 5, 6],
37      "axis": [1, 2, 3, -3, -2, -1],
38      "type": [tf.float32],
39      "fully_quantize": [True],
40      "quant_16x8": [False],
41      "dynamic_range_quantize": [False],
42  }, {
43      "base_shape": [[1, 3, 4, 3]],
44      "num_tensors": [6],
45      "axis": [-1],
46      "type": [tf.float32],
47      "fully_quantize": [True],
48      "quant_16x8": [True],
49      "dynamic_range_quantize": [False],
50  }, {
51      "base_shape": [[1, 3, 4, 3]],
52      "num_tensors": [6],
53      "axis": [1],
54      "type": [tf.float32],
55      "fully_quantize": [False],
56      "quant_16x8": [False],
57      "dynamic_range_quantize": [True],
58  }, {
59      "base_shape": [[1, 3, 4, 3]],
60      "num_tensors": [6],
61      "axis": [1],
62      "type": [tf.bool],
63      "fully_quantize": [False],
64      "quant_16x8": [False],
65      "dynamic_range_quantize": [True],
66  }]
67
68  def get_shape(parameters, delta):
69    """Return a tweaked version of 'base_shape'."""
70    axis = parameters["axis"]
71    shape = parameters["base_shape"][:]
72    if axis < 0:
73      axis += len(shape)
74    if axis < len(shape):
75      shape[axis] += delta
76    return shape
77
78  def build_graph(parameters):
79    all_tensors = []
80    for n in range(0, parameters["num_tensors"]):
81      input_tensor = tf.compat.v1.placeholder(
82          dtype=parameters["type"],
83          name=("input%d" % n),
84          shape=get_shape(parameters, n))
85      all_tensors.append(input_tensor)
86    out = tf.concat(all_tensors, parameters["axis"])
87    return all_tensors, [out]
88
89  def build_inputs(parameters, sess, inputs, outputs):
90    all_values = []
91    for n in range(0, parameters["num_tensors"]):
92      input_values = create_tensor_data(
93          parameters["type"],
94          get_shape(parameters, n),
95          min_value=-1,
96          max_value=1)
97      all_values.append(input_values)
98    return all_values, sess.run(
99        outputs, feed_dict=dict(zip(inputs, all_values)))
100
101  make_zip_of_tests(
102      options,
103      test_parameters,
104      build_graph,
105      build_inputs,
106      expected_tf_failures=75)
107