• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the swig wrapper of clusters."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import device_properties_pb2
22from tensorflow.python.framework import meta_graph
23from tensorflow.python.framework import ops
24from tensorflow.python.grappler import cluster
25from tensorflow.python.grappler import item
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import random_ops
29from tensorflow.python.platform import test
30
31
32class ClusterTest(test.TestCase):
33
34  def testBasic(self):
35    with ops.Graph().as_default() as g:
36      a = random_ops.random_uniform(shape=())
37      b = random_ops.random_uniform(shape=())
38      c = a + b
39      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
40      train_op.append(c)
41      mg = meta_graph.create_meta_graph_def(graph=g)
42      grappler_item = item.Item(mg)
43      grappler_cluster = cluster.Cluster(
44          disable_detailed_stats=False, disable_timeline=False)
45      op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
46          grappler_item)
47      self.assertTrue(run_time > 0)
48      self.assertEqual(len(op_perfs), 4)
49      self.assertTrue(step_stats.dev_stats)
50
51  def testNoDetailedStats(self):
52    with ops.Graph().as_default() as g:
53      a = random_ops.random_uniform(shape=())
54      b = random_ops.random_uniform(shape=())
55      c = a + b
56      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
57      train_op.append(c)
58      mg = meta_graph.create_meta_graph_def(graph=g)
59      grappler_item = item.Item(mg)
60      grappler_cluster = cluster.Cluster(disable_detailed_stats=True)
61
62      op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
63          grappler_item)
64      self.assertTrue(run_time > 0)
65      self.assertEqual(len(op_perfs), 0)
66      self.assertEqual(len(step_stats.dev_stats), 0)
67
68  def testMemoryEstimates(self):
69    with ops.Graph().as_default() as g:
70      with ops.device('/job:localhost/replica:0/task:0/device:CPU:0'):
71        a = random_ops.random_uniform(shape=())
72        b = random_ops.random_uniform(shape=())
73        c = a + b
74        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
75        train_op.append(c)
76        mg = meta_graph.create_meta_graph_def(graph=g)
77        grappler_item = item.Item(mg)
78        grappler_cluster = cluster.Cluster(
79            disable_detailed_stats=True, disable_timeline=True)
80        peak_mem = grappler_cluster.DeterminePeakMemoryUsage(grappler_item)
81        self.assertLessEqual(1, len(peak_mem))
82        snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0']
83        peak_usage = snapshot[0]
84        self.assertEqual(12, peak_usage)
85        live_tensors = snapshot[1]
86        self.assertEqual(5, len(live_tensors))
87
88  def testVirtualCluster(self):
89    with ops.Graph().as_default() as g:
90      with ops.device('/device:GPU:0'):
91        a = random_ops.random_uniform(shape=[1024, 1024])
92        b = random_ops.random_uniform(shape=[1024, 1024])
93        c = a + b
94      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
95      train_op.append(c)
96      mg = meta_graph.create_meta_graph_def(graph=g)
97      grappler_item = item.Item(mg)
98      device_properties = device_properties_pb2.DeviceProperties(
99          type='GPU',
100          frequency=1000,
101          num_cores=60,
102          environment={'architecture': '7'})
103      named_device = device_properties_pb2.NamedDevice(
104          properties=device_properties, name='/device:GPU:0')
105      grappler_cluster = cluster.Cluster(
106          disable_detailed_stats=False,
107          disable_timeline=False,
108          devices=[named_device])
109      op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
110      self.assertEqual(run_time, 0.000209)
111      self.assertEqual(len(op_perfs), 5)
112
113      estimated_perf = grappler_cluster.EstimatePerformance(named_device)
114      self.assertEqual(7680.0, estimated_perf)
115
116  def testContext(self):
117    with ops.Graph().as_default() as g:
118      a = random_ops.random_uniform(shape=())
119      b = random_ops.random_uniform(shape=())
120      c = a + b
121      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
122      train_op.append(c)
123      mg = meta_graph.create_meta_graph_def(graph=g)
124      grappler_item = item.Item(mg)
125
126    with cluster.Provision(
127        disable_detailed_stats=False, disable_timeline=False) as gcluster:
128      op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
129      self.assertTrue(run_time > 0)
130      self.assertEqual(len(op_perfs), 4)
131      self.assertTrue(step_stats.dev_stats)
132
133  def testAvailableOps(self):
134    with cluster.Provision() as gcluster:
135      op_names = gcluster.ListAvailableOps()
136      self.assertTrue('Add' in op_names)
137      self.assertTrue('MatMul' in op_names)
138      self.assertEqual(op_names, sorted(op_names))
139
140  def testSupportDevices(self):
141    with ops.Graph().as_default() as g:
142      a = random_ops.random_uniform(shape=(2, 3))
143      b = random_ops.random_uniform(shape=(2, 3))
144      c = a + b
145      dims = math_ops.range(0, array_ops.rank(c), 1)
146      d = math_ops.reduce_sum(a, axis=dims)
147      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
148      train_op.append(d)
149      mg = meta_graph.create_meta_graph_def(graph=g)
150      grappler_item = item.Item(mg)
151
152      device_properties = device_properties_pb2.DeviceProperties(
153          type='GPU', frequency=1000, num_cores=60)
154      named_gpu = device_properties_pb2.NamedDevice(
155          properties=device_properties, name='/GPU:0')
156      device_properties = device_properties_pb2.DeviceProperties(
157          type='CPU', frequency=3000, num_cores=6)
158      named_cpu = device_properties_pb2.NamedDevice(
159          properties=device_properties, name='/CPU:0')
160      virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
161      supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
162      self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0'])
163      self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0'])
164      self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0'])
165
166      real_cluster = cluster.Cluster()
167      supported_dev = real_cluster.GetSupportedDevices(grappler_item)
168      if test.is_gpu_available():
169        self.assertEqual(supported_dev['add'], [
170            '/job:localhost/replica:0/task:0/device:CPU:0',
171            '/job:localhost/replica:0/task:0/device:GPU:0'
172        ])
173        self.assertEqual(supported_dev['Sum'], [
174            '/job:localhost/replica:0/task:0/device:CPU:0',
175            '/job:localhost/replica:0/task:0/device:GPU:0'
176        ])
177        # The axis tensor must reside on the host
178        self.assertEqual(supported_dev['range'],
179                         ['/job:localhost/replica:0/task:0/device:CPU:0'])
180      else:
181        self.assertEqual(supported_dev['add'],
182                         ['/job:localhost/replica:0/task:0/device:CPU:0'])
183
184
185if __name__ == '__main__':
186  test.main()
187