• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for matmul_benchmark.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import numpy as np
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.core.framework import node_def_pb2
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import matmul_benchmark
28from tensorflow.python.platform import test as googletest
29from tensorflow.python.platform import tf_logging
30
31
32def BuildGraphTest(n, m, k, transpose_a, transpose_b, dtype):
33
34  def Test(self):
35    if not googletest.is_gpu_available():
36      tf_logging.info("Skipping BuildGraphTest %s",
37                      (n, m, k, transpose_a, transpose_b))
38      return
39    tf_logging.info("Testing BuildGraphTest %s",
40                    (n, m, k, transpose_a, transpose_b))
41    self._VerifyBuildGraph(n, m, k, transpose_a, transpose_b, dtype)
42
43  return Test
44
45
46def RunGraphTest(n, m, k, transpose_a, transpose_b, dtype):
47
48  def Test(self):
49    if not googletest.is_gpu_available():
50      tf_logging.info("Skipping RunGraphTest %s",
51                      (n, m, k, transpose_a, transpose_b))
52      return
53    tf_logging.info("Testing RunGraphTest %s",
54                    (n, m, k, transpose_a, transpose_b))
55    self._VerifyRunGraph(n, m, k, transpose_a, transpose_b, dtype)
56
57  return Test
58
59
60class MatmulBenchmarkTest(googletest.TestCase):
61
62  def _StripNode(self, nd):
63    snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
64    if nd.device:
65      snode.device = nd.device
66    return snode
67
68  def _StripGraph(self, gd):
69    return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
70
71  def _VerifyBuildGraph(self, n, m, k, transpose_a, transpose_b, dtype):
72    graph = ops.Graph()
73    with graph.as_default():
74      matmul_benchmark.build_graph(googletest.gpu_device_name(), n, m, k,
75                                   transpose_a, transpose_b, dtype)
76      gd = graph.as_graph_def()
77      dev = googletest.gpu_device_name()
78      proto_expected = """
79      node { name: "random_uniform/shape" op: "Const" device: \"""" + dev + """\" }
80      node { name: "random_uniform/min" op: "Const" device: \"""" + dev + """\" }
81      node { name: "random_uniform/max" op: "Const" device: \"""" + dev + """\" }
82      node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: \"""" + dev + """\" }
83      node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: \"""" + dev + """\" }
84      node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: \"""" + dev + """\" }
85      node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: \"""" + dev + """\" }
86      node { name: "Variable" op: "VariableV2" device: \"""" + dev + """\" }
87      node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: \"""" + dev + """\" }
88      node { name: "Variable/read" op: "Identity" input: "Variable" device: \"""" + dev + """\" }
89      node { name: "random_uniform_1/shape" op: "Const" device: \"""" + dev + """\" }
90      node { name: "random_uniform_1/min" op: "Const" device: \"""" + dev + """\" }
91      node { name: "random_uniform_1/max" op: "Const" device: \"""" + dev + """\" }
92      node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: \"""" + dev + """\" }
93      node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: \"""" + dev + """\" }
94      node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: \"""" + dev + """\" }
95      node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: \"""" + dev + """\" }
96      node { name: "Variable_1" op: "VariableV2" device: \"""" + dev + """\" }
97      node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: \"""" + dev + """\" }
98      node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: \"""" + dev + """\" }
99      node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: \"""" + dev + """\" }
100      node { name: "group_deps" op: "NoOp" input: "^MatMul" device: \"""" + dev + """\" }
101                       """
102      self.assertProtoEquals(str(proto_expected), self._StripGraph(gd))
103
104  def _VerifyRunGraph(self, n, m, k, transpose_a, transpose_b, dtype):
105    benchmark_instance = matmul_benchmark.MatmulBenchmark()
106    duration = benchmark_instance.run_graph(googletest.gpu_device_name(), n, m,
107                                            k, transpose_a, transpose_b, 1,
108                                            dtype)
109    self.assertTrue(duration > 1e-6)
110
111
112if __name__ == "__main__":
113  dtypes = [np.float32, np.float64]
114  index = 0
115  for _dtype in dtypes:
116    for _n, _m, (_transpose_a, _transpose_b) in itertools.product(
117        [512, 1024], [1, 8, 16, 128], [(False, False), (True, False),
118                                       (False, True)]):
119      _k = _n
120      setattr(MatmulBenchmarkTest, "testBuildGraph_" + str(index),
121              BuildGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
122      setattr(MatmulBenchmarkTest, "testRunGraph_" + str(index),
123              RunGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
124      index += 1
125  googletest.main()
126