1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the cost analyzer.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import meta_graph 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.grappler import model_analyzer 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import test 28 29 30class PyWrapOptimizeGraphTest(test.TestCase): 31 32 @test_util.run_deprecated_v1 33 def testBasic(self): 34 """Make sure arguments can be passed correctly.""" 35 a = constant_op.constant([10, 11], name="a") 36 b = constant_op.constant([10], name="b") 37 c = math_ops.add(a, b, name="c") 38 d = math_ops.add_n([a, c], name="d") 39 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 40 train_op.append(d) 41 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 42 43 report = model_analyzer.GenerateModelReport(mg) 44 45 # Check the report headers 46 self.assertTrue(b"a [Const]" in report) 47 self.assertTrue(b"a [Const]" in report) 48 self.assertTrue(b"c [Add]" in report) 49 self.assertTrue(b"d [AddN]" in report) 50 51 # Also print the report to make it easier to debug 52 print("{}".format(report)) 53 54 @test_util.run_deprecated_v1 55 def testDebugMode(self): 56 """Make sure arguments can be passed correctly.""" 57 a = constant_op.constant([10, 11], name="a") 58 b = constant_op.constant([10], name="b") 59 c = math_ops.add(a, b, name="c") 60 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 61 train_op.append(c) 62 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) 63 64 report = model_analyzer.GenerateModelReport(mg, debug=True) 65 66 # Check the report headers 67 self.assertTrue(b"input 0 (int32) has known value" in report) 68 self.assertTrue(b"input 1 (int32) has known value" in report) 69 70 # Also print the report to make it easier to debug 71 print("{}".format(report)) 72 73 74if __name__ == "__main__": 75 test.main() 76