1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test configs for while_loop.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import tensorflow.compat.v1 as tf 22from tensorflow.lite.testing import zip_test_utils 23from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests 24from tensorflow.lite.testing.zip_test_utils import register_make_test_function 25from tensorflow.python.framework import test_util 26 27 28@register_make_test_function("make_while_tests") 29@test_util.enable_control_flow_v2 30def make_while_tests(options): 31 """Make a set of tests to do while.""" 32 # Chose a set of parameters 33 test_parameters = [{ 34 "num_iterations": range(20), 35 "increment_value": [[1]], 36 "dtype": [tf.int32], 37 }, { 38 "num_iterations": range(20), 39 "increment_value": [["a"]], 40 "dtype": [tf.string], 41 }] 42 43 def build_graph(parameters): 44 """Build the graph for while tests.""" 45 # MLIR TFLite converter can't handle scalar inputs. This is a workaround 46 # to input (1,) tensors and then reshape to scalar. 47 # TODO(b/129003347): Remove the workaround after scalar inputs are 48 # supported. 49 num_iterations = tf.placeholder( 50 dtype=tf.int32, name="num_iterations", shape=(1,)) 51 increment_value = tf.placeholder( 52 dtype=parameters["dtype"], name="increment_value", shape=(1,)) 53 num_iterations_scalar = tf.reshape(num_iterations, ()) 54 55 # For intger inputs, this simple model calucates i-th number of triangular 56 # sequence. For string inputs, the model returns the string value, filled 57 # with the given increment value times the given num_iterations. 58 # The model also returns the counter variable and increment value in the 59 # outputs. The counter and increment value are passed to the result to make 60 # sure the necessary control depenecy of the model is generated for testing 61 # the dynamic tensor cases. 62 def cond_fn(counter, value, increment_value): 63 del value 64 del increment_value 65 return counter < num_iterations_scalar 66 67 def body_fn(counter, value, increment_value): 68 new_counter = counter + 1 69 if parameters["dtype"] == tf.string: 70 # Use fill op to create new string value with the given counter value. 71 del value 72 new_value = tf.fill([1], tf.reshape(increment_value, ())) 73 else: 74 new_value = value + increment_value 75 return [new_counter, new_value, increment_value] 76 77 counter, value, result_increment_value = tf.while_loop( 78 cond_fn, body_fn, loop_vars=[1, increment_value, increment_value]) 79 return [num_iterations, 80 increment_value], [counter, value, result_increment_value] 81 82 def build_inputs(parameters, sess, inputs, outputs): 83 numpy_type = zip_test_utils.TF_TYPE_INFO[parameters["dtype"]][0] 84 input_values = [ 85 np.array([parameters["num_iterations"]], dtype=np.int32), 86 np.array(parameters["increment_value"], dtype=numpy_type) 87 ] 88 return input_values, sess.run( 89 outputs, feed_dict=dict(zip(inputs, input_values))) 90 91 make_zip_of_tests(options, test_parameters, build_graph, build_inputs) 92