• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the cost analyzer."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import meta_graph
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.grappler import model_analyzer
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import test
28
29
30class PyWrapOptimizeGraphTest(test.TestCase):
31
32  @test_util.run_deprecated_v1
33  def testBasic(self):
34    """Make sure arguments can be passed correctly."""
35    a = constant_op.constant([10, 11], name="a")
36    b = constant_op.constant([10], name="b")
37    c = math_ops.add(a, b, name="c")
38    d = math_ops.add_n([a, c], name="d")
39    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
40    train_op.append(d)
41    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
42
43    report = model_analyzer.GenerateModelReport(mg)
44
45    # Check the report headers
46    self.assertTrue(b"a [Const]" in report)
47    self.assertTrue(b"a [Const]" in report)
48    self.assertTrue(b"c [Add]" in report)
49    self.assertTrue(b"d [AddN]" in report)
50
51    # Also print the report to make it easier to debug
52    print("{}".format(report))
53
54  @test_util.run_deprecated_v1
55  def testDebugMode(self):
56    """Make sure arguments can be passed correctly."""
57    a = constant_op.constant([10, 11], name="a")
58    b = constant_op.constant([10], name="b")
59    c = math_ops.add(a, b, name="c")
60    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
61    train_op.append(c)
62    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
63
64    report = model_analyzer.GenerateModelReport(mg, debug=True)
65
66    # Check the report headers
67    self.assertTrue(b"input 0 (int32) has known value" in report)
68    self.assertTrue(b"input 1 (int32) has known value" in report)
69
70    # Also print the report to make it easier to debug
71    print("{}".format(report))
72
73
74if __name__ == "__main__":
75  test.main()
76