1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for roll.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import tensorflow.compat.v1 as tf 22 23from tensorflow.lite.testing.zip_test_utils import create_tensor_data 24from tensorflow.lite.testing.zip_test_utils import ExtraTocoOptions 25from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 26from tensorflow.lite.testing.zip_test_utils import register_make_test_function 27 28test_parameters = [ 29 # Scalar axis. 30 { 31 "input_dtype": [tf.float32, tf.int32], 32 "input_shape": [[2, 4, 5], [3, 8, 4]], 33 "shift": [1, -3, 5], 34 "axis": [0, 1, 2], 35 }, 36 # 1-D axis. 37 { 38 "input_dtype": [tf.float32, tf.int32], 39 "input_shape": [[2, 4, 5], [3, 8, 4]], 40 "shift": [[1], [-3], [5]], 41 "axis": [[0], [1], [2]], 42 }, 43 # Multiple axis. 44 { 45 "input_dtype": [tf.float32, tf.int32], 46 "input_shape": [[2, 4, 5], [3, 8, 4]], 47 "shift": [[1, 3, 2], [3, -6, 5], [-5, 7, 8]], 48 "axis": [[0, 1, 2]], 49 }, 50 # Duplicate axis. 51 { 52 "input_dtype": [tf.float32], 53 "input_shape": [[2, 4, 5], [3, 8, 4]], 54 "shift": [[1, 3, -2]], 55 "axis": [[0, 1, 1]], 56 }, 57] 58 59 60@register_make_test_function() 61def make_roll_with_constant_tests(options): 62 """Make a set of tests to do roll with constant shift and axis.""" 63 64 def build_graph(parameters): 65 input_value = tf.compat.v1.placeholder( 66 dtype=parameters["input_dtype"], 67 name="input", 68 shape=parameters["input_shape"]) 69 outs = tf.roll( 70 input_value, shift=parameters["shift"], axis=parameters["axis"]) 71 return [input_value], [outs] 72 73 def build_inputs(parameters, sess, inputs, outputs): 74 input_value = create_tensor_data(parameters["input_dtype"], 75 parameters["input_shape"]) 76 return [input_value], sess.run( 77 outputs, feed_dict=dict(zip(inputs, [input_value]))) 78 79 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 80 81 82@register_make_test_function() 83def make_roll_tests(options): 84 """Make a set of tests to do roll.""" 85 86 ext_test_parameters = test_parameters + [ 87 # Scalar axis. 88 { 89 "input_dtype": [tf.float32, tf.int32], 90 "input_shape": [[None, 8, 4]], 91 "shift": [-3, 5], 92 "axis": [1, 2], 93 } 94 ] 95 96 def set_dynamic_shape(shape): 97 return [4 if x is None else x for x in shape] 98 99 def get_shape(param): 100 if np.isscalar(param): 101 return [] 102 return [len(param)] 103 104 def get_value(param, dtype): 105 if np.isscalar(param): 106 return np.dtype(dtype).type(param) 107 return np.array(param).astype(dtype) 108 109 def build_graph(parameters): 110 input_tensor = tf.compat.v1.placeholder( 111 dtype=parameters["input_dtype"], 112 name="input", 113 shape=parameters["input_shape"]) 114 shift_tensor = tf.compat.v1.placeholder( 115 dtype=tf.int64, name="shift", shape=get_shape(parameters["shift"])) 116 axis_tensor = tf.compat.v1.placeholder( 117 dtype=tf.int64, name="axis", shape=get_shape(parameters["axis"])) 118 outs = tf.roll(input_tensor, shift_tensor, axis_tensor) 119 return [input_tensor, shift_tensor, axis_tensor], [outs] 120 121 def build_inputs(parameters, sess, inputs, outputs): 122 input_value = create_tensor_data( 123 parameters["input_dtype"], set_dynamic_shape(parameters["input_shape"])) 124 shift_value = get_value(parameters["shift"], np.int64) 125 axis_value = get_value(parameters["axis"], np.int64) 126 return [input_value, shift_value, axis_value], sess.run( 127 outputs, 128 feed_dict=dict(zip(inputs, [input_value, shift_value, axis_value]))) 129 130 extra_toco_options = ExtraTocoOptions() 131 extra_toco_options.allow_custom_ops = True 132 make_zip_of_tests(options, ext_test_parameters, build_graph, build_inputs, 133 extra_toco_options) 134