1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the swig wrapper of clusters.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import device_properties_pb2 22from tensorflow.python.framework import meta_graph 23from tensorflow.python.framework import ops 24from tensorflow.python.grappler import cluster 25from tensorflow.python.grappler import item 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import random_ops 29from tensorflow.python.platform import test 30 31 32class ClusterTest(test.TestCase): 33 34 def testBasic(self): 35 with ops.Graph().as_default() as g: 36 a = random_ops.random_uniform(shape=()) 37 b = random_ops.random_uniform(shape=()) 38 c = a + b 39 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 40 train_op.append(c) 41 mg = meta_graph.create_meta_graph_def(graph=g) 42 grappler_item = item.Item(mg) 43 grappler_cluster = cluster.Cluster( 44 disable_detailed_stats=False, disable_timeline=False) 45 op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( 46 grappler_item) 47 self.assertTrue(run_time > 0) 48 self.assertEqual(len(op_perfs), 4) 49 self.assertTrue(step_stats.dev_stats) 50 51 def testNoDetailedStats(self): 52 with ops.Graph().as_default() as g: 53 a = random_ops.random_uniform(shape=()) 54 b = random_ops.random_uniform(shape=()) 55 c = a + b 56 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 57 train_op.append(c) 58 mg = meta_graph.create_meta_graph_def(graph=g) 59 grappler_item = item.Item(mg) 60 grappler_cluster = cluster.Cluster(disable_detailed_stats=True) 61 62 op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( 63 grappler_item) 64 self.assertTrue(run_time > 0) 65 self.assertEqual(len(op_perfs), 0) 66 self.assertEqual(len(step_stats.dev_stats), 0) 67 68 def testMemoryEstimates(self): 69 with ops.Graph().as_default() as g: 70 with ops.device('/job:localhost/replica:0/task:0/device:CPU:0'): 71 a = random_ops.random_uniform(shape=()) 72 b = random_ops.random_uniform(shape=()) 73 c = a + b 74 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 75 train_op.append(c) 76 mg = meta_graph.create_meta_graph_def(graph=g) 77 grappler_item = item.Item(mg) 78 grappler_cluster = cluster.Cluster( 79 disable_detailed_stats=True, disable_timeline=True) 80 peak_mem = grappler_cluster.DeterminePeakMemoryUsage(grappler_item) 81 self.assertLessEqual(1, len(peak_mem)) 82 snapshot = peak_mem['/job:localhost/replica:0/task:0/device:CPU:0'] 83 peak_usage = snapshot[0] 84 self.assertEqual(12, peak_usage) 85 live_tensors = snapshot[1] 86 self.assertEqual(5, len(live_tensors)) 87 88 def testVirtualCluster(self): 89 with ops.Graph().as_default() as g: 90 with ops.device('/device:GPU:0'): 91 a = random_ops.random_uniform(shape=[1024, 1024]) 92 b = random_ops.random_uniform(shape=[1024, 1024]) 93 c = a + b 94 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 95 train_op.append(c) 96 mg = meta_graph.create_meta_graph_def(graph=g) 97 grappler_item = item.Item(mg) 98 device_properties = device_properties_pb2.DeviceProperties( 99 type='GPU', 100 frequency=1000, 101 num_cores=60, 102 environment={'architecture': '7'}) 103 named_device = device_properties_pb2.NamedDevice( 104 properties=device_properties, name='/device:GPU:0') 105 grappler_cluster = cluster.Cluster( 106 disable_detailed_stats=False, 107 disable_timeline=False, 108 devices=[named_device]) 109 op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item) 110 self.assertEqual(run_time, 0.000209) 111 self.assertEqual(len(op_perfs), 5) 112 113 estimated_perf = grappler_cluster.EstimatePerformance(named_device) 114 self.assertEqual(7680.0, estimated_perf) 115 116 def testContext(self): 117 with ops.Graph().as_default() as g: 118 a = random_ops.random_uniform(shape=()) 119 b = random_ops.random_uniform(shape=()) 120 c = a + b 121 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 122 train_op.append(c) 123 mg = meta_graph.create_meta_graph_def(graph=g) 124 grappler_item = item.Item(mg) 125 126 with cluster.Provision( 127 disable_detailed_stats=False, disable_timeline=False) as gcluster: 128 op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) 129 self.assertTrue(run_time > 0) 130 self.assertEqual(len(op_perfs), 4) 131 self.assertTrue(step_stats.dev_stats) 132 133 def testAvailableOps(self): 134 with cluster.Provision() as gcluster: 135 op_names = gcluster.ListAvailableOps() 136 self.assertTrue('Add' in op_names) 137 self.assertTrue('MatMul' in op_names) 138 self.assertEqual(op_names, sorted(op_names)) 139 140 def testSupportDevices(self): 141 with ops.Graph().as_default() as g: 142 a = random_ops.random_uniform(shape=(2, 3)) 143 b = random_ops.random_uniform(shape=(2, 3)) 144 c = a + b 145 dims = math_ops.range(0, array_ops.rank(c), 1) 146 d = math_ops.reduce_sum(a, axis=dims) 147 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 148 train_op.append(d) 149 mg = meta_graph.create_meta_graph_def(graph=g) 150 grappler_item = item.Item(mg) 151 152 device_properties = device_properties_pb2.DeviceProperties( 153 type='GPU', frequency=1000, num_cores=60) 154 named_gpu = device_properties_pb2.NamedDevice( 155 properties=device_properties, name='/GPU:0') 156 device_properties = device_properties_pb2.DeviceProperties( 157 type='CPU', frequency=3000, num_cores=6) 158 named_cpu = device_properties_pb2.NamedDevice( 159 properties=device_properties, name='/CPU:0') 160 virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu]) 161 supported_dev = virtual_cluster.GetSupportedDevices(grappler_item) 162 self.assertEqual(supported_dev['add'], ['/CPU:0', '/GPU:0']) 163 self.assertEqual(supported_dev['Sum'], ['/CPU:0', '/GPU:0']) 164 self.assertEqual(supported_dev['range'], ['/CPU:0', '/GPU:0']) 165 166 real_cluster = cluster.Cluster() 167 supported_dev = real_cluster.GetSupportedDevices(grappler_item) 168 if test.is_gpu_available(): 169 self.assertEqual(supported_dev['add'], [ 170 '/job:localhost/replica:0/task:0/device:CPU:0', 171 '/job:localhost/replica:0/task:0/device:GPU:0' 172 ]) 173 self.assertEqual(supported_dev['Sum'], [ 174 '/job:localhost/replica:0/task:0/device:CPU:0', 175 '/job:localhost/replica:0/task:0/device:GPU:0' 176 ]) 177 # The axis tensor must reside on the host 178 self.assertEqual(supported_dev['range'], 179 ['/job:localhost/replica:0/task:0/device:CPU:0']) 180 else: 181 self.assertEqual(supported_dev['add'], 182 ['/job:localhost/replica:0/task:0/device:CPU:0']) 183 184 185if __name__ == '__main__': 186 test.main() 187