1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for sparse_to_dense.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import tensorflow.compat.v1 as tf 22from tensorflow.lite.testing.zip_test_utils import create_scalar_data 23from tensorflow.lite.testing.zip_test_utils import create_tensor_data 24from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 25from tensorflow.lite.testing.zip_test_utils import register_make_test_function 26 27 28@register_make_test_function() 29def make_sparse_to_dense_tests(options): 30 """Make a set of tests to do sparse to dense.""" 31 32 test_parameters = [{ 33 "value_dtype": [tf.float32, tf.int32, tf.int64], 34 "index_dtype": [tf.int32, tf.int64], 35 "value_count": [1, 3, 6, 8], 36 "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], 37 "default_value": [0, -1], 38 "value_is_scalar": [True, False], 39 }] 40 41 # Return a single value for 1-D dense shape, but a tuple for other shapes. 42 def generate_index(dense_shape): 43 if len(dense_shape) == 1: 44 return np.random.randint(dense_shape[0]) 45 else: 46 index = [] 47 for shape in dense_shape: 48 index.append(np.random.randint(shape)) 49 return tuple(index) 50 51 def build_graph(parameters): 52 """Build the sparse_to_dense op testing graph.""" 53 dense_shape = parameters["dense_shape"] 54 55 # Special handle for value_is_scalar case. 56 # value_count must be 1. 57 if parameters["value_is_scalar"] and parameters["value_count"] == 1: 58 value = tf.compat.v1.placeholder( 59 name="value", dtype=parameters["value_dtype"], shape=()) 60 else: 61 value = tf.compat.v1.placeholder( 62 name="value", 63 dtype=parameters["value_dtype"], 64 shape=[parameters["value_count"]]) 65 indices = set() 66 while len(indices) < parameters["value_count"]: 67 indices.add(generate_index(dense_shape)) 68 indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) 69 # TODO(renjieliu): Add test for validate_indices case. 70 out = tf.sparse_to_dense( 71 indices, 72 dense_shape, 73 value, 74 parameters["default_value"], 75 validate_indices=False) 76 77 return [value], [out] 78 79 def build_inputs(parameters, sess, inputs, outputs): 80 if parameters["value_is_scalar"] and parameters["value_count"] == 1: 81 input_value = create_scalar_data(parameters["value_dtype"]) 82 else: 83 input_value = create_tensor_data(parameters["value_dtype"], 84 [parameters["value_count"]]) 85 return [input_value], sess.run( 86 outputs, feed_dict=dict(zip(inputs, [input_value]))) 87 88 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 89