• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for sparse_to_dense."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import tensorflow.compat.v1 as tf
22from tensorflow.lite.testing.zip_test_utils import create_scalar_data
23from tensorflow.lite.testing.zip_test_utils import create_tensor_data
24from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
25from tensorflow.lite.testing.zip_test_utils import register_make_test_function
26
27
28@register_make_test_function()
29def make_sparse_to_dense_tests(options):
30  """Make a set of tests to do sparse to dense."""
31
32  test_parameters = [{
33      "value_dtype": [tf.float32, tf.int32, tf.int64],
34      "index_dtype": [tf.int32, tf.int64],
35      "value_count": [1, 3, 6, 8],
36      "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]],
37      "default_value": [0, -1],
38      "value_is_scalar": [True, False],
39  }]
40
41  # Return a single value for 1-D dense shape, but a tuple for other shapes.
42  def generate_index(dense_shape):
43    if len(dense_shape) == 1:
44      return np.random.randint(dense_shape[0])
45    else:
46      index = []
47      for shape in dense_shape:
48        index.append(np.random.randint(shape))
49      return tuple(index)
50
51  def build_graph(parameters):
52    """Build the sparse_to_dense op testing graph."""
53    dense_shape = parameters["dense_shape"]
54
55    # Special handle for value_is_scalar case.
56    # value_count must be 1.
57    if parameters["value_is_scalar"] and parameters["value_count"] == 1:
58      value = tf.compat.v1.placeholder(
59          name="value", dtype=parameters["value_dtype"], shape=())
60    else:
61      value = tf.compat.v1.placeholder(
62          name="value",
63          dtype=parameters["value_dtype"],
64          shape=[parameters["value_count"]])
65    indices = set()
66    while len(indices) < parameters["value_count"]:
67      indices.add(generate_index(dense_shape))
68    indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"])
69    # TODO(renjieliu): Add test for validate_indices case.
70    out = tf.sparse_to_dense(
71        indices,
72        dense_shape,
73        value,
74        parameters["default_value"],
75        validate_indices=False)
76
77    return [value], [out]
78
79  def build_inputs(parameters, sess, inputs, outputs):
80    if parameters["value_is_scalar"] and parameters["value_count"] == 1:
81      input_value = create_scalar_data(parameters["value_dtype"])
82    else:
83      input_value = create_tensor_data(parameters["value_dtype"],
84                                       [parameters["value_count"]])
85    return [input_value], sess.run(
86        outputs, feed_dict=dict(zip(inputs, [input_value])))
87
88  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
89