• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StatSummarizer Python wrapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python import pywrap_tensorflow
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import test
28
29
30class StatSummarizerTest(test.TestCase):
31
32  def testStatSummarizer(self):
33    with ops.Graph().as_default() as graph:
34      matrix1 = constant_op.constant([[3., 3.]], name=r"m1")
35      matrix2 = constant_op.constant([[2.], [2.]], name=r"m2")
36      product = math_ops.matmul(matrix1, matrix2, name=r"product")
37
38      graph_def = graph.as_graph_def()
39      ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
40
41      with self.cached_session() as sess:
42        sess.run(variables.global_variables_initializer())
43
44        for _ in range(20):
45          run_metadata = config_pb2.RunMetadata()
46          run_options = config_pb2.RunOptions(
47              trace_level=config_pb2.RunOptions.FULL_TRACE)
48          sess.run(product, options=run_options, run_metadata=run_metadata)
49
50          ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString())
51
52      output_string = ss.GetOutputString()
53
54      print(output_string)
55
56      # Test it recorded running the expected number of times.
57      self.assertRegexpMatches(output_string, r"count=20")
58
59      # Test that a header line got printed.
60      self.assertRegexpMatches(output_string, r"====== .* ======")
61
62      # Test that the nodes we added were analyzed.
63      # The line for the op should contain both the op type (MatMul)
64      # and the name of the node (product)
65      self.assertRegexpMatches(output_string, r"MatMul.*product")
66      self.assertRegexpMatches(output_string, r"Const.*m1")
67      self.assertRegexpMatches(output_string, r"Const.*m2")
68
69      # Test that a CDF summed to 100%
70      self.assertRegexpMatches(output_string, r"100\.")
71
72      pywrap_tensorflow.DeleteStatSummarizer(ss)
73
74
75if __name__ == "__main__":
76  test.main()
77