• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for InputToOps class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.quantize.python import input_to_ops
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import nn_ops
26from tensorflow.python.platform import googletest
27
28
29class InputToOpsTest(test_util.TensorFlowTestCase):
30
31  def testNoConsumerOperations(self):
32    graph = ops.Graph()
33    with graph.as_default():
34      input_tensor = array_ops.zeros((1, 2, 3, 4))
35
36    input_to_ops_map = input_to_ops.InputToOps(graph)
37    consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op)
38
39    self.assertEqual(0, len(consumer_operations))
40
41  def testOneConsumerOperation(self):
42    graph = ops.Graph()
43    with graph.as_default():
44      input_tensor = array_ops.zeros((1, 2, 3, 4))
45      output_tensor = nn_ops.relu6(input_tensor)
46
47    input_to_ops_map = input_to_ops.InputToOps(graph)
48    consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op)
49
50    self.assertEqual(consumer_operations, {output_tensor.op})
51
52  def testSeveralConsumerOperations(self):
53    graph = ops.Graph()
54    with graph.as_default():
55      input_tensor = array_ops.zeros((1, 2, 3, 4))
56      output_tensor_1 = nn_ops.relu6(input_tensor)
57      output_tensor_2 = input_tensor + output_tensor_1
58      output_tensor_3 = input_tensor * output_tensor_2
59
60    input_to_ops_map = input_to_ops.InputToOps(graph)
61    consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op)
62
63    self.assertEqual(consumer_operations,
64                     {output_tensor_1.op, output_tensor_2.op,
65                      output_tensor_3.op})
66
67if __name__ == '__main__':
68  googletest.main()
69