• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for while_loop."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import tensorflow.compat.v1 as tf
22from tensorflow.lite.testing import zip_test_utils
23from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
24from tensorflow.lite.testing.zip_test_utils import register_make_test_function
25from tensorflow.python.framework import test_util
26
27
28@register_make_test_function("make_while_tests")
29@test_util.enable_control_flow_v2
30def make_while_tests(options):
31  """Make a set of tests to do while."""
32  # Chose a set of parameters
33  test_parameters = [{
34      "num_iterations": range(20),
35      "increment_value": [[1]],
36      "dtype": [tf.int32],
37  }, {
38      "num_iterations": range(20),
39      "increment_value": [["a"]],
40      "dtype": [tf.string],
41  }]
42
43  def build_graph(parameters):
44    """Build the graph for while tests."""
45    # MLIR TFLite converter can't handle scalar inputs. This is a workaround
46    # to input (1,) tensors and then reshape to scalar.
47    # TODO(b/129003347): Remove the workaround after scalar inputs are
48    # supported.
49    num_iterations = tf.placeholder(
50        dtype=tf.int32, name="num_iterations", shape=(1,))
51    increment_value = tf.placeholder(
52        dtype=parameters["dtype"], name="increment_value", shape=(1,))
53    num_iterations_scalar = tf.reshape(num_iterations, ())
54
55    # For intger inputs, this simple model calucates i-th number of triangular
56    # sequence. For string inputs, the model returns the string value, filled
57    # with the given increment value times the given num_iterations.
58    # The model also returns the counter variable and increment value in the
59    # outputs. The counter and increment value are passed to the result to make
60    # sure the necessary control depenecy of the model is generated for testing
61    # the dynamic tensor cases.
62    def cond_fn(counter, value, increment_value):
63      del value
64      del increment_value
65      return counter < num_iterations_scalar
66
67    def body_fn(counter, value, increment_value):
68      new_counter = counter + 1
69      if parameters["dtype"] == tf.string:
70        # Use fill op to create new string value with the given counter value.
71        del value
72        new_value = tf.fill([1], tf.reshape(increment_value, ()))
73      else:
74        new_value = value + increment_value
75      return [new_counter, new_value, increment_value]
76
77    counter, value, result_increment_value = tf.while_loop(
78        cond_fn, body_fn, loop_vars=[1, increment_value, increment_value])
79    return [num_iterations,
80            increment_value], [counter, value, result_increment_value]
81
82  def build_inputs(parameters, sess, inputs, outputs):
83    numpy_type = zip_test_utils.TF_TYPE_INFO[parameters["dtype"]][0]
84    input_values = [
85        np.array([parameters["num_iterations"]], dtype=np.int32),
86        np.array(parameters["increment_value"], dtype=numpy_type)
87    ]
88    return input_values, sess.run(
89        outputs, feed_dict=dict(zip(inputs, input_values)))
90
91  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
92