• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for gather."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow.compat.v1 as tf
21from tensorflow.lite.testing.zip_test_utils import create_tensor_data
22from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
23from tensorflow.lite.testing.zip_test_utils import register_make_test_function
24
25
26@register_make_test_function()
27def make_gather_tests(options):
28  """Make a set of tests to do gather."""
29
30  test_parameters = [
31      {
32          "params_dtype": [tf.float32, tf.int32, tf.int64],
33          "params_shape": [[1, 2, 20]],
34          "indices_dtype": [tf.int32, tf.int64],
35          "indices_shape": [[3], [5]],
36          "axis": [-1, 0, 1],
37          "constant_params": [False, True],
38      },
39      {
40          "params_dtype": [tf.string],
41          "params_shape": [[8]],
42          "indices_dtype": [tf.int32],
43          "indices_shape": [[3], [3, 2]],
44          "axis": [0],
45          "constant_params": [False, True],
46      }
47  ]
48
49  def build_graph(parameters):
50    """Build the gather op testing graph."""
51    inputs = []
52
53    if parameters["constant_params"]:
54      params = create_tensor_data(parameters["params_dtype"],
55                                  parameters["params_shape"])
56    else:
57      params = tf.compat.v1.placeholder(
58          dtype=parameters["params_dtype"],
59          name="params",
60          shape=parameters["params_shape"])
61      inputs.append(params)
62
63    indices = tf.compat.v1.placeholder(
64        dtype=parameters["indices_dtype"],
65        name="indices",
66        shape=parameters["indices_shape"])
67    inputs.append(indices)
68    axis = min(len(parameters["params_shape"]), parameters["axis"])
69    out = tf.gather(params, indices, axis=axis)
70    return inputs, [out]
71
72  def build_inputs(parameters, sess, inputs, outputs):
73    input_values = []
74    if not parameters["constant_params"]:
75      params = create_tensor_data(parameters["params_dtype"],
76                                  parameters["params_shape"])
77      input_values.append(params)
78    indices = create_tensor_data(parameters["indices_dtype"],
79                                 parameters["indices_shape"], 0,
80                                 parameters["params_shape"][0] - 1)
81    input_values.append(indices)
82    return input_values, sess.run(
83        outputs, feed_dict=dict(zip(inputs, input_values)))
84
85  make_zip_of_tests(
86      options,
87      test_parameters,
88      build_graph,
89      build_inputs,
90      expected_tf_failures=0)
91