• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for unroll_batch_matmul."""
16import tensorflow.compat.v1 as tf
17from tensorflow.lite.testing.zip_test_utils import create_tensor_data
18from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
19from tensorflow.lite.testing.zip_test_utils import register_make_test_function
20
21
22@register_make_test_function()
23def make_unroll_batch_matmul_tests(options):
24  """Make a set of tests to test unroll_batch_matmul."""
25
26  # The test cases below requires broadcasting support (BatchMatMulV2 semantic),
27  # whis isn't supported as of this change.
28  broadcast_shape_params = [
29      # Simple broadcast.
30      [(1, 2, 3), (3, 5), False, False],
31      # Empty batch broadcast.
32      [(2, 5, 3), (3, 7), False, False],
33      # Single batch with non-empty batch broadcast.
34      [(1, 5, 3), (4, 3, 7), False, False],
35      # Broadcast both operands
36      [(3, 1, 5, 3), (1, 4, 3, 7), False, False],
37  ]
38
39  test_parameters = [{
40      "dtype": [tf.float32],
41      "shape": [[(2, 2, 3),
42                 (2, 3, 2), False, False], [(2, 2, 3), (2, 3, 2), True, True],
43                [(2, 2, 3),
44                 (2, 2, 3), False, True], [(2, 2, 3), (2, 2, 3), True, False],
45                [(4, 2, 2, 3), (4, 2, 3, 2), False, False],
46                [(4, 2, 2, 3), (4, 2, 3, 2), True, True],
47                [(4, 2, 2, 3), (4, 2, 2, 3), False, True],
48                [(4, 2, 2, 3),
49                 (4, 2, 2, 3), True, False]] + broadcast_shape_params,
50  }]
51
52  def build_graph(parameters):
53    """Build the batch_matmul op testing graph."""
54
55    def _build_graph():
56      """Build the graph."""
57      input_tensor1 = tf.compat.v1.placeholder(
58          dtype=parameters["dtype"], shape=parameters["shape"][0])
59      input_tensor2 = tf.compat.v1.placeholder(
60          dtype=parameters["dtype"], shape=parameters["shape"][1])
61      # Should be unrolled and replaced with fully_connected ops in the end.
62      out = tf.matmul(
63          input_tensor1,
64          input_tensor2,
65          transpose_a=parameters["shape"][2],
66          transpose_b=parameters["shape"][3])
67      return [input_tensor1, input_tensor2], [out]
68
69    return _build_graph()
70
71  def build_inputs(parameters, sess, inputs, outputs):
72    input_value1 = create_tensor_data(
73        parameters["dtype"], shape=parameters["shape"][0])
74    input_value2 = create_tensor_data(
75        parameters["dtype"], shape=parameters["shape"][1])
76    return [input_value1, input_value2], sess.run(
77        outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
78
79  make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
80