• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for gather."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26@register_make_test_function()
27def make_gather_tests(options):
28  """Make a set of tests to do gather."""
29
30  test_parameters = [{
31      "params_dtype": [tf.float32, tf.int32, tf.int64],
32      "params_shape": [[1, 2, 20]],
33      "indices_dtype": [tf.int32, tf.int64],
34      "indices_shape": [[3], [5]],
35      "axis": [-1, 0, 1],
36      "batch_dims": [0],
37      "constant_params": [False, True],
38  }, {
39      "params_dtype": [tf.string],
40      "params_shape": [[8]],
41      "indices_dtype": [tf.int32],
42      "indices_shape": [[3], [3, 2]],
43      "axis": [0],
44      "batch_dims": [0],
45      "constant_params": [False, True],
46  }]
47
48  if options.use_experimental_converter:
49    test_parameters = test_parameters + [
50        # Test with batch_dims.
51        {
52            "params_dtype": [tf.float32, tf.int32],
53            "params_shape": [[2, 2, 3, 5]],
54            "indices_dtype": [tf.int32],
55            "indices_shape": [[2, 2, 2]],
56            "axis": [0, 2],
57            "batch_dims": [1, 2],
58            "constant_params": [False, True],
59        }
60    ]
61
62  def build_graph(parameters):
63    """Build the gather op testing graph."""
64    inputs = []
65
66    if parameters["constant_params"]:
67      params = create_tensor_data(parameters["params_dtype"],
68                                  parameters["params_shape"])
69    else:
70      params = tf.compat.v1.placeholder(
71          dtype=parameters["params_dtype"],
72          name="params",
73          shape=parameters["params_shape"])
74      inputs.append(params)
75
76    indices = tf.compat.v1.placeholder(
77        dtype=parameters["indices_dtype"],
78        name="indices",
79        shape=parameters["indices_shape"])
80    inputs.append(indices)
81    axis = min(len(parameters["params_shape"]), parameters["axis"])
82    out = tf.gather(
83        params, indices, axis=axis, batch_dims=parameters["batch_dims"])
84    return inputs, [out]
85
86  def build_inputs(parameters, sess, inputs, outputs):
87    input_values = []
88    if not parameters["constant_params"]:
89      params = create_tensor_data(parameters["params_dtype"],
90                                  parameters["params_shape"])
91      input_values.append(params)
92    indices = create_tensor_data(parameters["indices_dtype"],
93                                 parameters["indices_shape"], 0,
94                                 parameters["params_shape"][0] - 1)
95    input_values.append(indices)
96    return input_values, sess.run(
97        outputs, feed_dict=dict(zip(inputs, input_values)))
98
99  make_zip_of_tests(
100      options,
101      test_parameters,
102      build_graph,
103      build_inputs,
104      expected_tf_failures=0)
105