• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test configs for fully_connected."""
16import numpy as np
17import tensorflow.compat.v1 as tf
18from tensorflow.lite.testing.zip_test_utils import create_tensor_data
19from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
20from tensorflow.lite.testing.zip_test_utils import register_make_test_function
21
22
23@register_make_test_function()
24def make_fully_connected_tests(options):
25  """Make a set of tests to do fully_connected."""
26
27  test_parameters = [{
28      "shape1": [[3, 3]],
29      "shape2": [[3, 3]],
30      "transpose_a": [True, False],
31      "transpose_b": [True, False],
32      "constant_filter": [True, False],
33      "fully_quantize": [False],
34      "quant_16x8": [False]
35  }, {
36      "shape1": [[4, 4], [1, 4], [4]],
37      "shape2": [[4, 4], [4, 1], [4]],
38      "transpose_a": [False],
39      "transpose_b": [False],
40      "constant_filter": [True, False],
41      "fully_quantize": [False],
42      "quant_16x8": [False]
43  }, {
44      "shape1": [[40, 37]],
45      "shape2": [[37, 40]],
46      "transpose_a": [False],
47      "transpose_b": [False],
48      "constant_filter": [True, False],
49      "fully_quantize": [False],
50      "quant_16x8": [False]
51  }, {
52      "shape1": [[40, 37]],
53      "shape2": [[40, 37]],
54      "transpose_a": [False],
55      "transpose_b": [True],
56      "constant_filter": [True, False],
57      "fully_quantize": [False],
58      "quant_16x8": [False]
59  }, {
60      "shape1": [[5, 3]],
61      "shape2": [[5, 3]],
62      "transpose_a": [True],
63      "transpose_b": [False],
64      "constant_filter": [True, False],
65      "fully_quantize": [False],
66      "quant_16x8": [False]
67  }, {
68      "shape1": [[1, 3]],
69      "shape2": [[3, 3]],
70      "transpose_a": [False],
71      "transpose_b": [False],
72      "constant_filter": [True],
73      "fully_quantize": [True],
74      "quant_16x8": [False]
75  }, {
76      "shape1": [[1, 4], [4]],
77      "shape2": [[4, 4], [4, 1], [4]],
78      "transpose_a": [False],
79      "transpose_b": [False],
80      "constant_filter": [True],
81      "fully_quantize": [True],
82      "quant_16x8": [False]
83  }, {
84      "shape1": [[1, 37], [2, 37]],
85      "shape2": [[37, 40]],
86      "transpose_a": [False],
87      "transpose_b": [False],
88      "constant_filter": [True],
89      "fully_quantize": [True],
90      "quant_16x8": [False]
91  }, {
92      "shape1": [[1, 3], [2, 3]],
93      "shape2": [[3, 5], [3, 1]],
94      "transpose_a": [False],
95      "transpose_b": [False],
96      "constant_filter": [True],
97      "fully_quantize": [True],
98      "quant_16x8": [False]
99  }, {
100      "shape1": [[2, 3]],
101      "shape2": [[3, 5]],
102      "transpose_a": [False],
103      "transpose_b": [False],
104      "constant_filter": [True],
105      "fully_quantize": [True],
106      "quant_16x8": [True]
107  }, {
108      "shape1": [[0, 3]],
109      "shape2": [[3, 3]],
110      "transpose_a": [False],
111      "transpose_b": [False],
112      "constant_filter": [True, False],
113      "fully_quantize": [False],
114      "quant_16x8": [False]
115  }, {
116      "shape1": [[3, 0]],
117      "shape2": [[0, 3]],
118      "transpose_a": [False],
119      "transpose_b": [False],
120      "constant_filter": [True, False],
121      "fully_quantize": [False],
122      "quant_16x8": [False]
123  }]
124
125  if options.use_experimental_converter:
126    test_parameters = test_parameters + [
127        # Zero in input shape.
128        {
129            "shape1": [[0, 3]],
130            "shape2": [[3, 3]],
131            "transpose_a": [False],
132            "transpose_b": [False],
133            "constant_filter": [True, False],
134            "fully_quantize": [False],
135            "quant_16x8": [False]
136        }
137    ]
138
139  def build_graph(parameters):
140    """Build a matmul graph given `parameters`."""
141    input_tensor1 = tf.compat.v1.placeholder(
142        dtype=tf.float32, name="input1", shape=parameters["shape1"])
143
144    # Get input_tensor2 either as a placeholder or constants. Also get a list of
145    # the input tensors that are represented as placeholders.
146    if parameters["constant_filter"]:
147      input_tensor2 = create_tensor_data(
148          np.float32, parameters["shape2"], min_value=-1, max_value=1)
149      input_tensors = [input_tensor1]
150    else:
151      input_tensor2 = tf.compat.v1.placeholder(
152          dtype=tf.float32, name="input2", shape=parameters["shape2"])
153      input_tensors = [input_tensor1, input_tensor2]
154
155    out = tf.matmul(
156        input_tensor1,
157        input_tensor2,
158        transpose_a=parameters["transpose_a"],
159        transpose_b=parameters["transpose_b"])
160    return input_tensors, [out]
161
162  def build_inputs(parameters, sess, inputs, outputs):
163    # pylint: disable=g-doc-return-or-yield, g-doc-args
164    """Build list of input values.
165
166    It either contains 1 tensor (input_values1) or
167    2 tensors (input_values1, input_values2) based on whether the second input
168    is a constant or variable input.
169    """
170
171    values = [
172        create_tensor_data(
173            np.float32, shape=parameters["shape1"], min_value=-1, max_value=1)
174    ]
175    if not parameters["constant_filter"]:
176      values.append(
177          create_tensor_data(
178              np.float32, parameters["shape2"], min_value=-1, max_value=1))
179    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
180
181  make_zip_of_tests(
182      options,
183      test_parameters,
184      build_graph,
185      build_inputs,
186      expected_tf_failures=14)
187