1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for sparse_to_dense.""" 16import numpy as np 17import tensorflow.compat.v1 as tf 18from tensorflow.lite.testing.zip_test_utils import create_scalar_data 19from tensorflow.lite.testing.zip_test_utils import create_tensor_data 20from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 21from tensorflow.lite.testing.zip_test_utils import register_make_test_function 22 23 24@register_make_test_function() 25def make_sparse_to_dense_tests(options): 26 """Make a set of tests to do sparse to dense.""" 27 28 test_parameters = [{ 29 "value_dtype": [tf.float32, tf.int32, tf.int64], 30 "index_dtype": [tf.int32, tf.int64], 31 "value_count": [1, 3, 6, 8], 32 "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], 33 "default_value": [0, -1], 34 "value_is_scalar": [True, False], 35 }] 36 37 # Return a single value for 1-D dense shape, but a tuple for other shapes. 38 def generate_index(dense_shape): 39 if len(dense_shape) == 1: 40 return np.random.randint(dense_shape[0]) 41 else: 42 index = [] 43 for shape in dense_shape: 44 index.append(np.random.randint(shape)) 45 return tuple(index) 46 47 def build_graph(parameters): 48 """Build the sparse_to_dense op testing graph.""" 49 dense_shape = parameters["dense_shape"] 50 51 # Special handle for value_is_scalar case. 52 # value_count must be 1. 53 if parameters["value_is_scalar"] and parameters["value_count"] == 1: 54 value = tf.compat.v1.placeholder( 55 name="value", dtype=parameters["value_dtype"], shape=()) 56 else: 57 value = tf.compat.v1.placeholder( 58 name="value", 59 dtype=parameters["value_dtype"], 60 shape=[parameters["value_count"]]) 61 indices = set() 62 while len(indices) < parameters["value_count"]: 63 indices.add(generate_index(dense_shape)) 64 indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) 65 # TODO(renjieliu): Add test for validate_indices case. 66 out = tf.sparse_to_dense( 67 indices, 68 dense_shape, 69 value, 70 parameters["default_value"], 71 validate_indices=False) 72 73 return [value], [out] 74 75 def build_inputs(parameters, sess, inputs, outputs): 76 if parameters["value_is_scalar"] and parameters["value_count"] == 1: 77 input_value = create_scalar_data(parameters["value_dtype"]) 78 else: 79 input_value = create_tensor_data(parameters["value_dtype"], 80 [parameters["value_count"]]) 81 return [input_value], sess.run( 82 outputs, feed_dict=dict(zip(inputs, [input_value]))) 83 84 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 85