• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the swig wrapper of items."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import errors_impl
20from tensorflow.python.framework import meta_graph
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import test_util
24from tensorflow.python.grappler import item
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_array_ops
27from tensorflow.python.ops import state_ops
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import test
30
31
32class ItemTest(test.TestCase):
33
34  def testInvalidItem(self):
35    with ops.Graph().as_default() as g:
36      a = constant_op.constant(10)
37      b = constant_op.constant(20)
38      c = a + b  # pylint: disable=unused-variable
39      mg = meta_graph.create_meta_graph_def(graph=g)
40
41    # The train op isn't specified: this should raise an InvalidArgumentError
42    # exception.
43    with self.assertRaises(errors_impl.InvalidArgumentError):
44      item.Item(mg)
45
46  def testImportantOps(self):
47    with ops.Graph().as_default() as g:
48      a = constant_op.constant(10)
49      b = constant_op.constant(20)
50      c = a + b
51      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
52      train_op.append(c)
53      mg = meta_graph.create_meta_graph_def(graph=g)
54      grappler_item = item.Item(mg)
55      op_list = grappler_item.IdentifyImportantOps()
56      self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
57
58  def testOpProperties(self):
59    with ops.Graph().as_default() as g:
60      a = constant_op.constant(10)
61      b = constant_op.constant(20)
62      c = a + b
63      z = control_flow_ops.no_op()
64      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
65      train_op.append(c)
66      mg = meta_graph.create_meta_graph_def(graph=g)
67      grappler_item = item.Item(mg)
68      op_properties = grappler_item.GetOpProperties()
69
70      # All the nodes in this model have one scalar output
71      for node in grappler_item.metagraph.graph_def.node:
72        node_prop = op_properties[node.name]
73
74        if node.name == z.name:
75          self.assertEqual(0, len(node_prop))
76        else:
77          self.assertEqual(1, len(node_prop))
78          self.assertEqual(dtypes.int32, node_prop[0].dtype)
79          self.assertEqual(tensor_shape.TensorShape([]), node_prop[0].shape)
80
81  def testUpdates(self):
82    with ops.Graph().as_default() as g:
83      a = constant_op.constant(10)
84      b = constant_op.constant(20)
85      c = a + b
86      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
87      train_op.append(c)
88      mg = meta_graph.create_meta_graph_def(graph=g)
89      grappler_item = item.Item(mg)
90
91    initial_tf_item = grappler_item.tf_item
92    no_change_tf_item = grappler_item.tf_item
93    self.assertEqual(initial_tf_item, no_change_tf_item)
94
95    # Modify the placement.
96    for node in grappler_item.metagraph.graph_def.node:
97      node.device = '/cpu:0'
98    new_tf_item = grappler_item.tf_item
99    self.assertNotEqual(initial_tf_item, new_tf_item)
100
101    # Assign the same placement.
102    for node in grappler_item.metagraph.graph_def.node:
103      node.device = '/cpu:0'
104    newest_tf_item = grappler_item.tf_item
105    self.assertEqual(new_tf_item, newest_tf_item)
106
107  @test_util.run_v1_only('b/120545219')
108  def testColocationConstraints(self):
109    with ops.Graph().as_default() as g:
110      c = constant_op.constant([10])
111      v = variables.VariableV1([3], dtype=dtypes.int32)
112      i = gen_array_ops.ref_identity(v)
113      a = state_ops.assign(i, c)
114      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
115      train_op.append(a)
116      mg = meta_graph.create_meta_graph_def(graph=g)
117      grappler_item = item.Item(mg)
118      groups = grappler_item.GetColocationGroups()
119      self.assertEqual(len(groups), 1)
120      self.assertItemsEqual(
121          groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
122
123
124if __name__ == '__main__':
125  test.main()
126