1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for parse example.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import string 21 22import numpy as np 23import tensorflow.compat.v1 as tf 24 25from tensorflow.lite.testing.zip_test_utils import ExtraTocoOptions 26from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 27from tensorflow.lite.testing.zip_test_utils import register_make_test_function 28 29 30def create_example_data(feature_dtype, feature_shape): 31 """Create structured example data.""" 32 features = {} 33 if feature_dtype in (tf.float32, tf.float16, tf.float64): 34 data = np.random.rand(*feature_shape) 35 features["x"] = tf.train.Feature( 36 float_list=tf.train.FloatList(value=list(data))) 37 elif feature_dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): 38 data = np.random.randint(-100, 100, size=feature_shape) 39 features["x"] = tf.train.Feature( 40 int64_list=tf.train.Int64List(value=list(data))) 41 elif feature_dtype == tf.string: 42 letters = list(string.ascii_uppercase) 43 data = "".join(np.random.choice(letters, size=10)).encode("utf-8") 44 features["x"] = tf.train.Feature( 45 bytes_list=tf.train.BytesList(value=[data]*feature_shape[0])) 46 example = tf.train.Example(features=tf.train.Features(feature=features)) 47 return np.array([example.SerializeToString()]) 48 49 50@register_make_test_function("make_parse_example_tests") 51def make_parse_example_tests(options): 52 """Make a set of tests to use parse_example.""" 53 54 # Chose a set of parameters 55 test_parameters = [{ 56 "feature_dtype": [tf.string, tf.float32, tf.int64], 57 "is_dense": [True, False], 58 "feature_shape": [[1], [2], [16]], 59 }] 60 61 def build_graph(parameters): 62 """Build the graph for parse_example tests.""" 63 feature_dtype = parameters["feature_dtype"] 64 feature_shape = parameters["feature_shape"] 65 is_dense = parameters["is_dense"] 66 input_value = tf.compat.v1.placeholder( 67 dtype=tf.string, name="input", shape=[1]) 68 if is_dense: 69 feature_default_value = np.zeros(shape=feature_shape) 70 if feature_dtype == tf.string: 71 feature_default_value = np.array(["missing"]*feature_shape[0]) 72 features = {"x": tf.FixedLenFeature(shape=feature_shape, 73 dtype=feature_dtype, 74 default_value=feature_default_value)} 75 else: # Sparse 76 features = {"x": tf.VarLenFeature(dtype=feature_dtype)} 77 out = tf.parse_example(input_value, features) 78 output_tensor = out["x"] 79 if not is_dense: 80 output_tensor = out["x"].values 81 return [input_value], [output_tensor] 82 83 def build_inputs(parameters, sess, inputs, outputs): 84 feature_dtype = parameters["feature_dtype"] 85 feature_shape = parameters["feature_shape"] 86 input_values = [create_example_data(feature_dtype, feature_shape)] 87 return input_values, sess.run( 88 outputs, feed_dict=dict(zip(inputs, input_values))) 89 90 extra_toco_options = ExtraTocoOptions() 91 extra_toco_options.allow_custom_ops = True 92 make_zip_of_tests(options, test_parameters, build_graph, build_inputs, 93 extra_toco_options) 94